Spaces:
Running
Running
File size: 12,205 Bytes
dc89cfc a202ab8 40a81d6 a202ab8 dc89cfc a202ab8 dc89cfc 40a81d6 dc89cfc 40a81d6 dc89cfc 40a81d6 dc89cfc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
import os
import tempfile
import torch
import whisperx
from flask import Flask, request, jsonify, render_template
from waitress import serve
import logging
import webbrowser
from threading import Timer
import shutil
import sys
import ffmpeg
try:
from whisperx.diarize import DiarizationPipeline
except Exception:
DiarizationPipeline = None
# --- 全局配置与初始化 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def _configure_torch_load_compat():
"""
PyTorch >=2.6 changed torch.load default `weights_only=True`, which can break
some third-party checkpoints (e.g. pyannote VAD models used by whisperx).
We only allowlist OmegaConf config objects required by those checkpoints.
"""
try:
import torch.serialization
from omegaconf import DictConfig, ListConfig
from omegaconf.base import ContainerMetadata
torch.serialization.add_safe_globals([DictConfig, ListConfig, ContainerMetadata])
logging.info("已配置 torch.load 安全全局白名单 (omegaconf DictConfig/ListConfig)。")
except Exception as e:
logging.warning(f"torch.load 兼容性配置跳过: {e}")
def _env_bool(name: str, default: bool) -> bool:
val = os.environ.get(name)
if val is None:
return default
return val.strip().lower() in {"1", "true", "yes", "y", "on"}
def get_hf_token():
"""
获取 Hugging Face 令牌。
优先从当前目录的 'token.txt' 文件读取,如果失败则从环境变量 'HUGGING_FACE_TOKEN' 读取。
"""
token = None
token_file = 'token.txt'
if os.path.exists(token_file):
try:
with open(token_file, 'r', encoding='utf-8') as f:
token = f.read().strip()
if token:
logging.info(f"成功从 {token_file} 文件中读取 Hugging Face 令牌。")
return token
except Exception as e:
logging.warning(f"无法从 {token_file} 读取令牌: {e}")
token = os.environ.get("HUGGING_FACE_TOKEN")
if token:
logging.info("成功从环境变量中读取 Hugging Face 令牌。")
else:
logging.warning("在 token.txt 或环境变量中均未找到 Hugging Face 令牌。说话人分离功能将被禁用。")
return token
HF_TOKEN = get_hf_token()
_configure_torch_load_compat()
# 设备和计算类型配置
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
COMPUTE_TYPE = "float16" if torch.cuda.is_available() else "int8"
BATCH_SIZE = 16 if DEVICE == "cuda" else 8
VAD_METHOD = os.environ.get("VAD_METHOD") or ("silero" if DEVICE == "cpu" else "pyannote")
logging.info(f"使用设备: {DEVICE},计算类型: {COMPUTE_TYPE}")
logging.info(f"VAD 方法: {VAD_METHOD}")
# 模型配置
ALLOWED_MODELS = ['tiny', 'base', 'small', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large-v3-turbo']
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL") or ("small" if DEVICE == "cpu" else "large-v3")
ALLOW_LARGE_ON_CPU = _env_bool("ALLOW_LARGE_ON_CPU", False)
# 模型缓存
whisper_models_cache = {}
diarize_model = None
diarize_model_loaded = False
align_models_cache = {}
def get_whisper_model(model_name: str):
if model_name not in whisper_models_cache:
logging.info(f"正在加载 Whisper 模型 '{model_name}'...")
try:
try:
model = whisperx.load_model(model_name, DEVICE, compute_type=COMPUTE_TYPE, vad_method=VAD_METHOD)
except TypeError:
model = whisperx.load_model(model_name, DEVICE, compute_type=COMPUTE_TYPE)
whisper_models_cache[model_name] = model
logging.info(f"模型 '{model_name}' 加载成功。")
except Exception as e:
logging.error(f"加载 Whisper 模型 '{model_name}' 失败: {e}")
if str(e).find('huggingface'):
print(f"\n\n=======可能模型下载失败,请尝试科学上网后再次重试=======\n\n")
raise
return whisper_models_cache[model_name]
def get_align_model(language_code: str):
if language_code not in align_models_cache:
logging.info(f"正在加载对齐模型 (language={language_code})...")
model_a, metadata = whisperx.load_align_model(language_code=language_code, device=DEVICE)
align_models_cache[language_code] = (model_a, metadata)
logging.info("对齐模型加载成功。")
return align_models_cache[language_code]
def get_diarize_model():
global diarize_model, diarize_model_loaded
if not diarize_model_loaded:
logging.info("正在尝试加载说话人分离模型...")
if DiarizationPipeline is None:
logging.warning("未检测到说话人分离依赖 (DiarizationPipeline),此功能将被禁用。")
diarize_model_loaded = True
return None
if not HF_TOKEN:
diarize_model_loaded = True
return None
try:
diarize_model = DiarizationPipeline(use_auth_token=HF_TOKEN, device=DEVICE)
diarize_model_loaded = True
logging.info("说话人分离模型加载成功。")
except Exception as e:
logging.error(f"严重错误: 说话人分离模型加载失败。此功能将被禁用。错误信息: {e}")
diarize_model = None
diarize_model_loaded = True
return diarize_model
# --- Flask 应用 ---
app = Flask(__name__, template_folder='.')
@app.route('/', methods=['GET'])
def index():
return render_template('index.html')
@app.route('/v1/audio/transcriptions', methods=['POST'])
def audio_transcriptions():
if 'file' not in request.files:
return jsonify({"error": "请求中未包含文件部分"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "未选择任何文件"}), 400
print(request.form)
model_id = request.form.get('model', DEFAULT_MODEL)
model_name = 'large-v3' if model_id == 'large-v3-turbo' else model_id
if model_name not in ALLOWED_MODELS:
model_name = DEFAULT_MODEL
if DEVICE == "cpu" and (model_name.startswith("large-") or model_name == "large") and not ALLOW_LARGE_ON_CPU:
logging.warning(f"CPU 环境下请求大模型 '{model_name}',将自动降级为 'small' (可通过 ALLOW_LARGE_ON_CPU=1 关闭降级)。")
model_name = "small"
language = request.form.get('language') or None
prompt = request.form.get('prompt')
max_speakers=int(request.form.get('max_speakers',-1))
min_speakers=int(request.form.get('min_speakers',0))
logging.info(f"收到请求: 模型='{model_id}', 语言='{language or '自动检测'}', 提示词='{'有' if prompt else '无'}'")
input_file_path = None
processed_wav_path = None
try:
suffix = os.path.splitext(file.filename)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
file.save(tmp.name)
input_file_path = tmp.name
logging.info(f"正在将上传的文件 '{file.filename}' 转换为标准的 16kHz 单声道 WAV 格式...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav:
processed_wav_path = tmp_wav.name
try:
(
ffmpeg
.input(input_file_path)
.output(processed_wav_path, ac=1, ar=16000, acodec='pcm_s16le', vn=None)
.run(capture_stdout=True, capture_stderr=True, overwrite_output=True)
)
logging.info("文件格式转换成功。")
except ffmpeg.Error as e:
error_details = e.stderr.decode('utf-8', errors='ignore')
logging.error(f"FFmpeg 文件转换失败: {error_details}")
return jsonify({"error": f"音频/视频文件处理失败,可能是文件已损坏或格式不受支持。"}), 400
audio = whisperx.load_audio(processed_wav_path)
model = get_whisper_model(model_name)
# ---
# *** FIX IS HERE ***
# ---
transcribe_options = {}
if language:
transcribe_options['language'] = language
if prompt:
# 使用正确的参数名 'prompt'
transcribe_options['prompt'] = prompt
print('开始转录')
result = model.transcribe(audio, batch_size=BATCH_SIZE, **transcribe_options)
print('转录结束,准备对齐')
model_a, metadata = get_align_model(result["language"])
result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
if max_speakers>-1:
print('进入说话人识别')
diar_model = get_diarize_model()
if diar_model:
try:
diarize_segments = diar_model(audio,max_speakers=max_speakers if max_speakers>0 else None,min_speakers=min_speakers if min_speakers>0 else None)
result = whisperx.assign_word_speakers(diarize_segments, result)
except Exception as e:
logging.error(f"说话人分离运行时失败: {e}。将回退到单说话人模式。")
speakers = {segment.get('speaker') for segment in result["segments"] if 'speaker' in segment}
is_single_speaker = len(speakers) <= 1
logging.info(f"检测到的说话人: {speakers}。单说话人模式: {'是' if is_single_speaker else '否'}")
speaker_mapping = {f"SPEAKER_{i:02d}": f"Speaker{i+1}" for i in range(20)}
print(result)
formatted_segments = []
for segment in result["segments"]:
speaker_raw = segment.get("speaker", "SPEAKER_00")
speaker_name = speaker_mapping.get(speaker_raw, speaker_raw)
text = segment['text'].strip()
if not text:
continue
tmp={
"start": segment['start'],
"end": segment['end'],
"text": text
}
segment_speaker = speaker_name if not is_single_speaker else None
if segment_speaker:
tmp['speaker']=segment_speaker
formatted_segments.append(tmp)
response_data = {"segments": formatted_segments}
return jsonify(response_data)
except Exception as e:
logging.error(f"处理流程中发生未知错误: {e}", exc_info=True)
return jsonify({"error": "处理过程中发生内部错误。"}), 500
finally:
if input_file_path and os.path.exists(input_file_path):
os.remove(input_file_path)
logging.info(f"已清理临时上传文件: {input_file_path}")
if processed_wav_path and os.path.exists(processed_wav_path):
os.remove(processed_wav_path)
logging.info(f"已清理临时WAV文件: {processed_wav_path}")
# --- 启动服务 ---
def check_ffmpeg():
if not shutil.which("ffmpeg"):
logging.error("错误: 系统 PATH 中未找到 FFmpeg。")
print("\n错误: 系统 PATH 中未找到 FFmpeg。")
print("请确保您已安装 FFmpeg 并且其路径已添加到系统环境变量中。")
print("Windows 安装指南: https://www.wikihow.com/Install-FFmpeg-on-Windows")
print("macOS (使用 Homebrew): brew install ffmpeg")
print("Linux (Ubuntu/Debian): sudo apt update && sudo apt install ffmpeg")
sys.exit(1)
logging.info("FFmpeg 环境检查通过。")
def open_browser(url):
webbrowser.open_new(url)
if __name__ == '__main__':
check_ffmpeg()
host = os.environ.get("HOST", "127.0.0.1")
port = int(os.environ.get("PORT", "9092"))
url = f"http://{host}:{port}"
running_in_space = bool(os.environ.get("SPACE_ID")) or bool(os.environ.get("HF_SPACE")) or bool(os.environ.get("SYSTEM") == "spaces")
if _env_bool("OPEN_BROWSER", True) and not running_in_space:
Timer(1, lambda: open_browser(url)).start()
logging.info(f"服务已启动,正在监听 http://{host}:{port}")
serve(app, host=host, port=port, threads=10)
|