import streamlit as st import os from utils import VideoUtils, SettingsManager from stt_module import STTManager from translator import Translator import tempfile from dotenv import load_dotenv # 加载 .env 文件中的配置 load_dotenv() # 加载本地保存的设置 saved_settings = SettingsManager.load_settings() # 取消 Streamlit 上传限制(虽然主要通过命令行配置,但在脚本中提醒用户) # 实际上 Streamlit 的服务器配置需要在运行命令时指定,或者写在 config.toml 中 # 我们这里先通过 UI 提醒用户 st.set_page_config(page_title="AI 媒体翻译专家", layout="wide") # 初始化 session_state if "process_results" not in st.session_state: st.session_state.process_results = {} if "srt_results" not in st.session_state: st.session_state.srt_results = {} if "awaiting_synthesis" not in st.session_state: st.session_state.awaiting_synthesis = False if "processing" not in st.session_state: st.session_state.processing = False if "last_uploaded_file" not in st.session_state: st.session_state.last_uploaded_file = None if "last_uploaded_srt" not in st.session_state: st.session_state.last_uploaded_srt = None # 回调函数:开始处理 def start_processing_callback(): st.session_state.processing = True st.session_state.process_results = {} st.session_state.awaiting_synthesis = False # 回调函数:开始合成 def start_synthesis_callback(): st.session_state.processing = True # 回调函数:开始翻译字幕 def start_translation_callback(): st.session_state.processing = True st.session_state.srt_results = {} st.title("🎬 AI 媒体字幕提取与翻译") # 侧边栏配置 with st.sidebar: st.header("⚙️ 配置参数") # 从本地保存的设置或环境变量读取默认值 default_api_key = saved_settings.get("API_KEY", os.getenv("API_KEY", "")) default_base_url = saved_settings.get("BASE_URL", os.getenv("BASE_URL", "https://api.siliconflow.cn/v1")) default_model = saved_settings.get("MODEL_NAME", os.getenv("MODEL_NAME", "THUDM/glm-4-9b-chat")) default_stt_size = saved_settings.get("STT_MODEL_SIZE", os.getenv("STT_MODEL_SIZE", "base")) default_lang = saved_settings.get("TARGET_LANG", os.getenv("TARGET_LANG", "中文")) default_device = saved_settings.get("DEVICE", os.getenv("DEVICE", "cpu")) api_key = st.text_input("API Key", value=default_api_key, type="password", help="输入 API Key", disabled=st.session_state.processing) base_url = st.text_input("API Base URL", value=default_base_url, disabled=st.session_state.processing) model_name = st.text_input("模型名称", value=default_model, disabled=st.session_state.processing) st.divider() # 设备选择 cuda_available = STTManager.is_cuda_available() device_options = ["cpu", "cuda"] if cuda_available else ["cpu"] device_index = device_options.index(default_device) if default_device in device_options else 0 device = st.selectbox( "计算设备 (Device)", device_options, index=device_index, help="如果有 NVIDIA 显卡且安装了 CUDA,选择 'cuda' 会大幅提升速度。", disabled=st.session_state.processing ) # 根据设备推荐精度 default_compute_type = "float16" if device == "cuda" else "int8" compute_type = st.selectbox( "计算精度 (Compute Type)", ["int8", "float16", "int8_float16"], index=1 if device == "cuda" else 0, help="CPU 推荐 int8,GPU 推荐 float16。", disabled=st.session_state.processing ) stt_options = ["tiny", "base", "small", "medium", "large-v3"] # 获取已下载模型 downloaded_models = STTManager.get_downloaded_models() # 构建选项显示名称 option_labels = [] for opt in stt_options: label = f"{opt} (已下载)" if opt in downloaded_models else opt option_labels.append(label) stt_index = stt_options.index(default_stt_size) if default_stt_size in stt_options else 1 selected_option = st.selectbox( "本地 STT 模型大小", option_labels, index=stt_index, help="越大越准,但速度越慢。首次使用未下载的模型时会自动下载。", disabled=st.session_state.processing ) # 从选项中提取真实模型名 model_size = selected_option.split(" ")[0] lang_options = ["中文", "English", "日本語", "Français"] lang_index = lang_options.index(default_lang) if default_lang in lang_options else 0 target_lang = st.selectbox("目标语言", lang_options, index=lang_index, disabled=st.session_state.processing) # 保存当前配置到本地文件的辅助函数 def save_current_settings(): SettingsManager.save_settings({ "API_KEY": api_key, "BASE_URL": base_url, "MODEL_NAME": model_name, "STT_MODEL_SIZE": model_size, "TARGET_LANG": target_lang, "DEVICE": device }) if st.session_state.processing: st.divider() if st.button("⏹️ 强制停止任务", type="secondary", use_container_width=True): st.session_state.processing = False st.session_state.awaiting_synthesis = False st.rerun() # 主界面 tab1, tab2 = st.tabs(["📁 媒体处理", "📜 字幕翻译"]) with tab1: st.info("💡 提示:支持视频和音频文件。如果文件非常大,处理会很慢请耐心等待。") uploaded_file = st.file_uploader("选择视频或音频文件", type=["mp4", "mkv", "avi", "mov", "mp3", "wav", "m4a", "flac", "aac"], disabled=st.session_state.processing) if uploaded_file: # 检测文件类型 video_extensions = [".mp4", ".mkv", ".avi", ".mov"] file_ext = os.path.splitext(uploaded_file.name)[1].lower() is_video = file_ext in video_extensions # 如果上传了新文件且与 session_state 中记录的不同,则清除旧结果 if "last_uploaded_file" not in st.session_state or st.session_state.last_uploaded_file != uploaded_file.name: st.session_state.process_results = {} st.session_state.awaiting_synthesis = False st.session_state.last_uploaded_file = uploaded_file.name # 保存上传的文件到临时目录 tfile = tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) tfile.write(uploaded_file.read()) file_path = tfile.name if is_video: st.video(file_path) else: st.audio(file_path) if st.button("开始处理", disabled=not api_key or st.session_state.awaiting_synthesis or st.session_state.processing, on_click=start_processing_callback): try: with st.status("正在处理中...", expanded=True) as status: # 1. 提取/准备音频 st.write("🎵 正在准备音频...") audio_path = os.path.splitext(file_path)[0] + '.wav' try: # 无论视频还是音频,都通过 prepare_audio (ffmpeg) 转换为标准格式,确保 STT 兼容性 VideoUtils.prepare_audio(file_path, audio_path) except Exception as e: st.error(str(e)) status.update(label="处理出错", state="error", expanded=True) st.session_state.processing = False st.stop() # 2. 本地 STT if model_size not in downloaded_models: st.write(f"📥 正在下载 {model_size} 模型,请稍候...") st.write(f"✍️ 正在识别语音 (使用 {model_size} 模型,设备: {device})...") stt_manager = STTManager(model_size=model_size, device=device, compute_type=compute_type) stt_manager.load_model() segments_gen, info = stt_manager.transcribe(audio_path) st.write(f"检测到语言: {info.language} (置信度: {info.language_probability:.2f})") # 增量处理与展示 st.write("---") st.write("实时识别与翻译预览:") preview_container = st.empty() all_segments = [] all_translated_segments = [] translator = Translator(api_key, base_url, model_name) # 用于展示的表格数据 display_data = [] for segment in segments_gen: # 1. 翻译当前段落 trans_text = translator.translate_text(segment.text, target_lang) # 2. 保存原始和翻译后的段落 all_segments.append(segment) new_trans_seg = type('Segment', (), { 'start': segment.start, 'end': segment.end, 'text': trans_text }) all_translated_segments.append(new_trans_seg) # 3. 更新预览界面 time_str = f"{VideoUtils.format_timestamp(segment.start)} -> {VideoUtils.format_timestamp(segment.end)}" display_data.append({ "时间轴": time_str, "原文": segment.text, "翻译": trans_text }) # 仅显示最后 5 条,避免页面过长,但提供滚动查看全部的可能 preview_container.table(display_data[-5:]) # 生成原始字幕 orig_srt_path = os.path.splitext(file_path)[0] + '_orig.srt' VideoUtils.write_srt(all_segments, orig_srt_path) # 生成翻译字幕 trans_srt_path = os.path.splitext(file_path)[0] + '_trans.srt' VideoUtils.write_srt(all_translated_segments, trans_srt_path) status.update(label="处理完成!", state="complete", expanded=False) # 成功处理后自动保存配置 save_current_settings() # 保存结果到 session_state st.session_state.process_results = { "orig_srt": orig_srt_path, "trans_srt": trans_srt_path, "output_video": None, "is_video": is_video } # 保存中间结果用于后续合成 (仅针对视频) if is_video: st.session_state.temp_video_path = file_path st.session_state.temp_trans_srt_path = trans_srt_path st.session_state.temp_orig_srt_path = orig_srt_path st.session_state.awaiting_synthesis = True st.session_state.processing = False st.rerun() except Exception as e: st.session_state.processing = False st.error(f"处理过程中发生未预期的错误: {e}") st.stop() # 4. 嵌入视频 (仅视频且在等待状态) if st.session_state.awaiting_synthesis and st.session_state.process_results.get("is_video"): st.success("✅ 字幕翻译已完成!") col_synth_1, col_synth_2 = st.columns(2) with col_synth_1: if st.button("🚀 开始合成视频字幕", type="primary", disabled=st.session_state.processing, on_click=start_synthesis_callback): try: with st.status("🎬 正在合成视频字幕...", expanded=True) as status: v_path = st.session_state.temp_video_path s_path = st.session_state.temp_trans_srt_path output_video_path = os.path.splitext(v_path)[0] + '_translated.mp4' video_ready = False try: VideoUtils.embed_subtitles(v_path, s_path, output_video_path) video_ready = True st.write("✨ 视频合成成功!") except Exception as e: st.error(f"视频合成失败 (请确保已安装 FFmpeg): {e}") status.update(label="全部处理完成!", state="complete", expanded=False) # 保存最终结果到 session_state st.session_state.process_results["output_video"] = output_video_path if video_ready else None st.session_state.awaiting_synthesis = False st.session_state.processing = False st.rerun() except Exception as e: st.session_state.processing = False st.error(f"合成过程中发生未预期的错误: {e}") st.stop() with col_synth_2: if st.button("📂 仅保存字幕"): st.session_state.awaiting_synthesis = False st.rerun() # 结果展示与下载 (移出 button 缩进块,始终根据 session_state 显示) if st.session_state.process_results: st.divider() col_title, col_clear = st.columns([5, 1]) with col_title: st.subheader("🎉 处理结果") with col_clear: if st.button("🗑️ 清除结果"): st.session_state.process_results = {} st.session_state.awaiting_synthesis = False st.rerun() col1, col2, col3 = st.columns(3) results = st.session_state.process_results if os.path.exists(results.get("orig_srt", "")): with col1: with open(results["orig_srt"], "rb") as f: st.download_button("⬇️ 下载原始字幕", f, file_name="original.srt", key="dl_orig") if os.path.exists(results.get("trans_srt", "")): with col2: with open(results["trans_srt"], "rb") as f: st.download_button("⬇️ 下载翻译字幕", f, file_name="translated.srt", key="dl_trans") if results.get("output_video") and os.path.exists(results["output_video"]): with col3: with open(results["output_video"], "rb") as f: st.download_button("⬇️ 下载翻译视频", f, file_name="video_with_subtitles.mp4", key="dl_video") with tab2: st.info("如果你已经有原始语言的字幕文件(SRT),可以在这里直接进行翻译。") uploaded_srt = st.file_uploader("上传原始 SRT 字幕", type=["srt"], disabled=st.session_state.processing) if uploaded_srt: # 如果上传了新字幕且与 session_state 中记录的不同,则清除旧结果 if "last_uploaded_srt" not in st.session_state or st.session_state.last_uploaded_srt != uploaded_srt.name: st.session_state.srt_results = {} st.session_state.last_uploaded_srt = uploaded_srt.name srt_content = uploaded_srt.read().decode("utf-8") st.text_area("字幕预览", srt_content, height=200) if st.button("开始翻译字幕", disabled=not api_key or st.session_state.processing, on_click=start_translation_callback): try: with st.spinner("正在翻译字幕..."): segments = VideoUtils.parse_srt(srt_content) translator = Translator(api_key, base_url, model_name) translated_segments = translator.translate_segments(segments, target_lang) # 保存到临时文件 temp_trans_srt = tempfile.NamedTemporaryFile(delete=False, suffix='.srt') VideoUtils.write_srt(translated_segments, temp_trans_srt.name) st.session_state.srt_results = { "trans_srt": temp_trans_srt.name } st.success("翻译完成!") # 成功处理后自动保存配置 save_current_settings() st.session_state.processing = False st.rerun() except Exception as e: st.session_state.processing = False st.error(f"翻译过程中发生错误: {e}") st.stop() # 结果下载 if st.session_state.srt_results: st.divider() res = st.session_state.srt_results if os.path.exists(res.get("trans_srt", "")): with open(res["trans_srt"], "rb") as f: st.download_button("⬇️ 下载翻译后的字幕", f, file_name="translated_only.srt", key="dl_srt_tab") if not api_key: st.warning("请在左侧边栏输入 API Key 后开始。")