File size: 12,205 Bytes
dc89cfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a202ab8
 
 
 
 
 
 
 
 
40a81d6
 
a202ab8
 
 
 
dc89cfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a202ab8
dc89cfc
 
 
 
 
40a81d6
dc89cfc
 
40a81d6
dc89cfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40a81d6
 
 
 
dc89cfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

import os
import tempfile
import torch
import whisperx
from flask import Flask, request, jsonify, render_template
from waitress import serve
import logging
import webbrowser
from threading import Timer
import shutil
import sys
import ffmpeg
try:
    from whisperx.diarize import DiarizationPipeline
except Exception:
    DiarizationPipeline = None

# --- 全局配置与初始化 ---

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def _configure_torch_load_compat():
    """
    PyTorch >=2.6 changed torch.load default `weights_only=True`, which can break
    some third-party checkpoints (e.g. pyannote VAD models used by whisperx).
    We only allowlist OmegaConf config objects required by those checkpoints.
    """
    try:
        import torch.serialization
        from omegaconf import DictConfig, ListConfig
        from omegaconf.base import ContainerMetadata
        torch.serialization.add_safe_globals([DictConfig, ListConfig, ContainerMetadata])
        logging.info("已配置 torch.load 安全全局白名单 (omegaconf DictConfig/ListConfig)。")
    except Exception as e:
        logging.warning(f"torch.load 兼容性配置跳过: {e}")

def _env_bool(name: str, default: bool) -> bool:
    val = os.environ.get(name)
    if val is None:
        return default
    return val.strip().lower() in {"1", "true", "yes", "y", "on"}

def get_hf_token():
    """
    获取 Hugging Face 令牌。
    优先从当前目录的 'token.txt' 文件读取,如果失败则从环境变量 'HUGGING_FACE_TOKEN' 读取。
    """
    token = None
    token_file = 'token.txt'
    if os.path.exists(token_file):
        try:
            with open(token_file, 'r', encoding='utf-8') as f:
                token = f.read().strip()
            if token:
                logging.info(f"成功从 {token_file} 文件中读取 Hugging Face 令牌。")
                return token
        except Exception as e:
            logging.warning(f"无法从 {token_file} 读取令牌: {e}")

    token = os.environ.get("HUGGING_FACE_TOKEN")
    if token:
        logging.info("成功从环境变量中读取 Hugging Face 令牌。")
    else:
        logging.warning("在 token.txt 或环境变量中均未找到 Hugging Face 令牌。说话人分离功能将被禁用。")
    return token

HF_TOKEN = get_hf_token()
_configure_torch_load_compat()

# 设备和计算类型配置
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
COMPUTE_TYPE = "float16" if torch.cuda.is_available() else "int8"
BATCH_SIZE = 16 if DEVICE == "cuda" else 8
VAD_METHOD = os.environ.get("VAD_METHOD") or ("silero" if DEVICE == "cpu" else "pyannote")

logging.info(f"使用设备: {DEVICE},计算类型: {COMPUTE_TYPE}")
logging.info(f"VAD 方法: {VAD_METHOD}")

