Spaces:
Sleeping
Sleeping
| """Model — Kibalama/baby_cry_classification_model (Wav2Vec2, 9-class).""" | |
| from __future__ import annotations | |
| import time | |
| import numpy as np | |
| from models.base import CryClassifier, CryPrediction, display_label | |
| # Map Kibalama's raw labels to our canonical label set | |
| _LABEL_MAP: dict[str, str | None] = { | |
| "belly pain": "belly_pain", | |
| "burping": "burping", | |
| "cold_hot": "cold_hot", | |
| "discomfort": "discomfort", | |
| "hungry": "hungry", | |
| "tired": "tired", | |
| # Non-cry labels — skip (return as-is but with low relevance) | |
| "laugh": None, | |
| "noise": None, | |
| "silence": None, | |
| } | |
| class KibalamaCry(CryClassifier): | |
| name = "Kibalama-9c" | |
| description = ( | |
| "Wav2Vec2 fine-tuned on 9-class baby cry dataset " | |
| "(HuggingFace: Kibalama/baby_cry_classification_model)" | |
| ) | |
| MODEL_ID = "Kibalama/baby_cry_classification_model" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._pipe = None | |
| def load(self) -> None: | |
| from transformers import pipeline | |
| self._pipe = pipeline( | |
| "audio-classification", | |
| model=self.MODEL_ID, | |
| device="cpu", | |
| ) | |
| self._loaded = True | |
| def predict(self, audio_np: np.ndarray, sr: int) -> CryPrediction: | |
| from audio.preprocess import SAMPLE_RATE, resample | |
| t0 = time.perf_counter() | |
| try: | |
| if sr != SAMPLE_RATE: | |
| audio_np = resample(audio_np, sr, SAMPLE_RATE) | |
| results = self._pipe( | |
| {"raw": audio_np, "sampling_rate": SAMPLE_RATE}, | |
| top_k=9, | |
| ) | |
| # Pick the top *cry-related* label (skip laugh/noise/silence) | |
| for res in results: | |
| raw = res["label"] | |
| mapped = _LABEL_MAP.get(raw, raw) | |
| if mapped is not None: | |
| label_raw = mapped | |
| confidence = res["score"] | |
| break | |
| else: | |
| # All top results were non-cry categories | |
| label_raw = "no_cry" | |
| confidence = 0.0 | |
| latency = (time.perf_counter() - t0) * 1000 | |
| return CryPrediction( | |
| model_name=self.name, | |
| label=label_raw, | |
| display_label=display_label(label_raw), | |
| confidence=confidence, | |
| latency_ms=latency, | |
| ) | |
| except Exception as exc: | |
| latency = (time.perf_counter() - t0) * 1000 | |
| return CryPrediction( | |
| model_name=self.name, | |
| label="error", | |
| display_label="⚠️ Error", | |
| confidence=0.0, | |
| latency_ms=latency, | |
| error=str(exc), | |
| ) | |