Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import sys | |
| import logging | |
| from pathlib import Path | |
| import shutil | |
| import time | |
| import json | |
| import psutil | |
| import gc | |
| import threading | |
| # Set up logging to stdout | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| stream=sys.stdout | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Add src to path | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) | |
| # Directories | |
| os.makedirs("models", exist_ok=True) | |
| os.makedirs("outputs", exist_ok=True) | |
| STATUS_FILE = "outputs/current_status.txt" | |
| from download_models import download_vad_model, download_hf_model, download_whisper_base_for_feature_extractor | |
| from faster_whisper_transwithai_chickenrice.infer import Inference, Segment, _require_faster_whisper | |
| MODELS = { | |
| "海南鸡 v2 (JP->ZH Optimized)": "chickenrice0721/whisper-large-v2-translate-zh-v0.2-st-ct2", | |
| "Whisper Large-v3": "Systran/faster-whisper-large-v3", | |
| "Whisper Medium": "Systran/faster-whisper-medium", | |
| "Whisper Small": "Systran/faster-whisper-small", | |
| } | |
| # --- GLOBAL STATE (Survives page refreshes) --- | |
| class GlobalState: | |
| def __init__(self): | |
| self.is_running = False | |
| self.msg = "就绪,等待任务..." | |
| self.last_files = [] | |
| self.lock = threading.Lock() | |
| state = GlobalState() | |
| GLOBAL_MODEL = {"key": None, "instance": None} | |
| def get_mem(): | |
| try: | |
| process = psutil.Process(os.getpid()) | |
| return f"{process.memory_info().rss / 1024**3:.2f}GB" | |
| except: return "N/A" | |
| def set_state(msg, is_running=True, files=None): | |
| with state.lock: | |
| state.is_running = is_running | |
| state.msg = f"{msg} | RAM: {get_mem()}" | |
| if files: state.last_files = files | |
| # Persistence | |
| try: | |
| with open(STATUS_FILE, "w", encoding="utf-8") as f: | |
| f.write(f"[{time.strftime('%H:%M:%S')}] {state.msg}") | |
| except: pass | |
| logger.info(f"BACKEND: {state.msg}") | |
| def backend_worker(audio_file, model_name, compute_type, sub_formats): | |
| global GLOBAL_MODEL | |
| try: | |
| model_id = MODELS[model_name] | |
| repo_name = model_id.split("/")[-1] | |
| model_path = os.path.join("models", repo_name) | |
| if not os.path.exists(model_path): | |
| set_state("正在下载模型文件...") | |
| download_hf_model(model_id) | |
| args = type('Args', (), { | |
| 'model_name_or_path': model_path, 'device': 'cpu', 'compute_type': compute_type, | |
| 'overwrite': True, 'audio_suffixes': 'wav,flac,mp3,m4a,mp4,mkv,avi,webm,mov,flv', | |
| 'sub_formats': ",".join(sub_formats), 'output_dir': 'outputs', 'generation_config': 'generation_config.json5', | |
| 'enable_batching': False, 'batch_size': 1, 'max_batch_size': 1, | |
| 'vad_threshold': None, 'vad_min_speech_duration_ms': None, 'vad_min_silence_duration_ms': None, | |
| 'vad_speech_pad_ms': None, 'merge_segments': None, 'merge_max_gap_ms': None, 'merge_max_duration_ms': None, | |
| 'base_dirs': [] | |
| })() | |
| inference = Inference(args) | |
| model_key = f"{model_path}-{compute_type}" | |
| if GLOBAL_MODEL["key"] != model_key: | |
| set_state("加载模型引擎中...") | |
| GLOBAL_MODEL["instance"] = None | |
| gc.collect() | |
| WhisperModelCls, _ = _require_faster_whisper() | |
| GLOBAL_MODEL["instance"] = WhisperModelCls(model_path, device="cpu", compute_type=compute_type, cpu_threads=2) | |
| GLOBAL_MODEL["key"] = model_key | |
| model = GLOBAL_MODEL["instance"] | |
| set_state("转录开始...") | |
| segments, _ = model.transcribe(audio_file, **inference.generation_config) | |
| all_segments = [] | |
| for s in segments: | |
| all_segments.append(s) | |
| ts = f"{time.strftime('%M:%S', time.gmtime(s.start))} -> {time.strftime('%M:%S', time.gmtime(s.end))}" | |
| set_state(f"处理中: [{ts}] {s.text}") | |
| if not all_segments: | |
| set_state("识别为空", is_running=False) | |
| return | |
| set_state("正在保存字幕文件...") | |
| base_name = os.path.splitext(os.path.basename(audio_file))[0] | |
| output_paths = [] | |
| app_segments = [Segment(start=int(s.start * 1000), end=int(s.end * 1000), text=s.text) for s in all_segments] | |
| for fmt in sub_formats: | |
| out_file = os.path.join("outputs", f"{base_name}.{fmt}") | |
| writer = inference.sub_writers.get(fmt) | |
| if writer: | |
| writer(app_segments, out_file) | |
| output_paths.append(out_file) | |
| set_state("任务完成!", is_running=False, files=output_paths) | |
| except Exception as e: | |
| set_state(f"致命错误: {str(e)}", is_running=False) | |
| def ui_trigger_task(audio_file, model_name, compute_type, sub_formats): | |
| if state.is_running: | |
| return "警告:已有任务正在后台运行。" | |
| if not audio_file: | |
| return "错误:请先上传音频。" | |
| set_state("初始化后台线程...") | |
| t = threading.Thread(target=backend_worker, args=(audio_file, model_name, compute_type, sub_formats)) | |
| t.daemon = False | |
| t.start() | |
| return "后台任务已启动!" | |
| def ui_poll_status(): | |
| """Timer callback for UI updates""" | |
| with state.lock: | |
| return state.msg, state.last_files if not state.is_running else None | |
| def ui_get_history(): | |
| files = [os.path.join("outputs", f) for f in os.listdir("outputs") if f.endswith((".srt", ".vtt", ".lrc", ".txt"))] | |
| files.sort(key=lambda x: os.path.getmtime(x), reverse=True) | |
| return files | |
| with gr.Blocks(title="Faster Whisper ChickenRice") as demo: | |
| gr.Markdown("# 🎙️ Faster Whisper ChickenRice (稳定监控版)") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_audio = gr.Audio(label="上传音视频", type="filepath") | |
| with gr.Row(): | |
| sel_model = gr.Dropdown(choices=list(MODELS.keys()), value="海南鸡 v2 (JP->ZH Optimized)", label="模型") | |
| sel_compute = gr.Dropdown(choices=["int8", "float32"], value="int8", label="精度") | |
| sel_formats = gr.CheckboxGroup(choices=["srt", "vtt", "lrc", "txt"], value=["srt", "vtt"], label="格式") | |
| btn_run = gr.Button("🔥 启动后台转录", variant="primary") | |
| with gr.Column(scale=3): | |
| # Using Timer component for polling in newer Gradio | |
| status_timer = gr.Timer(value=2, active=True) | |
| out_status = gr.Textbox(label="📡 后台实时监控", interactive=False) | |
| out_files = gr.File(label="最新生成的字幕") | |
| with gr.Accordion("历史记录 & 找回文件", open=False): | |
| btn_refresh = gr.Button("🔄 刷新磁盘文件列表") | |
| out_history = gr.File(label="所有已生成文件", file_count="multiple") | |
| # --- UI Logic --- | |
| btn_run.click(fn=ui_trigger_task, inputs=[input_audio, sel_model, sel_compute, sel_formats], outputs=[out_status]) | |
| # Correct way to poll status in Gradio: Timer tick or demo.load with limited support | |
| # Using the new gr.Timer (Gradio 4.40+) or a simpler load approach | |
| status_timer.tick(fn=ui_poll_status, outputs=[out_status, out_files]) | |
| btn_refresh.click(fn=ui_get_history, outputs=[out_history]) | |
| if __name__ == "__main__": | |
| try: | |
| download_vad_model() | |
| download_whisper_base_for_feature_extractor() | |
| except: pass | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |