| |
| from fastapi import FastAPI, UploadFile, File, HTTPException |
| import traceback |
| import numpy as np |
| import librosa |
| import joblib |
| import tempfile |
| import os |
| import tensorflow as tf |
| import tensorflow_hub as hub |
|
|
| |
| |
| |
| SR = 16000 |
|
|
| DETECTOR_MODEL_PATH = "detection_models/yamnet_lr_model.joblib" |
| DETECTOR_SCALER_PATH = "detection_models/scaler_yamnet.pkl" |
| DETECTOR_PCA_PATH = "detection_models/pca_yamnet.pkl" |
|
|
| CLASS_ENSEMBLE_PATH = "classification_models/babycry_ensemble.pkl" |
| CLASS_SCALER_PATH = "classification_models/scaler.pkl" |
| CLASS_SELECTOR_PATH = "classification_models/feature_selector.pkl" |
| CLASS_LE_PATH = "classification_models/label_encoder.pkl" |
|
|
| |
| |
| |
| yamnet = hub.load("https://tfhub.dev/google/yamnet/1") |
|
|
| det_model = joblib.load(DETECTOR_MODEL_PATH) |
| det_scaler = joblib.load(DETECTOR_SCALER_PATH) |
| det_pca = joblib.load(DETECTOR_PCA_PATH) |
|
|
| ensemble = joblib.load(CLASS_ENSEMBLE_PATH) |
| cls_scaler = joblib.load(CLASS_SCALER_PATH) |
| feature_selector = joblib.load(CLASS_SELECTOR_PATH) |
| label_encoder = joblib.load(CLASS_LE_PATH) |
|
|
| |
| |
| |
| def extract_yamnet_embedding(path): |
| wav, _ = librosa.load(path, sr=SR, mono=True) |
| waveform = tf.convert_to_tensor(wav, dtype=tf.float32) |
|
|
| _, embeddings, _ = yamnet(waveform) |
| emb = embeddings.numpy() |
|
|
| mean_emb = np.mean(emb, axis=0) |
| std_emb = np.std(emb, axis=0) |
|
|
| return np.concatenate([mean_emb, std_emb]).reshape(1, -1) |
|
|
| def extract_classification_features(path): |
| y, sr = librosa.load(path, sr=SR) |
| stft = np.abs(librosa.stft(y)) |
|
|
| mfcc = np.mean(librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40), axis=1) |
| chroma = np.mean(librosa.feature.chroma_stft(S=stft, sr=sr), axis=1) |
| mel = np.mean(librosa.feature.melspectrogram(y=y, sr=sr), axis=1) |
| contrast = np.mean(librosa.feature.spectral_contrast(S=stft, sr=sr), axis=1) |
| tonnetz = np.mean(librosa.feature.tonnetz(y=librosa.effects.harmonic(y), sr=sr), axis=1) |
|
|
| |
| zero_crossing = np.mean(librosa.feature.zero_crossing_rate(y)) |
| energy = np.mean(librosa.feature.rms(y=y)) |
| |
| |
| spec_centroid = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)) |
| spec_bandwidth = np.mean(librosa.feature.spectral_bandwidth(y=y, sr=sr)) |
| spec_rolloff = np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr)) |
| spec_flatness = np.mean(librosa.feature.spectral_flatness(y=y)) |
|
|
| combined_features = np.concatenate([ |
| mfcc[:40], |
| chroma[:12], |
| mel[:40], |
| contrast[:7], |
| tonnetz[:6], |
| [zero_crossing], |
| [energy], |
| [spec_centroid], |
| [spec_bandwidth], |
| [spec_rolloff], |
| [spec_flatness] |
| ]) |
|
|
| return combined_features.reshape(1,-1) |
| |
|
|
| |
| |
| |
| def detect_is_cry(path, threshold): |
| feat = extract_yamnet_embedding(path) |
| feat = det_scaler.transform(feat) |
| feat = det_pca.transform(feat) |
|
|
| prob = det_model.predict_proba(feat)[0][0] |
|
|
| is_cry = bool(prob >= threshold) |
| return is_cry, float(prob) |
|
|
|
|
| def classify_cry(path, conf_threshold): |
| feat = extract_classification_features(path) |
| current_len = feat.shape[1] |
| expected_len = getattr(cls_scaler, "n_features_in_", None) |
|
|
| if expected_len is not None and current_len != expected_len: |
| raise HTTPException( |
| status_code=500, |
| detail=f"Feature length mismatch: got {current_len}, expected {expected_len}" |
| ) |
| print("feat shape at classify_cry:", feat.shape) |
| print("scaler expects:", cls_scaler.n_features_in_) |
|
|
| feat_scaled = cls_scaler.transform(feat) |
| feat_selector = feature_selector.transform(feat_scaled) |
|
|
| probs = ensemble.predict_proba(feat_selector)[0] |
| max_prob = float(np.max(probs)) |
|
|
| if max_prob < conf_threshold: |
| return "Normal / Not a Cry", None, max_prob |
|
|
| label = label_encoder.inverse_transform([np.argmax(probs)])[0] |
| return label, probs.tolist(), max_prob |
|
|
| |
| |
| |
| app = FastAPI( |
| title="Baby Cry Detection & Classification API", |
| version="1.0" |
| ) |
|
|
| @app.post("/predict") |
| async def predict( |
| file: UploadFile = File(...), |
| detection_threshold: float = 0.5, |
| classification_threshold: float = 0.6 |
| ): |
| if not file.filename.lower().endswith((".wav", ".mp3", ".flac", ".ogg")): |
| raise HTTPException(status_code=400, detail="Invalid audio format") |
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(await file.read()) |
| tmp_path = tmp.name |
|
|
| try: |
| try: |
| is_cry, cry_prob = detect_is_cry(tmp_path, detection_threshold) |
|
|
| response = { |
| "filename": file.filename, |
| "cry_probability": cry_prob, |
| "is_cry": is_cry, |
| } |
|
|
| if not is_cry: |
| response["result"] = "Not a cry" |
| return response |
|
|
| label, probs, confidence = classify_cry( |
| tmp_path, |
| classification_threshold |
| ) |
|
|
| response.update({ |
| "result": label, |
| "confidence": confidence, |
| "class_probabilities": probs, |
| }) |
|
|
| return response |
| except Exception as e: |
| |
| traceback.print_exc() |
| |
| raise HTTPException( |
| status_code=500, |
| detail=f"Prediction failed: {e}" |
| ) |
| finally: |
| os.remove(tmp_path) |