"""Model — Kibalama/baby_cry_classification_model (Wav2Vec2, 9-class).""" from __future__ import annotations import time import numpy as np from models.base import CryClassifier, CryPrediction, display_label # Map Kibalama's raw labels to our canonical label set _LABEL_MAP: dict[str, str | None] = { "belly pain": "belly_pain", "burping": "burping", "cold_hot": "cold_hot", "discomfort": "discomfort", "hungry": "hungry", "tired": "tired", # Non-cry labels — skip (return as-is but with low relevance) "laugh": None, "noise": None, "silence": None, } class KibalamaCry(CryClassifier): name = "Kibalama-9c" description = ( "Wav2Vec2 fine-tuned on 9-class baby cry dataset " "(HuggingFace: Kibalama/baby_cry_classification_model)" ) MODEL_ID = "Kibalama/baby_cry_classification_model" def __init__(self) -> None: super().__init__() self._pipe = None def load(self) -> None: from transformers import pipeline self._pipe = pipeline( "audio-classification", model=self.MODEL_ID, device="cpu", ) self._loaded = True def predict(self, audio_np: np.ndarray, sr: int) -> CryPrediction: from audio.preprocess import SAMPLE_RATE, resample t0 = time.perf_counter() try: if sr != SAMPLE_RATE: audio_np = resample(audio_np, sr, SAMPLE_RATE) results = self._pipe( {"raw": audio_np, "sampling_rate": SAMPLE_RATE}, top_k=9, ) # Pick the top *cry-related* label (skip laugh/noise/silence) for res in results: raw = res["label"] mapped = _LABEL_MAP.get(raw, raw) if mapped is not None: label_raw = mapped confidence = res["score"] break else: # All top results were non-cry categories label_raw = "no_cry" confidence = 0.0 latency = (time.perf_counter() - t0) * 1000 return CryPrediction( model_name=self.name, label=label_raw, display_label=display_label(label_raw), confidence=confidence, latency_ms=latency, ) except Exception as exc: latency = (time.perf_counter() - t0) * 1000 return CryPrediction( model_name=self.name, label="error", display_label="⚠️ Error", confidence=0.0, latency_ms=latency, error=str(exc), )