tot-talk / display /table.py
grungecoder's picture
Initial commit: real-time multi-model baby cry classifier
ea2601f
"""Rich-based live terminal table for displaying predictions."""
from __future__ import annotations
from collections import deque
from rich.live import Live
from rich.table import Table
from rich.text import Text
from models.base import LABEL_EMOJI, LABEL_MEANING, CryPrediction
from models.ensemble import compute_consensus
# ── Helpers ───────────────────────────────────────────────────────────────────
_BAR_FULL = "β–ˆ"
_BAR_EMPTY = "β–‘"
_BAR_WIDTH = 5
def confidence_bar(value: float) -> str:
"""Render a 5-char Unicode bar for a 0.0–1.0 confidence value."""
filled = round(value * _BAR_WIDTH)
return _BAR_FULL * filled + _BAR_EMPTY * (_BAR_WIDTH - filled)
def format_confidence(value: float) -> str:
"""Bar + percentage string."""
pct = int(value * 100)
return f"{confidence_bar(value)} {pct:>3}%"
# ── Display state ─────────────────────────────────────────────────────────────
class CryDisplay:
"""Manages a ``rich.live.Live`` context showing model predictions."""
def __init__(self, max_history: int = 5) -> None:
self._window_count = 0
self._rms = 0.0
self._yamnet_status = ""
self._source_label = "mic"
self._predictions: list[CryPrediction] = []
self._history: deque[str] = deque(maxlen=max_history)
self._live: Live | None = None
# ── Public API ────────────────────────────────────────────────────────
def start(self) -> Live:
self._live = Live(self._build_table(), refresh_per_second=4)
self._live.start()
return self._live
def stop(self) -> None:
if self._live is not None:
self._live.stop()
self._live = None
def update(
self,
predictions: list[CryPrediction],
rms: float,
source_label: str = "mic",
is_silent: bool = False,
) -> None:
self._window_count += 1
self._rms = rms
self._source_label = source_label
self._predictions = predictions
# Update YAMNet status line
yamnet_preds = [p for p in predictions if p.model_name == "YAMNet-detector"]
if yamnet_preds:
yp = yamnet_preds[0]
icon = "βœ…" if yp.label == "cry" else "❌"
self._yamnet_status = f"YAMNet: {icon} {yp.label.upper()} ({yp.confidence:.2f})"
else:
self._yamnet_status = "YAMNet: n/a"
# History
if is_silent:
self._history.appendleft(f"#{self._window_count} [silence]")
else:
consensus = compute_consensus(predictions)
tag = consensus if consensus else "β€”"
self._history.appendleft(f"#{self._window_count} {tag}")
if self._live is not None:
self._live.update(self._build_table())
# ── Table builder ─────────────────────────────────────────────────────
def _build_table(self) -> Table:
outer = Table(
title=f"🍼 TotTalk Cry Eval β€” listening ({self._source_label}) (1s windows, 16 kHz)",
title_style="bold cyan",
show_header=False,
show_edge=True,
pad_edge=True,
expand=True,
)
outer.add_column(ratio=1)
# Header row
header = (
f" RMS: {self._rms:.4f} | Window #{self._window_count} "
f"| {self._yamnet_status}"
)
outer.add_row(Text(header, style="dim"))
# Predictions table
pred_table = Table(show_edge=False, expand=True, padding=(0, 1))
pred_table.add_column("Model", style="bold", min_width=18)
pred_table.add_column("Label", min_width=14)
pred_table.add_column("Confidence", min_width=12)
pred_table.add_column("Latency", justify="right", min_width=10)
for p in self._predictions:
if p.error:
pred_table.add_row(
p.model_name,
Text(f"⚠️ {p.error[:30]}", style="red"),
"",
"",
)
else:
pred_table.add_row(
p.model_name,
p.display_label,
format_confidence(p.confidence),
f"{p.latency_ms:.1f} ms",
)
# Consensus row
consensus = compute_consensus(self._predictions)
if consensus:
pred_table.add_row(
Text("CONSENSUS", style="bold magenta"),
Text(consensus, style="bold"),
"",
"",
)
outer.add_row(pred_table)
# History
if self._history:
hist_str = " ".join(self._history)
outer.add_row(Text(f" Last detections: {hist_str}", style="dim"))
# Cry meaning legend β€” show meaning for the consensus / top prediction
shown_label = self._current_reason_label()
if shown_label and shown_label in LABEL_MEANING:
emoji = LABEL_EMOJI.get(shown_label, "")
outer.add_row(
Text(
f" {emoji} {shown_label.replace('_', ' ').title()}: "
f"{LABEL_MEANING[shown_label]}",
style="italic yellow",
)
)
return outer
def _current_reason_label(self) -> str | None:
"""Return the most relevant reason label from the current predictions."""
for p in self._predictions:
if p.model_name == "YAMNet-detector":
continue
if p.error or p.label in ("no_cry", "timeout", "error"):
continue
return p.label
return None