| from fastapi import FastAPI, UploadFile, File, HTTPException |
| from fastapi.responses import JSONResponse |
| import subprocess |
| import tempfile |
| import os |
| import shutil |
| from pydantic import BaseModel |
| import sys |
| import numpy as np |
| import soundfile as sf |
| from typing import Optional, List |
| import librosa |
|
|
| app = FastAPI(title="ViSQOL 音频质量 API") |
|
|
| |
| |
| VISQOL_DIR = "./build/visqol" |
| VISQOL_LIB_PATH = os.path.join(VISQOL_DIR, "visqol_lib_py.so") |
| PB2_DIR = os.path.join(VISQOL_DIR, "pb2") |
| MODEL_DIR = os.path.join(VISQOL_DIR, "model") |
| SPEECH_MODEL_PATH = os.path.join(MODEL_DIR, "libsvm_nu_svr_model.txt") |
| AUDIO_MODEL_PATH = os.path.join(MODEL_DIR, "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite") |
| |
|
|
| |
| required_files = [VISQOL_LIB_PATH, SPEECH_MODEL_PATH, AUDIO_MODEL_PATH] |
| if not all(os.path.exists(f) for f in required_files): |
| missing = [f for f in required_files if not os.path.exists(f)] |
| raise FileNotFoundError(f"ViSQOL 必需文件未找到: {', '.join(missing)}") |
| if not os.path.exists(PB2_DIR) or not os.path.isdir(PB2_DIR): |
| raise FileNotFoundError(f"ViSQOL pb2 目录未找到: {PB2_DIR}") |
|
|
| |
| try: |
| |
| sys.path.insert(0, os.path.abspath(PB2_DIR)) |
| sys.path.insert(0, os.path.abspath(VISQOL_DIR)) |
| |
| |
| |
| import visqol_lib_py |
| import similarity_result_pb2 |
| import visqol_config_pb2 |
| print("ViSQOL 库和 pb2 文件导入成功。") |
| except ImportError as e: |
| print(f"错误:无法导入 ViSQOL 库或 pb2 文件。") |
| print(f"Python 搜索路径: {sys.path}") |
| print(f"错误详情: {e}") |
| |
| |
| visqol_lib_py = None |
|
|
| |
| class VisqolResponse(BaseModel): |
| reference_filename: str |
| degraded_filename: str |
| mode: str |
| moslqo: float |
| vnsim: Optional[float] = None |
| fvnsim: Optional[List[float]] = None |
| status: str |
| error_message: Optional[str] = None |
|
|
| |
| def convert_and_resample_audio(input_path, output_path, target_sr): |
| """Converts audio to WAV format and resamples using ffmpeg.""" |
| cmd = [ |
| 'ffmpeg', |
| '-y', |
| '-i', input_path, |
| '-ar', str(target_sr), |
| '-ac', '1', |
| output_path |
| ] |
| print(f"Running ffmpeg: {' '.join(cmd)}") |
| try: |
| result = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8') |
| print("ffmpeg conversion successful.") |
| |
| return True |
| except FileNotFoundError: |
| print("错误: ffmpeg 未找到,无法转换音频。请确保已在 Docker 环境中安装 ffmpeg。") |
| return False |
| except subprocess.CalledProcessError as e: |
| print(f"错误: ffmpeg 执行失败 (返回码 {e.returncode})。") |
| print(f"ffmpeg stderr: {e.stderr}") |
| return False |
| except Exception as e: |
| print(f"转换音频时发生未知错误: {e}") |
| return False |
|
|
| @app.post("/evaluate/", response_model=VisqolResponse) |
| async def evaluate_audio( |
| reference: UploadFile = File(..., description="参考音频文件"), |
| degraded: UploadFile = File(..., description="待评估音频文件"), |
| mode: str = "audio" |
| ): |
| """ |
| 使用 ViSQOL 评估两个音频文件之间的感知相似度。 |
| 返回预测的平均意见得分 (MOS-LQO)。 |
| """ |
| if visqol_lib_py is None: |
| raise HTTPException(status_code=500, detail="ViSQOL 库未成功加载。") |
|
|
| if mode not in ["audio", "speech"]: |
| raise HTTPException(status_code=400, detail="模式参数 'mode' 必须是 'audio' 或 'speech'") |
|
|
| temp_dir = tempfile.mkdtemp() |
| |
| ref_temp_orig = os.path.join(temp_dir, f"ref_{reference.filename}") |
| deg_temp_orig = os.path.join(temp_dir, f"deg_{degraded.filename}") |
| |
| ref_path_wav = os.path.join(temp_dir, "reference.wav") |
| deg_path_wav = os.path.join(temp_dir, "degraded.wav") |
| |
| mos = -1.0 |
| vnsim_val = None |
| fvnsim_val = None |
| status_msg = "处理失败" |
| error_msg = None |
|
|
| try: |
| |
| ref_content = await reference.read() |
| with open(ref_temp_orig, "wb") as f: f.write(ref_content) |
| deg_content = await degraded.read() |
| with open(deg_temp_orig, "wb") as f: f.write(deg_content) |
| await reference.close() |
| await degraded.close() |
|
|
| |
| target_sr = 48000 if mode == 'audio' else 16000 |
| print(f"目标采样率: {target_sr} Hz for mode '{mode}'") |
| |
| conv_ref_ok = convert_and_resample_audio(ref_temp_orig, ref_path_wav, target_sr) |
| conv_deg_ok = convert_and_resample_audio(deg_temp_orig, deg_path_wav, target_sr) |
| |
| if not (conv_ref_ok and conv_deg_ok): |
| raise HTTPException(status_code=500, detail="使用 ffmpeg 转换或重采样音频文件失败。") |
|
|
| |
| try: |
| ref_info = sf.info(ref_path_wav) |
| deg_info = sf.info(deg_path_wav) |
| if ref_info.samplerate != target_sr or deg_info.samplerate != target_sr: |
| print(f"警告:ffmpeg 转换后的采样率 ({ref_info.samplerate}/{deg_info.samplerate}) 与目标 ({target_sr}) 不符,可能影响 ViSQOL 结果。") |
| except Exception as audio_e: |
| |
| raise HTTPException(status_code=400, detail=f"无法读取转换后的 WAV 文件: {audio_e}") |
|
|
| |
| try: |
| print(f"从 WAV 加载音频数据: {ref_path_wav}, {deg_path_wav}") |
| |
| ref_data, sr_ref = sf.read(ref_path_wav, dtype='float64') |
| deg_data, sr_deg = sf.read(deg_path_wav, dtype='float64') |
| |
| if sr_ref != target_sr or sr_deg != target_sr: |
| print(f"警告:读取的 WAV 文件采样率 ({sr_ref}/{sr_deg}) 与目标 ({target_sr}) 不符。") |
| |
| print("音频数据加载成功。") |
| except Exception as read_e: |
| raise HTTPException(status_code=500, detail=f"读取转换后的 WAV 文件时出错: {read_e}") |
|
|
| |
| config = visqol_config_pb2.VisqolConfig() |
| config.audio.sample_rate = target_sr |
| |
| |
| if mode == "speech": |
| config.options.use_speech_scoring = True |
| |
| model_file_to_use = AUDIO_MODEL_PATH |
| else: |
| config.options.use_speech_scoring = False |
| |
| model_file_to_use = SPEECH_MODEL_PATH |
| |
| config.options.svr_model_path = os.path.abspath(model_file_to_use) |
| print(f"使用模型: {model_file_to_use} for mode '{mode}'") |
|
|
| |
| api = visqol_lib_py.VisqolApi() |
| api.Create(config) |
| |
| similarity_result_msg = api.Measure(ref_data, deg_data) |
|
|
| |
| if similarity_result_msg and hasattr(similarity_result_msg, 'moslqo'): |
| mos = similarity_result_msg.moslqo |
| status_msg = "处理成功" |
| print(f"ViSQOL 评估完成: MOS-LQO = {mos}") |
| |
| if hasattr(similarity_result_msg, 'vnsim'): |
| vnsim_val = similarity_result_msg.vnsim |
| print(f"VNSIM = {vnsim_val}") |
| else: |
| print("ViSQOL 结果中未找到 vnsim 字段。") |
| |
| if hasattr(similarity_result_msg, 'fvnsim') and similarity_result_msg.fvnsim: |
| fvnsim_val = list(similarity_result_msg.fvnsim) |
| print(f"FVNSIM (第一个元素): {fvnsim_val[0] if fvnsim_val else 'N/A'}") |
| else: |
| print("ViSQOL 结果中未找到 fvnsim 字段或为空。") |
| else: |
| error_msg = "ViSQOL 未返回有效的 MOS-LQO 结果。" |
| print(f"错误: {error_msg}") |
|
|
| except ImportError as e: |
| status_msg = "导入错误" |
| error_msg = f"无法导入 ViSQOL 库或依赖: {e}" |
| print(f"错误: {error_msg}") |
| except FileNotFoundError as e: |
| status_msg = "文件未找到错误" |
| error_msg = f"必需文件丢失: {e}" |
| print(f"错误: {error_msg}") |
| except HTTPException as e: |
| status_msg = "请求错误" |
| error_msg = str(e.detail) |
| print(f"错误: {error_msg}") |
| except Exception as e: |
| status_msg = "运行时错误" |
| error_msg = f"处理过程中发生错误: {type(e).__name__} - {e}" |
| print(f"错误: {error_msg}") |
| |
| |
| |
| finally: |
| if os.path.exists(temp_dir): |
| shutil.rmtree(temp_dir) |
|
|
| return VisqolResponse( |
| reference_filename=reference.filename, |
| degraded_filename=degraded.filename, |
| mode=mode, |
| moslqo=mos, |
| vnsim=vnsim_val, |
| fvnsim=fvnsim_val, |
| status=status_msg, |
| error_message=error_msg |
| ) |
|
|
| @app.get("/", include_in_schema=False) |
| async def root(): |
| |
| return {"message": "欢迎使用 ViSQOL 音频质量评估 API。请使用 POST 方法访问 /evaluate/ 端点。"} |
|
|
| |
| @app.get("/healthz", status_code=200) |
| async def health_check(): |
| """Hugging Face Spaces health check endpoint.""" |
| |
| if visqol_lib_py is None: |
| return {"status": "error", "detail": "ViSQOL library not loaded"} |
| return {"status": "ok"} |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| print("运行本地测试服务器: http://127.0.0.1:8000") |
| |
| uvicorn.run(app, host="127.0.0.1", port=8000) |