"""Model 3 — YAMNet binary baby-cry detector via TensorFlow Hub.""" from __future__ import annotations import time import numpy as np from models.base import CryClassifier, CryPrediction, display_label # "Baby cry, infant cry" class index in the AudioSet ontology used by YAMNet. _BABY_CRY_CLASS_INDEX = 20 class YAMNetDetector(CryClassifier): name = "YAMNet-detector" description = "YAMNet (TF Hub) binary cry detector — gates the reason classifiers" def __init__(self) -> None: super().__init__() self._model = None def load(self) -> None: import tensorflow_hub as hub self._model = hub.load("https://tfhub.dev/google/yamnet/1") self._loaded = True def predict(self, audio_np: np.ndarray, sr: int) -> CryPrediction: import tensorflow as tf from audio.preprocess import SAMPLE_RATE, resample t0 = time.perf_counter() try: if sr != SAMPLE_RATE: audio_np = resample(audio_np, sr, SAMPLE_RATE) waveform = tf.cast(audio_np, tf.float32) scores, embeddings, spectrogram = self._model(waveform) # scores shape: (num_frames, 521) scores_np = scores.numpy() cry_scores = scores_np[:, _BABY_CRY_CLASS_INDEX] avg_cry_score = float(np.mean(cry_scores)) is_cry = avg_cry_score >= 0.4 label_raw = "cry" if is_cry else "not_cry" latency = (time.perf_counter() - t0) * 1000 return CryPrediction( model_name=self.name, label=label_raw, display_label=display_label(label_raw), confidence=avg_cry_score, latency_ms=latency, ) except Exception as exc: latency = (time.perf_counter() - t0) * 1000 return CryPrediction( model_name=self.name, label="error", display_label="⚠️ Error", confidence=0.0, latency_ms=latency, error=str(exc), )