File size: 7,530 Bytes
da30535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
import os
import sys
import logging
from pathlib import Path
import shutil
import time
import json
import psutil
import gc
import threading

# Set up logging to stdout
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    stream=sys.stdout
)
logger = logging.getLogger(__name__)

# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))

# Directories
os.makedirs("models", exist_ok=True)
os.makedirs("outputs", exist_ok=True)
STATUS_FILE = "outputs/current_status.txt"

from download_models import download_vad_model, download_hf_model, download_whisper_base_for_feature_extractor
from faster_whisper_transwithai_chickenrice.infer import Inference, Segment, _require_faster_whisper

MODELS = {
    "海南鸡 v2 (JP->ZH Optimized)": "chickenrice0721/whisper-large-v2-translate-zh-v0.2-st-ct2",
    "Whisper Large-v3": "Systran/faster-whisper-large-v3",
    "Whisper Medium": "Systran/faster-whisper-medium",
    "Whisper Small": "Systran/faster-whisper-small",
}

# --- GLOBAL STATE (Survives page refreshes) ---
class GlobalState:
    def __init__(self):
        self.is_running = False
        self.msg = "就绪,等待任务..."
        self.last_files = []
        self.lock = threading.Lock()

state = GlobalState()
GLOBAL_MODEL = {"key": None, "instance": None}

def get_mem():
    try:
        process = psutil.Process(os.getpid())
        return f"{process.memory_info().rss / 1024**3:.2f}GB"
    except: return "N/A"

def set_state(msg, is_running=True, files=None):
    with state.lock:
        state.is_running = is_running
        state.msg = f"{msg} | RAM: {get_mem()}"
        if files: state.last_files = files
    # Persistence
    try:
        with open(STATUS_FILE, "w", encoding="utf-8") as f:
            f.write(f"[{time.strftime('%H:%M:%S')}] {state.msg}")
    except: pass
    logger.info(f"BACKEND: {state.msg}")

def backend_worker(audio_file, model_name, compute_type, sub_formats):
    global GLOBAL_MODEL
    try:
        model_id = MODELS[model_name]
        repo_name = model_id.split("/")[-1]
        model_path = os.path.join("models", repo_name)
        
        if not os.path.exists(model_path):
            set_state("正在下载模型文件...")
            download_hf_model(model_id)

        args = type('Args', (), {
            'model_name_or_path': model_path, 'device': 'cpu', 'compute_type': compute_type,
            'overwrite': True, 'audio_suffixes': 'wav,flac,mp3,m4a,mp4,mkv,avi,webm,mov,flv',
            'sub_formats': ",".join(sub_formats), 'output_dir': 'outputs', 'generation_config': 'generation_config.json5',
            'enable_batching': False, 'batch_size': 1, 'max_batch_size': 1,
            'vad_threshold': None, 'vad_min_speech_duration_ms': None, 'vad_min_silence_duration_ms': None,
            'vad_speech_pad_ms': None, 'merge_segments': None, 'merge_max_gap_ms': None, 'merge_max_duration_ms': None,
            'base_dirs': []
        })()

        inference = Inference(args)
        model_key = f"{model_path}-{compute_type}"
        
        if GLOBAL_MODEL["key"] != model_key:
            set_state("加载模型引擎中...")
            GLOBAL_MODEL["instance"] = None
            gc.collect()
            WhisperModelCls, _ = _require_faster_whisper()
            GLOBAL_MODEL["instance"] = WhisperModelCls(model_path, device="cpu", compute_type=compute_type, cpu_threads=2)
            GLOBAL_MODEL["key"] = model_key
        
        model = GLOBAL_MODEL["instance"]
        set_state("转录开始...")
        
        segments, _ = model.transcribe(audio_file, **inference.generation_config)
        
        all_segments = []
        for s in segments:
            all_segments.append(s)
            ts = f"{time.strftime('%M:%S', time.gmtime(s.start))} -> {time.strftime('%M:%S', time.gmtime(s.end))}"
            set_state(f"处理中: [{ts}] {s.text}")
        
        if not all_segments:
            set_state("识别为空", is_running=False)
            return

        set_state("正在保存字幕文件...")
        base_name = os.path.splitext(os.path.basename(audio_file))[0]
        output_paths = []
        app_segments = [Segment(start=int(s.start * 1000), end=int(s.end * 1000), text=s.text) for s in all_segments]
        
        for fmt in sub_formats:
            out_file = os.path.join("outputs", f"{base_name}.{fmt}")
            writer = inference.sub_writers.get(fmt)
            if writer:
                writer(app_segments, out_file)
                output_paths.append(out_file)
        
        set_state("任务完成!", is_running=False, files=output_paths)

    except Exception as e:
        set_state(f"致命错误: {str(e)}", is_running=False)

def ui_trigger_task(audio_file, model_name, compute_type, sub_formats):
    if state.is_running:
        return "警告:已有任务正在后台运行。"
    if not audio_file:
        return "错误:请先上传音频。"
    set_state("初始化后台线程...")
    t = threading.Thread(target=backend_worker, args=(audio_file, model_name, compute_type, sub_formats))
    t.daemon = False
    t.start()
    return "后台任务已启动!"

def ui_poll_status():
    """Timer callback for UI updates"""
    with state.lock:
        return state.msg, state.last_files if not state.is_running else None

def ui_get_history():
    files = [os.path.join("outputs", f) for f in os.listdir("outputs") if f.endswith((".srt", ".vtt", ".lrc", ".txt"))]
    files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
    return files

with gr.Blocks(title="Faster Whisper ChickenRice") as demo:
    gr.Markdown("# 🎙️ Faster Whisper ChickenRice (稳定监控版)")
    
    with gr.Row():
        with gr.Column(scale=2):
            input_audio = gr.Audio(label="上传音视频", type="filepath")
            with gr.Row():
                sel_model = gr.Dropdown(choices=list(MODELS.keys()), value="海南鸡 v2 (JP->ZH Optimized)", label="模型")
                sel_compute = gr.Dropdown(choices=["int8", "float32"], value="int8", label="精度")
            sel_formats = gr.CheckboxGroup(choices=["srt", "vtt", "lrc", "txt"], value=["srt", "vtt"], label="格式")
            btn_run = gr.Button("🔥 启动后台转录", variant="primary")
        
        with gr.Column(scale=3):
            # Using Timer component for polling in newer Gradio
            status_timer = gr.Timer(value=2, active=True)
            out_status = gr.Textbox(label="📡 后台实时监控", interactive=False)
            out_files = gr.File(label="最新生成的字幕")
            with gr.Accordion("历史记录 & 找回文件", open=False):
                btn_refresh = gr.Button("🔄 刷新磁盘文件列表")
                out_history = gr.File(label="所有已生成文件", file_count="multiple")

    # --- UI Logic ---
    btn_run.click(fn=ui_trigger_task, inputs=[input_audio, sel_model, sel_compute, sel_formats], outputs=[out_status])
    
    # Correct way to poll status in Gradio: Timer tick or demo.load with limited support
    # Using the new gr.Timer (Gradio 4.40+) or a simpler load approach
    status_timer.tick(fn=ui_poll_status, outputs=[out_status, out_files])
    
    btn_refresh.click(fn=ui_get_history, outputs=[out_history])

if __name__ == "__main__":
    try:
        download_vad_model()
        download_whisper_base_for_feature_extractor()
    except: pass
    demo.launch(server_name="0.0.0.0", server_port=7860)