import os import tempfile import logging import re import soundfile as sf from typing import Optional, List from fastapi import FastAPI, HTTPException, Body from fastapi.responses import JSONResponse import numpy as np from funasr import AutoModel from dotenv import load_dotenv load_dotenv() logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API") # 全局变量存储模型 asr_model = None @app.on_event("startup") async def load_model(): """ 服务启动时加载ASR模型 """ global asr_model logger.info("Loading ASR model...") try: model_path = os.getenv("ASR_MODEL_PATH", "") asr_model = AutoModel( model=model_path, device="cuda", disable_update=True, ) logger.info("ASR model loaded successfully") except Exception as e: logger.error(f"Failed to load ASR model: {str(e)}") raise def clean_text(text: str) -> str: """清理文本中的特殊标记""" text = re.sub(r'<\|[^|]*\|>', '', text) text = re.sub(r'\s+', ' ', text).strip() return text def validate_audio_data(audio_data: List[float]) -> np.ndarray: """ 验证并转换音频数据为numpy数组 参数: - audio_data: 浮点数列表表示的音频数据 返回: - 验证后的numpy数组 """ try: # 转换为numpy数组 np_array = np.array(audio_data, dtype=np.float32) # 验证数据有效性 if np_array.ndim != 1: raise ValueError("Audio data must be 1-dimensional") if len(np_array) == 0: raise ValueError("Audio data cannot be empty") return np_array except Exception as e: raise ValueError(f"Invalid audio data: {str(e)}") @app.post("/asr", summary="Recognize speech from numpy audio data") async def recognize_speech( audio_data: List[float] = Body(..., embed=True, description="Audio data as list of floats"), sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"), language: Optional[str] = Body("auto", description="Language"), ): """ 接收numpy数组格式的音频数据并返回识别结果 参数: - audio_data: 浮点数列表表示的音频数据 - sample_rate: 音频采样率(默认16000Hz) - language: 语言 (auto, zh, en, yue, ja, ko) 返回: - JSON包含识别文本 """ try: # 检查模型是否已加载 if asr_model is None: raise HTTPException(status_code=503, detail="ASR model not loaded") logger.info(f"Received audio data with length: {len(audio_data)}") # 验证并转换数据 np_audio = validate_audio_data(audio_data) # 保存为临时wav文件 with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: tmp_path = tmp_file.name sf.write(tmp_path, np_audio, sample_rate) try: # 调用模型进行识别 (使用sensevoice的generate方式) res = asr_model.generate( input=tmp_path, language=language or "auto", use_itn=True, batch_size_s=60 ) if not res: raise HTTPException(status_code=400, detail="No result generated") result_data = res[0] text = result_data.get("text", "") text = clean_text(text) return JSONResponse(content={"text": text}) finally: # 清理临时文件 try: os.remove(tmp_path) except: pass except ValueError as e: logger.error(f"Validation error: {str(e)}") raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Recognition error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "ok", "model_loaded": asr_model is not None} if __name__ == "__main__": import uvicorn port = int(os.getenv("ASR_API_PORT", 8007)) uvicorn.run(app, host="0.0.0.0", port=port)