unknown
play25
bcc5768
# import re
# from dataclasses import dataclass
# from typing import Any, Dict, List, Tuple, Optional
# import gradio as gr
# from huggingface_hub import list_repo_files, hf_hub_download
# from pydub import AudioSegment
# import numpy as np
# # =========================================================
# # Config
# # =========================================================
# MEDIA_EXTS = (".mp4", ".m4a", ".mp3", ".wav", ".flac", ".ogg", ".aac", ".mov", ".avi")
# VTT_EXTS = (".vtt",)
# DEFAULT_MAX_MID_DIFF = 1.5
# # Normalize audio for stable playback in browsers
# TARGET_SR = 48000
# TARGET_CH = 1 # mono
# TARGET_SW = 2 # 16-bit PCM
# # =========================================================
# # Data structures
# # =========================================================
# @dataclass
# class Cue:
# start: float
# end: float
# text: str
# # =========================================================
# # VTT parsing
# # =========================================================
# _TAG_RE = re.compile(r"</?[^>]+?>", re.IGNORECASE)
# _VTT_TIME_RE = re.compile(
# r"(?P<start>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})\s*-->\s*"
# r"(?P<end>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})"
# )
# def _strip_tags(text: str) -> str:
# return _TAG_RE.sub("", text).strip()
# def _time_to_seconds(t: str) -> float:
# parts = t.split(":")
# if len(parts) == 3:
# return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
# if len(parts) == 2:
# return int(parts[0]) * 60 + float(parts[1])
# raise ValueError(f"Bad VTT timestamp: {t}")
# def parse_vtt_file(path: str) -> List[Cue]:
# with open(path, "r", encoding="utf-8") as f:
# content = f.read()
# # Remove BOM / WEBVTT header (if any)
# content = content.replace("\ufeff", "")
# content = re.sub(r"^\s*WEBVTT.*?\n", "", content, flags=re.IGNORECASE)
# blocks = re.split(r"\r?\n\r?\n", content.strip())
# cues: List[Cue] = []
# for block in blocks:
# lines = [l.strip() for l in block.splitlines() if l.strip()]
# if not lines:
# continue
# # Locate the timestamp line (must contain "-->")
# time_idx: Optional[int] = None
# for i, line in enumerate(lines):
# if "-->" in line:
# time_idx = i
# break
# if time_idx is None:
# continue
# m = _VTT_TIME_RE.search(lines[time_idx])
# if not m:
# continue
# start = _time_to_seconds(m.group("start"))
# end = _time_to_seconds(m.group("end"))
# if end <= start:
# continue
# # Only take lines after the timestamp line as subtitle text
# text_lines = lines[time_idx + 1 :]
# if not text_lines:
# continue
# text = _strip_tags("\n".join(text_lines))
# if text:
# cues.append(Cue(start=start, end=end, text=text))
# return sorted(cues, key=lambda x: x.start)
# # =========================================================
# # Alignment (match by mid time), preserve per-track windows
# # =========================================================
# def align_by_time(a: List[Cue], b: List[Cue], th: float) -> List[Dict[str, Any]]:
# out: List[Dict[str, Any]] = []
# i, j, idx = 0, 0, 1
# while i < len(a) and j < len(b):
# ma = (a[i].start + a[i].end) / 2
# mb = (b[j].start + b[j].end) / 2
# if abs(ma - mb) <= th:
# out.append(
# {
# "idx": idx,
# # Per-track time window (recommended for playback)
# "a_start": a[i].start,
# "a_end": a[i].end,
# "b_start": b[j].start,
# "b_end": b[j].end,
# # Optional global time window (for comparison/debug)
# "start": min(a[i].start, b[j].start),
# "end": max(a[i].end, b[j].end),
# "a_text": a[i].text,
# "b_text": b[j].text,
# }
# )
# idx += 1
# i += 1
# j += 1
# elif ma < mb:
# i += 1
# else:
# j += 1
# return out
# # =========================================================
# # Audio slicing -> return (sr, np.int16) for gr.Audio(type="numpy")
# # =========================================================
# def export_segment_numpy(audio: AudioSegment, start: float, end: float) -> Tuple[int, np.ndarray]:
# """
# Robust segment export for gr.Audio(type="numpy").
# Key points:
# - Clamp start/end (after any offsets) to valid range.
# - Use *rounded* ms boundaries to avoid systematic truncation drift.
# - Slice via pydub (ms-accurate) using the original stream timeline.
# - Normalize to mono/48k/int16 for stable browser playback.
# - Return (sr, int16 ndarray) to avoid float32 scaling pitfalls.
# """
# # Clamp and ensure minimum duration
# start = float(start)
# end = float(end)
# if end < start:
# start, end = end, start
# start = max(0.0, start)
# end = max(start + 0.05, end)
# # Round to milliseconds (avoid int() truncation bias)
# start_ms = int(round(start * 1000.0))
# end_ms = int(round(end * 1000.0))
# seg = audio[start_ms:end_ms]
# # Normalize to mono/48k/int16
# seg = seg.set_channels(TARGET_CH).set_frame_rate(TARGET_SR).set_sample_width(TARGET_SW)
# arr = np.asarray(seg.get_array_of_samples())
# if arr.dtype != np.int16:
# arr = arr.astype(np.int16, copy=False)
# return TARGET_SR, np.ascontiguousarray(arr)
# # =========================================================
# # Helper: robustly read seg_idx from gr.Dataframe value
# # =========================================================
# def _get_seg_idx_from_df(df_value: Any, row: int) -> Optional[int]:
# if df_value is None:
# return None
# # pandas DataFrame in some Gradio versions
# try:
# import pandas as pd # type: ignore
# if isinstance(df_value, pd.DataFrame):
# if row < 0 or row >= len(df_value.index) or df_value.shape[1] < 1:
# return None
# return int(df_value.iloc[row, 0])
# except Exception:
# pass
# # list-of-lists
# try:
# if isinstance(df_value, list) and row >= 0 and row < len(df_value) and len(df_value[row]) >= 1:
# return int(df_value[row][0])
# except Exception:
# return None
# return None
# # =========================================================
# # Gradio callbacks
# # =========================================================
# def scan_dataset(repo_id: str, repo_type: str):
# if not repo_id:
# raise gr.Error("请填写 Dataset / Repo 名称(例如 org/dataset)。")
# files = list_repo_files(repo_id, repo_type=repo_type)
# media_files = sorted([f for f in files if f.lower().endswith(MEDIA_EXTS)])
# vtt_files = sorted([f for f in files if f.lower().endswith(VTT_EXTS)])
# if not media_files:
# raise gr.Error("未找到媒体文件(mp4/mp3/wav 等)。")
# if not vtt_files:
# raise gr.Error("未找到 VTT 字幕文件。")
# return (
# gr.update(choices=media_files, value=media_files[0]),
# gr.update(choices=media_files, value=media_files[0]),
# gr.update(choices=vtt_files, value=vtt_files[0]),
# gr.update(choices=vtt_files, value=vtt_files[0]),
# )
# def load_and_align(repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th):
# if not all([repo_id, repo_type, media_a, media_b, vtt_a, vtt_b]):
# raise gr.Error("请先选择 A/B 的媒体文件与 VTT 文件。")
# local_media_a = hf_hub_download(repo_id, media_a, repo_type=repo_type)
# local_media_b = hf_hub_download(repo_id, media_b, repo_type=repo_type)
# local_vtt_a = hf_hub_download(repo_id, vtt_a, repo_type=repo_type)
# local_vtt_b = hf_hub_download(repo_id, vtt_b, repo_type=repo_type)
# try:
# audio_a = AudioSegment.from_file(local_media_a)
# audio_b = AudioSegment.from_file(local_media_b)
# except Exception as e:
# raise gr.Error(
# "媒体解码失败。若是 mp4/m4a,通常需要 ffmpeg。\n"
# f"原始错误: {repr(e)}"
# )
# cues_a = parse_vtt_file(local_vtt_a)
# cues_b = parse_vtt_file(local_vtt_b)
# if not cues_a or not cues_b:
# raise gr.Error("VTT 解析为空,请检查字幕文件内容。")
# # ---- Drift fix: estimate time-scale (linear) between VTT timeline and audio timeline ----
# # If you observe increasing offset over time, it is usually a *scale* mismatch rather than a constant offset.
# # We estimate per-track scale by comparing audio duration to the last cue end time.
# a_vtt_end = max(c.end for c in cues_a) if cues_a else 0.0
# b_vtt_end = max(c.end for c in cues_b) if cues_b else 0.0
# a_dur = float(audio_a.duration_seconds)
# b_dur = float(audio_b.duration_seconds)
# # Default scale = 1.0 when we cannot estimate reliably.
# scale_a_suggest = (a_dur / a_vtt_end) if a_vtt_end > 1.0 and a_dur > 1.0 else 1.0
# scale_b_suggest = (b_dur / b_vtt_end) if b_vtt_end > 1.0 and b_dur > 1.0 else 1.0
# aligned = align_by_time(cues_a, cues_b, float(th))
# if not aligned:
# raise gr.Error("未对齐到任何字幕片段,请尝试增大对齐阈值。")
# rows = [
# [
# x["idx"],
# f'{x["a_start"]:.2f}-{x["a_end"]:.2f}',
# f'{x["b_start"]:.2f}-{x["b_end"]:.2f}',
# x["a_text"],
# x["b_text"],
# ]
# for x in aligned
# ]
# # Critical: build idx -> seg map to survive dataframe sorting/reordering
# idx_map = {int(x["idx"]): x for x in aligned}
# state = {
# "aligned": aligned,
# "idx_map": idx_map,
# "audio_a": audio_a,
# "audio_b": audio_b,
# "scale_a_suggest": scale_a_suggest,
# "scale_b_suggest": scale_b_suggest,
# "a_vtt_end": a_vtt_end,
# "b_vtt_end": b_vtt_end,
# "a_dur": a_dur,
# "b_dur": b_dur,
# }
# # Clear old playback outputs
# return rows, state, None, None, {}, gr.update(value=scale_a_suggest), gr.update(value=scale_b_suggest)
# def play_on_select(evt: gr.SelectData, df_value, crop_mode, offset_a, offset_b, scale_a, scale_b, state):
# if not state or "aligned" not in state:
# raise gr.Error("请先加载并对齐。")
# # evt.index: int or (row, col)
# idx_raw = evt.index
# row = int(idx_raw[0] if isinstance(idx_raw, (tuple, list)) else idx_raw)
# offset_a = float(offset_a)
# offset_b = float(offset_b)
# scale_a = float(scale_a)
# scale_b = float(scale_b)
# # Prefer seg_idx from the clicked row's first column; then resolve via idx_map.
# seg_idx = _get_seg_idx_from_df(df_value, row)
# seg = None
# idx_map = state.get("idx_map", {}) or {}
# if seg_idx is not None and seg_idx in idx_map:
# seg = idx_map[seg_idx]
# else:
# # Fallback to row->aligned if idx missing (should be rare)
# aligned = state["aligned"]
# if row < 0 or row >= len(aligned):
# raise gr.Error("选中行越界,请重试或重新对齐。")
# seg = aligned[row]
# seg_idx = int(seg.get("idx", row + 1))
# if crop_mode == "global":
# a_start, a_end = seg["start"] * scale_a + offset_a, seg["end"] * scale_a + offset_a
# b_start, b_end = seg["start"] * scale_b + offset_b, seg["end"] * scale_b + offset_b
# else:
# # per_track playback (recommended)
# a_start, a_end = seg["a_start"] * scale_a + offset_a, seg["a_end"] * scale_a + offset_a
# b_start, b_end = seg["b_start"] * scale_b + offset_b, seg["b_end"] * scale_b + offset_b
# a_np = export_segment_numpy(state["audio_a"], a_start, a_end)
# b_np = export_segment_numpy(state["audio_b"], b_start, b_end)
# info = {
# "segment": seg_idx,
# "row": row,
# "crop_mode": crop_mode,
# "A_time": f"{a_start:.2f}-{a_end:.2f}",
# "B_time": f"{b_start:.2f}-{b_end:.2f}",
# "scale_a": scale_a,
# "scale_b": scale_b,
# "scale_a_suggest": state.get("scale_a_suggest", 1.0),
# "scale_b_suggest": state.get("scale_b_suggest", 1.0),
# }
# return a_np, b_np, info
# # =========================================================
# # UI
# # =========================================================
# with gr.Blocks(title="双语音频字幕对齐(点击即播放)") as demo:
# gr.Markdown(
# "# 双语音频字幕对齐(点击表格即播放)\n"
# "流程:扫描 Dataset → 选择 A/B 媒体与字幕 → 加载并对齐 → 点击表格任意单元格播放对应片段。\n"
# "若字幕与音频整体存在固定延迟,可用 Track A/B 偏移进行校正。"
# )
# state = gr.State()
# with gr.Row():
# repo_id = gr.Textbox(label="Dataset / Repo 名称", placeholder="org/dataset")
# repo_type = gr.Radio(["dataset", "model"], value="dataset", label="Repo 类型")
# btn_scan = gr.Button("扫描 Dataset", variant="primary")
# with gr.Row():
# media_a = gr.Dropdown(label="Track A 媒体")
# media_b = gr.Dropdown(label="Track B 媒体")
# with gr.Row():
# vtt_a = gr.Dropdown(label="Track A 字幕")
# vtt_b = gr.Dropdown(label="Track B 字幕")
# btn_scan.click(
# scan_dataset,
# inputs=[repo_id, repo_type],
# outputs=[media_a, media_b, vtt_a, vtt_b],
# )
# th = gr.Slider(0.3, 5.0, value=DEFAULT_MAX_MID_DIFF, step=0.1, label="对齐阈值(秒)")
# btn_align = gr.Button("加载并对齐", variant="primary")
# df = gr.Dataframe(
# headers=["#", "A Time", "B Time", "Track A", "Track B"],
# interactive=True, # can be sorted/edited; mapping is stable due to idx_map
# wrap=True,
# max_height=520,
# )
# with gr.Row():
# crop_mode = gr.Radio(
# choices=["per_track", "global"],
# value="per_track",
# label="裁剪方式(建议 per_track)",
# )
# offset_a = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track A 时间偏移(s)")
# offset_b = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track B 时间偏移(s)")
# scale_a = gr.Slider(0.95, 1.05, value=1.0, step=0.0005, label="Track A 时间缩放(scale)")
# scale_b = gr.Slider(0.95, 1.05, value=1.0, step=0.0005, label="Track B 时间缩放(scale)")
# with gr.Row():
# a_out = gr.Audio(label="Track A 片段", type="numpy")
# b_out = gr.Audio(label="Track B 片段", type="numpy")
# play_info = gr.JSON(label="当前片段")
# btn_align.click(
# load_and_align,
# inputs=[repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th],
# outputs=[df, state, a_out, b_out, play_info, scale_a, scale_b],
# )
# df.select(
# play_on_select,
# inputs=[df, crop_mode, offset_a, offset_b, scale_a, scale_b, state],
# outputs=[a_out, b_out, play_info],
# )
# if __name__ == "__main__":
# demo.launch()
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional
import gradio as gr
from huggingface_hub import list_repo_files, hf_hub_download
from pydub import AudioSegment
import numpy as np
# =========================================================
# Config
# =========================================================
MEDIA_EXTS = (".mp4", ".m4a", ".mp3", ".wav", ".flac", ".ogg", ".aac", ".mov", ".avi")
VTT_EXTS = (".vtt",)
DEFAULT_MAX_MID_DIFF = 1.5
# Normalize audio for stable playback in browsers
TARGET_SR = 48000
TARGET_CH = 1 # mono
TARGET_SW = 2 # 16-bit PCM
# =========================================================
# Data structures
# =========================================================
@dataclass
class Cue:
start: float
end: float
text: str
# =========================================================
# VTT parsing
# =========================================================
_TAG_RE = re.compile(r"</?[^>]+?>", re.IGNORECASE)
_VTT_TIME_RE = re.compile(
r"(?P<start>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})\s*-->\s*"
r"(?P<end>\d{2}:\d{2}:\d{2}\.\d{3}|\d{1,2}:\d{2}\.\d{3})"
)
def _strip_tags(text: str) -> str:
return _TAG_RE.sub("", text).strip()
def _time_to_seconds(t: str) -> float:
parts = t.split(":")
if len(parts) == 3:
return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
if len(parts) == 2:
return int(parts[0]) * 60 + float(parts[1])
raise ValueError(f"Bad VTT timestamp: {t}")
def parse_vtt_file(path: str) -> List[Cue]:
with open(path, "r", encoding="utf-8") as f:
content = f.read()
# Remove BOM / WEBVTT header (if any)
content = content.replace("\ufeff", "")
content = re.sub(r"^\s*WEBVTT.*?\n", "", content, flags=re.IGNORECASE)
blocks = re.split(r"\r?\n\r?\n", content.strip())
cues: List[Cue] = []
for block in blocks:
lines = [l.strip() for l in block.splitlines() if l.strip()]
if not lines:
continue
# Locate the timestamp line (must contain "-->")
time_idx: Optional[int] = None
for i, line in enumerate(lines):
if "-->" in line:
time_idx = i
break
if time_idx is None:
continue
m = _VTT_TIME_RE.search(lines[time_idx])
if not m:
continue
start = _time_to_seconds(m.group("start"))
end = _time_to_seconds(m.group("end"))
if end <= start:
continue
# Only take lines after the timestamp line as subtitle text
text_lines = lines[time_idx + 1 :]
if not text_lines:
continue
text = _strip_tags("\n".join(text_lines))
if text:
cues.append(Cue(start=start, end=end, text=text))
return sorted(cues, key=lambda x: x.start)
# =========================================================
# Alignment (match by mid time), preserve per-track windows
# =========================================================
def align_by_time(a: List[Cue], b: List[Cue], th: float) -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
i, j, idx = 0, 0, 1
while i < len(a) and j < len(b):
ma = (a[i].start + a[i].end) / 2
mb = (b[j].start + b[j].end) / 2
if abs(ma - mb) <= th:
out.append(
{
"idx": idx,
# Per-track time window (recommended for playback)
"a_start": a[i].start,
"a_end": a[i].end,
"b_start": b[j].start,
"b_end": b[j].end,
# Optional global time window (for comparison/debug)
"start": min(a[i].start, b[j].start),
"end": max(a[i].end, b[j].end),
"a_text": a[i].text,
"b_text": b[j].text,
}
)
idx += 1
i += 1
j += 1
elif ma < mb:
i += 1
else:
j += 1
return out
# =========================================================
# Audio slicing -> return (sr, np.int16) for gr.Audio(type="numpy")
# =========================================================
def export_segment_numpy(audio: AudioSegment, start: float, end: float) -> Tuple[int, np.ndarray]:
"""
Robust segment export for gr.Audio(type="numpy").
Key points:
- Clamp start/end (after any offsets) to valid range.
- Use *rounded* ms boundaries to avoid systematic truncation drift.
- Slice via pydub (ms-accurate) using the original stream timeline.
- Normalize to mono/48k/int16 for stable browser playback.
- Return (sr, int16 ndarray) to avoid float32 scaling pitfalls.
"""
# Clamp and ensure minimum duration
start = float(start)
end = float(end)
if end < start:
start, end = end, start
start = max(0.0, start)
end = max(start + 0.05, end)
# Round to milliseconds (avoid int() truncation bias)
start_ms = int(round(start * 1000.0))
end_ms = int(round(end * 1000.0))
seg = audio[start_ms:end_ms]
# Normalize to mono/48k/int16
seg = seg.set_channels(TARGET_CH).set_frame_rate(TARGET_SR).set_sample_width(TARGET_SW)
arr = np.asarray(seg.get_array_of_samples())
if arr.dtype != np.int16:
arr = arr.astype(np.int16, copy=False)
return TARGET_SR, np.ascontiguousarray(arr)
# =========================================================
# Helper: robustly read seg_idx from gr.Dataframe value
# =========================================================
def _get_seg_idx_from_df(df_value: Any, row: int) -> Optional[int]:
if df_value is None:
return None
# pandas DataFrame in some Gradio versions
try:
import pandas as pd # type: ignore
if isinstance(df_value, pd.DataFrame):
if row < 0 or row >= len(df_value.index) or df_value.shape[1] < 1:
return None
return int(df_value.iloc[row, 0])
except Exception:
pass
# list-of-lists
try:
if isinstance(df_value, list) and row >= 0 and row < len(df_value) and len(df_value[row]) >= 1:
return int(df_value[row][0])
except Exception:
return None
return None
# =========================================================
# Gradio callbacks
# =========================================================
def scan_dataset(repo_id: str, repo_type: str):
if not repo_id:
raise gr.Error("请填写 Dataset / Repo 名称(例如 org/dataset)。")
files = list_repo_files(repo_id, repo_type=repo_type)
media_files = sorted([f for f in files if f.lower().endswith(MEDIA_EXTS)])
vtt_files = sorted([f for f in files if f.lower().endswith(VTT_EXTS)])
if not media_files:
raise gr.Error("未找到媒体文件(mp4/mp3/wav 等)。")
if not vtt_files:
raise gr.Error("未找到 VTT 字幕文件。")
return (
gr.update(choices=media_files, value=media_files[0]),
gr.update(choices=media_files, value=media_files[0]),
gr.update(choices=vtt_files, value=vtt_files[0]),
gr.update(choices=vtt_files, value=vtt_files[0]),
)
def load_and_align(repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th):
if not all([repo_id, repo_type, media_a, media_b, vtt_a, vtt_b]):
raise gr.Error("请先选择 A/B 的媒体文件与 VTT 文件。")
local_media_a = hf_hub_download(repo_id, media_a, repo_type=repo_type)
local_media_b = hf_hub_download(repo_id, media_b, repo_type=repo_type)
local_vtt_a = hf_hub_download(repo_id, vtt_a, repo_type=repo_type)
local_vtt_b = hf_hub_download(repo_id, vtt_b, repo_type=repo_type)
try:
audio_a = AudioSegment.from_file(local_media_a)
audio_b = AudioSegment.from_file(local_media_b)
except Exception as e:
raise gr.Error(
"媒体解码失败。若是 mp4/m4a,通常需要 ffmpeg。\n"
f"原始错误: {repr(e)}"
)
cues_a = parse_vtt_file(local_vtt_a)
cues_b = parse_vtt_file(local_vtt_b)
if not cues_a or not cues_b:
raise gr.Error("VTT 解析为空,请检查字幕文件内容。")
aligned = align_by_time(cues_a, cues_b, float(th))
if not aligned:
raise gr.Error("未对齐到任何字幕片段,请尝试增大对齐阈值。")
rows = [
[
x["idx"],
f'{x["a_start"]:.2f}-{x["a_end"]:.2f}',
f'{x["b_start"]:.2f}-{x["b_end"]:.2f}',
x["a_text"],
x["b_text"],
]
for x in aligned
]
# Critical: build idx -> seg map to survive dataframe sorting/reordering
idx_map = {int(x["idx"]): x for x in aligned}
state = {
"aligned": aligned,
"idx_map": idx_map,
"audio_a": audio_a,
"audio_b": audio_b,
}
# Clear old playback outputs
return rows, state, None, None, {}
def play_on_select(evt: gr.SelectData, df_value, crop_mode, offset_a, offset_b, state):
if not state or "aligned" not in state:
raise gr.Error("请先加载并对齐。")
# evt.index: int or (row, col)
idx_raw = evt.index
row = int(idx_raw[0] if isinstance(idx_raw, (tuple, list)) else idx_raw)
offset_a = float(offset_a)
offset_b = float(offset_b)
# Prefer seg_idx from the clicked row's first column; then resolve via idx_map.
seg_idx = _get_seg_idx_from_df(df_value, row)
seg = None
idx_map = state.get("idx_map", {}) or {}
if seg_idx is not None and seg_idx in idx_map:
seg = idx_map[seg_idx]
else:
# Fallback to row->aligned if idx missing (should be rare)
aligned = state["aligned"]
if row < 0 or row >= len(aligned):
raise gr.Error("选中行越界,请重试或重新对齐。")
seg = aligned[row]
seg_idx = int(seg.get("idx", row + 1))
if crop_mode == "global":
a_start, a_end = seg["start"] + offset_a, seg["end"] + offset_a
b_start, b_end = seg["start"] + offset_b, seg["end"] + offset_b
else:
# per_track playback (recommended)
a_start, a_end = seg["a_start"] + offset_a, seg["a_end"] + offset_a
b_start, b_end = seg["b_start"] + offset_b, seg["b_end"] + offset_b
a_np = export_segment_numpy(state["audio_a"], a_start, a_end)
b_np = export_segment_numpy(state["audio_b"], b_start, b_end)
info = {
"segment": seg_idx,
"row": row,
"crop_mode": crop_mode,
"A_time": f"{a_start:.2f}-{a_end:.2f}",
"B_time": f"{b_start:.2f}-{b_end:.2f}",
}
return a_np, b_np, info
# =========================================================
# UI
# =========================================================
with gr.Blocks(title="双语音频字幕对齐(点击即播放)") as demo:
gr.Markdown(
"# 双语音频字幕对齐(点击表格即播放)\n"
"流程:扫描 Dataset → 选择 A/B 媒体与字幕 → 加载并对齐 → 点击表格任意单元格播放对应片段。\n"
"若字幕与音频整体存在固定延迟,可用 Track A/B 偏移进行校正。"
)
state = gr.State()
with gr.Row():
repo_id = gr.Textbox(label="Dataset / Repo 名称", placeholder="org/dataset")
repo_type = gr.Radio(["dataset", "model"], value="dataset", label="Repo 类型")
btn_scan = gr.Button("扫描 Dataset", variant="primary")
with gr.Row():
media_a = gr.Dropdown(label="Track A 媒体")
media_b = gr.Dropdown(label="Track B 媒体")
with gr.Row():
vtt_a = gr.Dropdown(label="Track A 字幕")
vtt_b = gr.Dropdown(label="Track B 字幕")
btn_scan.click(
scan_dataset,
inputs=[repo_id, repo_type],
outputs=[media_a, media_b, vtt_a, vtt_b],
)
th = gr.Slider(0.3, 5.0, value=DEFAULT_MAX_MID_DIFF, step=0.1, label="对齐阈值(秒)")
btn_align = gr.Button("加载并对齐", variant="primary")
df = gr.Dataframe(
headers=["#", "A Time", "B Time", "Track A", "Track B"],
interactive=True, # can be sorted/edited; mapping is stable due to idx_map
wrap=True,
max_height=520,
)
with gr.Row():
crop_mode = gr.Radio(
choices=["per_track", "global"],
value="per_track",
label="裁剪方式(建议 per_track)",
)
offset_a = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track A 时间偏移(s)")
offset_b = gr.Slider(-20, 20, value=0.0, step=0.05, label="Track B 时间偏移(s)")
with gr.Row():
a_out = gr.Audio(label="Track A 片段", type="numpy")
b_out = gr.Audio(label="Track B 片段", type="numpy")
play_info = gr.JSON(label="当前片段")
btn_align.click(
load_and_align,
inputs=[repo_id, repo_type, media_a, media_b, vtt_a, vtt_b, th],
outputs=[df, state, a_out, b_out, play_info],
)
df.select(
play_on_select,
inputs=[df, crop_mode, offset_a, offset_b, state],
outputs=[a_out, b_out, play_info],
)
if __name__ == "__main__":
demo.launch()