Spaces:
Sleeping
Sleeping
| """Model 3 — YAMNet binary baby-cry detector via TensorFlow Hub.""" | |
| from __future__ import annotations | |
| import time | |
| import numpy as np | |
| from models.base import CryClassifier, CryPrediction, display_label | |
| # "Baby cry, infant cry" class index in the AudioSet ontology used by YAMNet. | |
| _BABY_CRY_CLASS_INDEX = 20 | |
| class YAMNetDetector(CryClassifier): | |
| name = "YAMNet-detector" | |
| description = "YAMNet (TF Hub) binary cry detector — gates the reason classifiers" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._model = None | |
| def load(self) -> None: | |
| import tensorflow_hub as hub | |
| self._model = hub.load("https://tfhub.dev/google/yamnet/1") | |
| self._loaded = True | |
| def predict(self, audio_np: np.ndarray, sr: int) -> CryPrediction: | |
| import tensorflow as tf | |
| from audio.preprocess import SAMPLE_RATE, resample | |
| t0 = time.perf_counter() | |
| try: | |
| if sr != SAMPLE_RATE: | |
| audio_np = resample(audio_np, sr, SAMPLE_RATE) | |
| waveform = tf.cast(audio_np, tf.float32) | |
| scores, embeddings, spectrogram = self._model(waveform) | |
| # scores shape: (num_frames, 521) | |
| scores_np = scores.numpy() | |
| cry_scores = scores_np[:, _BABY_CRY_CLASS_INDEX] | |
| avg_cry_score = float(np.mean(cry_scores)) | |
| is_cry = avg_cry_score >= 0.4 | |
| label_raw = "cry" if is_cry else "not_cry" | |
| latency = (time.perf_counter() - t0) * 1000 | |
| return CryPrediction( | |
| model_name=self.name, | |
| label=label_raw, | |
| display_label=display_label(label_raw), | |
| confidence=avg_cry_score, | |
| latency_ms=latency, | |
| ) | |
| except Exception as exc: | |
| latency = (time.perf_counter() - t0) * 1000 | |
| return CryPrediction( | |
| model_name=self.name, | |
| label="error", | |
| display_label="⚠️ Error", | |
| confidence=0.0, | |
| latency_ms=latency, | |
| error=str(exc), | |
| ) | |