Spaces:
Running
Running
| import os | |
| import tempfile | |
| import torch | |
| import whisperx | |
| from flask import Flask, request, jsonify, render_template | |
| from waitress import serve | |
| import logging | |
| import webbrowser | |
| from threading import Timer | |
| import shutil | |
| import sys | |
| import ffmpeg | |
| try: | |
| from whisperx.diarize import DiarizationPipeline | |
| except Exception: | |
| DiarizationPipeline = None | |
| # --- 全局配置与初始化 --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| def _configure_torch_load_compat(): | |
| """ | |
| PyTorch >=2.6 changed torch.load default `weights_only=True`, which can break | |
| some third-party checkpoints (e.g. pyannote VAD models used by whisperx). | |
| We only allowlist OmegaConf config objects required by those checkpoints. | |
| """ | |
| try: | |
| import torch.serialization | |
| from omegaconf import DictConfig, ListConfig | |
| from omegaconf.base import ContainerMetadata | |
| torch.serialization.add_safe_globals([DictConfig, ListConfig, ContainerMetadata]) | |
| logging.info("已配置 torch.load 安全全局白名单 (omegaconf DictConfig/ListConfig)。") | |
| except Exception as e: | |
| logging.warning(f"torch.load 兼容性配置跳过: {e}") | |
| def _env_bool(name: str, default: bool) -> bool: | |
| val = os.environ.get(name) | |
| if val is None: | |
| return default | |
| return val.strip().lower() in {"1", "true", "yes", "y", "on"} | |
| def get_hf_token(): | |
| """ | |
| 获取 Hugging Face 令牌。 | |
| 优先从当前目录的 'token.txt' 文件读取,如果失败则从环境变量 'HUGGING_FACE_TOKEN' 读取。 | |
| """ | |
| token = None | |
| token_file = 'token.txt' | |
| if os.path.exists(token_file): | |
| try: | |
| with open(token_file, 'r', encoding='utf-8') as f: | |
| token = f.read().strip() | |
| if token: | |
| logging.info(f"成功从 {token_file} 文件中读取 Hugging Face 令牌。") | |
| return token | |
| except Exception as e: | |
| logging.warning(f"无法从 {token_file} 读取令牌: {e}") | |
| token = os.environ.get("HUGGING_FACE_TOKEN") | |
| if token: | |
| logging.info("成功从环境变量中读取 Hugging Face 令牌。") | |
| else: | |
| logging.warning("在 token.txt 或环境变量中均未找到 Hugging Face 令牌。说话人分离功能将被禁用。") | |
| return token | |
| HF_TOKEN = get_hf_token() | |
| _configure_torch_load_compat() | |
| # 设备和计算类型配置 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| COMPUTE_TYPE = "float16" if torch.cuda.is_available() else "int8" | |
| BATCH_SIZE = 16 if DEVICE == "cuda" else 8 | |
| VAD_METHOD = os.environ.get("VAD_METHOD") or ("silero" if DEVICE == "cpu" else "pyannote") | |
| logging.info(f"使用设备: {DEVICE},计算类型: {COMPUTE_TYPE}") | |
| logging.info(f"VAD 方法: {VAD_METHOD}") | |
| # 模型配置 | |
| ALLOWED_MODELS = ['tiny', 'base', 'small', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large-v3-turbo'] | |
| DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL") or ("small" if DEVICE == "cpu" else "large-v3") | |
| ALLOW_LARGE_ON_CPU = _env_bool("ALLOW_LARGE_ON_CPU", False) | |
| # 模型缓存 | |
| whisper_models_cache = {} | |
| diarize_model = None | |
| diarize_model_loaded = False | |
| align_models_cache = {} | |
| def get_whisper_model(model_name: str): | |
| if model_name not in whisper_models_cache: | |
| logging.info(f"正在加载 Whisper 模型 '{model_name}'...") | |
| try: | |
| try: | |
| model = whisperx.load_model(model_name, DEVICE, compute_type=COMPUTE_TYPE, vad_method=VAD_METHOD) | |
| except TypeError: | |
| model = whisperx.load_model(model_name, DEVICE, compute_type=COMPUTE_TYPE) | |
| whisper_models_cache[model_name] = model | |
| logging.info(f"模型 '{model_name}' 加载成功。") | |
| except Exception as e: | |
| logging.error(f"加载 Whisper 模型 '{model_name}' 失败: {e}") | |
| if str(e).find('huggingface'): | |
| print(f"\n\n=======可能模型下载失败,请尝试科学上网后再次重试=======\n\n") | |
| raise | |
| return whisper_models_cache[model_name] | |
| def get_align_model(language_code: str): | |
| if language_code not in align_models_cache: | |
| logging.info(f"正在加载对齐模型 (language={language_code})...") | |
| model_a, metadata = whisperx.load_align_model(language_code=language_code, device=DEVICE) | |
| align_models_cache[language_code] = (model_a, metadata) | |
| logging.info("对齐模型加载成功。") | |
| return align_models_cache[language_code] | |
| def get_diarize_model(): | |
| global diarize_model, diarize_model_loaded | |
| if not diarize_model_loaded: | |
| logging.info("正在尝试加载说话人分离模型...") | |
| if DiarizationPipeline is None: | |
| logging.warning("未检测到说话人分离依赖 (DiarizationPipeline),此功能将被禁用。") | |
| diarize_model_loaded = True | |
| return None | |
| if not HF_TOKEN: | |
| diarize_model_loaded = True | |
| return None | |
| try: | |
| diarize_model = DiarizationPipeline(use_auth_token=HF_TOKEN, device=DEVICE) | |
| diarize_model_loaded = True | |
| logging.info("说话人分离模型加载成功。") | |
| except Exception as e: | |
| logging.error(f"严重错误: 说话人分离模型加载失败。此功能将被禁用。错误信息: {e}") | |
| diarize_model = None | |
| diarize_model_loaded = True | |
| return diarize_model | |
| # --- Flask 应用 --- | |
| app = Flask(__name__, template_folder='.') | |
| def index(): | |
| return render_template('index.html') | |
| def audio_transcriptions(): | |
| if 'file' not in request.files: | |
| return jsonify({"error": "请求中未包含文件部分"}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({"error": "未选择任何文件"}), 400 | |
| print(request.form) | |
| model_id = request.form.get('model', DEFAULT_MODEL) | |
| model_name = 'large-v3' if model_id == 'large-v3-turbo' else model_id | |
| if model_name not in ALLOWED_MODELS: | |
| model_name = DEFAULT_MODEL | |
| if DEVICE == "cpu" and (model_name.startswith("large-") or model_name == "large") and not ALLOW_LARGE_ON_CPU: | |
| logging.warning(f"CPU 环境下请求大模型 '{model_name}',将自动降级为 'small' (可通过 ALLOW_LARGE_ON_CPU=1 关闭降级)。") | |
| model_name = "small" | |
| language = request.form.get('language') or None | |
| prompt = request.form.get('prompt') | |
| max_speakers=int(request.form.get('max_speakers',-1)) | |
| min_speakers=int(request.form.get('min_speakers',0)) | |
| logging.info(f"收到请求: 模型='{model_id}', 语言='{language or '自动检测'}', 提示词='{'有' if prompt else '无'}'") | |
| input_file_path = None | |
| processed_wav_path = None | |
| try: | |
| suffix = os.path.splitext(file.filename)[1] | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| file.save(tmp.name) | |
| input_file_path = tmp.name | |
| logging.info(f"正在将上传的文件 '{file.filename}' 转换为标准的 16kHz 单声道 WAV 格式...") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav: | |
| processed_wav_path = tmp_wav.name | |
| try: | |
| ( | |
| ffmpeg | |
| .input(input_file_path) | |
| .output(processed_wav_path, ac=1, ar=16000, acodec='pcm_s16le', vn=None) | |
| .run(capture_stdout=True, capture_stderr=True, overwrite_output=True) | |
| ) | |
| logging.info("文件格式转换成功。") | |
| except ffmpeg.Error as e: | |
| error_details = e.stderr.decode('utf-8', errors='ignore') | |
| logging.error(f"FFmpeg 文件转换失败: {error_details}") | |
| return jsonify({"error": f"音频/视频文件处理失败,可能是文件已损坏或格式不受支持。"}), 400 | |
| audio = whisperx.load_audio(processed_wav_path) | |
| model = get_whisper_model(model_name) | |
| # --- | |
| # *** FIX IS HERE *** | |
| # --- | |
| transcribe_options = {} | |
| if language: | |
| transcribe_options['language'] = language | |
| if prompt: | |
| # 使用正确的参数名 'prompt' | |
| transcribe_options['prompt'] = prompt | |
| print('开始转录') | |
| result = model.transcribe(audio, batch_size=BATCH_SIZE, **transcribe_options) | |
| print('转录结束,准备对齐') | |
| model_a, metadata = get_align_model(result["language"]) | |
| result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False) | |
| if max_speakers>-1: | |
| print('进入说话人识别') | |
| diar_model = get_diarize_model() | |
| if diar_model: | |
| try: | |
| diarize_segments = diar_model(audio,max_speakers=max_speakers if max_speakers>0 else None,min_speakers=min_speakers if min_speakers>0 else None) | |
| result = whisperx.assign_word_speakers(diarize_segments, result) | |
| except Exception as e: | |
| logging.error(f"说话人分离运行时失败: {e}。将回退到单说话人模式。") | |
| speakers = {segment.get('speaker') for segment in result["segments"] if 'speaker' in segment} | |
| is_single_speaker = len(speakers) <= 1 | |
| logging.info(f"检测到的说话人: {speakers}。单说话人模式: {'是' if is_single_speaker else '否'}") | |
| speaker_mapping = {f"SPEAKER_{i:02d}": f"Speaker{i+1}" for i in range(20)} | |
| print(result) | |
| formatted_segments = [] | |
| for segment in result["segments"]: | |
| speaker_raw = segment.get("speaker", "SPEAKER_00") | |
| speaker_name = speaker_mapping.get(speaker_raw, speaker_raw) | |
| text = segment['text'].strip() | |
| if not text: | |
| continue | |
| tmp={ | |
| "start": segment['start'], | |
| "end": segment['end'], | |
| "text": text | |
| } | |
| segment_speaker = speaker_name if not is_single_speaker else None | |
| if segment_speaker: | |
| tmp['speaker']=segment_speaker | |
| formatted_segments.append(tmp) | |
| response_data = {"segments": formatted_segments} | |
| return jsonify(response_data) | |
| except Exception as e: | |
| logging.error(f"处理流程中发生未知错误: {e}", exc_info=True) | |
| return jsonify({"error": "处理过程中发生内部错误。"}), 500 | |
| finally: | |
| if input_file_path and os.path.exists(input_file_path): | |
| os.remove(input_file_path) | |
| logging.info(f"已清理临时上传文件: {input_file_path}") | |
| if processed_wav_path and os.path.exists(processed_wav_path): | |
| os.remove(processed_wav_path) | |
| logging.info(f"已清理临时WAV文件: {processed_wav_path}") | |
| # --- 启动服务 --- | |
| def check_ffmpeg(): | |
| if not shutil.which("ffmpeg"): | |
| logging.error("错误: 系统 PATH 中未找到 FFmpeg。") | |
| print("\n错误: 系统 PATH 中未找到 FFmpeg。") | |
| print("请确保您已安装 FFmpeg 并且其路径已添加到系统环境变量中。") | |
| print("Windows 安装指南: https://www.wikihow.com/Install-FFmpeg-on-Windows") | |
| print("macOS (使用 Homebrew): brew install ffmpeg") | |
| print("Linux (Ubuntu/Debian): sudo apt update && sudo apt install ffmpeg") | |
| sys.exit(1) | |
| logging.info("FFmpeg 环境检查通过。") | |
| def open_browser(url): | |
| webbrowser.open_new(url) | |
| if __name__ == '__main__': | |
| check_ffmpeg() | |
| host = os.environ.get("HOST", "127.0.0.1") | |
| port = int(os.environ.get("PORT", "9092")) | |
| url = f"http://{host}:{port}" | |
| running_in_space = bool(os.environ.get("SPACE_ID")) or bool(os.environ.get("HF_SPACE")) or bool(os.environ.get("SYSTEM") == "spaces") | |
| if _env_bool("OPEN_BROWSER", True) and not running_in_space: | |
| Timer(1, lambda: open_browser(url)).start() | |
| logging.info(f"服务已启动,正在监听 http://{host}:{port}") | |
| serve(app, host=host, port=port, threads=10) | |