unknown commited on
Commit
bcc5768
·
1 Parent(s): 36b339e
Files changed (1) hide show
  1. app.py +665 -665
app.py CHANGED
@@ -1,823 +1,823 @@
1
- import re
2
- from dataclasses import dataclass
3
- from typing import Any, Dict, List, Tuple, Optional
4
-
5
- import gradio as gr
6
- from huggingface_hub import list_repo_files, hf_hub_download
7
- from pydub import AudioSegment
8
- import numpy as np
9
-
10
- # =========================================================
11
- # Config
12
- # =========================================================
13
- MEDIA_EXTS = (".mp4", ".m4a", ".mp3", ".wav", ".flac", ".ogg", ".aac", ".mov", ".avi")
14
- VTT_EXTS = (".vtt",)
15
-
16
- DEFAULT_MAX_MID_DIFF = 1.5
17
-
18
- # Normalize audio for stable playback in browsers
19
- TARGET_SR = 48000
20
- TARGET_CH = 1 # mono
21
- TARGET_SW = 2 # 16-bit PCM
22
 
 
 
 
 
23
 
24
- # =========================================================
25
- # Data structures
26
- # =========================================================
27
- @dataclass
28
- class Cue:
29
- start: float
30
- end: float
31
- text: str
32
 
 
33
 
34
- # =========================================================
35
- # VTT parsing
36
- # =========================================================
37
- _TAG_RE = re.compile(r"</?[^>]+?>", re.IGNORECASE)
38
- _VTT_TIME_RE = re.compile(
39
- r"(?P<start>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})\s*-->\s*"
40
- r"(?P<end>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})"
41
- )
42
 
43
 
44
- def _strip_tags(text: str) -> str:
45
- return _TAG_RE.sub("", text).strip()
 
 
 
 
 
 
46
 
47
 
48
- def _time_to_seconds(t: str) -> float:
49
- parts = t.split(":")
50
- if len(parts) == 3:
51
- return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
52
- if len(parts) == 2:
53
- return int(parts[0]) * 60 + float(parts[1])
54
- raise ValueError(f"Bad VTT timestamp: {t}")
 
55
 
56
 
57
- def parse_vtt_file(path: str) -> List[Cue]:
58
- with open(path, "r", encoding="utf-8") as f:
59
- content = f.read()
60
 
61
- # Remove BOM / WEBVTT header (if any)
62
- content = content.replace("\ufeff", "")
63
- content = re.sub(r"^\s*WEBVTT.*?\n", "", content, flags=re.IGNORECASE)
64
 
65
- blocks = re.split(r"\r?\n\r?\n", content.strip())
66
- cues: List[Cue] = []
 
 
 
 
 
67
 
68
- for block in blocks:
69
- lines = [l.strip() for l in block.splitlines() if l.strip()]
70
- if not lines:
71
- continue
72
 
73
- # Locate the timestamp line (must contain "-->")
74
- time_idx: Optional[int] = None
75
- for i, line in enumerate(lines):
76
- if "-->" in line:
77
- time_idx = i
78
- break
79
- if time_idx is None:
80
- continue
81
 
82
- m = _VTT_TIME_RE.search(lines[time_idx])
83
- if not m:
84
- continue
85
 
86
- start = _time_to_seconds(m.group("start"))
87
- end = _time_to_seconds(m.group("end"))
88
- if end <= start:
89
- continue
90
 
91
- # Only take lines after the timestamp line as subtitle text
92
- text_lines = lines[time_idx + 1 :]
93
- if not text_lines:
94
- continue
95
 
96
- text = _strip_tags("\n".join(text_lines))
97
- if text:
98
- cues.append(Cue(start=start, end=end, text=text))
 
 
 
 
 
99
 
100
- return sorted(cues, key=lambda x: x.start)
 
 
101
 
 
 
 
 
102
 
103
- # =========================================================
104
- # Alignment (match by mid time), preserve per-track windows
105
- # =========================================================
106
- def align_by_time(a: List[Cue], b: List[Cue], th: float) -> List[Dict[str, Any]]:
107
- out: List[Dict[str, Any]] = []
108
- i, j, idx = 0, 0, 1
109
 
110
- while i < len(a) and j < len(b):
111
- ma = (a[i].start + a[i].end) / 2
112
- mb = (b[j].start + b[j].end) / 2
113
 
114
- if abs(ma - mb) <= th:
115
- out.append(
116
- {
117
- "idx": idx,
118
- # Per-track time window (recommended for playback)
119
- "a_start": a[i].start,
120
- "a_end": a[i].end,
121
- "b_start": b[j].start,
122
- "b_end": b[j].end,
123
- # Optional global time window (for comparison/debug)
124
- "start": min(a[i].start, b[j].start),
125
- "end": max(a[i].end, b[j].end),
126
- "a_text": a[i].text,
127
- "b_text": b[j].text,
128
- }
129
- )
130
- idx += 1
131
- i += 1
132
- j += 1
133
- elif ma < mb:
134
- i += 1
135
- else:
136
- j += 1
137
 
138
- return out
139
 
 
 
 
 
 
 
140
 
141
- # =========================================================
142
- # Audio slicing -> return (sr, np.int16) for gr.Audio(type="numpy")
143
- # =========================================================
144
- def export_segment_numpy(audio: AudioSegment, start: float, end: float) -> Tuple[int, np.ndarray]:
145
- """
146
- Robust segment export for gr.Audio(type="numpy").
147
 
148
- Key points:
149
- - Clamp start/end (after any offsets) to valid range.
150
- - Use *rounded* ms boundaries to avoid systematic truncation drift.
151
- - Slice via pydub (ms-accurate) using the original stream timeline.
152
- - Normalize to mono/48k/int16 for stable browser playback.
153
- - Return (sr, int16 ndarray) to avoid float32 scaling pitfalls.
154
- """
155
- # Clamp and ensure minimum duration
156
- start = float(start)
157
- end = float(end)
158
- if end < start:
159
- start, end = end, start
160
- start = max(0.0, start)
161
- end = max(start + 0.05, end)
 
 
 
 
 
 
 
 
 
162
 
163
- # Round to milliseconds (avoid int() truncation bias)
164
- start_ms = int(round(start * 1000.0))
165
- end_ms = int(round(end * 1000.0))
166
 
167
- seg = audio[start_ms:end_ms]
168
 
169
- # Normalize to mono/48k/int16
170
- seg = seg.set_channels(TARGET_CH).set_frame_rate(TARGET_SR).set_sample_width(TARGET_SW)
 
 
 
 
171
 
172
- arr = np.asarray(seg.get_array_of_samples())
173
- if arr.dtype != np.int16:
174
- arr = arr.astype(np.int16, copy=False)
175
-
176
- return TARGET_SR, np.ascontiguousarray(arr)
 
 
 
 
 
 
 
 
 
177
 
 
 
 
178
 
179
- # =========================================================
180
- # Helper: robustly read seg_idx from gr.Dataframe value
181
- # =========================================================
182
- def _get_seg_idx_from_df(df_value: Any, row: int) -> Optional[int]:
183
- if df_value is None:
184
- return None
185
 
186
- # pandas DataFrame in some Gradio versions
187
- try:
188
- import pandas as pd # type: ignore
189
- if isinstance(df_value, pd.DataFrame):
190
- if row < 0 or row >= len(df_value.index) or df_value.shape[1] < 1:
191
- return None
192
- return int(df_value.iloc[row, 0])
193
- except Exception:
194
- pass
195
 
196
- # list-of-lists
197
- try:
198
- if isinstance(df_value, list) and row >= 0 and row < len(df_value) and len(df_value[row]) >= 1:
199
- return int(df_value[row][0])
200
- except Exception:
201
- return None
202
 
203
- return None
204
 
205
 
206
- # =========================================================
207
- # Gradio callbacks
208
- # =========================================================
209
- def scan_dataset(repo_id: str, repo_type: str):
210
- if not repo_id:
211
- raise gr.Error("请填写 Dataset / Repo 名称(例如 org/dataset)。")
212
 
213
- files = list_repo_files(repo_id, repo_type=repo_type)
214
- media_files = sorted([f for f in files if f.lower().endswith(MEDIA_EXTS)])
215
- vtt_files = sorted([f for f in files if f.lower().endswith(VTT_EXTS)])
 
 
 
 
 
 
216
 
217
- if not media_files:
218
- raise gr.Error("未找到媒体文件(mp4/mp3/wav 等)。")
219
- if not vtt_files:
220
- raise gr.Error("未找到 VTT 字幕文件。")
 
 
221
 
222
- return (
223
- gr.update(choices=media_files, value=media_files[0]),
224
- gr.update(choices=media_files, value=media_files[0]),
225
- gr.update(choices=vtt_files, value=vtt_files[0]),
226
- gr.update(choices=vtt_files, value=vtt_files[0]),
227
- )
228
 
229
 
230
- def load_and_align(repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th):
231
- if not all([repo_id, repo_type, media_a, media_b, vtt_a, vtt_b]):
232
- raise gr.Error("请先选择 A/B 的媒体文件与 VTT 文件。")
 
 
 
