Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """TotTalk Cry Eval — real-time multi-model baby cry classifier.""" | |
| from __future__ import annotations | |
| import json | |
| import queue | |
| from pathlib import Path | |
| import click | |
| import numpy as np | |
| import sounddevice as sd | |
| from rich.console import Console | |
| from rich.progress import Progress, SpinnerColumn, TextColumn | |
| from audio.capture import FileCapture, MicCapture | |
| from audio.preprocess import SAMPLE_RATE, compute_rms, is_silent, normalize_audio | |
| from display.table import CryDisplay | |
| from models.ensemble import EnsembleClassifier | |
| console = Console(stderr=True) | |
| def _print_audio_devices() -> None: | |
| """Print available audio devices for reference.""" | |
| console.print("\n[bold]Audio devices:[/bold]") | |
| try: | |
| devices = sd.query_devices() | |
| default_in = sd.default.device[0] | |
| for i, d in enumerate(devices): | |
| marker = " ← default input" if i == default_in else "" | |
| if d["max_input_channels"] > 0: | |
| console.print(f" [{i}] {d['name']} (in:{d['max_input_channels']}){marker}") | |
| except Exception as exc: | |
| console.print(f" [red]Could not query devices: {exc}[/red]") | |
| def _load_models(ensemble: EnsembleClassifier) -> None: | |
| """Load all models with a rich progress spinner.""" | |
| console.print() | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| console=console, | |
| ) as progress: | |
| task = progress.add_task("Loading models…", total=None) | |
| results = ensemble.load_all() | |
| progress.update(task, description="Models loaded.") | |
| for name, error in results.items(): | |
| if error: | |
| console.print(f" [red]✗ {name}: {error}[/red]") | |
| else: | |
| console.print(f" [green]✓ {name}[/green]") | |
| console.print() | |
| def cli( | |
| audio_file: str | None, | |
| model_names: str | None, | |
| no_yamnet_gate: bool, | |
| save_log: str | None, | |
| sensitivity: float | None, | |
| ) -> None: | |
| """🍼 TotTalk Cry Eval — real-time multi-model baby cry classifier.""" | |
| console.print("[bold cyan]🍼 TotTalk Cry Eval[/bold cyan]") | |
| # Override silence threshold if requested | |
| if sensitivity is not None: | |
| import audio.preprocess as _ap | |
| _ap.SILENCE_RMS_THRESHOLD = sensitivity | |
| console.print(f"[dim]Silence threshold set to {sensitivity}[/dim]") | |
| # Parse model list | |
| selected = model_names.split(",") if model_names else None | |
| # Init ensemble | |
| ensemble = EnsembleClassifier( | |
| model_names=selected, | |
| use_yamnet_gate=not no_yamnet_gate, | |
| ) | |
| # Print device info | |
| if audio_file is None: | |
| _print_audio_devices() | |
| # Load models | |
| _load_models(ensemble) | |
| # Log file handle | |
| log_fh = None | |
| if save_log: | |
| log_fh = open(save_log, "a") # noqa: SIM115 | |
| # Set up audio source | |
| if audio_file: | |
| source_label = f"file: {Path(audio_file).name}" | |
| capture = FileCapture(audio_file) | |
| else: | |
| source_label = "mic" | |
| capture = MicCapture() | |
| # Display | |
| display = CryDisplay() | |
| try: | |
| capture.start() | |
| display.start() | |
| console.print(f"[dim]Listening ({source_label})… Press Ctrl+C to stop.[/dim]\n") | |
| while True: | |
| try: | |
| window: np.ndarray = capture.window_queue.get(timeout=3.0) | |
| except queue.Empty: | |
| continue | |
| rms = compute_rms(window) | |
| silent = is_silent(window) | |
| if silent: | |
| display.update([], rms, source_label=source_label, is_silent=True) | |
| continue | |
| # Peak-normalize so quiet phone playback reaches model-friendly levels | |
| window = normalize_audio(window) | |
| predictions = ensemble.predict_all(window, SAMPLE_RATE) | |
| display.update(predictions, rms, source_label=source_label) | |
| # Optional JSONL log | |
| if log_fh is not None: | |
| record = { | |
| "window": display._window_count, | |
| "rms": rms, | |
| "predictions": [ | |
| { | |
| "model": p.model_name, | |
| "label": p.label, | |
| "confidence": p.confidence, | |
| "latency_ms": p.latency_ms, | |
| "error": p.error, | |
| } | |
| for p in predictions | |
| ], | |
| } | |
| log_fh.write(json.dumps(record) + "\n") | |
| log_fh.flush() | |
| except KeyboardInterrupt: | |
| console.print("\n[yellow]Stopped.[/yellow]") | |
| finally: | |
| capture.stop() | |
| display.stop() | |
| if log_fh is not None: | |
| log_fh.close() | |
| if __name__ == "__main__": | |
| cli() | |