| import os |
| import tempfile |
| import logging |
| import re |
| import soundfile as sf |
| from typing import Optional, List |
| from fastapi import FastAPI, HTTPException, Body |
| from fastapi.responses import JSONResponse |
| import numpy as np |
|
|
| from funasr import AutoModel |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API") |
|
|
| |
| asr_model = None |
|
|
|
|
| @app.on_event("startup") |
| async def load_model(): |
| """ |
| 服务启动时加载ASR模型 |
| """ |
| global asr_model |
| logger.info("Loading ASR model...") |
|
|
| try: |
| model_path = os.getenv("ASR_MODEL_PATH", "") |
| asr_model = AutoModel( |
| model=model_path, |
| device="cuda", |
| disable_update=True, |
| |
| ) |
| logger.info("ASR model loaded successfully") |
| except Exception as e: |
| logger.error(f"Failed to load ASR model: {str(e)}") |
| raise |
|
|
|
|
| def clean_text(text: str) -> str: |
| """清理文本中的特殊标记""" |
| text = re.sub(r'<\|[^|]*\|>', '', text) |
| text = re.sub(r'\s+', ' ', text).strip() |
| return text |
|
|
|
|
| def validate_audio_data(audio_data: List[float]) -> np.ndarray: |
| """ |
| 验证并转换音频数据为numpy数组 |
| 参数: |
| - audio_data: 浮点数列表表示的音频数据 |
| 返回: |
| - 验证后的numpy数组 |
| """ |
| try: |
| |
| np_array = np.array(audio_data, dtype=np.float32) |
|
|
| |
| if np_array.ndim != 1: |
| raise ValueError("Audio data must be 1-dimensional") |
|
|
| if len(np_array) == 0: |
| raise ValueError("Audio data cannot be empty") |
|
|
| return np_array |
| except Exception as e: |
| raise ValueError(f"Invalid audio data: {str(e)}") |
|
|
|
|
| @app.post("/asr", summary="Recognize speech from numpy audio data") |
| async def recognize_speech( |
| audio_data: List[float] = Body(..., embed=True, description="Audio data as list of floats"), |
| sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"), |
| language: Optional[str] = Body("auto", description="Language"), |
| ): |
| """ |
| 接收numpy数组格式的音频数据并返回识别结果 |
| 参数: |
| - audio_data: 浮点数列表表示的音频数据 |
| - sample_rate: 音频采样率(默认16000Hz) |
| - language: 语言 (auto, zh, en, yue, ja, ko) |
| 返回: |
| - JSON包含识别文本 |
| """ |
| try: |
| |
| if asr_model is None: |
| raise HTTPException(status_code=503, detail="ASR model not loaded") |
|
|
| logger.info(f"Received audio data with length: {len(audio_data)}") |
|
|
| |
| np_audio = validate_audio_data(audio_data) |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: |
| tmp_path = tmp_file.name |
| sf.write(tmp_path, np_audio, sample_rate) |
|
|
| try: |
| |
| res = asr_model.generate( |
| input=tmp_path, |
| language=language or "auto", |
| use_itn=True, |
| batch_size_s=60 |
| ) |
|
|
| if not res: |
| raise HTTPException(status_code=400, detail="No result generated") |
|
|
| result_data = res[0] |
| text = result_data.get("text", "") |
| text = clean_text(text) |
|
|
| return JSONResponse(content={"text": text}) |
|
|
| finally: |
| |
| try: |
| os.remove(tmp_path) |
| except: |
| pass |
|
|
| except ValueError as e: |
| logger.error(f"Validation error: {str(e)}") |
| raise HTTPException(status_code=400, detail=str(e)) |
| except Exception as e: |
| logger.error(f"Recognition error: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.get("/health") |
| async def health_check(): |
| return {"status": "ok", "model_loaded": asr_model is not None} |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| port = int(os.getenv("ASR_API_PORT", 8007)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|