tot-talk / models /yamnet.py
grungecoder's picture
Initial commit: real-time multi-model baby cry classifier
ea2601f
"""Model 3 — YAMNet binary baby-cry detector via TensorFlow Hub."""
from __future__ import annotations
import time
import numpy as np
from models.base import CryClassifier, CryPrediction, display_label
# "Baby cry, infant cry" class index in the AudioSet ontology used by YAMNet.
_BABY_CRY_CLASS_INDEX = 20
class YAMNetDetector(CryClassifier):
name = "YAMNet-detector"
description = "YAMNet (TF Hub) binary cry detector — gates the reason classifiers"
def __init__(self) -> None:
super().__init__()
self._model = None
def load(self) -> None:
import tensorflow_hub as hub
self._model = hub.load("https://tfhub.dev/google/yamnet/1")
self._loaded = True
def predict(self, audio_np: np.ndarray, sr: int) -> CryPrediction:
import tensorflow as tf
from audio.preprocess import SAMPLE_RATE, resample
t0 = time.perf_counter()
try:
if sr != SAMPLE_RATE:
audio_np = resample(audio_np, sr, SAMPLE_RATE)
waveform = tf.cast(audio_np, tf.float32)
scores, embeddings, spectrogram = self._model(waveform)
# scores shape: (num_frames, 521)
scores_np = scores.numpy()
cry_scores = scores_np[:, _BABY_CRY_CLASS_INDEX]
avg_cry_score = float(np.mean(cry_scores))
is_cry = avg_cry_score >= 0.4
label_raw = "cry" if is_cry else "not_cry"
latency = (time.perf_counter() - t0) * 1000
return CryPrediction(
model_name=self.name,
label=label_raw,
display_label=display_label(label_raw),
confidence=avg_cry_score,
latency_ms=latency,
)
except Exception as exc:
latency = (time.perf_counter() - t0) * 1000
return CryPrediction(
model_name=self.name,
label="error",
display_label="⚠️ Error",
confidence=0.0,
latency_ms=latency,
error=str(exc),
)