Spaces:
Sleeping
Sleeping
| """Ensemble runner β loads all models and orchestrates per-window inference.""" | |
| from __future__ import annotations | |
| import concurrent.futures | |
| import threading | |
| from typing import Sequence | |
| import numpy as np | |
| from rich.console import Console | |
| from models.base import CryClassifier, CryPrediction, display_label | |
| from models.foduucom_svc import FoduucomSVC | |
| from models.kibalama import KibalamaCry | |
| from models.wiam_wav2vec2 import DistilHuBERTCry | |
| from models.yamnet import YAMNetDetector | |
| console = Console(stderr=True) | |
| # Map of short names β classes for CLI --models filtering | |
| MODEL_REGISTRY: dict[str, type[CryClassifier]] = { | |
| "svc": FoduucomSVC, | |
| "hubert": DistilHuBERTCry, | |
| "kibalama": KibalamaCry, | |
| "yamnet": YAMNetDetector, | |
| } | |
| class EnsembleClassifier: | |
| """Loads and runs multiple cry classifiers, aggregating results.""" | |
| def __init__( | |
| self, | |
| model_names: Sequence[str] | None = None, | |
| use_yamnet_gate: bool = True, | |
| ) -> None: | |
| self.use_yamnet_gate = use_yamnet_gate | |
| # Decide which models to instantiate | |
| if model_names is None: | |
| names = list(MODEL_REGISTRY.keys()) | |
| else: | |
| names = [n.lower() for n in model_names] | |
| # Always include YAMNet if gating is enabled and it's not already in the list | |
| if use_yamnet_gate and "yamnet" not in names: | |
| names.insert(0, "yamnet") | |
| self._classifiers: list[CryClassifier] = [] | |
| for n in names: | |
| cls = MODEL_REGISTRY.get(n) | |
| if cls is None: | |
| console.print(f"[yellow]β Unknown model '{n}' β skipping[/yellow]") | |
| continue | |
| self._classifiers.append(cls()) | |
| self._yamnet: YAMNetDetector | None = None | |
| self._reason_classifiers: list[CryClassifier] = [] | |
| for c in self._classifiers: | |
| if isinstance(c, YAMNetDetector): | |
| self._yamnet = c | |
| else: | |
| self._reason_classifiers.append(c) | |
| # ββ Loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_all(self) -> dict[str, str | None]: | |
| """Load every model in parallel. Return {name: error_or_None}.""" | |
| results: dict[str, str | None] = {} | |
| lock = threading.Lock() | |
| def _load(clf: CryClassifier) -> None: | |
| try: | |
| clf.load() | |
| with lock: | |
| results[clf.name] = None | |
| except Exception as exc: | |
| with lock: | |
| results[clf.name] = str(exc) | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=len(self._classifiers)) as pool: | |
| pool.map(_load, self._classifiers) | |
| return results | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_all( | |
| self, | |
| audio_np: np.ndarray, | |
| sr: int, | |
| ) -> list[CryPrediction]: | |
| predictions: list[CryPrediction] = [] | |
| # 1. YAMNet gate | |
| if self._yamnet is not None and self._yamnet.is_loaded(): | |
| yamnet_pred = self._yamnet.predict(audio_np, sr) | |
| predictions.append(yamnet_pred) | |
| if ( | |
| self.use_yamnet_gate | |
| and yamnet_pred.label == "not_cry" | |
| and yamnet_pred.confidence < 0.4 # not_cry with cry-score < 0.4 | |
| ): | |
| # Skip reason classifiers β no cry detected | |
| for rc in self._reason_classifiers: | |
| predictions.append( | |
| CryPrediction( | |
| model_name=rc.name, | |
| label="no_cry", | |
| display_label="β No cry", | |
| confidence=0.0, | |
| latency_ms=0.0, | |
| ) | |
| ) | |
| return predictions | |
| elif self._yamnet is not None: | |
| predictions.append( | |
| CryPrediction( | |
| model_name=self._yamnet.name, | |
| label="error", | |
| display_label="β οΈ Load Error", | |
| confidence=0.0, | |
| latency_ms=0.0, | |
| error="Model not loaded", | |
| ) | |
| ) | |
| # 2. Run reason classifiers | |
| # SVC is sub-ms β run synchronously | |
| # Transformer models (HuBERT, Kibalama) β run in threads with timeout | |
| inline_results: list[CryPrediction] = [] | |
| thread_futures: list[tuple[CryClassifier, concurrent.futures.Future[CryPrediction]]] = [] | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=4) as pool: | |
| for clf in self._reason_classifiers: | |
| if not clf.is_loaded(): | |
| predictions.append( | |
| CryPrediction( | |
| model_name=clf.name, | |
| label="error", | |
| display_label="β οΈ Load Error", | |
| confidence=0.0, | |
| latency_ms=0.0, | |
| error="Model not loaded", | |
| ) | |
| ) | |
| continue | |
| if isinstance(clf, FoduucomSVC): | |
| # Fast β run inline | |
| inline_results.append(clf.predict(audio_np, sr)) | |
| else: | |
| # Slow β run in a thread | |
| fut = pool.submit(clf.predict, audio_np, sr) | |
| thread_futures.append((clf, fut)) | |
| predictions.extend(inline_results) | |
| for clf, fut in thread_futures: | |
| try: | |
| result = fut.result(timeout=2.0) | |
| predictions.append(result) | |
| except concurrent.futures.TimeoutError: | |
| predictions.append( | |
| CryPrediction( | |
| model_name=clf.name, | |
| label="timeout", | |
| display_label="β³ Timeout", | |
| confidence=0.0, | |
| latency_ms=2000.0, | |
| error="Inference timed out (>2 s)", | |
| ) | |
| ) | |
| return predictions | |
| def classifiers(self) -> list[CryClassifier]: | |
| return list(self._classifiers) | |
| def compute_consensus(predictions: list[CryPrediction]) -> str | None: | |
| """Weighted-vote consensus across *reason* classifiers (exclude YAMNet). | |
| Each model contributes its confidence as a weight. | |
| Returns the winning label string or None if no agreement / no valid votes. | |
| """ | |
| weighted_votes: dict[str, float] = {} | |
| vote_count: dict[str, int] = {} | |
| total_voters = 0 | |
| for p in predictions: | |
| if p.model_name == "YAMNet-detector": | |
| continue | |
| if p.error or p.label in ("no_cry", "timeout", "error"): | |
| continue | |
| total_voters += 1 | |
| weighted_votes[p.label] = weighted_votes.get(p.label, 0.0) + p.confidence | |
| vote_count[p.label] = vote_count.get(p.label, 0) + 1 | |
| if not weighted_votes: | |
| return None | |
| top_label = max(weighted_votes, key=weighted_votes.__getitem__) | |
| count = vote_count[top_label] | |
| return f"{display_label(top_label)} ({count}/{total_voters} agree)" | |