Spaces:
Running
Running
| import os | |
| import re | |
| import sys | |
| import tempfile | |
| import subprocess | |
| import traceback | |
| from typing import List, Tuple, Dict | |
| import torch | |
| import soundfile as sf | |
| import gradio as gr | |
| # ---------- 设备检测 ---------- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"运行设备: {DEVICE}") | |
| # ---------- 导入 CTC Forced Aligner ---------- | |
| try: | |
| import ctc_forced_aligner | |
| from ctc_forced_aligner import ( | |
| load_audio, | |
| load_alignment_model, | |
| generate_emissions, | |
| preprocess_text, | |
| get_alignments, | |
| ) | |
| import ctc_forced_aligner.alignment_utils as ctc_au | |
| import ctc_forced_aligner.text_utils as ctc_tu | |
| CTC_AVAILABLE = True | |
| print("✅ CTC Forced Aligner 已就绪") | |
| except ImportError: | |
| CTC_AVAILABLE = False | |
| print("⚠️ CTC Forced Aligner 不可用") | |
| # ---------- Qwen3 模型封装 ---------- | |
| QWEN_AVAILABLE = False | |
| try: | |
| from qwen_asr import Qwen3ForcedAligner | |
| QWEN_AVAILABLE = True | |
| print("✅ Qwen3-ForcedAligner 已就绪") | |
| except ImportError: | |
| print("⚠️ Qwen3-ForcedAligner 不可用") | |
| # ================== 核心算法 ================== | |
| def get_pure_text_length(text: str) -> int: | |
| """计算纯净字符数:去除所有标点、空格、控制字符后剩余的字符数。""" | |
| return len(re.sub( | |
| r'[^\w一-鿿-ゟ゠-ヿ]', | |
| '', str(text) | |
| ).lower()) | |
| def merge_token_timestamps_to_sentences( | |
| token_timestamps: List[Tuple[str, float, float]], | |
| target_sentences: List[str], | |
| debug: bool = False | |
| ) -> List[Dict]: | |
| """通过字符数累计将模型输出的词/字级时间戳匹配到预分段短句。""" | |
| if not target_sentences: | |
| return [] | |
| results = [] | |
| token_idx = 0 | |
| total_tokens = len(token_timestamps) | |
| for sent_idx, sentence in enumerate(target_sentences): | |
| t_len = get_pure_text_length(sentence) | |
| if t_len == 0: | |
| results.append({"text": sentence, "start": 0.0, "end": 0.0}) | |
| continue | |
| acc_len = 0 | |
| st, et = None, None | |
| while token_idx < total_tokens and acc_len < t_len: | |
| seg_text, seg_start, seg_end = token_timestamps[token_idx] | |
| if st is None: | |
| st = seg_start | |
| et = seg_end | |
| acc_len += get_pure_text_length(seg_text) | |
| token_idx += 1 | |
| if debug and sent_idx < 5: | |
| print(f" [{sent_idx}] \"{sentence[:50]}\" -> " | |
| f"tokens char_cnt={acc_len}/{t_len} " | |
| f"time={st:.2f}s-{et:.2f}s " if st else " ") | |
| if st is not None and et is not None: | |
| results.append({ | |
| "text": sentence, | |
| "start": round(st, 3), | |
| "end": round(et, 3), | |
| }) | |
| else: | |
| results.append({"text": sentence, "start": 0.0, "end": 0.0}) | |
| # 后处理:修复缺失/异常时间戳 | |
| for i in range(len(results)): | |
| if results[i]["start"] == 0.0 and results[i]["end"] == 0.0: | |
| for j in range(i - 1, -1, -1): | |
| if results[j]["end"] > 0: | |
| results[i]["start"] = results[j]["end"] | |
| results[i]["end"] = results[j]["end"] | |
| break | |
| if results[i]["start"] == 0.0: | |
| for j in range(i + 1, len(results)): | |
| if results[j]["start"] > 0: | |
| results[i]["start"] = results[j]["start"] | |
| results[i]["end"] = results[j]["start"] | |
| break | |
| for i in range(len(results)): | |
| if i > 0 and results[i]["start"] < results[i - 1]["end"]: | |
| results[i]["start"] = results[i - 1]["end"] | |
| if results[i]["end"] < results[i]["start"]: | |
| results[i]["end"] = results[i]["start"] + 0.001 | |
| if debug: | |
| non_zero = sum(1 for r in results if r["start"] > 0 or r["end"] > 0) | |
| print(f"时间戳覆盖率: {non_zero}/{len(results)} 句") | |
| return results | |
| def seconds_to_srt_time(seconds: float) -> str: | |
| seconds = max(0, seconds) | |
| h = int(seconds // 3600) | |
| m = int((seconds % 3600) // 60) | |
| s = int(seconds % 60) | |
| ms = int((seconds % 1) * 1000) | |
| return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" | |
| def format_srt(segments: List[Dict]) -> str: | |
| lines = [] | |
| index = 1 | |
| for seg in segments: | |
| text = seg["text"].strip() | |
| if not text: | |
| continue | |
| lines.append(str(index)) | |
| lines.append( | |
| f"{seconds_to_srt_time(seg['start'])} --> {seconds_to_srt_time(seg['end'])}" | |
| ) | |
| lines.append(text) | |
| lines.append("") | |
| index += 1 | |
| return "\n".join(lines) | |
| # ================== CTC 对齐封装(含容错补丁) ================== | |
| def run_ctc_alignment( | |
| audio_path: str, | |
| full_text: str, | |
| target_sentences: List[str], | |
| language: str = "eng" | |
| ) -> List[Dict]: | |
| """使用 CTC Forced Aligner 进行强制对齐(原补丁保留)""" | |
| dtype = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| _original_get_spans = ctc_au.get_spans | |
| _original_postprocess = ctc_tu.postprocess_results | |
| def _relaxed_get_spans(tokens_starred, segments, blank_token): | |
| n_seg = len(segments) | |
| spans = [] | |
| si = 0 | |
| for token in tokens_starred: | |
| target_letters = token.split(" ") | |
| while si < n_seg and segments[si].label == blank_token: | |
| si += 1 | |
| start_seg_idx = si | |
| end_seg_idx = si | |
| matched_any = False | |
| for ltr in target_letters: | |
| while si < n_seg and segments[si].label == blank_token: | |
| si += 1 | |
| if si < n_seg and segments[si].label == ltr: | |
| if not matched_any: | |
| start_seg_idx = si | |
| end_seg_idx = si | |
| matched_any = True | |
| si += 1 | |
| if not matched_any: | |
| safe_idx = min(start_seg_idx, n_seg - 1) if n_seg > 0 else 0 | |
| spans.append([ctc_au.Segment(token, safe_idx, safe_idx)]) | |
| else: | |
| spans.append(segments[start_seg_idx : end_seg_idx + 1]) | |
| return spans | |
| def _safe_postprocess_results(text_starred, spans, stride, scores, merge_threshold=0.0): | |
| results = [] | |
| for i, t in enumerate(text_starred): | |
| if t == "<star>": continue | |
| span = spans[i] | |
| if not span: continue | |
| seg_start_idx = span[0].start | |
| seg_end_idx = span[-1].end | |
| audio_start_sec = seg_start_idx * stride / 1000.0 | |
| audio_end_sec = seg_end_idx * stride / 1000.0 | |
| score = scores[seg_start_idx : seg_end_idx + 1].sum() if seg_end_idx >= seg_start_idx else 0.0 | |
| score_val = score.item() if hasattr(score, "item") else float(score) | |
| results.append({ | |
| "start": audio_start_sec, | |
| "end": audio_end_sec, | |
| "text": t, | |
| "score": score_val, | |
| }) | |
| ctc_tu.merge_segments(results, merge_threshold) | |
| return results | |
| ctc_au.get_spans = _relaxed_get_spans | |
| ctc_tu.postprocess_results = _safe_postprocess_results | |
| print(f"🚀 加载 CTC 对齐模型 (设备: {DEVICE})") | |
| alignment_model, alignment_tokenizer = load_alignment_model(DEVICE, dtype=dtype) | |
| print("🔄 加载音频...") | |
| audio_waveform = load_audio(audio_path, alignment_model.dtype, alignment_model.device) | |
| print("🔄 生成发射矩阵...") | |
| emissions, stride = generate_emissions(alignment_model, audio_waveform, batch_size=8) | |
| non_latin = {"cmn", "zho", "chi", "jpn", "ja", "kor", "ko", "ara", "ar", "rus", "ru"} | |
| needs_romanize = language in non_latin | |
| tokens_starred, text_starred = preprocess_text(full_text, romanize=needs_romanize, language=language) | |
| print("🔄 CTC 解码...") | |
| segments_raw, scores, blank_token = get_alignments(emissions, tokens_starred, alignment_tokenizer) | |
| print("🔄 获取时间跨度 (容错模式)...") | |
| spans = ctc_au.get_spans(tokens_starred, segments_raw, blank_token) | |
| results = ctc_tu.postprocess_results(text_starred, spans, stride, scores) | |
| token_timestamps = [(seg["text"], seg["start"], seg["end"]) for seg in results] | |
| print(f"模型输出 {len(token_timestamps)} 个词/字级时间戳") | |
| segments = merge_token_timestamps_to_sentences(token_timestamps, target_sentences, debug=True) | |
| ctc_au.get_spans = _original_get_spans | |
| ctc_tu.postprocess_results = _original_postprocess | |
| del alignment_model | |
| if DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| return segments | |
| # ================== Qwen3 对齐封装 ================== | |
| def run_qwen_alignment( | |
| audio_path: str, | |
| full_text: str, | |
| target_sentences: List[str], | |
| language: str = "Chinese" | |
| ) -> List[Dict]: | |
| """ | |
| 使用 Qwen3-ForcedAligner-0.6B 进行强制对齐。 | |
| 自动适配 CPU/GPU,支持长音频的滑动窗口切片。 | |
| """ | |
| # 根据设备设置精度和 device_map | |
| if DEVICE == "cuda": | |
| dtype = torch.bfloat16 | |
| device_map = "cuda:0" | |
| else: | |
| dtype = torch.float32 # CPU 不支持 bfloat16 推理 | |
| device_map = "cpu" | |
| print(f"🚀 加载 Qwen3-ForcedAligner-0.6B (设备: {DEVICE}, dtype: {dtype})") | |
| model = Qwen3ForcedAligner.from_pretrained( | |
| "Qwen/Qwen3-ForcedAligner-0.6B", | |
| dtype=dtype, | |
| device_map=device_map, | |
| ) | |
| # 读取音频 | |
| audio_data, sr = sf.read(audio_path) | |
| total_duration = len(audio_data) / sr | |
| print(f"📊 音频总时长: {total_duration:.1f}s") | |
| # 切片参数 | |
| MAX_CHUNK_DUR = 240.0 # 每次最多 4 分钟 | |
| SAFE_TAIL_MARGIN = 15.0 # 丢弃末尾 15s 的不完整句子 | |
| remaining = list(target_sentences) | |
| time_offset = 0.0 | |
| all_segments = [] | |
| chunk_idx = 0 | |
| while remaining and time_offset < total_duration: | |
| chunk_idx += 1 | |
| chunk_dur = min(MAX_CHUNK_DUR, total_duration - time_offset) | |
| is_last = (time_offset + chunk_dur >= total_duration - 1.0) | |
| start_f = int(time_offset * sr) | |
| end_f = int((time_offset + chunk_dur) * sr) | |
| chunk_audio = audio_data[start_f:end_f] | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| sf.write(f.name, chunk_audio, sr) | |
| chunk_path = f.name | |
| chunk_text = " ".join(remaining) | |
| print(f"\n▶️ Chunk {chunk_idx}: 音频[{time_offset:.0f}s-{time_offset + chunk_dur:.0f}s] " | |
| f"剩余{len(remaining)}句") | |
| results = model.align(audio=chunk_path, text=chunk_text, language=language) | |
| tokens = results[0] # List[AlignmentResult] | |
| token_data = [] | |
| for seg in tokens: | |
| try: | |
| token_data.append((seg.text, seg.start_time, seg.end_time)) | |
| except AttributeError: | |
| d = vars(seg) if hasattr(seg, '__dict__') else {} | |
| token_data.append(( | |
| d.get('text', d.get('token', d.get('word', ''))), | |
| d.get('start_time', d.get('start', 0.0)), | |
| d.get('end_time', d.get('end', 0.0)), | |
| )) | |
| # 用字符计数法匹配句子 | |
| matched = [] | |
| ti = 0 | |
| for sentence in remaining: | |
| t_len = get_pure_text_length(sentence) | |
| if t_len == 0: | |
| continue | |
| acc = 0 | |
| st, et = None, None | |
| while ti < len(token_data) and acc < t_len: | |
| seg_text, seg_start, seg_end = token_data[ti] | |
| if st is None: | |
| st = seg_start | |
| et = seg_end | |
| acc += get_pure_text_length(seg_text) | |
| ti += 1 | |
| if st is not None and et is not None: | |
| matched.append({"text": sentence, "start": st, "end": et}) | |
| # 安全切分点 | |
| if is_last: | |
| valid = matched | |
| remaining = [] | |
| else: | |
| valid_idx = -1 | |
| for i, m in enumerate(matched): | |
| if m["end"] < (chunk_dur - SAFE_TAIL_MARGIN): | |
| valid_idx = i | |
| else: | |
| break | |
| if valid_idx == -1 and matched: | |
| valid_idx = 0 | |
| valid = matched[:valid_idx + 1] if valid_idx >= 0 else [] | |
| remaining = remaining[valid_idx + 1:] if valid_idx >= 0 else [] | |
| print(f" 本段对齐 {len(valid)} 句(共{len(matched)}句匹配)") | |
| for m in valid: | |
| all_segments.append({ | |
| "text": m["text"], | |
| "start": round(m["start"] + time_offset, 3), | |
| "end": round(m["end"] + time_offset, 3), | |
| }) | |
| if valid: | |
| time_offset = time_offset + valid[-1]["end"] | |
| else: | |
| time_offset = total_duration | |
| os.unlink(chunk_path) | |
| if DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| del model | |
| if DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| print(f"\n✅ Qwen 对齐完成:{len(all_segments)} 句") | |
| return all_segments | |
| # ================== 音频格式转换 ================== | |
| def convert_to_wav(input_audio_path: str) -> str: | |
| """使用 ffmpeg 转换为 16kHz 单声道 wav""" | |
| tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| tmp_wav.close() | |
| cmd = [ | |
| "ffmpeg", "-y", | |
| "-i", input_audio_path, | |
| "-ar", "16000", | |
| "-ac", "1", | |
| "-c:a", "pcm_s16le", | |
| "-loglevel", "error", | |
| tmp_wav.name | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| return tmp_wav.name | |
| except subprocess.CalledProcessError as e: | |
| raise RuntimeError(f"FFmpeg 转换失败: {e.stderr.decode('utf-8', errors='ignore')}") | |
| # ================== 主处理函数 ================== | |
| def process_alignment( | |
| audio_file, | |
| text_input: str, | |
| text_file, | |
| language: str, | |
| model_choice: str | |
| ): | |
| debug_lines = [] | |
| if audio_file is None: | |
| return "", "请上传音频文件", "", None | |
| # 读取文本 | |
| raw_text = "" | |
| if text_file is not None: | |
| try: | |
| file_path = text_file if isinstance(text_file, str) else ( | |
| text_file.get("name", "") if isinstance(text_file, dict) else getattr(text_file, "name", "") | |
| ) | |
| if file_path and os.path.exists(file_path): | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| raw_text = f.read() | |
| debug_lines.append(f"从文件读取文本 ({len(raw_text)} 字符)") | |
| except Exception as e: | |
| debug_lines.append(f"读取文本文件失败: {e}") | |
| if not raw_text and text_input: | |
| raw_text = text_input | |
| if not raw_text or not raw_text.strip(): | |
| return "", "请输入文本或上传文本文件", "", None | |
| target_sentences = [line.strip() for line in raw_text.strip().splitlines() if line.strip()] | |
| if not target_sentences: | |
| return "", "文本为空或格式不正确(每行一个短句)", "", None | |
| full_text = " ".join(target_sentences) | |
| lang_map = { | |
| "中文": "cmn", "英文": "eng", "日语": "jpn", | |
| "韩语": "kor", "法语": "fra", "德语": "deu", | |
| "俄语": "rus", "西班牙语": "spa", "意大利语": "ita", | |
| "葡萄牙语": "por", | |
| } | |
| lang = lang_map.get(language, "cmn") | |
| # Qwen 模型的语言映射(将 UI 的中文选项映射为模型需要的英文标识) | |
| qwen_lang_map = { | |
| "中文": "Chinese", | |
| "英文": "English", | |
| "日语": "Japanese", | |
| "韩语": "Korean", | |
| "法语": "French", | |
| "德语": "German", | |
| "俄语": "Russian", | |
| "西班牙语": "Spanish", | |
| "意大利语": "Italian", | |
| "葡萄牙语": "Portuguese", | |
| } | |
| # 如果选择了不支持的语言,默认回退到 English (或 Chinese,视 Qwen3 模型的具体支持情况而定) | |
| qwen_lang = qwen_lang_map.get(language, "English") | |
| debug_lines.append(f"音频: {audio_file}") | |
| debug_lines.append(f"语言: {language} (内部代码: {lang})") | |
| debug_lines.append(f"选用模型: {model_choice}") | |
| debug_lines.append(f"句子数: {len(target_sentences)}") | |
| # 音频转换 | |
| try: | |
| processed_audio_path = convert_to_wav(audio_file) | |
| debug_lines.append("✅ 音频格式转换完成") | |
| except Exception as e: | |
| debug_lines.append(f"❌ 音频转码失败: {e}") | |
| return "", "音频转码失败,请上传有效文件", "\n".join(debug_lines), None | |
| # 选择模型执行对齐 | |
| try: | |
| if model_choice == "CTC Forced Aligner": | |
| if not CTC_AVAILABLE: | |
| return "", "CTC 模型未安装,请检查依赖。", "\n".join(debug_lines), None | |
| segments = run_ctc_alignment(processed_audio_path, full_text, target_sentences, lang) | |
| else: # Qwen3 | |
| if not QWEN_AVAILABLE: | |
| return "", "Qwen3 模型未安装,请检查依赖。", "\n".join(debug_lines), None | |
| segments = run_qwen_alignment(processed_audio_path, full_text, target_sentences, qwen_lang) | |
| os.unlink(processed_audio_path) | |
| srt_content = format_srt(segments) | |
| debug_lines.append(f"\n🎉 对齐完成! 共 {len(segments)} 段") | |
| for seg in segments[:15]: | |
| debug_lines.append(f" [{seg['start']:.2f}s - {seg['end']:.2f}s] {seg['text'][:60]}") | |
| if len(segments) > 15: | |
| debug_lines.append(f" ... 共 {len(segments)} 段") | |
| # ================= 修改部分:生成同名 SRT 文件 ================= | |
| audio_basename = os.path.basename(audio_file) | |
| srt_filename = os.path.splitext(audio_basename)[0] + ".srt" | |
| srt_full_path = os.path.join(tempfile.gettempdir(), srt_filename) | |
| with open(srt_full_path, "w", encoding="utf-8") as f: | |
| f.write(srt_content) | |
| # =============================================================== | |
| return srt_content, f"对齐完成! 共 {len(segments)} 段", "\n".join(debug_lines), srt_full_path | |
| except Exception as e: | |
| error_detail = traceback.format_exc() | |
| debug_lines.append(f"\n❌ 错误: {e}\n{error_detail}") | |
| if os.path.exists(processed_audio_path): | |
| os.unlink(processed_audio_path) | |
| return "", f"处理出错: {str(e)}", "\n".join(debug_lines), None | |
| # ================== Gradio 界面 ================== | |
| with gr.Blocks(title="字幕自动打轴工具(双模型)") as demo: | |
| gr.Markdown(""" | |
| # 字幕自动打轴工具(支持双模型) | |
| 将音频与文本自动对齐,生成带精准时间轴的 SRT 字幕文件。 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| audio_input = gr.Audio(label="音频文件", type="filepath") | |
| text_input = gr.Textbox( | |
| label="文本内容(每行一个短句)", | |
| placeholder="今天天气真好。\n我们一起去公园吧。", | |
| lines=8, max_lines=20 | |
| ) | |
| text_file = gr.File(label="或上传文本文件 (.txt)", file_types=[".txt"]) | |
| language_choice = gr.Dropdown( | |
| label="音频语言", | |
| choices=["中文", "英文", "日语", "韩语", "法语", "德语", "俄语", "西班牙语", "意大利语", "葡萄牙语"], | |
| value="英文" | |
| ) | |
| model_choice = gr.Dropdown( | |
| label="对齐模型", | |
| choices=["CTC Forced Aligner", "Qwen3-ForcedAligner-0.6B"], | |
| value="CTC Forced Aligner" | |
| ) | |
| submit_btn = gr.Button("开始对齐", variant="primary") | |
| status_output = gr.Textbox(label="状态", interactive=False) | |
| with gr.Column(scale=2): | |
| srt_output = gr.Textbox( | |
| label="生成的 SRT 字幕", | |
| lines=18, max_lines=30, interactive=False, | |
| elem_classes=["srt-output"] | |
| ) | |
| srt_download = gr.File(label="下载 SRT 文件", interactive=False) | |
| with gr.Accordion("调试信息", open=False): | |
| debug_output = gr.Textbox(label="详细日志", lines=12, interactive=False) | |
| submit_btn.click( | |
| fn=process_alignment, | |
| inputs=[audio_input, text_input, text_file, language_choice, model_choice], | |
| outputs=[srt_output, status_output, debug_output, srt_download] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=5).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| css=""" | |
| .srt-output textarea { font-family: "Courier New", monospace; font-size: 13px; } | |
| footer { visibility: hidden; } | |
| """ | |
| ) |