Spaces:
Sleeping
Sleeping
File size: 7,576 Bytes
ea2601f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | """Ensemble runner β loads all models and orchestrates per-window inference."""
from __future__ import annotations
import concurrent.futures
import threading
from typing import Sequence
import numpy as np
from rich.console import Console
from models.base import CryClassifier, CryPrediction, display_label
from models.foduucom_svc import FoduucomSVC
from models.kibalama import KibalamaCry
from models.wiam_wav2vec2 import DistilHuBERTCry
from models.yamnet import YAMNetDetector
console = Console(stderr=True)
# Map of short names β classes for CLI --models filtering
MODEL_REGISTRY: dict[str, type[CryClassifier]] = {
"svc": FoduucomSVC,
"hubert": DistilHuBERTCry,
"kibalama": KibalamaCry,
"yamnet": YAMNetDetector,
}
class EnsembleClassifier:
"""Loads and runs multiple cry classifiers, aggregating results."""
def __init__(
self,
model_names: Sequence[str] | None = None,
use_yamnet_gate: bool = True,
) -> None:
self.use_yamnet_gate = use_yamnet_gate
# Decide which models to instantiate
if model_names is None:
names = list(MODEL_REGISTRY.keys())
else:
names = [n.lower() for n in model_names]
# Always include YAMNet if gating is enabled and it's not already in the list
if use_yamnet_gate and "yamnet" not in names:
names.insert(0, "yamnet")
self._classifiers: list[CryClassifier] = []
for n in names:
cls = MODEL_REGISTRY.get(n)
if cls is None:
console.print(f"[yellow]β Unknown model '{n}' β skipping[/yellow]")
continue
self._classifiers.append(cls())
self._yamnet: YAMNetDetector | None = None
self._reason_classifiers: list[CryClassifier] = []
for c in self._classifiers:
if isinstance(c, YAMNetDetector):
self._yamnet = c
else:
self._reason_classifiers.append(c)
# ββ Loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_all(self) -> dict[str, str | None]:
"""Load every model in parallel. Return {name: error_or_None}."""
results: dict[str, str | None] = {}
lock = threading.Lock()
def _load(clf: CryClassifier) -> None:
try:
clf.load()
with lock:
results[clf.name] = None
except Exception as exc:
with lock:
results[clf.name] = str(exc)
with concurrent.futures.ThreadPoolExecutor(max_workers=len(self._classifiers)) as pool:
pool.map(_load, self._classifiers)
return results
# ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict_all(
self,
audio_np: np.ndarray,
sr: int,
) -> list[CryPrediction]:
predictions: list[CryPrediction] = []
# 1. YAMNet gate
if self._yamnet is not None and self._yamnet.is_loaded():
yamnet_pred = self._yamnet.predict(audio_np, sr)
predictions.append(yamnet_pred)
if (
self.use_yamnet_gate
and yamnet_pred.label == "not_cry"
and yamnet_pred.confidence < 0.4 # not_cry with cry-score < 0.4
):
# Skip reason classifiers β no cry detected
for rc in self._reason_classifiers:
predictions.append(
CryPrediction(
model_name=rc.name,
label="no_cry",
display_label="β No cry",
confidence=0.0,
latency_ms=0.0,
)
)
return predictions
elif self._yamnet is not None:
predictions.append(
CryPrediction(
model_name=self._yamnet.name,
label="error",
display_label="β οΈ Load Error",
confidence=0.0,
latency_ms=0.0,
error="Model not loaded",
)
)
# 2. Run reason classifiers
# SVC is sub-ms β run synchronously
# Transformer models (HuBERT, Kibalama) β run in threads with timeout
inline_results: list[CryPrediction] = []
thread_futures: list[tuple[CryClassifier, concurrent.futures.Future[CryPrediction]]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as pool:
for clf in self._reason_classifiers:
if not clf.is_loaded():
predictions.append(
CryPrediction(
model_name=clf.name,
label="error",
display_label="β οΈ Load Error",
confidence=0.0,
latency_ms=0.0,
error="Model not loaded",
)
)
continue
if isinstance(clf, FoduucomSVC):
# Fast β run inline
inline_results.append(clf.predict(audio_np, sr))
else:
# Slow β run in a thread
fut = pool.submit(clf.predict, audio_np, sr)
thread_futures.append((clf, fut))
predictions.extend(inline_results)
for clf, fut in thread_futures:
try:
result = fut.result(timeout=2.0)
predictions.append(result)
except concurrent.futures.TimeoutError:
predictions.append(
CryPrediction(
model_name=clf.name,
label="timeout",
display_label="β³ Timeout",
confidence=0.0,
latency_ms=2000.0,
error="Inference timed out (>2 s)",
)
)
return predictions
@property
def classifiers(self) -> list[CryClassifier]:
return list(self._classifiers)
def compute_consensus(predictions: list[CryPrediction]) -> str | None:
"""Weighted-vote consensus across *reason* classifiers (exclude YAMNet).
Each model contributes its confidence as a weight.
Returns the winning label string or None if no agreement / no valid votes.
"""
weighted_votes: dict[str, float] = {}
vote_count: dict[str, int] = {}
total_voters = 0
for p in predictions:
if p.model_name == "YAMNet-detector":
continue
if p.error or p.label in ("no_cry", "timeout", "error"):
continue
total_voters += 1
weighted_votes[p.label] = weighted_votes.get(p.label, 0.0) + p.confidence
vote_count[p.label] = vote_count.get(p.label, 0) + 1
if not weighted_votes:
return None
top_label = max(weighted_votes, key=weighted_votes.__getitem__)
count = vote_count[top_label]
return f"{display_label(top_label)} ({count}/{total_voters} agree)"
|