File size: 6,990 Bytes
a96f2f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | # Author: Sathvik Udupa (2026)
# Email: udupa@fit.vutbr.cz
# Paper: Streaming Endpointer for Spoken Dialogue using Neural Audio Codecs and Label-Delayed Training, https://arxiv.org/abs/2506.07081, ASRU 2025
"""Mimi Endpointer — DiscriminativeModel for the TURN benchmark.
Two-stream LSTM over Mimi embeddings, streamed 20ms chunk at a time.
Mimi operates in 1920-sample (80ms) chunks → 2 LSTM frames per chunk.
Four harness steps are buffered before each Mimi run; floor bit is held
between updates. Inherent latency: ~80ms.
floor = 1 if P(user) > threshold else 0
subject is always fed as channel 0 (user); other as channel 1 (system).
Debug mode (MIMI_DEBUG=1): saves debug_pass{N}.npz per conversation pass;
run plot_debug.py afterwards to render PNGs.
Sweep mode (MIMI_SWEEP=1): runs the harness sweep over thresholds 0.05–0.95
in a single inference pass; threshold is reported per-step as a list[int].
"""
from __future__ import annotations
import atexit
import os
import sys
from pathlib import Path
import numpy as np
import torch
_HERE = Path(__file__).resolve().parent
# model.py lives alongside predict.py in the HF flat layout, or one level up in the
# local nested layout (turn-bench-submission/ inside baselines/mimi_endpointer/)
sys.path.insert(0, str(_HERE))
sys.path.insert(0, str(_HERE.parent))
from model import ( # noqa: E402
AudioFeatureExtractor,
AUDIO_DEFAULTS,
load_model as load_mimi_model,
)
IDX_USER = 4 # from training config: {bos:0, system_end:1, user_end:2, system:3, user:4}
# checkpoint.pt is alongside predict.py (HF) or one level up (local)
CHECKPOINT = next(
p for p in (_HERE / "checkpoint.pt", _HERE.parent / "checkpoint.pt") if p.exists()
)
_CHUNK_STEPS = 4 # 4 × 20ms = 80ms = one Mimi frame_size (1920 samples at 24kHz)
_SR = 24_000
_FRAME_RATE = 50 # harness step rate (Hz)
class MimiEndpointerModel:
input_sample_rate = _SR # Mimi native rate; 24000 % 50 == 0 → 480 samples/step
def __init__(
self,
threshold: float = 0.5,
thresholds: list[float] | None = None,
debug: bool = False,
) -> None:
# sweep mode: thresholds is a list; single mode: scalar threshold
if thresholds is not None:
self.thresholds = thresholds # harness detects sweep mode via hasattr
self._thresholds_arr = thresholds
else:
self.threshold = threshold # single operating point
self._thresholds_arr = [threshold]
self._sweep = thresholds is not None
self.debug = debug
device = "cuda" if torch.cuda.is_available() else "cpu"
self._device = device
self._model = load_mimi_model(str(CHECKPOINT), device=device)
self._extractor = AudioFeatureExtractor(**AUDIO_DEFAULTS, device=device)
self._ctx = None
self._debug_idx = 0
self._log_subj: list[np.ndarray] = []
self._log_other: list[np.ndarray] = []
self._log_floor: list[int] = []
self._log_probs: list[np.ndarray] = [] # all 5 class probs per step (T, 5)
if debug:
atexit.register(self._save_npz)
self.reset()
def reset(self) -> None:
if self.debug:
self._save_npz()
if self._ctx is not None:
self._ctx.__exit__(None, None, None)
self._ctx = self._extractor.mimi.streaming(batch_size=2)
self._ctx.__enter__()
self._h1, self._c1 = self._model.init_hidden(1, self._device)
self._h2, self._c2 = self._model.init_hidden(1, self._device)
self._buf_subj: list[np.ndarray] = []
self._buf_other: list[np.ndarray] = []
self._floor_bits: list[int] = [0] * len(self._thresholds_arr)
self._log_subj = []
self._log_other = []
self._log_floor = []
self._log_probs = []
def __del__(self) -> None:
if self._ctx is not None:
self._ctx.__exit__(None, None, None)
def _save_npz(self) -> None:
if not self._log_floor:
return
out = _HERE / f"debug_pass{self._debug_idx}.npz"
np.savez(
out,
subj=np.concatenate(self._log_subj),
other=np.concatenate(self._log_other),
floor=np.array(self._log_floor, dtype=np.int8),
probs=np.array(self._log_probs, dtype=np.float32), # (T, 5)
threshold=np.float32(self._thresholds_arr[0]),
sr=np.int32(_SR),
frame_rate=np.int32(_FRAME_RATE),
)
sys.stderr.write(f"[debug] saved → {out}\n")
self._debug_idx += 1
def step(self, subject_audio: np.ndarray, other_audio: np.ndarray):
self._buf_subj.append(subject_audio)
self._buf_other.append(other_audio)
if self.debug:
self._log_subj.append(subject_audio)
self._log_other.append(other_audio)
new_probs: np.ndarray | None = None
if len(self._buf_subj) == _CHUNK_STEPS:
chunk_s = torch.from_numpy(np.concatenate(self._buf_subj)).to(self._device)
chunk_o = torch.from_numpy(np.concatenate(self._buf_other)).to(self._device)
self._buf_subj.clear()
self._buf_other.clear()
# (2, 1, 1920) — subject=channel 0 (user), other=channel 1 (system)
chunk = torch.stack([chunk_s, chunk_o]).unsqueeze(1)
with torch.no_grad():
emb = self._extractor.mimi.encode_to_latent(chunk, quantize=True) # (2, feat, 1)
emb = self._extractor.mimi.upsample(emb) # (2, feat, 2)
logits = None
for t in range(emb.shape[-1]):
logits, self._h1, self._c1, self._h2, self._c2 = \
self._model.infer_ar_step(
emb[0:1, :, t], emb[1:2, :, t],
self._h1, self._c1, self._h2, self._c2,
)
new_probs = torch.softmax(logits[0], dim=-1).cpu().numpy() # (5,)
p_user = new_probs[IDX_USER]
self._floor_bits = [1 if p_user > t else 0 for t in self._thresholds_arr]
if self.debug:
p = new_probs if new_probs is not None else (
self._log_probs[-1] if self._log_probs else np.zeros(5, dtype=np.float32)
)
self._log_probs.append(p)
self._log_floor.append(self._floor_bits[0]) # first threshold for debug plot
return self._floor_bits if self._sweep else self._floor_bits[0]
def load_model() -> MimiEndpointerModel:
debug = os.environ.get("MIMI_DEBUG", "0") == "1"
sweep = os.environ.get("MIMI_SWEEP", "0") == "1"
if sweep:
thresholds = list(np.round(np.arange(0.05, 1.0, 0.05), 2).tolist())
return MimiEndpointerModel(thresholds=thresholds, debug=debug)
thr = float(os.environ.get("MIMI_THRESHOLD", "0.1"))
return MimiEndpointerModel(threshold=thr, debug=debug)
|