| | import numpy as np |
| | from fastapi import FastAPI, HTTPException, Body |
| | from fastapi.responses import JSONResponse |
| | from typing import List, Optional |
| | import logging |
| | from SenseVoiceAx import SenseVoiceAx |
| | import os |
| | import librosa |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API") |
| |
|
| | |
| | asr_model = None |
| |
|
| |
|
| | @app.on_event("startup") |
| | async def load_model(): |
| | """ |
| | 服务启动时加载ASR模型 |
| | """ |
| | global asr_model |
| | logger.info("Loading ASR model...") |
| |
|
| | try: |
| | |
| | language = "auto" |
| | model_root = "../sensevoice_ax650" |
| | max_seq_len = 256 |
| | model_path = os.path.join(model_root, "sensevoice.axmodel") |
| |
|
| | assert os.path.exists(model_path), f"model {model_path} not exist" |
| |
|
| | cmvn_file = os.path.join(model_root, "am.mvn") |
| | bpe_model = os.path.join(model_root, "chn_jpn_yue_eng_ko_spectok.bpe.model") |
| | token_file = os.path.join(model_root, "tokens.txt") |
| |
|
| | asr_model = SenseVoiceAx( |
| | model_path, |
| | cmvn_file, |
| | token_file, |
| | bpe_model, |
| | max_seq_len=max_seq_len, |
| | beam_size=3, |
| | hot_words=None, |
| | streaming=False, |
| | ) |
| |
|
| | print(f"language: {language}") |
| | print(f"model_path: {model_path}") |
| |
|
| | logger.info("ASR model loaded successfully") |
| | except Exception as e: |
| | logger.error(f"Failed to load ASR model: {str(e)}") |
| | raise |
| |
|
| |
|
| | def validate_audio_data(audio_data: List[float]) -> np.ndarray: |
| | """ |
| | 验证并转换音频数据为numpy数组 |
| | |
| | 参数: |
| | - audio_data: 浮点数列表表示的音频数据 |
| | |
| | 返回: |
| | - 验证后的numpy数组 |
| | """ |
| | try: |
| | |
| | np_array = np.array(audio_data, dtype=np.float32) |
| |
|
| | |
| | if np_array.ndim != 1: |
| | raise ValueError("Audio data must be 1-dimensional") |
| |
|
| | if len(np_array) == 0: |
| | raise ValueError("Audio data cannot be empty") |
| |
|
| | return np_array |
| | except Exception as e: |
| | raise ValueError(f"Invalid audio data: {str(e)}") |
| |
|
| |
|
| | @app.get("/get_language", summary="Get current language") |
| | async def get_language(): |
| | return JSONResponse(content={"language": asr_model.language}) |
| |
|
| |
|
| | @app.get( |
| | "/get_language_options", |
| | summary="Get possible language options, possible options include [auto, zh, en, yue, ja, ko]", |
| | ) |
| | async def get_language_options(): |
| | return JSONResponse(content={"language_options": asr_model.language_options}) |
| |
|
| |
|
| | @app.post("/asr", summary="Recognize speech from numpy audio data") |
| | async def recognize_speech( |
| | audio_data: List[float] = Body( |
| | ..., embed=True, description="Audio data as list of floats" |
| | ), |
| | sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"), |
| | language: Optional[str] = Body("auto", description="Language"), |
| | ): |
| | """ |
| | 接收numpy数组格式的音频数据并返回识别结果 |
| | |
| | 参数: |
| | - audio_data: 浮点数列表表示的音频数据 |
| | - sample_rate: 音频采样率(默认16000Hz) |
| | |
| | 返回: |
| | - JSON包含识别文本 |
| | """ |
| | try: |
| | |
| | if asr_model is None: |
| | raise HTTPException(status_code=503, detail="ASR model not loaded") |
| |
|
| | logger.info(f"Received audio data with length: {len(audio_data)}") |
| |
|
| | |
| | np_audio = validate_audio_data(audio_data) |
| |
|
| | |
| | result = asr_model.infer_waveform((np_audio, sample_rate), language) |
| |
|
| | return JSONResponse(content={"text": result}) |
| |
|
| | except ValueError as e: |
| | logger.error(f"Validation error: {str(e)}") |
| | raise HTTPException(status_code=400, detail=str(e)) |
| | except Exception as e: |
| | logger.error(f"Recognition error: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| |
|
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|