tot-talk / models /kibalama.py
grungecoder's picture
Initial commit: real-time multi-model baby cry classifier
ea2601f
"""Model — Kibalama/baby_cry_classification_model (Wav2Vec2, 9-class)."""
from __future__ import annotations
import time
import numpy as np
from models.base import CryClassifier, CryPrediction, display_label
# Map Kibalama's raw labels to our canonical label set
_LABEL_MAP: dict[str, str | None] = {
"belly pain": "belly_pain",
"burping": "burping",
"cold_hot": "cold_hot",
"discomfort": "discomfort",
"hungry": "hungry",
"tired": "tired",
# Non-cry labels — skip (return as-is but with low relevance)
"laugh": None,
"noise": None,
"silence": None,
}
class KibalamaCry(CryClassifier):
name = "Kibalama-9c"
description = (
"Wav2Vec2 fine-tuned on 9-class baby cry dataset "
"(HuggingFace: Kibalama/baby_cry_classification_model)"
)
MODEL_ID = "Kibalama/baby_cry_classification_model"
def __init__(self) -> None:
super().__init__()
self._pipe = None
def load(self) -> None:
from transformers import pipeline
self._pipe = pipeline(
"audio-classification",
model=self.MODEL_ID,
device="cpu",
)
self._loaded = True
def predict(self, audio_np: np.ndarray, sr: int) -> CryPrediction:
from audio.preprocess import SAMPLE_RATE, resample
t0 = time.perf_counter()
try:
if sr != SAMPLE_RATE:
audio_np = resample(audio_np, sr, SAMPLE_RATE)
results = self._pipe(
{"raw": audio_np, "sampling_rate": SAMPLE_RATE},
top_k=9,
)
# Pick the top *cry-related* label (skip laugh/noise/silence)
for res in results:
raw = res["label"]
mapped = _LABEL_MAP.get(raw, raw)
if mapped is not None:
label_raw = mapped
confidence = res["score"]
break
else:
# All top results were non-cry categories
label_raw = "no_cry"
confidence = 0.0
latency = (time.perf_counter() - t0) * 1000
return CryPrediction(
model_name=self.name,
label=label_raw,
display_label=display_label(label_raw),
confidence=confidence,
latency_ms=latency,
)
except Exception as exc:
latency = (time.perf_counter() - t0) * 1000
return CryPrediction(
model_name=self.name,
label="error",
display_label="⚠️ Error",
confidence=0.0,
latency_ms=latency,
error=str(exc),
)