import os import re import sys import tempfile import subprocess import traceback from typing import List, Tuple, Dict import torch import soundfile as sf import gradio as gr # ---------- 设备检测 ---------- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"运行设备: {DEVICE}") # ---------- 导入 CTC Forced Aligner ---------- try: import ctc_forced_aligner from ctc_forced_aligner import ( load_audio, load_alignment_model, generate_emissions, preprocess_text, get_alignments, ) import ctc_forced_aligner.alignment_utils as ctc_au import ctc_forced_aligner.text_utils as ctc_tu CTC_AVAILABLE = True print("✅ CTC Forced Aligner 已就绪") except ImportError: CTC_AVAILABLE = False print("⚠️ CTC Forced Aligner 不可用") # ---------- Qwen3 模型封装 ---------- QWEN_AVAILABLE = False try: from qwen_asr import Qwen3ForcedAligner QWEN_AVAILABLE = True print("✅ Qwen3-ForcedAligner 已就绪") except ImportError: print("⚠️ Qwen3-ForcedAligner 不可用") # ================== 核心算法 ================== def get_pure_text_length(text: str) -> int: """计算纯净字符数:去除所有标点、空格、控制字符后剩余的字符数。""" return len(re.sub( r'[^\w一-鿿぀-ゟ゠-ヿ]', '', str(text) ).lower()) def merge_token_timestamps_to_sentences( token_timestamps: List[Tuple[str, float, float]], target_sentences: List[str], debug: bool = False ) -> List[Dict]: """通过字符数累计将模型输出的词/字级时间戳匹配到预分段短句。""" if not target_sentences: return [] results = [] token_idx = 0 total_tokens = len(token_timestamps) for sent_idx, sentence in enumerate(target_sentences): t_len = get_pure_text_length(sentence) if t_len == 0: results.append({"text": sentence, "start": 0.0, "end": 0.0}) continue acc_len = 0 st, et = None, None while token_idx < total_tokens and acc_len < t_len: seg_text, seg_start, seg_end = token_timestamps[token_idx] if st is None: st = seg_start et = seg_end acc_len += get_pure_text_length(seg_text) token_idx += 1 if debug and sent_idx < 5: print(f" [{sent_idx}] \"{sentence[:50]}\" -> " f"tokens char_cnt={acc_len}/{t_len} " f"time={st:.2f}s-{et:.2f}s " if st else " ") if st is not None and et is not None: results.append({ "text": sentence, "start": round(st, 3), "end": round(et, 3), }) else: results.append({"text": sentence, "start": 0.0, "end": 0.0}) # 后处理:修复缺失/异常时间戳 for i in range(len(results)): if results[i]["start"] == 0.0 and results[i]["end"] == 0.0: for j in range(i - 1, -1, -1): if results[j]["end"] > 0: results[i]["start"] = results[j]["end"] results[i]["end"] = results[j]["end"] break if results[i]["start"] == 0.0: for j in range(i + 1, len(results)): if results[j]["start"] > 0: results[i]["start"] = results[j]["start"] results[i]["end"] = results[j]["start"] break for i in range(len(results)): if i > 0 and results[i]["start"] < results[i - 1]["end"]: results[i]["start"] = results[i - 1]["end"] if results[i]["end"] < results[i]["start"]: results[i]["end"] = results[i]["start"] + 0.001 if debug: non_zero = sum(1 for r in results if r["start"] > 0 or r["end"] > 0) print(f"时间戳覆盖率: {non_zero}/{len(results)} 句") return results def seconds_to_srt_time(seconds: float) -> str: seconds = max(0, seconds) h = int(seconds // 3600) m = int((seconds % 3600) // 60) s = int(seconds % 60) ms = int((seconds % 1) * 1000) return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" def format_srt(segments: List[Dict]) -> str: lines = [] index = 1 for seg in segments: text = seg["text"].strip() if not text: continue lines.append(str(index)) lines.append( f"{seconds_to_srt_time(seg['start'])} --> {seconds_to_srt_time(seg['end'])}" ) lines.append(text) lines.append("") index += 1 return "\n".join(lines) # ================== CTC 对齐封装(含容错补丁) ================== def run_ctc_alignment( audio_path: str, full_text: str, target_sentences: List[str], language: str = "eng" ) -> List[Dict]: """使用 CTC Forced Aligner 进行强制对齐(原补丁保留)""" dtype = torch.float16 if DEVICE == "cuda" else torch.float32 _original_get_spans = ctc_au.get_spans _original_postprocess = ctc_tu.postprocess_results def _relaxed_get_spans(tokens_starred, segments, blank_token): n_seg = len(segments) spans = [] si = 0 for token in tokens_starred: target_letters = token.split(" ") while si < n_seg and segments[si].label == blank_token: si += 1 start_seg_idx = si end_seg_idx = si matched_any = False for ltr in target_letters: while si < n_seg and segments[si].label == blank_token: si += 1 if si < n_seg and segments[si].label == ltr: if not matched_any: start_seg_idx = si end_seg_idx = si matched_any = True si += 1 if not matched_any: safe_idx = min(start_seg_idx, n_seg - 1) if n_seg > 0 else 0 spans.append([ctc_au.Segment(token, safe_idx, safe_idx)]) else: spans.append(segments[start_seg_idx : end_seg_idx + 1]) return spans def _safe_postprocess_results(text_starred, spans, stride, scores, merge_threshold=0.0): results = [] for i, t in enumerate(text_starred): if t == "": continue span = spans[i] if not span: continue seg_start_idx = span[0].start seg_end_idx = span[-1].end audio_start_sec = seg_start_idx * stride / 1000.0 audio_end_sec = seg_end_idx * stride / 1000.0 score = scores[seg_start_idx : seg_end_idx + 1].sum() if seg_end_idx >= seg_start_idx else 0.0 score_val = score.item() if hasattr(score, "item") else float(score) results.append({ "start": audio_start_sec, "end": audio_end_sec, "text": t, "score": score_val, }) ctc_tu.merge_segments(results, merge_threshold) return results ctc_au.get_spans = _relaxed_get_spans ctc_tu.postprocess_results = _safe_postprocess_results print(f"🚀 加载 CTC 对齐模型 (设备: {DEVICE})") alignment_model, alignment_tokenizer = load_alignment_model(DEVICE, dtype=dtype) print("🔄 加载音频...") audio_waveform = load_audio(audio_path, alignment_model.dtype, alignment_model.device) print("🔄 生成发射矩阵...") emissions, stride = generate_emissions(alignment_model, audio_waveform, batch_size=8) non_latin = {"cmn", "zho", "chi", "jpn", "ja", "kor", "ko", "ara", "ar", "rus", "ru"} needs_romanize = language in non_latin tokens_starred, text_starred = preprocess_text(full_text, romanize=needs_romanize, language=language) print("🔄 CTC 解码...") segments_raw, scores, blank_token = get_alignments(emissions, tokens_starred, alignment_tokenizer) print("🔄 获取时间跨度 (容错模式)...") spans = ctc_au.get_spans(tokens_starred, segments_raw, blank_token) results = ctc_tu.postprocess_results(text_starred, spans, stride, scores) token_timestamps = [(seg["text"], seg["start"], seg["end"]) for seg in results] print(f"模型输出 {len(token_timestamps)} 个词/字级时间戳") segments = merge_token_timestamps_to_sentences(token_timestamps, target_sentences, debug=True) ctc_au.get_spans = _original_get_spans ctc_tu.postprocess_results = _original_postprocess del alignment_model if DEVICE == "cuda": torch.cuda.empty_cache() return segments # ================== Qwen3 对齐封装 ================== def run_qwen_alignment( audio_path: str, full_text: str, target_sentences: List[str], language: str = "Chinese" ) -> List[Dict]: """ 使用 Qwen3-ForcedAligner-0.6B 进行强制对齐。 自动适配 CPU/GPU,支持长音频的滑动窗口切片。 """ # 根据设备设置精度和 device_map if DEVICE == "cuda": dtype = torch.bfloat16 device_map = "cuda:0" else: dtype = torch.float32 # CPU 不支持 bfloat16 推理 device_map = "cpu" print(f"🚀 加载 Qwen3-ForcedAligner-0.6B (设备: {DEVICE}, dtype: {dtype})") model = Qwen3ForcedAligner.from_pretrained( "Qwen/Qwen3-ForcedAligner-0.6B", dtype=dtype, device_map=device_map, ) # 读取音频 audio_data, sr = sf.read(audio_path) total_duration = len(audio_data) / sr print(f"📊 音频总时长: {total_duration:.1f}s") # 切片参数 MAX_CHUNK_DUR = 240.0 # 每次最多 4 分钟 SAFE_TAIL_MARGIN = 15.0 # 丢弃末尾 15s 的不完整句子 remaining = list(target_sentences) time_offset = 0.0 all_segments = [] chunk_idx = 0 while remaining and time_offset < total_duration: chunk_idx += 1 chunk_dur = min(MAX_CHUNK_DUR, total_duration - time_offset) is_last = (time_offset + chunk_dur >= total_duration - 1.0) start_f = int(time_offset * sr) end_f = int((time_offset + chunk_dur) * sr) chunk_audio = audio_data[start_f:end_f] with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: sf.write(f.name, chunk_audio, sr) chunk_path = f.name chunk_text = " ".join(remaining) print(f"\n▶️ Chunk {chunk_idx}: 音频[{time_offset:.0f}s-{time_offset + chunk_dur:.0f}s] " f"剩余{len(remaining)}句") results = model.align(audio=chunk_path, text=chunk_text, language=language) tokens = results[0] # List[AlignmentResult] token_data = [] for seg in tokens: try: token_data.append((seg.text, seg.start_time, seg.end_time)) except AttributeError: d = vars(seg) if hasattr(seg, '__dict__') else {} token_data.append(( d.get('text', d.get('token', d.get('word', ''))), d.get('start_time', d.get('start', 0.0)), d.get('end_time', d.get('end', 0.0)), )) # 用字符计数法匹配句子 matched = [] ti = 0 for sentence in remaining: t_len = get_pure_text_length(sentence) if t_len == 0: continue acc = 0 st, et = None, None while ti < len(token_data) and acc < t_len: seg_text, seg_start, seg_end = token_data[ti] if st is None: st = seg_start et = seg_end acc += get_pure_text_length(seg_text) ti += 1 if st is not None and et is not None: matched.append({"text": sentence, "start": st, "end": et}) # 安全切分点 if is_last: valid = matched remaining = [] else: valid_idx = -1 for i, m in enumerate(matched): if m["end"] < (chunk_dur - SAFE_TAIL_MARGIN): valid_idx = i else: break if valid_idx == -1 and matched: valid_idx = 0 valid = matched[:valid_idx + 1] if valid_idx >= 0 else [] remaining = remaining[valid_idx + 1:] if valid_idx >= 0 else [] print(f" 本段对齐 {len(valid)} 句(共{len(matched)}句匹配)") for m in valid: all_segments.append({ "text": m["text"], "start": round(m["start"] + time_offset, 3), "end": round(m["end"] + time_offset, 3), }) if valid: time_offset = time_offset + valid[-1]["end"] else: time_offset = total_duration os.unlink(chunk_path) if DEVICE == "cuda": torch.cuda.empty_cache() del model if DEVICE == "cuda": torch.cuda.empty_cache() print(f"\n✅ Qwen 对齐完成:{len(all_segments)} 句") return all_segments # ================== 音频格式转换 ================== def convert_to_wav(input_audio_path: str) -> str: """使用 ffmpeg 转换为 16kHz 单声道 wav""" tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) tmp_wav.close() cmd = [ "ffmpeg", "-y", "-i", input_audio_path, "-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", "-loglevel", "error", tmp_wav.name ] try: subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return tmp_wav.name except subprocess.CalledProcessError as e: raise RuntimeError(f"FFmpeg 转换失败: {e.stderr.decode('utf-8', errors='ignore')}") # ================== 主处理函数 ================== def process_alignment( audio_file, text_input: str, text_file, language: str, model_choice: str ): debug_lines = [] if audio_file is None: return "", "请上传音频文件", "", None # 读取文本 raw_text = "" if text_file is not None: try: file_path = text_file if isinstance(text_file, str) else ( text_file.get("name", "") if isinstance(text_file, dict) else getattr(text_file, "name", "") ) if file_path and os.path.exists(file_path): with open(file_path, "r", encoding="utf-8") as f: raw_text = f.read() debug_lines.append(f"从文件读取文本 ({len(raw_text)} 字符)") except Exception as e: debug_lines.append(f"读取文本文件失败: {e}") if not raw_text and text_input: raw_text = text_input if not raw_text or not raw_text.strip(): return "", "请输入文本或上传文本文件", "", None target_sentences = [line.strip() for line in raw_text.strip().splitlines() if line.strip()] if not target_sentences: return "", "文本为空或格式不正确(每行一个短句)", "", None full_text = " ".join(target_sentences) lang_map = { "中文": "cmn", "英文": "eng", "日语": "jpn", "韩语": "kor", "法语": "fra", "德语": "deu", "俄语": "rus", "西班牙语": "spa", "意大利语": "ita", "葡萄牙语": "por", } lang = lang_map.get(language, "cmn") # Qwen 模型的语言映射(将 UI 的中文选项映射为模型需要的英文标识) qwen_lang_map = { "中文": "Chinese", "英文": "English", "日语": "Japanese", "韩语": "Korean", "法语": "French", "德语": "German", "俄语": "Russian", "西班牙语": "Spanish", "意大利语": "Italian", "葡萄牙语": "Portuguese", } # 如果选择了不支持的语言,默认回退到 English (或 Chinese,视 Qwen3 模型的具体支持情况而定) qwen_lang = qwen_lang_map.get(language, "English") debug_lines.append(f"音频: {audio_file}") debug_lines.append(f"语言: {language} (内部代码: {lang})") debug_lines.append(f"选用模型: {model_choice}") debug_lines.append(f"句子数: {len(target_sentences)}") # 音频转换 try: processed_audio_path = convert_to_wav(audio_file) debug_lines.append("✅ 音频格式转换完成") except Exception as e: debug_lines.append(f"❌ 音频转码失败: {e}") return "", "音频转码失败,请上传有效文件", "\n".join(debug_lines), None # 选择模型执行对齐 try: if model_choice == "CTC Forced Aligner": if not CTC_AVAILABLE: return "", "CTC 模型未安装,请检查依赖。", "\n".join(debug_lines), None segments = run_ctc_alignment(processed_audio_path, full_text, target_sentences, lang) else: # Qwen3 if not QWEN_AVAILABLE: return "", "Qwen3 模型未安装,请检查依赖。", "\n".join(debug_lines), None segments = run_qwen_alignment(processed_audio_path, full_text, target_sentences, qwen_lang) os.unlink(processed_audio_path) srt_content = format_srt(segments) debug_lines.append(f"\n🎉 对齐完成! 共 {len(segments)} 段") for seg in segments[:15]: debug_lines.append(f" [{seg['start']:.2f}s - {seg['end']:.2f}s] {seg['text'][:60]}") if len(segments) > 15: debug_lines.append(f" ... 共 {len(segments)} 段") # ================= 修改部分:生成同名 SRT 文件 ================= audio_basename = os.path.basename(audio_file) srt_filename = os.path.splitext(audio_basename)[0] + ".srt" srt_full_path = os.path.join(tempfile.gettempdir(), srt_filename) with open(srt_full_path, "w", encoding="utf-8") as f: f.write(srt_content) # =============================================================== return srt_content, f"对齐完成! 共 {len(segments)} 段", "\n".join(debug_lines), srt_full_path except Exception as e: error_detail = traceback.format_exc() debug_lines.append(f"\n❌ 错误: {e}\n{error_detail}") if os.path.exists(processed_audio_path): os.unlink(processed_audio_path) return "", f"处理出错: {str(e)}", "\n".join(debug_lines), None # ================== Gradio 界面 ================== with gr.Blocks(title="字幕自动打轴工具(双模型)") as demo: gr.Markdown(""" # 字幕自动打轴工具(支持双模型) 将音频与文本自动对齐,生成带精准时间轴的 SRT 字幕文件。 """) with gr.Row(): with gr.Column(scale=2): audio_input = gr.Audio(label="音频文件", type="filepath") text_input = gr.Textbox( label="文本内容(每行一个短句)", placeholder="今天天气真好。\n我们一起去公园吧。", lines=8, max_lines=20 ) text_file = gr.File(label="或上传文本文件 (.txt)", file_types=[".txt"]) language_choice = gr.Dropdown( label="音频语言", choices=["中文", "英文", "日语", "韩语", "法语", "德语", "俄语", "西班牙语", "意大利语", "葡萄牙语"], value="英文" ) model_choice = gr.Dropdown( label="对齐模型", choices=["CTC Forced Aligner", "Qwen3-ForcedAligner-0.6B"], value="CTC Forced Aligner" ) submit_btn = gr.Button("开始对齐", variant="primary") status_output = gr.Textbox(label="状态", interactive=False) with gr.Column(scale=2): srt_output = gr.Textbox( label="生成的 SRT 字幕", lines=18, max_lines=30, interactive=False, elem_classes=["srt-output"] ) srt_download = gr.File(label="下载 SRT 文件", interactive=False) with gr.Accordion("调试信息", open=False): debug_output = gr.Textbox(label="详细日志", lines=12, interactive=False) submit_btn.click( fn=process_alignment, inputs=[audio_input, text_input, text_file, language_choice, model_choice], outputs=[srt_output, status_output, debug_output, srt_download] ) if __name__ == "__main__": demo.queue(max_size=5).launch( server_name="0.0.0.0", server_port=7860, share=False, css=""" .srt-output textarea { font-family: "Courier New", monospace; font-size: 13px; } footer { visibility: hidden; } """ )