Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| from utils import VideoUtils, SettingsManager | |
| from stt_module import STTManager | |
| from translator import Translator | |
| import tempfile | |
| from dotenv import load_dotenv | |
| # 加载 .env 文件中的配置 | |
| load_dotenv() | |
| # 加载本地保存的设置 | |
| saved_settings = SettingsManager.load_settings() | |
| # 取消 Streamlit 上传限制(虽然主要通过命令行配置,但在脚本中提醒用户) | |
| # 实际上 Streamlit 的服务器配置需要在运行命令时指定,或者写在 config.toml 中 | |
| # 我们这里先通过 UI 提醒用户 | |
| st.set_page_config(page_title="AI 媒体翻译专家", layout="wide") | |
| # 初始化 session_state | |
| if "process_results" not in st.session_state: | |
| st.session_state.process_results = {} | |
| if "srt_results" not in st.session_state: | |
| st.session_state.srt_results = {} | |
| if "awaiting_synthesis" not in st.session_state: | |
| st.session_state.awaiting_synthesis = False | |
| if "processing" not in st.session_state: | |
| st.session_state.processing = False | |
| if "last_uploaded_file" not in st.session_state: | |
| st.session_state.last_uploaded_file = None | |
| if "last_uploaded_srt" not in st.session_state: | |
| st.session_state.last_uploaded_srt = None | |
| # 回调函数:开始处理 | |
| def start_processing_callback(): | |
| st.session_state.processing = True | |
| st.session_state.process_results = {} | |
| st.session_state.awaiting_synthesis = False | |
| # 回调函数:开始合成 | |
| def start_synthesis_callback(): | |
| st.session_state.processing = True | |
| # 回调函数:开始翻译字幕 | |
| def start_translation_callback(): | |
| st.session_state.processing = True | |
| st.session_state.srt_results = {} | |
| st.title("🎬 AI 媒体字幕提取与翻译") | |
| # 侧边栏配置 | |
| with st.sidebar: | |
| st.header("⚙️ 配置参数") | |
| # 从本地保存的设置或环境变量读取默认值 | |
| default_api_key = saved_settings.get("API_KEY", os.getenv("API_KEY", "")) | |
| default_base_url = saved_settings.get("BASE_URL", os.getenv("BASE_URL", "https://api.siliconflow.cn/v1")) | |
| default_model = saved_settings.get("MODEL_NAME", os.getenv("MODEL_NAME", "THUDM/glm-4-9b-chat")) | |
| default_stt_size = saved_settings.get("STT_MODEL_SIZE", os.getenv("STT_MODEL_SIZE", "base")) | |
| default_lang = saved_settings.get("TARGET_LANG", os.getenv("TARGET_LANG", "中文")) | |
| default_device = saved_settings.get("DEVICE", os.getenv("DEVICE", "cpu")) | |
| api_key = st.text_input("API Key", value=default_api_key, type="password", help="输入 API Key", disabled=st.session_state.processing) | |
| base_url = st.text_input("API Base URL", value=default_base_url, disabled=st.session_state.processing) | |
| model_name = st.text_input("模型名称", value=default_model, disabled=st.session_state.processing) | |
| st.divider() | |
| # 设备选择 | |
| cuda_available = STTManager.is_cuda_available() | |
| device_options = ["cpu", "cuda"] if cuda_available else ["cpu"] | |
| device_index = device_options.index(default_device) if default_device in device_options else 0 | |
| device = st.selectbox( | |
| "计算设备 (Device)", | |
| device_options, | |
| index=device_index, | |
| help="如果有 NVIDIA 显卡且安装了 CUDA,选择 'cuda' 会大幅提升速度。", | |
| disabled=st.session_state.processing | |
| ) | |
| # 根据设备推荐精度 | |
| default_compute_type = "float16" if device == "cuda" else "int8" | |
| compute_type = st.selectbox( | |
| "计算精度 (Compute Type)", | |
| ["int8", "float16", "int8_float16"], | |
| index=1 if device == "cuda" else 0, | |
| help="CPU 推荐 int8,GPU 推荐 float16。", | |
| disabled=st.session_state.processing | |
| ) | |
| stt_options = ["tiny", "base", "small", "medium", "large-v3"] | |
| # 获取已下载模型 | |
| downloaded_models = STTManager.get_downloaded_models() | |
| # 构建选项显示名称 | |
| option_labels = [] | |
| for opt in stt_options: | |
| label = f"{opt} (已下载)" if opt in downloaded_models else opt | |
| option_labels.append(label) | |
| stt_index = stt_options.index(default_stt_size) if default_stt_size in stt_options else 1 | |
| selected_option = st.selectbox( | |
| "本地 STT 模型大小", | |
| option_labels, | |
| index=stt_index, | |
| help="越大越准,但速度越慢。首次使用未下载的模型时会自动下载。", | |
| disabled=st.session_state.processing | |
| ) | |
| # 从选项中提取真实模型名 | |
| model_size = selected_option.split(" ")[0] | |
| lang_options = ["中文", "English", "日本語", "Français"] | |
| lang_index = lang_options.index(default_lang) if default_lang in lang_options else 0 | |
| target_lang = st.selectbox("目标语言", lang_options, index=lang_index, disabled=st.session_state.processing) | |
| # 保存当前配置到本地文件的辅助函数 | |
| def save_current_settings(): | |
| SettingsManager.save_settings({ | |
| "API_KEY": api_key, | |
| "BASE_URL": base_url, | |
| "MODEL_NAME": model_name, | |
| "STT_MODEL_SIZE": model_size, | |
| "TARGET_LANG": target_lang, | |
| "DEVICE": device | |
| }) | |
| if st.session_state.processing: | |
| st.divider() | |
| if st.button("⏹️ 强制停止任务", type="secondary", use_container_width=True): | |
| st.session_state.processing = False | |
| st.session_state.awaiting_synthesis = False | |
| st.rerun() | |
| # 主界面 | |
| tab1, tab2 = st.tabs(["📁 媒体处理", "📜 字幕翻译"]) | |
| with tab1: | |
| st.info("💡 提示:支持视频和音频文件。如果文件非常大,处理会很慢请耐心等待。") | |
| uploaded_file = st.file_uploader("选择视频或音频文件", type=["mp4", "mkv", "avi", "mov", "mp3", "wav", "m4a", "flac", "aac"], disabled=st.session_state.processing) | |
| if uploaded_file: | |
| # 检测文件类型 | |
| video_extensions = [".mp4", ".mkv", ".avi", ".mov"] | |
| file_ext = os.path.splitext(uploaded_file.name)[1].lower() | |
| is_video = file_ext in video_extensions | |
| # 如果上传了新文件且与 session_state 中记录的不同,则清除旧结果 | |
| if "last_uploaded_file" not in st.session_state or st.session_state.last_uploaded_file != uploaded_file.name: | |
| st.session_state.process_results = {} | |
| st.session_state.awaiting_synthesis = False | |
| st.session_state.last_uploaded_file = uploaded_file.name | |
| # 保存上传的文件到临时目录 | |
| tfile = tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) | |
| tfile.write(uploaded_file.read()) | |
| file_path = tfile.name | |
| if is_video: | |
| st.video(file_path) | |
| else: | |
| st.audio(file_path) | |
| if st.button("开始处理", | |
| disabled=not api_key or st.session_state.awaiting_synthesis or st.session_state.processing, | |
| on_click=start_processing_callback): | |
| try: | |
| with st.status("正在处理中...", expanded=True) as status: | |
| # 1. 提取/准备音频 | |
| st.write("🎵 正在准备音频...") | |
| audio_path = os.path.splitext(file_path)[0] + '.wav' | |
| try: | |
| # 无论视频还是音频,都通过 prepare_audio (ffmpeg) 转换为标准格式,确保 STT 兼容性 | |
| VideoUtils.prepare_audio(file_path, audio_path) | |
| except Exception as e: | |
| st.error(str(e)) | |
| status.update(label="处理出错", state="error", expanded=True) | |
| st.session_state.processing = False | |
| st.stop() | |
| # 2. 本地 STT | |
| if model_size not in downloaded_models: | |
| st.write(f"📥 正在下载 {model_size} 模型,请稍候...") | |
| st.write(f"✍️ 正在识别语音 (使用 {model_size} 模型,设备: {device})...") | |
| stt_manager = STTManager(model_size=model_size, device=device, compute_type=compute_type) | |
| stt_manager.load_model() | |
| segments_gen, info = stt_manager.transcribe(audio_path) | |
| st.write(f"检测到语言: {info.language} (置信度: {info.language_probability:.2f})") | |
| # 增量处理与展示 | |
| st.write("---") | |
| st.write("实时识别与翻译预览:") | |
| preview_container = st.empty() | |
| all_segments = [] | |
| all_translated_segments = [] | |
| translator = Translator(api_key, base_url, model_name) | |
| # 用于展示的表格数据 | |
| display_data = [] | |
| for segment in segments_gen: | |
| # 1. 翻译当前段落 | |
| trans_text = translator.translate_text(segment.text, target_lang) | |
| # 2. 保存原始和翻译后的段落 | |
| all_segments.append(segment) | |
| new_trans_seg = type('Segment', (), { | |
| 'start': segment.start, | |
| 'end': segment.end, | |
| 'text': trans_text | |
| }) | |
| all_translated_segments.append(new_trans_seg) | |
| # 3. 更新预览界面 | |
| time_str = f"{VideoUtils.format_timestamp(segment.start)} -> {VideoUtils.format_timestamp(segment.end)}" | |
| display_data.append({ | |
| "时间轴": time_str, | |
| "原文": segment.text, | |
| "翻译": trans_text | |
| }) | |
| # 仅显示最后 5 条,避免页面过长,但提供滚动查看全部的可能 | |
| preview_container.table(display_data[-5:]) | |
| # 生成原始字幕 | |
| orig_srt_path = os.path.splitext(file_path)[0] + '_orig.srt' | |
| VideoUtils.write_srt(all_segments, orig_srt_path) | |
| # 生成翻译字幕 | |
| trans_srt_path = os.path.splitext(file_path)[0] + '_trans.srt' | |
| VideoUtils.write_srt(all_translated_segments, trans_srt_path) | |
| status.update(label="处理完成!", state="complete", expanded=False) | |
| # 成功处理后自动保存配置 | |
| save_current_settings() | |
| # 保存结果到 session_state | |
| st.session_state.process_results = { | |
| "orig_srt": orig_srt_path, | |
| "trans_srt": trans_srt_path, | |
| "output_video": None, | |
| "is_video": is_video | |
| } | |
| # 保存中间结果用于后续合成 (仅针对视频) | |
| if is_video: | |
| st.session_state.temp_video_path = file_path | |
| st.session_state.temp_trans_srt_path = trans_srt_path | |
| st.session_state.temp_orig_srt_path = orig_srt_path | |
| st.session_state.awaiting_synthesis = True | |
| st.session_state.processing = False | |
| st.rerun() | |
| except Exception as e: | |
| st.session_state.processing = False | |
| st.error(f"处理过程中发生未预期的错误: {e}") | |
| st.stop() | |
| # 4. 嵌入视频 (仅视频且在等待状态) | |
| if st.session_state.awaiting_synthesis and st.session_state.process_results.get("is_video"): | |
| st.success("✅ 字幕翻译已完成!") | |
| col_synth_1, col_synth_2 = st.columns(2) | |
| with col_synth_1: | |
| if st.button("🚀 开始合成视频字幕", | |
| type="primary", | |
| disabled=st.session_state.processing, | |
| on_click=start_synthesis_callback): | |
| try: | |
| with st.status("🎬 正在合成视频字幕...", expanded=True) as status: | |
| v_path = st.session_state.temp_video_path | |
| s_path = st.session_state.temp_trans_srt_path | |
| output_video_path = os.path.splitext(v_path)[0] + '_translated.mp4' | |
| video_ready = False | |
| try: | |
| VideoUtils.embed_subtitles(v_path, s_path, output_video_path) | |
| video_ready = True | |
| st.write("✨ 视频合成成功!") | |
| except Exception as e: | |
| st.error(f"视频合成失败 (请确保已安装 FFmpeg): {e}") | |
| status.update(label="全部处理完成!", state="complete", expanded=False) | |
| # 保存最终结果到 session_state | |
| st.session_state.process_results["output_video"] = output_video_path if video_ready else None | |
| st.session_state.awaiting_synthesis = False | |
| st.session_state.processing = False | |
| st.rerun() | |
| except Exception as e: | |
| st.session_state.processing = False | |
| st.error(f"合成过程中发生未预期的错误: {e}") | |
| st.stop() | |
| with col_synth_2: | |
| if st.button("📂 仅保存字幕"): | |
| st.session_state.awaiting_synthesis = False | |
| st.rerun() | |
| # 结果展示与下载 (移出 button 缩进块,始终根据 session_state 显示) | |
| if st.session_state.process_results: | |
| st.divider() | |
| col_title, col_clear = st.columns([5, 1]) | |
| with col_title: | |
| st.subheader("🎉 处理结果") | |
| with col_clear: | |
| if st.button("🗑️ 清除结果"): | |
| st.session_state.process_results = {} | |
| st.session_state.awaiting_synthesis = False | |
| st.rerun() | |
| col1, col2, col3 = st.columns(3) | |
| results = st.session_state.process_results | |
| if os.path.exists(results.get("orig_srt", "")): | |
| with col1: | |
| with open(results["orig_srt"], "rb") as f: | |
| st.download_button("⬇️ 下载原始字幕", f, file_name="original.srt", key="dl_orig") | |
| if os.path.exists(results.get("trans_srt", "")): | |
| with col2: | |
| with open(results["trans_srt"], "rb") as f: | |
| st.download_button("⬇️ 下载翻译字幕", f, file_name="translated.srt", key="dl_trans") | |
| if results.get("output_video") and os.path.exists(results["output_video"]): | |
| with col3: | |
| with open(results["output_video"], "rb") as f: | |
| st.download_button("⬇️ 下载翻译视频", f, file_name="video_with_subtitles.mp4", key="dl_video") | |
| with tab2: | |
| st.info("如果你已经有原始语言的字幕文件(SRT),可以在这里直接进行翻译。") | |
| uploaded_srt = st.file_uploader("上传原始 SRT 字幕", type=["srt"], disabled=st.session_state.processing) | |
| if uploaded_srt: | |
| # 如果上传了新字幕且与 session_state 中记录的不同,则清除旧结果 | |
| if "last_uploaded_srt" not in st.session_state or st.session_state.last_uploaded_srt != uploaded_srt.name: | |
| st.session_state.srt_results = {} | |
| st.session_state.last_uploaded_srt = uploaded_srt.name | |
| srt_content = uploaded_srt.read().decode("utf-8") | |
| st.text_area("字幕预览", srt_content, height=200) | |
| if st.button("开始翻译字幕", | |
| disabled=not api_key or st.session_state.processing, | |
| on_click=start_translation_callback): | |
| try: | |
| with st.spinner("正在翻译字幕..."): | |
| segments = VideoUtils.parse_srt(srt_content) | |
| translator = Translator(api_key, base_url, model_name) | |
| translated_segments = translator.translate_segments(segments, target_lang) | |
| # 保存到临时文件 | |
| temp_trans_srt = tempfile.NamedTemporaryFile(delete=False, suffix='.srt') | |
| VideoUtils.write_srt(translated_segments, temp_trans_srt.name) | |
| st.session_state.srt_results = { | |
| "trans_srt": temp_trans_srt.name | |
| } | |
| st.success("翻译完成!") | |
| # 成功处理后自动保存配置 | |
| save_current_settings() | |
| st.session_state.processing = False | |
| st.rerun() | |
| except Exception as e: | |
| st.session_state.processing = False | |
| st.error(f"翻译过程中发生错误: {e}") | |
| st.stop() | |
| # 结果下载 | |
| if st.session_state.srt_results: | |
| st.divider() | |
| res = st.session_state.srt_results | |
| if os.path.exists(res.get("trans_srt", "")): | |
| with open(res["trans_srt"], "rb") as f: | |
| st.download_button("⬇️ 下载翻译后的字幕", f, file_name="translated_only.srt", key="dl_srt_tab") | |
| if not api_key: | |
| st.warning("请在左侧边栏输入 API Key 后开始。") | |