File size: 6,990 Bytes
a96f2f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Author: Sathvik Udupa (2026)
# Email:  udupa@fit.vutbr.cz
# Paper:  Streaming Endpointer for Spoken Dialogue using Neural Audio Codecs and Label-Delayed Training, https://arxiv.org/abs/2506.07081, ASRU 2025

"""Mimi Endpointer — DiscriminativeModel for the TURN benchmark.

Two-stream LSTM over Mimi embeddings, streamed 20ms chunk at a time.
Mimi operates in 1920-sample (80ms) chunks → 2 LSTM frames per chunk.
Four harness steps are buffered before each Mimi run; floor bit is held
between updates. Inherent latency: ~80ms.

  floor = 1  if  P(user) > threshold  else  0

subject is always fed as channel 0 (user); other as channel 1 (system).

Debug mode (MIMI_DEBUG=1): saves debug_pass{N}.npz per conversation pass;
run plot_debug.py afterwards to render PNGs.

Sweep mode (MIMI_SWEEP=1): runs the harness sweep over thresholds 0.05–0.95
in a single inference pass; threshold is reported per-step as a list[int].
"""
from __future__ import annotations

import atexit
import os
import sys
from pathlib import Path

import numpy as np
import torch

_HERE = Path(__file__).resolve().parent
# model.py lives alongside predict.py in the HF flat layout, or one level up in the
# local nested layout (turn-bench-submission/ inside baselines/mimi_endpointer/)
sys.path.insert(0, str(_HERE))
sys.path.insert(0, str(_HERE.parent))

from model import (  # noqa: E402
    AudioFeatureExtractor,
    AUDIO_DEFAULTS,
    load_model as load_mimi_model,
)

IDX_USER = 4  # from training config: {bos:0, system_end:1, user_end:2, system:3, user:4}

# checkpoint.pt is alongside predict.py (HF) or one level up (local)
CHECKPOINT = next(
    p for p in (_HERE / "checkpoint.pt", _HERE.parent / "checkpoint.pt") if p.exists()
)

_CHUNK_STEPS = 4   # 4 × 20ms = 80ms = one Mimi frame_size (1920 samples at 24kHz)
_SR = 24_000
_FRAME_RATE = 50   # harness step rate (Hz)


class MimiEndpointerModel:
    input_sample_rate = _SR  # Mimi native rate; 24000 % 50 == 0 → 480 samples/step

    def __init__(
        self,
        threshold: float = 0.5,
        thresholds: list[float] | None = None,
        debug: bool = False,
    ) -> None:
        # sweep mode: thresholds is a list; single mode: scalar threshold
        if thresholds is not None:
            self.thresholds = thresholds      # harness detects sweep mode via hasattr
            self._thresholds_arr = thresholds
        else:
            self.threshold = threshold        # single operating point
            self._thresholds_arr = [threshold]
        self._sweep = thresholds is not None
        self.debug = debug
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self._device = device
        self._model = load_mimi_model(str(CHECKPOINT), device=device)
        self._extractor = AudioFeatureExtractor(**AUDIO_DEFAULTS, device=device)
        self._ctx = None
        self._debug_idx = 0
        self._log_subj: list[np.ndarray] = []
        self._log_other: list[np.ndarray] = []
        self._log_floor: list[int] = []
        self._log_probs: list[np.ndarray] = []  # all 5 class probs per step (T, 5)
        if debug:
            atexit.register(self._save_npz)
        self.reset()

    def reset(self) -> None:
        if self.debug:
            self._save_npz()
        if self._ctx is not None:
            self._ctx.__exit__(None, None, None)
        self._ctx = self._extractor.mimi.streaming(batch_size=2)
        self._ctx.__enter__()
        self._h1, self._c1 = self._model.init_hidden(1, self._device)
        self._h2, self._c2 = self._model.init_hidden(1, self._device)
        self._buf_subj: list[np.ndarray] = []
        self._buf_other: list[np.ndarray] = []
        self._floor_bits: list[int] = [0] * len(self._thresholds_arr)
        self._log_subj = []
        self._log_other = []
        self._log_floor = []
        self._log_probs = []

    def __del__(self) -> None:
        if self._ctx is not None:
            self._ctx.__exit__(None, None, None)

    def _save_npz(self) -> None:
        if not self._log_floor:
            return
        out = _HERE / f"debug_pass{self._debug_idx}.npz"
        np.savez(
            out,
            subj=np.concatenate(self._log_subj),
            other=np.concatenate(self._log_other),
            floor=np.array(self._log_floor, dtype=np.int8),
            probs=np.array(self._log_probs, dtype=np.float32),  # (T, 5)
            threshold=np.float32(self._thresholds_arr[0]),
            sr=np.int32(_SR),
            frame_rate=np.int32(_FRAME_RATE),
        )
        sys.stderr.write(f"[debug] saved → {out}\n")
        self._debug_idx += 1

    def step(self, subject_audio: np.ndarray, other_audio: np.ndarray):
        self._buf_subj.append(subject_audio)
        self._buf_other.append(other_audio)

        if self.debug:
            self._log_subj.append(subject_audio)
            self._log_other.append(other_audio)

        new_probs: np.ndarray | None = None

        if len(self._buf_subj) == _CHUNK_STEPS:
            chunk_s = torch.from_numpy(np.concatenate(self._buf_subj)).to(self._device)
            chunk_o = torch.from_numpy(np.concatenate(self._buf_other)).to(self._device)
            self._buf_subj.clear()
            self._buf_other.clear()

            # (2, 1, 1920) — subject=channel 0 (user), other=channel 1 (system)
            chunk = torch.stack([chunk_s, chunk_o]).unsqueeze(1)
            with torch.no_grad():
                emb = self._extractor.mimi.encode_to_latent(chunk, quantize=True)  # (2, feat, 1)
                emb = self._extractor.mimi.upsample(emb)                            # (2, feat, 2)
                logits = None
                for t in range(emb.shape[-1]):
                    logits, self._h1, self._c1, self._h2, self._c2 = \
                        self._model.infer_ar_step(
                            emb[0:1, :, t], emb[1:2, :, t],
                            self._h1, self._c1, self._h2, self._c2,
                        )
            new_probs = torch.softmax(logits[0], dim=-1).cpu().numpy()  # (5,)
            p_user = new_probs[IDX_USER]
            self._floor_bits = [1 if p_user > t else 0 for t in self._thresholds_arr]

        if self.debug:
            p = new_probs if new_probs is not None else (
                self._log_probs[-1] if self._log_probs else np.zeros(5, dtype=np.float32)
            )
            self._log_probs.append(p)
            self._log_floor.append(self._floor_bits[0])  # first threshold for debug plot

        return self._floor_bits if self._sweep else self._floor_bits[0]


def load_model() -> MimiEndpointerModel:
    debug  = os.environ.get("MIMI_DEBUG", "0") == "1"
    sweep  = os.environ.get("MIMI_SWEEP", "0") == "1"
    if sweep:
        thresholds = list(np.round(np.arange(0.05, 1.0, 0.05), 2).tolist())
        return MimiEndpointerModel(thresholds=thresholds, debug=debug)
    thr = float(os.environ.get("MIMI_THRESHOLD", "0.1"))
    return MimiEndpointerModel(threshold=thr, debug=debug)