File size: 4,064 Bytes
07f9af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1ae526
07f9af1
 
 
 
 
 
 
d1ae526
07f9af1
 
 
b3b007c
02707b5
07f9af1
4fd2d31
07f9af1
 
 
 
 
 
 
d1ae526
 
 
b3b007c
 
 
d1ae526
b3b007c
d1ae526
 
07f9af1
 
 
 
 
d1ae526
07f9af1
 
 
d1ae526
07f9af1
 
d1ae526
07f9af1
 
 
 
 
 
d1ae526
07f9af1
 
 
d1ae526
07f9af1
 
d1ae526
07f9af1
 
 
d1ae526
 
07f9af1
 
 
 
d1ae526
 
 
 
 
07f9af1
 
 
d1ae526
07f9af1
 
d1ae526
 
 
07f9af1
d1ae526
07f9af1
 
 
d1ae526
07f9af1
 
 
d1ae526
07f9af1
 
 
 
 
 
 
d1ae526
07f9af1
d1ae526
07f9af1
 
 
 
d1ae526
07f9af1
 
d1ae526
07f9af1
d1ae526
07f9af1
 
 
 
 
 
 
d1ae526
07f9af1
 
d1ae526
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
from fastapi import FastAPI, HTTPException, Body
from fastapi.responses import JSONResponse
from typing import List, Optional
import logging
from SenseVoiceAx import SenseVoiceAx
import os
import librosa

# 初始化日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API")

# 全局变量存储模型
asr_model = None


@app.on_event("startup")
async def load_model():
    """
    服务启动时加载ASR模型
    """
    global asr_model
    logger.info("Loading ASR model...")

    try:
        # 模型加载
        language = "auto"
        use_itn = True  # 逆文本规范
        max_len = 256

        model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")

        assert os.path.exists(model_path), f"model {model_path} not exist"

        print(f"language: {language}")
        print(f"use_itn: {use_itn}")
        print(f"model_path: {model_path}")

        asr_model = SenseVoiceAx(
            model_path,
            max_len=max_len,
            beam_size=3,
            language="auto",
            hot_words=None,
            use_itn=use_itn,
            streaming=False,
        )

        logger.info("ASR model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load ASR model: {str(e)}")
        raise


def validate_audio_data(audio_data: List[float]) -> np.ndarray:
    """
    验证并转换音频数据为numpy数组

    参数:
    - audio_data: 浮点数列表表示的音频数据

    返回:
    - 验证后的numpy数组
    """
    try:
        # 转换为numpy数组
        np_array = np.array(audio_data, dtype=np.float32)

        # 验证数据有效性
        if np_array.ndim != 1:
            raise ValueError("Audio data must be 1-dimensional")

        if len(np_array) == 0:
            raise ValueError("Audio data cannot be empty")

        return np_array
    except Exception as e:
        raise ValueError(f"Invalid audio data: {str(e)}")


@app.get("/get_language", summary="Get current language")
async def get_language():
    return JSONResponse(content={"language": asr_model.language})


@app.get(
    "/get_language_options",
    summary="Get possible language options, possible options include [auto, zh, en, yue, ja, ko]",
)
async def get_language_options():
    return JSONResponse(content={"language_options": asr_model.language_options})


@app.post("/asr", summary="Recognize speech from numpy audio data")
async def recognize_speech(
    audio_data: List[float] = Body(
        ..., embed=True, description="Audio data as list of floats"
    ),
    sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
    language: Optional[str] = Body("auto", description="Language"),
):
    """
    接收numpy数组格式的音频数据并返回识别结果

    参数:
    - audio_data: 浮点数列表表示的音频数据
    - sample_rate: 音频采样率(默认16000Hz)

    返回:
    - JSON包含识别文本
    """
    try:
        # 检查模型是否已加载
        if asr_model is None:
            raise HTTPException(status_code=503, detail="ASR model not loaded")

        logger.info(f"Received audio data with length: {len(audio_data)}")

        # 验证并转换数据
        np_audio = validate_audio_data(audio_data)
        if sample_rate != asr_model.sample_rate:
            np_audio = librosa.resample(np_audio, sample_rate, asr_model.sample_rate)

        # 调用模型进行识别
        result = asr_model.infer_waveform(np_audio, language)

        return JSONResponse(content={"text": result})

    except ValueError as e:
        logger.error(f"Validation error: {str(e)}")
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        logger.error(f"Recognition error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)