translator / app.py
felix1968839's picture
配置自动保存功能
728fd98 verified
import streamlit as st
import os
from utils import VideoUtils, SettingsManager
from stt_module import STTManager
from translator import Translator
import tempfile
from dotenv import load_dotenv
# 加载 .env 文件中的配置
load_dotenv()
# 加载本地保存的设置
saved_settings = SettingsManager.load_settings()
# 取消 Streamlit 上传限制(虽然主要通过命令行配置,但在脚本中提醒用户)
# 实际上 Streamlit 的服务器配置需要在运行命令时指定,或者写在 config.toml 中
# 我们这里先通过 UI 提醒用户
st.set_page_config(page_title="AI 媒体翻译专家", layout="wide")
# 初始化 session_state
if "process_results" not in st.session_state:
st.session_state.process_results = {}
if "srt_results" not in st.session_state:
st.session_state.srt_results = {}
if "awaiting_synthesis" not in st.session_state:
st.session_state.awaiting_synthesis = False
if "processing" not in st.session_state:
st.session_state.processing = False
if "last_uploaded_file" not in st.session_state:
st.session_state.last_uploaded_file = None
if "last_uploaded_srt" not in st.session_state:
st.session_state.last_uploaded_srt = None
# 回调函数:开始处理
def start_processing_callback():
st.session_state.processing = True
st.session_state.process_results = {}
st.session_state.awaiting_synthesis = False
# 回调函数:开始合成
def start_synthesis_callback():
st.session_state.processing = True
# 回调函数:开始翻译字幕
def start_translation_callback():
st.session_state.processing = True
st.session_state.srt_results = {}
st.title("🎬 AI 媒体字幕提取与翻译")
# 侧边栏配置
with st.sidebar:
st.header("⚙️ 配置参数")
# 从本地保存的设置或环境变量读取默认值
default_api_key = saved_settings.get("API_KEY", os.getenv("API_KEY", ""))
default_base_url = saved_settings.get("BASE_URL", os.getenv("BASE_URL", "https://api.siliconflow.cn/v1"))
default_model = saved_settings.get("MODEL_NAME", os.getenv("MODEL_NAME", "THUDM/glm-4-9b-chat"))
default_stt_size = saved_settings.get("STT_MODEL_SIZE", os.getenv("STT_MODEL_SIZE", "base"))
default_lang = saved_settings.get("TARGET_LANG", os.getenv("TARGET_LANG", "中文"))
default_device = saved_settings.get("DEVICE", os.getenv("DEVICE", "cpu"))
api_key = st.text_input("API Key", value=default_api_key, type="password", help="输入 API Key", disabled=st.session_state.processing)
base_url = st.text_input("API Base URL", value=default_base_url, disabled=st.session_state.processing)
model_name = st.text_input("模型名称", value=default_model, disabled=st.session_state.processing)
st.divider()
# 设备选择
cuda_available = STTManager.is_cuda_available()
device_options = ["cpu", "cuda"] if cuda_available else ["cpu"]
device_index = device_options.index(default_device) if default_device in device_options else 0
device = st.selectbox(
"计算设备 (Device)",
device_options,
index=device_index,
help="如果有 NVIDIA 显卡且安装了 CUDA,选择 'cuda' 会大幅提升速度。",
disabled=st.session_state.processing
)
# 根据设备推荐精度
default_compute_type = "float16" if device == "cuda" else "int8"
compute_type = st.selectbox(
"计算精度 (Compute Type)",
["int8", "float16", "int8_float16"],
index=1 if device == "cuda" else 0,
help="CPU 推荐 int8,GPU 推荐 float16。",
disabled=st.session_state.processing
)
stt_options = ["tiny", "base", "small", "medium", "large-v3"]
# 获取已下载模型
downloaded_models = STTManager.get_downloaded_models()
# 构建选项显示名称
option_labels = []
for opt in stt_options:
label = f"{opt} (已下载)" if opt in downloaded_models else opt
option_labels.append(label)
stt_index = stt_options.index(default_stt_size) if default_stt_size in stt_options else 1
selected_option = st.selectbox(
"本地 STT 模型大小",
option_labels,
index=stt_index,
help="越大越准,但速度越慢。首次使用未下载的模型时会自动下载。",
disabled=st.session_state.processing
)
# 从选项中提取真实模型名
model_size = selected_option.split(" ")[0]
lang_options = ["中文", "English", "日本語", "Français"]
lang_index = lang_options.index(default_lang) if default_lang in lang_options else 0
target_lang = st.selectbox("目标语言", lang_options, index=lang_index, disabled=st.session_state.processing)
# 保存当前配置到本地文件的辅助函数
def save_current_settings():
SettingsManager.save_settings({
"API_KEY": api_key,
"BASE_URL": base_url,
"MODEL_NAME": model_name,
"STT_MODEL_SIZE": model_size,
"TARGET_LANG": target_lang,
"DEVICE": device
})
if st.session_state.processing:
st.divider()
if st.button("⏹️ 强制停止任务", type="secondary", use_container_width=True):
st.session_state.processing = False
st.session_state.awaiting_synthesis = False
st.rerun()
# 主界面
tab1, tab2 = st.tabs(["📁 媒体处理", "📜 字幕翻译"])
with tab1:
st.info("💡 提示:支持视频和音频文件。如果文件非常大,处理会很慢请耐心等待。")
uploaded_file = st.file_uploader("选择视频或音频文件", type=["mp4", "mkv", "avi", "mov", "mp3", "wav", "m4a", "flac", "aac"], disabled=st.session_state.processing)
if uploaded_file:
# 检测文件类型
video_extensions = [".mp4", ".mkv", ".avi", ".mov"]
file_ext = os.path.splitext(uploaded_file.name)[1].lower()
is_video = file_ext in video_extensions
# 如果上传了新文件且与 session_state 中记录的不同,则清除旧结果
if "last_uploaded_file" not in st.session_state or st.session_state.last_uploaded_file != uploaded_file.name:
st.session_state.process_results = {}
st.session_state.awaiting_synthesis = False
st.session_state.last_uploaded_file = uploaded_file.name
# 保存上传的文件到临时目录
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=file_ext)
tfile.write(uploaded_file.read())
file_path = tfile.name
if is_video:
st.video(file_path)
else:
st.audio(file_path)
if st.button("开始处理",
disabled=not api_key or st.session_state.awaiting_synthesis or st.session_state.processing,
on_click=start_processing_callback):
try:
with st.status("正在处理中...", expanded=True) as status:
# 1. 提取/准备音频
st.write("🎵 正在准备音频...")
audio_path = os.path.splitext(file_path)[0] + '.wav'
try:
# 无论视频还是音频,都通过 prepare_audio (ffmpeg) 转换为标准格式,确保 STT 兼容性
VideoUtils.prepare_audio(file_path, audio_path)
except Exception as e:
st.error(str(e))
status.update(label="处理出错", state="error", expanded=True)
st.session_state.processing = False
st.stop()
# 2. 本地 STT
if model_size not in downloaded_models:
st.write(f"📥 正在下载 {model_size} 模型,请稍候...")
st.write(f"✍️ 正在识别语音 (使用 {model_size} 模型,设备: {device})...")
stt_manager = STTManager(model_size=model_size, device=device, compute_type=compute_type)
stt_manager.load_model()
segments_gen, info = stt_manager.transcribe(audio_path)
st.write(f"检测到语言: {info.language} (置信度: {info.language_probability:.2f})")
# 增量处理与展示
st.write("---")
st.write("实时识别与翻译预览:")
preview_container = st.empty()
all_segments = []
all_translated_segments = []
translator = Translator(api_key, base_url, model_name)
# 用于展示的表格数据
display_data = []
for segment in segments_gen:
# 1. 翻译当前段落
trans_text = translator.translate_text(segment.text, target_lang)
# 2. 保存原始和翻译后的段落
all_segments.append(segment)
new_trans_seg = type('Segment', (), {
'start': segment.start,
'end': segment.end,
'text': trans_text
})
all_translated_segments.append(new_trans_seg)
# 3. 更新预览界面
time_str = f"{VideoUtils.format_timestamp(segment.start)} -> {VideoUtils.format_timestamp(segment.end)}"
display_data.append({
"时间轴": time_str,
"原文": segment.text,
"翻译": trans_text
})
# 仅显示最后 5 条,避免页面过长,但提供滚动查看全部的可能
preview_container.table(display_data[-5:])
# 生成原始字幕
orig_srt_path = os.path.splitext(file_path)[0] + '_orig.srt'
VideoUtils.write_srt(all_segments, orig_srt_path)
# 生成翻译字幕
trans_srt_path = os.path.splitext(file_path)[0] + '_trans.srt'
VideoUtils.write_srt(all_translated_segments, trans_srt_path)
status.update(label="处理完成!", state="complete", expanded=False)
# 成功处理后自动保存配置
save_current_settings()
# 保存结果到 session_state
st.session_state.process_results = {
"orig_srt": orig_srt_path,
"trans_srt": trans_srt_path,
"output_video": None,
"is_video": is_video
}
# 保存中间结果用于后续合成 (仅针对视频)
if is_video:
st.session_state.temp_video_path = file_path
st.session_state.temp_trans_srt_path = trans_srt_path
st.session_state.temp_orig_srt_path = orig_srt_path
st.session_state.awaiting_synthesis = True
st.session_state.processing = False
st.rerun()
except Exception as e:
st.session_state.processing = False
st.error(f"处理过程中发生未预期的错误: {e}")
st.stop()
# 4. 嵌入视频 (仅视频且在等待状态)
if st.session_state.awaiting_synthesis and st.session_state.process_results.get("is_video"):
st.success("✅ 字幕翻译已完成!")
col_synth_1, col_synth_2 = st.columns(2)
with col_synth_1:
if st.button("🚀 开始合成视频字幕",
type="primary",
disabled=st.session_state.processing,
on_click=start_synthesis_callback):
try:
with st.status("🎬 正在合成视频字幕...", expanded=True) as status:
v_path = st.session_state.temp_video_path
s_path = st.session_state.temp_trans_srt_path
output_video_path = os.path.splitext(v_path)[0] + '_translated.mp4'
video_ready = False
try:
VideoUtils.embed_subtitles(v_path, s_path, output_video_path)
video_ready = True
st.write("✨ 视频合成成功!")
except Exception as e:
st.error(f"视频合成失败 (请确保已安装 FFmpeg): {e}")
status.update(label="全部处理完成!", state="complete", expanded=False)
# 保存最终结果到 session_state
st.session_state.process_results["output_video"] = output_video_path if video_ready else None
st.session_state.awaiting_synthesis = False
st.session_state.processing = False
st.rerun()
except Exception as e:
st.session_state.processing = False
st.error(f"合成过程中发生未预期的错误: {e}")
st.stop()
with col_synth_2:
if st.button("📂 仅保存字幕"):
st.session_state.awaiting_synthesis = False
st.rerun()
# 结果展示与下载 (移出 button 缩进块,始终根据 session_state 显示)
if st.session_state.process_results:
st.divider()
col_title, col_clear = st.columns([5, 1])
with col_title:
st.subheader("🎉 处理结果")
with col_clear:
if st.button("🗑️ 清除结果"):
st.session_state.process_results = {}
st.session_state.awaiting_synthesis = False
st.rerun()
col1, col2, col3 = st.columns(3)
results = st.session_state.process_results
if os.path.exists(results.get("orig_srt", "")):
with col1:
with open(results["orig_srt"], "rb") as f:
st.download_button("⬇️ 下载原始字幕", f, file_name="original.srt", key="dl_orig")
if os.path.exists(results.get("trans_srt", "")):
with col2:
with open(results["trans_srt"], "rb") as f:
st.download_button("⬇️ 下载翻译字幕", f, file_name="translated.srt", key="dl_trans")
if results.get("output_video") and os.path.exists(results["output_video"]):
with col3:
with open(results["output_video"], "rb") as f:
st.download_button("⬇️ 下载翻译视频", f, file_name="video_with_subtitles.mp4", key="dl_video")
with tab2:
st.info("如果你已经有原始语言的字幕文件(SRT),可以在这里直接进行翻译。")
uploaded_srt = st.file_uploader("上传原始 SRT 字幕", type=["srt"], disabled=st.session_state.processing)
if uploaded_srt:
# 如果上传了新字幕且与 session_state 中记录的不同,则清除旧结果
if "last_uploaded_srt" not in st.session_state or st.session_state.last_uploaded_srt != uploaded_srt.name:
st.session_state.srt_results = {}
st.session_state.last_uploaded_srt = uploaded_srt.name
srt_content = uploaded_srt.read().decode("utf-8")
st.text_area("字幕预览", srt_content, height=200)
if st.button("开始翻译字幕",
disabled=not api_key or st.session_state.processing,
on_click=start_translation_callback):
try:
with st.spinner("正在翻译字幕..."):
segments = VideoUtils.parse_srt(srt_content)
translator = Translator(api_key, base_url, model_name)
translated_segments = translator.translate_segments(segments, target_lang)
# 保存到临时文件
temp_trans_srt = tempfile.NamedTemporaryFile(delete=False, suffix='.srt')
VideoUtils.write_srt(translated_segments, temp_trans_srt.name)
st.session_state.srt_results = {
"trans_srt": temp_trans_srt.name
}
st.success("翻译完成!")
# 成功处理后自动保存配置
save_current_settings()
st.session_state.processing = False
st.rerun()
except Exception as e:
st.session_state.processing = False
st.error(f"翻译过程中发生错误: {e}")
st.stop()
# 结果下载
if st.session_state.srt_results:
st.divider()
res = st.session_state.srt_results
if os.path.exists(res.get("trans_srt", "")):
with open(res["trans_srt"], "rb") as f:
st.download_button("⬇️ 下载翻译后的字幕", f, file_name="translated_only.srt", key="dl_srt_tab")
if not api_key:
st.warning("请在左侧边栏输入 API Key 后开始。")