File size: 6,198 Bytes
ea2601f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""Rich-based live terminal table for displaying predictions."""

from __future__ import annotations

from collections import deque

from rich.live import Live
from rich.table import Table
from rich.text import Text

from models.base import LABEL_EMOJI, LABEL_MEANING, CryPrediction
from models.ensemble import compute_consensus

# ── Helpers ───────────────────────────────────────────────────────────────────

_BAR_FULL = "β–ˆ"
_BAR_EMPTY = "β–‘"
_BAR_WIDTH = 5


def confidence_bar(value: float) -> str:
    """Render a 5-char Unicode bar for a 0.0–1.0 confidence value."""
    filled = round(value * _BAR_WIDTH)
    return _BAR_FULL * filled + _BAR_EMPTY * (_BAR_WIDTH - filled)


def format_confidence(value: float) -> str:
    """Bar + percentage string."""
    pct = int(value * 100)
    return f"{confidence_bar(value)} {pct:>3}%"


# ── Display state ─────────────────────────────────────────────────────────────

class CryDisplay:
    """Manages a ``rich.live.Live`` context showing model predictions."""

    def __init__(self, max_history: int = 5) -> None:
        self._window_count = 0
        self._rms = 0.0
        self._yamnet_status = ""
        self._source_label = "mic"
        self._predictions: list[CryPrediction] = []
        self._history: deque[str] = deque(maxlen=max_history)
        self._live: Live | None = None

    # ── Public API ────────────────────────────────────────────────────────

    def start(self) -> Live:
        self._live = Live(self._build_table(), refresh_per_second=4)
        self._live.start()
        return self._live

    def stop(self) -> None:
        if self._live is not None:
            self._live.stop()
            self._live = None

    def update(
        self,
        predictions: list[CryPrediction],
        rms: float,
        source_label: str = "mic",
        is_silent: bool = False,
    ) -> None:
        self._window_count += 1
        self._rms = rms
        self._source_label = source_label
        self._predictions = predictions

        # Update YAMNet status line
        yamnet_preds = [p for p in predictions if p.model_name == "YAMNet-detector"]
        if yamnet_preds:
            yp = yamnet_preds[0]
            icon = "βœ…" if yp.label == "cry" else "❌"
            self._yamnet_status = f"YAMNet: {icon} {yp.label.upper()} ({yp.confidence:.2f})"
        else:
            self._yamnet_status = "YAMNet: n/a"

        # History
        if is_silent:
            self._history.appendleft(f"#{self._window_count}  [silence]")
        else:
            consensus = compute_consensus(predictions)
            tag = consensus if consensus else "β€”"
            self._history.appendleft(f"#{self._window_count}  {tag}")

        if self._live is not None:
            self._live.update(self._build_table())

    # ── Table builder ─────────────────────────────────────────────────────

    def _build_table(self) -> Table:
        outer = Table(
            title=f"🍼 TotTalk Cry Eval β€” listening ({self._source_label})  (1s windows, 16 kHz)",
            title_style="bold cyan",
            show_header=False,
            show_edge=True,
            pad_edge=True,
            expand=True,
        )
        outer.add_column(ratio=1)

        # Header row
        header = (
            f"  RMS: {self._rms:.4f}  |  Window #{self._window_count}  "
            f"|  {self._yamnet_status}"
        )
        outer.add_row(Text(header, style="dim"))

        # Predictions table
        pred_table = Table(show_edge=False, expand=True, padding=(0, 1))
        pred_table.add_column("Model", style="bold", min_width=18)
        pred_table.add_column("Label", min_width=14)
        pred_table.add_column("Confidence", min_width=12)
        pred_table.add_column("Latency", justify="right", min_width=10)

        for p in self._predictions:
            if p.error:
                pred_table.add_row(
                    p.model_name,
                    Text(f"⚠️ {p.error[:30]}", style="red"),
                    "",
                    "",
                )
            else:
                pred_table.add_row(
                    p.model_name,
                    p.display_label,
                    format_confidence(p.confidence),
                    f"{p.latency_ms:.1f} ms",
                )

        # Consensus row
        consensus = compute_consensus(self._predictions)
        if consensus:
            pred_table.add_row(
                Text("CONSENSUS", style="bold magenta"),
                Text(consensus, style="bold"),
                "",
                "",
            )

        outer.add_row(pred_table)

        # History
        if self._history:
            hist_str = "  ".join(self._history)
            outer.add_row(Text(f"  Last detections: {hist_str}", style="dim"))

        # Cry meaning legend β€” show meaning for the consensus / top prediction
        shown_label = self._current_reason_label()
        if shown_label and shown_label in LABEL_MEANING:
            emoji = LABEL_EMOJI.get(shown_label, "")
            outer.add_row(
                Text(
                    f"  {emoji} {shown_label.replace('_', ' ').title()}: "
                    f"{LABEL_MEANING[shown_label]}",
                    style="italic yellow",
                )
            )

        return outer

    def _current_reason_label(self) -> str | None:
        """Return the most relevant reason label from the current predictions."""
        for p in self._predictions:
            if p.model_name == "YAMNet-detector":
                continue
            if p.error or p.label in ("no_cry", "timeout", "error"):
                continue
            return p.label
        return None