# app.py from fastapi import FastAPI, UploadFile, File, HTTPException import traceback import numpy as np import librosa import joblib import tempfile import os import tensorflow as tf import tensorflow_hub as hub # ========================= # Configuration # ========================= SR = 16000 DETECTOR_MODEL_PATH = "detection_models/yamnet_lr_model.joblib" DETECTOR_SCALER_PATH = "detection_models/scaler_yamnet.pkl" DETECTOR_PCA_PATH = "detection_models/pca_yamnet.pkl" CLASS_ENSEMBLE_PATH = "classification_models/babycry_ensemble.pkl" CLASS_SCALER_PATH = "classification_models/scaler.pkl" CLASS_SELECTOR_PATH = "classification_models/feature_selector.pkl" CLASS_LE_PATH = "classification_models/label_encoder.pkl" # ========================= # Load models (ONCE) # ========================= yamnet = hub.load("https://tfhub.dev/google/yamnet/1") det_model = joblib.load(DETECTOR_MODEL_PATH) det_scaler = joblib.load(DETECTOR_SCALER_PATH) det_pca = joblib.load(DETECTOR_PCA_PATH) ensemble = joblib.load(CLASS_ENSEMBLE_PATH) cls_scaler = joblib.load(CLASS_SCALER_PATH) feature_selector = joblib.load(CLASS_SELECTOR_PATH) label_encoder = joblib.load(CLASS_LE_PATH) # ========================= # Feature Extraction # ========================= def extract_yamnet_embedding(path): wav, _ = librosa.load(path, sr=SR, mono=True) waveform = tf.convert_to_tensor(wav, dtype=tf.float32) _, embeddings, _ = yamnet(waveform) emb = embeddings.numpy() mean_emb = np.mean(emb, axis=0) std_emb = np.std(emb, axis=0) return np.concatenate([mean_emb, std_emb]).reshape(1, -1) def extract_classification_features(path): y, sr = librosa.load(path, sr=SR) stft = np.abs(librosa.stft(y)) mfcc = np.mean(librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40), axis=1) chroma = np.mean(librosa.feature.chroma_stft(S=stft, sr=sr), axis=1) mel = np.mean(librosa.feature.melspectrogram(y=y, sr=sr), axis=1) contrast = np.mean(librosa.feature.spectral_contrast(S=stft, sr=sr), axis=1) tonnetz = np.mean(librosa.feature.tonnetz(y=librosa.effects.harmonic(y), sr=sr), axis=1) # Time-domain features (ensure 1D) zero_crossing = np.mean(librosa.feature.zero_crossing_rate(y)) energy = np.mean(librosa.feature.rms(y=y)) # Spectral features (ensure 1D) spec_centroid = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)) spec_bandwidth = np.mean(librosa.feature.spectral_bandwidth(y=y, sr=sr)) spec_rolloff = np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr)) spec_flatness = np.mean(librosa.feature.spectral_flatness(y=y)) combined_features = np.concatenate([ mfcc[:40], # First 40 MFCCs chroma[:12], # 12 chroma features mel[:40], # First 40 mel features contrast[:7], # 7 contrast features tonnetz[:6], # 6 tonnetz features [zero_crossing], # 1 feature [energy], # 1 feature [spec_centroid], # 1 feature [spec_bandwidth], # 1 feature [spec_rolloff], # 1 feature [spec_flatness] # 1 feature ]) return combined_features.reshape(1,-1) # ========================= # Detection & Classification # ========================= def detect_is_cry(path, threshold): feat = extract_yamnet_embedding(path) feat = det_scaler.transform(feat) feat = det_pca.transform(feat) prob = det_model.predict_proba(feat)[0][0] is_cry = bool(prob >= threshold) return is_cry, float(prob) def classify_cry(path, conf_threshold): feat = extract_classification_features(path) current_len = feat.shape[1] expected_len = getattr(cls_scaler, "n_features_in_", None) if expected_len is not None and current_len != expected_len: raise HTTPException( status_code=500, detail=f"Feature length mismatch: got {current_len}, expected {expected_len}" ) print("feat shape at classify_cry:", feat.shape) # should be (1, 111) print("scaler expects:", cls_scaler.n_features_in_) # should be 111 feat_scaled = cls_scaler.transform(feat) feat_selector = feature_selector.transform(feat_scaled) probs = ensemble.predict_proba(feat_selector)[0] max_prob = float(np.max(probs)) if max_prob < conf_threshold: return "Normal / Not a Cry", None, max_prob label = label_encoder.inverse_transform([np.argmax(probs)])[0] return label, probs.tolist(), max_prob # ========================= # FastAPI App # ========================= app = FastAPI( title="Baby Cry Detection & Classification API", version="1.0" ) @app.post("/predict") async def predict( file: UploadFile = File(...), detection_threshold: float = 0.5, classification_threshold: float = 0.6 ): if not file.filename.lower().endswith((".wav", ".mp3", ".flac", ".ogg")): raise HTTPException(status_code=400, detail="Invalid audio format") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(await file.read()) tmp_path = tmp.name try: try: is_cry, cry_prob = detect_is_cry(tmp_path, detection_threshold) response = { "filename": file.filename, "cry_probability": cry_prob, "is_cry": is_cry, } if not is_cry: response["result"] = "Not a cry" return response label, probs, confidence = classify_cry( tmp_path, classification_threshold ) response.update({ "result": label, "confidence": confidence, "class_probabilities": probs, }) return response except Exception as e: # Log full traceback to the server console traceback.print_exc() # Return the error message so you see it in the client raise HTTPException( status_code=500, detail=f"Prediction failed: {e}" ) finally: os.remove(tmp_path)