File size: 4,064 Bytes
07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 b3b007c 02707b5 07f9af1 4fd2d31 07f9af1 d1ae526 b3b007c d1ae526 b3b007c d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 07f9af1 d1ae526 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import numpy as np
from fastapi import FastAPI, HTTPException, Body
from fastapi.responses import JSONResponse
from typing import List, Optional
import logging
from SenseVoiceAx import SenseVoiceAx
import os
import librosa
# 初始化日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API")
# 全局变量存储模型
asr_model = None
@app.on_event("startup")
async def load_model():
"""
服务启动时加载ASR模型
"""
global asr_model
logger.info("Loading ASR model...")
try:
# 模型加载
language = "auto"
use_itn = True # 逆文本规范
max_len = 256
model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
assert os.path.exists(model_path), f"model {model_path} not exist"
print(f"language: {language}")
print(f"use_itn: {use_itn}")
print(f"model_path: {model_path}")
asr_model = SenseVoiceAx(
model_path,
max_len=max_len,
beam_size=3,
language="auto",
hot_words=None,
use_itn=use_itn,
streaming=False,
)
logger.info("ASR model loaded successfully")
except Exception as e:
logger.error(f"Failed to load ASR model: {str(e)}")
raise
def validate_audio_data(audio_data: List[float]) -> np.ndarray:
"""
验证并转换音频数据为numpy数组
参数:
- audio_data: 浮点数列表表示的音频数据
返回:
- 验证后的numpy数组
"""
try:
# 转换为numpy数组
np_array = np.array(audio_data, dtype=np.float32)
# 验证数据有效性
if np_array.ndim != 1:
raise ValueError("Audio data must be 1-dimensional")
if len(np_array) == 0:
raise ValueError("Audio data cannot be empty")
return np_array
except Exception as e:
raise ValueError(f"Invalid audio data: {str(e)}")
@app.get("/get_language", summary="Get current language")
async def get_language():
return JSONResponse(content={"language": asr_model.language})
@app.get(
"/get_language_options",
summary="Get possible language options, possible options include [auto, zh, en, yue, ja, ko]",
)
async def get_language_options():
return JSONResponse(content={"language_options": asr_model.language_options})
@app.post("/asr", summary="Recognize speech from numpy audio data")
async def recognize_speech(
audio_data: List[float] = Body(
..., embed=True, description="Audio data as list of floats"
),
sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
language: Optional[str] = Body("auto", description="Language"),
):
"""
接收numpy数组格式的音频数据并返回识别结果
参数:
- audio_data: 浮点数列表表示的音频数据
- sample_rate: 音频采样率(默认16000Hz)
返回:
- JSON包含识别文本
"""
try:
# 检查模型是否已加载
if asr_model is None:
raise HTTPException(status_code=503, detail="ASR model not loaded")
logger.info(f"Received audio data with length: {len(audio_data)}")
# 验证并转换数据
np_audio = validate_audio_data(audio_data)
if sample_rate != asr_model.sample_rate:
np_audio = librosa.resample(np_audio, sample_rate, asr_model.sample_rate)
# 调用模型进行识别
result = asr_model.infer_waveform(np_audio, language)
return JSONResponse(content={"text": result})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Recognition error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|