233
 
234
- local_media_a = hf_hub_download(repo_id, media_a, repo_type=repo_type)
235
- local_media_b = hf_hub_download(repo_id, media_b, repo_type=repo_type)
236
- local_vtt_a = hf_hub_download(repo_id, vtt_a, repo_type=repo_type)
237
- local_vtt_b = hf_hub_download(repo_id, vtt_b, repo_type=repo_type)
238
 
239
- try:
240
- audio_a = AudioSegment.from_file(local_media_a)
241
- audio_b = AudioSegment.from_file(local_media_b)
242
- except Exception as e:
243
- raise gr.Error(
244
- "媒体解码失败。若是 mp4/m4a,通常需要 ffmpeg。\n"
245
- f"原始错误: {repr(e)}"
246
- )
247
 
248
- cues_a = parse_vtt_file(local_vtt_a)
249
- cues_b = parse_vtt_file(local_vtt_b)
250
- if not cues_a or not cues_b:
251
- raise gr.Error("VTT 解析为空,请检查字幕文件内容。")
 
 
252
 
253
- # ---- Drift fix: estimate time-scale (linear) between VTT timeline and audio timeline ----
254
- # If you observe increasing offset over time, it is usually a *scale* mismatch rather than a constant offset.
255
- # We estimate per-track scale by comparing audio duration to the last cue end time.
256
- a_vtt_end = max(c.end for c in cues_a) if cues_a else 0.0
257
- b_vtt_end = max(c.end for c in cues_b) if cues_b else 0.0
258
- a_dur = float(audio_a.duration_seconds)
259
- b_dur = float(audio_b.duration_seconds)
260
 
261
- # Default scale = 1.0 when we cannot estimate reliably.
262
- scale_a_suggest = (a_dur / a_vtt_end) if a_vtt_end > 1.0 and a_dur > 1.0 else 1.0
263
- scale_b_suggest = (b_dur / b_vtt_end) if b_vtt_end > 1.0 and b_dur > 1.0 else 1.0
264
 
265
- aligned = align_by_time(cues_a, cues_b, float(th))
266
- if not aligned:
267
- raise gr.Error("未对齐到任何字幕片段,请尝试增大对齐阈值。")
 
268
 
269
- rows = [
270
- [
271
- x["idx"],
272
- f'{x["a_start"]:.2f}-{x["a_end"]:.2f}',
273
- f'{x["b_start"]:.2f}-{x["b_end"]:.2f}',
274
- x["a_text"],
275
- x["b_text"],
276
- ]
277
- for x in aligned
278
- ]
279
 
280
- # Critical: build idx -> seg map to survive dataframe sorting/reordering
281
- idx_map = {int(x["idx"]): x for x in aligned}
 
 
282
 
283
- state = {
284
- "aligned": aligned,
285
- "idx_map": idx_map,
286
- "audio_a": audio_a,
287
- "audio_b": audio_b,
288
- "scale_a_suggest": scale_a_suggest,
289
- "scale_b_suggest": scale_b_suggest,
290
- "a_vtt_end": a_vtt_end,
291
- "b_vtt_end": b_vtt_end,
292
- "a_dur": a_dur,
293
- "b_dur": b_dur,
294
- }
295
 
296
- # Clear old playback outputs
297
- return rows, state, None, None, {}, gr.update(value=scale_a_suggest), gr.update(value=scale_b_suggest)
 
298
 
 
 
 
299
 
300
- def play_on_select(evt: gr.SelectData, df_value, crop_mode, offset_a, offset_b, scale_a, scale_b, state):
301
- if not state or "aligned" not in state:
302
- raise gr.Error("请先加载并对齐。")
 
 
 
 
 
 
 
303
 
304
- # evt.index: int or (row, col)
305
- idx_raw = evt.index
306
- row = int(idx_raw[0] if isinstance(idx_raw, (tuple, list)) else idx_raw)
307
 
308
- offset_a = float(offset_a)
309
- offset_b = float(offset_b)
310
- scale_a = float(scale_a)
311
- scale_b = float(scale_b)
 
 
 
 
 
 
 
 
312
 
313
- # Prefer seg_idx from the clicked row's first column; then resolve via idx_map.
314
- seg_idx = _get_seg_idx_from_df(df_value, row)
315
- seg = None
316
- idx_map = state.get("idx_map", {}) or {}
317
- if seg_idx is not None and seg_idx in idx_map:
318
- seg = idx_map[seg_idx]
319
- else:
320
- # Fallback to row->aligned if idx missing (should be rare)
321
- aligned = state["aligned"]
322
- if row < 0 or row >= len(aligned):
323
- raise gr.Error("选中行越界,请重试或重新对齐。")
324
- seg = aligned[row]
325
- seg_idx = int(seg.get("idx", row + 1))
326
 
327
- if crop_mode == "global":
328
- a_start, a_end = seg["start"] * scale_a + offset_a, seg["end"] * scale_a + offset_a
329
- b_start, b_end = seg["start"] * scale_b + offset_b, seg["end"] * scale_b + offset_b
330
- else:
331
- # per_track playback (recommended)
332
- a_start, a_end = seg["a_start"] * scale_a + offset_a, seg["a_end"] * scale_a + offset_a
333
- b_start, b_end = seg["b_start"] * scale_b + offset_b, seg["b_end"] * scale_b + offset_b
334
 
335
- a_np = export_segment_numpy(state["audio_a"], a_start, a_end)
336
- b_np = export_segment_numpy(state["audio_b"], b_start, b_end)
 
337
 
338
- info = {
339
- "segment": seg_idx,
340
- "row": row,
341
- "crop_mode": crop_mode,
342
- "A_time": f"{a_start:.2f}-{a_end:.2f}",
343
- "B_time": f"{b_start:.2f}-{b_end:.2f}",
344
- "scale_a": scale_a,
345
- "scale_b": scale_b,
346
- "scale_a_suggest": state.get("scale_a_suggest", 1.0),
347
- "scale_b_suggest": state.get("scale_b_suggest", 1.0),
348
- }
349
- return a_np, b_np, info
350
 
 
 
 
 
351
 
352
- # =========================================================
353
- # UI
354
- # =========================================================
355
- with gr.Blocks(title="双语音频字幕对齐(点击即播放)") as demo:
356
- gr.Markdown(
357
- "# 双语音频字幕对齐(点击表格即播放)\n"
358
- "流程:扫描 Dataset → 选择 A/B 媒体与字幕 → 加载并对齐 → 点击表格任意单元格播放对应片段。\n"
359
- "若字幕与音频整体存在固定延迟,可用 Track A/B 偏移进行校正。"
360
- )
 
 
 
 
361
 
362
- state = gr.State()
 
 
 
 
 
 
363
 
364
- with gr.Row():
365
- repo_id = gr.Textbox(label="Dataset / Repo 名称", placeholder="org/dataset")
366
- repo_type = gr.Radio(["dataset", "model"], value="dataset", label="Repo 类型")
367
 
