Spaces:
Sleeping
Sleeping
| # import re | |
| # from dataclasses import dataclass | |
| # from typing import Any, Dict, List, Tuple, Optional | |
| # import gradio as gr | |
| # from huggingface_hub import list_repo_files, hf_hub_download | |
| # from pydub import AudioSegment | |
| # import numpy as np | |
| # # ========================================================= | |
| # # Config | |
| # # ========================================================= | |
| # MEDIA_EXTS = (".mp4", ".m4a", ".mp3", ".wav", ".flac", ".ogg", ".aac", ".mov", ".avi") | |
| # VTT_EXTS = (".vtt",) | |
| # DEFAULT_MAX_MID_DIFF = 1.5 | |
| # # Normalize audio for stable playback in browsers | |
| # TARGET_SR = 48000 | |
| # TARGET_CH = 1 # mono | |
| # TARGET_SW = 2 # 16-bit PCM | |
| # # ========================================================= | |
| # # Data structures | |
| # # ========================================================= | |
| # @dataclass | |
| # class Cue: | |
| # start: float | |
| # end: float | |
| # text: str | |
| # # ========================================================= | |
| # # VTT parsing | |
| # # ========================================================= | |
| # _TAG_RE = re.compile(r"</?[^>]+?>", re.IGNORECASE) | |
| # _VTT_TIME_RE = re.compile( | |
| # r"(?P<start>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})\s*-->\s*" | |
| # r"(?P<end>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})" | |
| # ) | |
| # def _strip_tags(text: str) -> str: | |
| # return _TAG_RE.sub("", text).strip() | |
| # def _time_to_seconds(t: str) -> float: | |
| # parts = t.split(":") | |
| # if len(parts) == 3: | |
| # return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2]) | |
| # if len(parts) == 2: | |
| # return int(parts[0]) * 60 + float(parts[1]) | |
| # raise ValueError(f"Bad VTT timestamp: {t}") | |
| # def parse_vtt_file(path: str) -> List[Cue]: | |
| # with open(path, "r", encoding="utf-8") as f: | |
| # content = f.read() | |
| # # Remove BOM / WEBVTT header (if any) | |
| # content = content.replace("\ufeff", "") | |
| # content = re.sub(r"^\s*WEBVTT.*?\n", "", content, flags=re.IGNORECASE) | |
| # blocks = re.split(r"\r?\n\r?\n", content.strip()) | |
| # cues: List[Cue] = [] | |
| # for block in blocks: | |
| # lines = [l.strip() for l in block.splitlines() if l.strip()] | |
| # if not lines: | |
| # continue | |
| # # Locate the timestamp line (must contain "-->") | |
| # time_idx: Optional[int] = None | |
| # for i, line in enumerate(lines): | |
| # if "-->" in line: | |
| # time_idx = i | |
| # break | |
| # if time_idx is None: | |
| # continue | |
| # m = _VTT_TIME_RE.search(lines[time_idx]) | |
| # if not m: | |
| # continue | |
| # start = _time_to_seconds(m.group("start")) | |
| # end = _time_to_seconds(m.group("end")) | |
| # if end <= start: | |
| # continue | |
| # # Only take lines after the timestamp line as subtitle text | |
| # text_lines = lines[time_idx + 1 :] | |
| # if not text_lines: | |
| # continue | |
| # text = _strip_tags("\n".join(text_lines)) | |
| # if text: | |
| # cues.append(Cue(start=start, end=end, text=text)) | |
| # return sorted(cues, key=lambda x: x.start) | |
| # # ========================================================= | |
| # # Alignment (match by mid time), preserve per-track windows | |
| # # ========================================================= | |
| # def align_by_time(a: List[Cue], b: List[Cue], th: float) -> List[Dict[str, Any]]: | |
| # out: List[Dict[str, Any]] = [] | |
| # i, j, idx = 0, 0, 1 | |
| # while i < len(a) and j < len(b): | |
| # ma = (a[i].start + a[i].end) / 2 | |
| # mb = (b[j].start + b[j].end) / 2 | |
| # if abs(ma - mb) <= th: | |
| # out.append( | |
| # { | |
| # "idx": idx, | |
| # # Per-track time window (recommended for playback) | |
| # "a_start": a[i].start, | |
| # "a_end": a[i].end, | |
| # "b_start": b[j].start, | |
| # "b_end": b[j].end, | |
| # # Optional global time window (for comparison/debug) | |
| # "start": min(a[i].start, b[j].start), | |
| # "end": max(a[i].end, b[j].end), | |
| # "a_text": a[i].text, | |
| # "b_text": b[j].text, | |
| # } | |
| # ) | |
| # idx += 1 | |
| # i += 1 | |
| # j += 1 | |
| # elif ma < mb: | |
| # i += 1 | |
| # else: | |
| # j += 1 | |
| # return out | |
| # # ========================================================= | |
| # # Audio slicing -> return (sr, np.int16) for gr.Audio(type="numpy") | |
| # # ========================================================= | |
| # def export_segment_numpy(audio: AudioSegment, start: float, end: float) -> Tuple[int, np.ndarray]: | |
| # """ | |
| # Robust segment export for gr.Audio(type="numpy"). | |
| # Key points: | |
| # - Clamp start/end (after any offsets) to valid range. | |
| # - Use *rounded* ms boundaries to avoid systematic truncation drift. | |
| # - Slice via pydub (ms-accurate) using the original stream timeline. | |
| # - Normalize to mono/48k/int16 for stable browser playback. | |
| # - Return (sr, int16 ndarray) to avoid float32 scaling pitfalls. | |
| # """ | |
| # # Clamp and ensure minimum duration | |
| # start = float(start) | |
| # end = float(end) | |
| # if end < start: | |
| # start, end = end, start | |
| # start = max(0.0, start) | |
| # end = max(start + 0.05, end) | |
| # # Round to milliseconds (avoid int() truncation bias) | |
| # start_ms = int(round(start * 1000.0)) | |
| # end_ms = int(round(end * 1000.0)) | |
| # seg = audio[start_ms:end_ms] | |
| # # Normalize to mono/48k/int16 | |
| # seg = seg.set_channels(TARGET_CH).set_frame_rate(TARGET_SR).set_sample_width(TARGET_SW) | |
| # arr = np.asarray(seg.get_array_of_samples()) | |
| # if arr.dtype != np.int16: | |
| # arr = arr.astype(np.int16, copy=False) | |
| # return TARGET_SR, np.ascontiguousarray(arr) | |
| # # ========================================================= | |
| # # Helper: robustly read seg_idx from gr.Dataframe value | |
| # # ========================================================= | |
| # def _get_seg_idx_from_df(df_value: Any, row: int) -> Optional[int]: | |
| # if df_value is None: | |
| # return None | |
| # # pandas DataFrame in some Gradio versions | |
| # try: | |
| # import pandas as pd # type: ignore | |
| # if isinstance(df_value, pd.DataFrame): | |
| # if row < 0 or row >= len(df_value.index) or df_value.shape[1] < 1: | |
| # return None | |
| # return int(df_value.iloc[row, 0]) | |
| # except Exception: | |
| # pass | |
| # # list-of-lists | |
| # try: | |
| # if isinstance(df_value, list) and row >= 0 and row < len(df_value) and len(df_value[row]) >= 1: | |
| # return int(df_value[row][0]) | |
| # except Exception: | |
| # return None | |
| # return None | |
| # # ========================================================= | |
| # # Gradio callbacks | |
| # # ========================================================= | |
| # def scan_dataset(repo_id: str, repo_type: str): | |
| # if not repo_id: | |
| # raise gr.Error("请填写 Dataset / Repo 名称(例如 org/dataset)。") | |
| # files = list_repo_files(repo_id, repo_type=repo_type) | |
| # media_files = sorted([f for f in files if f.lower().endswith(MEDIA_EXTS)]) | |
| # vtt_files = sorted([f for f in files if f.lower().endswith(VTT_EXTS)]) | |
| # if not media_files: | |
| # raise gr.Error("未找到媒体文件(mp4/mp3/wav 等)。") | |
| # if not vtt_files: | |
| # raise gr.Error("未找到 VTT 字幕文件。") | |
| # return ( | |
| # gr.update(choices=media_files, value=media_files[0]), | |
| # gr.update(choices=media_files, value=media_files[0]), | |
| # gr.update(choices=vtt_files, value=vtt_files[0]), | |
| # gr.update(choices=vtt_files, value=vtt_files[0]), | |
| # ) | |
| # def load_and_align(repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th): | |
| # if not all([repo_id, repo_type, media_a, media_b, vtt_a, vtt_b]): | |
| # raise gr.Error("请先选择 A/B 的媒体文件与 VTT 文件。") | |
| # local_media_a = hf_hub_download(repo_id, media_a, repo_type=repo_type) | |
| # local_media_b = hf_hub_download(repo_id, media_b, repo_type=repo_type) | |
| # local_vtt_a = hf_hub_download(repo_id, vtt_a, repo_type=repo_type) | |
| # local_vtt_b = hf_hub_download(repo_id, vtt_b, repo_type=repo_type) | |
| # try: | |
| # audio_a = AudioSegment.from_file(local_media_a) | |
| # audio_b = AudioSegment.from_file(local_media_b) | |
| # except Exception as e: | |
| # raise gr.Error( | |
| # "媒体解码失败。若是 mp4/m4a,通常需要 ffmpeg。\n" | |
| # f"原始错误: {repr(e)}" | |
| # ) | |
| # cues_a = parse_vtt_file(local_vtt_a) | |
| # cues_b = parse_vtt_file(local_vtt_b) | |
| # if not cues_a or not cues_b: | |
| # raise gr.Error("VTT 解析为空,请检查字幕文件内容。") | |
| # # ---- Drift fix: estimate time-scale (linear) between VTT timeline and audio timeline ---- | |
| # # If you observe increasing offset over time, it is usually a *scale* mismatch rather than a constant offset. | |
| # # We estimate per-track scale by comparing audio duration to the last cue end time. | |
| # a_vtt_end = max(c.end for c in cues_a) if cues_a else 0.0 | |
| # b_vtt_end = max(c.end for c in cues_b) if cues_b else 0.0 | |
| # a_dur = float(audio_a.duration_seconds) | |
| # b_dur = float(audio_b.duration_seconds) | |
| # # Default scale = 1.0 when we cannot estimate reliably. | |
| # scale_a_suggest = (a_dur / a_vtt_end) if a_vtt_end > 1.0 and a_dur > 1.0 else 1.0 | |
| # scale_b_suggest = (b_dur / b_vtt_end) if b_vtt_end > 1.0 and b_dur > 1.0 else 1.0 | |
| # aligned = align_by_time(cues_a, cues_b, float(th)) | |
| # if not aligned: | |
| # raise gr.Error("未对齐到任何字幕片段,请尝试增大对齐阈值。") | |
| # rows = [ | |
| # [ | |
| # x["idx"], | |
| # f'{x["a_start"]:.2f}-{x["a_end"]:.2f}', | |
| # f'{x["b_start"]:.2f}-{x["b_end"]:.2f}', | |
| # x["a_text"], | |
| # x["b_text"], | |
| # ] | |
| # for x in aligned | |
| # ] | |
| # # Critical: build idx -> seg map to survive dataframe sorting/reordering | |
| # idx_map = {int(x["idx"]): x for x in aligned} | |
| # state = { | |
| # "aligned": aligned, | |
| # "idx_map": idx_map, | |
| # "audio_a": audio_a, | |
| # "audio_b": audio_b, | |
| # "scale_a_suggest": scale_a_suggest, | |
| # "scale_b_suggest": scale_b_suggest, | |
| # "a_vtt_end": a_vtt_end, | |
| # "b_vtt_end": b_vtt_end, | |
| # "a_dur": a_dur, | |
| # "b_dur": b_dur, | |
| # } | |
| # # Clear old playback outputs | |
| # return rows, state, None, None, {}, gr.update(value=scale_a_suggest), gr.update(value=scale_b_suggest) | |
| # def play_on_select(evt: gr.SelectData, df_value, crop_mode, offset_a, offset_b, scale_a, scale_b, state): | |
| # if not state or "aligned" not in state: | |
| # raise gr.Error("请先加载并对齐。") | |
| # # evt.index: int or (row, col) | |
| # idx_raw = evt.index | |
| # row = int(idx_raw[0] if isinstance(idx_raw, (tuple, list)) else idx_raw) | |
| # offset_a = float(offset_a) | |
| # offset_b = float(offset_b) | |
| # scale_a = float(scale_a) | |
| # scale_b = float(scale_b) | |
| # # Prefer seg_idx from the clicked row's first column; then resolve via idx_map. | |
| # seg_idx = _get_seg_idx_from_df(df_value, row) | |
| # seg = None | |
| # idx_map = state.get("idx_map", {}) or {} | |
| # if seg_idx is not None and seg_idx in idx_map: | |
| # seg = idx_map[seg_idx] | |
| # else: | |
| # # Fallback to row->aligned if idx missing (should be rare) | |
| # aligned = state["aligned"] | |
| # if row < 0 or row >= len(aligned): | |
| # raise gr.Error("选中行越界,请重试或重新对齐。") | |
| # seg = aligned[row] | |
| # seg_idx = int(seg.get("idx", row + 1)) | |
| # if crop_mode == "global": | |
| # a_start, a_end = seg["start"] * scale_a + offset_a, seg["end"] * scale_a + offset_a | |
| # b_start, b_end = seg["start"] * scale_b + offset_b, seg["end"] * scale_b + offset_b | |
| # else: | |
| # # per_track playback (recommended) | |
| # a_start, a_end = seg["a_start"] * scale_a + offset_a, seg["a_end"] * scale_a + offset_a | |
| # b_start, b_end = seg["b_start"] * scale_b + offset_b, seg["b_end"] * scale_b + offset_b | |
| # a_np = export_segment_numpy(state["audio_a"], a_start, a_end) | |
| # b_np = export_segment_numpy(state["audio_b"], b_start, b_end) | |
| # info = { | |
| # "segment": seg_idx, | |
| # "row": row, | |
| # "crop_mode": crop_mode, | |
| # "A_time": f"{a_start:.2f}-{a_end:.2f}", | |
| # "B_time": f"{b_start:.2f}-{b_end:.2f}", | |
| # "scale_a": scale_a, | |
| # "scale_b": scale_b, | |
| # "scale_a_suggest": state.get("scale_a_suggest", 1.0), | |
| # "scale_b_suggest": state.get("scale_b_suggest", 1.0), | |
| # } | |
| # return a_np, b_np, info | |
| # # ========================================================= | |
| # # UI | |
| # # ========================================================= | |
| # with gr.Blocks(title="双语音频字幕对齐(点击即播放)") as demo: | |
| # gr.Markdown( | |
| # "# 双语音频字幕对齐(点击表格即播放)\n" | |
| # "流程:扫描 Dataset → 选择 A/B 媒体与字幕 → 加载并对齐 → 点击表格任意单元格播放对应片段。\n" | |
| # "若字幕与音频整体存在固定延迟,可用 Track A/B 偏移进行校正。" | |
| # ) | |
| # state = gr.State() | |
| # with gr.Row(): | |
| # repo_id = gr.Textbox(label="Dataset / Repo 名称", placeholder="org/dataset") | |
| # repo_type = gr.Radio(["dataset", "model"], value="dataset", label="Repo 类型") | |
| # btn_scan = gr.Button("扫描 Dataset", variant="primary") | |
| # with gr.Row(): | |
| # media_a = gr.Dropdown(label="Track A 媒体") | |
| # media_b = gr.Dropdown(label="Track B 媒体") | |
| # with gr.Row(): | |
| # vtt_a = gr.Dropdown(label="Track A 字幕") | |
| # vtt_b = gr.Dropdown(label="Track B 字幕") | |
| # btn_scan.click( | |
| # scan_dataset, | |
| # inputs=[repo_id, repo_type], | |
| # outputs=[media_a, media_b, vtt_a, vtt_b], | |
| # ) | |
| # th = gr.Slider(0.3, 5.0, value=DEFAULT_MAX_MID_DIFF, step=0.1, label="对齐阈值(秒)") | |
| # btn_align = gr.Button("加载并对齐", variant="primary") | |
| # df = gr.Dataframe( | |
| # headers=["#", "A Time", "B Time", "Track A", "Track B"], | |
| # interactive=True, # can be sorted/edited; mapping is stable due to idx_map | |
| # wrap=True, | |
| # max_height=520, | |
| # ) | |
| # with gr.Row(): | |
| # crop_mode = gr.Radio( | |
| # choices=["per_track", "global"], | |
| # value="per_track", | |
| # label="裁剪方式(建议 per_track)", | |
| # ) | |
| # offset_a = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track A 时间偏移(s)") | |
| # offset_b = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track B 时间偏移(s)") | |
| # scale_a = gr.Slider(0.95, 1.05, value=1.0, step=0.0005, label="Track A 时间缩放(scale)") | |
| # scale_b = gr.Slider(0.95, 1.05, value=1.0, step=0.0005, label="Track B 时间缩放(scale)") | |
| # with gr.Row(): | |
| # a_out = gr.Audio(label="Track A 片段", type="numpy") | |
| # b_out = gr.Audio(label="Track B 片段", type="numpy") | |
| # play_info = gr.JSON(label="当前片段") | |
| # btn_align.click( | |
| # load_and_align, | |
| # inputs=[repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th], | |
| # outputs=[df, state, a_out, b_out, play_info, scale_a, scale_b], | |
| # ) | |
| # df.select( | |
| # play_on_select, | |
| # inputs=[df, crop_mode, offset_a, offset_b, scale_a, scale_b, state], | |
| # outputs=[a_out, b_out, play_info], | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Tuple, Optional | |
| import gradio as gr | |
| from huggingface_hub import list_repo_files, hf_hub_download | |
| from pydub import AudioSegment | |
| import numpy as np | |
| # ========================================================= | |
| # Config | |
| # ========================================================= | |
| MEDIA_EXTS = (".mp4", ".m4a", ".mp3", ".wav", ".flac", ".ogg", ".aac", ".mov", ".avi") | |
| VTT_EXTS = (".vtt",) | |
| DEFAULT_MAX_MID_DIFF = 1.5 | |
| # Normalize audio for stable playback in browsers | |
| TARGET_SR = 48000 | |
| TARGET_CH = 1 # mono | |
| TARGET_SW = 2 # 16-bit PCM | |
| # ========================================================= | |
| # Data structures | |
| # ========================================================= | |
| class Cue: | |
| start: float | |
| end: float | |
| text: str | |
| # ========================================================= | |
| # VTT parsing | |
| # ========================================================= | |
| _TAG_RE = re.compile(r"</?[^>]+?>", re.IGNORECASE) | |
| _VTT_TIME_RE = re.compile( | |
| r"(?P<start>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})\s*-->\s*" | |
| r"(?P<end>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})" | |
| ) | |
| def _strip_tags(text: str) -> str: | |
| return _TAG_RE.sub("", text).strip() | |
| def _time_to_seconds(t: str) -> float: | |
| parts = t.split(":") | |
| if len(parts) == 3: | |
| return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2]) | |
| if len(parts) == 2: | |
| return int(parts[0]) * 60 + float(parts[1]) | |
| raise ValueError(f"Bad VTT timestamp: {t}") | |
| def parse_vtt_file(path: str) -> List[Cue]: | |
| with open(path, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| # Remove BOM / WEBVTT header (if any) | |
| content = content.replace("\ufeff", "") | |
| content = re.sub(r"^\s*WEBVTT.*?\n", "", content, flags=re.IGNORECASE) | |
| blocks = re.split(r"\r?\n\r?\n", content.strip()) | |
| cues: List[Cue] = [] | |
| for block in blocks: | |
| lines = [l.strip() for l in block.splitlines() if l.strip()] | |
| if not lines: | |
| continue | |
| # Locate the timestamp line (must contain "-->") | |
| time_idx: Optional[int] = None | |
| for i, line in enumerate(lines): | |
| if "-->" in line: | |
| time_idx = i | |
| break | |
| if time_idx is None: | |
| continue | |
| m = _VTT_TIME_RE.search(lines[time_idx]) | |
| if not m: | |
| continue | |
| start = _time_to_seconds(m.group("start")) | |
| end = _time_to_seconds(m.group("end")) | |
| if end <= start: | |
| continue | |
| # Only take lines after the timestamp line as subtitle text | |
| text_lines = lines[time_idx + 1 :] | |
| if not text_lines: | |
| continue | |
| text = _strip_tags("\n".join(text_lines)) | |
| if text: | |
| cues.append(Cue(start=start, end=end, text=text)) | |
| return sorted(cues, key=lambda x: x.start) | |
| # ========================================================= | |
| # Alignment (match by mid time), preserve per-track windows | |
| # ========================================================= | |
| def align_by_time(a: List[Cue], b: List[Cue], th: float) -> List[Dict[str, Any]]: | |
| out: List[Dict[str, Any]] = [] | |
| i, j, idx = 0, 0, 1 | |
| while i < len(a) and j < len(b): | |
| ma = (a[i].start + a[i].end) / 2 | |
| mb = (b[j].start + b[j].end) / 2 | |
| if abs(ma - mb) <= th: | |
| out.append( | |
| { | |
| "idx": idx, | |
| # Per-track time window (recommended for playback) | |
| "a_start": a[i].start, | |
| "a_end": a[i].end, | |
| "b_start": b[j].start, | |
| "b_end": b[j].end, | |
| # Optional global time window (for comparison/debug) | |
| "start": min(a[i].start, b[j].start), | |
| "end": max(a[i].end, b[j].end), | |
| "a_text": a[i].text, | |
| "b_text": b[j].text, | |
| } | |
| ) | |
| idx += 1 | |
| i += 1 | |
| j += 1 | |
| elif ma < mb: | |
| i += 1 | |
| else: | |
| j += 1 | |
| return out | |
| # ========================================================= | |
| # Audio slicing -> return (sr, np.int16) for gr.Audio(type="numpy") | |
| # ========================================================= | |
| def export_segment_numpy(audio: AudioSegment, start: float, end: float) -> Tuple[int, np.ndarray]: | |
| """ | |
| Robust segment export for gr.Audio(type="numpy"). | |
| Key points: | |
| - Clamp start/end (after any offsets) to valid range. | |
| - Use *rounded* ms boundaries to avoid systematic truncation drift. | |
| - Slice via pydub (ms-accurate) using the original stream timeline. | |
| - Normalize to mono/48k/int16 for stable browser playback. | |
| - Return (sr, int16 ndarray) to avoid float32 scaling pitfalls. | |
| """ | |
| # Clamp and ensure minimum duration | |
| start = float(start) | |
| end = float(end) | |
| if end < start: | |
| start, end = end, start | |
| start = max(0.0, start) | |
| end = max(start + 0.05, end) | |
| # Round to milliseconds (avoid int() truncation bias) | |
| start_ms = int(round(start * 1000.0)) | |
| end_ms = int(round(end * 1000.0)) | |
| seg = audio[start_ms:end_ms] | |
| # Normalize to mono/48k/int16 | |
| seg = seg.set_channels(TARGET_CH).set_frame_rate(TARGET_SR).set_sample_width(TARGET_SW) | |
| arr = np.asarray(seg.get_array_of_samples()) | |
| if arr.dtype != np.int16: | |
| arr = arr.astype(np.int16, copy=False) | |
| return TARGET_SR, np.ascontiguousarray(arr) | |
| # ========================================================= | |
| # Helper: robustly read seg_idx from gr.Dataframe value | |
| # ========================================================= | |
| def _get_seg_idx_from_df(df_value: Any, row: int) -> Optional[int]: | |
| if df_value is None: | |
| return None | |
| # pandas DataFrame in some Gradio versions | |
| try: | |
| import pandas as pd # type: ignore | |
| if isinstance(df_value, pd.DataFrame): | |
| if row < 0 or row >= len(df_value.index) or df_value.shape[1] < 1: | |
| return None | |
| return int(df_value.iloc[row, 0]) | |
| except Exception: | |
| pass | |
| # list-of-lists | |
| try: | |
| if isinstance(df_value, list) and row >= 0 and row < len(df_value) and len(df_value[row]) >= 1: | |
| return int(df_value[row][0]) | |
| except Exception: | |
| return None | |
| return None | |
| # ========================================================= | |
| # Gradio callbacks | |
| # ========================================================= | |
| def scan_dataset(repo_id: str, repo_type: str): | |
| if not repo_id: | |
| raise gr.Error("请填写 Dataset / Repo 名称(例如 org/dataset)。") | |
| files = list_repo_files(repo_id, repo_type=repo_type) | |
| media_files = sorted([f for f in files if f.lower().endswith(MEDIA_EXTS)]) | |
| vtt_files = sorted([f for f in files if f.lower().endswith(VTT_EXTS)]) | |
| if not media_files: | |
| raise gr.Error("未找到媒体文件(mp4/mp3/wav 等)。") | |
| if not vtt_files: | |
| raise gr.Error("未找到 VTT 字幕文件。") | |
| return ( | |
| gr.update(choices=media_files, value=media_files[0]), | |
| gr.update(choices=media_files, value=media_files[0]), | |
| gr.update(choices=vtt_files, value=vtt_files[0]), | |
| gr.update(choices=vtt_files, value=vtt_files[0]), | |
| ) | |
| def load_and_align(repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th): | |
| if not all([repo_id, repo_type, media_a, media_b, vtt_a, vtt_b]): | |
| raise gr.Error("请先选择 A/B 的媒体文件与 VTT 文件。") | |
| local_media_a = hf_hub_download(repo_id, media_a, repo_type=repo_type) | |
| local_media_b = hf_hub_download(repo_id, media_b, repo_type=repo_type) | |
| local_vtt_a = hf_hub_download(repo_id, vtt_a, repo_type=repo_type) | |
| local_vtt_b = hf_hub_download(repo_id, vtt_b, repo_type=repo_type) | |
| try: | |
| audio_a = AudioSegment.from_file(local_media_a) | |
| audio_b = AudioSegment.from_file(local_media_b) | |
| except Exception as e: | |
| raise gr.Error( | |
| "媒体解码失败。若是 mp4/m4a,通常需要 ffmpeg。\n" | |
| f"原始错误: {repr(e)}" | |
| ) | |
| cues_a = parse_vtt_file(local_vtt_a) | |
| cues_b = parse_vtt_file(local_vtt_b) | |
| if not cues_a or not cues_b: | |
| raise gr.Error("VTT 解析为空,请检查字幕文件内容。") | |
| aligned = align_by_time(cues_a, cues_b, float(th)) | |
| if not aligned: | |
| raise gr.Error("未对齐到任何字幕片段,请尝试增大对齐阈值。") | |
| rows = [ | |
| [ | |
| x["idx"], | |
| f'{x["a_start"]:.2f}-{x["a_end"]:.2f}', | |
| f'{x["b_start"]:.2f}-{x["b_end"]:.2f}', | |
| x["a_text"], | |
| x["b_text"], | |
| ] | |
| for x in aligned | |
| ] | |
| # Critical: build idx -> seg map to survive dataframe sorting/reordering | |
| idx_map = {int(x["idx"]): x for x in aligned} | |
| state = { | |
| "aligned": aligned, | |
| "idx_map": idx_map, | |
| "audio_a": audio_a, | |
| "audio_b": audio_b, | |
| } | |
| # Clear old playback outputs | |
| return rows, state, None, None, {} | |
| def play_on_select(evt: gr.SelectData, df_value, crop_mode, offset_a, offset_b, state): | |
| if not state or "aligned" not in state: | |
| raise gr.Error("请先加载并对齐。") | |
| # evt.index: int or (row, col) | |
| idx_raw = evt.index | |
| row = int(idx_raw[0] if isinstance(idx_raw, (tuple, list)) else idx_raw) | |
| offset_a = float(offset_a) | |
| offset_b = float(offset_b) | |
| # Prefer seg_idx from the clicked row's first column; then resolve via idx_map. | |
| seg_idx = _get_seg_idx_from_df(df_value, row) | |
| seg = None | |
| idx_map = state.get("idx_map", {}) or {} | |
| if seg_idx is not None and seg_idx in idx_map: | |
| seg = idx_map[seg_idx] | |
| else: | |
| # Fallback to row->aligned if idx missing (should be rare) | |
| aligned = state["aligned"] | |
| if row < 0 or row >= len(aligned): | |
| raise gr.Error("选中行越界,请重试或重新对齐。") | |
| seg = aligned[row] | |
| seg_idx = int(seg.get("idx", row + 1)) | |
| if crop_mode == "global": | |
| a_start, a_end = seg["start"] + offset_a, seg["end"] + offset_a | |
| b_start, b_end = seg["start"] + offset_b, seg["end"] + offset_b | |
| else: | |
| # per_track playback (recommended) | |
| a_start, a_end = seg["a_start"] + offset_a, seg["a_end"] + offset_a | |
| b_start, b_end = seg["b_start"] + offset_b, seg["b_end"] + offset_b | |
| a_np = export_segment_numpy(state["audio_a"], a_start, a_end) | |
| b_np = export_segment_numpy(state["audio_b"], b_start, b_end) | |
| info = { | |
| "segment": seg_idx, | |
| "row": row, | |
| "crop_mode": crop_mode, | |
| "A_time": f"{a_start:.2f}-{a_end:.2f}", | |
| "B_time": f"{b_start:.2f}-{b_end:.2f}", | |
| } | |
| return a_np, b_np, info | |
| # ========================================================= | |
| # UI | |
| # ========================================================= | |
| with gr.Blocks(title="双语音频字幕对齐(点击即播放)") as demo: | |
| gr.Markdown( | |
| "# 双语音频字幕对齐(点击表格即播放)\n" | |
| "流程:扫描 Dataset → 选择 A/B 媒体与字幕 → 加载并对齐 → 点击表格任意单元格播放对应片段。\n" | |
| "若字幕与音频整体存在固定延迟,可用 Track A/B 偏移进行校正。" | |
| ) | |
| state = gr.State() | |
| with gr.Row(): | |
| repo_id = gr.Textbox(label="Dataset / Repo 名称", placeholder="org/dataset") | |
| repo_type = gr.Radio(["dataset", "model"], value="dataset", label="Repo 类型") | |
| btn_scan = gr.Button("扫描 Dataset", variant="primary") | |
| with gr.Row(): | |
| media_a = gr.Dropdown(label="Track A 媒体") | |
| media_b = gr.Dropdown(label="Track B 媒体") | |
| with gr.Row(): | |
| vtt_a = gr.Dropdown(label="Track A 字幕") | |
| vtt_b = gr.Dropdown(label="Track B 字幕") | |
| btn_scan.click( | |
| scan_dataset, | |
| inputs=[repo_id, repo_type], | |
| outputs=[media_a, media_b, vtt_a, vtt_b], | |
| ) | |
| th = gr.Slider(0.3, 5.0, value=DEFAULT_MAX_MID_DIFF, step=0.1, label="对齐阈值(秒)") | |
| btn_align = gr.Button("加载并对齐", variant="primary") | |
| df = gr.Dataframe( | |
| headers=["#", "A Time", "B Time", "Track A", "Track B"], | |
| interactive=True, # can be sorted/edited; mapping is stable due to idx_map | |
| wrap=True, | |
| max_height=520, | |
| ) | |
| with gr.Row(): | |
| crop_mode = gr.Radio( | |
| choices=["per_track", "global"], | |
| value="per_track", | |
| label="裁剪方式(建议 per_track)", | |
| ) | |
| offset_a = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track A 时间偏移(s)") | |
| offset_b = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track B 时间偏移(s)") | |
| with gr.Row(): | |
| a_out = gr.Audio(label="Track A 片段", type="numpy") | |
| b_out = gr.Audio(label="Track B 片段", type="numpy") | |
| play_info = gr.JSON(label="当前片段") | |
| btn_align.click( | |
| load_and_align, | |
| inputs=[repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th], | |
| outputs=[df, state, a_out, b_out, play_info], | |
| ) | |
| df.select( | |
| play_on_select, | |
| inputs=[df, crop_mode, offset_a, offset_b, state], | |
| outputs=[a_out, b_out, play_info], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |