tot-talk / models /ensemble.py
grungecoder's picture
Initial commit: real-time multi-model baby cry classifier
ea2601f
"""Ensemble runner β€” loads all models and orchestrates per-window inference."""
from __future__ import annotations
import concurrent.futures
import threading
from typing import Sequence
import numpy as np
from rich.console import Console
from models.base import CryClassifier, CryPrediction, display_label
from models.foduucom_svc import FoduucomSVC
from models.kibalama import KibalamaCry
from models.wiam_wav2vec2 import DistilHuBERTCry
from models.yamnet import YAMNetDetector
console = Console(stderr=True)
# Map of short names β†’ classes for CLI --models filtering
MODEL_REGISTRY: dict[str, type[CryClassifier]] = {
"svc": FoduucomSVC,
"hubert": DistilHuBERTCry,
"kibalama": KibalamaCry,
"yamnet": YAMNetDetector,
}
class EnsembleClassifier:
"""Loads and runs multiple cry classifiers, aggregating results."""
def __init__(
self,
model_names: Sequence[str] | None = None,
use_yamnet_gate: bool = True,
) -> None:
self.use_yamnet_gate = use_yamnet_gate
# Decide which models to instantiate
if model_names is None:
names = list(MODEL_REGISTRY.keys())
else:
names = [n.lower() for n in model_names]
# Always include YAMNet if gating is enabled and it's not already in the list
if use_yamnet_gate and "yamnet" not in names:
names.insert(0, "yamnet")
self._classifiers: list[CryClassifier] = []
for n in names:
cls = MODEL_REGISTRY.get(n)
if cls is None:
console.print(f"[yellow]⚠ Unknown model '{n}' β€” skipping[/yellow]")
continue
self._classifiers.append(cls())
self._yamnet: YAMNetDetector | None = None
self._reason_classifiers: list[CryClassifier] = []
for c in self._classifiers:
if isinstance(c, YAMNetDetector):
self._yamnet = c
else:
self._reason_classifiers.append(c)
# ── Loading ───────────────────────────────────────────────────────────
def load_all(self) -> dict[str, str | None]:
"""Load every model in parallel. Return {name: error_or_None}."""
results: dict[str, str | None] = {}
lock = threading.Lock()
def _load(clf: CryClassifier) -> None:
try:
clf.load()
with lock:
results[clf.name] = None
except Exception as exc:
with lock:
results[clf.name] = str(exc)
with concurrent.futures.ThreadPoolExecutor(max_workers=len(self._classifiers)) as pool:
pool.map(_load, self._classifiers)
return results
# ── Inference ─────────────────────────────────────────────────────────
def predict_all(
self,
audio_np: np.ndarray,
sr: int,
) -> list[CryPrediction]:
predictions: list[CryPrediction] = []
# 1. YAMNet gate
if self._yamnet is not None and self._yamnet.is_loaded():
yamnet_pred = self._yamnet.predict(audio_np, sr)
predictions.append(yamnet_pred)
if (
self.use_yamnet_gate
and yamnet_pred.label == "not_cry"
and yamnet_pred.confidence < 0.4 # not_cry with cry-score < 0.4
):
# Skip reason classifiers β€” no cry detected
for rc in self._reason_classifiers:
predictions.append(
CryPrediction(
model_name=rc.name,
label="no_cry",
display_label="β€” No cry",
confidence=0.0,
latency_ms=0.0,
)
)
return predictions
elif self._yamnet is not None:
predictions.append(
CryPrediction(
model_name=self._yamnet.name,
label="error",
display_label="⚠️ Load Error",
confidence=0.0,
latency_ms=0.0,
error="Model not loaded",
)
)
# 2. Run reason classifiers
# SVC is sub-ms β€” run synchronously
# Transformer models (HuBERT, Kibalama) β€” run in threads with timeout
inline_results: list[CryPrediction] = []
thread_futures: list[tuple[CryClassifier, concurrent.futures.Future[CryPrediction]]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as pool:
for clf in self._reason_classifiers:
if not clf.is_loaded():
predictions.append(
CryPrediction(
model_name=clf.name,
label="error",
display_label="⚠️ Load Error",
confidence=0.0,
latency_ms=0.0,
error="Model not loaded",
)
)
continue
if isinstance(clf, FoduucomSVC):
# Fast β€” run inline
inline_results.append(clf.predict(audio_np, sr))
else:
# Slow β€” run in a thread
fut = pool.submit(clf.predict, audio_np, sr)
thread_futures.append((clf, fut))
predictions.extend(inline_results)
for clf, fut in thread_futures:
try:
result = fut.result(timeout=2.0)
predictions.append(result)
except concurrent.futures.TimeoutError:
predictions.append(
CryPrediction(
model_name=clf.name,
label="timeout",
display_label="⏳ Timeout",
confidence=0.0,
latency_ms=2000.0,
error="Inference timed out (>2 s)",
)
)
return predictions
@property
def classifiers(self) -> list[CryClassifier]:
return list(self._classifiers)
def compute_consensus(predictions: list[CryPrediction]) -> str | None:
"""Weighted-vote consensus across *reason* classifiers (exclude YAMNet).
Each model contributes its confidence as a weight.
Returns the winning label string or None if no agreement / no valid votes.
"""
weighted_votes: dict[str, float] = {}
vote_count: dict[str, int] = {}
total_voters = 0
for p in predictions:
if p.model_name == "YAMNet-detector":
continue
if p.error or p.label in ("no_cry", "timeout", "error"):
continue
total_voters += 1
weighted_votes[p.label] = weighted_votes.get(p.label, 0.0) + p.confidence
vote_count[p.label] = vote_count.get(p.label, 0) + 1
if not weighted_votes:
return None
top_label = max(weighted_votes, key=weighted_votes.__getitem__)
count = vote_count[top_label]
return f"{display_label(top_label)} ({count}/{total_voters} agree)"