whisperx-api / app.py
HFHash789's picture
Upload folder using huggingface_hub
40a81d6 verified
import os
import tempfile
import torch
import whisperx
from flask import Flask, request, jsonify, render_template
from waitress import serve
import logging
import webbrowser
from threading import Timer
import shutil
import sys
import ffmpeg
try:
from whisperx.diarize import DiarizationPipeline
except Exception:
DiarizationPipeline = None
# --- 全局配置与初始化 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def _configure_torch_load_compat():
"""
PyTorch >=2.6 changed torch.load default `weights_only=True`, which can break
some third-party checkpoints (e.g. pyannote VAD models used by whisperx).
We only allowlist OmegaConf config objects required by those checkpoints.
"""
try:
import torch.serialization
from omegaconf import DictConfig, ListConfig
from omegaconf.base import ContainerMetadata
torch.serialization.add_safe_globals([DictConfig, ListConfig, ContainerMetadata])
logging.info("已配置 torch.load 安全全局白名单 (omegaconf DictConfig/ListConfig)。")
except Exception as e:
logging.warning(f"torch.load 兼容性配置跳过: {e}")
def _env_bool(name: str, default: bool) -> bool:
val = os.environ.get(name)
if val is None:
return default
return val.strip().lower() in {"1", "true", "yes", "y", "on"}
def get_hf_token():
"""
获取 Hugging Face 令牌。
优先从当前目录的 'token.txt' 文件读取,如果失败则从环境变量 'HUGGING_FACE_TOKEN' 读取。
"""
token = None
token_file = 'token.txt'
if os.path.exists(token_file):
try:
with open(token_file, 'r', encoding='utf-8') as f:
token = f.read().strip()
if token:
logging.info(f"成功从 {token_file} 文件中读取 Hugging Face 令牌。")
return token
except Exception as e:
logging.warning(f"无法从 {token_file} 读取令牌: {e}")
token = os.environ.get("HUGGING_FACE_TOKEN")
if token:
logging.info("成功从环境变量中读取 Hugging Face 令牌。")
else:
logging.warning("在 token.txt 或环境变量中均未找到 Hugging Face 令牌。说话人分离功能将被禁用。")
return token
HF_TOKEN = get_hf_token()
_configure_torch_load_compat()
# 设备和计算类型配置
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
COMPUTE_TYPE = "float16" if torch.cuda.is_available() else "int8"
BATCH_SIZE = 16 if DEVICE == "cuda" else 8
VAD_METHOD = os.environ.get("VAD_METHOD") or ("silero" if DEVICE == "cpu" else "pyannote")
logging.info(f"使用设备: {DEVICE},计算类型: {COMPUTE_TYPE}")
logging.info(f"VAD 方法: {VAD_METHOD}")
# 模型配置
ALLOWED_MODELS = ['tiny', 'base', 'small', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large-v3-turbo']
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL") or ("small" if DEVICE == "cpu" else "large-v3")
ALLOW_LARGE_ON_CPU = _env_bool("ALLOW_LARGE_ON_CPU", False)
# 模型缓存
whisper_models_cache = {}
diarize_model = None
diarize_model_loaded = False
align_models_cache = {}
def get_whisper_model(model_name: str):
if model_name not in whisper_models_cache:
logging.info(f"正在加载 Whisper 模型 '{model_name}'...")
try:
try:
model = whisperx.load_model(model_name, DEVICE, compute_type=COMPUTE_TYPE, vad_method=VAD_METHOD)
except TypeError:
model = whisperx.load_model(model_name, DEVICE, compute_type=COMPUTE_TYPE)
whisper_models_cache[model_name] = model
logging.info(f"模型 '{model_name}' 加载成功。")
except Exception as e:
logging.error(f"加载 Whisper 模型 '{model_name}' 失败: {e}")
if str(e).find('huggingface'):
print(f"\n\n=======可能模型下载失败,请尝试科学上网后再次重试=======\n\n")
raise
return whisper_models_cache[model_name]
def get_align_model(language_code: str):
if language_code not in align_models_cache:
logging.info(f"正在加载对齐模型 (language={language_code})...")
model_a, metadata = whisperx.load_align_model(language_code=language_code, device=DEVICE)
align_models_cache[language_code] = (model_a, metadata)
logging.info("对齐模型加载成功。")
return align_models_cache[language_code]
def get_diarize_model():
global diarize_model, diarize_model_loaded
if not diarize_model_loaded:
logging.info("正在尝试加载说话人分离模型...")
if DiarizationPipeline is None:
logging.warning("未检测到说话人分离依赖 (DiarizationPipeline),此功能将被禁用。")
diarize_model_loaded = True
return None
if not HF_TOKEN:
diarize_model_loaded = True
return None
try:
diarize_model = DiarizationPipeline(use_auth_token=HF_TOKEN, device=DEVICE)
diarize_model_loaded = True
logging.info("说话人分离模型加载成功。")
except Exception as e:
logging.error(f"严重错误: 说话人分离模型加载失败。此功能将被禁用。错误信息: {e}")
diarize_model = None
diarize_model_loaded = True
return diarize_model
# --- Flask 应用 ---
app = Flask(__name__, template_folder='.')
@app.route('/', methods=['GET'])
def index():
return render_template('index.html')
@app.route('/v1/audio/transcriptions', methods=['POST'])
def audio_transcriptions():
if 'file' not in request.files:
return jsonify({"error": "请求中未包含文件部分"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "未选择任何文件"}), 400
print(request.form)
model_id = request.form.get('model', DEFAULT_MODEL)
model_name = 'large-v3' if model_id == 'large-v3-turbo' else model_id
if model_name not in ALLOWED_MODELS:
model_name = DEFAULT_MODEL
if DEVICE == "cpu" and (model_name.startswith("large-") or model_name == "large") and not ALLOW_LARGE_ON_CPU:
logging.warning(f"CPU 环境下请求大模型 '{model_name}',将自动降级为 'small' (可通过 ALLOW_LARGE_ON_CPU=1 关闭降级)。")
model_name = "small"
language = request.form.get('language') or None
prompt = request.form.get('prompt')
max_speakers=int(request.form.get('max_speakers',-1))
min_speakers=int(request.form.get('min_speakers',0))
logging.info(f"收到请求: 模型='{model_id}', 语言='{language or '自动检测'}', 提示词='{'有' if prompt else '无'}'")
input_file_path = None
processed_wav_path = None
try:
suffix = os.path.splitext(file.filename)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
file.save(tmp.name)
input_file_path = tmp.name
logging.info(f"正在将上传的文件 '{file.filename}' 转换为标准的 16kHz 单声道 WAV 格式...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav:
processed_wav_path = tmp_wav.name
try:
(
ffmpeg
.input(input_file_path)
.output(processed_wav_path, ac=1, ar=16000, acodec='pcm_s16le', vn=None)
.run(capture_stdout=True, capture_stderr=True, overwrite_output=True)
)
logging.info("文件格式转换成功。")
except ffmpeg.Error as e:
error_details = e.stderr.decode('utf-8', errors='ignore')
logging.error(f"FFmpeg 文件转换失败: {error_details}")
return jsonify({"error": f"音频/视频文件处理失败,可能是文件已损坏或格式不受支持。"}), 400
audio = whisperx.load_audio(processed_wav_path)
model = get_whisper_model(model_name)
# ---
# *** FIX IS HERE ***
# ---
transcribe_options = {}
if language:
transcribe_options['language'] = language
if prompt:
# 使用正确的参数名 'prompt'
transcribe_options['prompt'] = prompt
print('开始转录')
result = model.transcribe(audio, batch_size=BATCH_SIZE, **transcribe_options)
print('转录结束,准备对齐')
model_a, metadata = get_align_model(result["language"])
result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
if max_speakers>-1:
print('进入说话人识别')
diar_model = get_diarize_model()
if diar_model:
try:
diarize_segments = diar_model(audio,max_speakers=max_speakers if max_speakers>0 else None,min_speakers=min_speakers if min_speakers>0 else None)
result = whisperx.assign_word_speakers(diarize_segments, result)
except Exception as e:
logging.error(f"说话人分离运行时失败: {e}。将回退到单说话人模式。")
speakers = {segment.get('speaker') for segment in result["segments"] if 'speaker' in segment}
is_single_speaker = len(speakers) <= 1
logging.info(f"检测到的说话人: {speakers}。单说话人模式: {'是' if is_single_speaker else '否'}")
speaker_mapping = {f"SPEAKER_{i:02d}": f"Speaker{i+1}" for i in range(20)}
print(result)
formatted_segments = []
for segment in result["segments"]:
speaker_raw = segment.get("speaker", "SPEAKER_00")
speaker_name = speaker_mapping.get(speaker_raw, speaker_raw)
text = segment['text'].strip()
if not text:
continue
tmp={
"start": segment['start'],
"end": segment['end'],
"text": text
}
segment_speaker = speaker_name if not is_single_speaker else None
if segment_speaker:
tmp['speaker']=segment_speaker
formatted_segments.append(tmp)
response_data = {"segments": formatted_segments}
return jsonify(response_data)
except Exception as e:
logging.error(f"处理流程中发生未知错误: {e}", exc_info=True)
return jsonify({"error": "处理过程中发生内部错误。"}), 500
finally:
if input_file_path and os.path.exists(input_file_path):
os.remove(input_file_path)
logging.info(f"已清理临时上传文件: {input_file_path}")
if processed_wav_path and os.path.exists(processed_wav_path):
os.remove(processed_wav_path)
logging.info(f"已清理临时WAV文件: {processed_wav_path}")
# --- 启动服务 ---
def check_ffmpeg():
if not shutil.which("ffmpeg"):
logging.error("错误: 系统 PATH 中未找到 FFmpeg。")
print("\n错误: 系统 PATH 中未找到 FFmpeg。")
print("请确保您已安装 FFmpeg 并且其路径已添加到系统环境变量中。")
print("Windows 安装指南: https://www.wikihow.com/Install-FFmpeg-on-Windows")
print("macOS (使用 Homebrew): brew install ffmpeg")
print("Linux (Ubuntu/Debian): sudo apt update && sudo apt install ffmpeg")
sys.exit(1)
logging.info("FFmpeg 环境检查通过。")
def open_browser(url):
webbrowser.open_new(url)
if __name__ == '__main__':
check_ffmpeg()
host = os.environ.get("HOST", "127.0.0.1")
port = int(os.environ.get("PORT", "9092"))
url = f"http://{host}:{port}"
running_in_space = bool(os.environ.get("SPACE_ID")) or bool(os.environ.get("HF_SPACE")) or bool(os.environ.get("SYSTEM") == "spaces")
if _env_bool("OPEN_BROWSER", True) and not running_in_space:
Timer(1, lambda: open_browser(url)).start()
logging.info(f"服务已启动,正在监听 http://{host}:{port}")
serve(app, host=host, port=port, threads=10)