import gradio as gr import os import sys import logging from pathlib import Path import shutil import time import json import psutil import gc import threading # Set up logging to stdout logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', stream=sys.stdout ) logger = logging.getLogger(__name__) # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) # Directories os.makedirs("models", exist_ok=True) os.makedirs("outputs", exist_ok=True) STATUS_FILE = "outputs/current_status.txt" from download_models import download_vad_model, download_hf_model, download_whisper_base_for_feature_extractor from faster_whisper_transwithai_chickenrice.infer import Inference, Segment, _require_faster_whisper MODELS = { "海南鸡 v2 (JP->ZH Optimized)": "chickenrice0721/whisper-large-v2-translate-zh-v0.2-st-ct2", "Whisper Large-v3": "Systran/faster-whisper-large-v3", "Whisper Medium": "Systran/faster-whisper-medium", "Whisper Small": "Systran/faster-whisper-small", } # --- GLOBAL STATE (Survives page refreshes) --- class GlobalState: def __init__(self): self.is_running = False self.msg = "就绪,等待任务..." self.last_files = [] self.lock = threading.Lock() state = GlobalState() GLOBAL_MODEL = {"key": None, "instance": None} def get_mem(): try: process = psutil.Process(os.getpid()) return f"{process.memory_info().rss / 1024**3:.2f}GB" except: return "N/A" def set_state(msg, is_running=True, files=None): with state.lock: state.is_running = is_running state.msg = f"{msg} | RAM: {get_mem()}" if files: state.last_files = files # Persistence try: with open(STATUS_FILE, "w", encoding="utf-8") as f: f.write(f"[{time.strftime('%H:%M:%S')}] {state.msg}") except: pass logger.info(f"BACKEND: {state.msg}") def backend_worker(audio_file, model_name, compute_type, sub_formats): global GLOBAL_MODEL try: model_id = MODELS[model_name] repo_name = model_id.split("/")[-1] model_path = os.path.join("models", repo_name) if not os.path.exists(model_path): set_state("正在下载模型文件...") download_hf_model(model_id) args = type('Args', (), { 'model_name_or_path': model_path, 'device': 'cpu', 'compute_type': compute_type, 'overwrite': True, 'audio_suffixes': 'wav,flac,mp3,m4a,mp4,mkv,avi,webm,mov,flv', 'sub_formats': ",".join(sub_formats), 'output_dir': 'outputs', 'generation_config': 'generation_config.json5', 'enable_batching': False, 'batch_size': 1, 'max_batch_size': 1, 'vad_threshold': None, 'vad_min_speech_duration_ms': None, 'vad_min_silence_duration_ms': None, 'vad_speech_pad_ms': None, 'merge_segments': None, 'merge_max_gap_ms': None, 'merge_max_duration_ms': None, 'base_dirs': [] })() inference = Inference(args) model_key = f"{model_path}-{compute_type}" if GLOBAL_MODEL["key"] != model_key: set_state("加载模型引擎中...") GLOBAL_MODEL["instance"] = None gc.collect() WhisperModelCls, _ = _require_faster_whisper() GLOBAL_MODEL["instance"] = WhisperModelCls(model_path, device="cpu", compute_type=compute_type, cpu_threads=2) GLOBAL_MODEL["key"] = model_key model = GLOBAL_MODEL["instance"] set_state("转录开始...") segments, _ = model.transcribe(audio_file, **inference.generation_config) all_segments = [] for s in segments: all_segments.append(s) ts = f"{time.strftime('%M:%S', time.gmtime(s.start))} -> {time.strftime('%M:%S', time.gmtime(s.end))}" set_state(f"处理中: [{ts}] {s.text}") if not all_segments: set_state("识别为空", is_running=False) return set_state("正在保存字幕文件...") base_name = os.path.splitext(os.path.basename(audio_file))[0] output_paths = [] app_segments = [Segment(start=int(s.start * 1000), end=int(s.end * 1000), text=s.text) for s in all_segments] for fmt in sub_formats: out_file = os.path.join("outputs", f"{base_name}.{fmt}") writer = inference.sub_writers.get(fmt) if writer: writer(app_segments, out_file) output_paths.append(out_file) set_state("任务完成!", is_running=False, files=output_paths) except Exception as e: set_state(f"致命错误: {str(e)}", is_running=False) def ui_trigger_task(audio_file, model_name, compute_type, sub_formats): if state.is_running: return "警告:已有任务正在后台运行。" if not audio_file: return "错误:请先上传音频。" set_state("初始化后台线程...") t = threading.Thread(target=backend_worker, args=(audio_file, model_name, compute_type, sub_formats)) t.daemon = False t.start() return "后台任务已启动!" def ui_poll_status(): """Timer callback for UI updates""" with state.lock: return state.msg, state.last_files if not state.is_running else None def ui_get_history(): files = [os.path.join("outputs", f) for f in os.listdir("outputs") if f.endswith((".srt", ".vtt", ".lrc", ".txt"))] files.sort(key=lambda x: os.path.getmtime(x), reverse=True) return files with gr.Blocks(title="Faster Whisper ChickenRice") as demo: gr.Markdown("# 🎙️ Faster Whisper ChickenRice (稳定监控版)") with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio(label="上传音视频", type="filepath") with gr.Row(): sel_model = gr.Dropdown(choices=list(MODELS.keys()), value="海南鸡 v2 (JP->ZH Optimized)", label="模型") sel_compute = gr.Dropdown(choices=["int8", "float32"], value="int8", label="精度") sel_formats = gr.CheckboxGroup(choices=["srt", "vtt", "lrc", "txt"], value=["srt", "vtt"], label="格式") btn_run = gr.Button("🔥 启动后台转录", variant="primary") with gr.Column(scale=3): # Using Timer component for polling in newer Gradio status_timer = gr.Timer(value=2, active=True) out_status = gr.Textbox(label="📡 后台实时监控", interactive=False) out_files = gr.File(label="最新生成的字幕") with gr.Accordion("历史记录 & 找回文件", open=False): btn_refresh = gr.Button("🔄 刷新磁盘文件列表") out_history = gr.File(label="所有已生成文件", file_count="multiple") # --- UI Logic --- btn_run.click(fn=ui_trigger_task, inputs=[input_audio, sel_model, sel_compute, sel_formats], outputs=[out_status]) # Correct way to poll status in Gradio: Timer tick or demo.load with limited support # Using the new gr.Timer (Gradio 4.40+) or a simpler load approach status_timer.tick(fn=ui_poll_status, outputs=[out_status, out_files]) btn_refresh.click(fn=ui_get_history, outputs=[out_history]) if __name__ == "__main__": try: download_vad_model() download_whisper_base_for_feature_extractor() except: pass demo.launch(server_name="0.0.0.0", server_port=7860)