"""Ensemble runner — loads all models and orchestrates per-window inference.""" from __future__ import annotations import concurrent.futures import threading from typing import Sequence import numpy as np from rich.console import Console from models.base import CryClassifier, CryPrediction, display_label from models.foduucom_svc import FoduucomSVC from models.kibalama import KibalamaCry from models.wiam_wav2vec2 import DistilHuBERTCry from models.yamnet import YAMNetDetector console = Console(stderr=True) # Map of short names → classes for CLI --models filtering MODEL_REGISTRY: dict[str, type[CryClassifier]] = { "svc": FoduucomSVC, "hubert": DistilHuBERTCry, "kibalama": KibalamaCry, "yamnet": YAMNetDetector, } class EnsembleClassifier: """Loads and runs multiple cry classifiers, aggregating results.""" def __init__( self, model_names: Sequence[str] | None = None, use_yamnet_gate: bool = True, ) -> None: self.use_yamnet_gate = use_yamnet_gate # Decide which models to instantiate if model_names is None: names = list(MODEL_REGISTRY.keys()) else: names = [n.lower() for n in model_names] # Always include YAMNet if gating is enabled and it's not already in the list if use_yamnet_gate and "yamnet" not in names: names.insert(0, "yamnet") self._classifiers: list[CryClassifier] = [] for n in names: cls = MODEL_REGISTRY.get(n) if cls is None: console.print(f"[yellow]⚠ Unknown model '{n}' — skipping[/yellow]") continue self._classifiers.append(cls()) self._yamnet: YAMNetDetector | None = None self._reason_classifiers: list[CryClassifier] = [] for c in self._classifiers: if isinstance(c, YAMNetDetector): self._yamnet = c else: self._reason_classifiers.append(c) # ── Loading ─────────────────────────────────────────────────────────── def load_all(self) -> dict[str, str | None]: """Load every model in parallel. Return {name: error_or_None}.""" results: dict[str, str | None] = {} lock = threading.Lock() def _load(clf: CryClassifier) -> None: try: clf.load() with lock: results[clf.name] = None except Exception as exc: with lock: results[clf.name] = str(exc) with concurrent.futures.ThreadPoolExecutor(max_workers=len(self._classifiers)) as pool: pool.map(_load, self._classifiers) return results # ── Inference ───────────────────────────────────────────────────────── def predict_all( self, audio_np: np.ndarray, sr: int, ) -> list[CryPrediction]: predictions: list[CryPrediction] = [] # 1. YAMNet gate if self._yamnet is not None and self._yamnet.is_loaded(): yamnet_pred = self._yamnet.predict(audio_np, sr) predictions.append(yamnet_pred) if ( self.use_yamnet_gate and yamnet_pred.label == "not_cry" and yamnet_pred.confidence < 0.4 # not_cry with cry-score < 0.4 ): # Skip reason classifiers — no cry detected for rc in self._reason_classifiers: predictions.append( CryPrediction( model_name=rc.name, label="no_cry", display_label="— No cry", confidence=0.0, latency_ms=0.0, ) ) return predictions elif self._yamnet is not None: predictions.append( CryPrediction( model_name=self._yamnet.name, label="error", display_label="⚠️ Load Error", confidence=0.0, latency_ms=0.0, error="Model not loaded", ) ) # 2. Run reason classifiers # SVC is sub-ms — run synchronously # Transformer models (HuBERT, Kibalama) — run in threads with timeout inline_results: list[CryPrediction] = [] thread_futures: list[tuple[CryClassifier, concurrent.futures.Future[CryPrediction]]] = [] with concurrent.futures.ThreadPoolExecutor(max_workers=4) as pool: for clf in self._reason_classifiers: if not clf.is_loaded(): predictions.append( CryPrediction( model_name=clf.name, label="error", display_label="⚠️ Load Error", confidence=0.0, latency_ms=0.0, error="Model not loaded", ) ) continue if isinstance(clf, FoduucomSVC): # Fast — run inline inline_results.append(clf.predict(audio_np, sr)) else: # Slow — run in a thread fut = pool.submit(clf.predict, audio_np, sr) thread_futures.append((clf, fut)) predictions.extend(inline_results) for clf, fut in thread_futures: try: result = fut.result(timeout=2.0) predictions.append(result) except concurrent.futures.TimeoutError: predictions.append( CryPrediction( model_name=clf.name, label="timeout", display_label="⏳ Timeout", confidence=0.0, latency_ms=2000.0, error="Inference timed out (>2 s)", ) ) return predictions @property def classifiers(self) -> list[CryClassifier]: return list(self._classifiers) def compute_consensus(predictions: list[CryPrediction]) -> str | None: """Weighted-vote consensus across *reason* classifiers (exclude YAMNet). Each model contributes its confidence as a weight. Returns the winning label string or None if no agreement / no valid votes. """ weighted_votes: dict[str, float] = {} vote_count: dict[str, int] = {} total_voters = 0 for p in predictions: if p.model_name == "YAMNet-detector": continue if p.error or p.label in ("no_cry", "timeout", "error"): continue total_voters += 1 weighted_votes[p.label] = weighted_votes.get(p.label, 0.0) + p.confidence vote_count[p.label] = vote_count.get(p.label, 0) + 1 if not weighted_votes: return None top_label = max(weighted_votes, key=weighted_votes.__getitem__) count = vote_count[top_label] return f"{display_label(top_label)} ({count}/{total_voters} agree)"