# 模型配置
ALLOWED_MODELS = ['tiny', 'base', 'small', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large-v3-turbo']
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL") or ("small" if DEVICE == "cpu" else "large-v3")
ALLOW_LARGE_ON_CPU = _env_bool("ALLOW_LARGE_ON_CPU", False)

# 模型缓存
whisper_models_cache = {}
diarize_model = None
diarize_model_loaded = False
align_models_cache = {}

def get_whisper_model(model_name: str):
    if model_name not in whisper_models_cache:
        logging.info(f"正在加载 Whisper 模型 '{model_name}'...")
        try:
            try:
                model = whisperx.load_model(model_name, DEVICE, compute_type=COMPUTE_TYPE, vad_method=VAD_METHOD)
            except TypeError:
                model = whisperx.load_model(model_name, DEVICE, compute_type=COMPUTE_TYPE)
            whisper_models_cache[model_name] = model
            logging.info(f"模型 '{model_name}' 加载成功。")
        except Exception as e:
            logging.error(f"加载 Whisper 模型 '{model_name}' 失败: {e}")
            if str(e).find('huggingface'):
                print(f"\n\n=======可能模型下载失败,请尝试科学上网后再次重试=======\n\n")
            raise
    return whisper_models_cache[model_name]

def get_align_model(language_code: str):
    if language_code not in align_models_cache:
        logging.info(f"正在加载对齐模型 (language={language_code})...")
        model_a, metadata = whisperx.load_align_model(language_code=language_code, device=DEVICE)
        align_models_cache[language_code] = (model_a, metadata)
        logging.info("对齐模型加载成功。")
    return align_models_cache[language_code]

def get_diarize_model():
    global diarize_model, diarize_model_loaded
    

    if not diarize_model_loaded:
        logging.info("正在尝试加载说话人分离模型...")
        if DiarizationPipeline is None:
            logging.warning("未检测到说话人分离依赖 (DiarizationPipeline),此功能将被禁用。")
            diarize_model_loaded = True
            return None
        if not HF_TOKEN:
            diarize_model_loaded = True
            return None
        try:
            diarize_model = DiarizationPipeline(use_auth_token=HF_TOKEN, device=DEVICE)
            diarize_model_loaded = True
            logging.info("说话人分离模型加载成功。")
        except Exception as e:
            logging.error(f"严重错误: 说话人分离模型加载失败。此功能将被禁用。错误信息: {e}")
            diarize_model = None 
            diarize_model_loaded = True
    return diarize_model

# --- Flask 应用 ---
app = Flask(__name__, template_folder='.')

@app.route('/', methods=['GET'])
def index():
    return render_template('index.html')

@app.route('/v1/audio/transcriptions', methods=['POST'])
def audio_transcriptions():
    if 'file' not in request.files:
        return jsonify({"error": "请求中未包含文件部分"}), 400
    file = request.files['file']
    if file.filename == '':
        return jsonify({"error": "未选择任何文件"}), 400
    
    print(request.form)
    model_id = request.form.get('model', DEFAULT_MODEL)
    model_name = 'large-v3' if model_id == 'large-v3-turbo' else model_id
    if model_name not in ALLOWED_MODELS:
        model_name = DEFAULT_MODEL
    if DEVICE == "cpu" and (model_name.startswith("large-") or model_name == "large") and not ALLOW_LARGE_ON_CPU:
        logging.warning(f"CPU 环境下请求大模型 '{model_name}',将自动降级为 'small' (可通过 ALLOW_LARGE_ON_CPU=1 关闭降级)。")
        model_name = "small"
    
    language = request.form.get('language') or None
    prompt = request.form.get('prompt')
    max_speakers=int(request.form.get('max_speakers',-1))
    min_speakers=int(request.form.get('min_speakers',0))
    
    logging.info(f"收到请求: 模型='{model_id}', 语言='{language or '自动检测'}', 提示词='{'有' if prompt else '无'}'")

    input_file_path = None
    processed_wav_path = None
    try:
        suffix = os.path.splitext(file.filename)[1]
        with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
            file.save(tmp.name)
            input_file_path = tmp.name

        logging.info(f"正在将上传的文件 '{file.filename}' 转换为标准的 16kHz 单声道 WAV 格式...")
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav:
            processed_wav_path = tmp_wav.name
        
        try:
            (
                ffmpeg
                .input(input_file_path)
                .output(processed_wav_path, ac=1, ar=16000, acodec='pcm_s16le', vn=None)
                .run(capture_stdout=True, capture_stderr=True, overwrite_output=True)
            )
            logging.info("文件格式转换成功。")
        except ffmpeg.Error as e:
            error_details = e.stderr.decode('utf-8', errors='ignore')
            logging.error(f"FFmpeg 文件转换失败: {error_details}")
            return jsonify({"error": f"音频/视频文件处理失败,可能是文件已损坏或格式不受支持。"}), 400

        audio = whisperx.load_audio(processed_wav_path)
        model = get_whisper_model(model_name)
        
        # ---
        # *** FIX IS HERE ***
        # ---
        transcribe_options = {}
        if language:
            transcribe_options['language'] = language
        if prompt:
            # 使用正确的参数名 'prompt'
            transcribe_options['prompt'] = prompt
        print('开始转录')
        result = model.transcribe(audio, batch_size=BATCH_SIZE, **transcribe_options)
        print('转录结束,准备对齐')        
        model_a, metadata = get_align_model(result["language"])
        result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
        
        if max_speakers>-1:
            print('进入说话人识别')
            diar_model = get_diarize_model()
            if diar_model:
                try:
                    diarize_segments = diar_model(audio,max_speakers=max_speakers if max_speakers>0 else None,min_speakers=min_speakers if min_speakers>0 else None)
                    result = whisperx.assign_word_speakers(diarize_segments, result)
                except Exception as e:
                    logging.error(f"说话人分离运行时失败: {e}。将回退到单说话人模式。")
        
        speakers = {segment.get('speaker') for segment in result["segments"] if 'speaker' in segment}
        is_single_speaker = len(speakers) <= 1
        logging.info(f"检测到的说话人: {speakers}。单说话人模式: {'是' if is_single_speaker else '否'}")

        speaker_mapping = {f"SPEAKER_{i:02d}": f"Speaker{i+1}" for i in range(20)}
        
        
        print(result)
        formatted_segments = []
        for segment in result["segments"]:
            speaker_raw = segment.get("speaker", "SPEAKER_00")
            speaker_name = speaker_mapping.get(speaker_raw, speaker_raw)
            text = segment['text'].strip()
            if not text:
                continue


            tmp={
                "start": segment['start'],
                "end": segment['end'],
                "text": text
            }
            segment_speaker = speaker_name if not is_single_speaker else None
            if segment_speaker:
                tmp['speaker']=segment_speaker
            formatted_segments.append(tmp)
        
        response_data = {"segments": formatted_segments}
        return jsonify(response_data)

    except Exception as e:
        logging.error(f"处理流程中发生未知错误: {e}", exc_info=True)
        return jsonify({"error": "处理过程中发生内部错误。"}), 500
    finally:
        if input_file_path and os.path.exists(input_file_path):
            os.remove(input_file_path)
            logging.info(f"已清理临时上传文件: {input_file_path}")
        if processed_wav_path and os.path.exists(processed_wav_path):
            os.remove(processed_wav_path)
            logging.info(f"已清理临时WAV文件: {processed_wav_path}")

# --- 启动服务 ---
def check_ffmpeg():
    if not shutil.which("ffmpeg"):
        logging.error("错误: 系统 PATH 中未找到 FFmpeg。")
        print("\n错误: 系统 PATH 中未找到 FFmpeg。")
        print("请确保您已安装 FFmpeg 并且其路径已添加到系统环境变量中。")
        print("Windows 安装指南: https://www.wikihow.com/Install-FFmpeg-on-Windows")
        print("macOS (使用 Homebrew): brew install ffmpeg")
        print("Linux (Ubuntu/Debian): sudo apt update && sudo apt install ffmpeg")
        sys.exit(1)
    logging.info("FFmpeg 环境检查通过。")

def open_browser(url):
    webbrowser.open_new(url)

if __name__ == '__main__':
    check_ffmpeg()
    host = os.environ.get("HOST", "127.0.0.1")
    port = int(os.environ.get("PORT", "9092"))
    url = f"http://{host}:{port}"
    running_in_space = bool(os.environ.get("SPACE_ID")) or bool(os.environ.get("HF_SPACE")) or bool(os.environ.get("SYSTEM") == "spaces")
    if _env_bool("OPEN_BROWSER", True) and not running_in_space:
        Timer(1, lambda: open_browser(url)).start()
    logging.info(f"服务已启动,正在监听 http://{host}:{port}")
    serve(app, host=host, port=port, threads=10)