Commit ·
ee4406b
0
Parent(s):
* support for maixcam2
Browse files- .gitattributes +39 -0
- .gitignore +2 -0
- LICENSE +21 -0
- README.md +96 -0
- SenseVoiceAx.py +415 -0
- am.mvn +8 -0
- chn_jpn_yue_eng_ko_spectok.bpe.model +3 -0
- client.py +146 -0
- config.json +0 -0
- download_dataset.sh +2 -0
- download_utils.py +33 -0
- embeddings.npy +3 -0
- example/en.mp3 +3 -0
- example/ja.mp3 +3 -0
- example/ko.mp3 +3 -0
- example/yue.mp3 +3 -0
- example/zh.mp3 +3 -0
- frontend.py +460 -0
- gradio_demo.py +62 -0
- main.py +79 -0
- model.mud +15 -0
- pe_nonstream.npy +3 -0
- pe_streaming.npy +3 -0
- print_utils.py +131 -0
- requirements.txt +11 -0
- sensevoice_ax630c/sensevoice.axmodel +3 -0
- sensevoice_ax630c/streaming_sensevoice.axmodel +3 -0
- server.py +270 -0
- test_wer.py +296 -0
- tokenizer.py +135 -0
.gitattributes
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
sensevoice.axmodel filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.axmodel filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
.gradio
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 祈Inory
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
pipeline_tag: automatic-speech-recognition
|
| 6 |
+
---
|
| 7 |
+
# sensevoice.axera
|
| 8 |
+
FunASR SenseVoice on Axera, official repo: https://github.com/FunAudioLLM/SenseVoice
|
| 9 |
+
|
| 10 |
+
## TODO
|
| 11 |
+
|
| 12 |
+
- [x] 支持AX630C
|
| 13 |
+
- [ ] 支持C++
|
| 14 |
+
- [x] 支持FastAPI
|
| 15 |
+
|
| 16 |
+
## 功能
|
| 17 |
+
- 语音识别
|
| 18 |
+
- 自动识别语言(支持中文、英文、粤语、日语、韩语)
|
| 19 |
+
- 情感识别
|
| 20 |
+
- 自动标点
|
| 21 |
+
- 支持流式识别
|
| 22 |
+
|
| 23 |
+
## 支持平台
|
| 24 |
+
|
| 25 |
+
- [x] AX650N
|
| 26 |
+
- [x] AX630C
|
| 27 |
+
|
| 28 |
+
## 环境安装
|
| 29 |
+
```
|
| 30 |
+
pip3 install -r requirements.txt
|
| 31 |
+
```
|
| 32 |
+
如果空间不足可以使用 --prefix 指定别的安装路径
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
## 使用
|
| 36 |
+
```
|
| 37 |
+
# 首次运行会自动从huggingface上下载模型, 保存到models中
|
| 38 |
+
python3 main.py -i 输入音频文件
|
| 39 |
+
```
|
| 40 |
+
运行参数说明:
|
| 41 |
+
| 参数名称 | 说明 | 默认值 |
|
| 42 |
+
| --- | --- | --- |
|
| 43 |
+
| --input/-i | 输入音频文件 | |
|
| 44 |
+
| --language/-l | 识别语言,支持auto, zh, en, yue, ja, ko | auto |
|
| 45 |
+
| --streaming | 流式识别 | |
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
### 示例:
|
| 49 |
+
example下有测试音频
|
| 50 |
+
|
| 51 |
+
如 粤语测试
|
| 52 |
+
```
|
| 53 |
+
python3 main.py -i example/yue.mp3
|
| 54 |
+
```
|
| 55 |
+
输出
|
| 56 |
+
```
|
| 57 |
+
RTF: 0.03026517820946964 Latency: 0.15689468383789062s Total length: 5.184s
|
| 58 |
+
['呢几个字。', '都表达唔到,我想讲嘅意。', '思。']
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
流式识别
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
python3 main.py -i example/zh.mp3 --streaming
|
| 65 |
+
```
|
| 66 |
+
输出
|
| 67 |
+
```
|
| 68 |
+
{'timestamps': [540], 'text': '开'}
|
| 69 |
+
{'timestamps': [540, 780, 1080], 'text': '开放时'}
|
| 70 |
+
{'timestamps': [540, 780, 1080, 1260, 1740], 'text': '开放时间早'}
|
| 71 |
+
{'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340], 'text': '开放时间早上9'}
|
| 72 |
+
{'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640], 'text': '开放时间早上9点'}
|
| 73 |
+
{'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060], 'text': '开放时间早上9点至'}
|
| 74 |
+
{'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060, 3780, 4020], 'text': '开放时间早上9点至下午'}
|
| 75 |
+
{'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060, 3780, 4020, 4440, 4620], 'text': '开放时间早上9点至下午五点'}
|
| 76 |
+
RTF: 0.03678379235444246
|
| 77 |
+
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## 准确率
|
| 81 |
+
|
| 82 |
+
使用WER(Word-Error-Rate)作为评价标准
|
| 83 |
+
|
| 84 |
+
**WER = 0.0389**
|
| 85 |
+
|
| 86 |
+
### 复现测试结果
|
| 87 |
+
|
| 88 |
+
```
|
| 89 |
+
./download_datasets.sh
|
| 90 |
+
python test_wer.py -d datasets -l zh
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## 技术讨论
|
| 94 |
+
|
| 95 |
+
- Github issues
|
| 96 |
+
- QQ 群: 139953715
|
SenseVoiceAx.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import axengine as axe
|
| 2 |
+
import numpy as np
|
| 3 |
+
import librosa
|
| 4 |
+
from frontend import WavFrontend
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from typing import List, Union, Optional
|
| 8 |
+
from asr_decoder import CTCDecoder
|
| 9 |
+
from tokenizer import SentencepiecesTokenizer
|
| 10 |
+
from online_fbank import OnlineFbank
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def sequence_mask(lengths, maxlen=None, dtype=np.float32):
|
| 15 |
+
# 如果 maxlen 未指定,则取 lengths 中的最大值
|
| 16 |
+
if maxlen is None:
|
| 17 |
+
maxlen = np.max(lengths)
|
| 18 |
+
|
| 19 |
+
# 创建一个从 0 到 maxlen-1 的行向量
|
| 20 |
+
row_vector = np.arange(0, maxlen, 1)
|
| 21 |
+
|
| 22 |
+
# 将 lengths 转换为列向量
|
| 23 |
+
matrix = np.expand_dims(lengths, axis=-1)
|
| 24 |
+
|
| 25 |
+
# 比较生成掩码
|
| 26 |
+
mask = row_vector < matrix
|
| 27 |
+
if mask.shape[-1] < lengths[0]:
|
| 28 |
+
mask = np.concatenate(
|
| 29 |
+
[
|
| 30 |
+
mask,
|
| 31 |
+
np.zeros(
|
| 32 |
+
(mask.shape[0], lengths[0] - mask.shape[-1]), dtype=np.float32
|
| 33 |
+
),
|
| 34 |
+
],
|
| 35 |
+
axis=-1,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# 返回指定数据类型的掩码
|
| 39 |
+
return mask.astype(dtype)[None, ...]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def unique_consecutive_np(arr):
|
| 43 |
+
"""
|
| 44 |
+
找出数组中连续的唯一值,模拟 torch.unique_consecutive(yseq, dim=-1)
|
| 45 |
+
|
| 46 |
+
参数:
|
| 47 |
+
arr: 一维numpy数组
|
| 48 |
+
|
| 49 |
+
返回:
|
| 50 |
+
unique_values: 去除连续重复值后的数组
|
| 51 |
+
"""
|
| 52 |
+
if len(arr) == 0:
|
| 53 |
+
return np.array([])
|
| 54 |
+
|
| 55 |
+
if len(arr) == 1:
|
| 56 |
+
return arr.copy()
|
| 57 |
+
|
| 58 |
+
# 找出变化的位置
|
| 59 |
+
diff = np.diff(arr)
|
| 60 |
+
change_positions = np.where(diff != 0)[0] + 1
|
| 61 |
+
|
| 62 |
+
# 添加起始位置
|
| 63 |
+
start_positions = np.concatenate(([0], change_positions))
|
| 64 |
+
|
| 65 |
+
# 获取唯一值(每个连续段的第一个值)
|
| 66 |
+
unique_values = arr[start_positions]
|
| 67 |
+
|
| 68 |
+
return unique_values
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SenseVoiceAx:
|
| 72 |
+
"""SenseVoice axmodel runner"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
model_path: str,
|
| 77 |
+
max_len: int = 256,
|
| 78 |
+
beam_size: int = 3,
|
| 79 |
+
language: str = "auto",
|
| 80 |
+
hot_words: Optional[List[str]] = None,
|
| 81 |
+
use_itn: bool = True,
|
| 82 |
+
streaming: bool = False,
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Initialize SenseVoiceAx
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
model_path: Path of axmodel
|
| 89 |
+
max_len: Fixed shape of input of axmodel
|
| 90 |
+
beam_size: Max number of hypos to hold after each decode step
|
| 91 |
+
language: Support auto, zh(Chinese), en(English), yue(Cantonese), ja(Japanese), ko(Korean)
|
| 92 |
+
hot_words: Words that may fail to recognize,
|
| 93 |
+
special words/phrases (aka hotwords) like rare words, personalized information etc.
|
| 94 |
+
use_itn: Allow Invert Text Normalization if True,
|
| 95 |
+
ITN converts ASR model output into its written form to improve text readability,
|
| 96 |
+
For example, the ITN module replaces “one hundred and twenty-three dollars” transcribed by an ASR model with “$123.”
|
| 97 |
+
streaming: Processes audio in small segments or "chunks" sequentially and outputs text on the fly.
|
| 98 |
+
Use stream_infer method if streaming is true otherwise infer.
|
| 99 |
+
|
| 100 |
+
"""
|
| 101 |
+
model_path_root = os.path.dirname(model_path)
|
| 102 |
+
emb_path = os.path.join(model_path_root, "../embeddings.npy")
|
| 103 |
+
cmvn_file = os.path.join(model_path_root, "../am.mvn")
|
| 104 |
+
bpe_model = os.path.join(
|
| 105 |
+
model_path_root, "../chn_jpn_yue_eng_ko_spectok.bpe.model"
|
| 106 |
+
)
|
| 107 |
+
if streaming:
|
| 108 |
+
self.position_encoding = np.load(
|
| 109 |
+
os.path.join(model_path_root, "../pe_streaming.npy")
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
self.position_encoding = np.load(
|
| 113 |
+
os.path.join(model_path_root, "../pe_nonstream.npy")
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.streaming = streaming
|
| 117 |
+
self.tokenizer = SentencepiecesTokenizer(bpemodel=bpe_model)
|
| 118 |
+
|
| 119 |
+
self.frontend = WavFrontend(
|
| 120 |
+
cmvn_file=cmvn_file,
|
| 121 |
+
fs=16000,
|
| 122 |
+
window="hamming",
|
| 123 |
+
n_mels=80,
|
| 124 |
+
frame_length=25,
|
| 125 |
+
frame_shift=10,
|
| 126 |
+
lfr_m=7,
|
| 127 |
+
lfr_n=6,
|
| 128 |
+
)
|
| 129 |
+
self.model = axe.InferenceSession(model_path)
|
| 130 |
+
self.sample_rate = 16000
|
| 131 |
+
self.blank_id = 0
|
| 132 |
+
self.max_len = max_len
|
| 133 |
+
self.padding = 16
|
| 134 |
+
self.input_size = 560
|
| 135 |
+
|
| 136 |
+
self.lid_dict = {
|
| 137 |
+
"auto": 0,
|
| 138 |
+
"zh": 3,
|
| 139 |
+
"en": 4,
|
| 140 |
+
"yue": 7,
|
| 141 |
+
"ja": 11,
|
| 142 |
+
"ko": 12,
|
| 143 |
+
"nospeech": 13,
|
| 144 |
+
}
|
| 145 |
+
self.lid_int_dict = {
|
| 146 |
+
24884: 3,
|
| 147 |
+
24885: 4,
|
| 148 |
+
24888: 7,
|
| 149 |
+
24892: 11,
|
| 150 |
+
24896: 12,
|
| 151 |
+
24992: 13,
|
| 152 |
+
}
|
| 153 |
+
self.textnorm_dict = {"withitn": 14, "woitn": 15}
|
| 154 |
+
self.textnorm_int_dict = {25016: 14, 25017: 15}
|
| 155 |
+
self.emo_dict = {
|
| 156 |
+
"unk": 25009,
|
| 157 |
+
"happy": 25001,
|
| 158 |
+
"sad": 25002,
|
| 159 |
+
"angry": 25003,
|
| 160 |
+
"neutral": 25004,
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
self.load_embeddings(emb_path, language, use_itn)
|
| 164 |
+
self.language = language
|
| 165 |
+
|
| 166 |
+
# decoder
|
| 167 |
+
if beam_size > 1 and hot_words is not None:
|
| 168 |
+
self.beam_size = beam_size
|
| 169 |
+
symbol_table = {}
|
| 170 |
+
for i in range(self.tokenizer.get_vocab_size()):
|
| 171 |
+
symbol_table[self.tokenizer.decode(i)] = i
|
| 172 |
+
self.decoder = CTCDecoder(hot_words, symbol_table, bpe_model)
|
| 173 |
+
else:
|
| 174 |
+
self.beam_size = 1
|
| 175 |
+
self.decoder = CTCDecoder()
|
| 176 |
+
|
| 177 |
+
if streaming:
|
| 178 |
+
self.cur_idx = -1
|
| 179 |
+
self.chunk_size = max_len - self.padding
|
| 180 |
+
self.caches_shape = (max_len, self.input_size)
|
| 181 |
+
self.caches = np.zeros(self.caches_shape, dtype=np.float32)
|
| 182 |
+
self.zeros = np.zeros((1, self.input_size), dtype=np.float32)
|
| 183 |
+
self.neg_mean, self.inv_stddev = (
|
| 184 |
+
self.frontend.cmvn[0, :],
|
| 185 |
+
self.frontend.cmvn[1, :],
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.fbank = OnlineFbank(window_type="hamming")
|
| 189 |
+
self.masks = sequence_mask(
|
| 190 |
+
np.array([self.max_len], dtype=np.int32),
|
| 191 |
+
maxlen=self.max_len,
|
| 192 |
+
dtype=np.float32,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def language_options(self):
|
| 197 |
+
return list(self.lid_dict.keys())
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def textnorm_options(self):
|
| 201 |
+
return list(self.textnorm_dict.keys())
|
| 202 |
+
|
| 203 |
+
def load_embeddings(self, emb_path, language, use_itn):
|
| 204 |
+
self.embeddings = np.load(emb_path, allow_pickle=True).item()
|
| 205 |
+
self.language_query = self.embeddings[language]
|
| 206 |
+
self.textnorm_query = (
|
| 207 |
+
self.embeddings["withitn"] if use_itn else self.embeddings["woitn"]
|
| 208 |
+
)
|
| 209 |
+
self.event_emo_query = self.embeddings["event_emo"]
|
| 210 |
+
self.input_query = np.concatenate(
|
| 211 |
+
(self.textnorm_query, self.language_query, self.event_emo_query), axis=1
|
| 212 |
+
)
|
| 213 |
+
self.query_num = self.input_query.shape[1]
|
| 214 |
+
|
| 215 |
+
def choose_language(self, language):
|
| 216 |
+
self.language_query = self.embeddings[language]
|
| 217 |
+
self.input_query = np.concatenate(
|
| 218 |
+
(self.textnorm_query, self.language_query, self.event_emo_query), axis=1
|
| 219 |
+
)
|
| 220 |
+
self.language = language
|
| 221 |
+
|
| 222 |
+
def load_data(self, filepath: str) -> np.ndarray:
|
| 223 |
+
waveform, _ = librosa.load(filepath, sr=self.sample_rate)
|
| 224 |
+
return waveform.flatten()
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
|
| 228 |
+
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
|
| 229 |
+
pad_width = ((0, max_feat_len - cur_len), (0, 0))
|
| 230 |
+
return np.pad(feat, pad_width, "constant", constant_values=0)
|
| 231 |
+
|
| 232 |
+
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
|
| 233 |
+
feats = np.array(feat_res).astype(np.float32)
|
| 234 |
+
return feats
|
| 235 |
+
|
| 236 |
+
def preprocess(self, waveform):
|
| 237 |
+
feats, feats_len = [], []
|
| 238 |
+
for wf in [waveform]:
|
| 239 |
+
speech, _ = self.frontend.fbank(wf)
|
| 240 |
+
feat, feat_len = self.frontend.lfr_cmvn(speech)
|
| 241 |
+
feats.append(feat)
|
| 242 |
+
feats_len.append(feat_len)
|
| 243 |
+
|
| 244 |
+
feats = self.pad_feats(feats, np.max(feats_len))
|
| 245 |
+
feats_len = np.array(feats_len).astype(np.int32)
|
| 246 |
+
return feats, feats_len
|
| 247 |
+
|
| 248 |
+
def postprocess(self, ctc_logits, encoder_out_lens):
|
| 249 |
+
# 提取数据
|
| 250 |
+
x = ctc_logits[0, 4 : encoder_out_lens[0], :]
|
| 251 |
+
|
| 252 |
+
# 获取最大值索引
|
| 253 |
+
yseq = np.argmax(x, axis=-1)
|
| 254 |
+
|
| 255 |
+
# 去除连续重复元素
|
| 256 |
+
yseq = unique_consecutive_np(yseq)
|
| 257 |
+
|
| 258 |
+
# 创建掩码并过滤 blank_id
|
| 259 |
+
mask = yseq != self.blank_id
|
| 260 |
+
token_int = yseq[mask].tolist()
|
| 261 |
+
|
| 262 |
+
return token_int
|
| 263 |
+
|
| 264 |
+
def infer_waveform(self, waveform: np.ndarray, language="auto"):
|
| 265 |
+
if language != self.language:
|
| 266 |
+
self.choose_language(language)
|
| 267 |
+
|
| 268 |
+
# start = time.time()
|
| 269 |
+
feat, feat_len = self.preprocess(waveform)
|
| 270 |
+
# print(f"Preprocess take {time.time() - start}s")
|
| 271 |
+
|
| 272 |
+
slice_len = self.max_len - self.query_num
|
| 273 |
+
slice_num = int(np.ceil(feat.shape[1] / slice_len))
|
| 274 |
+
|
| 275 |
+
asr_res = []
|
| 276 |
+
for i in range(slice_num):
|
| 277 |
+
if i == 0:
|
| 278 |
+
sub_feat = feat[:, i * slice_len : (i + 1) * slice_len, :]
|
| 279 |
+
else:
|
| 280 |
+
sub_feat = feat[
|
| 281 |
+
:,
|
| 282 |
+
i * slice_len - self.padding : (i + 1) * slice_len - self.padding,
|
| 283 |
+
:,
|
| 284 |
+
]
|
| 285 |
+
# concat query
|
| 286 |
+
sub_feat = np.concatenate([self.input_query, sub_feat], axis=1)
|
| 287 |
+
real_len = sub_feat.shape[1]
|
| 288 |
+
if real_len < self.max_len:
|
| 289 |
+
sub_feat = np.concatenate(
|
| 290 |
+
[
|
| 291 |
+
sub_feat,
|
| 292 |
+
np.zeros(
|
| 293 |
+
(1, self.max_len - real_len, sub_feat.shape[-1]),
|
| 294 |
+
dtype=np.float32,
|
| 295 |
+
),
|
| 296 |
+
],
|
| 297 |
+
axis=1,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
masks = sequence_mask(
|
| 301 |
+
np.array([self.max_len], dtype=np.int32),
|
| 302 |
+
maxlen=real_len,
|
| 303 |
+
dtype=np.float32,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# start = time.time()
|
| 307 |
+
outputs = self.model.run(
|
| 308 |
+
None,
|
| 309 |
+
{
|
| 310 |
+
"speech": sub_feat,
|
| 311 |
+
"masks": masks,
|
| 312 |
+
"position_encoding": self.position_encoding,
|
| 313 |
+
},
|
| 314 |
+
)
|
| 315 |
+
ctc_logits, encoder_out_lens = outputs
|
| 316 |
+
|
| 317 |
+
token_int = self.postprocess(ctc_logits, encoder_out_lens)
|
| 318 |
+
|
| 319 |
+
if self.tokenizer is not None:
|
| 320 |
+
asr_res.append(self.tokenizer.tokens2text(token_int))
|
| 321 |
+
else:
|
| 322 |
+
asr_res.append(token_int)
|
| 323 |
+
|
| 324 |
+
return asr_res
|
| 325 |
+
|
| 326 |
+
def infer(
|
| 327 |
+
self, filepath_or_data: Union[np.ndarray, str], language="auto", print_rtf=False
|
| 328 |
+
):
|
| 329 |
+
assert not self.streaming, "This method is for non-streaming model"
|
| 330 |
+
|
| 331 |
+
if isinstance(filepath_or_data, str):
|
| 332 |
+
waveform = self.load_data(filepath_or_data)
|
| 333 |
+
else:
|
| 334 |
+
waveform = filepath_or_data
|
| 335 |
+
|
| 336 |
+
total_time = waveform.shape[-1] / self.sample_rate
|
| 337 |
+
|
| 338 |
+
start = time.time()
|
| 339 |
+
asr_res = self.infer_waveform(waveform, language)
|
| 340 |
+
latency = time.time() - start
|
| 341 |
+
|
| 342 |
+
if print_rtf:
|
| 343 |
+
rtf = latency / total_time
|
| 344 |
+
print(f"RTF: {rtf} Latency: {latency}s Total length: {total_time}s")
|
| 345 |
+
return "".join(asr_res)
|
| 346 |
+
|
| 347 |
+
def decode(self, times, tokens):
|
| 348 |
+
times_ms = []
|
| 349 |
+
for step, token in zip(times, tokens):
|
| 350 |
+
if len(self.tokenizer.decode(token).strip()) == 0:
|
| 351 |
+
continue
|
| 352 |
+
times_ms.append(step * 60)
|
| 353 |
+
return times_ms, self.tokenizer.decode(tokens)
|
| 354 |
+
|
| 355 |
+
def reset(self):
|
| 356 |
+
self.cur_idx = -1
|
| 357 |
+
self.decoder.reset()
|
| 358 |
+
self.fbank = OnlineFbank(window_type="hamming")
|
| 359 |
+
self.caches = np.zeros(self.caches_shape)
|
| 360 |
+
|
| 361 |
+
def get_size(self):
|
| 362 |
+
effective_size = self.cur_idx + 1 - self.padding
|
| 363 |
+
if effective_size <= 0:
|
| 364 |
+
return 0
|
| 365 |
+
return effective_size % self.chunk_size or self.chunk_size
|
| 366 |
+
|
| 367 |
+
def stream_infer(self, audio, is_last, language="auto"):
|
| 368 |
+
assert self.streaming, "This method is for streaming model"
|
| 369 |
+
|
| 370 |
+
if language != self.language:
|
| 371 |
+
self.choose_language(language)
|
| 372 |
+
|
| 373 |
+
self.fbank.accept_waveform(audio, is_last)
|
| 374 |
+
features = self.fbank.get_lfr_frames(
|
| 375 |
+
neg_mean=self.neg_mean, inv_stddev=self.inv_stddev
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
if is_last and len(features) == 0:
|
| 379 |
+
features = self.zeros
|
| 380 |
+
|
| 381 |
+
for idx, feature in enumerate(features):
|
| 382 |
+
is_last = is_last and idx == features.shape[0] - 1
|
| 383 |
+
self.caches = np.roll(self.caches, -1, axis=0)
|
| 384 |
+
self.caches[-1, :] = feature
|
| 385 |
+
self.cur_idx += 1
|
| 386 |
+
cur_size = self.get_size()
|
| 387 |
+
if cur_size != self.chunk_size and not is_last:
|
| 388 |
+
continue
|
| 389 |
+
|
| 390 |
+
speech = self.caches[None, ...]
|
| 391 |
+
outputs = self.model.run(
|
| 392 |
+
None,
|
| 393 |
+
{
|
| 394 |
+
"speech": speech,
|
| 395 |
+
"masks": self.masks,
|
| 396 |
+
"position_encoding": self.position_encoding,
|
| 397 |
+
},
|
| 398 |
+
)
|
| 399 |
+
ctc_logits, encoder_out_lens = outputs
|
| 400 |
+
probs = ctc_logits[0, 4 : encoder_out_lens[0]]
|
| 401 |
+
probs = torch.from_numpy(probs)
|
| 402 |
+
|
| 403 |
+
if cur_size != self.chunk_size:
|
| 404 |
+
probs = probs[self.chunk_size - cur_size :]
|
| 405 |
+
if not is_last:
|
| 406 |
+
probs = probs[: self.chunk_size]
|
| 407 |
+
if self.beam_size > 1:
|
| 408 |
+
res = self.decoder.ctc_prefix_beam_search(
|
| 409 |
+
probs, beam_size=self.beam_size, is_last=is_last
|
| 410 |
+
)
|
| 411 |
+
times_ms, text = self.decode(res["times"][0], res["tokens"][0])
|
| 412 |
+
else:
|
| 413 |
+
res = self.decoder.ctc_greedy_search(probs, is_last=is_last)
|
| 414 |
+
times_ms, text = self.decode(res["times"], res["tokens"])
|
| 415 |
+
yield {"timestamps": times_ms, "text": text}
|
am.mvn
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<Nnet>
|
| 2 |
+
<Splice> 560 560
|
| 3 |
+
[ 0 ]
|
| 4 |
+
<AddShift> 560 560
|
| 5 |
+
<LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
|
| 6 |
+
<Rescale> 560 560
|
| 7 |
+
<LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
|
| 8 |
+
</Nnet>
|
chn_jpn_yue_eng_ko_spectok.bpe.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa87f86064c3730d799ddf7af3c04659151102cba548bce325cf06ba4da4e6a8
|
| 3 |
+
size 377341
|
client.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests, json, os
|
| 2 |
+
import librosa
|
| 3 |
+
|
| 4 |
+
class SensevoiceClient:
|
| 5 |
+
def __init__(self, model = "", url="http://0.0.0.0:12347", lauguage="auto", stream=False):
|
| 6 |
+
self.model = model
|
| 7 |
+
self.url = url
|
| 8 |
+
self.stream = stream
|
| 9 |
+
self.launguage = lauguage
|
| 10 |
+
def _check_service(self):
|
| 11 |
+
try:
|
| 12 |
+
response = requests.get(self.url + '/status')
|
| 13 |
+
if response.status_code == 200:
|
| 14 |
+
return True
|
| 15 |
+
except:
|
| 16 |
+
return False
|
| 17 |
+
|
| 18 |
+
def _start_service(self):
|
| 19 |
+
import time
|
| 20 |
+
if not self._check_service():
|
| 21 |
+
os.system("systemctl start sensevoice.service")
|
| 22 |
+
|
| 23 |
+
while not self._check_service():
|
| 24 |
+
print("Waiting for service to start...")
|
| 25 |
+
time.sleep(1)
|
| 26 |
+
|
| 27 |
+
return True
|
| 28 |
+
|
| 29 |
+
def _stop_service(self):
|
| 30 |
+
os.system("systemctl stop sensevoice.service")
|
| 31 |
+
|
| 32 |
+
def _get_status(self):
|
| 33 |
+
try:
|
| 34 |
+
response = requests.get(self.url + '/status')
|
| 35 |
+
if response.status_code == 200:
|
| 36 |
+
res = json.loads(response.text)
|
| 37 |
+
return res["status"]
|
| 38 |
+
except:
|
| 39 |
+
return "not loaded"
|
| 40 |
+
|
| 41 |
+
def _start_model(self):
|
| 42 |
+
try:
|
| 43 |
+
data = {
|
| 44 |
+
"model_path": self.model,
|
| 45 |
+
"sample_rate": 16000,
|
| 46 |
+
"language": self.launguage,
|
| 47 |
+
"stream": self.stream
|
| 48 |
+
}
|
| 49 |
+
response = requests.post(self.url + '/start_model', json=data)
|
| 50 |
+
if response.status_code == 200:
|
| 51 |
+
res = json.loads(response.text)
|
| 52 |
+
return True if res["status"] == 'loaded' else False
|
| 53 |
+
except Exception as e:
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
def _stop_model(self):
|
| 57 |
+
try:
|
| 58 |
+
response = requests.post(self.url + '/_stop_model')
|
| 59 |
+
if response.status_code == 200:
|
| 60 |
+
res = json.loads(response.text)
|
| 61 |
+
return True if res["status"] == 'not loaded' else False
|
| 62 |
+
except Exception as e:
|
| 63 |
+
return False
|
| 64 |
+
|
| 65 |
+
def start(self):
|
| 66 |
+
if self._start_service():
|
| 67 |
+
print("Service started successfully.")
|
| 68 |
+
else:
|
| 69 |
+
print("Failed to start service.")
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
if self._start_model():
|
| 73 |
+
print("Model started successfully.")
|
| 74 |
+
else:
|
| 75 |
+
print("Failed to start model.")
|
| 76 |
+
return False
|
| 77 |
+
return True
|
| 78 |
+
|
| 79 |
+
def stop_model(self):
|
| 80 |
+
self._stop_model()
|
| 81 |
+
|
| 82 |
+
def stop(self):
|
| 83 |
+
self._stop_model()
|
| 84 |
+
self._stop_service()
|
| 85 |
+
|
| 86 |
+
def get_wave_form(self, path):
|
| 87 |
+
waveform, _ = librosa.load(path, sr=16000)
|
| 88 |
+
return waveform
|
| 89 |
+
|
| 90 |
+
def refer(self, filepath):
|
| 91 |
+
if self.stream:
|
| 92 |
+
print("Streaming mode, use refer_stream() instead.")
|
| 93 |
+
return ""
|
| 94 |
+
waveform = self.get_wave_form(filepath)
|
| 95 |
+
data = {
|
| 96 |
+
"audio_data": waveform.tolist(),
|
| 97 |
+
"sample_rate": 16000,
|
| 98 |
+
"launguage": "auto"
|
| 99 |
+
}
|
| 100 |
+
try:
|
| 101 |
+
response = requests.post(self.url + '/asr', json=data)
|
| 102 |
+
if response.status_code == 200:
|
| 103 |
+
res = json.loads(response.text)
|
| 104 |
+
return res.get("text", "")
|
| 105 |
+
else:
|
| 106 |
+
print(f"Requests failed: {response.status_code}")
|
| 107 |
+
return ""
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print("Requests failed:", e)
|
| 110 |
+
return ""
|
| 111 |
+
|
| 112 |
+
def refer_stream(self, filepath):
|
| 113 |
+
if not self.stream:
|
| 114 |
+
print("Streaming mode, use refer() instead.")
|
| 115 |
+
return ""
|
| 116 |
+
waveform = self.get_wave_form(filepath)
|
| 117 |
+
data = {
|
| 118 |
+
"audio_data": waveform.tolist(),
|
| 119 |
+
"sample_rate": 16000,
|
| 120 |
+
"launguage": "auto",
|
| 121 |
+
"step": 0.1,
|
| 122 |
+
}
|
| 123 |
+
print('start post')
|
| 124 |
+
try:
|
| 125 |
+
response = requests.post(self.url + '/asr_stream', json=data, stream=True)
|
| 126 |
+
for line in response.iter_lines():
|
| 127 |
+
if line:
|
| 128 |
+
chunk = json.loads(line)
|
| 129 |
+
yield chunk.get("text", "")
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print("Requests failed:", e)
|
| 132 |
+
return ""
|
| 133 |
+
|
| 134 |
+
stream = True
|
| 135 |
+
client = SensevoiceClient(model="/root/models/sensevoice-maixcam2/model.mud", stream=stream)
|
| 136 |
+
if client.start() is False:
|
| 137 |
+
print("Failed to start service or model.")
|
| 138 |
+
exit()
|
| 139 |
+
if not stream:
|
| 140 |
+
print('start refer')
|
| 141 |
+
text = client.refer("example/zh.mp3")
|
| 142 |
+
print(text)
|
| 143 |
+
else:
|
| 144 |
+
print('start refer stream')
|
| 145 |
+
for text in client.refer_stream("example/zh.mp3"):
|
| 146 |
+
print(text)
|
config.json
ADDED
|
File without changes
|
download_dataset.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wget https://github.com/ml-inory/whisper.axera/releases/download/v1.0/datasets.zip
|
| 2 |
+
unzip datasets.zip -d ./
|
download_utils.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Speed up hf download using mirror url
|
| 4 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 5 |
+
from huggingface_hub import snapshot_download
|
| 6 |
+
|
| 7 |
+
current_file_path = os.path.dirname(__file__)
|
| 8 |
+
REPO_ROOT = "AXERA-TECH"
|
| 9 |
+
CACHE_PATH = os.path.join(current_file_path, "models")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def download_model(model_name: str) -> str:
|
| 13 |
+
"""
|
| 14 |
+
Download model from AXERA-TECH's huggingface space.
|
| 15 |
+
|
| 16 |
+
model_name: str
|
| 17 |
+
Available model names could be checked on https://huggingface.co/AXERA-TECH.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
str: Path to model_name
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
os.makedirs(CACHE_PATH, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
model_path = os.path.join(CACHE_PATH, model_name)
|
| 26 |
+
if not os.path.exists(model_path):
|
| 27 |
+
print(f"Downloading {model_name}...")
|
| 28 |
+
snapshot_download(
|
| 29 |
+
repo_id=f"{REPO_ROOT}/{model_name}",
|
| 30 |
+
local_dir=os.path.join(CACHE_PATH, model_name),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
return model_path
|
embeddings.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a453244ab037744531b97bcb8574c8442301dac11f6406fdab208dddb83b93e
|
| 3 |
+
size 25523
|
example/en.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f10378336a4e584f3f63799e62f99d5add3c2a401b51d3abe7d3a3a82f255ada
|
| 3 |
+
size 57441
|
example/ja.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:496dbc43b289e1d0d0cb916df9737450bca56acd8aaca046a7a2472363b1be53
|
| 3 |
+
size 57837
|
example/ko.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8612f62db8319a6cb4ab4b1d2039bfc32f174f89611889ddafdeb5c0a6070b5f
|
| 3 |
+
size 27909
|
example/yue.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5098eebc13530a66e4eac1f30d3246e65c9cfc4e096665f9d395aca8eff0d181
|
| 3 |
+
size 31246
|
example/zh.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e64de19e4ff9a02e682955c9112f32d2317cfdbb5bc2f3504664044c993f195
|
| 3 |
+
size 44973
|
frontend.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- encoding: utf-8 -*-
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import kaldi_native_fbank as knf
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class WavFrontend:
|
| 11 |
+
"""Conventional frontend structure for ASR."""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
cmvn_file: str = None,
|
| 16 |
+
fs: int = 16000,
|
| 17 |
+
window: str = "hamming",
|
| 18 |
+
n_mels: int = 80,
|
| 19 |
+
frame_length: int = 25,
|
| 20 |
+
frame_shift: int = 10,
|
| 21 |
+
lfr_m: int = 1,
|
| 22 |
+
lfr_n: int = 1,
|
| 23 |
+
dither: float = 1.0,
|
| 24 |
+
**kwargs,
|
| 25 |
+
) -> None:
|
| 26 |
+
|
| 27 |
+
opts = knf.FbankOptions()
|
| 28 |
+
opts.frame_opts.samp_freq = fs
|
| 29 |
+
opts.frame_opts.dither = dither
|
| 30 |
+
opts.frame_opts.window_type = window
|
| 31 |
+
opts.frame_opts.frame_shift_ms = float(frame_shift)
|
| 32 |
+
opts.frame_opts.frame_length_ms = float(frame_length)
|
| 33 |
+
opts.mel_opts.num_bins = n_mels
|
| 34 |
+
opts.energy_floor = 0
|
| 35 |
+
opts.frame_opts.snip_edges = True
|
| 36 |
+
opts.mel_opts.debug_mel = False
|
| 37 |
+
self.opts = opts
|
| 38 |
+
|
| 39 |
+
self.lfr_m = lfr_m
|
| 40 |
+
self.lfr_n = lfr_n
|
| 41 |
+
self.cmvn_file = cmvn_file
|
| 42 |
+
|
| 43 |
+
if self.cmvn_file:
|
| 44 |
+
self.cmvn = self.load_cmvn()
|
| 45 |
+
self.fbank_fn = None
|
| 46 |
+
self.fbank_beg_idx = 0
|
| 47 |
+
self.reset_status()
|
| 48 |
+
|
| 49 |
+
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 50 |
+
waveform = waveform * (1 << 15)
|
| 51 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 52 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 53 |
+
frames = self.fbank_fn.num_frames_ready
|
| 54 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 55 |
+
for i in range(frames):
|
| 56 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 57 |
+
feat = mat.astype(np.float32)
|
| 58 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 59 |
+
return feat, feat_len
|
| 60 |
+
|
| 61 |
+
def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 62 |
+
waveform = waveform * (1 << 15)
|
| 63 |
+
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 64 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
| 65 |
+
frames = self.fbank_fn.num_frames_ready
|
| 66 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 67 |
+
for i in range(self.fbank_beg_idx, frames):
|
| 68 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 69 |
+
# self.fbank_beg_idx += (frames-self.fbank_beg_idx)
|
| 70 |
+
feat = mat.astype(np.float32)
|
| 71 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 72 |
+
return feat, feat_len
|
| 73 |
+
|
| 74 |
+
def reset_status(self):
|
| 75 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 76 |
+
self.fbank_beg_idx = 0
|
| 77 |
+
|
| 78 |
+
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 79 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 80 |
+
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
|
| 81 |
+
|
| 82 |
+
if self.cmvn_file:
|
| 83 |
+
feat = self.apply_cmvn(feat)
|
| 84 |
+
|
| 85 |
+
feat_len = np.array(feat.shape[0]).astype(np.int32)
|
| 86 |
+
return feat, feat_len
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
|
| 90 |
+
LFR_inputs = []
|
| 91 |
+
|
| 92 |
+
T = inputs.shape[0]
|
| 93 |
+
T_lfr = int(np.ceil(T / lfr_n))
|
| 94 |
+
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
|
| 95 |
+
inputs = np.vstack((left_padding, inputs))
|
| 96 |
+
T = T + (lfr_m - 1) // 2
|
| 97 |
+
for i in range(T_lfr):
|
| 98 |
+
if lfr_m <= T - i * lfr_n:
|
| 99 |
+
LFR_inputs.append(
|
| 100 |
+
(inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
# process last LFR frame
|
| 104 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
| 105 |
+
frame = inputs[i * lfr_n :].reshape(-1)
|
| 106 |
+
for _ in range(num_padding):
|
| 107 |
+
frame = np.hstack((frame, inputs[-1]))
|
| 108 |
+
|
| 109 |
+
LFR_inputs.append(frame)
|
| 110 |
+
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
|
| 111 |
+
return LFR_outputs
|
| 112 |
+
|
| 113 |
+
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
|
| 114 |
+
"""
|
| 115 |
+
Apply CMVN with mvn data
|
| 116 |
+
"""
|
| 117 |
+
frame, dim = inputs.shape
|
| 118 |
+
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
|
| 119 |
+
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
|
| 120 |
+
inputs = (inputs + means) * vars
|
| 121 |
+
return inputs
|
| 122 |
+
|
| 123 |
+
def load_cmvn(
|
| 124 |
+
self,
|
| 125 |
+
) -> np.ndarray:
|
| 126 |
+
with open(self.cmvn_file, "r", encoding="utf-8") as f:
|
| 127 |
+
lines = f.readlines()
|
| 128 |
+
|
| 129 |
+
means_list = []
|
| 130 |
+
vars_list = []
|
| 131 |
+
for i in range(len(lines)):
|
| 132 |
+
line_item = lines[i].split()
|
| 133 |
+
if line_item[0] == "<AddShift>":
|
| 134 |
+
line_item = lines[i + 1].split()
|
| 135 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 136 |
+
add_shift_line = line_item[3 : (len(line_item) - 1)]
|
| 137 |
+
means_list = list(add_shift_line)
|
| 138 |
+
continue
|
| 139 |
+
elif line_item[0] == "<Rescale>":
|
| 140 |
+
line_item = lines[i + 1].split()
|
| 141 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 142 |
+
rescale_line = line_item[3 : (len(line_item) - 1)]
|
| 143 |
+
vars_list = list(rescale_line)
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
means = np.array(means_list).astype(np.float64)
|
| 147 |
+
vars = np.array(vars_list).astype(np.float64)
|
| 148 |
+
cmvn = np.array([means, vars])
|
| 149 |
+
return cmvn
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class WavFrontendOnline(WavFrontend):
|
| 153 |
+
def __init__(self, **kwargs):
|
| 154 |
+
super().__init__(**kwargs)
|
| 155 |
+
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 156 |
+
# add variables
|
| 157 |
+
self.frame_sample_length = int(
|
| 158 |
+
self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
|
| 159 |
+
)
|
| 160 |
+
self.frame_shift_sample_length = int(
|
| 161 |
+
self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
|
| 162 |
+
)
|
| 163 |
+
self.waveform = None
|
| 164 |
+
self.reserve_waveforms = None
|
| 165 |
+
self.input_cache = None
|
| 166 |
+
self.lfr_splice_cache = []
|
| 167 |
+
|
| 168 |
+
@staticmethod
|
| 169 |
+
# inputs has catted the cache
|
| 170 |
+
def apply_lfr(
|
| 171 |
+
inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
|
| 172 |
+
) -> Tuple[np.ndarray, np.ndarray, int]:
|
| 173 |
+
"""
|
| 174 |
+
Apply lfr with data
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
LFR_inputs = []
|
| 178 |
+
T = inputs.shape[0] # include the right context
|
| 179 |
+
T_lfr = int(
|
| 180 |
+
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
|
| 181 |
+
) # minus the right context: (lfr_m - 1) // 2
|
| 182 |
+
splice_idx = T_lfr
|
| 183 |
+
for i in range(T_lfr):
|
| 184 |
+
if lfr_m <= T - i * lfr_n:
|
| 185 |
+
LFR_inputs.append(
|
| 186 |
+
(inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
|
| 187 |
+
)
|
| 188 |
+
else: # process last LFR frame
|
| 189 |
+
if is_final:
|
| 190 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
| 191 |
+
frame = (inputs[i * lfr_n :]).reshape(-1)
|
| 192 |
+
for _ in range(num_padding):
|
| 193 |
+
frame = np.hstack((frame, inputs[-1]))
|
| 194 |
+
LFR_inputs.append(frame)
|
| 195 |
+
else:
|
| 196 |
+
# update splice_idx and break the circle
|
| 197 |
+
splice_idx = i
|
| 198 |
+
break
|
| 199 |
+
splice_idx = min(T - 1, splice_idx * lfr_n)
|
| 200 |
+
lfr_splice_cache = inputs[splice_idx:, :]
|
| 201 |
+
LFR_outputs = np.vstack(LFR_inputs)
|
| 202 |
+
return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def compute_frame_num(
|
| 206 |
+
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
|
| 207 |
+
) -> int:
|
| 208 |
+
frame_num = int(
|
| 209 |
+
(sample_length - frame_sample_length) / frame_shift_sample_length + 1
|
| 210 |
+
)
|
| 211 |
+
return (
|
| 212 |
+
frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def fbank(
|
| 216 |
+
self, input: np.ndarray, input_lengths: np.ndarray
|
| 217 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 218 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 219 |
+
batch_size = input.shape[0]
|
| 220 |
+
if self.input_cache is None:
|
| 221 |
+
self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
|
| 222 |
+
input = np.concatenate((self.input_cache, input), axis=1)
|
| 223 |
+
frame_num = self.compute_frame_num(
|
| 224 |
+
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
|
| 225 |
+
)
|
| 226 |
+
# update self.in_cache
|
| 227 |
+
self.input_cache = input[
|
| 228 |
+
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
|
| 229 |
+
]
|
| 230 |
+
waveforms = np.empty(0, dtype=np.float32)
|
| 231 |
+
feats_pad = np.empty(0, dtype=np.float32)
|
| 232 |
+
feats_lens = np.empty(0, dtype=np.int32)
|
| 233 |
+
if frame_num:
|
| 234 |
+
waveforms = []
|
| 235 |
+
feats = []
|
| 236 |
+
feats_lens = []
|
| 237 |
+
for i in range(batch_size):
|
| 238 |
+
waveform = input[i]
|
| 239 |
+
waveforms.append(
|
| 240 |
+
waveform[
|
| 241 |
+
: (
|
| 242 |
+
(frame_num - 1) * self.frame_shift_sample_length
|
| 243 |
+
+ self.frame_sample_length
|
| 244 |
+
)
|
| 245 |
+
]
|
| 246 |
+
)
|
| 247 |
+
waveform = waveform * (1 << 15)
|
| 248 |
+
|
| 249 |
+
self.fbank_fn.accept_waveform(
|
| 250 |
+
self.opts.frame_opts.samp_freq, waveform.tolist()
|
| 251 |
+
)
|
| 252 |
+
frames = self.fbank_fn.num_frames_ready
|
| 253 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
| 254 |
+
for i in range(frames):
|
| 255 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
| 256 |
+
feat = mat.astype(np.float32)
|
| 257 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
| 258 |
+
feats.append(feat)
|
| 259 |
+
feats_lens.append(feat_len)
|
| 260 |
+
|
| 261 |
+
waveforms = np.stack(waveforms)
|
| 262 |
+
feats_lens = np.array(feats_lens)
|
| 263 |
+
feats_pad = np.array(feats)
|
| 264 |
+
self.fbanks = feats_pad
|
| 265 |
+
self.fbanks_lens = copy.deepcopy(feats_lens)
|
| 266 |
+
return waveforms, feats_pad, feats_lens
|
| 267 |
+
|
| 268 |
+
def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
|
| 269 |
+
return self.fbanks, self.fbanks_lens
|
| 270 |
+
|
| 271 |
+
def lfr_cmvn(
|
| 272 |
+
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
| 273 |
+
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
|
| 274 |
+
batch_size = input.shape[0]
|
| 275 |
+
feats = []
|
| 276 |
+
feats_lens = []
|
| 277 |
+
lfr_splice_frame_idxs = []
|
| 278 |
+
for i in range(batch_size):
|
| 279 |
+
mat = input[i, : input_lengths[i], :]
|
| 280 |
+
lfr_splice_frame_idx = -1
|
| 281 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 282 |
+
# update self.lfr_splice_cache in self.apply_lfr
|
| 283 |
+
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
|
| 284 |
+
mat, self.lfr_m, self.lfr_n, is_final
|
| 285 |
+
)
|
| 286 |
+
if self.cmvn_file is not None:
|
| 287 |
+
mat = self.apply_cmvn(mat)
|
| 288 |
+
feat_length = mat.shape[0]
|
| 289 |
+
feats.append(mat)
|
| 290 |
+
feats_lens.append(feat_length)
|
| 291 |
+
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
|
| 292 |
+
|
| 293 |
+
feats_lens = np.array(feats_lens)
|
| 294 |
+
feats_pad = np.array(feats)
|
| 295 |
+
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
| 296 |
+
|
| 297 |
+
def extract_fbank(
|
| 298 |
+
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
| 299 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 300 |
+
batch_size = input.shape[0]
|
| 301 |
+
assert (
|
| 302 |
+
batch_size == 1
|
| 303 |
+
), "we support to extract feature online only when the batch size is equal to 1 now"
|
| 304 |
+
waveforms, feats, feats_lengths = self.fbank(
|
| 305 |
+
input, input_lengths
|
| 306 |
+
) # input shape: B T D
|
| 307 |
+
if feats.shape[0]:
|
| 308 |
+
self.waveforms = (
|
| 309 |
+
waveforms
|
| 310 |
+
if self.reserve_waveforms is None
|
| 311 |
+
else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
|
| 312 |
+
)
|
| 313 |
+
if not self.lfr_splice_cache:
|
| 314 |
+
for i in range(batch_size):
|
| 315 |
+
self.lfr_splice_cache.append(
|
| 316 |
+
np.expand_dims(feats[i][0, :], axis=0).repeat(
|
| 317 |
+
(self.lfr_m - 1) // 2, axis=0
|
| 318 |
+
)
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
|
| 322 |
+
lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
|
| 323 |
+
feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
|
| 324 |
+
feats_lengths += lfr_splice_cache_np[0].shape[0]
|
| 325 |
+
frame_from_waveforms = int(
|
| 326 |
+
(self.waveforms.shape[1] - self.frame_sample_length)
|
| 327 |
+
/ self.frame_shift_sample_length
|
| 328 |
+
+ 1
|
| 329 |
+
)
|
| 330 |
+
minus_frame = (
|
| 331 |
+
(self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
|
| 332 |
+
)
|
| 333 |
+
feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
|
| 334 |
+
feats, feats_lengths, is_final
|
| 335 |
+
)
|
| 336 |
+
if self.lfr_m == 1:
|
| 337 |
+
self.reserve_waveforms = None
|
| 338 |
+
else:
|
| 339 |
+
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
|
| 340 |
+
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
|
| 341 |
+
# print('frame_frame: ' + str(frame_from_waveforms))
|
| 342 |
+
self.reserve_waveforms = self.waveforms[
|
| 343 |
+
:,
|
| 344 |
+
reserve_frame_idx
|
| 345 |
+
* self.frame_shift_sample_length : frame_from_waveforms
|
| 346 |
+
* self.frame_shift_sample_length,
|
| 347 |
+
]
|
| 348 |
+
sample_length = (
|
| 349 |
+
frame_from_waveforms - 1
|
| 350 |
+
) * self.frame_shift_sample_length + self.frame_sample_length
|
| 351 |
+
self.waveforms = self.waveforms[:, :sample_length]
|
| 352 |
+
else:
|
| 353 |
+
# update self.reserve_waveforms and self.lfr_splice_cache
|
| 354 |
+
self.reserve_waveforms = self.waveforms[
|
| 355 |
+
:, : -(self.frame_sample_length - self.frame_shift_sample_length)
|
| 356 |
+
]
|
| 357 |
+
for i in range(batch_size):
|
| 358 |
+
self.lfr_splice_cache[i] = np.concatenate(
|
| 359 |
+
(self.lfr_splice_cache[i], feats[i]), axis=0
|
| 360 |
+
)
|
| 361 |
+
return np.empty(0, dtype=np.float32), feats_lengths
|
| 362 |
+
else:
|
| 363 |
+
if is_final:
|
| 364 |
+
self.waveforms = (
|
| 365 |
+
waveforms
|
| 366 |
+
if self.reserve_waveforms is None
|
| 367 |
+
else self.reserve_waveforms
|
| 368 |
+
)
|
| 369 |
+
feats = np.stack(self.lfr_splice_cache)
|
| 370 |
+
feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
|
| 371 |
+
feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
|
| 372 |
+
if is_final:
|
| 373 |
+
self.cache_reset()
|
| 374 |
+
return feats, feats_lengths
|
| 375 |
+
|
| 376 |
+
def get_waveforms(self):
|
| 377 |
+
return self.waveforms
|
| 378 |
+
|
| 379 |
+
def cache_reset(self):
|
| 380 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
| 381 |
+
self.reserve_waveforms = None
|
| 382 |
+
self.input_cache = None
|
| 383 |
+
self.lfr_splice_cache = []
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def load_bytes(input):
|
| 387 |
+
middle_data = np.frombuffer(input, dtype=np.int16)
|
| 388 |
+
middle_data = np.asarray(middle_data)
|
| 389 |
+
if middle_data.dtype.kind not in "iu":
|
| 390 |
+
raise TypeError("'middle_data' must be an array of integers")
|
| 391 |
+
dtype = np.dtype("float32")
|
| 392 |
+
if dtype.kind != "f":
|
| 393 |
+
raise TypeError("'dtype' must be a floating point type")
|
| 394 |
+
|
| 395 |
+
i = np.iinfo(middle_data.dtype)
|
| 396 |
+
abs_max = 2 ** (i.bits - 1)
|
| 397 |
+
offset = i.min + abs_max
|
| 398 |
+
array = np.frombuffer(
|
| 399 |
+
(middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32
|
| 400 |
+
)
|
| 401 |
+
return array
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class SinusoidalPositionEncoderOnline:
|
| 405 |
+
"""Streaming Positional encoding."""
|
| 406 |
+
|
| 407 |
+
def encode(
|
| 408 |
+
self,
|
| 409 |
+
positions: np.ndarray = None,
|
| 410 |
+
depth: int = None,
|
| 411 |
+
dtype: np.dtype = np.float32,
|
| 412 |
+
):
|
| 413 |
+
batch_size = positions.shape[0]
|
| 414 |
+
positions = positions.astype(dtype)
|
| 415 |
+
log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (
|
| 416 |
+
depth / 2 - 1
|
| 417 |
+
)
|
| 418 |
+
inv_timescales = np.exp(
|
| 419 |
+
np.arange(depth / 2).astype(dtype) * (-log_timescale_increment)
|
| 420 |
+
)
|
| 421 |
+
inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
|
| 422 |
+
scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(
|
| 423 |
+
inv_timescales, [1, 1, -1]
|
| 424 |
+
)
|
| 425 |
+
encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
|
| 426 |
+
return encoding.astype(dtype)
|
| 427 |
+
|
| 428 |
+
def forward(self, x, start_idx=0):
|
| 429 |
+
batch_size, timesteps, input_dim = x.shape
|
| 430 |
+
positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
|
| 431 |
+
position_encoding = self.encode(positions, input_dim, x.dtype)
|
| 432 |
+
|
| 433 |
+
return x + position_encoding[:, start_idx : start_idx + timesteps]
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def test():
|
| 437 |
+
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
|
| 438 |
+
import librosa
|
| 439 |
+
|
| 440 |
+
cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
|
| 441 |
+
config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
|
| 442 |
+
from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
|
| 443 |
+
|
| 444 |
+
config = read_yaml(config_file)
|
| 445 |
+
waveform, _ = librosa.load(path, sr=None)
|
| 446 |
+
frontend = WavFrontend(
|
| 447 |
+
cmvn_file=cmvn_file,
|
| 448 |
+
**config["frontend_conf"],
|
| 449 |
+
)
|
| 450 |
+
speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
|
| 451 |
+
feat, feat_len = frontend.lfr_cmvn(
|
| 452 |
+
speech
|
| 453 |
+
) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
|
| 454 |
+
|
| 455 |
+
frontend.reset_status() # clear cache
|
| 456 |
+
return feat, feat_len
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
if __name__ == "__main__":
|
| 460 |
+
test()
|
gradio_demo.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
from SenseVoiceAx import SenseVoiceAx
|
| 4 |
+
from print_utils import rich_transcription_postprocess
|
| 5 |
+
|
| 6 |
+
max_len = 256
|
| 7 |
+
|
| 8 |
+
model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
|
| 9 |
+
|
| 10 |
+
assert os.path.exists(model_path), f"model {model_path} not exist"
|
| 11 |
+
|
| 12 |
+
pipeline = SenseVoiceAx(
|
| 13 |
+
model_path,
|
| 14 |
+
max_len=max_len,
|
| 15 |
+
beam_size=3,
|
| 16 |
+
language="auto",
|
| 17 |
+
hot_words=None,
|
| 18 |
+
use_itn=True,
|
| 19 |
+
streaming=False,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def speech_to_text(audio_path, lang):
|
| 24 |
+
"""
|
| 25 |
+
audio_path: 音频文件路径
|
| 26 |
+
lang: 语言类型 "auto", "zh", "en", "yue", "ja", "ko"
|
| 27 |
+
"""
|
| 28 |
+
if not audio_path:
|
| 29 |
+
return "无音频"
|
| 30 |
+
|
| 31 |
+
pipeline.choose_language(language=lang)
|
| 32 |
+
asr_res = pipeline.infer(audio_path, print_rtf=False)
|
| 33 |
+
|
| 34 |
+
return asr_res
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main():
|
| 38 |
+
with gr.Blocks() as demo:
|
| 39 |
+
with gr.Row():
|
| 40 |
+
output_text = gr.Textbox(label="识别结果", lines=5)
|
| 41 |
+
|
| 42 |
+
with gr.Row():
|
| 43 |
+
audio_input = gr.Audio(
|
| 44 |
+
sources=["upload"], type="filepath", label="录制或上传音频", format="mp3"
|
| 45 |
+
)
|
| 46 |
+
lang_dropdown = gr.Dropdown(
|
| 47 |
+
choices=["auto", "zh", "en", "yue", "ja", "ko"],
|
| 48 |
+
value="auto",
|
| 49 |
+
label="选择音频语言",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
audio_input.change(
|
| 53 |
+
fn=speech_to_text, inputs=[audio_input, lang_dropdown], outputs=output_text
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
demo.launch(
|
| 57 |
+
server_name="0.0.0.0",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|
main.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from SenseVoiceAx import SenseVoiceAx
|
| 4 |
+
import librosa
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_args():
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument(
|
| 12 |
+
"--input", "-i", required=True, type=str, help="Input audio file"
|
| 13 |
+
)
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
"--language",
|
| 16 |
+
"-l",
|
| 17 |
+
required=False,
|
| 18 |
+
type=str,
|
| 19 |
+
default="auto",
|
| 20 |
+
choices=["auto", "zh", "en", "yue", "ja", "ko"],
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument("--streaming", action="store_true")
|
| 23 |
+
return parser.parse_args()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
args = get_args()
|
| 28 |
+
|
| 29 |
+
input_audio = args.input
|
| 30 |
+
language = args.language
|
| 31 |
+
use_itn = True # 标点符号预测
|
| 32 |
+
if not args.streaming:
|
| 33 |
+
max_len = 256
|
| 34 |
+
model_path = os.path.join("sensevoice_ax630c", "sensevoice.axmodel")
|
| 35 |
+
else:
|
| 36 |
+
max_len = 26
|
| 37 |
+
model_path = os.path.join("sensevoice_ax630c", "streaming_sensevoice.axmodel")
|
| 38 |
+
|
| 39 |
+
assert os.path.exists(model_path), f"model {model_path} not exist"
|
| 40 |
+
|
| 41 |
+
print(f"input_audio: {input_audio}")
|
| 42 |
+
print(f"language: {language}")
|
| 43 |
+
print(f"use_itn: {use_itn}")
|
| 44 |
+
print(f"model_path: {model_path}")
|
| 45 |
+
print(f"streaming: {args.streaming}")
|
| 46 |
+
|
| 47 |
+
pipeline = SenseVoiceAx(
|
| 48 |
+
model_path,
|
| 49 |
+
max_len=max_len,
|
| 50 |
+
beam_size=3,
|
| 51 |
+
language="auto",
|
| 52 |
+
hot_words=None,
|
| 53 |
+
use_itn=True,
|
| 54 |
+
streaming=args.streaming,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if not args.streaming:
|
| 58 |
+
asr_res = pipeline.infer(input_audio, print_rtf=True)
|
| 59 |
+
print("ASR result: " + asr_res)
|
| 60 |
+
else:
|
| 61 |
+
samples, sr = librosa.load(input_audio, sr=16000)
|
| 62 |
+
samples = (samples * 32768).tolist()
|
| 63 |
+
duration = len(samples) / 16000
|
| 64 |
+
|
| 65 |
+
start = time.time()
|
| 66 |
+
step = int(0.1 * sr)
|
| 67 |
+
for i in range(0, len(samples), step):
|
| 68 |
+
is_last = i + step >= len(samples)
|
| 69 |
+
for res in pipeline.stream_infer(samples[i : i + step], is_last):
|
| 70 |
+
print(res)
|
| 71 |
+
|
| 72 |
+
end = time.time()
|
| 73 |
+
cost_time = end - start
|
| 74 |
+
|
| 75 |
+
print(f"RTF: {cost_time / duration}")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
main()
|
model.mud
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[basic]
|
| 2 |
+
type = axmodel
|
| 3 |
+
model_npu = sensevoice_ax630c/sensevoice.axmodel
|
| 4 |
+
model_vnpu =
|
| 5 |
+
|
| 6 |
+
[extra]
|
| 7 |
+
model_type = sensevoice
|
| 8 |
+
input_cache = true
|
| 9 |
+
output_cache = true
|
| 10 |
+
beam_size = 3
|
| 11 |
+
language = auto
|
| 12 |
+
hot_words = None,
|
| 13 |
+
use_itn = True
|
| 14 |
+
stream_model = sensevoice_ax630c/streaming_sensevoice.axmodel
|
| 15 |
+
server_url = http://127.0.0.1:12345
|
pe_nonstream.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f1c9c550bd62fa164a959517f52d46a28591812fafdf002df0df2bd998f44b5
|
| 3 |
+
size 573568
|
pe_streaming.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54fec2fe2670168d36678c5857e65c459c634e6b6d6df928b7d415399ce2c291
|
| 3 |
+
size 58368
|
print_utils.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
emo_dict = {
|
| 2 |
+
"<|HAPPY|>": "😊",
|
| 3 |
+
"<|SAD|>": "😔",
|
| 4 |
+
"<|ANGRY|>": "😡",
|
| 5 |
+
"<|NEUTRAL|>": "",
|
| 6 |
+
"<|FEARFUL|>": "😰",
|
| 7 |
+
"<|DISGUSTED|>": "🤢",
|
| 8 |
+
"<|SURPRISED|>": "😮",
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
event_dict = {
|
| 12 |
+
"<|BGM|>": "🎼",
|
| 13 |
+
"<|Speech|>": "",
|
| 14 |
+
"<|Applause|>": "👏",
|
| 15 |
+
"<|Laughter|>": "😀",
|
| 16 |
+
"<|Cry|>": "😭",
|
| 17 |
+
"<|Sneeze|>": "🤧",
|
| 18 |
+
"<|Breath|>": "",
|
| 19 |
+
"<|Cough|>": "🤧",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
lang_dict = {
|
| 23 |
+
"<|zh|>": "<|lang|>",
|
| 24 |
+
"<|en|>": "<|lang|>",
|
| 25 |
+
"<|yue|>": "<|lang|>",
|
| 26 |
+
"<|ja|>": "<|lang|>",
|
| 27 |
+
"<|ko|>": "<|lang|>",
|
| 28 |
+
"<|nospeech|>": "<|lang|>",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
emoji_dict = {
|
| 32 |
+
"<|nospeech|><|Event_UNK|>": "❓",
|
| 33 |
+
"<|zh|>": "",
|
| 34 |
+
"<|en|>": "",
|
| 35 |
+
"<|yue|>": "",
|
| 36 |
+
"<|ja|>": "",
|
| 37 |
+
"<|ko|>": "",
|
| 38 |
+
"<|nospeech|>": "",
|
| 39 |
+
"<|HAPPY|>": "😊",
|
| 40 |
+
"<|SAD|>": "😔",
|
| 41 |
+
"<|ANGRY|>": "😡",
|
| 42 |
+
"<|NEUTRAL|>": "",
|
| 43 |
+
"<|BGM|>": "🎼",
|
| 44 |
+
"<|Speech|>": "",
|
| 45 |
+
"<|Applause|>": "👏",
|
| 46 |
+
"<|Laughter|>": "😀",
|
| 47 |
+
"<|FEARFUL|>": "😰",
|
| 48 |
+
"<|DISGUSTED|>": "🤢",
|
| 49 |
+
"<|SURPRISED|>": "😮",
|
| 50 |
+
"<|Cry|>": "😭",
|
| 51 |
+
"<|EMO_UNKNOWN|>": "",
|
| 52 |
+
"<|Sneeze|>": "🤧",
|
| 53 |
+
"<|Breath|>": "",
|
| 54 |
+
"<|Cough|>": "😷",
|
| 55 |
+
"<|Sing|>": "",
|
| 56 |
+
"<|Speech_Noise|>": "",
|
| 57 |
+
"<|withitn|>": "",
|
| 58 |
+
"<|woitn|>": "",
|
| 59 |
+
"<|GBG|>": "",
|
| 60 |
+
"<|Event_UNK|>": "",
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
|
| 64 |
+
event_set = {
|
| 65 |
+
"🎼",
|
| 66 |
+
"👏",
|
| 67 |
+
"😀",
|
| 68 |
+
"😭",
|
| 69 |
+
"🤧",
|
| 70 |
+
"😷",
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def format_str_v2(s):
|
| 75 |
+
sptk_dict = {}
|
| 76 |
+
for sptk in emoji_dict:
|
| 77 |
+
sptk_dict[sptk] = s.count(sptk)
|
| 78 |
+
s = s.replace(sptk, "")
|
| 79 |
+
emo = "<|NEUTRAL|>"
|
| 80 |
+
for e in emo_dict:
|
| 81 |
+
if sptk_dict[e] > sptk_dict[emo]:
|
| 82 |
+
emo = e
|
| 83 |
+
for e in event_dict:
|
| 84 |
+
if sptk_dict[e] > 0:
|
| 85 |
+
s = event_dict[e] + s
|
| 86 |
+
s = s + emo_dict[emo]
|
| 87 |
+
|
| 88 |
+
for emoji in emo_set.union(event_set):
|
| 89 |
+
s = s.replace(" " + emoji, emoji)
|
| 90 |
+
s = s.replace(emoji + " ", emoji)
|
| 91 |
+
return s.strip()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def rich_transcription_postprocess(s):
|
| 95 |
+
def get_emo(s):
|
| 96 |
+
return s[-1] if s[-1] in emo_set else None
|
| 97 |
+
|
| 98 |
+
def get_event(s):
|
| 99 |
+
return s[0] if s[0] in event_set else None
|
| 100 |
+
|
| 101 |
+
s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
|
| 102 |
+
for lang in lang_dict:
|
| 103 |
+
s = s.replace(lang, "<|lang|>")
|
| 104 |
+
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
|
| 105 |
+
new_s = " " + s_list[0]
|
| 106 |
+
cur_ent_event = get_event(new_s)
|
| 107 |
+
for i in range(1, len(s_list)):
|
| 108 |
+
if len(s_list[i]) == 0:
|
| 109 |
+
continue
|
| 110 |
+
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
|
| 111 |
+
s_list[i] = s_list[i][1:]
|
| 112 |
+
# else:
|
| 113 |
+
cur_ent_event = get_event(s_list[i])
|
| 114 |
+
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
|
| 115 |
+
new_s = new_s[:-1]
|
| 116 |
+
new_s += s_list[i].strip().lstrip()
|
| 117 |
+
new_s = new_s.replace("The.", " ")
|
| 118 |
+
return new_s.strip()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def rich_print_asr_res(asr_res, will_print=True, remove_punc=False):
|
| 122 |
+
res = "".join([rich_transcription_postprocess(i) for i in asr_res])
|
| 123 |
+
|
| 124 |
+
if remove_punc:
|
| 125 |
+
res = res.replace(",", "")
|
| 126 |
+
res = res.replace("。", "")
|
| 127 |
+
|
| 128 |
+
if will_print:
|
| 129 |
+
print(res)
|
| 130 |
+
|
| 131 |
+
return res
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
huggingface_hub
|
| 2 |
+
numpy<2
|
| 3 |
+
kaldi-native-fbank
|
| 4 |
+
librosa==0.9.1
|
| 5 |
+
sentencepiece
|
| 6 |
+
fastapi
|
| 7 |
+
gradio
|
| 8 |
+
emoji
|
| 9 |
+
asr-decoder
|
| 10 |
+
online-fbank
|
| 11 |
+
torch
|
sensevoice_ax630c/sensevoice.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67d290cf7cebf45db5f37b2e93b8bdfff44dc35110bb29d84204a5f9eae9fd4d
|
| 3 |
+
size 256550253
|
sensevoice_ax630c/streaming_sensevoice.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba1ddd60841297903bfdae059ad88092d0fd1c543e1d80d7f64199d4e27b8263
|
| 3 |
+
size 249023211
|
server.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from fastapi import FastAPI, HTTPException, Body
|
| 3 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
import logging
|
| 6 |
+
import json
|
| 7 |
+
import configparser
|
| 8 |
+
from SenseVoiceAx import SenseVoiceAx
|
| 9 |
+
import os
|
| 10 |
+
import librosa
|
| 11 |
+
|
| 12 |
+
# 初始化日志
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API")
|
| 17 |
+
|
| 18 |
+
# 全局变量存储模型
|
| 19 |
+
asr_model = None
|
| 20 |
+
asr_model_is_loaded = False
|
| 21 |
+
mud_configs = None
|
| 22 |
+
|
| 23 |
+
def parse_config_file_to_json(file_path):
|
| 24 |
+
"""从文件读取配置并解析为JSON"""
|
| 25 |
+
if not os.path.exists(file_path):
|
| 26 |
+
raise FileNotFoundError(f"配置文件不存在: {file_path}")
|
| 27 |
+
|
| 28 |
+
config = configparser.ConfigParser()
|
| 29 |
+
config.read(file_path, encoding='utf-8')
|
| 30 |
+
|
| 31 |
+
result = {}
|
| 32 |
+
for section in config.sections():
|
| 33 |
+
result[section] = {}
|
| 34 |
+
for key, value in config[section].items():
|
| 35 |
+
# 简单类型转换
|
| 36 |
+
value = value.strip()
|
| 37 |
+
|
| 38 |
+
if value.lower() == 'true':
|
| 39 |
+
result[section][key] = True
|
| 40 |
+
elif value.lower() == 'false':
|
| 41 |
+
result[section][key] = False
|
| 42 |
+
elif value.lower() == 'none' or value == '':
|
| 43 |
+
result[section][key] = None
|
| 44 |
+
elif value.isdigit():
|
| 45 |
+
result[section][key] = int(value)
|
| 46 |
+
else:
|
| 47 |
+
result[section][key] = value
|
| 48 |
+
|
| 49 |
+
return result
|
| 50 |
+
|
| 51 |
+
@app.on_event("startup")
|
| 52 |
+
async def load_model():
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
def validate_audio_data(audio_data: List[float]) -> np.ndarray:
|
| 56 |
+
"""
|
| 57 |
+
验证并转换音频数据为numpy数组
|
| 58 |
+
|
| 59 |
+
参数:
|
| 60 |
+
- audio_data: 浮点数列表表示的音频数据
|
| 61 |
+
|
| 62 |
+
返回:
|
| 63 |
+
- 验证后的numpy数组
|
| 64 |
+
"""
|
| 65 |
+
try:
|
| 66 |
+
# 转换为numpy数组
|
| 67 |
+
np_array = np.array(audio_data, dtype=np.float32)
|
| 68 |
+
|
| 69 |
+
# 验证数据有效性
|
| 70 |
+
if np_array.ndim != 1:
|
| 71 |
+
raise ValueError("Audio data must be 1-dimensional")
|
| 72 |
+
|
| 73 |
+
if len(np_array) == 0:
|
| 74 |
+
raise ValueError("Audio data cannot be empty")
|
| 75 |
+
|
| 76 |
+
return np_array
|
| 77 |
+
except Exception as e:
|
| 78 |
+
raise ValueError(f"Invalid audio data: {str(e)}")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@app.get("/get_language", summary="Get current language")
|
| 82 |
+
async def get_language():
|
| 83 |
+
return JSONResponse(content={"language": asr_model.language})
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@app.get(
|
| 87 |
+
"/get_language_options",
|
| 88 |
+
summary="Get possible language options, possible options include [auto, zh, en, yue, ja, ko]",
|
| 89 |
+
)
|
| 90 |
+
async def get_language_options():
|
| 91 |
+
return JSONResponse(content={"language_options": asr_model.language_options})
|
| 92 |
+
|
| 93 |
+
@app.get("/status", summary="Get ASR model status")
|
| 94 |
+
async def get_status():
|
| 95 |
+
global asr_model_is_loaded
|
| 96 |
+
return JSONResponse(content={"status": "loaded" if asr_model_is_loaded else "not loaded"})
|
| 97 |
+
|
| 98 |
+
@app.post("/start_model", summary="Load model")
|
| 99 |
+
async def start_model(
|
| 100 |
+
model_path: str = Body(
|
| 101 |
+
"sensevoice_ax630c/sensevoice.axmodel",
|
| 102 |
+
description="Path to the model file",
|
| 103 |
+
),
|
| 104 |
+
language: str = Body("auto", description="Language"),
|
| 105 |
+
stream: bool = Body(False, description="streaming or not"),
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
服务启动时加载ASR模型
|
| 109 |
+
"""
|
| 110 |
+
global asr_model
|
| 111 |
+
global asr_model_is_loaded
|
| 112 |
+
logger.info("Loading ASR model...")
|
| 113 |
+
|
| 114 |
+
if asr_model_is_loaded:
|
| 115 |
+
return JSONResponse(content={"status": "loaded"})
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
mud_configs = parse_config_file_to_json(model_path)
|
| 119 |
+
axmodel_path = mud_configs.get("basic", {}).get("model_npu", None)
|
| 120 |
+
streaming_axmodel_path = mud_configs.get("extra", {}).get("stream_model", None)
|
| 121 |
+
model_dir_path = os.path.dirname(model_path)
|
| 122 |
+
if stream:
|
| 123 |
+
if streaming_axmodel_path is None:
|
| 124 |
+
logger.error(f"Not found model:{streaming_axmodel_path}")
|
| 125 |
+
raise HTTPException(status_code=400, detail=f"Not found model:{streaming_axmodel_path}")
|
| 126 |
+
model_path = os.path.join(model_dir_path, streaming_axmodel_path)
|
| 127 |
+
else:
|
| 128 |
+
if axmodel_path is None:
|
| 129 |
+
logger.error(f"Not found model:{axmodel_path}")
|
| 130 |
+
raise HTTPException(status_code=400, detail=f"Not found model:{axmodel_path}")
|
| 131 |
+
model_path = os.path.join(model_dir_path, axmodel_path)
|
| 132 |
+
|
| 133 |
+
# 模型加载
|
| 134 |
+
use_itn = mud_configs.get("extra", {}).get("use_itn", True) # 逆文本规范
|
| 135 |
+
beam_size = mud_configs.get("extra", {}).get("beam_size", 3)
|
| 136 |
+
hot_words = mud_configs.get("extra", {}).get("hot_words", None)
|
| 137 |
+
use_itn = mud_configs.get("extra", {}).get("use_itn", True)
|
| 138 |
+
streaming = stream
|
| 139 |
+
max_len = 26 if streaming else 256
|
| 140 |
+
|
| 141 |
+
print(f'model path: {model_path}')
|
| 142 |
+
print(f'max_len: {max_len}')
|
| 143 |
+
print(f'beam_size: {beam_size}')
|
| 144 |
+
print(f"language: {language}")
|
| 145 |
+
print(f'hot_words: {hot_words}')
|
| 146 |
+
print(f"use_itn: {use_itn}")
|
| 147 |
+
print(f'streaming: {streaming}')
|
| 148 |
+
|
| 149 |
+
if not os.path.exists(model_path):
|
| 150 |
+
raise HTTPException(status_code=400, detail=f"model {model_path} not exist")
|
| 151 |
+
|
| 152 |
+
asr_model = SenseVoiceAx(
|
| 153 |
+
model_path,
|
| 154 |
+
max_len=max_len,
|
| 155 |
+
beam_size=beam_size,
|
| 156 |
+
language=language,
|
| 157 |
+
hot_words=hot_words,
|
| 158 |
+
use_itn=use_itn,
|
| 159 |
+
streaming=streaming,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
logger.info("ASR model loaded successfully")
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error(f"Failed to load ASR model: {str(e)}")
|
| 165 |
+
raise
|
| 166 |
+
|
| 167 |
+
return JSONResponse(content={"status": "loaded"})
|
| 168 |
+
|
| 169 |
+
@app.post("/stop_model", summary="Load model")
|
| 170 |
+
async def stop_model(
|
| 171 |
+
):
|
| 172 |
+
global asr_model
|
| 173 |
+
global asr_model_is_loaded
|
| 174 |
+
del asr_model
|
| 175 |
+
asr_model = None
|
| 176 |
+
asr_model_is_loaded = False
|
| 177 |
+
|
| 178 |
+
@app.post("/asr", summary="Recognize speech from numpy audio data")
|
| 179 |
+
async def recognize_speech(
|
| 180 |
+
audio_data: List[float] = Body(
|
| 181 |
+
..., embed=True, description="Audio data as list of floats"
|
| 182 |
+
),
|
| 183 |
+
sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
|
| 184 |
+
language: Optional[str] = Body("auto", description="Language"),
|
| 185 |
+
):
|
| 186 |
+
"""
|
| 187 |
+
接收numpy数组格式的音频数据并返回识别结果
|
| 188 |
+
|
| 189 |
+
参数:
|
| 190 |
+
- audio_data: 浮点数列表表示的音频数据
|
| 191 |
+
- sample_rate: 音频采样率(默认16000Hz)
|
| 192 |
+
|
| 193 |
+
返回:
|
| 194 |
+
- JSON包含识别文本
|
| 195 |
+
"""
|
| 196 |
+
try:
|
| 197 |
+
# 检查模型是否已加载
|
| 198 |
+
if asr_model is None:
|
| 199 |
+
raise HTTPException(status_code=503, detail="ASR model not loaded")
|
| 200 |
+
|
| 201 |
+
logger.info(f"Received audio data with length: {len(audio_data)}")
|
| 202 |
+
|
| 203 |
+
# 验证并转换数据
|
| 204 |
+
np_audio = validate_audio_data(audio_data)
|
| 205 |
+
if sample_rate != asr_model.sample_rate:
|
| 206 |
+
np_audio = librosa.resample(np_audio, sample_rate, asr_model.sample_rate)
|
| 207 |
+
|
| 208 |
+
# 调用模型进行识别
|
| 209 |
+
result = asr_model.infer_waveform(np_audio, language)
|
| 210 |
+
|
| 211 |
+
return JSONResponse(content={"text": result})
|
| 212 |
+
|
| 213 |
+
except ValueError as e:
|
| 214 |
+
logger.error(f"Validation error: {str(e)}")
|
| 215 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"Recognition error: {str(e)}")
|
| 218 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 219 |
+
|
| 220 |
+
@app.post("/asr_stream", summary="Recognize speech from numpy audio data")
|
| 221 |
+
async def recognize_speech_stream(
|
| 222 |
+
audio_data: List[float] = Body(
|
| 223 |
+
..., embed=True, description="Audio data as list of floats"
|
| 224 |
+
),
|
| 225 |
+
sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
|
| 226 |
+
language: Optional[str] = Body("auto", description="Language"),
|
| 227 |
+
step: Optional[float] = Body(0.1, description="step in seconds"),
|
| 228 |
+
):
|
| 229 |
+
"""
|
| 230 |
+
接收numpy数组格式的音频数据并返回识别结果
|
| 231 |
+
|
| 232 |
+
参数:
|
| 233 |
+
- audio_data: 浮点数列表表示的音频数据
|
| 234 |
+
- sample_rate: 音频采样率(默认16000Hz)
|
| 235 |
+
|
| 236 |
+
返回:
|
| 237 |
+
- JSON包含识别文本
|
| 238 |
+
"""
|
| 239 |
+
try:
|
| 240 |
+
# 检查模型是否已加载
|
| 241 |
+
if asr_model is None:
|
| 242 |
+
raise HTTPException(status_code=503, detail="ASR model not loaded")
|
| 243 |
+
|
| 244 |
+
logger.info(f"Received audio data with length: {len(audio_data)}")
|
| 245 |
+
|
| 246 |
+
# 验证并转换数据
|
| 247 |
+
np_audio = validate_audio_data(audio_data)
|
| 248 |
+
if sample_rate != asr_model.sample_rate:
|
| 249 |
+
np_audio = librosa.resample(np_audio, sample_rate, asr_model.sample_rate)
|
| 250 |
+
# 调用模型进行识别
|
| 251 |
+
def stream_infer(np_audio, step):
|
| 252 |
+
samples = (np_audio * 32768).tolist()
|
| 253 |
+
|
| 254 |
+
step = int(step * 16000)
|
| 255 |
+
for i in range(0, len(samples), step):
|
| 256 |
+
is_last = i + step >= len(samples)
|
| 257 |
+
for res in asr_model.stream_infer(samples[i : i + step], is_last, language):
|
| 258 |
+
yield json.dumps(res) + "\n"
|
| 259 |
+
return StreamingResponse(stream_infer(np_audio, step), media_type="application/json")
|
| 260 |
+
except ValueError as e:
|
| 261 |
+
logger.error(f"Validation error: {str(e)}")
|
| 262 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 263 |
+
except Exception as e:
|
| 264 |
+
logger.error(f"Recognition error: {str(e)}")
|
| 265 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
import uvicorn
|
| 269 |
+
|
| 270 |
+
uvicorn.run(app, host="0.0.0.0", port=12347)
|
test_wer.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
import argparse
|
| 3 |
+
from SenseVoiceAx import SenseVoiceAx
|
| 4 |
+
from tokenizer import SentencepiecesTokenizer
|
| 5 |
+
from print_utils import rich_transcription_postprocess, rich_print_asr_res
|
| 6 |
+
from download_utils import download_model
|
| 7 |
+
import logging
|
| 8 |
+
import re
|
| 9 |
+
import emoji
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def setup_logging():
|
| 13 |
+
"""配置日志系统,同时输出到控制台和文件"""
|
| 14 |
+
# 获取脚本所在目录
|
| 15 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
log_file = os.path.join(script_dir, "test_wer.log")
|
| 17 |
+
|
| 18 |
+
# 配置日志格式
|
| 19 |
+
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
| 20 |
+
date_format = "%Y-%m-%d %H:%M:%S"
|
| 21 |
+
|
| 22 |
+
# 创建logger
|
| 23 |
+
logger = logging.getLogger()
|
| 24 |
+
logger.setLevel(logging.INFO)
|
| 25 |
+
|
| 26 |
+
# 清除现有的handler
|
| 27 |
+
for handler in logger.handlers[:]:
|
| 28 |
+
logger.removeHandler(handler)
|
| 29 |
+
|
| 30 |
+
# 创建文件handler
|
| 31 |
+
file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
|
| 32 |
+
file_handler.setLevel(logging.INFO)
|
| 33 |
+
file_formatter = logging.Formatter(log_format, date_format)
|
| 34 |
+
file_handler.setFormatter(file_formatter)
|
| 35 |
+
|
| 36 |
+
# 创建控制台handler
|
| 37 |
+
console_handler = logging.StreamHandler()
|
| 38 |
+
console_handler.setLevel(logging.INFO)
|
| 39 |
+
console_formatter = logging.Formatter(log_format, date_format)
|
| 40 |
+
console_handler.setFormatter(console_formatter)
|
| 41 |
+
|
| 42 |
+
# 添加handler到logger
|
| 43 |
+
logger.addHandler(file_handler)
|
| 44 |
+
logger.addHandler(console_handler)
|
| 45 |
+
|
| 46 |
+
return logger
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AIShellDataset:
|
| 50 |
+
def __init__(self, gt_path: str):
|
| 51 |
+
"""
|
| 52 |
+
初始化数据集
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
json_path: voice.json文件的路径
|
| 56 |
+
"""
|
| 57 |
+
self.gt_path = gt_path
|
| 58 |
+
self.dataset_dir = os.path.dirname(gt_path)
|
| 59 |
+
self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
|
| 60 |
+
|
| 61 |
+
# 检查必要文件和文件夹是否存在
|
| 62 |
+
assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
|
| 63 |
+
assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
|
| 64 |
+
|
| 65 |
+
# 加载数据
|
| 66 |
+
self.data = []
|
| 67 |
+
with open(gt_path, "r", encoding="utf-8") as f:
|
| 68 |
+
for line in f:
|
| 69 |
+
line = line.strip()
|
| 70 |
+
audio_path, gt = line.split(" ")
|
| 71 |
+
audio_path = os.path.join(self.voice_dir, audio_path + ".wav")
|
| 72 |
+
self.data.append({"audio_path": audio_path, "gt": gt})
|
| 73 |
+
|
| 74 |
+
# 使用logging而不是print
|
| 75 |
+
logger = logging.getLogger()
|
| 76 |
+
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 77 |
+
|
| 78 |
+
def __iter__(self):
|
| 79 |
+
"""返回迭代器"""
|
| 80 |
+
self.index = 0
|
| 81 |
+
return self
|
| 82 |
+
|
| 83 |
+
def __next__(self):
|
| 84 |
+
"""返回下一个数据项"""
|
| 85 |
+
if self.index >= len(self.data):
|
| 86 |
+
raise StopIteration
|
| 87 |
+
|
| 88 |
+
item = self.data[self.index]
|
| 89 |
+
audio_path = item["audio_path"]
|
| 90 |
+
ground_truth = item["gt"]
|
| 91 |
+
|
| 92 |
+
self.index += 1
|
| 93 |
+
return audio_path, ground_truth
|
| 94 |
+
|
| 95 |
+
def __len__(self):
|
| 96 |
+
"""返回数据集大小"""
|
| 97 |
+
return len(self.data)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class CommonVoiceDataset:
|
| 101 |
+
"""Common Voice数据集解析器"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, tsv_path: str):
|
| 104 |
+
"""
|
| 105 |
+
初始化数据集
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
json_path: voice.json文件的路径
|
| 109 |
+
"""
|
| 110 |
+
self.tsv_path = tsv_path
|
| 111 |
+
self.dataset_dir = os.path.dirname(tsv_path)
|
| 112 |
+
self.voice_dir = os.path.join(self.dataset_dir, "clips")
|
| 113 |
+
|
| 114 |
+
# 检查必要文件和文件夹是否存在
|
| 115 |
+
assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
|
| 116 |
+
assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
|
| 117 |
+
|
| 118 |
+
# 加载JSON数据
|
| 119 |
+
self.data = []
|
| 120 |
+
with open(tsv_path, "r", encoding="utf-8") as f:
|
| 121 |
+
f.readline()
|
| 122 |
+
for line in f:
|
| 123 |
+
line = line.strip()
|
| 124 |
+
splits = line.split("\t")
|
| 125 |
+
audio_path = splits[1]
|
| 126 |
+
gt = splits[3]
|
| 127 |
+
audio_path = os.path.join(self.voice_dir, audio_path)
|
| 128 |
+
self.data.append({"audio_path": audio_path, "gt": gt})
|
| 129 |
+
|
| 130 |
+
# 使用logging而不是print
|
| 131 |
+
logger = logging.getLogger()
|
| 132 |
+
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 133 |
+
|
| 134 |
+
def __iter__(self):
|
| 135 |
+
"""返回迭代器"""
|
| 136 |
+
self.index = 0
|
| 137 |
+
return self
|
| 138 |
+
|
| 139 |
+
def __next__(self):
|
| 140 |
+
"""返回下一个数据项"""
|
| 141 |
+
if self.index >= len(self.data):
|
| 142 |
+
raise StopIteration
|
| 143 |
+
|
| 144 |
+
item = self.data[self.index]
|
| 145 |
+
audio_path = item["audio_path"]
|
| 146 |
+
ground_truth = item["gt"]
|
| 147 |
+
|
| 148 |
+
self.index += 1
|
| 149 |
+
return audio_path, ground_truth
|
| 150 |
+
|
| 151 |
+
def __len__(self):
|
| 152 |
+
"""返回数据集大小"""
|
| 153 |
+
return len(self.data)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_args():
|
| 157 |
+
parser = argparse.ArgumentParser()
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--dataset",
|
| 160 |
+
"-d",
|
| 161 |
+
type=str,
|
| 162 |
+
required=True,
|
| 163 |
+
choices=["aishell", "common_voice"],
|
| 164 |
+
help="Test dataset",
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--gt_path",
|
| 168 |
+
"-g",
|
| 169 |
+
type=str,
|
| 170 |
+
required=True,
|
| 171 |
+
help="Test dataset ground truth file",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--language",
|
| 175 |
+
"-l",
|
| 176 |
+
required=False,
|
| 177 |
+
type=str,
|
| 178 |
+
default="auto",
|
| 179 |
+
choices=["auto", "zh", "en", "yue", "ja", "ko"],
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--max_num", type=int, default=-1, required=False, help="Maximum test data num"
|
| 183 |
+
)
|
| 184 |
+
return parser.parse_args()
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def min_distance(word1: str, word2: str) -> int:
|
| 188 |
+
|
| 189 |
+
row = len(word1) + 1
|
| 190 |
+
column = len(word2) + 1
|
| 191 |
+
|
| 192 |
+
cache = [[0] * column for i in range(row)]
|
| 193 |
+
|
| 194 |
+
for i in range(row):
|
| 195 |
+
for j in range(column):
|
| 196 |
+
|
| 197 |
+
if i == 0 and j == 0:
|
| 198 |
+
cache[i][j] = 0
|
| 199 |
+
elif i == 0 and j != 0:
|
| 200 |
+
cache[i][j] = j
|
| 201 |
+
elif j == 0 and i != 0:
|
| 202 |
+
cache[i][j] = i
|
| 203 |
+
else:
|
| 204 |
+
if word1[i - 1] == word2[j - 1]:
|
| 205 |
+
cache[i][j] = cache[i - 1][j - 1]
|
| 206 |
+
else:
|
| 207 |
+
replace = cache[i - 1][j - 1] + 1
|
| 208 |
+
insert = cache[i][j - 1] + 1
|
| 209 |
+
remove = cache[i - 1][j] + 1
|
| 210 |
+
|
| 211 |
+
cache[i][j] = min(replace, insert, remove)
|
| 212 |
+
|
| 213 |
+
return cache[row - 1][column - 1]
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def remove_punctuation(text):
|
| 217 |
+
# 定义正则表达式模式,匹配所有标点符号
|
| 218 |
+
# 这个模式包括常见的标点符号和中文标点
|
| 219 |
+
pattern = r"[^\w\s]|_"
|
| 220 |
+
|
| 221 |
+
# 使用sub方法将所有匹配的标点符号替换为空字符串
|
| 222 |
+
cleaned_text = re.sub(pattern, "", text)
|
| 223 |
+
|
| 224 |
+
return cleaned_text
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def main():
|
| 228 |
+
logger = setup_logging()
|
| 229 |
+
args = get_args()
|
| 230 |
+
|
| 231 |
+
language = args.language
|
| 232 |
+
use_itn = False # 标点符号预测
|
| 233 |
+
max_num = args.max_num
|
| 234 |
+
|
| 235 |
+
dataset_type = args.dataset.lower()
|
| 236 |
+
if dataset_type == "aishell":
|
| 237 |
+
dataset = AIShellDataset(args.gt_path)
|
| 238 |
+
elif dataset_type == "common_voice":
|
| 239 |
+
dataset = CommonVoiceDataset(args.gt_path)
|
| 240 |
+
else:
|
| 241 |
+
raise ValueError(f"Unknown dataset type {dataset_type}")
|
| 242 |
+
|
| 243 |
+
# model_path_root = download_model("SenseVoice")
|
| 244 |
+
model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
|
| 245 |
+
bpemodel = "chn_jpn_yue_eng_ko_spectok.bpe.model"
|
| 246 |
+
|
| 247 |
+
assert os.path.exists(model_path), f"model {model_path} not exist"
|
| 248 |
+
|
| 249 |
+
logger.info(f"dataset: {args.dataset}")
|
| 250 |
+
logger.info(f"language: {language}")
|
| 251 |
+
logger.info(f"use_itn: {use_itn}")
|
| 252 |
+
logger.info(f"model_path: {model_path}")
|
| 253 |
+
|
| 254 |
+
tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
|
| 255 |
+
pipeline = SenseVoiceAx(
|
| 256 |
+
model_path, language=language, use_itn=use_itn, tokenizer=tokenizer, max_len=256
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Iterate over dataset
|
| 260 |
+
hyp = []
|
| 261 |
+
references = []
|
| 262 |
+
all_character_error_num = 0
|
| 263 |
+
all_character_num = 0
|
| 264 |
+
max_data_num = max_num if max_num > 0 else len(dataset)
|
| 265 |
+
for n, (audio_path, reference) in enumerate(dataset):
|
| 266 |
+
reference = remove_punctuation(reference).lower()
|
| 267 |
+
|
| 268 |
+
asr_res = pipeline.infer(audio_path, print_rtf=False)
|
| 269 |
+
hypothesis = rich_print_asr_res(
|
| 270 |
+
asr_res, will_print=False, remove_punc=True
|
| 271 |
+
).lower()
|
| 272 |
+
hypothesis = emoji.replace_emoji(hypothesis, replace="")
|
| 273 |
+
|
| 274 |
+
character_error_num = min_distance(reference, hypothesis)
|
| 275 |
+
character_num = len(reference)
|
| 276 |
+
character_error_rate = character_error_num / character_num * 100
|
| 277 |
+
|
| 278 |
+
all_character_error_num += character_error_num
|
| 279 |
+
all_character_num += character_num
|
| 280 |
+
|
| 281 |
+
hyp.append(hypothesis)
|
| 282 |
+
references.append(reference)
|
| 283 |
+
|
| 284 |
+
line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
|
| 285 |
+
logger.info(line_content)
|
| 286 |
+
|
| 287 |
+
if n + 1 >= max_data_num:
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
total_character_error_rate = all_character_error_num / all_character_num * 100
|
| 291 |
+
|
| 292 |
+
logger.info(f"Total WER: {total_character_error_rate}%")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
main()
|
tokenizer.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sentencepiece as spm
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from abc import abstractmethod
|
| 8 |
+
from abc import ABC
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseTokenizer(ABC):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
token_list: Union[Path, str, Iterable[str]] = None,
|
| 16 |
+
unk_symbol: str = "<unk>",
|
| 17 |
+
**kwargs,
|
| 18 |
+
):
|
| 19 |
+
|
| 20 |
+
if token_list is not None:
|
| 21 |
+
if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
|
| 22 |
+
token_list = Path(token_list)
|
| 23 |
+
self.token_list_repr = str(token_list)
|
| 24 |
+
self.token_list: List[str] = []
|
| 25 |
+
|
| 26 |
+
with token_list.open("r", encoding="utf-8") as f:
|
| 27 |
+
for idx, line in enumerate(f):
|
| 28 |
+
line = line.rstrip()
|
| 29 |
+
self.token_list.append(line)
|
| 30 |
+
elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
|
| 31 |
+
token_list = Path(token_list)
|
| 32 |
+
self.token_list_repr = str(token_list)
|
| 33 |
+
self.token_list: List[str] = []
|
| 34 |
+
|
| 35 |
+
with open(token_list, "r", encoding="utf-8") as f:
|
| 36 |
+
self.token_list = json.load(f)
|
| 37 |
+
|
| 38 |
+
else:
|
| 39 |
+
self.token_list: List[str] = list(token_list)
|
| 40 |
+
self.token_list_repr = ""
|
| 41 |
+
for i, t in enumerate(self.token_list):
|
| 42 |
+
if i == 3:
|
| 43 |
+
break
|
| 44 |
+
self.token_list_repr += f"{t}, "
|
| 45 |
+
self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
|
| 46 |
+
|
| 47 |
+
self.token2id: Dict[str, int] = {}
|
| 48 |
+
for i, t in enumerate(self.token_list):
|
| 49 |
+
if t in self.token2id:
|
| 50 |
+
raise RuntimeError(f'Symbol "{t}" is duplicated')
|
| 51 |
+
self.token2id[t] = i
|
| 52 |
+
|
| 53 |
+
self.unk_symbol = unk_symbol
|
| 54 |
+
if self.unk_symbol not in self.token2id:
|
| 55 |
+
raise RuntimeError(
|
| 56 |
+
f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
|
| 57 |
+
)
|
| 58 |
+
self.unk_id = self.token2id[self.unk_symbol]
|
| 59 |
+
|
| 60 |
+
def encode(self, text, **kwargs):
|
| 61 |
+
tokens = self.text2tokens(text)
|
| 62 |
+
text_ints = self.tokens2ids(tokens)
|
| 63 |
+
|
| 64 |
+
return text_ints
|
| 65 |
+
|
| 66 |
+
def decode(self, text_ints):
|
| 67 |
+
token = self.ids2tokens(text_ints)
|
| 68 |
+
text = self.tokens2text(token)
|
| 69 |
+
return text
|
| 70 |
+
|
| 71 |
+
def get_num_vocabulary_size(self) -> int:
|
| 72 |
+
return len(self.token_list)
|
| 73 |
+
|
| 74 |
+
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
| 75 |
+
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
| 76 |
+
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
| 77 |
+
return [self.token_list[i] for i in integers]
|
| 78 |
+
|
| 79 |
+
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
| 80 |
+
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
| 81 |
+
|
| 82 |
+
@abstractmethod
|
| 83 |
+
def text2tokens(self, line: str) -> List[str]:
|
| 84 |
+
raise NotImplementedError
|
| 85 |
+
|
| 86 |
+
@abstractmethod
|
| 87 |
+
def tokens2text(self, tokens: Iterable[str]) -> str:
|
| 88 |
+
raise NotImplementedError
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class SentencepiecesTokenizer(BaseTokenizer):
|
| 92 |
+
def __init__(self, bpemodel: Union[Path, str], **kwargs):
|
| 93 |
+
super().__init__(**kwargs)
|
| 94 |
+
self.bpemodel = str(bpemodel)
|
| 95 |
+
# NOTE(kamo):
|
| 96 |
+
# Don't build SentencePieceProcessor in __init__()
|
| 97 |
+
# because it's not picklable and it may cause following error,
|
| 98 |
+
# "TypeError: can't pickle SwigPyObject objects",
|
| 99 |
+
# when giving it as argument of "multiprocessing.Process()".
|
| 100 |
+
self.sp = None
|
| 101 |
+
self._build_sentence_piece_processor()
|
| 102 |
+
|
| 103 |
+
def __repr__(self):
|
| 104 |
+
return f'{self.__class__.__name__}(model="{self.bpemodel}")'
|
| 105 |
+
|
| 106 |
+
def _build_sentence_piece_processor(self):
|
| 107 |
+
# Build SentencePieceProcessor lazily.
|
| 108 |
+
if self.sp is None:
|
| 109 |
+
self.sp = spm.SentencePieceProcessor()
|
| 110 |
+
self.sp.load(self.bpemodel)
|
| 111 |
+
|
| 112 |
+
def text2tokens(self, line: str) -> List[str]:
|
| 113 |
+
self._build_sentence_piece_processor()
|
| 114 |
+
return self.sp.EncodeAsPieces(line)
|
| 115 |
+
|
| 116 |
+
def tokens2text(self, tokens: Iterable[str]) -> str:
|
| 117 |
+
self._build_sentence_piece_processor()
|
| 118 |
+
return self.sp.DecodePieces(list(tokens))
|
| 119 |
+
|
| 120 |
+
def encode(self, line: str, **kwargs) -> List[int]:
|
| 121 |
+
self._build_sentence_piece_processor()
|
| 122 |
+
return self.sp.EncodeAsIds(line)
|
| 123 |
+
|
| 124 |
+
def decode(self, line: List[int], **kwargs):
|
| 125 |
+
self._build_sentence_piece_processor()
|
| 126 |
+
return self.sp.DecodeIds(line)
|
| 127 |
+
|
| 128 |
+
def get_vocab_size(self):
|
| 129 |
+
return self.sp.GetPieceSize()
|
| 130 |
+
|
| 131 |
+
def ids2tokens(self, *args, **kwargs):
|
| 132 |
+
return self.decode(*args, **kwargs)
|
| 133 |
+
|
| 134 |
+
def tokens2ids(self, *args, **kwargs):
|
| 135 |
+
return self.encode(*args, **kwargs)
|