Spaces:
Sleeping
Sleeping
File size: 7,530 Bytes
da30535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import gradio as gr
import os
import sys
import logging
from pathlib import Path
import shutil
import time
import json
import psutil
import gc
import threading
# Set up logging to stdout
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
# Directories
os.makedirs("models", exist_ok=True)
os.makedirs("outputs", exist_ok=True)
STATUS_FILE = "outputs/current_status.txt"
from download_models import download_vad_model, download_hf_model, download_whisper_base_for_feature_extractor
from faster_whisper_transwithai_chickenrice.infer import Inference, Segment, _require_faster_whisper
MODELS = {
"海南鸡 v2 (JP->ZH Optimized)": "chickenrice0721/whisper-large-v2-translate-zh-v0.2-st-ct2",
"Whisper Large-v3": "Systran/faster-whisper-large-v3",
"Whisper Medium": "Systran/faster-whisper-medium",
"Whisper Small": "Systran/faster-whisper-small",
}
# --- GLOBAL STATE (Survives page refreshes) ---
class GlobalState:
def __init__(self):
self.is_running = False
self.msg = "就绪,等待任务..."
self.last_files = []
self.lock = threading.Lock()
state = GlobalState()
GLOBAL_MODEL = {"key": None, "instance": None}
def get_mem():
try:
process = psutil.Process(os.getpid())
return f"{process.memory_info().rss / 1024**3:.2f}GB"
except: return "N/A"
def set_state(msg, is_running=True, files=None):
with state.lock:
state.is_running = is_running
state.msg = f"{msg} | RAM: {get_mem()}"
if files: state.last_files = files
# Persistence
try:
with open(STATUS_FILE, "w", encoding="utf-8") as f:
f.write(f"[{time.strftime('%H:%M:%S')}] {state.msg}")
except: pass
logger.info(f"BACKEND: {state.msg}")
def backend_worker(audio_file, model_name, compute_type, sub_formats):
global GLOBAL_MODEL
try:
model_id = MODELS[model_name]
repo_name = model_id.split("/")[-1]
model_path = os.path.join("models", repo_name)
if not os.path.exists(model_path):
set_state("正在下载模型文件...")
download_hf_model(model_id)
args = type('Args', (), {
'model_name_or_path': model_path, 'device': 'cpu', 'compute_type': compute_type,
'overwrite': True, 'audio_suffixes': 'wav,flac,mp3,m4a,mp4,mkv,avi,webm,mov,flv',
'sub_formats': ",".join(sub_formats), 'output_dir': 'outputs', 'generation_config': 'generation_config.json5',
'enable_batching': False, 'batch_size': 1, 'max_batch_size': 1,
'vad_threshold': None, 'vad_min_speech_duration_ms': None, 'vad_min_silence_duration_ms': None,
'vad_speech_pad_ms': None, 'merge_segments': None, 'merge_max_gap_ms': None, 'merge_max_duration_ms': None,
'base_dirs': []
})()
inference = Inference(args)
model_key = f"{model_path}-{compute_type}"
if GLOBAL_MODEL["key"] != model_key:
set_state("加载模型引擎中...")
GLOBAL_MODEL["instance"] = None
gc.collect()
WhisperModelCls, _ = _require_faster_whisper()
GLOBAL_MODEL["instance"] = WhisperModelCls(model_path, device="cpu", compute_type=compute_type, cpu_threads=2)
GLOBAL_MODEL["key"] = model_key
model = GLOBAL_MODEL["instance"]
set_state("转录开始...")
segments, _ = model.transcribe(audio_file, **inference.generation_config)
all_segments = []
for s in segments:
all_segments.append(s)
ts = f"{time.strftime('%M:%S', time.gmtime(s.start))} -> {time.strftime('%M:%S', time.gmtime(s.end))}"
set_state(f"处理中: [{ts}] {s.text}")
if not all_segments:
set_state("识别为空", is_running=False)
return
set_state("正在保存字幕文件...")
base_name = os.path.splitext(os.path.basename(audio_file))[0]
output_paths = []
app_segments = [Segment(start=int(s.start * 1000), end=int(s.end * 1000), text=s.text) for s in all_segments]
for fmt in sub_formats:
out_file = os.path.join("outputs", f"{base_name}.{fmt}")
writer = inference.sub_writers.get(fmt)
if writer:
writer(app_segments, out_file)
output_paths.append(out_file)
set_state("任务完成!", is_running=False, files=output_paths)
except Exception as e:
set_state(f"致命错误: {str(e)}", is_running=False)
def ui_trigger_task(audio_file, model_name, compute_type, sub_formats):
if state.is_running:
return "警告:已有任务正在后台运行。"
if not audio_file:
return "错误:请先上传音频。"
set_state("初始化后台线程...")
t = threading.Thread(target=backend_worker, args=(audio_file, model_name, compute_type, sub_formats))
t.daemon = False
t.start()
return "后台任务已启动!"
def ui_poll_status():
"""Timer callback for UI updates"""
with state.lock:
return state.msg, state.last_files if not state.is_running else None
def ui_get_history():
files = [os.path.join("outputs", f) for f in os.listdir("outputs") if f.endswith((".srt", ".vtt", ".lrc", ".txt"))]
files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
return files
with gr.Blocks(title="Faster Whisper ChickenRice") as demo:
gr.Markdown("# 🎙️ Faster Whisper ChickenRice (稳定监控版)")
with gr.Row():
with gr.Column(scale=2):
input_audio = gr.Audio(label="上传音视频", type="filepath")
with gr.Row():
sel_model = gr.Dropdown(choices=list(MODELS.keys()), value="海南鸡 v2 (JP->ZH Optimized)", label="模型")
sel_compute = gr.Dropdown(choices=["int8", "float32"], value="int8", label="精度")
sel_formats = gr.CheckboxGroup(choices=["srt", "vtt", "lrc", "txt"], value=["srt", "vtt"], label="格式")
btn_run = gr.Button("🔥 启动后台转录", variant="primary")
with gr.Column(scale=3):
# Using Timer component for polling in newer Gradio
status_timer = gr.Timer(value=2, active=True)
out_status = gr.Textbox(label="📡 后台实时监控", interactive=False)
out_files = gr.File(label="最新生成的字幕")
with gr.Accordion("历史记录 & 找回文件", open=False):
btn_refresh = gr.Button("🔄 刷新磁盘文件列表")
out_history = gr.File(label="所有已生成文件", file_count="multiple")
# --- UI Logic ---
btn_run.click(fn=ui_trigger_task, inputs=[input_audio, sel_model, sel_compute, sel_formats], outputs=[out_status])
# Correct way to poll status in Gradio: Timer tick or demo.load with limited support
# Using the new gr.Timer (Gradio 4.40+) or a simpler load approach
status_timer.tick(fn=ui_poll_status, outputs=[out_status, out_files])
btn_refresh.click(fn=ui_get_history, outputs=[out_history])
if __name__ == "__main__":
try:
download_vad_model()
download_whisper_base_for_feature_extractor()
except: pass
demo.launch(server_name="0.0.0.0", server_port=7860)
|