368
- btn_scan = gr.Button("扫描 Dataset", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
369
 
370
- with gr.Row():
371
- media_a = gr.Dropdown(label="Track A 媒体")
372
- media_b = gr.Dropdown(label="Track B 媒体")
373
 
374
- with gr.Row():
375
- vtt_a = gr.Dropdown(label="Track A 字幕")
376
- vtt_b = gr.Dropdown(label="Track B 字幕")
 
 
 
 
 
 
377
 
378
- btn_scan.click(
379
- scan_dataset,
380
- inputs=[repo_id, repo_type],
381
- outputs=[media_a, media_b, vtt_a, vtt_b],
382
- )
383
 
384
- th = gr.Slider(0.3, 5.0, value=DEFAULT_MAX_MID_DIFF, step=0.1, label="对齐阈值(秒)")
385
- btn_align = gr.Button("加载并对齐", variant="primary")
 
386
 
387
- df = gr.Dataframe(
388
- headers=["#", "A Time", "B Time", "Track A", "Track B"],
389
- interactive=True, # can be sorted/edited; mapping is stable due to idx_map
390
- wrap=True,
391
- max_height=520,
392
- )
393
 
394
- with gr.Row():
395
- crop_mode = gr.Radio(
396
- choices=["per_track", "global"],
397
- value="per_track",
398
- label="裁剪方式(建议 per_track)",
399
- )
400
- offset_a = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track A 时间偏移(s)")
401
- offset_b = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track B 时间偏移(s)")
402
- scale_a = gr.Slider(0.95, 1.05, value=1.0, step=0.0005, label="Track A 时间缩放(scale)")
403
- scale_b = gr.Slider(0.95, 1.05, value=1.0, step=0.0005, label="Track B 时间缩放(scale)")
404
 
405
- with gr.Row():
406
- a_out = gr.Audio(label="Track A 片段", type="numpy")
407
- b_out = gr.Audio(label="Track B 片段", type="numpy")
408
 
409
- play_info = gr.JSON(label="当前片段")
 
 
 
 
410
 
411
- btn_align.click(
412
- load_and_align,
413
- inputs=[repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th],
414
- outputs=[df, state, a_out, b_out, play_info, scale_a, scale_b],
415
- )
416
 
417
- df.select(
418
- play_on_select,
419
- inputs=[df, crop_mode, offset_a, offset_b, scale_a, scale_b, state],
420
- outputs=[a_out, b_out, play_info],
421
- )
 
422
 
423
- if __name__ == "__main__":
424
- demo.launch()
 
 
 
 
 
 
 
 
425
 
426
- # import re
427
- # from dataclasses import dataclass
428
- # from typing import Any, Dict, List, Tuple, Optional
429
 
430
- # import gradio as gr
431
- # from huggingface_hub import list_repo_files, hf_hub_download
432
- # from pydub import AudioSegment
433
- # import numpy as np
434
 
435
- # # =========================================================
436
- # # Config
437
- # # =========================================================
438
- # MEDIA_EXTS = (".mp4", ".m4a", ".mp3", ".wav", ".flac", ".ogg", ".aac", ".mov", ".avi")
439
- # VTT_EXTS = (".vtt",)
440
 
441
- # DEFAULT_MAX_MID_DIFF = 1.5
 
 
 
 
442
 
443
- # # Normalize audio for stable playback in browsers
444
- # TARGET_SR = 48000
445
- # TARGET_CH = 1 # mono
446
- # TARGET_SW = 2 # 16-bit PCM
447
 
 
 
 
448
 
449
- # # =========================================================
450
- # # Data structures
451
- # # =========================================================
452
- # @dataclass
453
- # class Cue:
454
- # start: float
455
- # end: float
456
- # text: str
457
 
 
 
 
 
 
458
 
459
- # # =========================================================
460
- # # VTT parsing
461
- # # =========================================================
462
- # _TAG_RE = re.compile(r"</?[^>]+?>", re.IGNORECASE)
463
- # _VTT_TIME_RE = re.compile(
464
- # r"(?P<start>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})\s*-->\s*"
465
- # r"(?P<end>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})"
466
- # )
467
 
 
 
 
 
468
 
469
- # def _strip_tags(text: str) -> str:
470
- # return _TAG_RE.sub("", text).strip()
471
 
 
 
 
 
 
 
 
 
472
 
473
- # def _time_to_seconds(t: str) -> float:
474
- # parts = t.split(":")
475
- # if len(parts) == 3:
476
- # return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
477
- # if len(parts) == 2:
478
- # return int(parts[0]) * 60 + float(parts[1])
479
- # raise ValueError(f"Bad VTT timestamp: {t}")
480
 
 
 
 
 
 
 
 
 
481
 
482
- # def parse_vtt_file(path: str) -> List[Cue]:
483
- # with open(path, "r", encoding="utf-8") as f:
484
- # content = f.read()
485
 
486
- # # Remove BOM / WEBVTT header (if any)
487
- # content = content.replace("\ufeff", "")
488
- # content = re.sub(r"^\s*WEBVTT.*?\n", "", content, flags=re.IGNORECASE)
489
 
490
- # blocks = re.split(r"\r?\n\r?\n", content.strip())
491
- # cues: List[Cue] = []
492
 
493
- # for block in blocks:
494
- # lines = [l.strip() for l in block.splitlines() if l.strip()]
495
- # if not lines:
496
- # continue
 
 
 
497
 
498
- # # Locate the timestamp line (must contain "-->")
499
- # time_idx: Optional[int] = None
500
- # for i, line in enumerate(lines):
501
- # if "-->" in line:
502
- # time_idx = i
503
- # break
504
- # if time_idx is None:
505
- # continue
506
 
507
- # m = _VTT_TIME_RE.search(lines[time_idx])
508
- # if not m:
509
- # continue
510
 
511
- # start = _time_to_seconds(m.group("start"))
512
- # end = _time_to_seconds(m.group("end"))
513
- # if end <= start:
514
- # continue
515
 
516
- # # Only take lines after the timestamp line as subtitle text
517
- # text_lines = lines[time_idx + 1 :]
518
- # if not text_lines:
519
- # continue
520
 
521
- # text = _strip_tags("\n".join(text_lines))
522
- # if text:
523
- # cues.append(Cue(start=start, end=end, text=text))
 
524
 
525
- # return sorted(cues, key=lambda x: x.start)
 
 
 
 
 
 
 
526
 
 
 
 
527
 
528
- # # =========================================================
529
- # # Alignment (match by mid time), preserve per-track windows
530
- # # =========================================================
531
- # def align_by_time(a: List[Cue], b: List[Cue], th: float) -> List[Dict[str, Any]]:
532
- # out: List[Dict[str, Any]] = []
533
- # i, j, idx = 0, 0, 1
534
 
535
- # while i < len(a) and j < len(b):
536
- # ma = (a[i].start + a[i].end) / 2
537
- # mb = (b[j].start + b[j].end) / 2
 
538
 
539
- # if abs(ma - mb) <= th:
540
- # out.append(
541
- # {
542
- # "idx": idx,
543
- # # Per-track time window (recommended for playback)
544
- # "a_start": a[i].start,
545
- # "a_end": a[i].end,
546
- # "b_start": b[j].start,
547
- # "b_end": b[j].end,
548
- # # Optional global time window (for comparison/debug)
549
- # "start": min(a[i].start, b[j].start),
550
- # "end": max(a[i].end, b[j].end),
551
- # "a_text": a[i].text,
552
- # "b_text": b[j].text,
553
- # }
554
- # )
555
- # idx += 1
556
- # i += 1
557
- # j += 1
558
- # elif ma < mb:
559
- # i += 1
560
- # else:
561
- # j += 1
562
 
563
- # return out
564
 
565
 
566
- # # =========================================================
567
- # # Audio slicing -> return (sr, np.int16) for gr.Audio(type="numpy")
568
- # # =========================================================
569
- # def export_segment_numpy(audio: AudioSegment, start: float, end: float) -> Tuple[int, np.ndarray]:
570
- # """
571
- # Robust segment export for gr.Audio(type="numpy").
572
 
573
- # Key points:
574
- # - Clamp start/end (after any offsets) to valid range.
575
- # - Use *rounded* ms boundaries to avoid systematic truncation drift.
576
- # - Slice via pydub (ms-accurate) using the original stream timeline.
577
- # - Normalize to mono/48k/int16 for stable browser playback.
578
- # - Return (sr, int16 ndarray) to avoid float32 scaling pitfalls.
579
- # """
580
- # # Clamp and ensure minimum duration
581
- # start = float(start)
582
- # end = float(end)
583
- # if end < start:
584
- # start, end = end, start
585
- # start = max(0.0, start)
586
- # end = max(start + 0.05, end)
587
 
588
- # # Round to milliseconds (avoid int() truncation bias)
589
- # start_ms = int(round(start * 1000.0))
590
- # end_ms = int(round(end * 1000.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
592
- # seg = audio[start_ms:end_ms]
593
 
594
- # # Normalize to mono/48k/int16
595
- # seg = seg.set_channels(TARGET_CH).set_frame_rate(TARGET_SR).set_sample_width(TARGET_SW)
596
 
597
- # arr = np.asarray(seg.get_array_of_samples())
598
- # if arr.dtype != np.int16:
599
- # arr = arr.astype(np.int16, copy=False)
 
 
 
600
 
601
- # return TARGET_SR, np.ascontiguousarray(arr)
 
 
 
 
 
 
 
 
 
 
 
 
 
602
 
 
 
 
603
 
604
- # # =========================================================
605
- # # Helper: robustly read seg_idx from gr.Dataframe value
606
- # # =========================================================
607
- # def _get_seg_idx_from_df(df_value: Any, row: int) -> Optional[int]:
608
- # if df_value is None:
609
- # return None
610
 
611
- # # pandas DataFrame in some Gradio versions
612
- # try:
613
- # import pandas as pd # type: ignore
614
- # if isinstance(df_value, pd.DataFrame):
615
- # if row < 0 or row >= len(df_value.index) or df_value.shape[1] < 1:
616
- # return None
617
- # return int(df_value.iloc[row, 0])
618
- # except Exception:
619
- # pass
620
 
621
- # # list-of-lists
622
- # try:
623
- # if isinstance(df_value, list) and row >= 0 and row < len(df_value) and len(df_value[row]) >= 1:
624
- # return int(df_value[row][0])
625
- # except Exception:
626
- # return None
627
 
628
- # return None
629
 
630
 
631
- # # =========================================================
632
- # # Gradio callbacks
633
- # # =========================================================
634
- # def scan_dataset(repo_id: str, repo_type: str):
635
- # if not repo_id:
636
- # raise gr.Error("请填写 Dataset / Repo 名称(例如 org/dataset)。")
637
 
638
- # files = list_repo_files(repo_id, repo_type=repo_type)
639
- # media_files = sorted([f for f in files if f.lower().endswith(MEDIA_EXTS)])
640
- # vtt_files = sorted([f for f in files if f.lower().endswith(VTT_EXTS)])
 
 
 
 
 
 
641
 
642
- # if not media_files:
643
- # raise gr.Error("未找到媒体文件(mp4/mp3/wav 等)。")
644
- # if not vtt_files:
645
- # raise gr.Error("未找到 VTT 字幕文件。")
 
 
646
 
647
- # return (
648
- # gr.update(choices=media_files, value=media_files[0]),
649
- # gr.update(choices=media_files, value=media_files[0]),
650
- # gr.update(choices=vtt_files, value=vtt_files[0]),
651
- # gr.update(choices=vtt_files, value=vtt_files[0]),
652
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
 
654
 
655
- # def load_and_align(repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th):
656
- # if not all([repo_id, repo_type, media_a, media_b, vtt_a, vtt_b]):
657
- # raise gr.Error("请先选择 A/B 的媒体文件与 VTT 文件。")
658
 
659
- # local_media_a = hf_hub_download(repo_id, media_a, repo_type=repo_type)
660
- # local_media_b = hf_hub_download(repo_id, media_b, repo_type=repo_type)
661
- # local_vtt_a = hf_hub_download(repo_id, vtt_a, repo_type=repo_type)
662
- # local_vtt_b = hf_hub_download(repo_id, vtt_b, repo_type=repo_type)
663
 
664
- # try:
665
- # audio_a = AudioSegment.from_file(local_media_a)
666
- # audio_b = AudioSegment.from_file(local_media_b)
667
- # except Exception as e:
668
- # raise gr.Error(
669
- # "媒体解码失败。若是 mp4/m4a,通常需要 ffmpeg。\n"
670
- # f"原始错误: {repr(e)}"
671
- # )
672
 
673
- # cues_a = parse_vtt_file(local_vtt_a)
674
- # cues_b = parse_vtt_file(local_vtt_b)
675
- # if not cues_a or not cues_b:
676
- # raise gr.Error("VTT 解析为空,请检查字幕文件内容。")
677
 
678
- # aligned = align_by_time(cues_a, cues_b, float(th))
679
- # if not aligned:
680
- # raise gr.Error("未对齐到任何字幕片段,请尝试增大对齐阈值。")
681
 
682
- # rows = [
683
- # [
684
- # x["idx"],
685
- # f'{x["a_start"]:.2f}-{x["a_end"]:.2f}',
686
- # f'{x["b_start"]:.2f}-{x["b_end"]:.2f}',
687
- # x["a_text"],
688
- # x["b_text"],
689
- # ]
690
- # for x in aligned
691
- # ]
692
 
693
- # # Critical: build idx -> seg map to survive dataframe sorting/reordering
694
- # idx_map = {int(x["idx"]): x for x in aligned}
695
 
696
- # state = {
697
- # "aligned": aligned,
698
- # "idx_map": idx_map,
699
- # "audio_a": audio_a,
700
- # "audio_b": audio_b,
701
- # }
702
 
703
- # # Clear old playback outputs
704
- # return rows, state, None, None, {}
705
 
706
 
707
- # def play_on_select(evt: gr.SelectData, df_value, crop_mode, offset_a, offset_b, state):
708
- # if not state or "aligned" not in state:
709
- # raise gr.Error("请先加载并对齐。")
710
 
711
- # # evt.index: int or (row, col)
712
- # idx_raw = evt.index
713
- # row = int(idx_raw[0] if isinstance(idx_raw, (tuple, list)) else idx_raw)
714
 
715
- # offset_a = float(offset_a)
716
- # offset_b = float(offset_b)
717
 
718
- # # Prefer seg_idx from the clicked row's first column; then resolve via idx_map.
719
- # seg_idx = _get_seg_idx_from_df(df_value, row)
720
- # seg = None
721
- # idx_map = state.get("idx_map", {}) or {}
722
- # if seg_idx is not None and seg_idx in idx_map:
723
- # seg = idx_map[seg_idx]
724
- # else:
725
- # # Fallback to row->aligned if idx missing (should be rare)
726
- # aligned = state["aligned"]
727
- # if row < 0 or row >= len(aligned):
728
- # raise gr.Error("选中行越界,请重试或重新对齐。")
729
- # seg = aligned[row]
730
- # seg_idx = int(seg.get("idx", row + 1))
731
 
732
- # if crop_mode == "global":
733
- # a_start, a_end = seg["start"] + offset_a, seg["end"] + offset_a
734
- # b_start, b_end = seg["start"] + offset_b, seg["end"] + offset_b
735
- # else:
736
- # # per_track playback (recommended)
737
- # a_start, a_end = seg["a_start"] + offset_a, seg["a_end"] + offset_a
738
- # b_start, b_end = seg["b_start"] + offset_b, seg["b_end"] + offset_b
739
 
740
- # a_np = export_segment_numpy(state["audio_a"], a_start, a_end)
741
- # b_np = export_segment_numpy(state["audio_b"], b_start, b_end)
742
 
743
- # info = {
744
- # "segment": seg_idx,
745
- # "row": row,
746
- # "crop_mode": crop_mode,
747
- # "A_time": f"{a_start:.2f}-{a_end:.2f}",
748
- # "B_time": f"{b_start:.2f}-{b_end:.2f}",
749
- # }
750
- # return a_np, b_np, info
751
 
752
 
753
- # # =========================================================
754
- # # UI
755
- # # =========================================================
756
- # with gr.Blocks(title="双语音频字幕对齐(点击即播放)") as demo:
757
- # gr.Markdown(
758
- # "# 双语音频字幕对齐(点击表格即播放)\n"
759
- # "流程:扫描 Dataset → 选择 A/B 媒体与字幕 → 加载并对齐 → 点击表格任意单元格播放对应片段。\n"
760
- # "若字幕与音频整体存在固定延迟,可用 Track A/B 偏移进行校正。"
761
- # )
762
 
763
- # state = gr.State()
764
 
765
- # with gr.Row():
766
- # repo_id = gr.Textbox(label="Dataset / Repo 名称", placeholder="org/dataset")
767
- # repo_type = gr.Radio(["dataset", "model"], value="dataset", label="Repo 类型")
768
 
769
- # btn_scan = gr.Button("扫描 Dataset", variant="primary")
770
 
771
- # with gr.Row():
772
- # media_a = gr.Dropdown(label="Track A 媒体")
773
- # media_b = gr.Dropdown(label="Track B 媒体")
774
 
775
- # with gr.Row():
776
- # vtt_a = gr.Dropdown(label="Track A 字幕")
777
- # vtt_b = gr.Dropdown(label="Track B 字幕")
778
 
779
- # btn_scan.click(
780
- # scan_dataset,
781
- # inputs=[repo_id, repo_type],
782
- # outputs=[media_a, media_b, vtt_a, vtt_b],
783
- # )
784
 
785
- # th = gr.Slider(0.3, 5.0, value=DEFAULT_MAX_MID_DIFF, step=0.1, label="对齐阈值(秒)")
786
- # btn_align = gr.Button("加载并对齐", variant="primary")
787
 
788
- # df = gr.Dataframe(
789
- # headers=["#", "A Time", "B Time", "Track A", "Track B"],
790
- # interactive=True, # can be sorted/edited; mapping is stable due to idx_map
791
- # wrap=True,
792
- # max_height=520,
793
- # )
794
 
795
- # with gr.Row():
796
- # crop_mode = gr.Radio(
797
- # choices=["per_track", "global"],
798
- # value="per_track",
799
- # label="裁剪方式(建议 per_track)",
800
- # )
801
- # offset_a = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track A 时间偏移(s)")
802
- # offset_b = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track B 时间偏移(s)")
803
 
804
- # with gr.Row():
805
- # a_out = gr.Audio(label="Track A 片段", type="numpy")
806
- # b_out = gr.Audio(label="Track B 片段", type="numpy")
807
 
808
- # play_info = gr.JSON(label="当前片段")
809
 
810
- # btn_align.click(
811
- # load_and_align,
812
- # inputs=[repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th],
813
- # outputs=[df, state, a_out, b_out, play_info],
814
- # )
815
 
816
- # df.select(
817
- # play_on_select,
818
- # inputs=[df, crop_mode, offset_a, offset_b, state],
819
- # outputs=[a_out, b_out, play_info],
820
- # )
821
 
822
- # if __name__ == "__main__":
823
- # demo.launch()
 
1
+ # import re
2
+ # from dataclasses import dataclass
3
+ # from typing import Any, Dict, List, Tuple, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # import gradio as gr
6
+ # from huggingface_hub import list_repo_files, hf_hub_download
7
+ # from pydub import AudioSegment
8
+ # import numpy as np
9
 
10
+ # # =========================================================
11
+ # # Config
12
+ # # =========================================================
13
+ # MEDIA_EXTS = (".mp4", ".m4a", ".mp3", ".wav", ".flac", ".ogg", ".aac", ".mov", ".avi")
14
+ # VTT_EXTS = (".vtt",)
 
 
 
15
 
16
+ # DEFAULT_MAX_MID_DIFF = 1.5
17
 
18
+ # # Normalize audio for stable playback in browsers
19
+ # TARGET_SR = 48000
20
+ # TARGET_CH = 1 # mono
21
+ # TARGET_SW = 2 # 16-bit PCM
 
 
 
 
22
 
23
 
24
+ # # =========================================================
25
+ # # Data structures
26
+ # # =========================================================
27
+ # @dataclass
28
+ # class Cue:
29
+ # start: float
30
+ # end: float
31
+ # text: str
32
 
33
 
34
+ # # =========================================================
35
+ # # VTT parsing
36
+ # # =========================================================
37
+ # _TAG_RE = re.compile(r"</?[^>]+?>", re.IGNORECASE)
38
+ # _VTT_TIME_RE = re.compile(
39
+ # r"(?P<start>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})\s*-->\s*"
40
+ # r"(?P<end>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})"
41
+ # )
42
 
43
 
44
+ # def _strip_tags(text: str) -> str:
45
+ # return _TAG_RE.sub("", text).strip()
 
46
 
 
 
 
47
 
48
+ # def _time_to_seconds(t: str) -> float:
49
+ # parts = t.split(":")
50
+ # if len(parts) == 3:
51
+ # return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
52
+ # if len(parts) == 2:
53
+ # return int(parts[0]) * 60 + float(parts[1])
54
+ # raise ValueError(f"Bad VTT timestamp: {t}")
55
 
 
 
 
 
56
 
57
+ # def parse_vtt_file(path: str) -> List[Cue]:
58
+ # with open(path, "r", encoding="utf-8") as f:
59
+ # content = f.read()
 
 
 
 
 
60
 
61
+ # # Remove BOM / WEBVTT header (if any)
62
+ # content = content.replace("\ufeff", "")
63
+ # content = re.sub(r"^\s*WEBVTT.*?\n", "", content, flags=re.IGNORECASE)
64
 
65
+ # blocks = re.split(r"\r?\n\r?\n", content.strip())
66
+ # cues: List[Cue] = []
 
 
67
 
68
+ # for block in blocks:
69
+ # lines = [l.strip() for l in block.splitlines() if l.strip()]
70
+ # if not lines:
71
+ # continue
72
 
73
+ # # Locate the timestamp line (must contain "-->")
74
+ # time_idx: Optional[int] = None
75
+ # for i, line in enumerate(lines):
76
+ # if "-->" in line:
77
+ # time_idx = i
78
+ # break
79
+ # if time_idx is None:
80
+ # continue
81
 
82
+ # m = _VTT_TIME_RE.search(lines[time_idx])
83
+ # if not m:
84
+ # continue
85
 
86
+ # start = _time_to_seconds(m.group("start"))
87
+ # end = _time_to_seconds(m.group("end"))
88
+ # if end <= start:
89
+ # continue
90
 
91
+ # # Only take lines after the timestamp line as subtitle text
92
+ # text_lines = lines[time_idx + 1 :]
93
+ # if not text_lines:
94
+ # continue
 
 
95
 
96
+ # text = _strip_tags("\n".join(text_lines))
97
+ # if text:
98
+ # cues.append(Cue(start=start, end=end, text=text))
99
 
100
+ # return sorted(cues, key=lambda x: x.start)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
102
 
103
+ # # =========================================================
104
+ # # Alignment (match by mid time), preserve per-track windows
105
+ # # =========================================================
106
+ # def align_by_time(a: List[Cue], b: List[Cue], th: float) -> List[Dict[str, Any]]:
107
+ # out: List[Dict[str, Any]] = []
108
+ # i, j, idx = 0, 0, 1
109
 
110
+ # while i < len(a) and j < len(b):
111
+ # ma = (a[i].start + a[i].end) / 2
112
+ # mb = (b[j].start + b[j].end) / 2
 
 
 
113
 
114
+ # if abs(ma - mb) <= th:
115
+ # out.append(
116
+ # {
117
+ # "idx": idx,
118
+ # # Per-track time window (recommended for playback)
119
+ # "a_start": a[i].start,
120
+ # "a_end": a[i].end,
121
+ # "b_start": b[j].start,
122
+ # "b_end": b[j].end,
123
+ # # Optional global time window (for comparison/debug)
124
+ # "start": min(a[i].start, b[j].start),
125
+ # "end": max(a[i].end, b[j].end),
126
+ # "a_text": a[i].text,
127
+ # "b_text": b[j].text,
128
+ # }
129
+ # )
130
+ # idx += 1
131
+ # i += 1
132
+ # j += 1
133
+ # elif ma < mb:
134
+ # i += 1
135
+ # else:
136
+ # j += 1
137
 
138
+ # return out
 
 
139
 
 
140
 
141
+ # # =========================================================
142
+ # # Audio slicing -> return (sr, np.int16) for gr.Audio(type="numpy")
143
+ # # =========================================================
144
+ # def export_segment_numpy(audio: AudioSegment, start: float, end: float) -> Tuple[int, np.ndarray]:
145
+ # """
146
+ # Robust segment export for gr.Audio(type="numpy").
147
 
148
+ # Key points:
149
+ # - Clamp start/end (after any offsets) to valid range.
150
+ # - Use *rounded* ms boundaries to avoid systematic truncation drift.
151
+ # - Slice via pydub (ms-accurate) using the original stream timeline.
152
+ # - Normalize to mono/48k/int16 for stable browser playback.
153
+ # - Return (sr, int16 ndarray) to avoid float32 scaling pitfalls.
154
+ # """
155
+ # # Clamp and ensure minimum duration
156
+ # start = float(start)
157
+ # end = float(end)
158
+ # if end < start:
159
+ # start, end = end, start
160
+ # start = max(0.0, start)
161
+ # end = max(start + 0.05, end)
162
 
163
+ # # Round to milliseconds (avoid int() truncation bias)
164
+ # start_ms = int(round(start * 1000.0))
165
+ # end_ms = int(round(end * 1000.0))
166
 
167
+ # seg = audio[start_ms:end_ms]
 
 
 
 
 
168
 
169
+ # # Normalize to mono/48k/int16
170
+ # seg = seg.set_channels(TARGET_CH).set_frame_rate(TARGET_SR).set_sample_width(TARGET_SW)
 
 
 
 
 
 
 
171
 
172
+ # arr = np.asarray(seg.get_array_of_samples())
173
+ # if arr.dtype != np.int16:
174
+ # arr = arr.astype(np.int16, copy=False)
 
 
 
175
 
176
+ # return TARGET_SR, np.ascontiguousarray(arr)
177
 
178
 
179
+ # # =========================================================
180
+ # # Helper: robustly read seg_idx from gr.Dataframe value
181
+ # # =========================================================
182
+ # def _get_seg_idx_from_df(df_value: Any, row: int) -> Optional[int]:
183
+ # if df_value is None:
184
+ # return None
185
 
186
+ # # pandas DataFrame in some Gradio versions
187
+ # try:
188
+ # import pandas as pd # type: ignore
189
+ # if isinstance(df_value, pd.DataFrame):
190
+ # if row < 0 or row >= len(df_value.index) or df_value.shape[1] < 1:
191
+ # return None
192
+ # return int(df_value.iloc[row, 0])
193
+ # except Exception:
194
+ # pass
195
 
196
+ # # list-of-lists
197
+ # try:
198
+ # if isinstance(df_value, list) and row >= 0 and row < len(df_value) and len(df_value[row]) >= 1:
199
+ # return int(df_value[row][0])
200
+ # except Exception:
201
+ # return None
202
 
203
+ # return None
 
 
 
 
 
204
 
205
 
206
+ # # =========================================================
207
+ # # Gradio callbacks
208
+ # # =========================================================
209
+ # def scan_dataset(repo_id: str, repo_type: str):
210
+ # if not repo_id:
211
+ # raise gr.Error("请填写 Dataset / Repo 名称(例如 org/dataset)。")
212
 
213
+ # files = list_repo_files(repo_id, repo_type=repo_type)
214
+ # media_files = sorted([f for f in files if f.lower().endswith(MEDIA_EXTS)])
215
+ # vtt_files = sorted([f for f in files if f.lower().endswith(VTT_EXTS)])
 
216
 
217
+ # if not media_files:
218
+ # raise gr.Error("未找到媒体文件(mp4/mp3/wav 等)。")
219
+ # if not vtt_files:
220
+ # raise gr.Error("未找到 VTT 字幕文件。")
 
 
 
 
221
 
222
+ # return (
223
+ # gr.update(choices=media_files, value=media_files[0]),
224
+ # gr.update(choices=media_files, value=media_files[0]),
225
+ # gr.update(choices=vtt_files, value=vtt_files[0]),
226
+ # gr.update(choices=vtt_files, value=vtt_files[0]),
227
+ # )
228
 
 
 
 
 
 
 
 
229
 
230
+ # def load_and_align(repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th):
231
+ # if not all([repo_id, repo_type, media_a, media_b, vtt_a, vtt_b]):
232
+ # raise gr.Error("请先选择 A/B 的媒体文件与 VTT 文件。")
233
 
234
+ # local_media_a = hf_hub_download(repo_id, media_a, repo_type=repo_type)
235
+ # local_media_b = hf_hub_download(repo_id, media_b, repo_type=repo_type)
236
+ # local_vtt_a = hf_hub_download(repo_id, vtt_a, repo_type=repo_type)
237
+ # local_vtt_b = hf_hub_download(repo_id, vtt_b, repo_type=repo_type)
238
 
239
+ # try:
240
+ # audio_a = AudioSegment.from_file(local_media_a)
241
+ # audio_b = AudioSegment.from_file(local_media_b)
242
+ # except Exception as e:
243
+ # raise gr.Error(
244
+ # "媒体解码失败。若是 mp4/m4a,通常需要 ffmpeg。\n"
245
+ # f"原始错误: {repr(e)}"
246
+ # )
 
 
247
 
248
+ # cues_a = parse_vtt_file(local_vtt_a)
249
+ # cues_b = parse_vtt_file(local_vtt_b)
250
+ # if not cues_a or not cues_b:
251
+ # raise gr.Error("VTT 解析为空,请检查字幕文件内容。")
252
 
253
+ # # ---- Drift fix: estimate time-scale (linear) between VTT timeline and audio timeline ----
254
+ # # If you observe increasing offset over time, it is usually a *scale* mismatch rather than a constant offset.
255
+ # # We estimate per-track scale by comparing audio duration to the last cue end time.
256
+ # a_vtt_end = max(c.end for c in cues_a) if cues_a else 0.0
257
+ # b_vtt_end = max(c.end for c in cues_b) if cues_b else 0.0
258
+ # a_dur = float(audio_a.duration_seconds)
259
+ # b_dur = float(audio_b.duration_seconds)
 
 
 
 
 
260
 
261
+ # # Default scale = 1.0 when we cannot estimate reliably.
262
+ # scale_a_suggest = (a_dur / a_vtt_end) if a_vtt_end > 1.0 and a_dur > 1.0 else 1.0
263
+ # scale_b_suggest = (b_dur / b_vtt_end) if b_vtt_end > 1.0 and b_dur > 1.0 else 1.0
264
 
265
+ # aligned = align_by_time(cues_a, cues_b, float(th))
266
+ # if not aligned:
267
+ # raise gr.Error("未对齐到任何字幕片段,请尝试增大对齐阈值。")
268
 
269
+ # rows = [
270
+ # [
271
+ # x["idx"],
272
+ # f'{x["a_start"]:.2f}-{x["a_end"]:.2f}',
273
+ # f'{x["b_start"]:.2f}-{x["b_end"]:.2f}',
274
+ # x["a_text"],
275
+ # x["b_text"],
276
+ # ]
277
+ # for x in aligned
278
+ # ]
279
 
280
+ # # Critical: build idx -> seg map to survive dataframe sorting/reordering
281
+ # idx_map = {int(x["idx"]): x for x in aligned}
 
282
 
283
+ # state = {
284
+ # "aligned": aligned,
285
+ # "idx_map": idx_map,
286
+ # "audio_a": audio_a,
287
+ # "audio_b": audio_b,
288
+ # "scale_a_suggest": scale_a_suggest,
289
+ # "scale_b_suggest": scale_b_suggest,
290
+ # "a_vtt_end": a_vtt_end,
291
+ # "b_vtt_end": b_vtt_end,
292
+ # "a_dur": a_dur,
293
+ # "b_dur": b_dur,
294
+ # }
295
 
296
+ # # Clear old playback outputs
297
+ # return rows, state, None, None, {}, gr.update(value=scale_a_suggest), gr.update(value=scale_b_suggest)
 
 
 
 
 
 
 
 
 
 
 
298
 
 
 
 
 
 
 
 
299
 
300
+ # def play_on_select(evt: gr.SelectData, df_value, crop_mode, offset_a, offset_b, scale_a, scale_b, state):
301
+ # if not state or "aligned" not in state:
302
+ # raise gr.Error("请先加载并对齐。")
303
 
304
+ # # evt.index: int or (row, col)
305
+ # idx_raw = evt.index
306
+ # row = int(idx_raw[0] if isinstance(idx_raw, (tuple, list)) else idx_raw)
 
 
 
 
 
 
 
 
 
307
 
308
+ # offset_a = float(offset_a)
309
+ # offset_b = float(offset_b)
310
+ # scale_a = float(scale_a)
311
+ # scale_b = float(scale_b)
312
 
313
+ # # Prefer seg_idx from the clicked row's first column; then resolve via idx_map.
314
+ # seg_idx = _get_seg_idx_from_df(df_value, row)
315
+ # seg = None
316
+ # idx_map = state.get("idx_map", {}) or {}
317
+ # if seg_idx is not None and seg_idx in idx_map:
318
+ # seg = idx_map[seg_idx]
319
+ # else:
320
+ # # Fallback to row->aligned if idx missing (should be rare)
321
+ # aligned = state["aligned"]
322
+ # if row < 0 or row >= len(aligned):
323
+ # raise gr.Error("选中行越界,请重试或重新对齐。")
324
+ # seg = aligned[row]
325
+ # seg_idx = int(seg.get("idx", row + 1))
326
 
327
+ # if crop_mode == "global":
328
+ # a_start, a_end = seg["start"] * scale_a + offset_a, seg["end"] * scale_a + offset_a
329
+ # b_start, b_end = seg["start"] * scale_b + offset_b, seg["end"] * scale_b + offset_b
330
+ # else:
331
+ # # per_track playback (recommended)
332
+ # a_start, a_end = seg["a_start"] * scale_a + offset_a, seg["a_end"] * scale_a + offset_a
333
+ # b_start, b_end = seg["b_start"] * scale_b + offset_b, seg["b_end"] * scale_b + offset_b
334
 
335
+ # a_np = export_segment_numpy(state["audio_a"], a_start, a_end)
336
+ # b_np = export_segment_numpy(state["audio_b"], b_start, b_end)
 
337
 
338
+ # info = {
339
+ # "segment": seg_idx,
340
+ # "row": row,
341
+ # "crop_mode": crop_mode,
342
+ # "A_time": f"{a_start:.2f}-{a_end:.2f}",
343
+ # "B_time": f"{b_start:.2f}-{b_end:.2f}",
344
+ # "scale_a": scale_a,
345
+ # "scale_b": scale_b,
346
+ # "scale_a_suggest": state.get("scale_a_suggest", 1.0),
347
+ # "scale_b_suggest": state.get("scale_b_suggest", 1.0),
348
+ # }
349
+ # return a_np, b_np, info
350
 
 
 
 
351
 
352
+ # # =========================================================
353
+ # # UI
354
+ # # =========================================================
355
+ # with gr.Blocks(title="双语音频字幕对齐(点击即播放)") as demo:
356
+ # gr.Markdown(
357
+ # "# 双语音频字幕对齐(点击表格即播放)\n"
358
+ # "流程:扫描 Dataset → 选择 A/B 媒体与字幕 → 加载并对齐 → 点击表格任意单元格播放对应片段。\n"
359
+ # "若字幕与音频整体存在固定延迟,可用 Track A/B 偏移进行校正。"
360
+ # )
361
 
362
+ # state = gr.State()
 
 
 
 
363
 
364
+ # with gr.Row():
365
+ # repo_id = gr.Textbox(label="Dataset / Repo 名称", placeholder="org/dataset")
366
+ # repo_type = gr.Radio(["dataset", "model"], value="dataset", label="Repo 类型")
367
 
368
+ # btn_scan = gr.Button("扫描 Dataset", variant="primary")
 
 
 
 
 
369
 
370
+ # with gr.Row():
371
+ # media_a = gr.Dropdown(label="Track A 媒体")
372
+ # media_b = gr.Dropdown(label="Track B 媒体")
 
 
 
 
 
 
 
373
 
374
+ # with gr.Row():
375
+ # vtt_a = gr.Dropdown(label="Track A 字幕")
376
+ # vtt_b = gr.Dropdown(label="Track B 字幕")
377
 
378
+ # btn_scan.click(
379
+ # scan_dataset,
380
+ # inputs=[repo_id, repo_type],
381
+ # outputs=[media_a, media_b, vtt_a, vtt_b],
382
+ # )
383
 
384
+ # th = gr.Slider(0.3, 5.0, value=DEFAULT_MAX_MID_DIFF, step=0.1, label="对齐阈值(秒)")
385
+ # btn_align = gr.Button("加载并对齐", variant="primary")
 
 
 
386
 
387
+ # df = gr.Dataframe(
388
+ # headers=["#", "A Time", "B Time", "Track A", "Track B"],
389
+ # interactive=True, # can be sorted/edited; mapping is stable due to idx_map
390
+ # wrap=True,
391
+ # max_height=520,
392
+ # )
393
 
394
+ # with gr.Row():
395
+ # crop_mode = gr.Radio(
396
+ # choices=["per_track", "global"],
397
+ # value="per_track",
398
+ # label="裁剪方式(建议 per_track)",
399
+ # )
400
+ # offset_a = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track A 时间偏移(s)")
401
+ # offset_b = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track B 时间偏移(s)")
402
+ # scale_a = gr.Slider(0.95, 1.05, value=1.0, step=0.0005, label="Track A 时间缩放(scale)")
403
+ # scale_b = gr.Slider(0.95, 1.05, value=1.0, step=0.0005, label="Track B 时间缩放(scale)")
404
 
405
+ # with gr.Row():
406
+ # a_out = gr.Audio(label="Track A 片段", type="numpy")
407
+ # b_out = gr.Audio(label="Track B 片段", type="numpy")
408
 
409
+ # play_info = gr.JSON(label="当前片段")
 
 
 
410
 
411
+ # btn_align.click(
412
+ # load_and_align,
413
+ # inputs=[repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th],
414
+ # outputs=[df, state, a_out, b_out, play_info, scale_a, scale_b],
415
+ # )
416
 
417
+ # df.select(
418
+ # play_on_select,
419
+ # inputs=[df, crop_mode, offset_a, offset_b, scale_a, scale_b, state],
420
+ # outputs=[a_out, b_out, play_info],
421
+ # )
422
 
423
+ # if __name__ == "__main__":
424
+ # demo.launch()
 
 
425
 
426
+ import re
427
+ from dataclasses import dataclass
428
+ from typing import Any, Dict, List, Tuple, Optional
429
 
430
+ import gradio as gr
431
+ from huggingface_hub import list_repo_files, hf_hub_download
432
+ from pydub import AudioSegment
433
+ import numpy as np
 
 
 
 
434
 
435
+ # =========================================================
436
+ # Config
437
+ # =========================================================
438
+ MEDIA_EXTS = (".mp4", ".m4a", ".mp3", ".wav", ".flac", ".ogg", ".aac", ".mov", ".avi")
439
+ VTT_EXTS = (".vtt",)
440
 
441
+ DEFAULT_MAX_MID_DIFF = 1.5
 
 
 
 
 
 
 
442
 
443
+ # Normalize audio for stable playback in browsers
444
+ TARGET_SR = 48000
445
+ TARGET_CH = 1 # mono
446
+ TARGET_SW = 2 # 16-bit PCM
447
 
 
 
448
 
449
+ # =========================================================
450
+ # Data structures
451
+ # =========================================================
452
+ @dataclass
453
+ class Cue:
454
+ start: float
455
+ end: float
456
+ text: str
457
 
 
 
 
 
 
 
 
458
 
459
+ # =========================================================
460
+ # VTT parsing
461
+ # =========================================================
462
+ _TAG_RE = re.compile(r"</?[^>]+?>", re.IGNORECASE)
463
+ _VTT_TIME_RE = re.compile(
464
+ r"(?P<start>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})\s*-->\s*"
465
+ r"(?P<end>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})"
466
+ )
467
 
 
 
 
468
 
469
+ def _strip_tags(text: str) -> str:
470
+ return _TAG_RE.sub("", text).strip()
 
471
 
 
 
472
 
473
+ def _time_to_seconds(t: str) -> float:
474
+ parts = t.split(":")
475
+ if len(parts) == 3:
476
+ return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
477
+ if len(parts) == 2:
478
+ return int(parts[0]) * 60 + float(parts[1])
479
+ raise ValueError(f"Bad VTT timestamp: {t}")
480
 
 
 
 
 
 
 
 
 
481
 
482
+ def parse_vtt_file(path: str) -> List[Cue]:
483
+ with open(path, "r", encoding="utf-8") as f:
484
+ content = f.read()
485
 
486
+ # Remove BOM / WEBVTT header (if any)
487
+ content = content.replace("\ufeff", "")
488
+ content = re.sub(r"^\s*WEBVTT.*?\n", "", content, flags=re.IGNORECASE)
 
489
 
490
+ blocks = re.split(r"\r?\n\r?\n", content.strip())
491
+ cues: List[Cue] = []
 
 
492
 
493
+ for block in blocks:
494
+ lines = [l.strip() for l in block.splitlines() if l.strip()]
495
+ if not lines:
496
+ continue
497
 
498
+ # Locate the timestamp line (must contain "-->")
499
+ time_idx: Optional[int] = None
500
+ for i, line in enumerate(lines):
501
+ if "-->" in line:
502
+ time_idx = i
503
+ break
504
+ if time_idx is None:
505
+ continue
506
 
507
+ m = _VTT_TIME_RE.search(lines[time_idx])
508
+ if not m:
509
+ continue
510
 
511
+ start = _time_to_seconds(m.group("start"))
512
+ end = _time_to_seconds(m.group("end"))
513
+ if end <= start:
514
+ continue
 
 
515
 
516
+ # Only take lines after the timestamp line as subtitle text
517
+ text_lines = lines[time_idx + 1 :]
518
+ if not text_lines:
519
+ continue
520
 
521
+ text = _strip_tags("\n".join(text_lines))
522
+ if text:
523
+ cues.append(Cue(start=start, end=end, text=text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
+ return sorted(cues, key=lambda x: x.start)
526
 
527
 
528
+ # =========================================================
529
+ # Alignment (match by mid time), preserve per-track windows
530
+ # =========================================================
531
+ def align_by_time(a: List[Cue], b: List[Cue], th: float) -> List[Dict[str, Any]]:
532
+ out: List[Dict[str, Any]] = []
533
+ i, j, idx = 0, 0, 1
534
 
535
+ while i < len(a) and j < len(b):
536
+ ma = (a[i].start + a[i].end) / 2
537
+ mb = (b[j].start + b[j].end) / 2
 
 
 
 
 
 
 
 
 
 
 
538
 
539
+ if abs(ma - mb) <= th:
540
+ out.append(
541
+ {
542
+ "idx": idx,
543
+ # Per-track time window (recommended for playback)
544
+ "a_start": a[i].start,
545
+ "a_end": a[i].end,
546
+ "b_start": b[j].start,
547
+ "b_end": b[j].end,
548
+ # Optional global time window (for comparison/debug)
549
+ "start": min(a[i].start, b[j].start),
550
+ "end": max(a[i].end, b[j].end),
551
+ "a_text": a[i].text,
552
+ "b_text": b[j].text,
553
+ }
554
+ )
555
+ idx += 1
556
+ i += 1
557
+ j += 1
558
+ elif ma < mb:
559
+ i += 1
560
+ else:
561
+ j += 1
562
 
563
+ return out
564
 
 
 
565
 
566
+ # =========================================================
567
+ # Audio slicing -> return (sr, np.int16) for gr.Audio(type="numpy")
568
+ # =========================================================
569
+ def export_segment_numpy(audio: AudioSegment, start: float, end: float) -> Tuple[int, np.ndarray]:
570
+ """
571
+ Robust segment export for gr.Audio(type="numpy").
572
 
573
+ Key points:
574
+ - Clamp start/end (after any offsets) to valid range.
575
+ - Use *rounded* ms boundaries to avoid systematic truncation drift.
576
+ - Slice via pydub (ms-accurate) using the original stream timeline.
577
+ - Normalize to mono/48k/int16 for stable browser playback.
578
+ - Return (sr, int16 ndarray) to avoid float32 scaling pitfalls.
579
+ """
580
+ # Clamp and ensure minimum duration
581
+ start = float(start)
582
+ end = float(end)
583
+ if end < start:
584
+ start, end = end, start
585
+ start = max(0.0, start)
586
+ end = max(start + 0.05, end)
587
 
588
+ # Round to milliseconds (avoid int() truncation bias)
589
+ start_ms = int(round(start * 1000.0))
590
+ end_ms = int(round(end * 1000.0))
591
 
592
+ seg = audio[start_ms:end_ms]
 
 
 
 
 
593
 
594
+ # Normalize to mono/48k/int16
595
+ seg = seg.set_channels(TARGET_CH).set_frame_rate(TARGET_SR).set_sample_width(TARGET_SW)
 
 
 
 
 
 
 
596
 
597
+ arr = np.asarray(seg.get_array_of_samples())
598
+ if arr.dtype != np.int16:
599
+ arr = arr.astype(np.int16, copy=False)
 
 
 
600
 
601
+ return TARGET_SR, np.ascontiguousarray(arr)
602
 
603
 
604
+ # =========================================================
605
+ # Helper: robustly read seg_idx from gr.Dataframe value
606
+ # =========================================================
607
+ def _get_seg_idx_from_df(df_value: Any, row: int) -> Optional[int]:
608
+ if df_value is None:
609
+ return None
610
 
611
+ # pandas DataFrame in some Gradio versions
612
+ try:
613
+ import pandas as pd # type: ignore
614
+ if isinstance(df_value, pd.DataFrame):
615
+ if row < 0 or row >= len(df_value.index) or df_value.shape[1] < 1:
616
+ return None
617
+ return int(df_value.iloc[row, 0])
618
+ except Exception:
619
+ pass
620
 
621
+ # list-of-lists
622
+ try:
623
+ if isinstance(df_value, list) and row >= 0 and row < len(df_value) and len(df_value[row]) >= 1:
624
+ return int(df_value[row][0])
625
+ except Exception:
626
+ return None
627
 
628
+ return None
629
+
630
+
631
+ # =========================================================
632
+ # Gradio callbacks
633
+ # =========================================================
634
+ def scan_dataset(repo_id: str, repo_type: str):
635
+ if not repo_id:
636
+ raise gr.Error("请填写 Dataset / Repo 名称(例如 org/dataset)。")
637
+
638
+ files = list_repo_files(repo_id, repo_type=repo_type)
639
+ media_files = sorted([f for f in files if f.lower().endswith(MEDIA_EXTS)])
640
+ vtt_files = sorted([f for f in files if f.lower().endswith(VTT_EXTS)])
641
+
642
+ if not media_files:
643
+ raise gr.Error("未找到媒体文件(mp4/mp3/wav 等)。")
644
+ if not vtt_files:
645
+ raise gr.Error("未找到 VTT 字幕文件。")
646
+
647
+ return (
648
+ gr.update(choices=media_files, value=media_files[0]),
649
+ gr.update(choices=media_files, value=media_files[0]),
650
+ gr.update(choices=vtt_files, value=vtt_files[0]),
651
+ gr.update(choices=vtt_files, value=vtt_files[0]),
652
+ )
653
 
654
 
655
+ def load_and_align(repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th):
656
+ if not all([repo_id, repo_type, media_a, media_b, vtt_a, vtt_b]):
657
+ raise gr.Error("请先选择 A/B 的媒体文件与 VTT 文件。")
658
 
659
+ local_media_a = hf_hub_download(repo_id, media_a, repo_type=repo_type)
660
+ local_media_b = hf_hub_download(repo_id, media_b, repo_type=repo_type)
661
+ local_vtt_a = hf_hub_download(repo_id, vtt_a, repo_type=repo_type)
662
+ local_vtt_b = hf_hub_download(repo_id, vtt_b, repo_type=repo_type)
663
 
664
+ try:
665
+ audio_a = AudioSegment.from_file(local_media_a)
666
+ audio_b = AudioSegment.from_file(local_media_b)
667
+ except Exception as e:
668
+ raise gr.Error(
669
+ "媒体解码失败。若是 mp4/m4a,通常需要 ffmpeg。\n"
670
+ f"原始错误: {repr(e)}"
671
+ )
672
 
673
+ cues_a = parse_vtt_file(local_vtt_a)
674
+ cues_b = parse_vtt_file(local_vtt_b)
675
+ if not cues_a or not cues_b:
676
+ raise gr.Error("VTT 解析为空,请检查字幕文件内容。")
677
 
678
+ aligned = align_by_time(cues_a, cues_b, float(th))
679
+ if not aligned:
680
+ raise gr.Error("未对齐到任何字幕片段,请尝试增大对齐阈值。")
681
 
682
+ rows = [
683
+ [
684
+ x["idx"],
685
+ f'{x["a_start"]:.2f}-{x["a_end"]:.2f}',
686
+ f'{x["b_start"]:.2f}-{x["b_end"]:.2f}',
687
+ x["a_text"],
688
+ x["b_text"],
689
+ ]
690
+ for x in aligned
691
+ ]
692
 
693
+ # Critical: build idx -> seg map to survive dataframe sorting/reordering
694
+ idx_map = {int(x["idx"]): x for x in aligned}
695
 
696
+ state = {
697
+ "aligned": aligned,
698
+ "idx_map": idx_map,
699
+ "audio_a": audio_a,
700
+ "audio_b": audio_b,
701
+ }
702
 
703
+ # Clear old playback outputs
704
+ return rows, state, None, None, {}
705
 
706
 
707
+ def play_on_select(evt: gr.SelectData, df_value, crop_mode, offset_a, offset_b, state):
708
+ if not state or "aligned" not in state:
709
+ raise gr.Error("请先加载并对齐。")
710
 
711
+ # evt.index: int or (row, col)
712
+ idx_raw = evt.index
713
+ row = int(idx_raw[0] if isinstance(idx_raw, (tuple, list)) else idx_raw)
714
 
715
+ offset_a = float(offset_a)
716
+ offset_b = float(offset_b)
717
 
718
+ # Prefer seg_idx from the clicked row's first column; then resolve via idx_map.
719
+ seg_idx = _get_seg_idx_from_df(df_value, row)
720
+ seg = None
721
+ idx_map = state.get("idx_map", {}) or {}
722
+ if seg_idx is not None and seg_idx in idx_map:
723
+ seg = idx_map[seg_idx]
724
+ else:
725
+ # Fallback to row->aligned if idx missing (should be rare)
726
+ aligned = state["aligned"]
727
+ if row < 0 or row >= len(aligned):
728
+ raise gr.Error("选中行越界,请重试或重新对齐。")
729
+ seg = aligned[row]
730
+ seg_idx = int(seg.get("idx", row + 1))
731
 
732
+ if crop_mode == "global":
733
+ a_start, a_end = seg["start"] + offset_a, seg["end"] + offset_a
734
+ b_start, b_end = seg["start"] + offset_b, seg["end"] + offset_b
735
+ else:
736
+ # per_track playback (recommended)
737
+ a_start, a_end = seg["a_start"] + offset_a, seg["a_end"] + offset_a
738
+ b_start, b_end = seg["b_start"] + offset_b, seg["b_end"] + offset_b
739
 
740
+ a_np = export_segment_numpy(state["audio_a"], a_start, a_end)
741
+ b_np = export_segment_numpy(state["audio_b"], b_start, b_end)
742
 
743
+ info = {
744
+ "segment": seg_idx,
745
+ "row": row,
746
+ "crop_mode": crop_mode,
747
+ "A_time": f"{a_start:.2f}-{a_end:.2f}",
748
+ "B_time": f"{b_start:.2f}-{b_end:.2f}",
749
+ }
750
+ return a_np, b_np, info
751
 
752
 
753
+ # =========================================================
754
+ # UI
755
+ # =========================================================
756
+ with gr.Blocks(title="双语音频字幕对齐(点击即播放)") as demo:
757
+ gr.Markdown(
758
+ "# 双语音频字幕对齐(点击表格即播放)\n"
759
+ "流程:扫描 Dataset → 选择 A/B 媒体与字幕 → 加载并对齐 → 点击表格任意单元格播放对应片段。\n"
760
+ "若字幕与音频整体存在固定延迟,可用 Track A/B 偏移进行校正。"
761
+ )
762
 
763
+ state = gr.State()
764
 
765
+ with gr.Row():
766
+ repo_id = gr.Textbox(label="Dataset / Repo 名称", placeholder="org/dataset")
767
+ repo_type = gr.Radio(["dataset", "model"], value="dataset", label="Repo 类型")
768
 
769
+ btn_scan = gr.Button("扫描 Dataset", variant="primary")
770
 
771
+ with gr.Row():
772
+ media_a = gr.Dropdown(label="Track A 媒体")
773
+ media_b = gr.Dropdown(label="Track B 媒体")
774
 
775
+ with gr.Row():
776
+ vtt_a = gr.Dropdown(label="Track A 字幕")
777
+ vtt_b = gr.Dropdown(label="Track B 字幕")
778
 
779
+ btn_scan.click(
780
+ scan_dataset,
781
+ inputs=[repo_id, repo_type],
782
+ outputs=[media_a, media_b, vtt_a, vtt_b],
783
+ )
784
 
785
+ th = gr.Slider(0.3, 5.0, value=DEFAULT_MAX_MID_DIFF, step=0.1, label="对齐阈值(秒)")
786
+ btn_align = gr.Button("加载并对齐", variant="primary")
787
 
788
+ df = gr.Dataframe(
789
+ headers=["#", "A Time", "B Time", "Track A", "Track B"],
790
+ interactive=True, # can be sorted/edited; mapping is stable due to idx_map
791
+ wrap=True,
792
+ max_height=520,
793
+ )
794
 
795
+ with gr.Row():
796
+ crop_mode = gr.Radio(
797
+ choices=["per_track", "global"],
798
+ value="per_track",
799
+ label="裁剪方式(建议 per_track)",
800
+ )
801
+ offset_a = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track A 时间偏移(s)")
802
+ offset_b = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track B 时间偏移(s)")
803
 
804
+ with gr.Row():
805
+ a_out = gr.Audio(label="Track A 片段", type="numpy")
806
+ b_out = gr.Audio(label="Track B 片段", type="numpy")
807
 
808
+ play_info = gr.JSON(label="当前片段")
809
 
810
+ btn_align.click(
811
+ load_and_align,
812
+ inputs=[repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th],
813
+ outputs=[df, state, a_out, b_out, play_info],
814
+ )
815
 
816
+ df.select(
817
+ play_on_select,
818
+ inputs=[df, crop_mode, offset_a, offset_b, state],
819
+ outputs=[a_out, b_out, play_info],
820
+ )
821
 
822
+ if __name__ == "__main__":
823
+ demo.launch()