Mohammedelhakim's picture
Update app.py
93a1b57 verified
# app.py
from fastapi import FastAPI, UploadFile, File, HTTPException
import traceback
import numpy as np
import librosa
import joblib
import tempfile
import os
import tensorflow as tf
import tensorflow_hub as hub
# =========================
# Configuration
# =========================
SR = 16000
DETECTOR_MODEL_PATH = "detection_models/yamnet_lr_model.joblib"
DETECTOR_SCALER_PATH = "detection_models/scaler_yamnet.pkl"
DETECTOR_PCA_PATH = "detection_models/pca_yamnet.pkl"
CLASS_ENSEMBLE_PATH = "classification_models/babycry_ensemble.pkl"
CLASS_SCALER_PATH = "classification_models/scaler.pkl"
CLASS_SELECTOR_PATH = "classification_models/feature_selector.pkl"
CLASS_LE_PATH = "classification_models/label_encoder.pkl"
# =========================
# Load models (ONCE)
# =========================
yamnet = hub.load("https://tfhub.dev/google/yamnet/1")
det_model = joblib.load(DETECTOR_MODEL_PATH)
det_scaler = joblib.load(DETECTOR_SCALER_PATH)
det_pca = joblib.load(DETECTOR_PCA_PATH)
ensemble = joblib.load(CLASS_ENSEMBLE_PATH)
cls_scaler = joblib.load(CLASS_SCALER_PATH)
feature_selector = joblib.load(CLASS_SELECTOR_PATH)
label_encoder = joblib.load(CLASS_LE_PATH)
# =========================
# Feature Extraction
# =========================
def extract_yamnet_embedding(path):
wav, _ = librosa.load(path, sr=SR, mono=True)
waveform = tf.convert_to_tensor(wav, dtype=tf.float32)
_, embeddings, _ = yamnet(waveform)
emb = embeddings.numpy()
mean_emb = np.mean(emb, axis=0)
std_emb = np.std(emb, axis=0)
return np.concatenate([mean_emb, std_emb]).reshape(1, -1)
def extract_classification_features(path):
y, sr = librosa.load(path, sr=SR)
stft = np.abs(librosa.stft(y))
mfcc = np.mean(librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40), axis=1)
chroma = np.mean(librosa.feature.chroma_stft(S=stft, sr=sr), axis=1)
mel = np.mean(librosa.feature.melspectrogram(y=y, sr=sr), axis=1)
contrast = np.mean(librosa.feature.spectral_contrast(S=stft, sr=sr), axis=1)
tonnetz = np.mean(librosa.feature.tonnetz(y=librosa.effects.harmonic(y), sr=sr), axis=1)
# Time-domain features (ensure 1D)
zero_crossing = np.mean(librosa.feature.zero_crossing_rate(y))
energy = np.mean(librosa.feature.rms(y=y))
# Spectral features (ensure 1D)
spec_centroid = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))
spec_bandwidth = np.mean(librosa.feature.spectral_bandwidth(y=y, sr=sr))
spec_rolloff = np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr))
spec_flatness = np.mean(librosa.feature.spectral_flatness(y=y))
combined_features = np.concatenate([
mfcc[:40], # First 40 MFCCs
chroma[:12], # 12 chroma features
mel[:40], # First 40 mel features
contrast[:7], # 7 contrast features
tonnetz[:6], # 6 tonnetz features
[zero_crossing], # 1 feature
[energy], # 1 feature
[spec_centroid], # 1 feature
[spec_bandwidth], # 1 feature
[spec_rolloff], # 1 feature
[spec_flatness] # 1 feature
])
return combined_features.reshape(1,-1)
# =========================
# Detection & Classification
# =========================
def detect_is_cry(path, threshold):
feat = extract_yamnet_embedding(path)
feat = det_scaler.transform(feat)
feat = det_pca.transform(feat)
prob = det_model.predict_proba(feat)[0][0]
is_cry = bool(prob >= threshold)
return is_cry, float(prob)
def classify_cry(path, conf_threshold):
feat = extract_classification_features(path)
current_len = feat.shape[1]
expected_len = getattr(cls_scaler, "n_features_in_", None)
if expected_len is not None and current_len != expected_len:
raise HTTPException(
status_code=500,
detail=f"Feature length mismatch: got {current_len}, expected {expected_len}"
)
print("feat shape at classify_cry:", feat.shape) # should be (1, 111)
print("scaler expects:", cls_scaler.n_features_in_) # should be 111
feat_scaled = cls_scaler.transform(feat)
feat_selector = feature_selector.transform(feat_scaled)
probs = ensemble.predict_proba(feat_selector)[0]
max_prob = float(np.max(probs))
if max_prob < conf_threshold:
return "Normal / Not a Cry", None, max_prob
label = label_encoder.inverse_transform([np.argmax(probs)])[0]
return label, probs.tolist(), max_prob
# =========================
# FastAPI App
# =========================
app = FastAPI(
title="Baby Cry Detection & Classification API",
version="1.0"
)
@app.post("/predict")
async def predict(
file: UploadFile = File(...),
detection_threshold: float = 0.5,
classification_threshold: float = 0.6
):
if not file.filename.lower().endswith((".wav", ".mp3", ".flac", ".ogg")):
raise HTTPException(status_code=400, detail="Invalid audio format")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
try:
try:
is_cry, cry_prob = detect_is_cry(tmp_path, detection_threshold)
response = {
"filename": file.filename,
"cry_probability": cry_prob,
"is_cry": is_cry,
}
if not is_cry:
response["result"] = "Not a cry"
return response
label, probs, confidence = classify_cry(
tmp_path,
classification_threshold
)
response.update({
"result": label,
"confidence": confidence,
"class_probabilities": probs,
})
return response
except Exception as e:
# Log full traceback to the server console
traceback.print_exc()
# Return the error message so you see it in the client
raise HTTPException(
status_code=500,
detail=f"Prediction failed: {e}"
)
finally:
os.remove(tmp_path)