* optmize models path
Browse files- README.md +2 -89
- README_ZH.md +9 -0
- config.json +0 -0
- download_dataset.sh +0 -2
- download_utils.py +0 -33
- example/en.mp3 +0 -3
- example/ja.mp3 +0 -3
- example/ko.mp3 +0 -3
- example/yue.mp3 +0 -3
- example/zh.mp3 +0 -3
- gradio_demo.py +0 -62
- main.py +0 -79
- print_utils.py +0 -131
- requirements.txt +0 -11
- SenseVoiceAx.py → sensevoice-maixcam2/SenseVoiceAx.py +6 -6
- am.mvn → sensevoice-maixcam2/am.mvn +0 -0
- chn_jpn_yue_eng_ko_spectok.bpe.model → sensevoice-maixcam2/chn_jpn_yue_eng_ko_spectok.bpe.model +0 -0
- client.py → sensevoice-maixcam2/client.py +34 -9
- embeddings.npy → sensevoice-maixcam2/embeddings.npy +0 -0
- frontend.py → sensevoice-maixcam2/frontend.py +0 -0
- model.mud → sensevoice-maixcam2/model.mud +4 -3
- pe_nonstream.npy → sensevoice-maixcam2/pe_nonstream.npy +0 -0
- pe_streaming.npy → sensevoice-maixcam2/pe_streaming.npy +0 -0
- {sensevoice_ax630c → sensevoice-maixcam2}/sensevoice.axmodel +0 -0
- server.py → sensevoice-maixcam2/server.py +22 -6
- {sensevoice_ax630c → sensevoice-maixcam2}/streaming_sensevoice.axmodel +0 -0
- tokenizer.py → sensevoice-maixcam2/tokenizer.py +0 -0
- test_wer.py +0 -296
README.md
CHANGED
|
@@ -4,93 +4,6 @@ language:
|
|
| 4 |
- en
|
| 5 |
pipeline_tag: automatic-speech-recognition
|
| 6 |
---
|
| 7 |
-
#
|
| 8 |
-
FunASR SenseVoice on Axera, official repo: https://github.com/FunAudioLLM/SenseVoice
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
- [x] 支持AX630C
|
| 13 |
-
- [ ] 支持C++
|
| 14 |
-
- [x] 支持FastAPI
|
| 15 |
-
|
| 16 |
-
## 功能
|
| 17 |
-
- 语音识别
|
| 18 |
-
- 自动识别语言(支持中文、英文、粤语、日语、韩语)
|
| 19 |
-
- 情感识别
|
| 20 |
-
- 自动标点
|
| 21 |
-
- 支持流式识别
|
| 22 |
-
|
| 23 |
-
## 支持平台
|
| 24 |
-
|
| 25 |
-
- [x] AX650N
|
| 26 |
-
- [x] AX630C
|
| 27 |
-
|
| 28 |
-
## 环境安装
|
| 29 |
-
```
|
| 30 |
-
pip3 install -r requirements.txt
|
| 31 |
-
```
|
| 32 |
-
如果空间不足可以使用 --prefix 指定别的安装路径
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
## 使用
|
| 36 |
-
```
|
| 37 |
-
# 首次运行会自动从huggingface上下载模型, 保存到models中
|
| 38 |
-
python3 main.py -i 输入音频文件
|
| 39 |
-
```
|
| 40 |
-
运行参数说明:
|
| 41 |
-
| 参数名称 | 说明 | 默认值 |
|
| 42 |
-
| --- | --- | --- |
|
| 43 |
-
| --input/-i | 输入音频文件 | |
|
| 44 |
-
| --language/-l | 识别语言,支持auto, zh, en, yue, ja, ko | auto |
|
| 45 |
-
| --streaming | 流式识别 | |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
### 示例:
|
| 49 |
-
example下有测试音频
|
| 50 |
-
|
| 51 |
-
如 粤语测试
|
| 52 |
-
```
|
| 53 |
-
python3 main.py -i example/yue.mp3
|
| 54 |
-
```
|
| 55 |
-
输出
|
| 56 |
-
```
|
| 57 |
-
RTF: 0.03026517820946964 Latency: 0.15689468383789062s Total length: 5.184s
|
| 58 |
-
['呢几个字。', '都表达唔到,我想讲嘅意。', '思。']
|
| 59 |
-
```
|
| 60 |
-
|
| 61 |
-
流式识别
|
| 62 |
-
|
| 63 |
-
```
|
| 64 |
-
python3 main.py -i example/zh.mp3 --streaming
|
| 65 |
-
```
|
| 66 |
-
输出
|
| 67 |
-
```
|
| 68 |
-
{'timestamps': [540], 'text': '开'}
|
| 69 |
-
{'timestamps': [540, 780, 1080], 'text': '开放时'}
|
| 70 |
-
{'timestamps': [540, 780, 1080, 1260, 1740], 'text': '开放时间早'}
|
| 71 |
-
{'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340], 'text': '开放时间早上9'}
|
| 72 |
-
{'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640], 'text': '开放时间早上9点'}
|
| 73 |
-
{'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060], 'text': '开放时间早上9点至'}
|
| 74 |
-
{'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060, 3780, 4020], 'text': '开放时间早上9点至下午'}
|
| 75 |
-
{'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060, 3780, 4020, 4440, 4620], 'text': '开放时间早上9点至下午五点'}
|
| 76 |
-
RTF: 0.03678379235444246
|
| 77 |
-
|
| 78 |
-
```
|
| 79 |
-
|
| 80 |
-
## 准确率
|
| 81 |
-
|
| 82 |
-
使用WER(Word-Error-Rate)作为评价标准
|
| 83 |
-
|
| 84 |
-
**WER = 0.0389**
|
| 85 |
-
|
| 86 |
-
### 复现测试结果
|
| 87 |
-
|
| 88 |
-
```
|
| 89 |
-
./download_datasets.sh
|
| 90 |
-
python test_wer.py -d datasets -l zh
|
| 91 |
-
```
|
| 92 |
-
|
| 93 |
-
## 技术讨论
|
| 94 |
-
|
| 95 |
-
- Github issues
|
| 96 |
-
- QQ 群: 139953715
|
|
|
|
| 4 |
- en
|
| 5 |
pipeline_tag: automatic-speech-recognition
|
| 6 |
---
|
| 7 |
+
# SenseVoice
|
|
|
|
| 8 |
|
| 9 |
+
Refer here[](https://wiki.sipeed.com/maixpy/doc/en/mllm/asr_sensevoice.html)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README_ZH.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
pipeline_tag: automatic-speech-recognition
|
| 6 |
+
---
|
| 7 |
+
# SenseVoice
|
| 8 |
+
|
| 9 |
+
使用文档参考[这里](https://wiki.sipeed.com/maixpy/doc/zh/mllm/asr_sensevoice.html)
|
config.json
DELETED
|
File without changes
|
download_dataset.sh
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
wget https://github.com/ml-inory/whisper.axera/releases/download/v1.0/datasets.zip
|
| 2 |
-
unzip datasets.zip -d ./
|
|
|
|
|
|
|
|
|
download_utils.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
# Speed up hf download using mirror url
|
| 4 |
-
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 5 |
-
from huggingface_hub import snapshot_download
|
| 6 |
-
|
| 7 |
-
current_file_path = os.path.dirname(__file__)
|
| 8 |
-
REPO_ROOT = "AXERA-TECH"
|
| 9 |
-
CACHE_PATH = os.path.join(current_file_path, "models")
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def download_model(model_name: str) -> str:
|
| 13 |
-
"""
|
| 14 |
-
Download model from AXERA-TECH's huggingface space.
|
| 15 |
-
|
| 16 |
-
model_name: str
|
| 17 |
-
Available model names could be checked on https://huggingface.co/AXERA-TECH.
|
| 18 |
-
|
| 19 |
-
Returns:
|
| 20 |
-
str: Path to model_name
|
| 21 |
-
|
| 22 |
-
"""
|
| 23 |
-
os.makedirs(CACHE_PATH, exist_ok=True)
|
| 24 |
-
|
| 25 |
-
model_path = os.path.join(CACHE_PATH, model_name)
|
| 26 |
-
if not os.path.exists(model_path):
|
| 27 |
-
print(f"Downloading {model_name}...")
|
| 28 |
-
snapshot_download(
|
| 29 |
-
repo_id=f"{REPO_ROOT}/{model_name}",
|
| 30 |
-
local_dir=os.path.join(CACHE_PATH, model_name),
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
return model_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
example/en.mp3
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f10378336a4e584f3f63799e62f99d5add3c2a401b51d3abe7d3a3a82f255ada
|
| 3 |
-
size 57441
|
|
|
|
|
|
|
|
|
|
|
|
example/ja.mp3
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:496dbc43b289e1d0d0cb916df9737450bca56acd8aaca046a7a2472363b1be53
|
| 3 |
-
size 57837
|
|
|
|
|
|
|
|
|
|
|
|
example/ko.mp3
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8612f62db8319a6cb4ab4b1d2039bfc32f174f89611889ddafdeb5c0a6070b5f
|
| 3 |
-
size 27909
|
|
|
|
|
|
|
|
|
|
|
|
example/yue.mp3
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5098eebc13530a66e4eac1f30d3246e65c9cfc4e096665f9d395aca8eff0d181
|
| 3 |
-
size 31246
|
|
|
|
|
|
|
|
|
|
|
|
example/zh.mp3
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:0e64de19e4ff9a02e682955c9112f32d2317cfdbb5bc2f3504664044c993f195
|
| 3 |
-
size 44973
|
|
|
|
|
|
|
|
|
|
|
|
gradio_demo.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import os
|
| 3 |
-
from SenseVoiceAx import SenseVoiceAx
|
| 4 |
-
from print_utils import rich_transcription_postprocess
|
| 5 |
-
|
| 6 |
-
max_len = 256
|
| 7 |
-
|
| 8 |
-
model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
|
| 9 |
-
|
| 10 |
-
assert os.path.exists(model_path), f"model {model_path} not exist"
|
| 11 |
-
|
| 12 |
-
pipeline = SenseVoiceAx(
|
| 13 |
-
model_path,
|
| 14 |
-
max_len=max_len,
|
| 15 |
-
beam_size=3,
|
| 16 |
-
language="auto",
|
| 17 |
-
hot_words=None,
|
| 18 |
-
use_itn=True,
|
| 19 |
-
streaming=False,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def speech_to_text(audio_path, lang):
|
| 24 |
-
"""
|
| 25 |
-
audio_path: 音频文件路径
|
| 26 |
-
lang: 语言类型 "auto", "zh", "en", "yue", "ja", "ko"
|
| 27 |
-
"""
|
| 28 |
-
if not audio_path:
|
| 29 |
-
return "无音频"
|
| 30 |
-
|
| 31 |
-
pipeline.choose_language(language=lang)
|
| 32 |
-
asr_res = pipeline.infer(audio_path, print_rtf=False)
|
| 33 |
-
|
| 34 |
-
return asr_res
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def main():
|
| 38 |
-
with gr.Blocks() as demo:
|
| 39 |
-
with gr.Row():
|
| 40 |
-
output_text = gr.Textbox(label="识别结果", lines=5)
|
| 41 |
-
|
| 42 |
-
with gr.Row():
|
| 43 |
-
audio_input = gr.Audio(
|
| 44 |
-
sources=["upload"], type="filepath", label="录制或上传音频", format="mp3"
|
| 45 |
-
)
|
| 46 |
-
lang_dropdown = gr.Dropdown(
|
| 47 |
-
choices=["auto", "zh", "en", "yue", "ja", "ko"],
|
| 48 |
-
value="auto",
|
| 49 |
-
label="选择音频语言",
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
audio_input.change(
|
| 53 |
-
fn=speech_to_text, inputs=[audio_input, lang_dropdown], outputs=output_text
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
demo.launch(
|
| 57 |
-
server_name="0.0.0.0",
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
if __name__ == "__main__":
|
| 62 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
DELETED
|
@@ -1,79 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import argparse
|
| 3 |
-
from SenseVoiceAx import SenseVoiceAx
|
| 4 |
-
import librosa
|
| 5 |
-
import numpy as np
|
| 6 |
-
import time
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def get_args():
|
| 10 |
-
parser = argparse.ArgumentParser()
|
| 11 |
-
parser.add_argument(
|
| 12 |
-
"--input", "-i", required=True, type=str, help="Input audio file"
|
| 13 |
-
)
|
| 14 |
-
parser.add_argument(
|
| 15 |
-
"--language",
|
| 16 |
-
"-l",
|
| 17 |
-
required=False,
|
| 18 |
-
type=str,
|
| 19 |
-
default="auto",
|
| 20 |
-
choices=["auto", "zh", "en", "yue", "ja", "ko"],
|
| 21 |
-
)
|
| 22 |
-
parser.add_argument("--streaming", action="store_true")
|
| 23 |
-
return parser.parse_args()
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def main():
|
| 27 |
-
args = get_args()
|
| 28 |
-
|
| 29 |
-
input_audio = args.input
|
| 30 |
-
language = args.language
|
| 31 |
-
use_itn = True # 标点符号预测
|
| 32 |
-
if not args.streaming:
|
| 33 |
-
max_len = 256
|
| 34 |
-
model_path = os.path.join("sensevoice_ax630c", "sensevoice.axmodel")
|
| 35 |
-
else:
|
| 36 |
-
max_len = 26
|
| 37 |
-
model_path = os.path.join("sensevoice_ax630c", "streaming_sensevoice.axmodel")
|
| 38 |
-
|
| 39 |
-
assert os.path.exists(model_path), f"model {model_path} not exist"
|
| 40 |
-
|
| 41 |
-
print(f"input_audio: {input_audio}")
|
| 42 |
-
print(f"language: {language}")
|
| 43 |
-
print(f"use_itn: {use_itn}")
|
| 44 |
-
print(f"model_path: {model_path}")
|
| 45 |
-
print(f"streaming: {args.streaming}")
|
| 46 |
-
|
| 47 |
-
pipeline = SenseVoiceAx(
|
| 48 |
-
model_path,
|
| 49 |
-
max_len=max_len,
|
| 50 |
-
beam_size=3,
|
| 51 |
-
language="auto",
|
| 52 |
-
hot_words=None,
|
| 53 |
-
use_itn=True,
|
| 54 |
-
streaming=args.streaming,
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
if not args.streaming:
|
| 58 |
-
asr_res = pipeline.infer(input_audio, print_rtf=True)
|
| 59 |
-
print("ASR result: " + asr_res)
|
| 60 |
-
else:
|
| 61 |
-
samples, sr = librosa.load(input_audio, sr=16000)
|
| 62 |
-
samples = (samples * 32768).tolist()
|
| 63 |
-
duration = len(samples) / 16000
|
| 64 |
-
|
| 65 |
-
start = time.time()
|
| 66 |
-
step = int(0.1 * sr)
|
| 67 |
-
for i in range(0, len(samples), step):
|
| 68 |
-
is_last = i + step >= len(samples)
|
| 69 |
-
for res in pipeline.stream_infer(samples[i : i + step], is_last):
|
| 70 |
-
print(res)
|
| 71 |
-
|
| 72 |
-
end = time.time()
|
| 73 |
-
cost_time = end - start
|
| 74 |
-
|
| 75 |
-
print(f"RTF: {cost_time / duration}")
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
if __name__ == "__main__":
|
| 79 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print_utils.py
DELETED
|
@@ -1,131 +0,0 @@
|
|
| 1 |
-
emo_dict = {
|
| 2 |
-
"<|HAPPY|>": "😊",
|
| 3 |
-
"<|SAD|>": "😔",
|
| 4 |
-
"<|ANGRY|>": "😡",
|
| 5 |
-
"<|NEUTRAL|>": "",
|
| 6 |
-
"<|FEARFUL|>": "😰",
|
| 7 |
-
"<|DISGUSTED|>": "🤢",
|
| 8 |
-
"<|SURPRISED|>": "😮",
|
| 9 |
-
}
|
| 10 |
-
|
| 11 |
-
event_dict = {
|
| 12 |
-
"<|BGM|>": "🎼",
|
| 13 |
-
"<|Speech|>": "",
|
| 14 |
-
"<|Applause|>": "👏",
|
| 15 |
-
"<|Laughter|>": "😀",
|
| 16 |
-
"<|Cry|>": "😭",
|
| 17 |
-
"<|Sneeze|>": "🤧",
|
| 18 |
-
"<|Breath|>": "",
|
| 19 |
-
"<|Cough|>": "🤧",
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
lang_dict = {
|
| 23 |
-
"<|zh|>": "<|lang|>",
|
| 24 |
-
"<|en|>": "<|lang|>",
|
| 25 |
-
"<|yue|>": "<|lang|>",
|
| 26 |
-
"<|ja|>": "<|lang|>",
|
| 27 |
-
"<|ko|>": "<|lang|>",
|
| 28 |
-
"<|nospeech|>": "<|lang|>",
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
emoji_dict = {
|
| 32 |
-
"<|nospeech|><|Event_UNK|>": "❓",
|
| 33 |
-
"<|zh|>": "",
|
| 34 |
-
"<|en|>": "",
|
| 35 |
-
"<|yue|>": "",
|
| 36 |
-
"<|ja|>": "",
|
| 37 |
-
"<|ko|>": "",
|
| 38 |
-
"<|nospeech|>": "",
|
| 39 |
-
"<|HAPPY|>": "😊",
|
| 40 |
-
"<|SAD|>": "😔",
|
| 41 |
-
"<|ANGRY|>": "😡",
|
| 42 |
-
"<|NEUTRAL|>": "",
|
| 43 |
-
"<|BGM|>": "🎼",
|
| 44 |
-
"<|Speech|>": "",
|
| 45 |
-
"<|Applause|>": "👏",
|
| 46 |
-
"<|Laughter|>": "😀",
|
| 47 |
-
"<|FEARFUL|>": "😰",
|
| 48 |
-
"<|DISGUSTED|>": "🤢",
|
| 49 |
-
"<|SURPRISED|>": "😮",
|
| 50 |
-
"<|Cry|>": "😭",
|
| 51 |
-
"<|EMO_UNKNOWN|>": "",
|
| 52 |
-
"<|Sneeze|>": "🤧",
|
| 53 |
-
"<|Breath|>": "",
|
| 54 |
-
"<|Cough|>": "😷",
|
| 55 |
-
"<|Sing|>": "",
|
| 56 |
-
"<|Speech_Noise|>": "",
|
| 57 |
-
"<|withitn|>": "",
|
| 58 |
-
"<|woitn|>": "",
|
| 59 |
-
"<|GBG|>": "",
|
| 60 |
-
"<|Event_UNK|>": "",
|
| 61 |
-
}
|
| 62 |
-
|
| 63 |
-
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
|
| 64 |
-
event_set = {
|
| 65 |
-
"🎼",
|
| 66 |
-
"👏",
|
| 67 |
-
"😀",
|
| 68 |
-
"😭",
|
| 69 |
-
"🤧",
|
| 70 |
-
"😷",
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def format_str_v2(s):
|
| 75 |
-
sptk_dict = {}
|
| 76 |
-
for sptk in emoji_dict:
|
| 77 |
-
sptk_dict[sptk] = s.count(sptk)
|
| 78 |
-
s = s.replace(sptk, "")
|
| 79 |
-
emo = "<|NEUTRAL|>"
|
| 80 |
-
for e in emo_dict:
|
| 81 |
-
if sptk_dict[e] > sptk_dict[emo]:
|
| 82 |
-
emo = e
|
| 83 |
-
for e in event_dict:
|
| 84 |
-
if sptk_dict[e] > 0:
|
| 85 |
-
s = event_dict[e] + s
|
| 86 |
-
s = s + emo_dict[emo]
|
| 87 |
-
|
| 88 |
-
for emoji in emo_set.union(event_set):
|
| 89 |
-
s = s.replace(" " + emoji, emoji)
|
| 90 |
-
s = s.replace(emoji + " ", emoji)
|
| 91 |
-
return s.strip()
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def rich_transcription_postprocess(s):
|
| 95 |
-
def get_emo(s):
|
| 96 |
-
return s[-1] if s[-1] in emo_set else None
|
| 97 |
-
|
| 98 |
-
def get_event(s):
|
| 99 |
-
return s[0] if s[0] in event_set else None
|
| 100 |
-
|
| 101 |
-
s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
|
| 102 |
-
for lang in lang_dict:
|
| 103 |
-
s = s.replace(lang, "<|lang|>")
|
| 104 |
-
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
|
| 105 |
-
new_s = " " + s_list[0]
|
| 106 |
-
cur_ent_event = get_event(new_s)
|
| 107 |
-
for i in range(1, len(s_list)):
|
| 108 |
-
if len(s_list[i]) == 0:
|
| 109 |
-
continue
|
| 110 |
-
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
|
| 111 |
-
s_list[i] = s_list[i][1:]
|
| 112 |
-
# else:
|
| 113 |
-
cur_ent_event = get_event(s_list[i])
|
| 114 |
-
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
|
| 115 |
-
new_s = new_s[:-1]
|
| 116 |
-
new_s += s_list[i].strip().lstrip()
|
| 117 |
-
new_s = new_s.replace("The.", " ")
|
| 118 |
-
return new_s.strip()
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def rich_print_asr_res(asr_res, will_print=True, remove_punc=False):
|
| 122 |
-
res = "".join([rich_transcription_postprocess(i) for i in asr_res])
|
| 123 |
-
|
| 124 |
-
if remove_punc:
|
| 125 |
-
res = res.replace(",", "")
|
| 126 |
-
res = res.replace("。", "")
|
| 127 |
-
|
| 128 |
-
if will_print:
|
| 129 |
-
print(res)
|
| 130 |
-
|
| 131 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
huggingface_hub
|
| 2 |
-
numpy<2
|
| 3 |
-
kaldi-native-fbank
|
| 4 |
-
librosa==0.9.1
|
| 5 |
-
sentencepiece
|
| 6 |
-
fastapi
|
| 7 |
-
gradio
|
| 8 |
-
emoji
|
| 9 |
-
asr-decoder
|
| 10 |
-
online-fbank
|
| 11 |
-
torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SenseVoiceAx.py → sensevoice-maixcam2/SenseVoiceAx.py
RENAMED
|
@@ -281,18 +281,18 @@ class SenseVoiceAx:
|
|
| 281 |
|
| 282 |
"""
|
| 283 |
model_path_root = os.path.dirname(model_path)
|
| 284 |
-
emb_path = os.path.join(model_path_root, "
|
| 285 |
-
cmvn_file = os.path.join(model_path_root, "
|
| 286 |
bpe_model = os.path.join(
|
| 287 |
-
model_path_root, "
|
| 288 |
)
|
| 289 |
if streaming:
|
| 290 |
self.position_encoding = np.load(
|
| 291 |
-
os.path.join(model_path_root, "
|
| 292 |
)
|
| 293 |
else:
|
| 294 |
self.position_encoding = np.load(
|
| 295 |
-
os.path.join(model_path_root, "
|
| 296 |
)
|
| 297 |
|
| 298 |
self.streaming = streaming
|
|
@@ -553,7 +553,7 @@ class SenseVoiceAx:
|
|
| 553 |
self.cur_idx = -1
|
| 554 |
self.decoder.reset()
|
| 555 |
self.fbank = OnlineFbank(window_type="hamming")
|
| 556 |
-
self.caches = np.zeros(self.caches_shape)
|
| 557 |
|
| 558 |
def get_size(self):
|
| 559 |
effective_size = self.cur_idx + 1 - self.padding
|
|
|
|
| 281 |
|
| 282 |
"""
|
| 283 |
model_path_root = os.path.dirname(model_path)
|
| 284 |
+
emb_path = os.path.join(model_path_root, "embeddings.npy")
|
| 285 |
+
cmvn_file = os.path.join(model_path_root, "am.mvn")
|
| 286 |
bpe_model = os.path.join(
|
| 287 |
+
model_path_root, "chn_jpn_yue_eng_ko_spectok.bpe.model"
|
| 288 |
)
|
| 289 |
if streaming:
|
| 290 |
self.position_encoding = np.load(
|
| 291 |
+
os.path.join(model_path_root, "pe_streaming.npy")
|
| 292 |
)
|
| 293 |
else:
|
| 294 |
self.position_encoding = np.load(
|
| 295 |
+
os.path.join(model_path_root, "pe_nonstream.npy")
|
| 296 |
)
|
| 297 |
|
| 298 |
self.streaming = streaming
|
|
|
|
| 553 |
self.cur_idx = -1
|
| 554 |
self.decoder.reset()
|
| 555 |
self.fbank = OnlineFbank(window_type="hamming")
|
| 556 |
+
self.caches = np.zeros(self.caches_shape, dtype=np.float32)
|
| 557 |
|
| 558 |
def get_size(self):
|
| 559 |
effective_size = self.cur_idx + 1 - self.padding
|
am.mvn → sensevoice-maixcam2/am.mvn
RENAMED
|
File without changes
|
chn_jpn_yue_eng_ko_spectok.bpe.model → sensevoice-maixcam2/chn_jpn_yue_eng_ko_spectok.bpe.model
RENAMED
|
File without changes
|
client.py → sensevoice-maixcam2/client.py
RENAMED
|
@@ -6,7 +6,7 @@ import requests, json, os
|
|
| 6 |
import wave
|
| 7 |
import numpy as np
|
| 8 |
import threading
|
| 9 |
-
from maix import app, time
|
| 10 |
|
| 11 |
class Sensevoice:
|
| 12 |
def __init__(self, model = "", url="http://0.0.0.0:12347", lauguage="auto", stream=False):
|
|
@@ -26,6 +26,17 @@ class Sensevoice:
|
|
| 26 |
if not os.path.exists(model):
|
| 27 |
raise ValueError(f'Model {self.model} is not existed!')
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def _check_service(self):
|
| 30 |
try:
|
| 31 |
response = requests.get(self.url + '/status')
|
|
@@ -75,6 +86,15 @@ class Sensevoice:
|
|
| 75 |
except:
|
| 76 |
return "not loaded"
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def _start_model(self):
|
| 79 |
try:
|
| 80 |
data = {
|
|
@@ -83,6 +103,7 @@ class Sensevoice:
|
|
| 83 |
"language": self.launguage,
|
| 84 |
"stream": self.stream
|
| 85 |
}
|
|
|
|
| 86 |
response = requests.post(self.url + '/start_model', json=data)
|
| 87 |
if response.status_code == 200:
|
| 88 |
res = json.loads(response.text)
|
|
@@ -129,6 +150,9 @@ class Sensevoice:
|
|
| 129 |
|
| 130 |
def is_ready(self, block=False):
|
| 131 |
while not app.need_exit():
|
|
|
|
|
|
|
|
|
|
| 132 |
if self._get_status() == "loaded":
|
| 133 |
return True
|
| 134 |
else:
|
|
@@ -136,10 +160,6 @@ class Sensevoice:
|
|
| 136 |
time.sleep(1)
|
| 137 |
else:
|
| 138 |
return False
|
| 139 |
-
|
| 140 |
-
if self.thread_is_exit:
|
| 141 |
-
return True if self.thread_exit_code == 0 else False
|
| 142 |
-
|
| 143 |
return False
|
| 144 |
|
| 145 |
def stop(self):
|
|
@@ -225,17 +245,20 @@ class Sensevoice:
|
|
| 225 |
}
|
| 226 |
|
| 227 |
try:
|
|
|
|
| 228 |
response = requests.post(self.url + '/asr', json=data)
|
| 229 |
if response.status_code == 200:
|
| 230 |
res = json.loads(response.text)
|
| 231 |
text = res.get("text", "")
|
| 232 |
if len(text) > 0:
|
| 233 |
-
|
| 234 |
else:
|
| 235 |
-
|
| 236 |
else:
|
| 237 |
print(f"Requests failed: {response.status_code}")
|
| 238 |
-
|
|
|
|
|
|
|
| 239 |
except Exception as e:
|
| 240 |
print("Requests failed:", e)
|
| 241 |
return ""
|
|
@@ -258,13 +281,15 @@ class Sensevoice:
|
|
| 258 |
"launguage": "auto",
|
| 259 |
"step": 0.1,
|
| 260 |
}
|
| 261 |
-
|
| 262 |
try:
|
| 263 |
response = requests.post(self.url + '/asr_stream', json=data, stream=True)
|
| 264 |
for line in response.iter_lines():
|
| 265 |
if line:
|
| 266 |
chunk = json.loads(line)
|
| 267 |
yield chunk.get("text", "")
|
|
|
|
|
|
|
| 268 |
except Exception as e:
|
| 269 |
print("Requests failed:", e)
|
| 270 |
return ""
|
|
|
|
| 6 |
import wave
|
| 7 |
import numpy as np
|
| 8 |
import threading
|
| 9 |
+
from maix import app, time, nn
|
| 10 |
|
| 11 |
class Sensevoice:
|
| 12 |
def __init__(self, model = "", url="http://0.0.0.0:12347", lauguage="auto", stream=False):
|
|
|
|
| 26 |
if not os.path.exists(model):
|
| 27 |
raise ValueError(f'Model {self.model} is not existed!')
|
| 28 |
|
| 29 |
+
self.model_mud = nn.MUD(model)
|
| 30 |
+
self.model_configs = self.model_mud.items
|
| 31 |
+
server_path = self.model_configs.get('extra', {}).get('server', '/root/models/sensevoice_maixcam2/server.py')
|
| 32 |
+
model_dir = os.path.dirname(self.model)
|
| 33 |
+
service_env_path = self.model_configs.get('extra', {}).get('sensevoice_service_env')
|
| 34 |
+
self._create_service_environment_file(model_dir, os.path.join(model_dir, server_path), os.path.join(model_dir, service_env_path))
|
| 35 |
+
|
| 36 |
+
def _create_service_environment_file(self, working_dir: str, server_path: str, env_file_path: str = "/tmp/sensevoice.service.env"):
|
| 37 |
+
os.system(f"echo 'WORK_DIR={os.path.realpath(working_dir)}' > {os.path.realpath(env_file_path)}")
|
| 38 |
+
os.system(f"echo 'SERVER_FILE={os.path.realpath(server_path)}' >> {os.path.realpath(env_file_path)}")
|
| 39 |
+
|
| 40 |
def _check_service(self):
|
| 41 |
try:
|
| 42 |
response = requests.get(self.url + '/status')
|
|
|
|
| 86 |
except:
|
| 87 |
return "not loaded"
|
| 88 |
|
| 89 |
+
def _reset_model(self):
|
| 90 |
+
try:
|
| 91 |
+
response = requests.post(self.url + '/reset_model')
|
| 92 |
+
if response.status_code == 200:
|
| 93 |
+
res = json.loads(response.text)
|
| 94 |
+
return res["status"]
|
| 95 |
+
except:
|
| 96 |
+
return "not loaded"
|
| 97 |
+
|
| 98 |
def _start_model(self):
|
| 99 |
try:
|
| 100 |
data = {
|
|
|
|
| 103 |
"language": self.launguage,
|
| 104 |
"stream": self.stream
|
| 105 |
}
|
| 106 |
+
|
| 107 |
response = requests.post(self.url + '/start_model', json=data)
|
| 108 |
if response.status_code == 200:
|
| 109 |
res = json.loads(response.text)
|
|
|
|
| 150 |
|
| 151 |
def is_ready(self, block=False):
|
| 152 |
while not app.need_exit():
|
| 153 |
+
if self.thread_is_exit:
|
| 154 |
+
return True if self.thread_exit_code == 0 else False
|
| 155 |
+
|
| 156 |
if self._get_status() == "loaded":
|
| 157 |
return True
|
| 158 |
else:
|
|
|
|
| 160 |
time.sleep(1)
|
| 161 |
else:
|
| 162 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
return False
|
| 164 |
|
| 165 |
def stop(self):
|
|
|
|
| 245 |
}
|
| 246 |
|
| 247 |
try:
|
| 248 |
+
result = ""
|
| 249 |
response = requests.post(self.url + '/asr', json=data)
|
| 250 |
if response.status_code == 200:
|
| 251 |
res = json.loads(response.text)
|
| 252 |
text = res.get("text", "")
|
| 253 |
if len(text) > 0:
|
| 254 |
+
result = text[0]
|
| 255 |
else:
|
| 256 |
+
result = ""
|
| 257 |
else:
|
| 258 |
print(f"Requests failed: {response.status_code}")
|
| 259 |
+
result = ""
|
| 260 |
+
|
| 261 |
+
return result
|
| 262 |
except Exception as e:
|
| 263 |
print("Requests failed:", e)
|
| 264 |
return ""
|
|
|
|
| 281 |
"launguage": "auto",
|
| 282 |
"step": 0.1,
|
| 283 |
}
|
| 284 |
+
|
| 285 |
try:
|
| 286 |
response = requests.post(self.url + '/asr_stream', json=data, stream=True)
|
| 287 |
for line in response.iter_lines():
|
| 288 |
if line:
|
| 289 |
chunk = json.loads(line)
|
| 290 |
yield chunk.get("text", "")
|
| 291 |
+
|
| 292 |
+
self._reset_model()
|
| 293 |
except Exception as e:
|
| 294 |
print("Requests failed:", e)
|
| 295 |
return ""
|
embeddings.npy → sensevoice-maixcam2/embeddings.npy
RENAMED
|
File without changes
|
frontend.py → sensevoice-maixcam2/frontend.py
RENAMED
|
File without changes
|
model.mud → sensevoice-maixcam2/model.mud
RENAMED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
[basic]
|
| 2 |
type = axmodel
|
| 3 |
-
model_npu =
|
| 4 |
model_vnpu =
|
| 5 |
|
| 6 |
[extra]
|
|
@@ -11,5 +11,6 @@ beam_size = 3
|
|
| 11 |
language = auto
|
| 12 |
hot_words = None,
|
| 13 |
use_itn = True
|
| 14 |
-
stream_model =
|
| 15 |
-
|
|
|
|
|
|
| 1 |
[basic]
|
| 2 |
type = axmodel
|
| 3 |
+
model_npu = sensevoice.axmodel
|
| 4 |
model_vnpu =
|
| 5 |
|
| 6 |
[extra]
|
|
|
|
| 11 |
language = auto
|
| 12 |
hot_words = None,
|
| 13 |
use_itn = True
|
| 14 |
+
stream_model = streaming_sensevoice.axmodel
|
| 15 |
+
sensevoice_service_env = /tmp/sensevoice.service.env
|
| 16 |
+
server = server.py
|
pe_nonstream.npy → sensevoice-maixcam2/pe_nonstream.npy
RENAMED
|
File without changes
|
pe_streaming.npy → sensevoice-maixcam2/pe_streaming.npy
RENAMED
|
File without changes
|
{sensevoice_ax630c → sensevoice-maixcam2}/sensevoice.axmodel
RENAMED
|
File without changes
|
server.py → sensevoice-maixcam2/server.py
RENAMED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import numpy as np
|
| 2 |
from fastapi import FastAPI, HTTPException, Body
|
| 3 |
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
| 4 |
from typing import List, Optional
|
| 5 |
import logging
|
| 6 |
import json
|
|
@@ -15,13 +16,12 @@ import time
|
|
| 15 |
logging.basicConfig(level=logging.INFO)
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
-
app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API")
|
| 19 |
-
|
| 20 |
# 全局变量存储模型
|
| 21 |
asr_model = None
|
| 22 |
asr_model_is_loaded = False
|
| 23 |
mud_configs = None
|
| 24 |
|
|
|
|
| 25 |
def parse_config_file_to_json(file_path):
|
| 26 |
"""从文件读取配置并解析为JSON"""
|
| 27 |
if not os.path.exists(file_path):
|
|
@@ -96,9 +96,12 @@ async def handle_heartbeat():
|
|
| 96 |
last_get_heartbeat_s = time.time() # 模型未加载时,重置获取心跳的时间
|
| 97 |
await asyncio.sleep(10) # 每隔 10 秒检查一次
|
| 98 |
|
| 99 |
-
@
|
| 100 |
-
async def
|
| 101 |
asyncio.create_task(handle_heartbeat())
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def validate_audio_data(audio_data: List[float]) -> np.ndarray:
|
| 104 |
"""
|
|
@@ -147,6 +150,15 @@ async def get_status():
|
|
| 147 |
global asr_model_is_loaded
|
| 148 |
return JSONResponse(content={"status": "loaded" if asr_model_is_loaded else "not loaded"})
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
@app.post("/start_model", summary="Load model")
|
| 151 |
async def start_model(
|
| 152 |
model_path: str = Body(
|
|
@@ -161,10 +173,14 @@ async def start_model(
|
|
| 161 |
"""
|
| 162 |
global asr_model
|
| 163 |
global asr_model_is_loaded
|
| 164 |
-
logger.info("Loading ASR model...")
|
| 165 |
|
| 166 |
if asr_model_is_loaded:
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
try:
|
| 170 |
mud_configs = parse_config_file_to_json(model_path)
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from fastapi import FastAPI, HTTPException, Body
|
| 3 |
from fastapi.responses import JSONResponse, StreamingResponse
|
| 4 |
+
from contextlib import asynccontextmanager
|
| 5 |
from typing import List, Optional
|
| 6 |
import logging
|
| 7 |
import json
|
|
|
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
|
|
|
|
|
|
| 19 |
# 全局变量存储模型
|
| 20 |
asr_model = None
|
| 21 |
asr_model_is_loaded = False
|
| 22 |
mud_configs = None
|
| 23 |
|
| 24 |
+
|
| 25 |
def parse_config_file_to_json(file_path):
|
| 26 |
"""从文件读取配置并解析为JSON"""
|
| 27 |
if not os.path.exists(file_path):
|
|
|
|
| 96 |
last_get_heartbeat_s = time.time() # 模型未加载时,重置获取心跳的时间
|
| 97 |
await asyncio.sleep(10) # 每隔 10 秒检查一次
|
| 98 |
|
| 99 |
+
@asynccontextmanager
|
| 100 |
+
async def lifespan(app: FastAPI):
|
| 101 |
asyncio.create_task(handle_heartbeat())
|
| 102 |
+
yield
|
| 103 |
+
|
| 104 |
+
app = FastAPI(title="ASR Server", lifespan = lifespan, description="Automatic Speech Recognition API")
|
| 105 |
|
| 106 |
def validate_audio_data(audio_data: List[float]) -> np.ndarray:
|
| 107 |
"""
|
|
|
|
| 150 |
global asr_model_is_loaded
|
| 151 |
return JSONResponse(content={"status": "loaded" if asr_model_is_loaded else "not loaded"})
|
| 152 |
|
| 153 |
+
@app.post("/reset_model", summary="Reset model")
|
| 154 |
+
async def reset_model():
|
| 155 |
+
global asr_model
|
| 156 |
+
global asr_model_is_loaded
|
| 157 |
+
if not asr_model_is_loaded:
|
| 158 |
+
return JSONResponse(content={"status": "not loaded"})
|
| 159 |
+
asr_model.reset()
|
| 160 |
+
return JSONResponse(content={"status": "ok"})
|
| 161 |
+
|
| 162 |
@app.post("/start_model", summary="Load model")
|
| 163 |
async def start_model(
|
| 164 |
model_path: str = Body(
|
|
|
|
| 173 |
"""
|
| 174 |
global asr_model
|
| 175 |
global asr_model_is_loaded
|
|
|
|
| 176 |
|
| 177 |
if asr_model_is_loaded:
|
| 178 |
+
if asr_model.streaming == stream:
|
| 179 |
+
return JSONResponse(content={"status": "loaded"})
|
| 180 |
+
else:
|
| 181 |
+
await try_stop_model()
|
| 182 |
+
|
| 183 |
+
logger.info("Loading ASR model...")
|
| 184 |
|
| 185 |
try:
|
| 186 |
mud_configs = parse_config_file_to_json(model_path)
|
{sensevoice_ax630c → sensevoice-maixcam2}/streaming_sensevoice.axmodel
RENAMED
|
File without changes
|
tokenizer.py → sensevoice-maixcam2/tokenizer.py
RENAMED
|
File without changes
|
test_wer.py
DELETED
|
@@ -1,296 +0,0 @@
|
|
| 1 |
-
import os, sys
|
| 2 |
-
import argparse
|
| 3 |
-
from SenseVoiceAx import SenseVoiceAx
|
| 4 |
-
from tokenizer import SentencepiecesTokenizer
|
| 5 |
-
from print_utils import rich_transcription_postprocess, rich_print_asr_res
|
| 6 |
-
from download_utils import download_model
|
| 7 |
-
import logging
|
| 8 |
-
import re
|
| 9 |
-
import emoji
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def setup_logging():
|
| 13 |
-
"""配置日志系统,同时输出到控制台和文件"""
|
| 14 |
-
# 获取脚本所在目录
|
| 15 |
-
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
-
log_file = os.path.join(script_dir, "test_wer.log")
|
| 17 |
-
|
| 18 |
-
# 配置日志格式
|
| 19 |
-
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
| 20 |
-
date_format = "%Y-%m-%d %H:%M:%S"
|
| 21 |
-
|
| 22 |
-
# 创建logger
|
| 23 |
-
logger = logging.getLogger()
|
| 24 |
-
logger.setLevel(logging.INFO)
|
| 25 |
-
|
| 26 |
-
# 清除现有的handler
|
| 27 |
-
for handler in logger.handlers[:]:
|
| 28 |
-
logger.removeHandler(handler)
|
| 29 |
-
|
| 30 |
-
# 创建文件handler
|
| 31 |
-
file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
|
| 32 |
-
file_handler.setLevel(logging.INFO)
|
| 33 |
-
file_formatter = logging.Formatter(log_format, date_format)
|
| 34 |
-
file_handler.setFormatter(file_formatter)
|
| 35 |
-
|
| 36 |
-
# 创建控制台handler
|
| 37 |
-
console_handler = logging.StreamHandler()
|
| 38 |
-
console_handler.setLevel(logging.INFO)
|
| 39 |
-
console_formatter = logging.Formatter(log_format, date_format)
|
| 40 |
-
console_handler.setFormatter(console_formatter)
|
| 41 |
-
|
| 42 |
-
# 添加handler到logger
|
| 43 |
-
logger.addHandler(file_handler)
|
| 44 |
-
logger.addHandler(console_handler)
|
| 45 |
-
|
| 46 |
-
return logger
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class AIShellDataset:
|
| 50 |
-
def __init__(self, gt_path: str):
|
| 51 |
-
"""
|
| 52 |
-
初始化数据集
|
| 53 |
-
|
| 54 |
-
Args:
|
| 55 |
-
json_path: voice.json文件的路径
|
| 56 |
-
"""
|
| 57 |
-
self.gt_path = gt_path
|
| 58 |
-
self.dataset_dir = os.path.dirname(gt_path)
|
| 59 |
-
self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
|
| 60 |
-
|
| 61 |
-
# 检查必要文件和文件夹是否存在
|
| 62 |
-
assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
|
| 63 |
-
assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
|
| 64 |
-
|
| 65 |
-
# 加载数据
|
| 66 |
-
self.data = []
|
| 67 |
-
with open(gt_path, "r", encoding="utf-8") as f:
|
| 68 |
-
for line in f:
|
| 69 |
-
line = line.strip()
|
| 70 |
-
audio_path, gt = line.split(" ")
|
| 71 |
-
audio_path = os.path.join(self.voice_dir, audio_path + ".wav")
|
| 72 |
-
self.data.append({"audio_path": audio_path, "gt": gt})
|
| 73 |
-
|
| 74 |
-
# 使用logging而不是print
|
| 75 |
-
logger = logging.getLogger()
|
| 76 |
-
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 77 |
-
|
| 78 |
-
def __iter__(self):
|
| 79 |
-
"""返回迭代器"""
|
| 80 |
-
self.index = 0
|
| 81 |
-
return self
|
| 82 |
-
|
| 83 |
-
def __next__(self):
|
| 84 |
-
"""返回下一个数据项"""
|
| 85 |
-
if self.index >= len(self.data):
|
| 86 |
-
raise StopIteration
|
| 87 |
-
|
| 88 |
-
item = self.data[self.index]
|
| 89 |
-
audio_path = item["audio_path"]
|
| 90 |
-
ground_truth = item["gt"]
|
| 91 |
-
|
| 92 |
-
self.index += 1
|
| 93 |
-
return audio_path, ground_truth
|
| 94 |
-
|
| 95 |
-
def __len__(self):
|
| 96 |
-
"""返回数据集大小"""
|
| 97 |
-
return len(self.data)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
class CommonVoiceDataset:
|
| 101 |
-
"""Common Voice数据集解析器"""
|
| 102 |
-
|
| 103 |
-
def __init__(self, tsv_path: str):
|
| 104 |
-
"""
|
| 105 |
-
初始化数据集
|
| 106 |
-
|
| 107 |
-
Args:
|
| 108 |
-
json_path: voice.json文件的路径
|
| 109 |
-
"""
|
| 110 |
-
self.tsv_path = tsv_path
|
| 111 |
-
self.dataset_dir = os.path.dirname(tsv_path)
|
| 112 |
-
self.voice_dir = os.path.join(self.dataset_dir, "clips")
|
| 113 |
-
|
| 114 |
-
# 检查必要文件和文件夹是否存在
|
| 115 |
-
assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
|
| 116 |
-
assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
|
| 117 |
-
|
| 118 |
-
# 加载JSON数据
|
| 119 |
-
self.data = []
|
| 120 |
-
with open(tsv_path, "r", encoding="utf-8") as f:
|
| 121 |
-
f.readline()
|
| 122 |
-
for line in f:
|
| 123 |
-
line = line.strip()
|
| 124 |
-
splits = line.split("\t")
|
| 125 |
-
audio_path = splits[1]
|
| 126 |
-
gt = splits[3]
|
| 127 |
-
audio_path = os.path.join(self.voice_dir, audio_path)
|
| 128 |
-
self.data.append({"audio_path": audio_path, "gt": gt})
|
| 129 |
-
|
| 130 |
-
# 使用logging而不是print
|
| 131 |
-
logger = logging.getLogger()
|
| 132 |
-
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 133 |
-
|
| 134 |
-
def __iter__(self):
|
| 135 |
-
"""返回迭代器"""
|
| 136 |
-
self.index = 0
|
| 137 |
-
return self
|
| 138 |
-
|
| 139 |
-
def __next__(self):
|
| 140 |
-
"""返回下一个数据项"""
|
| 141 |
-
if self.index >= len(self.data):
|
| 142 |
-
raise StopIteration
|
| 143 |
-
|
| 144 |
-
item = self.data[self.index]
|
| 145 |
-
audio_path = item["audio_path"]
|
| 146 |
-
ground_truth = item["gt"]
|
| 147 |
-
|
| 148 |
-
self.index += 1
|
| 149 |
-
return audio_path, ground_truth
|
| 150 |
-
|
| 151 |
-
def __len__(self):
|
| 152 |
-
"""返回数据集大小"""
|
| 153 |
-
return len(self.data)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def get_args():
|
| 157 |
-
parser = argparse.ArgumentParser()
|
| 158 |
-
parser.add_argument(
|
| 159 |
-
"--dataset",
|
| 160 |
-
"-d",
|
| 161 |
-
type=str,
|
| 162 |
-
required=True,
|
| 163 |
-
choices=["aishell", "common_voice"],
|
| 164 |
-
help="Test dataset",
|
| 165 |
-
)
|
| 166 |
-
parser.add_argument(
|
| 167 |
-
"--gt_path",
|
| 168 |
-
"-g",
|
| 169 |
-
type=str,
|
| 170 |
-
required=True,
|
| 171 |
-
help="Test dataset ground truth file",
|
| 172 |
-
)
|
| 173 |
-
parser.add_argument(
|
| 174 |
-
"--language",
|
| 175 |
-
"-l",
|
| 176 |
-
required=False,
|
| 177 |
-
type=str,
|
| 178 |
-
default="auto",
|
| 179 |
-
choices=["auto", "zh", "en", "yue", "ja", "ko"],
|
| 180 |
-
)
|
| 181 |
-
parser.add_argument(
|
| 182 |
-
"--max_num", type=int, default=-1, required=False, help="Maximum test data num"
|
| 183 |
-
)
|
| 184 |
-
return parser.parse_args()
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
def min_distance(word1: str, word2: str) -> int:
|
| 188 |
-
|
| 189 |
-
row = len(word1) + 1
|
| 190 |
-
column = len(word2) + 1
|
| 191 |
-
|
| 192 |
-
cache = [[0] * column for i in range(row)]
|
| 193 |
-
|
| 194 |
-
for i in range(row):
|
| 195 |
-
for j in range(column):
|
| 196 |
-
|
| 197 |
-
if i == 0 and j == 0:
|
| 198 |
-
cache[i][j] = 0
|
| 199 |
-
elif i == 0 and j != 0:
|
| 200 |
-
cache[i][j] = j
|
| 201 |
-
elif j == 0 and i != 0:
|
| 202 |
-
cache[i][j] = i
|
| 203 |
-
else:
|
| 204 |
-
if word1[i - 1] == word2[j - 1]:
|
| 205 |
-
cache[i][j] = cache[i - 1][j - 1]
|
| 206 |
-
else:
|
| 207 |
-
replace = cache[i - 1][j - 1] + 1
|
| 208 |
-
insert = cache[i][j - 1] + 1
|
| 209 |
-
remove = cache[i - 1][j] + 1
|
| 210 |
-
|
| 211 |
-
cache[i][j] = min(replace, insert, remove)
|
| 212 |
-
|
| 213 |
-
return cache[row - 1][column - 1]
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
def remove_punctuation(text):
|
| 217 |
-
# 定义正则表达式模式,匹配所有标点符号
|
| 218 |
-
# 这个模式包括常见的标点符号和中文标点
|
| 219 |
-
pattern = r"[^\w\s]|_"
|
| 220 |
-
|
| 221 |
-
# 使用sub方法将所有匹配的标点符号替换为空字符串
|
| 222 |
-
cleaned_text = re.sub(pattern, "", text)
|
| 223 |
-
|
| 224 |
-
return cleaned_text
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
def main():
|
| 228 |
-
logger = setup_logging()
|
| 229 |
-
args = get_args()
|
| 230 |
-
|
| 231 |
-
language = args.language
|
| 232 |
-
use_itn = False # 标点符号预测
|
| 233 |
-
max_num = args.max_num
|
| 234 |
-
|
| 235 |
-
dataset_type = args.dataset.lower()
|
| 236 |
-
if dataset_type == "aishell":
|
| 237 |
-
dataset = AIShellDataset(args.gt_path)
|
| 238 |
-
elif dataset_type == "common_voice":
|
| 239 |
-
dataset = CommonVoiceDataset(args.gt_path)
|
| 240 |
-
else:
|
| 241 |
-
raise ValueError(f"Unknown dataset type {dataset_type}")
|
| 242 |
-
|
| 243 |
-
# model_path_root = download_model("SenseVoice")
|
| 244 |
-
model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
|
| 245 |
-
bpemodel = "chn_jpn_yue_eng_ko_spectok.bpe.model"
|
| 246 |
-
|
| 247 |
-
assert os.path.exists(model_path), f"model {model_path} not exist"
|
| 248 |
-
|
| 249 |
-
logger.info(f"dataset: {args.dataset}")
|
| 250 |
-
logger.info(f"language: {language}")
|
| 251 |
-
logger.info(f"use_itn: {use_itn}")
|
| 252 |
-
logger.info(f"model_path: {model_path}")
|
| 253 |
-
|
| 254 |
-
tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
|
| 255 |
-
pipeline = SenseVoiceAx(
|
| 256 |
-
model_path, language=language, use_itn=use_itn, tokenizer=tokenizer, max_len=256
|
| 257 |
-
)
|
| 258 |
-
|
| 259 |
-
# Iterate over dataset
|
| 260 |
-
hyp = []
|
| 261 |
-
references = []
|
| 262 |
-
all_character_error_num = 0
|
| 263 |
-
all_character_num = 0
|
| 264 |
-
max_data_num = max_num if max_num > 0 else len(dataset)
|
| 265 |
-
for n, (audio_path, reference) in enumerate(dataset):
|
| 266 |
-
reference = remove_punctuation(reference).lower()
|
| 267 |
-
|
| 268 |
-
asr_res = pipeline.infer(audio_path, print_rtf=False)
|
| 269 |
-
hypothesis = rich_print_asr_res(
|
| 270 |
-
asr_res, will_print=False, remove_punc=True
|
| 271 |
-
).lower()
|
| 272 |
-
hypothesis = emoji.replace_emoji(hypothesis, replace="")
|
| 273 |
-
|
| 274 |
-
character_error_num = min_distance(reference, hypothesis)
|
| 275 |
-
character_num = len(reference)
|
| 276 |
-
character_error_rate = character_error_num / character_num * 100
|
| 277 |
-
|
| 278 |
-
all_character_error_num += character_error_num
|
| 279 |
-
all_character_num += character_num
|
| 280 |
-
|
| 281 |
-
hyp.append(hypothesis)
|
| 282 |
-
references.append(reference)
|
| 283 |
-
|
| 284 |
-
line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
|
| 285 |
-
logger.info(line_content)
|
| 286 |
-
|
| 287 |
-
if n + 1 >= max_data_num:
|
| 288 |
-
break
|
| 289 |
-
|
| 290 |
-
total_character_error_rate = all_character_error_num / all_character_num * 100
|
| 291 |
-
|
| 292 |
-
logger.info(f"Total WER: {total_character_error_rate}%")
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
if __name__ == "__main__":
|
| 296 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|