ChristophSchuhmann commited on
Commit
bcceffa
·
verified ·
1 Parent(s): 9376900

Upload segmentation_infer_html.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. segmentation_infer_html.py +835 -0
segmentation_infer_html.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ segmentation_infer_smooth_segments.py
6
+
7
+ - Loads WhisperOddEven checkpoint
8
+ /home/user/outs/segmentation_gemini_2p_medium_model_best.pt
9
+ (override via CKPT env var).
10
+
11
+ - For each audio file in AUDIO_INPUT_DIR:
12
+ * load, resample to 16 kHz mono
13
+ * split into 30 s chunks
14
+ * run segmentation
15
+ * SMOOTH each track so that no segment (incl. background 0) is shorter than
16
+ MIN_SEGMENT_SEC seconds
17
+ * extract per-track segments (odd/even) and cut audio snippets
18
+ * build a MERGED timeline that starts/ends segments whenever either track
19
+ changes label, then smooth that merged timeline so that each merged
20
+ segment is also at least MIN_SEGMENT_SEC long, merging short segments
21
+ with neighbors using the rules described below.
22
+
23
+ - Writes a single HTML report with:
24
+ * smoothed per-track heatmap
25
+ * merged-timeline heatmap
26
+ * tables of per-track segments (with audio players)
27
+ * tables of merged segments (with audio players)
28
+
29
+ Merging rule for short merged segments:
30
+ - If a merged segment is shorter than MIN_SEGMENT_SEC, merge it with one of its
31
+ immediate neighbors.
32
+ - Prefer the neighbor whose (odd_label, even_label) matches this segment best
33
+ (majority vote over the two labels).
34
+ - If similarity is equal (or one neighbor is missing), merge with the neighbor
35
+ that has the shorter duration. If still equal, merge with the left neighbor.
36
+ """
37
+
38
+ from __future__ import annotations
39
+ import os
40
+ import io
41
+ import sys
42
+ import time
43
+ import math
44
+ import base64
45
+ import shutil
46
+ from pathlib import Path
47
+ from typing import List, Dict, Any, Tuple
48
+
49
+ import numpy as np
50
+ import torch
51
+ import torch.nn as nn
52
+ import torch.nn.functional as F
53
+
54
+ # plotting
55
+ import matplotlib
56
+ matplotlib.use("Agg")
57
+ import matplotlib.pyplot as plt
58
+
59
+ # audio
60
+ import soundfile as sf
61
+ import librosa
62
+ from pydub import AudioSegment # requires ffmpeg
63
+
64
+ from transformers import WhisperFeatureExtractor, WhisperModel
65
+
66
+ # =========================
67
+ # ========== CONFIG =======
68
+ # =========================
69
+
70
+ AUDIO_INPUT_DIR = Path(os.getenv("AUDIO_INPUT_DIR", "./infer-audio"))
71
+ OUT_DIR = Path(os.getenv("OUT_DIR", "./outs_infer"))
72
+ CKPT_PATH = Path(os.getenv("CKPT", "/home/user/outs/segmentation_gemini_medium_no_overlap_4epochs_model_best.pt"))
73
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "openai/whisper-small")
74
+
75
+ USE_LOCAL_MODELS = bool(int(os.getenv("USE_LOCAL_MODELS", "0")))
76
+ MODELS_SNAPSHOT_DIR = Path(os.getenv("MODELS_SNAPSHOT_DIR", "")) if USE_LOCAL_MODELS else None
77
+ HF_HOME = Path(os.getenv("HF_HOME", (OUT_DIR / ".hf")))
78
+ TRANSFORMERS_CACHE = Path(os.getenv("TRANSFORMERS_CACHE", (OUT_DIR / ".hf" / "hub")))
79
+
80
+ MIXED_PRECISION = os.getenv("MIXED_PRECISION", "auto").lower()
81
+
82
+ # constants (must match training)
83
+ SAMPLE_RATE = 16000
84
+ CLIP_SECONDS = 30.0
85
+ NUM_FRAMES = 1500
86
+ NUM_TRACKS = 2
87
+ MAX_SEGMENTS = 20
88
+
89
+ # --- MINIMUM SEGMENT LENGTH (seconds) for both per-track and merged segments ---
90
+ MIN_SEGMENT_SEC = float(os.getenv("MIN_SEGMENT_SEC", "1.0"))
91
+ MIN_SEGMENT_FRAMES = max(1, int(round(MIN_SEGMENT_SEC * NUM_FRAMES / CLIP_SECONDS)))
92
+
93
+ FFMPEG_AVAILABLE = shutil.which("ffmpeg") is not None
94
+ WARNED_NO_FFMPEG = False
95
+
96
+ # =========================
97
+ # ====== BASIC SETUP ======
98
+ # =========================
99
+
100
+ def setup_dirs():
101
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
102
+ (OUT_DIR / ".mplconfig").mkdir(parents=True, exist_ok=True)
103
+ os.environ.setdefault("MPLCONFIGDIR", str((OUT_DIR / ".mplconfig").resolve()))
104
+ HF_HOME.mkdir(parents=True, exist_ok=True)
105
+ os.environ.setdefault("HF_HOME", str(HF_HOME.resolve()))
106
+ os.environ.setdefault("TRANSFORMERS_CACHE", str(TRANSFORMERS_CACHE.resolve()))
107
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:128")
108
+
109
+ def preferred_dtype():
110
+ if MIXED_PRECISION == "bf16":
111
+ return torch.bfloat16
112
+ if MIXED_PRECISION == "fp16":
113
+ return torch.float16
114
+ if MIXED_PRECISION == "fp32":
115
+ return torch.float32
116
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
117
+ return torch.bfloat16
118
+ return torch.float16 if torch.cuda.is_available() else torch.float32
119
+
120
+ def _model_resolved_name(model_id: str) -> Tuple[str, bool]:
121
+ if USE_LOCAL_MODELS and MODELS_SNAPSHOT_DIR and MODELS_SNAPSHOT_DIR.is_dir():
122
+ local_dirname = model_id.replace("/", "__")
123
+ cand = MODELS_SNAPSHOT_DIR / local_dirname
124
+ if cand.is_dir():
125
+ return str(cand), True
126
+ return model_id, False
127
+
128
+ # =========================
129
+ # ========= MODEL =========
130
+ # =========================
131
+
132
+ class WhisperOddEven(nn.Module):
133
+ def __init__(self, base_id: str, freeze_encoder: bool = False):
134
+ super().__init__()
135
+ resolved, is_local = _model_resolved_name(base_id)
136
+ self.whisper = WhisperModel.from_pretrained(resolved, local_files_only=is_local)
137
+
138
+ # decoder unused
139
+ for p in self.whisper.decoder.parameters():
140
+ p.requires_grad = False
141
+
142
+ for p in self.whisper.encoder.parameters():
143
+ p.requires_grad = not freeze_encoder
144
+
145
+ d_model = self.whisper.config.d_model
146
+ hidden = max(256, d_model // 2)
147
+ self.head = nn.Sequential(
148
+ nn.Linear(d_model, hidden),
149
+ nn.GELU(),
150
+ nn.Linear(hidden, NUM_TRACKS * (MAX_SEGMENTS + 1)),
151
+ )
152
+
153
+ def forward(self, input_features: torch.FloatTensor):
154
+ enc = self.whisper.encoder(input_features=input_features).last_hidden_state # [B,1500,D]
155
+ B, T, D = enc.shape
156
+ logits = self.head(enc) # [B,T,NUM_TRACKS*(C)]
157
+ C = MAX_SEGMENTS + 1
158
+ logits = logits.view(B, T, NUM_TRACKS, C).permute(0, 2, 1, 3).contiguous()
159
+ return logits # [B,2,1500,C]
160
+
161
+ # =========================
162
+ # ====== AUDIO UTILS ======
163
+ # =========================
164
+
165
+ def load_audio_mono_16k(path: Path) -> np.ndarray:
166
+ wav, sr = librosa.load(str(path), sr=SAMPLE_RATE, mono=True)
167
+ if wav.ndim > 1:
168
+ wav = wav.mean(axis=0)
169
+ return wav.astype(np.float32, copy=False)
170
+
171
+ def split_into_chunks(wav: np.ndarray, sr: int, clip_seconds: float):
172
+ chunk_size = int(clip_seconds * sr)
173
+ total = len(wav)
174
+ if total == 0:
175
+ return []
176
+ n_chunks = math.ceil(total / chunk_size)
177
+ chunks = []
178
+ for i in range(n_chunks):
179
+ start = i * chunk_size
180
+ end = min(start + chunk_size, total)
181
+ seg = wav[start:end]
182
+ if len(seg) < chunk_size:
183
+ seg = np.pad(seg, (0, chunk_size - len(seg)), mode="constant")
184
+ chunks.append((i, start, seg.astype(np.float32, copy=False)))
185
+ return chunks
186
+
187
+ def wav_chunk_to_audio_bytes(wav: np.ndarray, sr: int):
188
+ """
189
+ Try to export as MP3 (if ffmpeg is available). Otherwise fall back to WAV.
190
+ Returns (audio_bytes, mime_type).
191
+ """
192
+ global WARNED_NO_FFMPEG
193
+
194
+ buf_wav = io.BytesIO()
195
+ sf.write(buf_wav, wav, sr, format="WAV")
196
+ wav_bytes = buf_wav.getvalue()
197
+
198
+ if not FFMPEG_AVAILABLE:
199
+ if not WARNED_NO_FFMPEG:
200
+ print("[audio] ffmpeg not found; embedding WAV instead of MP3.", flush=True)
201
+ WARNED_NO_FFMPEG = True
202
+ return wav_bytes, "audio/wav"
203
+
204
+ try:
205
+ buf_wav.seek(0)
206
+ audio = AudioSegment.from_file(buf_wav, format="wav")
207
+ out_buf = io.BytesIO()
208
+ audio.export(out_buf, format="mp3", bitrate="128k")
209
+ out_buf.seek(0)
210
+ return out_buf.read(), "audio/mpeg"
211
+ except Exception as e:
212
+ if not WARNED_NO_FFMPEG:
213
+ print(f"[audio] Failed to encode MP3, falling back to WAV: {e}", flush=True)
214
+ WARNED_NO_FFMPEG = True
215
+ return wav_bytes, "audio/wav"
216
+
217
+ # =========================
218
+ # ====== SEGMENT OPS ======
219
+ # =========================
220
+
221
+ def smooth_min_duration(ids: np.ndarray, min_frames: int, max_iter: int = 10) -> np.ndarray:
222
+ """
223
+ Enforce a minimum run length (in frames) for an ID sequence (1D).
224
+ Shorter runs are reassigned to the longer of their neighbors, iteratively.
225
+ """
226
+ ids = ids.copy()
227
+ n = len(ids)
228
+ if n == 0:
229
+ return ids
230
+
231
+ for _ in range(max_iter):
232
+ runs = []
233
+ start = 0
234
+ cur = ids[0]
235
+ for i in range(1, n):
236
+ if ids[i] != cur:
237
+ runs.append((cur, start, i))
238
+ start = i
239
+ cur = ids[i]
240
+ runs.append((cur, start, n))
241
+
242
+ changed = False
243
+ for ri, (label, s, e) in enumerate(runs):
244
+ length = e - s
245
+ if length >= min_frames:
246
+ continue
247
+
248
+ left = runs[ri - 1] if ri > 0 else None
249
+ right = runs[ri + 1] if ri + 1 < len(runs) else None
250
+ if left is None and right is None:
251
+ continue
252
+
253
+ if left is None:
254
+ new_label = right[0]
255
+ elif right is None:
256
+ new_label = left[0]
257
+ else:
258
+ len_left = left[2] - left[1]
259
+ len_right = right[2] - right[1]
260
+ new_label = left[0] if len_left >= len_right else right[0]
261
+
262
+ if new_label != label:
263
+ ids[s:e] = new_label
264
+ changed = True
265
+
266
+ if not changed:
267
+ break
268
+
269
+ return ids
270
+
271
+ def extract_segments(ids: np.ndarray, include_bg: bool = False):
272
+ """
273
+ Return list of (label, frame_start, frame_end) runs.
274
+ Optionally filter out background label 0.
275
+ """
276
+ n = len(ids)
277
+ if n == 0:
278
+ return []
279
+ runs = []
280
+ start = 0
281
+ cur = ids[0]
282
+ for i in range(1, n):
283
+ if ids[i] != cur:
284
+ runs.append((cur, start, i))
285
+ start = i
286
+ cur = ids[i]
287
+ runs.append((cur, start, n))
288
+ if not include_bg:
289
+ runs = [(lab, s, e) for (lab, s, e) in runs if lab != 0]
290
+ return runs
291
+
292
+ def frames_to_times(s: int, e: int):
293
+ start_t = s / NUM_FRAMES * CLIP_SECONDS
294
+ end_t = e / NUM_FRAMES * CLIP_SECONDS
295
+ return start_t, end_t
296
+
297
+ def cut_wav(seg_wav: np.ndarray, start_t: float, end_t: float) -> np.ndarray:
298
+ start_samp = int(round(start_t * SAMPLE_RATE))
299
+ end_samp = int(round(end_t * SAMPLE_RATE))
300
+ start_samp = max(0, min(start_samp, len(seg_wav)))
301
+ end_samp = max(start_samp + 1, min(end_samp, len(seg_wav)))
302
+ return seg_wav[start_samp:end_samp]
303
+
304
+ # =========================
305
+ # ==== MERGED TIMELINE ====
306
+ # =========================
307
+
308
+ def smooth_merged_segments(merged: List[Tuple[int,int,int,int]], min_frames: int) -> List[Tuple[int,int,int,int]]:
309
+ """
310
+ Enforce minimum length for merged segments.
311
+
312
+ merged: list of (frame_start, frame_end, odd_label, even_label).
313
+ If a segment has length < min_frames, we merge it with a neighbor:
314
+ - If both neighbors exist, choose the one with higher similarity of
315
+ (odd_label, even_label). Similarity is number of matching labels (0..2).
316
+ - If similarity is equal, merge with the neighbor that has shorter
317
+ duration (in frames). If still equal, merge with the left neighbor.
318
+ - If only one neighbor exists, merge with that neighbor.
319
+
320
+ Returns a new merged list.
321
+ """
322
+ if len(merged) <= 1:
323
+ return merged
324
+
325
+ merged = list(merged)
326
+
327
+ def seg_len(seg):
328
+ return seg[1] - seg[0]
329
+
330
+ def sim(a, b):
331
+ # a,b: (fs,fe, odd,even)
332
+ score = 0
333
+ if a[2] == b[2]:
334
+ score += 1
335
+ if a[3] == b[3]:
336
+ score += 1
337
+ return score
338
+
339
+ changed = True
340
+ while changed:
341
+ changed = False
342
+ n = len(merged)
343
+ if n <= 1:
344
+ break
345
+ for i, seg in enumerate(merged):
346
+ length = seg_len(seg)
347
+ if length >= min_frames:
348
+ continue
349
+
350
+ left = merged[i - 1] if i > 0 else None
351
+ right = merged[i + 1] if i + 1 < n else None
352
+
353
+ if left is None and right is None:
354
+ continue
355
+
356
+ # Decide which neighbor to merge with
357
+ if left is not None and right is not None:
358
+ s_left = sim(seg, left)
359
+ s_right = sim(seg, right)
360
+ if s_left > s_right:
361
+ target = "left"
362
+ elif s_right > s_left:
363
+ target = "right"
364
+ else:
365
+ # similarity tie -> choose shorter neighbor
366
+ len_left = seg_len(left)
367
+ len_right = seg_len(right)
368
+ if len_left < len_right:
369
+ target = "left"
370
+ elif len_right < len_left:
371
+ target = "right"
372
+ else:
373
+ target = "left" # full tie -> left
374
+ elif left is not None:
375
+ target = "left"
376
+ else:
377
+ target = "right"
378
+
379
+ if target == "left":
380
+ fs = left[0]
381
+ fe = seg[1]
382
+ odd_label = left[2]
383
+ even_label = left[3]
384
+ merged[i - 1] = (fs, fe, odd_label, even_label)
385
+ del merged[i]
386
+ else:
387
+ fs = seg[0]
388
+ fe = right[1]
389
+ odd_label = right[2]
390
+ even_label = right[3]
391
+ merged[i + 1] = (fs, fe, odd_label, even_label)
392
+ del merged[i]
393
+ changed = True
394
+ break # restart scanning with new list
395
+
396
+ return merged
397
+
398
+ def build_merged_segments(ids_odd: np.ndarray, ids_even: np.ndarray, min_frames: int):
399
+ """
400
+ Build merged segmentation from two tracks and then smooth merged segments.
401
+
402
+ - boundaries are at 0, NUM_FRAMES, and every point where either track changes.
403
+ - for each raw merged segment we set odd/even labels via majority label.
404
+ - then we enforce minimum length for the merged segments via
405
+ smooth_merged_segments.
406
+ """
407
+ assert len(ids_odd) == len(ids_even) == NUM_FRAMES
408
+ n = NUM_FRAMES
409
+ boundaries = {0, n}
410
+ for ids in (ids_odd, ids_even):
411
+ cur = ids[0]
412
+ for i in range(1, n):
413
+ if ids[i] != cur:
414
+ boundaries.add(i)
415
+ cur = ids[i]
416
+ b = sorted(boundaries)
417
+ merged = []
418
+ for i in range(len(b) - 1):
419
+ s = b[i]
420
+ e = b[i + 1]
421
+ if e <= s:
422
+ continue
423
+ slice_odd = ids_odd[s:e]
424
+ slice_even = ids_even[s:e]
425
+ if slice_odd.size == 0 or slice_even.size == 0:
426
+ continue
427
+ odd_vals, odd_counts = np.unique(slice_odd, return_counts=True)
428
+ even_vals, even_counts = np.unique(slice_even, return_counts=True)
429
+ odd_label = int(odd_vals[np.argmax(odd_counts)])
430
+ even_label = int(even_vals[np.argmax(even_counts)])
431
+ merged.append((s, e, odd_label, even_label))
432
+
433
+ # Now enforce min length also on merged segments
434
+ merged = smooth_merged_segments(merged, min_frames)
435
+ return merged
436
+
437
+ # =========================
438
+ # ======= PLOTTING ========
439
+ # =========================
440
+
441
+ def _plot_tracks_seconds(pred_ids: torch.Tensor, title: str) -> bytes:
442
+ """
443
+ pred_ids: [2, NUM_FRAMES] LongTensor
444
+ """
445
+ secs = np.linspace(0.0, CLIP_SECONDS, NUM_FRAMES)
446
+ fig = plt.figure(figsize=(10, 2.8))
447
+ ax = plt.gca()
448
+ im = ax.imshow(
449
+ pred_ids.numpy(),
450
+ aspect="auto",
451
+ interpolation="nearest",
452
+ origin="upper",
453
+ extent=[secs[0], secs[-1], -0.5, 1.5],
454
+ )
455
+ ax.set_title(title)
456
+ ax.set_xlabel("Time (s)")
457
+ ax.set_yticks([0, 1])
458
+ ax.set_yticklabels(["odd", "even"])
459
+ cb = plt.colorbar(im, fraction=0.046, pad=0.04)
460
+ cb.set_label("Segment ID")
461
+ buf = io.BytesIO()
462
+ fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
463
+ plt.close(fig)
464
+ buf.seek(0)
465
+ return buf.read()
466
+
467
+ def _plot_merged_segments(seg_ids: np.ndarray, title: str) -> bytes:
468
+ """
469
+ seg_ids: [NUM_FRAMES] array where each frame holds a merged-segment index.
470
+ """
471
+ secs = np.linspace(0.0, CLIP_SECONDS, NUM_FRAMES)
472
+ fig = plt.figure(figsize=(10, 2.8))
473
+ ax = plt.gca()
474
+ im = ax.imshow(
475
+ seg_ids[np.newaxis, :],
476
+ aspect="auto",
477
+ interpolation="nearest",
478
+ origin="upper",
479
+ extent=[secs[0], secs[-1], -0.5, 0.5],
480
+ )
481
+ ax.set_title(title)
482
+ ax.set_xlabel("Time (s)")
483
+ ax.set_yticks([0])
484
+ ax.set_yticklabels(["merged"])
485
+ cb = plt.colorbar(im, fraction=0.046, pad=0.04)
486
+ cb.set_label("Merged seg ID")
487
+ buf = io.BytesIO()
488
+ fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
489
+ plt.close(fig)
490
+ buf.seek(0)
491
+ return buf.read()
492
+
493
+ # =========================
494
+ # ========= HTML ==========
495
+ # =========================
496
+
497
+ def write_html_report(out_dir: Path, chunks: List[Dict[str, Any]]) -> Path:
498
+ ts = time.strftime("%Y%m%d_%H%M%S")
499
+ html = [f"""<!doctype html><html><head><meta charset="utf-8">
500
+ <style>
501
+ body{{font-family:system-ui,Segoe UI,Roboto,Arial,sans-serif;margin:20px}}
502
+ .card{{border:1px solid #ddd;border-radius:10px;padding:16px;margin:16px 0;
503
+ box-shadow:0 2px 6px rgba(0,0,0,.05)}}
504
+ .grid{{display:grid;grid-template-columns:1fr 1fr;gap:12px}}
505
+ figure{{margin:0}}
506
+ figcaption{{font-size:13px;color:#555;margin-top:6px}}
507
+ audio{{width:100%;min-width:200px;margin-top:4px}}
508
+ .meta{{font-size:13px;color:#666;margin-bottom:4px}}
509
+ table{{border-collapse:collapse;width:100%;margin-top:8px;font-size:13px;table-layout:fixed}}
510
+ th,td{{border:1px solid #ddd;padding:4px 6px;text-align:left;vertical-align:top;overflow:hidden;text-overflow:ellipsis;white-space:nowrap}}
511
+ th{{background:#f5f5f5}}
512
+ </style>
513
+ <title>Odd/Even Segmentation - Inference {ts}</title></head><body>
514
+ <h1>Odd/Even Segmentation - Inference</h1>
515
+ <p>
516
+ This report shows <b>smoothed</b> segmentations for each 30-second chunk of your audio files.
517
+ The model predicts two parallel time tracks ("odd" and "even") that can hold overlapping events.
518
+ We first smooth each track so that <b>no segment (including background 0) is shorter than {MIN_SEGMENT_SEC:.2f} seconds</b>.
519
+ Then:
520
+ </p>
521
+ <ul>
522
+ <li><b>Per-track segments</b>: segments for each track (odd/even) with duration &gt;= {MIN_SEGMENT_SEC:.2f}s, each with its own audio player.</li>
523
+ <li><b>Merged timeline</b>: a single segmentation where a new segment starts or ends whenever either track changes, and each merged segment is also at least {MIN_SEGMENT_SEC:.2f}s long by merging very short segments into their most similar neighbor.</li>
524
+ </ul>
525
+ """]
526
+
527
+ for ch in chunks:
528
+ html.append(f"""
529
+ <section class="card">
530
+ <h2>{ch['file_name']} - chunk {ch['chunk_idx']}</h2>
531
+ <div class="meta">
532
+ Chunk offset in file: {ch['chunk_offset']:.2f} - {ch['chunk_offset'] + CLIP_SECONDS:.2f} s
533
+ </div>
534
+ <div class="grid">
535
+ <figure>
536
+ <img src="data:image/png;base64,{ch['png_tracks']}" alt="smoothed tracks">
537
+ <figcaption>Smoothed per-track predictions (odd/even).</figcaption>
538
+ </figure>
539
+ <figure>
540
+ <img src="data:image/png;base64,{ch['png_merged']}" alt="merged timeline">
541
+ <figcaption>Merged timeline: segment borders whenever odd or even track changes label, then smoothed to enforce a minimum duration.</figcaption>
542
+ </figure>
543
+ </div>
544
+
545
+ <h3>Per-track segments (min {MIN_SEGMENT_SEC:.2f} s)</h3>
546
+ <p>Each row is one predicted event on the odd or even track. Times are relative to the start of this 30-second chunk.</p>
547
+ <table class="seg seg-track">
548
+ <colgroup>
549
+ <col style="width:5%">
550
+ <col style="width:10%">
551
+ <col style="width:10%">
552
+ <col style="width:10%">
553
+ <col style="width:10%">
554
+ <col style="width:10%">
555
+ <col style="width:45%">
556
+ </colgroup>
557
+ <tr><th>#</th><th>Track</th><th>Label ID</th><th>Start (s)</th><th>End (s)</th>
558
+ <th>Duration (s)</th><th>Audio</th></tr>
559
+ """)
560
+ # per-track table
561
+ for i, seg in enumerate(ch["track_segments"], start=1):
562
+ audio_cell = ""
563
+ if seg["audio_b64"] and seg["audio_mime"]:
564
+ audio_cell = (
565
+ '<audio controls preload="none">'
566
+ f'<source src="data:{seg["audio_mime"]};base64,{seg["audio_b64"]}" '
567
+ f'type="{seg["audio_mime"]}"></audio>'
568
+ )
569
+ html.append(
570
+ f"<tr><td>{i}</td>"
571
+ f"<td>{seg['track']}</td>"
572
+ f"<td>{seg['label']}</td>"
573
+ f"<td>{seg['start']:.2f}</td>"
574
+ f"<td>{seg['end']:.2f}</td>"
575
+ f"<td>{seg['dur']:.2f}</td>"
576
+ f"<td>{audio_cell}</td></tr>"
577
+ )
578
+ html.append("</table>")
579
+
580
+ # merged timeline table
581
+ html.append(f"""
582
+ <h3>Merged timeline segments</h3>
583
+ <p>
584
+ The merged timeline splits the 30-second chunk wherever either the odd or even track changes label.
585
+ Very short merged segments (shorter than {MIN_SEGMENT_SEC:.2f}s) are merged into their most similar neighbor
586
+ based on odd/even labels; if both neighbors are equally similar, they are merged into the shorter neighbor.
587
+ This yields a single sequence of non-overlapping segments that cover the entire chunk.
588
+ Each row shows the majority label on the odd and even tracks within that merged segment.
589
+ </p>
590
+ <table class="seg seg-merged">
591
+ <colgroup>
592
+ <col style="width:5%">
593
+ <col style="width:10%">
594
+ <col style="width:10%">
595
+ <col style="width:10%">
596
+ <col style="width:10%">
597
+ <col style="width:10%">
598
+ <col style="width:45%">
599
+ </colgroup>
600
+ <tr><th>#</th><th>Start (s)</th><th>End (s)</th><th>Duration (s)</th>
601
+ <th>Odd label</th><th>Even label</th><th>Audio</th></tr>
602
+ """)
603
+ for i, seg in enumerate(ch["merged_segments"], start=1):
604
+ audio_cell = ""
605
+ if seg["audio_b64"] and seg["audio_mime"]:
606
+ audio_cell = (
607
+ '<audio controls preload="none">'
608
+ f'<source src="data:{seg["audio_mime"]};base64,{seg["audio_b64"]}" '
609
+ f'type="{seg["audio_mime"]}"></audio>'
610
+ )
611
+ html.append(
612
+ f"<tr><td>{i}</td>"
613
+ f"<td>{seg['start']:.2f}</td>"
614
+ f"<td>{seg['end']:.2f}</td>"
615
+ f"<td>{seg['dur']:.2f}</td>"
616
+ f"<td>{seg['odd_label']}</td>"
617
+ f"<td>{seg['even_label']}</td>"
618
+ f"<td>{audio_cell}</td></tr>"
619
+ )
620
+ html.append("</table></section>")
621
+
622
+ html.append("</body></html>")
623
+ out_path = out_dir / f"seg_infer_smooth_{ts}.html"
624
+ out_path.write_text("\n".join(html), encoding="utf-8")
625
+ return out_path
626
+
627
+ # =========================
628
+ # ========= MAIN ==========
629
+ # =========================
630
+
631
+ def main():
632
+ setup_dirs()
633
+
634
+ global AUDIO_INPUT_DIR
635
+ if len(sys.argv) > 1:
636
+ AUDIO_INPUT_DIR = Path(sys.argv[1])
637
+
638
+ if not AUDIO_INPUT_DIR.is_dir():
639
+ print(f"[ERR] AUDIO_INPUT_DIR not found or not a dir: {AUDIO_INPUT_DIR}", file=sys.stderr)
640
+ sys.exit(1)
641
+ if not CKPT_PATH.is_file():
642
+ print(f"[ERR] Checkpoint not found: {CKPT_PATH}", file=sys.stderr)
643
+ sys.exit(1)
644
+
645
+ print(f"[cfg] AUDIO_INPUT_DIR = {AUDIO_INPUT_DIR}")
646
+ print(f"[cfg] OUT_DIR = {OUT_DIR}")
647
+ print(f"[cfg] CKPT_PATH = {CKPT_PATH}")
648
+ print(f"[cfg] HF_MODEL_ID = {HF_MODEL_ID}")
649
+ print(f"[cfg] ffmpeg available: {FFMPEG_AVAILABLE}")
650
+ print(f"[cfg] MIN_SEGMENT_SEC = {MIN_SEGMENT_SEC:.2f} (frames >= {MIN_SEGMENT_FRAMES})")
651
+
652
+ # find audio files
653
+ exts = {".wav", ".mp3", ".m4a", ".flac", ".ogg"}
654
+ audio_files: List[Path] = []
655
+ for p in AUDIO_INPUT_DIR.rglob("*"):
656
+ if p.is_file() and p.suffix.lower() in exts:
657
+ audio_files.append(p)
658
+ audio_files = sorted(audio_files)
659
+
660
+ if not audio_files:
661
+ print("[ERR] No audio files found.", file=sys.stderr)
662
+ sys.exit(1)
663
+
664
+ print(f"[scan] Found {len(audio_files)} audio files.")
665
+
666
+ # feature extractor
667
+ resolved, is_local = _model_resolved_name(HF_MODEL_ID)
668
+ fe = WhisperFeatureExtractor.from_pretrained(resolved, local_files_only=is_local)
669
+
670
+ # model + checkpoint
671
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
672
+ model = WhisperOddEven(HF_MODEL_ID, freeze_encoder=False).to(device)
673
+
674
+ state = torch.load(CKPT_PATH, map_location="cpu")
675
+ # accept full trainer_state dict or plain state_dict
676
+ if isinstance(state, dict) and "model" in state and any(
677
+ k.startswith("whisper.") for k in state["model"].keys()
678
+ ):
679
+ state = state["model"]
680
+
681
+ missing, unexpected = model.load_state_dict(state, strict=False)
682
+ print(f"[ckpt] Loaded checkpoint from {CKPT_PATH}")
683
+ if missing:
684
+ print(f"[ckpt] Missing keys: {missing}")
685
+ if unexpected:
686
+ print(f"[ckpt] Unexpected keys: {unexpected}")
687
+ model.eval()
688
+
689
+ use_dtype = preferred_dtype()
690
+ amp_enabled = use_dtype in (torch.float16, torch.bfloat16)
691
+
692
+ chunk_results: List[Dict[str, Any]] = []
693
+
694
+ with torch.no_grad():
695
+ for fpath in audio_files:
696
+ print(f"[file] {fpath}")
697
+ try:
698
+ wav = load_audio_mono_16k(fpath)
699
+ except Exception as e:
700
+ print(f"[file] Failed to load {fpath}: {e}")
701
+ continue
702
+
703
+ chunks = split_into_chunks(wav, SAMPLE_RATE, CLIP_SECONDS)
704
+ if not chunks:
705
+ print(f"[file] No audio samples in {fpath}")
706
+ continue
707
+
708
+ for chunk_idx, start_sample, seg in chunks:
709
+ chunk_offset_sec = start_sample / SAMPLE_RATE
710
+
711
+ # features
712
+ feat = fe(seg, sampling_rate=SAMPLE_RATE, return_tensors="pt")
713
+ x = feat.input_features.to(device)
714
+
715
+ # forward
716
+ with torch.autocast(
717
+ device_type="cuda" if torch.cuda.is_available() else "cpu",
718
+ enabled=amp_enabled,
719
+ dtype=use_dtype,
720
+ ):
721
+ logits = model(x)
722
+
723
+ # raw argmax
724
+ raw_ids = logits.argmax(dim=-1).squeeze(0).cpu().numpy() # [2,1500]
725
+
726
+ # aggressive smoothing with min duration per track
727
+ sm_ids = np.zeros_like(raw_ids)
728
+ for tr in range(NUM_TRACKS):
729
+ sm_ids[tr] = smooth_min_duration(raw_ids[tr], MIN_SEGMENT_FRAMES)
730
+
731
+ sm_ids_t = torch.from_numpy(sm_ids)
732
+ png_tracks = base64.b64encode(
733
+ _plot_tracks_seconds(
734
+ sm_ids_t,
735
+ f"Smoothed tracks - {fpath.name} - chunk {chunk_idx}",
736
+ )
737
+ ).decode("ascii")
738
+
739
+ # merged timeline with its own min-duration smoothing
740
+ merged = build_merged_segments(sm_ids[0], sm_ids[1], MIN_SEGMENT_FRAMES)
741
+ merged_index = np.zeros(NUM_FRAMES, dtype=np.int64)
742
+ for idx, (fs, fe_, _ol, _el) in enumerate(merged, start=1):
743
+ merged_index[fs:fe_] = idx
744
+
745
+ png_merged = base64.b64encode(
746
+ _plot_merged_segments(
747
+ merged_index,
748
+ f"Merged segments - {fpath.name} - chunk {chunk_idx}",
749
+ )
750
+ ).decode("ascii")
751
+
752
+ # per-track segments -> audio snippets
753
+ track_segments: List[Dict[str, Any]] = []
754
+ for tr, track_name in enumerate(("odd", "even")):
755
+ seg_runs = extract_segments(sm_ids[tr], include_bg=False)
756
+ for (lab, fs, fe_) in seg_runs:
757
+ start_t, end_t = frames_to_times(fs, fe_)
758
+ dur = end_t - start_t
759
+ if dur <= 0:
760
+ continue
761
+ sub_wav = cut_wav(seg, start_t, end_t)
762
+ if sub_wav.size == 0:
763
+ continue
764
+ try:
765
+ audio_bytes, audio_mime = wav_chunk_to_audio_bytes(sub_wav, SAMPLE_RATE)
766
+ audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
767
+ except Exception as e:
768
+ print(f"[audio] Failed per-track snippet for {fpath} chunk {chunk_idx}: {e}")
769
+ audio_b64 = None
770
+ audio_mime = None
771
+
772
+ track_segments.append(
773
+ {
774
+ "track": track_name,
775
+ "label": int(lab),
776
+ "start": float(start_t),
777
+ "end": float(end_t),
778
+ "dur": float(dur),
779
+ "audio_b64": audio_b64,
780
+ "audio_mime": audio_mime,
781
+ }
782
+ )
783
+
784
+ # merged segments -> audio snippets
785
+ merged_segments: List[Dict[str, Any]] = []
786
+ for idx, (fs, fe_, odd_label, even_label) in enumerate(merged, start=1):
787
+ start_t, end_t = frames_to_times(fs, fe_)
788
+ dur = end_t - start_t
789
+ if dur <= 0:
790
+ continue
791
+ sub_wav = cut_wav(seg, start_t, end_t)
792
+ if sub_wav.size == 0:
793
+ continue
794
+ try:
795
+ audio_bytes, audio_mime = wav_chunk_to_audio_bytes(sub_wav, SAMPLE_RATE)
796
+ audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
797
+ except Exception as e:
798
+ print(f"[audio] Failed merged snippet for {fpath} chunk {chunk_idx}: {e}")
799
+ audio_b64 = None
800
+ audio_mime = None
801
+
802
+ merged_segments.append(
803
+ {
804
+ "idx": idx,
805
+ "start": float(start_t),
806
+ "end": float(end_t),
807
+ "dur": float(dur),
808
+ "odd_label": int(odd_label),
809
+ "even_label": int(even_label),
810
+ "audio_b64": audio_b64,
811
+ "audio_mime": audio_mime,
812
+ }
813
+ )
814
+
815
+ chunk_results.append(
816
+ {
817
+ "file_name": fpath.name,
818
+ "chunk_idx": int(chunk_idx),
819
+ "chunk_offset": float(chunk_offset_sec),
820
+ "png_tracks": png_tracks,
821
+ "png_merged": png_merged,
822
+ "track_segments": track_segments,
823
+ "merged_segments": merged_segments,
824
+ }
825
+ )
826
+
827
+ if not chunk_results:
828
+ print("[ERR] No chunk results; nothing to write.", file=sys.stderr)
829
+ sys.exit(1)
830
+
831
+ out_html = write_html_report(OUT_DIR, chunk_results)
832
+ print(f"[done] Wrote HTML report: {out_html}")
833
+
834
+ if __name__ == "__main__":
835
+ main()