inoryQwQ
commited on
Commit
·
798e40d
1
Parent(s):
33455e8
Update cpp bins, python scripts, add English readme
Browse files- .gitattributes +4 -0
- .gitignore +1 -0
- README.md +61 -74
- README_EN.md +261 -0
- cpp/ax630c/TSCharacters.ocd2 +0 -0
- cpp/ax630c/TSPhrases.ocd2 +0 -0
- cpp/ax630c/include/ax_whisper_api.h +107 -0
- cpp/ax630c/lib/cmake/ax_whisper/ax_whisper-config-release.cmake +19 -0
- cpp/ax630c/lib/cmake/ax_whisper/ax_whisper-config.cmake +94 -0
- cpp/{TSCharacters.ocd2 → ax630c/lib/libax_whisper.so} +2 -2
- cpp/ax630c/t2s.json +22 -0
- cpp/ax630c/whisper_cli +0 -0
- cpp/{TSPhrases.ocd2 → ax630c/whisper_svr} +2 -2
- cpp/ax650/TSCharacters.ocd2 +0 -0
- cpp/ax650/TSPhrases.ocd2 +0 -0
- cpp/ax650/include/ax_whisper_api.h +107 -0
- cpp/ax650/lib/cmake/ax_whisper/ax_whisper-config-release.cmake +19 -0
- cpp/ax650/lib/cmake/ax_whisper/ax_whisper-config.cmake +94 -0
- cpp/{t2s.json → ax650/lib/libax_whisper.so} +2 -2
- cpp/ax650/t2s.json +22 -0
- cpp/ax650/whisper_cli +0 -0
- cpp/{whisper_aarch64 → ax650/whisper_svr} +2 -2
- cpp/whisper_axcl_aarch64 +0 -3
- cpp/whisper_axcl_x86 +0 -3
- cpp/whisper_srv +0 -3
- python/assets/multilingual.tiktoken +0 -0
- python/languages.py +0 -102
- python/main.py +0 -74
- python/test_svr.py +46 -0
- python/test_wer.py +196 -25
- python/{whisper.py → whisper_ax.py} +76 -45
- python/whisper_cli.py +59 -35
- python/whisper_svr.py +1 -1
- python/whisper_tokenizer.py +0 -395
.gitattributes
CHANGED
|
@@ -51,3 +51,7 @@ ax620e/install/whisper filter=lfs diff=lfs merge=lfs -text
|
|
| 51 |
*axcl_aarch64/whisper filter=lfs diff=lfs merge=lfs -text
|
| 52 |
*ax650/install/whisper filter=lfs diff=lfs merge=lfs -text
|
| 53 |
*ax620e/install/whisper filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
*axcl_aarch64/whisper filter=lfs diff=lfs merge=lfs -text
|
| 52 |
*ax650/install/whisper filter=lfs diff=lfs merge=lfs -text
|
| 53 |
*ax620e/install/whisper filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
cpp/ax630c/lib/libax_whisper.so filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
cpp/ax650/lib/libax_whisper.so filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
cpp/ax650/whisper_svr filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
cpp/ax630c/whisper_svr filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
README.md
CHANGED
|
@@ -5,6 +5,10 @@ pipeline_tag: automatic-speech-recognition
|
|
| 5 |
|
| 6 |
# Whisper
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
OpenAI Whisper on Axera
|
| 9 |
|
| 10 |
- 目前支持 C++ 和 Python 两种语言
|
|
@@ -13,6 +17,12 @@ OpenAI Whisper on Axera
|
|
| 13 |
|
| 14 |
- 如需自行转换请参考[模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
## 支持平台
|
| 17 |
|
| 18 |
- [x] AX650N
|
|
@@ -20,6 +30,21 @@ OpenAI Whisper on Axera
|
|
| 20 |
|
| 21 |
## 模型转换
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
[模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
|
| 24 |
|
| 25 |
## 上板部署
|
|
@@ -87,60 +112,22 @@ python3 main.py --model_type small --model_path ../models-ax650 --wav ../demo.wa
|
|
| 87 |
输出结果
|
| 88 |
|
| 89 |
```
|
| 90 |
-
root@ax650:/mnt/
|
| 91 |
[INFO] Available providers: ['AxEngineExecutionProvider']
|
| 92 |
-
wav: ../demo.wav
|
| 93 |
-
model_type: small
|
| 94 |
-
model_path: ../models/
|
| 95 |
-
language: zh
|
| 96 |
[INFO] Using provider: AxEngineExecutionProvider
|
| 97 |
[INFO] Chip type: ChipType.MC50
|
| 98 |
[INFO] VNPU type: VNPUType.DISABLED
|
| 99 |
-
[INFO] Engine version: 2.
|
| 100 |
[INFO] Model type: 2 (triple core)
|
| 101 |
-
[INFO] Compiler version:
|
| 102 |
[INFO] Using provider: AxEngineExecutionProvider
|
| 103 |
[INFO] Model type: 2 (triple core)
|
| 104 |
-
[INFO] Compiler version:
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
Preprocess wav take 6971.68493270874ms
|
| 110 |
-
Run encoder take 211.52877807617188ms
|
| 111 |
-
Run decoder_main take 79.00094985961914ms
|
| 112 |
-
First token: 17556
|
| 113 |
-
Run decoder_loop take 101.91774368286133ms
|
| 114 |
-
Iter 0 Token: 20844
|
| 115 |
-
Run decoder_loop take 60.30416488647461ms
|
| 116 |
-
Iter 1 Token: 7781
|
| 117 |
-
Run decoder_loop take 60.22000312805176ms
|
| 118 |
-
Iter 2 Token: 20204
|
| 119 |
-
Run decoder_loop take 60.23716926574707ms
|
| 120 |
-
Iter 3 Token: 28455
|
| 121 |
-
Run decoder_loop take 60.214996337890625ms
|
| 122 |
-
Iter 4 Token: 31962
|
| 123 |
-
Run decoder_loop take 60.17565727233887ms
|
| 124 |
-
Iter 5 Token: 6336
|
| 125 |
-
Run decoder_loop take 60.94002723693848ms
|
| 126 |
-
Iter 6 Token: 254
|
| 127 |
-
Run decoder_loop take 60.71639060974121ms
|
| 128 |
-
Iter 7 Token: 2930
|
| 129 |
-
Run decoder_loop take 60.225725173950195ms
|
| 130 |
-
Iter 8 Token: 236
|
| 131 |
-
Run decoder_loop take 60.167789459228516ms
|
| 132 |
-
Iter 9 Token: 36135
|
| 133 |
-
Run decoder_loop take 60.29987335205078ms
|
| 134 |
-
Iter 10 Token: 15868
|
| 135 |
-
Run decoder_loop take 61.163902282714844ms
|
| 136 |
-
Iter 11 Token: 252
|
| 137 |
-
Run decoder_loop take 60.273170471191406ms
|
| 138 |
-
Iter 12 Token: 1546
|
| 139 |
-
Run decoder_loop take 60.23144721984863ms
|
| 140 |
-
Iter 13 Token: 46514
|
| 141 |
-
Run decoder_loop take 60.31966209411621ms
|
| 142 |
-
Iter 14 Token: 50257
|
| 143 |
-
Result: 甚至出现交易几乎停滞的情况
|
| 144 |
```
|
| 145 |
|
| 146 |
运行参数说明:
|
|
@@ -152,6 +139,21 @@ Result: 甚至出现交易几乎停滞的情况
|
|
| 152 |
| --language/-l | 识别语言 | zh |
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
<h3 id="CPP">CPP</h3>
|
| 156 |
|
| 157 |
#### 运行
|
|
@@ -160,50 +162,35 @@ Result: 甚至出现交易几乎停滞的情况
|
|
| 160 |
|
| 161 |
```
|
| 162 |
cd cpp
|
| 163 |
-
./
|
| 164 |
```
|
| 165 |
|
| 166 |
或
|
| 167 |
|
| 168 |
```
|
| 169 |
cd cpp
|
| 170 |
-
./
|
| 171 |
```
|
| 172 |
|
| 173 |
输出结果
|
| 174 |
|
| 175 |
```
|
| 176 |
-
root@ax650:/mnt/
|
| 177 |
-
wav_file:
|
| 178 |
-
model_path:
|
| 179 |
-
model_type:
|
| 180 |
language: zh
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
Next Token: 20204 take 29.64ms
|
| 186 |
-
Next Token: 28455 take 29.65ms
|
| 187 |
-
Next Token: 31962 take 29.61ms
|
| 188 |
-
Next Token: 6336 take 29.67ms
|
| 189 |
-
Next Token: 254 take 29.63ms
|
| 190 |
-
Next Token: 2930 take 29.61ms
|
| 191 |
-
Next Token: 236 take 29.56ms
|
| 192 |
-
Next Token: 36135 take 29.64ms
|
| 193 |
-
Next Token: 15868 take 29.71ms
|
| 194 |
-
Next Token: 252 take 29.51ms
|
| 195 |
-
Next Token: 1546 take 29.63ms
|
| 196 |
-
Next Token: 46514 take 29.51ms
|
| 197 |
-
Next Token: 50257 take 29.69ms
|
| 198 |
-
All take 801.13 ms
|
| 199 |
-
Result: 甚至出现交易几乎停滞的情况
|
| 200 |
```
|
| 201 |
|
| 202 |
### 服务端
|
| 203 |
|
| 204 |
```
|
| 205 |
-
cd cpp
|
| 206 |
-
./whisper_srv --model_type tiny --
|
| 207 |
```
|
| 208 |
|
| 209 |
### 客户端
|
|
|
|
| 5 |
|
| 6 |
# Whisper
|
| 7 |
|
| 8 |
+
<div align="center">
|
| 9 |
+
<a href="README_EN.md">English</a> | <a href="README.md">中文</a>
|
| 10 |
+
</div>
|
| 11 |
+
|
| 12 |
OpenAI Whisper on Axera
|
| 13 |
|
| 14 |
- 目前支持 C++ 和 Python 两种语言
|
|
|
|
| 17 |
|
| 18 |
- 如需自行转换请参考[模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
|
| 19 |
|
| 20 |
+
|
| 21 |
+
## Update
|
| 22 |
+
|
| 23 |
+
- 2026/01/14: 更简单的模型结构,现在只需要encoder和decoder,去掉原来的decoder_main和decoder_loop;支持来自HuggingFace的模型导出
|
| 24 |
+
|
| 25 |
+
|
| 26 |
## 支持平台
|
| 27 |
|
| 28 |
- [x] AX650N
|
|
|
|
| 30 |
|
| 31 |
## 模型转换
|
| 32 |
|
| 33 |
+
目前支持的模型规模:
|
| 34 |
+
- tiny
|
| 35 |
+
- base
|
| 36 |
+
- small
|
| 37 |
+
- medium
|
| 38 |
+
- turbo
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
目前测试过的语言:
|
| 42 |
+
- English
|
| 43 |
+
- Chinese
|
| 44 |
+
- Japanese
|
| 45 |
+
- Korean
|
| 46 |
+
- Malaysian
|
| 47 |
+
|
| 48 |
[模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
|
| 49 |
|
| 50 |
## 上板部署
|
|
|
|
| 112 |
输出结果
|
| 113 |
|
| 114 |
```
|
| 115 |
+
(whisper) root@ax650:/mnt/data/Github/whisper.axera/python# python whisper_cli.py -t tiny -w ../demo.wav
|
| 116 |
[INFO] Available providers: ['AxEngineExecutionProvider']
|
| 117 |
+
{'wav': '../demo.wav', 'model_type': 'tiny', 'model_path': '../models-ax650', 'language': 'zh', 'task': 'transcribe'}
|
|
|
|
|
|
|
|
|
|
| 118 |
[INFO] Using provider: AxEngineExecutionProvider
|
| 119 |
[INFO] Chip type: ChipType.MC50
|
| 120 |
[INFO] VNPU type: VNPUType.DISABLED
|
| 121 |
+
[INFO] Engine version: 2.12.0s
|
| 122 |
[INFO] Model type: 2 (triple core)
|
| 123 |
+
[INFO] Compiler version: 5.0 76f70fdc
|
| 124 |
[INFO] Using provider: AxEngineExecutionProvider
|
| 125 |
[INFO] Model type: 2 (triple core)
|
| 126 |
+
[INFO] Compiler version: 5.0 76f70fdc
|
| 127 |
+
ASR result:
|
| 128 |
+
擅职出现交易几乎停止的情况
|
| 129 |
+
RTF: 0.11406774537746188
|
| 130 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
```
|
| 132 |
|
| 133 |
运行参数说明:
|
|
|
|
| 139 |
| --language/-l | 识别语言 | zh |
|
| 140 |
|
| 141 |
|
| 142 |
+
##### 服务端
|
| 143 |
+
|
| 144 |
+
```
|
| 145 |
+
(whisper) root@ax650:/mnt/data/Github/whisper.axera/python# python whisper_svr.py
|
| 146 |
+
[INFO] Available providers: ['AxEngineExecutionProvider']
|
| 147 |
+
Server started at http://0.0.0.0:8000
|
| 148 |
+
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
测试服务端
|
| 152 |
+
```
|
| 153 |
+
python test_svr.py
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
|
| 157 |
<h3 id="CPP">CPP</h3>
|
| 158 |
|
| 159 |
#### 运行
|
|
|
|
| 162 |
|
| 163 |
```
|
| 164 |
cd cpp
|
| 165 |
+
./whisper_cli -w ../demo.wav -t tiny
|
| 166 |
```
|
| 167 |
|
| 168 |
或
|
| 169 |
|
| 170 |
```
|
| 171 |
cd cpp
|
| 172 |
+
./whisper_cli --model_type small -w ../demo.wav
|
| 173 |
```
|
| 174 |
|
| 175 |
输出结果
|
| 176 |
|
| 177 |
```
|
| 178 |
+
(whisper) root@ax650:/mnt/data/HF/Whisper/cpp/ax650# ./whisper_cli -w ../../demo.wav -t tiny
|
| 179 |
+
wav_file: ../../demo.wav
|
| 180 |
+
model_path: ../../models-ax650
|
| 181 |
+
model_type: tiny
|
| 182 |
language: zh
|
| 183 |
+
Init whisper success, take 0.3540seconds
|
| 184 |
+
Result: 甚至出现交易几乎停止的情况
|
| 185 |
+
RTF: 0.0968
|
| 186 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
```
|
| 188 |
|
| 189 |
### 服务端
|
| 190 |
|
| 191 |
```
|
| 192 |
+
cd cpp/ax650
|
| 193 |
+
./whisper_srv --model_type tiny --language zh --port 8080
|
| 194 |
```
|
| 195 |
|
| 196 |
### 客户端
|
README_EN.md
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# whisper.axera
|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
<a href="README_EN.md">English</a> | <a href="README.md">中文</a>
|
| 5 |
+
</div>
|
| 6 |
+
|
| 7 |
+
OpenAI Whisper on Axera Platform
|
| 8 |
+
|
| 9 |
+
## Overview
|
| 10 |
+
|
| 11 |
+
This project provides an optimized implementation of OpenAI's Whisper speech recognition model for Axera AI processors (AX650N/AX630C). It supports both C++ and Python interfaces for efficient on-device speech-to-text conversion.
|
| 12 |
+
|
| 13 |
+
## Features
|
| 14 |
+
|
| 15 |
+
- **Dual Language Support**: Both C++ and Python APIs available
|
| 16 |
+
- **Multiple Model Sizes**: Support for tiny, base, small, and turbo model variants
|
| 17 |
+
- **Multi-language Recognition**: Tested with English, Chinese, Japanese, and Korean
|
| 18 |
+
- **Optimized Performance**: Specially optimized for Axera NPU acceleration
|
| 19 |
+
- **Easy Deployment**: Pre-built packages and cross-compilation support
|
| 20 |
+
|
| 21 |
+
## Update
|
| 22 |
+
|
| 23 |
+
- 2026/01/14: We provide cleaner model architecture now.(With encoder and decoder instead of decoder_main and decoder_loop). Support exporting models from huggingface.
|
| 24 |
+
|
| 25 |
+
## Supported Platforms
|
| 26 |
+
|
| 27 |
+
- ✅ AX650N
|
| 28 |
+
- ✅ AX630C
|
| 29 |
+
|
| 30 |
+
## Pre-trained Models
|
| 31 |
+
|
| 32 |
+
Download pre-compiled models from:
|
| 33 |
+
- [Baidu Cloud](https://pan.baidu.com/s/1tOHVMZCin0A68T5HmKRJyg?pwd=axyz)
|
| 34 |
+
- [Huggingface](https://huggingface.co/AXERA-TECH/Whisper)
|
| 35 |
+
|
| 36 |
+
For custom model conversion, please refer to [Model Conversion Guide](./model_convert/README_EN.md).
|
| 37 |
+
|
| 38 |
+
## Model Conversion
|
| 39 |
+
|
| 40 |
+
Currently supported model scales:
|
| 41 |
+
- tiny
|
| 42 |
+
- base
|
| 43 |
+
- small
|
| 44 |
+
- medium
|
| 45 |
+
- turbo
|
| 46 |
+
|
| 47 |
+
Tested languages:
|
| 48 |
+
- English
|
| 49 |
+
- Chinese
|
| 50 |
+
- Japanese
|
| 51 |
+
- Korean
|
| 52 |
+
- Malaysian
|
| 53 |
+
|
| 54 |
+
For other languages or custom model sizes, please refer to the [Model Conversion Guide](./model_convert/README_EN.md).
|
| 55 |
+
|
| 56 |
+
## Deployment on Target Devices
|
| 57 |
+
|
| 58 |
+
### Prerequisites
|
| 59 |
+
- AX650N/AX630C devices with Ubuntu 22.04 pre-installed
|
| 60 |
+
- Internet connection for `apt install` and `pip install`
|
| 61 |
+
- Verified hardware platforms:
|
| 62 |
+
- [MaixIV M4nDock (AX650N)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
|
| 63 |
+
- [M.2 Accelerator Card (AX650N)](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
|
| 64 |
+
- [Axera Pi 2 (AX630C)](https://axera-pi-2-docs-cn.readthedocs.io/zh-cn/latest/index.html)
|
| 65 |
+
- [Module-LLM (AX630C)](https://docs.m5stack.com/zh_CN/module/Module-LLM)
|
| 66 |
+
- [LLM630 Compute Kit (AX630C)](https://docs.m5stack.com/zh_CN/core/LLM630%20Compute%20Kit)
|
| 67 |
+
|
| 68 |
+
## Programming Language Support
|
| 69 |
+
|
| 70 |
+
### Python
|
| 71 |
+
|
| 72 |
+
Tested with Python 3.12. We recommend using [Miniconda](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh) for environment management.
|
| 73 |
+
|
| 74 |
+
#### Installation
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
cd python
|
| 78 |
+
pip3 install -r requirements.txt
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
#### pyaxenigne
|
| 82 |
+
|
| 83 |
+
Install NPU Python API from: https://github.com/AXERA-TECH/pyaxengine
|
| 84 |
+
|
| 85 |
+
#### Usage
|
| 86 |
+
|
| 87 |
+
##### Command Line Interface
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
cd python
|
| 91 |
+
(whisper) root@ax650:/mnt/data/HF/Whisper/python# python whisper_cli.py -w ../demo.wav -t tiny
|
| 92 |
+
[INFO] Available providers: ['AxEngineExecutionProvider']
|
| 93 |
+
{'wav': '../demo.wav', 'model_type': 'tiny', 'model_path': '../models-ax650', 'language': 'zh', 'task': 'transcribe'}
|
| 94 |
+
[INFO] Using provider: AxEngineExecutionProvider
|
| 95 |
+
[INFO] Chip type: ChipType.MC50
|
| 96 |
+
[INFO] VNPU type: VNPUType.DISABLED
|
| 97 |
+
[INFO] Engine version: 2.12.0s
|
| 98 |
+
[INFO] Model type: 2 (triple core)
|
| 99 |
+
[INFO] Compiler version: 5.0 76f70fdc
|
| 100 |
+
[INFO] Using provider: AxEngineExecutionProvider
|
| 101 |
+
[INFO] Model type: 2 (triple core)
|
| 102 |
+
[INFO] Compiler version: 5.0 76f70fdc
|
| 103 |
+
ASR result:
|
| 104 |
+
擅职出现交易几乎停止的情况
|
| 105 |
+
RTF: 0.10313174677896837
|
| 106 |
+
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
Command line arguments:
|
| 110 |
+
| Argument | Description | Default |
|
| 111 |
+
| --- | --- | --- |
|
| 112 |
+
| --wav | Input audio file | - |
|
| 113 |
+
| --model_type/-t | Model type: tiny/base/small | - |
|
| 114 |
+
| --model_path/-p | Model directory | ../models |
|
| 115 |
+
| --language/-l | Recognition language | zh |
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
##### Server Mode
|
| 119 |
+
|
| 120 |
+
```
|
| 121 |
+
(whisper) root@ax650:/mnt/data/HF/Whisper/python# python whisper_svr.py
|
| 122 |
+
[INFO] Available providers: ['AxEngineExecutionProvider']
|
| 123 |
+
Server started at http://0.0.0.0:8000
|
| 124 |
+
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
Test the server:
|
| 128 |
+
```
|
| 129 |
+
python test_svr.py
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
<h3 id="CPP">CPP</h3>
|
| 134 |
+
|
| 135 |
+
#### Usage on Target Device
|
| 136 |
+
```
|
| 137 |
+
cd cpp/ax650
|
| 138 |
+
./whisper_cli -w ../demo.wav -t tiny
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
或
|
| 142 |
+
|
| 143 |
+
```
|
| 144 |
+
cd cpp/ax650
|
| 145 |
+
./whisper_cli --model_type small -w ../demo.wav
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
Example Output:
|
| 149 |
+
|
| 150 |
+
```
|
| 151 |
+
(whisper) root@ax650:/mnt/data/HF/Whisper/cpp/ax650# ./whisper_cli -w ../../demo.wav -t tiny
|
| 152 |
+
wav_file: ../../demo.wav
|
| 153 |
+
model_path: ../../models-ax650
|
| 154 |
+
model_type: tiny
|
| 155 |
+
language: zh
|
| 156 |
+
Init whisper success, take 0.3540seconds
|
| 157 |
+
Result: 甚至出现交易几乎停止的情况
|
| 158 |
+
RTF: 0.0968
|
| 159 |
+
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
### Server Mode
|
| 163 |
+
|
| 164 |
+
```
|
| 165 |
+
cd cpp/ax650
|
| 166 |
+
(whisper) root@ax650:/mnt/data/HF/Whisper/cpp/ax650# ./whisper_svr -t tiny
|
| 167 |
+
port: 8080
|
| 168 |
+
model_path: ../../models-ax650
|
| 169 |
+
model_type: tiny
|
| 170 |
+
language: zh
|
| 171 |
+
[I][ main][ 60]: Initializing server...
|
| 172 |
+
[I][ main][ 65]: Init server success
|
| 173 |
+
[I][ start][ 32]: Start server at port 8080, POST binary stream to IP:8080/asr
|
| 174 |
+
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### Client test using curl:
|
| 178 |
+
|
| 179 |
+
```
|
| 180 |
+
ffmpeg -i demo.wav -f f32le -c:a pcm_f32le - 2>/dev/null | \
|
| 181 |
+
curl -X POST 10.126.33.192:8080/asr \
|
| 182 |
+
-H "Content-Type: application/octet-stream" \
|
| 183 |
+
--data-binary @-
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
## Performance Benchmarks
|
| 187 |
+
|
| 188 |
+
### Latency
|
| 189 |
+
|
| 190 |
+
RTF: Real-Time Factor
|
| 191 |
+
|
| 192 |
+
CPP:
|
| 193 |
+
|
| 194 |
+
| Models | AX650N | AX630C |
|
| 195 |
+
| ------------- | ------ | ------ |
|
| 196 |
+
| Whisper-Tiny | 0.08 | |
|
| 197 |
+
| Whisper-Base | 0.11 | 0.35 |
|
| 198 |
+
| Whisper-Small | 0.24 | |
|
| 199 |
+
| Whisper-Turbo | 0.48 | |
|
| 200 |
+
|
| 201 |
+
Python:
|
| 202 |
+
|
| 203 |
+
| Models | AX650N | AX630C |
|
| 204 |
+
| ------------- | ------ | ------ |
|
| 205 |
+
| Whisper-Tiny | 0.12 | |
|
| 206 |
+
| Whisper-Base | 0.16 | 0.35 |
|
| 207 |
+
| Whisper-Small | 0.50 | |
|
| 208 |
+
| Whisper-Turbo | 0.60 | |
|
| 209 |
+
|
| 210 |
+
### Word Error Rate(Test on AIShell dataset)
|
| 211 |
+
|
| 212 |
+
| Models | AX650N | AX630C |
|
| 213 |
+
| ------------- | ------ | ------ |
|
| 214 |
+
| Whisper-Tiny | 0.24 | |
|
| 215 |
+
| Whisper-Base | 0.18 | |
|
| 216 |
+
| Whisper-Small | 0.11 | |
|
| 217 |
+
| Whisper-Turbo | 0.06 | |
|
| 218 |
+
|
| 219 |
+
To reproduce WER test results:
|
| 220 |
+
|
| 221 |
+
Download dataset:
|
| 222 |
+
```
|
| 223 |
+
cd model_convert
|
| 224 |
+
bash download_dataset.sh
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
Run test script:
|
| 228 |
+
```
|
| 229 |
+
cd python
|
| 230 |
+
conda activate whisper
|
| 231 |
+
python test_wer.py -d aishell --gt_path ../model_convert/datasets/ground_truth.txt --model_type tiny
|
| 232 |
+
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
### MEM Usage
|
| 236 |
+
|
| 237 |
+
* CMM Stands for Physical memory used by Axera modules like VDEC(Video decoder), VENC(Video encoder), NPU, etc.
|
| 238 |
+
|
| 239 |
+
Python:
|
| 240 |
+
|
| 241 |
+
| Models | CMM(MB)| OS(MB) |
|
| 242 |
+
| ------------- | ------ | ------ |
|
| 243 |
+
| Whisper-Tiny | 332 | 512 |
|
| 244 |
+
| Whisper-Base | 533 | 644 |
|
| 245 |
+
| Whisper-Small | 1106 | 906 |
|
| 246 |
+
| Whisper-Turbo | 2065 | 2084 |
|
| 247 |
+
|
| 248 |
+
C++:
|
| 249 |
+
|
| 250 |
+
| Models | CMM(MB)| OS(MB) |
|
| 251 |
+
| ------------- | ------ | ------ |
|
| 252 |
+
| Whisper-Tiny | 332 | 31 |
|
| 253 |
+
| Whisper-Base | 533 | 54 |
|
| 254 |
+
| Whisper-Small | 1106 | 146 |
|
| 255 |
+
| Whisper-Turbo | 2065 | 86 |
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
## Technical Discussion
|
| 259 |
+
|
| 260 |
+
- Github issues
|
| 261 |
+
- Tencent QQ Group: 139953715
|
cpp/ax630c/TSCharacters.ocd2
ADDED
|
Binary file (46.1 kB). View file
|
|
|
cpp/ax630c/TSPhrases.ocd2
ADDED
|
Binary file (9.78 kB). View file
|
|
|
cpp/ax630c/include/ax_whisper_api.h
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @file ax_whisper_api.h
|
| 3 |
+
* @brief AX Whisper API header - C-compatible interface for Whisper ASR system
|
| 4 |
+
* @note This header provides a C interface to the Whisper speech recognition system
|
| 5 |
+
*/
|
| 6 |
+
|
| 7 |
+
#ifndef _AX_WHISPER_API_H_
|
| 8 |
+
#define _AX_WHISPER_API_H_
|
| 9 |
+
|
| 10 |
+
#ifdef __cplusplus
|
| 11 |
+
extern "C" {
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
#define AX_WHISPER_API __attribute__((visibility("default")))
|
| 15 |
+
|
| 16 |
+
/**
|
| 17 |
+
* @brief Opaque handle type for Whisper ASR context
|
| 18 |
+
*
|
| 19 |
+
* This handle encapsulates all internal state of the Whisper ASR system.
|
| 20 |
+
* The actual implementation is hidden from C callers to maintain ABI stability.
|
| 21 |
+
*/
|
| 22 |
+
typedef void* AX_WHISPER_HANDLE;
|
| 23 |
+
|
| 24 |
+
/**
|
| 25 |
+
* @brief Initialize the Whisper ASR system with specific configuration
|
| 26 |
+
*
|
| 27 |
+
* Creates and initializes a new Whisper ASR context with the specified
|
| 28 |
+
* model type, model path, and language. This function loads the appropriate
|
| 29 |
+
* models, configures the recognizer, and prepares it for speech recognition.
|
| 30 |
+
*
|
| 31 |
+
* @param model_type Type of Whisper model to use (e.g., "tiny", "base", "small", "medium", "large")
|
| 32 |
+
* or custom model identifier
|
| 33 |
+
* @param model_path Directory path where model files are stored
|
| 34 |
+
* Model files are expected to be in the format:
|
| 35 |
+
* - {model_path}/{model_type}/{model_type}-encoder.axmodel
|
| 36 |
+
* - {model_path}/{model_type}/{model_type}-decoder.axmodel
|
| 37 |
+
* - {model_path}/{model_type}/{model_type}-tokens.txt
|
| 38 |
+
* - {model_path}/{model_type}/{model_type}_config.json
|
| 39 |
+
* @param language Language code for recognition (e.g., "en", "zh", "ja", "ko")
|
| 40 |
+
* Use "auto" for automatic language detection if supported
|
| 41 |
+
*
|
| 42 |
+
* @return AX_WHISPER_HANDLE Opaque handle to the initialized Whisper context,
|
| 43 |
+
* or NULL if initialization fails
|
| 44 |
+
*
|
| 45 |
+
* @note The caller is responsible for calling AX_WHISPER_Uninit() to free
|
| 46 |
+
* resources when the handle is no longer needed.
|
| 47 |
+
* @note If language is not supported by the model, the function may fall back
|
| 48 |
+
* to a default language or return NULL.
|
| 49 |
+
* @example
|
| 50 |
+
* // Initialize English recognition with base model
|
| 51 |
+
* AX_WHISPER_HANDLE handle = AX_WHISPER_Init("base", "../models-ax650", "en");
|
| 52 |
+
*
|
| 53 |
+
*/
|
| 54 |
+
AX_WHISPER_API AX_WHISPER_HANDLE AX_WHISPER_Init(const char* model_type, const char* model_path, const char* language);
|
| 55 |
+
|
| 56 |
+
/**
|
| 57 |
+
* @brief Deinitialize and release Whisper ASR resources
|
| 58 |
+
*
|
| 59 |
+
* Cleans up all resources associated with the Whisper context, including
|
| 60 |
+
* unloading models, freeing memory, and releasing hardware resources.
|
| 61 |
+
*
|
| 62 |
+
* @param handle Whisper context handle obtained from AX_WHISPER_Init()
|
| 63 |
+
*
|
| 64 |
+
* @warning After calling this function, the handle becomes invalid and
|
| 65 |
+
* should not be used in any subsequent API calls.
|
| 66 |
+
*/
|
| 67 |
+
AX_WHISPER_API void AX_WHISPER_Uninit(AX_WHISPER_HANDLE handle);
|
| 68 |
+
|
| 69 |
+
/**
|
| 70 |
+
* @brief Perform speech recognition and return dynamically allocated string
|
| 71 |
+
*
|
| 72 |
+
* @param handle Whisper context handle
|
| 73 |
+
* @param wav_file Path to the input 16k pcmf32 WAV audio file
|
| 74 |
+
* @param result Pointer to receive the allocated result string
|
| 75 |
+
*
|
| 76 |
+
* @return int Status code (0 = success, <0 = error)
|
| 77 |
+
*
|
| 78 |
+
* @note The returned string is allocated with malloc() and must be freed
|
| 79 |
+
* by the caller using free() when no longer needed.
|
| 80 |
+
*/
|
| 81 |
+
AX_WHISPER_API int AX_WHISPER_RunFile(AX_WHISPER_HANDLE handle,
|
| 82 |
+
const char* wav_file,
|
| 83 |
+
char** result);
|
| 84 |
+
|
| 85 |
+
/**
|
| 86 |
+
* @brief Perform speech recognition and return dynamically allocated string
|
| 87 |
+
*
|
| 88 |
+
* @param handle Whisper context handle
|
| 89 |
+
* @param pcm_data 16k Mono PCM f32 data, range from -1.0 to 1.0
|
| 90 |
+
* @param num_samples Sample num of PCM data
|
| 91 |
+
* @param result Pointer to receive the allocated result string
|
| 92 |
+
*
|
| 93 |
+
* @return int Status code (0 = success, <0 = error)
|
| 94 |
+
*
|
| 95 |
+
* @note The returned string is allocated with malloc() and must be freed
|
| 96 |
+
* by the caller using free() when no longer needed.
|
| 97 |
+
*/
|
| 98 |
+
AX_WHISPER_API int AX_WHISPER_RunPCM(AX_WHISPER_HANDLE handle,
|
| 99 |
+
float* pcm_data,
|
| 100 |
+
int num_samples,
|
| 101 |
+
char** result);
|
| 102 |
+
|
| 103 |
+
#ifdef __cplusplus
|
| 104 |
+
}
|
| 105 |
+
#endif
|
| 106 |
+
|
| 107 |
+
#endif // _AX_WHISPER_API_H_
|
cpp/ax630c/lib/cmake/ax_whisper/ax_whisper-config-release.cmake
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#----------------------------------------------------------------
|
| 2 |
+
# Generated CMake target import file for configuration "Release".
|
| 3 |
+
#----------------------------------------------------------------
|
| 4 |
+
|
| 5 |
+
# Commands may need to know the format version.
|
| 6 |
+
set(CMAKE_IMPORT_FILE_VERSION 1)
|
| 7 |
+
|
| 8 |
+
# Import target "ax::ax_whisper" for configuration "Release"
|
| 9 |
+
set_property(TARGET ax::ax_whisper APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
|
| 10 |
+
set_target_properties(ax::ax_whisper PROPERTIES
|
| 11 |
+
IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libax_whisper.so"
|
| 12 |
+
IMPORTED_SONAME_RELEASE "libax_whisper.so"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
list(APPEND _IMPORT_CHECK_TARGETS ax::ax_whisper )
|
| 16 |
+
list(APPEND _IMPORT_CHECK_FILES_FOR_ax::ax_whisper "${_IMPORT_PREFIX}/lib/libax_whisper.so" )
|
| 17 |
+
|
| 18 |
+
# Commands beyond this point should not need to know the version.
|
| 19 |
+
set(CMAKE_IMPORT_FILE_VERSION)
|
cpp/ax630c/lib/cmake/ax_whisper/ax_whisper-config.cmake
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by CMake
|
| 2 |
+
|
| 3 |
+
if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.6)
|
| 4 |
+
message(FATAL_ERROR "CMake >= 2.6.0 required")
|
| 5 |
+
endif()
|
| 6 |
+
cmake_policy(PUSH)
|
| 7 |
+
cmake_policy(VERSION 2.6...3.20)
|
| 8 |
+
#----------------------------------------------------------------
|
| 9 |
+
# Generated CMake target import file.
|
| 10 |
+
#----------------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
# Commands may need to know the format version.
|
| 13 |
+
set(CMAKE_IMPORT_FILE_VERSION 1)
|
| 14 |
+
|
| 15 |
+
# Protect against multiple inclusion, which would fail when already imported targets are added once more.
|
| 16 |
+
set(_targetsDefined)
|
| 17 |
+
set(_targetsNotDefined)
|
| 18 |
+
set(_expectedTargets)
|
| 19 |
+
foreach(_expectedTarget ax::ax_whisper)
|
| 20 |
+
list(APPEND _expectedTargets ${_expectedTarget})
|
| 21 |
+
if(NOT TARGET ${_expectedTarget})
|
| 22 |
+
list(APPEND _targetsNotDefined ${_expectedTarget})
|
| 23 |
+
endif()
|
| 24 |
+
if(TARGET ${_expectedTarget})
|
| 25 |
+
list(APPEND _targetsDefined ${_expectedTarget})
|
| 26 |
+
endif()
|
| 27 |
+
endforeach()
|
| 28 |
+
if("${_targetsDefined}" STREQUAL "${_expectedTargets}")
|
| 29 |
+
unset(_targetsDefined)
|
| 30 |
+
unset(_targetsNotDefined)
|
| 31 |
+
unset(_expectedTargets)
|
| 32 |
+
set(CMAKE_IMPORT_FILE_VERSION)
|
| 33 |
+
cmake_policy(POP)
|
| 34 |
+
return()
|
| 35 |
+
endif()
|
| 36 |
+
if(NOT "${_targetsDefined}" STREQUAL "")
|
| 37 |
+
message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_targetsDefined}\nTargets not yet defined: ${_targetsNotDefined}\n")
|
| 38 |
+
endif()
|
| 39 |
+
unset(_targetsDefined)
|
| 40 |
+
unset(_targetsNotDefined)
|
| 41 |
+
unset(_expectedTargets)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Compute the installation prefix relative to this file.
|
| 45 |
+
get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
|
| 46 |
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
| 47 |
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
| 48 |
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
| 49 |
+
if(_IMPORT_PREFIX STREQUAL "/")
|
| 50 |
+
set(_IMPORT_PREFIX "")
|
| 51 |
+
endif()
|
| 52 |
+
|
| 53 |
+
# Create imported target ax::ax_whisper
|
| 54 |
+
add_library(ax::ax_whisper SHARED IMPORTED)
|
| 55 |
+
|
| 56 |
+
set_target_properties(ax::ax_whisper PROPERTIES
|
| 57 |
+
INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Load information for each installed configuration.
|
| 61 |
+
get_filename_component(_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
|
| 62 |
+
file(GLOB CONFIG_FILES "${_DIR}/ax_whisper-config-*.cmake")
|
| 63 |
+
foreach(f ${CONFIG_FILES})
|
| 64 |
+
include(${f})
|
| 65 |
+
endforeach()
|
| 66 |
+
|
| 67 |
+
# Cleanup temporary variables.
|
| 68 |
+
set(_IMPORT_PREFIX)
|
| 69 |
+
|
| 70 |
+
# Loop over all imported files and verify that they actually exist
|
| 71 |
+
foreach(target ${_IMPORT_CHECK_TARGETS} )
|
| 72 |
+
foreach(file ${_IMPORT_CHECK_FILES_FOR_${target}} )
|
| 73 |
+
if(NOT EXISTS "${file}" )
|
| 74 |
+
message(FATAL_ERROR "The imported target \"${target}\" references the file
|
| 75 |
+
\"${file}\"
|
| 76 |
+
but this file does not exist. Possible reasons include:
|
| 77 |
+
* The file was deleted, renamed, or moved to another location.
|
| 78 |
+
* An install or uninstall procedure did not complete successfully.
|
| 79 |
+
* The installation package was faulty and contained
|
| 80 |
+
\"${CMAKE_CURRENT_LIST_FILE}\"
|
| 81 |
+
but not all the files it references.
|
| 82 |
+
")
|
| 83 |
+
endif()
|
| 84 |
+
endforeach()
|
| 85 |
+
unset(_IMPORT_CHECK_FILES_FOR_${target})
|
| 86 |
+
endforeach()
|
| 87 |
+
unset(_IMPORT_CHECK_TARGETS)
|
| 88 |
+
|
| 89 |
+
# This file does not depend on other imported targets which have
|
| 90 |
+
# been exported from the same project but in a separate export set.
|
| 91 |
+
|
| 92 |
+
# Commands beyond this point should not need to know the version.
|
| 93 |
+
set(CMAKE_IMPORT_FILE_VERSION)
|
| 94 |
+
cmake_policy(POP)
|
cpp/{TSCharacters.ocd2 → ax630c/lib/libax_whisper.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:469ffe5522d67f92b9b7d5390dc17740cbef3cd3b8b484542a8e1de44c11ad5a
|
| 3 |
+
size 623784
|
cpp/ax630c/t2s.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "Traditional Chinese to Simplified Chinese",
|
| 3 |
+
"segmentation": {
|
| 4 |
+
"type": "mmseg",
|
| 5 |
+
"dict": {
|
| 6 |
+
"type": "ocd2",
|
| 7 |
+
"file": "TSPhrases.ocd2"
|
| 8 |
+
}
|
| 9 |
+
},
|
| 10 |
+
"conversion_chain": [{
|
| 11 |
+
"dict": {
|
| 12 |
+
"type": "group",
|
| 13 |
+
"dicts": [{
|
| 14 |
+
"type": "ocd2",
|
| 15 |
+
"file": "TSPhrases.ocd2"
|
| 16 |
+
}, {
|
| 17 |
+
"type": "ocd2",
|
| 18 |
+
"file": "TSCharacters.ocd2"
|
| 19 |
+
}]
|
| 20 |
+
}
|
| 21 |
+
}]
|
| 22 |
+
}
|
cpp/ax630c/whisper_cli
ADDED
|
Binary file (93.8 kB). View file
|
|
|
cpp/{TSPhrases.ocd2 → ax630c/whisper_svr}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2af4115e31d25a1c85f39666aba3918403449ca64e60d11f5af7e49c7456cd7
|
| 3 |
+
size 531688
|
cpp/ax650/TSCharacters.ocd2
ADDED
|
Binary file (46.1 kB). View file
|
|
|
cpp/ax650/TSPhrases.ocd2
ADDED
|
Binary file (9.78 kB). View file
|
|
|
cpp/ax650/include/ax_whisper_api.h
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @file ax_whisper_api.h
|
| 3 |
+
* @brief AX Whisper API header - C-compatible interface for Whisper ASR system
|
| 4 |
+
* @note This header provides a C interface to the Whisper speech recognition system
|
| 5 |
+
*/
|
| 6 |
+
|
| 7 |
+
#ifndef _AX_WHISPER_API_H_
|
| 8 |
+
#define _AX_WHISPER_API_H_
|
| 9 |
+
|
| 10 |
+
#ifdef __cplusplus
|
| 11 |
+
extern "C" {
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
#define AX_WHISPER_API __attribute__((visibility("default")))
|
| 15 |
+
|
| 16 |
+
/**
|
| 17 |
+
* @brief Opaque handle type for Whisper ASR context
|
| 18 |
+
*
|
| 19 |
+
* This handle encapsulates all internal state of the Whisper ASR system.
|
| 20 |
+
* The actual implementation is hidden from C callers to maintain ABI stability.
|
| 21 |
+
*/
|
| 22 |
+
typedef void* AX_WHISPER_HANDLE;
|
| 23 |
+
|
| 24 |
+
/**
|
| 25 |
+
* @brief Initialize the Whisper ASR system with specific configuration
|
| 26 |
+
*
|
| 27 |
+
* Creates and initializes a new Whisper ASR context with the specified
|
| 28 |
+
* model type, model path, and language. This function loads the appropriate
|
| 29 |
+
* models, configures the recognizer, and prepares it for speech recognition.
|
| 30 |
+
*
|
| 31 |
+
* @param model_type Type of Whisper model to use (e.g., "tiny", "base", "small", "medium", "large")
|
| 32 |
+
* or custom model identifier
|
| 33 |
+
* @param model_path Directory path where model files are stored
|
| 34 |
+
* Model files are expected to be in the format:
|
| 35 |
+
* - {model_path}/{model_type}/{model_type}-encoder.axmodel
|
| 36 |
+
* - {model_path}/{model_type}/{model_type}-decoder.axmodel
|
| 37 |
+
* - {model_path}/{model_type}/{model_type}-tokens.txt
|
| 38 |
+
* - {model_path}/{model_type}/{model_type}_config.json
|
| 39 |
+
* @param language Language code for recognition (e.g., "en", "zh", "ja", "ko")
|
| 40 |
+
* Use "auto" for automatic language detection if supported
|
| 41 |
+
*
|
| 42 |
+
* @return AX_WHISPER_HANDLE Opaque handle to the initialized Whisper context,
|
| 43 |
+
* or NULL if initialization fails
|
| 44 |
+
*
|
| 45 |
+
* @note The caller is responsible for calling AX_WHISPER_Uninit() to free
|
| 46 |
+
* resources when the handle is no longer needed.
|
| 47 |
+
* @note If language is not supported by the model, the function may fall back
|
| 48 |
+
* to a default language or return NULL.
|
| 49 |
+
* @example
|
| 50 |
+
* // Initialize English recognition with base model
|
| 51 |
+
* AX_WHISPER_HANDLE handle = AX_WHISPER_Init("base", "../models-ax650", "en");
|
| 52 |
+
*
|
| 53 |
+
*/
|
| 54 |
+
AX_WHISPER_API AX_WHISPER_HANDLE AX_WHISPER_Init(const char* model_type, const char* model_path, const char* language);
|
| 55 |
+
|
| 56 |
+
/**
|
| 57 |
+
* @brief Deinitialize and release Whisper ASR resources
|
| 58 |
+
*
|
| 59 |
+
* Cleans up all resources associated with the Whisper context, including
|
| 60 |
+
* unloading models, freeing memory, and releasing hardware resources.
|
| 61 |
+
*
|
| 62 |
+
* @param handle Whisper context handle obtained from AX_WHISPER_Init()
|
| 63 |
+
*
|
| 64 |
+
* @warning After calling this function, the handle becomes invalid and
|
| 65 |
+
* should not be used in any subsequent API calls.
|
| 66 |
+
*/
|
| 67 |
+
AX_WHISPER_API void AX_WHISPER_Uninit(AX_WHISPER_HANDLE handle);
|
| 68 |
+
|
| 69 |
+
/**
|
| 70 |
+
* @brief Perform speech recognition and return dynamically allocated string
|
| 71 |
+
*
|
| 72 |
+
* @param handle Whisper context handle
|
| 73 |
+
* @param wav_file Path to the input 16k pcmf32 WAV audio file
|
| 74 |
+
* @param result Pointer to receive the allocated result string
|
| 75 |
+
*
|
| 76 |
+
* @return int Status code (0 = success, <0 = error)
|
| 77 |
+
*
|
| 78 |
+
* @note The returned string is allocated with malloc() and must be freed
|
| 79 |
+
* by the caller using free() when no longer needed.
|
| 80 |
+
*/
|
| 81 |
+
AX_WHISPER_API int AX_WHISPER_RunFile(AX_WHISPER_HANDLE handle,
|
| 82 |
+
const char* wav_file,
|
| 83 |
+
char** result);
|
| 84 |
+
|
| 85 |
+
/**
|
| 86 |
+
* @brief Perform speech recognition and return dynamically allocated string
|
| 87 |
+
*
|
| 88 |
+
* @param handle Whisper context handle
|
| 89 |
+
* @param pcm_data 16k Mono PCM f32 data, range from -1.0 to 1.0
|
| 90 |
+
* @param num_samples Sample num of PCM data
|
| 91 |
+
* @param result Pointer to receive the allocated result string
|
| 92 |
+
*
|
| 93 |
+
* @return int Status code (0 = success, <0 = error)
|
| 94 |
+
*
|
| 95 |
+
* @note The returned string is allocated with malloc() and must be freed
|
| 96 |
+
* by the caller using free() when no longer needed.
|
| 97 |
+
*/
|
| 98 |
+
AX_WHISPER_API int AX_WHISPER_RunPCM(AX_WHISPER_HANDLE handle,
|
| 99 |
+
float* pcm_data,
|
| 100 |
+
int num_samples,
|
| 101 |
+
char** result);
|
| 102 |
+
|
| 103 |
+
#ifdef __cplusplus
|
| 104 |
+
}
|
| 105 |
+
#endif
|
| 106 |
+
|
| 107 |
+
#endif // _AX_WHISPER_API_H_
|
cpp/ax650/lib/cmake/ax_whisper/ax_whisper-config-release.cmake
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#----------------------------------------------------------------
|
| 2 |
+
# Generated CMake target import file for configuration "Release".
|
| 3 |
+
#----------------------------------------------------------------
|
| 4 |
+
|
| 5 |
+
# Commands may need to know the format version.
|
| 6 |
+
set(CMAKE_IMPORT_FILE_VERSION 1)
|
| 7 |
+
|
| 8 |
+
# Import target "ax::ax_whisper" for configuration "Release"
|
| 9 |
+
set_property(TARGET ax::ax_whisper APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
|
| 10 |
+
set_target_properties(ax::ax_whisper PROPERTIES
|
| 11 |
+
IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libax_whisper.so"
|
| 12 |
+
IMPORTED_SONAME_RELEASE "libax_whisper.so"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
list(APPEND _IMPORT_CHECK_TARGETS ax::ax_whisper )
|
| 16 |
+
list(APPEND _IMPORT_CHECK_FILES_FOR_ax::ax_whisper "${_IMPORT_PREFIX}/lib/libax_whisper.so" )
|
| 17 |
+
|
| 18 |
+
# Commands beyond this point should not need to know the version.
|
| 19 |
+
set(CMAKE_IMPORT_FILE_VERSION)
|
cpp/ax650/lib/cmake/ax_whisper/ax_whisper-config.cmake
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by CMake
|
| 2 |
+
|
| 3 |
+
if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.6)
|
| 4 |
+
message(FATAL_ERROR "CMake >= 2.6.0 required")
|
| 5 |
+
endif()
|
| 6 |
+
cmake_policy(PUSH)
|
| 7 |
+
cmake_policy(VERSION 2.6...3.20)
|
| 8 |
+
#----------------------------------------------------------------
|
| 9 |
+
# Generated CMake target import file.
|
| 10 |
+
#----------------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
# Commands may need to know the format version.
|
| 13 |
+
set(CMAKE_IMPORT_FILE_VERSION 1)
|
| 14 |
+
|
| 15 |
+
# Protect against multiple inclusion, which would fail when already imported targets are added once more.
|
| 16 |
+
set(_targetsDefined)
|
| 17 |
+
set(_targetsNotDefined)
|
| 18 |
+
set(_expectedTargets)
|
| 19 |
+
foreach(_expectedTarget ax::ax_whisper)
|
| 20 |
+
list(APPEND _expectedTargets ${_expectedTarget})
|
| 21 |
+
if(NOT TARGET ${_expectedTarget})
|
| 22 |
+
list(APPEND _targetsNotDefined ${_expectedTarget})
|
| 23 |
+
endif()
|
| 24 |
+
if(TARGET ${_expectedTarget})
|
| 25 |
+
list(APPEND _targetsDefined ${_expectedTarget})
|
| 26 |
+
endif()
|
| 27 |
+
endforeach()
|
| 28 |
+
if("${_targetsDefined}" STREQUAL "${_expectedTargets}")
|
| 29 |
+
unset(_targetsDefined)
|
| 30 |
+
unset(_targetsNotDefined)
|
| 31 |
+
unset(_expectedTargets)
|
| 32 |
+
set(CMAKE_IMPORT_FILE_VERSION)
|
| 33 |
+
cmake_policy(POP)
|
| 34 |
+
return()
|
| 35 |
+
endif()
|
| 36 |
+
if(NOT "${_targetsDefined}" STREQUAL "")
|
| 37 |
+
message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_targetsDefined}\nTargets not yet defined: ${_targetsNotDefined}\n")
|
| 38 |
+
endif()
|
| 39 |
+
unset(_targetsDefined)
|
| 40 |
+
unset(_targetsNotDefined)
|
| 41 |
+
unset(_expectedTargets)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Compute the installation prefix relative to this file.
|
| 45 |
+
get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
|
| 46 |
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
| 47 |
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
| 48 |
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
| 49 |
+
if(_IMPORT_PREFIX STREQUAL "/")
|
| 50 |
+
set(_IMPORT_PREFIX "")
|
| 51 |
+
endif()
|
| 52 |
+
|
| 53 |
+
# Create imported target ax::ax_whisper
|
| 54 |
+
add_library(ax::ax_whisper SHARED IMPORTED)
|
| 55 |
+
|
| 56 |
+
set_target_properties(ax::ax_whisper PROPERTIES
|
| 57 |
+
INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Load information for each installed configuration.
|
| 61 |
+
get_filename_component(_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
|
| 62 |
+
file(GLOB CONFIG_FILES "${_DIR}/ax_whisper-config-*.cmake")
|
| 63 |
+
foreach(f ${CONFIG_FILES})
|
| 64 |
+
include(${f})
|
| 65 |
+
endforeach()
|
| 66 |
+
|
| 67 |
+
# Cleanup temporary variables.
|
| 68 |
+
set(_IMPORT_PREFIX)
|
| 69 |
+
|
| 70 |
+
# Loop over all imported files and verify that they actually exist
|
| 71 |
+
foreach(target ${_IMPORT_CHECK_TARGETS} )
|
| 72 |
+
foreach(file ${_IMPORT_CHECK_FILES_FOR_${target}} )
|
| 73 |
+
if(NOT EXISTS "${file}" )
|
| 74 |
+
message(FATAL_ERROR "The imported target \"${target}\" references the file
|
| 75 |
+
\"${file}\"
|
| 76 |
+
but this file does not exist. Possible reasons include:
|
| 77 |
+
* The file was deleted, renamed, or moved to another location.
|
| 78 |
+
* An install or uninstall procedure did not complete successfully.
|
| 79 |
+
* The installation package was faulty and contained
|
| 80 |
+
\"${CMAKE_CURRENT_LIST_FILE}\"
|
| 81 |
+
but not all the files it references.
|
| 82 |
+
")
|
| 83 |
+
endif()
|
| 84 |
+
endforeach()
|
| 85 |
+
unset(_IMPORT_CHECK_FILES_FOR_${target})
|
| 86 |
+
endforeach()
|
| 87 |
+
unset(_IMPORT_CHECK_TARGETS)
|
| 88 |
+
|
| 89 |
+
# This file does not depend on other imported targets which have
|
| 90 |
+
# been exported from the same project but in a separate export set.
|
| 91 |
+
|
| 92 |
+
# Commands beyond this point should not need to know the version.
|
| 93 |
+
set(CMAKE_IMPORT_FILE_VERSION)
|
| 94 |
+
cmake_policy(POP)
|
cpp/{t2s.json → ax650/lib/libax_whisper.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2aafbad2ea23d93c226fdb3d7d22ccbad44813638c78222787c8de9f85ec358
|
| 3 |
+
size 624072
|
cpp/ax650/t2s.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "Traditional Chinese to Simplified Chinese",
|
| 3 |
+
"segmentation": {
|
| 4 |
+
"type": "mmseg",
|
| 5 |
+
"dict": {
|
| 6 |
+
"type": "ocd2",
|
| 7 |
+
"file": "TSPhrases.ocd2"
|
| 8 |
+
}
|
| 9 |
+
},
|
| 10 |
+
"conversion_chain": [{
|
| 11 |
+
"dict": {
|
| 12 |
+
"type": "group",
|
| 13 |
+
"dicts": [{
|
| 14 |
+
"type": "ocd2",
|
| 15 |
+
"file": "TSPhrases.ocd2"
|
| 16 |
+
}, {
|
| 17 |
+
"type": "ocd2",
|
| 18 |
+
"file": "TSCharacters.ocd2"
|
| 19 |
+
}]
|
| 20 |
+
}
|
| 21 |
+
}]
|
| 22 |
+
}
|
cpp/ax650/whisper_cli
ADDED
|
Binary file (93.8 kB). View file
|
|
|
cpp/{whisper_aarch64 → ax650/whisper_svr}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:400a719b950cf61863179204d509c4f785b75184af2415bca9ed04a0198bf363
|
| 3 |
+
size 531688
|
cpp/whisper_axcl_aarch64
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:47ad92295eecfcd230f420acbb5ecd86211c53735aa216d935b22e17d740cd09
|
| 3 |
-
size 1251248
|
|
|
|
|
|
|
|
|
|
|
|
cpp/whisper_axcl_x86
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f59c233fb5c7d935422b650b5cd968e4b40654d368c671e6551c8ea7dcc78cee
|
| 3 |
-
size 1212056
|
|
|
|
|
|
|
|
|
|
|
|
cpp/whisper_srv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:27f1a07763df0f8fc2a94cf32551daae4082111668ea5e938b93bdd6a76e3d29
|
| 3 |
-
size 1006728
|
|
|
|
|
|
|
|
|
|
|
|
python/assets/multilingual.tiktoken
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
python/languages.py
DELETED
|
@@ -1,102 +0,0 @@
|
|
| 1 |
-
WHISPER_LANGUAGES = {
|
| 2 |
-
"en": "english",
|
| 3 |
-
"zh": "chinese",
|
| 4 |
-
"de": "german",
|
| 5 |
-
"es": "spanish",
|
| 6 |
-
"ru": "russian",
|
| 7 |
-
"ko": "korean",
|
| 8 |
-
"fr": "french",
|
| 9 |
-
"ja": "japanese",
|
| 10 |
-
"pt": "portuguese",
|
| 11 |
-
"tr": "turkish",
|
| 12 |
-
"pl": "polish",
|
| 13 |
-
"ca": "catalan",
|
| 14 |
-
"nl": "dutch",
|
| 15 |
-
"ar": "arabic",
|
| 16 |
-
"sv": "swedish",
|
| 17 |
-
"it": "italian",
|
| 18 |
-
"id": "indonesian",
|
| 19 |
-
"hi": "hindi",
|
| 20 |
-
"fi": "finnish",
|
| 21 |
-
"vi": "vietnamese",
|
| 22 |
-
"he": "hebrew",
|
| 23 |
-
"uk": "ukrainian",
|
| 24 |
-
"el": "greek",
|
| 25 |
-
"ms": "malay",
|
| 26 |
-
"cs": "czech",
|
| 27 |
-
"ro": "romanian",
|
| 28 |
-
"da": "danish",
|
| 29 |
-
"hu": "hungarian",
|
| 30 |
-
"ta": "tamil",
|
| 31 |
-
"no": "norwegian",
|
| 32 |
-
"th": "thai",
|
| 33 |
-
"ur": "urdu",
|
| 34 |
-
"hr": "croatian",
|
| 35 |
-
"bg": "bulgarian",
|
| 36 |
-
"lt": "lithuanian",
|
| 37 |
-
"la": "latin",
|
| 38 |
-
"mi": "maori",
|
| 39 |
-
"ml": "malayalam",
|
| 40 |
-
"cy": "welsh",
|
| 41 |
-
"sk": "slovak",
|
| 42 |
-
"te": "telugu",
|
| 43 |
-
"fa": "persian",
|
| 44 |
-
"lv": "latvian",
|
| 45 |
-
"bn": "bengali",
|
| 46 |
-
"sr": "serbian",
|
| 47 |
-
"az": "azerbaijani",
|
| 48 |
-
"sl": "slovenian",
|
| 49 |
-
"kn": "kannada",
|
| 50 |
-
"et": "estonian",
|
| 51 |
-
"mk": "macedonian",
|
| 52 |
-
"br": "breton",
|
| 53 |
-
"eu": "basque",
|
| 54 |
-
"is": "icelandic",
|
| 55 |
-
"hy": "armenian",
|
| 56 |
-
"ne": "nepali",
|
| 57 |
-
"mn": "mongolian",
|
| 58 |
-
"bs": "bosnian",
|
| 59 |
-
"kk": "kazakh",
|
| 60 |
-
"sq": "albanian",
|
| 61 |
-
"sw": "swahili",
|
| 62 |
-
"gl": "galician",
|
| 63 |
-
"mr": "marathi",
|
| 64 |
-
"pa": "punjabi",
|
| 65 |
-
"si": "sinhala",
|
| 66 |
-
"km": "khmer",
|
| 67 |
-
"sn": "shona",
|
| 68 |
-
"yo": "yoruba",
|
| 69 |
-
"so": "somali",
|
| 70 |
-
"af": "afrikaans",
|
| 71 |
-
"oc": "occitan",
|
| 72 |
-
"ka": "georgian",
|
| 73 |
-
"be": "belarusian",
|
| 74 |
-
"tg": "tajik",
|
| 75 |
-
"sd": "sindhi",
|
| 76 |
-
"gu": "gujarati",
|
| 77 |
-
"am": "amharic",
|
| 78 |
-
"yi": "yiddish",
|
| 79 |
-
"lo": "lao",
|
| 80 |
-
"uz": "uzbek",
|
| 81 |
-
"fo": "faroese",
|
| 82 |
-
"ht": "haitian creole",
|
| 83 |
-
"ps": "pashto",
|
| 84 |
-
"tk": "turkmen",
|
| 85 |
-
"nn": "nynorsk",
|
| 86 |
-
"mt": "maltese",
|
| 87 |
-
"sa": "sanskrit",
|
| 88 |
-
"lb": "luxembourgish",
|
| 89 |
-
"my": "myanmar",
|
| 90 |
-
"bo": "tibetan",
|
| 91 |
-
"tl": "tagalog",
|
| 92 |
-
"mg": "malagasy",
|
| 93 |
-
"as": "assamese",
|
| 94 |
-
"tt": "tatar",
|
| 95 |
-
"haw": "hawaiian",
|
| 96 |
-
"ln": "lingala",
|
| 97 |
-
"ha": "hausa",
|
| 98 |
-
"ba": "bashkir",
|
| 99 |
-
"jw": "javanese",
|
| 100 |
-
"su": "sundanese",
|
| 101 |
-
"yue": "cantonese",
|
| 102 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
python/main.py
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
from whisper import Whisper
|
| 4 |
-
import time
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def get_args():
|
| 8 |
-
parser = argparse.ArgumentParser(
|
| 9 |
-
prog="whisper", description="Run Whisper on input audio file"
|
| 10 |
-
)
|
| 11 |
-
parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
|
| 12 |
-
parser.add_argument(
|
| 13 |
-
"--model_type",
|
| 14 |
-
"-t",
|
| 15 |
-
type=str,
|
| 16 |
-
choices=["tiny", "base", "small", "large", "large-v3", "turbo"],
|
| 17 |
-
required=True,
|
| 18 |
-
help="model type, only support tiny, base and small currently",
|
| 19 |
-
)
|
| 20 |
-
parser.add_argument(
|
| 21 |
-
"--model_path",
|
| 22 |
-
"-p",
|
| 23 |
-
type=str,
|
| 24 |
-
required=False,
|
| 25 |
-
default="../models-ax650",
|
| 26 |
-
help="model path for *.axmodel, tokens.txt, positional_embedding.bin",
|
| 27 |
-
)
|
| 28 |
-
parser.add_argument(
|
| 29 |
-
"--language",
|
| 30 |
-
"-l",
|
| 31 |
-
type=str,
|
| 32 |
-
required=False,
|
| 33 |
-
default="zh",
|
| 34 |
-
help="Target language, support en, zh, ja, and others. See languages.py for more options.",
|
| 35 |
-
)
|
| 36 |
-
parser.add_argument(
|
| 37 |
-
"--task",
|
| 38 |
-
type=str,
|
| 39 |
-
required=False,
|
| 40 |
-
choices=["translate", "transcribe"],
|
| 41 |
-
default="transcribe",
|
| 42 |
-
)
|
| 43 |
-
parser.add_argument(
|
| 44 |
-
"--print_rtf", action="store_true", help="Print Real-Time Factor"
|
| 45 |
-
)
|
| 46 |
-
return parser.parse_args()
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def main():
|
| 50 |
-
args = get_args()
|
| 51 |
-
print(vars(args))
|
| 52 |
-
|
| 53 |
-
# Check wav existence
|
| 54 |
-
wav_path = args.wav
|
| 55 |
-
assert os.path.exists(wav_path), f"{wav_path} NOT exist"
|
| 56 |
-
|
| 57 |
-
model = Whisper(args.model_type, args.model_path, args.language, args.task)
|
| 58 |
-
|
| 59 |
-
print("ASR result:")
|
| 60 |
-
start = time.time()
|
| 61 |
-
print(model.run(wav_path))
|
| 62 |
-
end = time.time()
|
| 63 |
-
|
| 64 |
-
if args.print_rtf:
|
| 65 |
-
import librosa
|
| 66 |
-
|
| 67 |
-
samples, sr = librosa.load(wav_path, sr=16000)
|
| 68 |
-
duration = len(samples) / sr
|
| 69 |
-
process_time = end - start
|
| 70 |
-
print(f"RTF: {process_time / duration}")
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
if __name__ == "__main__":
|
| 74 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
python/test_svr.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def transcribe_audio(
|
| 5 |
+
server_url: str,
|
| 6 |
+
wav_path: str,
|
| 7 |
+
model_type: str = "tiny",
|
| 8 |
+
model_path: str = "../models-ax650",
|
| 9 |
+
language: str = "zh",
|
| 10 |
+
task: str = "transcribe",
|
| 11 |
+
):
|
| 12 |
+
url = f"{server_url.rstrip('/')}/asr"
|
| 13 |
+
|
| 14 |
+
files = {
|
| 15 |
+
"wav": open(wav_path, "rb"),
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
data = {
|
| 19 |
+
"model_type": model_type,
|
| 20 |
+
"model_path": model_path,
|
| 21 |
+
"language": language,
|
| 22 |
+
"task": task,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
print(f"Sending request to: {url}")
|
| 26 |
+
|
| 27 |
+
response = requests.post(url, files=files, data=data)
|
| 28 |
+
if response.status_code != 200:
|
| 29 |
+
print("❌ Error:", response.text)
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
result = response.json()
|
| 33 |
+
print("服务器返回结果:")
|
| 34 |
+
print(result)
|
| 35 |
+
|
| 36 |
+
return result
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
# 你的服务器地址
|
| 41 |
+
SERVER = "http://127.0.0.1:8000"
|
| 42 |
+
|
| 43 |
+
# 本地 wav 文件路径
|
| 44 |
+
WAV = "../demo.wav"
|
| 45 |
+
|
| 46 |
+
transcribe_audio(SERVER, WAV)
|
python/test_wer.py
CHANGED
|
@@ -2,14 +2,19 @@ import argparse
|
|
| 2 |
import os
|
| 3 |
import logging
|
| 4 |
import re
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
-
def setup_logging():
|
| 9 |
"""配置日志系统,同时输出到控制台和文件"""
|
| 10 |
# 获取脚本所在目录
|
| 11 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
-
log_file = os.path.join(script_dir, "
|
| 13 |
|
| 14 |
# 配置日志格式
|
| 15 |
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
|
@@ -24,7 +29,7 @@ def setup_logging():
|
|
| 24 |
logger.removeHandler(handler)
|
| 25 |
|
| 26 |
# 创建文件handler
|
| 27 |
-
file_handler = logging.FileHandler(log_file, mode="
|
| 28 |
file_handler.setLevel(logging.INFO)
|
| 29 |
file_formatter = logging.Formatter(log_format, date_format)
|
| 30 |
file_handler.setFormatter(file_formatter)
|
|
@@ -42,6 +47,61 @@ def setup_logging():
|
|
| 42 |
return logger
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
class AIShellDataset:
|
| 46 |
def __init__(self, gt_path: str):
|
| 47 |
"""
|
|
@@ -149,6 +209,56 @@ class CommonVoiceDataset:
|
|
| 149 |
return len(self.data)
|
| 150 |
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
def get_args():
|
| 153 |
parser = argparse.ArgumentParser(prog="whisper", description="Test WER on dataset")
|
| 154 |
parser.add_argument(
|
|
@@ -156,7 +266,7 @@ def get_args():
|
|
| 156 |
"-d",
|
| 157 |
type=str,
|
| 158 |
required=True,
|
| 159 |
-
choices=["aishell", "common_voice"],
|
| 160 |
help="Test dataset",
|
| 161 |
)
|
| 162 |
parser.add_argument(
|
|
@@ -173,7 +283,7 @@ def get_args():
|
|
| 173 |
"--model_type",
|
| 174 |
"-t",
|
| 175 |
type=str,
|
| 176 |
-
choices=["tiny", "base", "small", "large", "large-v3", "turbo"],
|
| 177 |
required=True,
|
| 178 |
help="model type, only support tiny, base and small currently",
|
| 179 |
)
|
|
@@ -182,8 +292,11 @@ def get_args():
|
|
| 182 |
"-p",
|
| 183 |
type=str,
|
| 184 |
required=False,
|
| 185 |
-
default="../models
|
| 186 |
-
help="model path for *.axmodel, tokens.txt
|
|
|
|
|
|
|
|
|
|
| 187 |
)
|
| 188 |
parser.add_argument(
|
| 189 |
"--language",
|
|
@@ -193,17 +306,16 @@ def get_args():
|
|
| 193 |
default="zh",
|
| 194 |
help="Target language, support en, zh, ja, and others. See languages.py for more options.",
|
| 195 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
return parser.parse_args()
|
| 197 |
|
| 198 |
|
| 199 |
def print_args(args):
|
| 200 |
logger = logging.getLogger()
|
| 201 |
-
logger.info(
|
| 202 |
-
logger.info(f"gt_path: {args.gt_path}")
|
| 203 |
-
logger.info(f"max_num: {args.max_num}")
|
| 204 |
-
logger.info(f"model_type: {args.model_type}")
|
| 205 |
-
logger.info(f"model_path: {args.model_path}")
|
| 206 |
-
logger.info(f"language: {args.language}")
|
| 207 |
|
| 208 |
|
| 209 |
def min_distance(word1: str, word2: str) -> int:
|
|
@@ -247,10 +359,10 @@ def remove_punctuation(text):
|
|
| 247 |
|
| 248 |
|
| 249 |
def main():
|
| 250 |
-
# 设置日志系统
|
| 251 |
-
logger = setup_logging()
|
| 252 |
-
|
| 253 |
args = get_args()
|
|
|
|
|
|
|
|
|
|
| 254 |
print_args(args)
|
| 255 |
|
| 256 |
dataset_type = args.dataset.lower()
|
|
@@ -258,26 +370,88 @@ def main():
|
|
| 258 |
dataset = AIShellDataset(args.gt_path)
|
| 259 |
elif dataset_type == "common_voice":
|
| 260 |
dataset = CommonVoiceDataset(args.gt_path)
|
|
|
|
|
|
|
| 261 |
else:
|
| 262 |
raise ValueError(f"Unknown dataset type {dataset_type}")
|
| 263 |
|
| 264 |
max_num = args.max_num
|
| 265 |
|
| 266 |
# Load model
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
# Iterate over dataset
|
| 270 |
references = []
|
| 271 |
hyp = []
|
| 272 |
all_character_error_num = 0
|
| 273 |
all_character_num = 0
|
| 274 |
-
wer_file = open("wer.txt", "w")
|
| 275 |
max_data_num = max_num if max_num > 0 else len(dataset)
|
| 276 |
for n, (audio_path, reference) in enumerate(dataset):
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
-
hypothesis = remove_punctuation(hypothesis)
|
| 280 |
-
reference = remove_punctuation(reference)
|
| 281 |
|
| 282 |
character_error_num = min_distance(reference, hypothesis)
|
| 283 |
character_num = len(reference)
|
|
@@ -290,7 +464,6 @@ def main():
|
|
| 290 |
references.append(reference)
|
| 291 |
|
| 292 |
line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
|
| 293 |
-
wer_file.write(line_content + "\n")
|
| 294 |
logger.info(line_content)
|
| 295 |
|
| 296 |
if n + 1 >= max_data_num:
|
|
@@ -299,8 +472,6 @@ def main():
|
|
| 299 |
total_character_error_rate = all_character_error_num / all_character_num * 100
|
| 300 |
|
| 301 |
logger.info(f"Total WER: {total_character_error_rate}%")
|
| 302 |
-
wer_file.write(f"Total WER: {total_character_error_rate}%")
|
| 303 |
-
wer_file.close()
|
| 304 |
|
| 305 |
|
| 306 |
if __name__ == "__main__":
|
|
|
|
| 2 |
import os
|
| 3 |
import logging
|
| 4 |
import re
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
import numpy as np
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
import zhconv
|
| 10 |
+
import librosa
|
| 11 |
|
| 12 |
|
| 13 |
+
def setup_logging(filename):
|
| 14 |
"""配置日志系统,同时输出到控制台和文件"""
|
| 15 |
# 获取脚本所在目录
|
| 16 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
log_file = os.path.join(script_dir, f"{filename}.log")
|
| 18 |
|
| 19 |
# 配置日志格式
|
| 20 |
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
|
|
|
| 29 |
logger.removeHandler(handler)
|
| 30 |
|
| 31 |
# 创建文件handler
|
| 32 |
+
file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
|
| 33 |
file_handler.setLevel(logging.INFO)
|
| 34 |
file_formatter = logging.Formatter(log_format, date_format)
|
| 35 |
file_handler.setFormatter(file_formatter)
|
|
|
|
| 47 |
return logger
|
| 48 |
|
| 49 |
|
| 50 |
+
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
|
| 51 |
+
data, sample_rate = sf.read(
|
| 52 |
+
filename,
|
| 53 |
+
always_2d=True,
|
| 54 |
+
dtype="float32",
|
| 55 |
+
)
|
| 56 |
+
data = data[:, 0] # use only the first channel
|
| 57 |
+
if sample_rate != 16000:
|
| 58 |
+
wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
|
| 59 |
+
sample_rate = 16000
|
| 60 |
+
samples = np.ascontiguousarray(data)
|
| 61 |
+
return samples, sample_rate
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def compute_feat(filename: str, n_mels: int = 80):
|
| 65 |
+
audio, sample_rate = load_audio(filename)
|
| 66 |
+
if sample_rate != 16000:
|
| 67 |
+
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
|
| 68 |
+
sample_rate = 16000
|
| 69 |
+
|
| 70 |
+
mel = librosa.feature.melspectrogram(
|
| 71 |
+
y=audio,
|
| 72 |
+
sr=sample_rate,
|
| 73 |
+
n_fft=480,
|
| 74 |
+
hop_length=160,
|
| 75 |
+
window="hann",
|
| 76 |
+
center=True,
|
| 77 |
+
pad_mode="reflect",
|
| 78 |
+
power=2.0,
|
| 79 |
+
n_mels=n_mels,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
log_spec = np.log10(np.maximum(mel, 1e-10))
|
| 83 |
+
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
| 84 |
+
mel = (log_spec + 4.0) / 4.0
|
| 85 |
+
|
| 86 |
+
target = 3000
|
| 87 |
+
if mel.shape[1] > target:
|
| 88 |
+
# -50 so that there are some zero tail paddings.
|
| 89 |
+
mel = mel[:, :target]
|
| 90 |
+
mel[:, -50:] = 0
|
| 91 |
+
|
| 92 |
+
# We don't need to pad it to 30 seconds now!
|
| 93 |
+
if mel.shape[1] < target:
|
| 94 |
+
mel = np.concatenate(
|
| 95 |
+
(
|
| 96 |
+
mel,
|
| 97 |
+
np.zeros((n_mels, target - mel.shape[1]), dtype=np.float32),
|
| 98 |
+
),
|
| 99 |
+
axis=-1,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
return mel[np.newaxis, ...]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
class AIShellDataset:
|
| 106 |
def __init__(self, gt_path: str):
|
| 107 |
"""
|
|
|
|
| 209 |
return len(self.data)
|
| 210 |
|
| 211 |
|
| 212 |
+
class CustomDataset:
|
| 213 |
+
"""自定义数据集解析器"""
|
| 214 |
+
|
| 215 |
+
def __init__(self, label_path: str):
|
| 216 |
+
"""
|
| 217 |
+
初始化数据集
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
self.label_path = label_path
|
| 221 |
+
self.dataset_dir = os.path.dirname(label_path)
|
| 222 |
+
|
| 223 |
+
# 检查必要文件和文件夹是否存在
|
| 224 |
+
assert os.path.exists(label_path), f"{label_path}文件不存在: {label_path}"
|
| 225 |
+
|
| 226 |
+
# 加载csv
|
| 227 |
+
self.data = []
|
| 228 |
+
df = pd.read_csv(label_path, sep="\t")
|
| 229 |
+
for i, row in df.iterrows():
|
| 230 |
+
audio_path = os.path.join(
|
| 231 |
+
self.dataset_dir, row["SPEAKER_ID"], row["UTTRANS_ID"]
|
| 232 |
+
)
|
| 233 |
+
gt = row["TRANSCRIPTION"]
|
| 234 |
+
self.data.append({"audio_path": audio_path, "gt": gt})
|
| 235 |
+
|
| 236 |
+
# 使用logging而不是print
|
| 237 |
+
logger = logging.getLogger()
|
| 238 |
+
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 239 |
+
|
| 240 |
+
def __iter__(self):
|
| 241 |
+
"""返回迭代器"""
|
| 242 |
+
self.index = 0
|
| 243 |
+
return self
|
| 244 |
+
|
| 245 |
+
def __next__(self):
|
| 246 |
+
"""返回下一个数据项"""
|
| 247 |
+
if self.index >= len(self.data):
|
| 248 |
+
raise StopIteration
|
| 249 |
+
|
| 250 |
+
item = self.data[self.index]
|
| 251 |
+
audio_path = item["audio_path"]
|
| 252 |
+
ground_truth = item["gt"]
|
| 253 |
+
|
| 254 |
+
self.index += 1
|
| 255 |
+
return audio_path, ground_truth
|
| 256 |
+
|
| 257 |
+
def __len__(self):
|
| 258 |
+
"""返回数据集大小"""
|
| 259 |
+
return len(self.data)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
def get_args():
|
| 263 |
parser = argparse.ArgumentParser(prog="whisper", description="Test WER on dataset")
|
| 264 |
parser.add_argument(
|
|
|
|
| 266 |
"-d",
|
| 267 |
type=str,
|
| 268 |
required=True,
|
| 269 |
+
choices=["aishell", "common_voice", "custom"],
|
| 270 |
help="Test dataset",
|
| 271 |
)
|
| 272 |
parser.add_argument(
|
|
|
|
| 283 |
"--model_type",
|
| 284 |
"-t",
|
| 285 |
type=str,
|
| 286 |
+
choices=["tiny", "base", "small", "medium", "large", "large-v3", "turbo"],
|
| 287 |
required=True,
|
| 288 |
help="model type, only support tiny, base and small currently",
|
| 289 |
)
|
|
|
|
| 292 |
"-p",
|
| 293 |
type=str,
|
| 294 |
required=False,
|
| 295 |
+
default="../models-ax650",
|
| 296 |
+
help="model path for *.axmodel, tokens.txt",
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument(
|
| 299 |
+
"--repo_id", type=str, default=None, help="repo id from huggingface"
|
| 300 |
)
|
| 301 |
parser.add_argument(
|
| 302 |
"--language",
|
|
|
|
| 306 |
default="zh",
|
| 307 |
help="Target language, support en, zh, ja, and others. See languages.py for more options.",
|
| 308 |
)
|
| 309 |
+
parser.add_argument(
|
| 310 |
+
"--backend", type=str, default="ax", choices=["ax", "torch", "onnx"]
|
| 311 |
+
)
|
| 312 |
+
parser.add_argument("--log_name", type=str, default="test_wer")
|
| 313 |
return parser.parse_args()
|
| 314 |
|
| 315 |
|
| 316 |
def print_args(args):
|
| 317 |
logger = logging.getLogger()
|
| 318 |
+
logger.info(vars(args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
|
| 321 |
def min_distance(word1: str, word2: str) -> int:
|
|
|
|
| 359 |
|
| 360 |
|
| 361 |
def main():
|
|
|
|
|
|
|
|
|
|
| 362 |
args = get_args()
|
| 363 |
+
|
| 364 |
+
# 设置日志系统
|
| 365 |
+
logger = setup_logging(args.log_name)
|
| 366 |
print_args(args)
|
| 367 |
|
| 368 |
dataset_type = args.dataset.lower()
|
|
|
|
| 370 |
dataset = AIShellDataset(args.gt_path)
|
| 371 |
elif dataset_type == "common_voice":
|
| 372 |
dataset = CommonVoiceDataset(args.gt_path)
|
| 373 |
+
elif dataset_type == "custom":
|
| 374 |
+
dataset = CustomDataset(args.gt_path)
|
| 375 |
else:
|
| 376 |
raise ValueError(f"Unknown dataset type {dataset_type}")
|
| 377 |
|
| 378 |
max_num = args.max_num
|
| 379 |
|
| 380 |
# Load model
|
| 381 |
+
use_hf_model = False
|
| 382 |
+
tokenizer = None
|
| 383 |
+
task = "transcribe"
|
| 384 |
+
|
| 385 |
+
if args.backend == "ax":
|
| 386 |
+
from whisper_ax import Whisper
|
| 387 |
+
|
| 388 |
+
model = Whisper(args.model_type, args.model_path, args.language, task)
|
| 389 |
+
elif args.backend == "torch":
|
| 390 |
+
if args.repo_id is not None:
|
| 391 |
+
use_hf_model = True
|
| 392 |
+
|
| 393 |
+
from transformers import WhisperForConditionalGeneration
|
| 394 |
+
import torch
|
| 395 |
+
|
| 396 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
| 397 |
+
args.repo_id,
|
| 398 |
+
dtype=torch.float32,
|
| 399 |
+
).cpu()
|
| 400 |
+
else:
|
| 401 |
+
import whisper
|
| 402 |
+
|
| 403 |
+
model = whisper.load_model(args.model_type).cpu()
|
| 404 |
+
|
| 405 |
+
tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True)
|
| 406 |
+
elif args.backend == "onnx":
|
| 407 |
+
import onnxruntime as ort
|
| 408 |
+
from ..model_convert.generate_data import OnnxModel
|
| 409 |
+
|
| 410 |
+
encoder_path = os.path.join(
|
| 411 |
+
args.model_path, f"{args.model_type}/{args.model_type}-encoder.onnx"
|
| 412 |
+
)
|
| 413 |
+
decoder_path = os.path.join(
|
| 414 |
+
args.model_path, f"{args.model_type}/{args.model_type}-decoder.onnx"
|
| 415 |
+
)
|
| 416 |
+
model = OnnxModel(encoder_path, decoder_path)
|
| 417 |
|
| 418 |
# Iterate over dataset
|
| 419 |
references = []
|
| 420 |
hyp = []
|
| 421 |
all_character_error_num = 0
|
| 422 |
all_character_num = 0
|
|
|
|
| 423 |
max_data_num = max_num if max_num > 0 else len(dataset)
|
| 424 |
for n, (audio_path, reference) in enumerate(dataset):
|
| 425 |
+
if args.backend == "ax":
|
| 426 |
+
hypothesis = model.run(audio_path)
|
| 427 |
+
elif args.backend == "torch":
|
| 428 |
+
if use_hf_model:
|
| 429 |
+
with torch.no_grad():
|
| 430 |
+
feature = compute_feat(audio_path, model.config.num_mel_bins)
|
| 431 |
+
r = model.generate(
|
| 432 |
+
torch.from_numpy(feature),
|
| 433 |
+
output_scores=True,
|
| 434 |
+
return_dict_in_generate=True,
|
| 435 |
+
return_timestamps=False,
|
| 436 |
+
language=args.language,
|
| 437 |
+
task="transcribe",
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
tokens = r["sequences"][0][4:-1]
|
| 441 |
+
hypothesis = "".join(tokenizer.decode(tokens)).strip()
|
| 442 |
+
else:
|
| 443 |
+
result = model.transcribe(
|
| 444 |
+
audio_path, fp16=False, language=args.language
|
| 445 |
+
)
|
| 446 |
+
hypothesis = result["text"]
|
| 447 |
+
if args.language == "zh":
|
| 448 |
+
hypothesis = zhconv.convert(hypothesis, "zh-hans")
|
| 449 |
+
|
| 450 |
+
elif args.backend == "onnx":
|
| 451 |
+
hypothesis = model.run(audio_path, args.language, task)
|
| 452 |
|
| 453 |
+
hypothesis = remove_punctuation(hypothesis).lower()
|
| 454 |
+
reference = remove_punctuation(reference).lower()
|
| 455 |
|
| 456 |
character_error_num = min_distance(reference, hypothesis)
|
| 457 |
character_num = len(reference)
|
|
|
|
| 464 |
references.append(reference)
|
| 465 |
|
| 466 |
line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
|
|
|
|
| 467 |
logger.info(line_content)
|
| 468 |
|
| 469 |
if n + 1 >= max_data_num:
|
|
|
|
| 472 |
total_character_error_rate = all_character_error_num / all_character_num * 100
|
| 473 |
|
| 474 |
logger.info(f"Total WER: {total_character_error_rate}%")
|
|
|
|
|
|
|
| 475 |
|
| 476 |
|
| 477 |
if __name__ == "__main__":
|
python/{whisper.py → whisper_ax.py}
RENAMED
|
@@ -2,11 +2,11 @@ import axengine as axe
|
|
| 2 |
import numpy as np
|
| 3 |
import librosa
|
| 4 |
import os
|
| 5 |
-
from typing import Union
|
| 6 |
-
from whisper_tokenizer import *
|
| 7 |
import json
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
import zhconv
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass
|
|
@@ -34,11 +34,9 @@ class WhisperConfig:
|
|
| 34 |
|
| 35 |
class Whisper:
|
| 36 |
def __init__(self, model_type: str, model_path: str, language: str, task: str):
|
| 37 |
-
assert task in ["translate", "transcribe"]
|
| 38 |
-
|
| 39 |
self.language = language
|
| 40 |
self.task = task
|
| 41 |
-
self.encoder, self.decoder,
|
| 42 |
model_type, model_path, language, task
|
| 43 |
)
|
| 44 |
self.config = self.load_config(model_config)
|
|
@@ -73,16 +71,20 @@ class Whisper:
|
|
| 73 |
model_config["all_language_codes"] = [
|
| 74 |
i for i in model_config["all_language_codes"].split(",")
|
| 75 |
]
|
| 76 |
-
tokenizer = get_tokenizer(
|
| 77 |
-
model_config["is_multilingual"],
|
| 78 |
-
num_languages=len(model_config["all_language_codes"]),
|
| 79 |
-
language=language,
|
| 80 |
-
task=task,
|
| 81 |
-
)
|
| 82 |
|
| 83 |
self.id2token = self.load_tokens(required_files[3])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
return encoder, decoder,
|
| 86 |
|
| 87 |
def load_config(self, model_config):
|
| 88 |
config = WhisperConfig
|
|
@@ -109,6 +111,7 @@ class Whisper:
|
|
| 109 |
task_token = (
|
| 110 |
config.transcribe if self.task == "transcribe" else config.translate
|
| 111 |
)
|
|
|
|
| 112 |
config.sot_sequence = np.array(
|
| 113 |
[config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32
|
| 114 |
)
|
|
@@ -124,9 +127,14 @@ class Whisper:
|
|
| 124 |
return tokens
|
| 125 |
|
| 126 |
def load_audio(self, audio: str):
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
def compute_feature(self, audio: np.ndarray):
|
| 132 |
mel = librosa.feature.melspectrogram(
|
|
@@ -189,19 +197,27 @@ class Whisper:
|
|
| 189 |
return out
|
| 190 |
|
| 191 |
def get_self_cache(self) -> List[np.ndarray]:
|
| 192 |
-
self_cache = []
|
| 193 |
batch_size = 1
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
def causal_mask_1d(self, n: int, L: int):
|
| 207 |
"""
|
|
@@ -214,47 +230,46 @@ class Whisper:
|
|
| 214 |
mask[:n] = 0
|
| 215 |
return mask
|
| 216 |
|
| 217 |
-
def
|
| 218 |
-
|
| 219 |
-
audio, sample_rate = self.load_audio(audio)
|
| 220 |
-
|
| 221 |
-
mel = self.compute_feature(audio)
|
| 222 |
-
|
| 223 |
-
cross_kv = self.run_encoder(mel)
|
| 224 |
|
| 225 |
-
|
| 226 |
|
| 227 |
offset = np.array([0], dtype=np.int32)
|
| 228 |
for t in self.config.sot_sequence:
|
| 229 |
token = np.array([[t]], dtype=np.int32) # sot
|
| 230 |
mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
|
| 231 |
|
| 232 |
-
|
|
|
|
|
|
|
| 233 |
|
| 234 |
-
|
| 235 |
-
|
| 236 |
|
| 237 |
offset += 1
|
| 238 |
|
| 239 |
-
idx =
|
| 240 |
|
| 241 |
eot = self.config.eot
|
| 242 |
|
| 243 |
ans = []
|
| 244 |
|
| 245 |
-
while idx != eot and offset.item() <
|
| 246 |
ans.append(idx)
|
| 247 |
token = np.array([[idx]], dtype=np.int32)
|
| 248 |
|
| 249 |
mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
|
| 250 |
|
| 251 |
-
|
|
|
|
|
|
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
|
| 256 |
offset += 1
|
| 257 |
-
idx =
|
| 258 |
|
| 259 |
# print(ans)
|
| 260 |
|
|
@@ -273,3 +288,19 @@ class Whisper:
|
|
| 273 |
return text
|
| 274 |
|
| 275 |
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import librosa
|
| 4 |
import os
|
| 5 |
+
from typing import Union, List
|
|
|
|
| 6 |
import json
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
import zhconv
|
| 9 |
+
import base64
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass
|
|
|
|
| 34 |
|
| 35 |
class Whisper:
|
| 36 |
def __init__(self, model_type: str, model_path: str, language: str, task: str):
|
|
|
|
|
|
|
| 37 |
self.language = language
|
| 38 |
self.task = task
|
| 39 |
+
self.encoder, self.decoder, model_config = self.load_model(
|
| 40 |
model_type, model_path, language, task
|
| 41 |
)
|
| 42 |
self.config = self.load_config(model_config)
|
|
|
|
| 71 |
model_config["all_language_codes"] = [
|
| 72 |
i for i in model_config["all_language_codes"].split(",")
|
| 73 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
self.id2token = self.load_tokens(required_files[3])
|
| 76 |
+
self.lang2token = {
|
| 77 |
+
k: v
|
| 78 |
+
for k, v in zip(
|
| 79 |
+
model_config["all_language_codes"], model_config["all_language_tokens"]
|
| 80 |
+
)
|
| 81 |
+
}
|
| 82 |
+
self.task2token = {
|
| 83 |
+
"transcribe": model_config["transcribe"],
|
| 84 |
+
"translate": model_config["translate"],
|
| 85 |
+
}
|
| 86 |
|
| 87 |
+
return encoder, decoder, model_config
|
| 88 |
|
| 89 |
def load_config(self, model_config):
|
| 90 |
config = WhisperConfig
|
|
|
|
| 111 |
task_token = (
|
| 112 |
config.transcribe if self.task == "transcribe" else config.translate
|
| 113 |
)
|
| 114 |
+
|
| 115 |
config.sot_sequence = np.array(
|
| 116 |
[config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32
|
| 117 |
)
|
|
|
|
| 127 |
return tokens
|
| 128 |
|
| 129 |
def load_audio(self, audio: str):
|
| 130 |
+
samples, sample_rate = librosa.load(audio, sr=self.config.sample_rate)
|
| 131 |
+
if sample_rate != self.config.sample_rate:
|
| 132 |
+
samples = librosa.resample(
|
| 133 |
+
samples, orig_sr=sample_rate, target_sr=self.config.sample_rate
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
samples = np.ascontiguousarray(samples)
|
| 137 |
+
return samples, self.config.sample_rate
|
| 138 |
|
| 139 |
def compute_feature(self, audio: np.ndarray):
|
| 140 |
mel = librosa.feature.melspectrogram(
|
|
|
|
| 197 |
return out
|
| 198 |
|
| 199 |
def get_self_cache(self) -> List[np.ndarray]:
|
|
|
|
| 200 |
batch_size = 1
|
| 201 |
+
|
| 202 |
+
self_k = np.zeros(
|
| 203 |
+
(
|
| 204 |
+
self.config.n_text_layer,
|
| 205 |
+
batch_size,
|
| 206 |
+
self.config.n_text_ctx,
|
| 207 |
+
self.config.n_text_state,
|
| 208 |
+
),
|
| 209 |
+
dtype=np.float32,
|
| 210 |
+
)
|
| 211 |
+
self_v = np.zeros(
|
| 212 |
+
(
|
| 213 |
+
self.config.n_text_layer,
|
| 214 |
+
batch_size,
|
| 215 |
+
self.config.n_text_ctx,
|
| 216 |
+
self.config.n_text_state,
|
| 217 |
+
),
|
| 218 |
+
dtype=np.float32,
|
| 219 |
+
)
|
| 220 |
+
return self_k, self_v
|
| 221 |
|
| 222 |
def causal_mask_1d(self, n: int, L: int):
|
| 223 |
"""
|
|
|
|
| 230 |
mask[:n] = 0
|
| 231 |
return mask
|
| 232 |
|
| 233 |
+
def run_mel(self, mel):
|
| 234 |
+
cross_k, cross_v = self.run_encoder(mel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
+
self_k, self_v = self.get_self_cache()
|
| 237 |
|
| 238 |
offset = np.array([0], dtype=np.int32)
|
| 239 |
for t in self.config.sot_sequence:
|
| 240 |
token = np.array([[t]], dtype=np.int32) # sot
|
| 241 |
mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
|
| 242 |
|
| 243 |
+
logits, this_self_k, this_self_v = self.run_decoder(
|
| 244 |
+
[token] + [self_k, self_v] + [cross_k, cross_v] + [offset, mask]
|
| 245 |
+
)
|
| 246 |
|
| 247 |
+
self_k[:, :, offset.item() : offset.item() + 1, :] = this_self_k
|
| 248 |
+
self_v[:, :, offset.item() : offset.item() + 1, :] = this_self_v
|
| 249 |
|
| 250 |
offset += 1
|
| 251 |
|
| 252 |
+
idx = logits[0, 0].argmax()
|
| 253 |
|
| 254 |
eot = self.config.eot
|
| 255 |
|
| 256 |
ans = []
|
| 257 |
|
| 258 |
+
while idx != eot and offset.item() < self.config.n_text_ctx:
|
| 259 |
ans.append(idx)
|
| 260 |
token = np.array([[idx]], dtype=np.int32)
|
| 261 |
|
| 262 |
mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
|
| 263 |
|
| 264 |
+
logits, this_self_k, this_self_v = self.run_decoder(
|
| 265 |
+
[token] + [self_k, self_v] + [cross_k, cross_v] + [offset, mask]
|
| 266 |
+
)
|
| 267 |
|
| 268 |
+
self_k[:, :, offset.item() : offset.item() + 1, :] = this_self_k
|
| 269 |
+
self_v[:, :, offset.item() : offset.item() + 1, :] = this_self_v
|
| 270 |
|
| 271 |
offset += 1
|
| 272 |
+
idx = logits[0, 0].argmax()
|
| 273 |
|
| 274 |
# print(ans)
|
| 275 |
|
|
|
|
| 288 |
return text
|
| 289 |
|
| 290 |
return text
|
| 291 |
+
|
| 292 |
+
def run(
|
| 293 |
+
self, audio: Union[str, np.ndarray], language: str = None, task: str = None
|
| 294 |
+
) -> str:
|
| 295 |
+
if isinstance(audio, str):
|
| 296 |
+
audio, sample_rate = self.load_audio(audio)
|
| 297 |
+
|
| 298 |
+
mel = self.compute_feature(audio)
|
| 299 |
+
|
| 300 |
+
if language is not None and self.language != language:
|
| 301 |
+
self.config.sot_sequence[1] = self.lang2token(language)
|
| 302 |
+
|
| 303 |
+
if task is not None and self.task != task:
|
| 304 |
+
self.config.sot_sequence[2] = self.task2token(task)
|
| 305 |
+
|
| 306 |
+
return self.run_mel(mel)
|
python/whisper_cli.py
CHANGED
|
@@ -1,46 +1,70 @@
|
|
| 1 |
-
import
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
-
def
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
files = {
|
| 15 |
-
"wav": open(wav_path, "rb"),
|
| 16 |
-
}
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
"language": language,
|
| 22 |
-
"task": task,
|
| 23 |
-
}
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
if response.status_code != 200:
|
| 29 |
-
print("❌ Error:", response.text)
|
| 30 |
-
return None
|
| 31 |
|
| 32 |
-
result
|
| 33 |
-
|
| 34 |
-
print(
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
if __name__ == "__main__":
|
| 40 |
-
# 你的服务器地址
|
| 41 |
-
SERVER = "http://127.0.0.1:8000"
|
| 42 |
-
|
| 43 |
-
# 本地 wav 文件路径
|
| 44 |
-
WAV = "../demo.wav"
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from whisper_ax import Whisper
|
| 4 |
+
import time
|
| 5 |
|
| 6 |
|
| 7 |
+
def get_args():
|
| 8 |
+
parser = argparse.ArgumentParser(
|
| 9 |
+
prog="whisper", description="Run Whisper on input audio file"
|
| 10 |
+
)
|
| 11 |
+
parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
|
| 12 |
+
parser.add_argument(
|
| 13 |
+
"--model_type",
|
| 14 |
+
"-t",
|
| 15 |
+
type=str,
|
| 16 |
+
choices=["tiny", "base", "small", "large", "large-v3", "turbo"],
|
| 17 |
+
required=True,
|
| 18 |
+
help="model type, only support tiny, base and small currently",
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--model_path",
|
| 22 |
+
"-p",
|
| 23 |
+
type=str,
|
| 24 |
+
required=False,
|
| 25 |
+
default="../models-ax650",
|
| 26 |
+
help="model path for *.axmodel, tokens.txt",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--language",
|
| 30 |
+
"-l",
|
| 31 |
+
type=str,
|
| 32 |
+
required=False,
|
| 33 |
+
default="zh",
|
| 34 |
+
help="Target language, support en, zh, ja, and others. See languages.py for more options.",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--task",
|
| 38 |
+
type=str,
|
| 39 |
+
required=False,
|
| 40 |
+
choices=["translate", "transcribe"],
|
| 41 |
+
default="transcribe",
|
| 42 |
+
)
|
| 43 |
+
return parser.parse_args()
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
def main():
|
| 47 |
+
args = get_args()
|
| 48 |
+
print(vars(args))
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
# Check wav existence
|
| 51 |
+
wav_path = args.wav
|
| 52 |
+
assert os.path.exists(wav_path), f"{wav_path} NOT exist"
|
| 53 |
|
| 54 |
+
model = Whisper(args.model_type, args.model_path, args.language, args.task)
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
print("ASR result:")
|
| 57 |
+
start = time.time()
|
| 58 |
+
print(model.run(wav_path))
|
| 59 |
+
end = time.time()
|
| 60 |
|
| 61 |
+
import librosa
|
| 62 |
|
| 63 |
+
samples, sr = librosa.load(wav_path, sr=16000)
|
| 64 |
+
duration = len(samples) / sr
|
| 65 |
+
process_time = end - start
|
| 66 |
+
print(f"RTF: {process_time / duration}")
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
main()
|
python/whisper_svr.py
CHANGED
|
@@ -5,7 +5,7 @@ import tempfile
|
|
| 5 |
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 6 |
from urllib.parse import parse_qs
|
| 7 |
|
| 8 |
-
from
|
| 9 |
import cgi
|
| 10 |
|
| 11 |
|
|
|
|
| 5 |
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 6 |
from urllib.parse import parse_qs
|
| 7 |
|
| 8 |
+
from whisper_ax import Whisper
|
| 9 |
import cgi
|
| 10 |
|
| 11 |
|
python/whisper_tokenizer.py
DELETED
|
@@ -1,395 +0,0 @@
|
|
| 1 |
-
import base64
|
| 2 |
-
import os
|
| 3 |
-
import string
|
| 4 |
-
from dataclasses import dataclass, field
|
| 5 |
-
from functools import cached_property, lru_cache
|
| 6 |
-
from typing import Dict, List, Optional, Tuple
|
| 7 |
-
|
| 8 |
-
import tiktoken
|
| 9 |
-
|
| 10 |
-
LANGUAGES = {
|
| 11 |
-
"en": "english",
|
| 12 |
-
"zh": "chinese",
|
| 13 |
-
"de": "german",
|
| 14 |
-
"es": "spanish",
|
| 15 |
-
"ru": "russian",
|
| 16 |
-
"ko": "korean",
|
| 17 |
-
"fr": "french",
|
| 18 |
-
"ja": "japanese",
|
| 19 |
-
"pt": "portuguese",
|
| 20 |
-
"tr": "turkish",
|
| 21 |
-
"pl": "polish",
|
| 22 |
-
"ca": "catalan",
|
| 23 |
-
"nl": "dutch",
|
| 24 |
-
"ar": "arabic",
|
| 25 |
-
"sv": "swedish",
|
| 26 |
-
"it": "italian",
|
| 27 |
-
"id": "indonesian",
|
| 28 |
-
"hi": "hindi",
|
| 29 |
-
"fi": "finnish",
|
| 30 |
-
"vi": "vietnamese",
|
| 31 |
-
"he": "hebrew",
|
| 32 |
-
"uk": "ukrainian",
|
| 33 |
-
"el": "greek",
|
| 34 |
-
"ms": "malay",
|
| 35 |
-
"cs": "czech",
|
| 36 |
-
"ro": "romanian",
|
| 37 |
-
"da": "danish",
|
| 38 |
-
"hu": "hungarian",
|
| 39 |
-
"ta": "tamil",
|
| 40 |
-
"no": "norwegian",
|
| 41 |
-
"th": "thai",
|
| 42 |
-
"ur": "urdu",
|
| 43 |
-
"hr": "croatian",
|
| 44 |
-
"bg": "bulgarian",
|
| 45 |
-
"lt": "lithuanian",
|
| 46 |
-
"la": "latin",
|
| 47 |
-
"mi": "maori",
|
| 48 |
-
"ml": "malayalam",
|
| 49 |
-
"cy": "welsh",
|
| 50 |
-
"sk": "slovak",
|
| 51 |
-
"te": "telugu",
|
| 52 |
-
"fa": "persian",
|
| 53 |
-
"lv": "latvian",
|
| 54 |
-
"bn": "bengali",
|
| 55 |
-
"sr": "serbian",
|
| 56 |
-
"az": "azerbaijani",
|
| 57 |
-
"sl": "slovenian",
|
| 58 |
-
"kn": "kannada",
|
| 59 |
-
"et": "estonian",
|
| 60 |
-
"mk": "macedonian",
|
| 61 |
-
"br": "breton",
|
| 62 |
-
"eu": "basque",
|
| 63 |
-
"is": "icelandic",
|
| 64 |
-
"hy": "armenian",
|
| 65 |
-
"ne": "nepali",
|
| 66 |
-
"mn": "mongolian",
|
| 67 |
-
"bs": "bosnian",
|
| 68 |
-
"kk": "kazakh",
|
| 69 |
-
"sq": "albanian",
|
| 70 |
-
"sw": "swahili",
|
| 71 |
-
"gl": "galician",
|
| 72 |
-
"mr": "marathi",
|
| 73 |
-
"pa": "punjabi",
|
| 74 |
-
"si": "sinhala",
|
| 75 |
-
"km": "khmer",
|
| 76 |
-
"sn": "shona",
|
| 77 |
-
"yo": "yoruba",
|
| 78 |
-
"so": "somali",
|
| 79 |
-
"af": "afrikaans",
|
| 80 |
-
"oc": "occitan",
|
| 81 |
-
"ka": "georgian",
|
| 82 |
-
"be": "belarusian",
|
| 83 |
-
"tg": "tajik",
|
| 84 |
-
"sd": "sindhi",
|
| 85 |
-
"gu": "gujarati",
|
| 86 |
-
"am": "amharic",
|
| 87 |
-
"yi": "yiddish",
|
| 88 |
-
"lo": "lao",
|
| 89 |
-
"uz": "uzbek",
|
| 90 |
-
"fo": "faroese",
|
| 91 |
-
"ht": "haitian creole",
|
| 92 |
-
"ps": "pashto",
|
| 93 |
-
"tk": "turkmen",
|
| 94 |
-
"nn": "nynorsk",
|
| 95 |
-
"mt": "maltese",
|
| 96 |
-
"sa": "sanskrit",
|
| 97 |
-
"lb": "luxembourgish",
|
| 98 |
-
"my": "myanmar",
|
| 99 |
-
"bo": "tibetan",
|
| 100 |
-
"tl": "tagalog",
|
| 101 |
-
"mg": "malagasy",
|
| 102 |
-
"as": "assamese",
|
| 103 |
-
"tt": "tatar",
|
| 104 |
-
"haw": "hawaiian",
|
| 105 |
-
"ln": "lingala",
|
| 106 |
-
"ha": "hausa",
|
| 107 |
-
"ba": "bashkir",
|
| 108 |
-
"jw": "javanese",
|
| 109 |
-
"su": "sundanese",
|
| 110 |
-
"yue": "cantonese",
|
| 111 |
-
}
|
| 112 |
-
|
| 113 |
-
# language code lookup by name, with a few language aliases
|
| 114 |
-
TO_LANGUAGE_CODE = {
|
| 115 |
-
**{language: code for code, language in LANGUAGES.items()},
|
| 116 |
-
"burmese": "my",
|
| 117 |
-
"valencian": "ca",
|
| 118 |
-
"flemish": "nl",
|
| 119 |
-
"haitian": "ht",
|
| 120 |
-
"letzeburgesch": "lb",
|
| 121 |
-
"pushto": "ps",
|
| 122 |
-
"panjabi": "pa",
|
| 123 |
-
"moldavian": "ro",
|
| 124 |
-
"moldovan": "ro",
|
| 125 |
-
"sinhalese": "si",
|
| 126 |
-
"castilian": "es",
|
| 127 |
-
"mandarin": "zh",
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
@dataclass
|
| 132 |
-
class Tokenizer:
|
| 133 |
-
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
| 134 |
-
|
| 135 |
-
encoding: tiktoken.Encoding
|
| 136 |
-
num_languages: int
|
| 137 |
-
language: Optional[str] = None
|
| 138 |
-
task: Optional[str] = None
|
| 139 |
-
sot_sequence: Tuple[int] = ()
|
| 140 |
-
special_tokens: Dict[str, int] = field(default_factory=dict)
|
| 141 |
-
|
| 142 |
-
def __post_init__(self):
|
| 143 |
-
for special in self.encoding.special_tokens_set:
|
| 144 |
-
special_token = self.encoding.encode_single_token(special)
|
| 145 |
-
self.special_tokens[special] = special_token
|
| 146 |
-
|
| 147 |
-
sot: int = self.special_tokens["<|startoftranscript|>"]
|
| 148 |
-
translate: int = self.special_tokens["<|translate|>"]
|
| 149 |
-
transcribe: int = self.special_tokens["<|transcribe|>"]
|
| 150 |
-
|
| 151 |
-
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
| 152 |
-
sot_sequence = [sot]
|
| 153 |
-
if self.language is not None:
|
| 154 |
-
sot_sequence.append(sot + 1 + langs.index(self.language))
|
| 155 |
-
if self.task is not None:
|
| 156 |
-
task_token: int = transcribe if self.task == "transcribe" else translate
|
| 157 |
-
sot_sequence.append(task_token)
|
| 158 |
-
|
| 159 |
-
self.sot_sequence = tuple(sot_sequence)
|
| 160 |
-
|
| 161 |
-
def encode(self, text, **kwargs):
|
| 162 |
-
return self.encoding.encode(text, **kwargs)
|
| 163 |
-
|
| 164 |
-
def decode(self, token_ids: List[int], **kwargs) -> str:
|
| 165 |
-
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
| 166 |
-
return self.encoding.decode(token_ids, **kwargs)
|
| 167 |
-
|
| 168 |
-
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
| 169 |
-
"""
|
| 170 |
-
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
| 171 |
-
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
| 172 |
-
"""
|
| 173 |
-
return self.encoding.decode(token_ids, **kwargs)
|
| 174 |
-
|
| 175 |
-
@cached_property
|
| 176 |
-
def eot(self) -> int:
|
| 177 |
-
return self.encoding.eot_token
|
| 178 |
-
|
| 179 |
-
@cached_property
|
| 180 |
-
def transcribe(self) -> int:
|
| 181 |
-
return self.special_tokens["<|transcribe|>"]
|
| 182 |
-
|
| 183 |
-
@cached_property
|
| 184 |
-
def translate(self) -> int:
|
| 185 |
-
return self.special_tokens["<|translate|>"]
|
| 186 |
-
|
| 187 |
-
@cached_property
|
| 188 |
-
def sot(self) -> int:
|
| 189 |
-
return self.special_tokens["<|startoftranscript|>"]
|
| 190 |
-
|
| 191 |
-
@cached_property
|
| 192 |
-
def sot_lm(self) -> int:
|
| 193 |
-
return self.special_tokens["<|startoflm|>"]
|
| 194 |
-
|
| 195 |
-
@cached_property
|
| 196 |
-
def sot_prev(self) -> int:
|
| 197 |
-
return self.special_tokens["<|startofprev|>"]
|
| 198 |
-
|
| 199 |
-
@cached_property
|
| 200 |
-
def no_speech(self) -> int:
|
| 201 |
-
return self.special_tokens["<|nospeech|>"]
|
| 202 |
-
|
| 203 |
-
@cached_property
|
| 204 |
-
def no_timestamps(self) -> int:
|
| 205 |
-
return self.special_tokens["<|notimestamps|>"]
|
| 206 |
-
|
| 207 |
-
@cached_property
|
| 208 |
-
def timestamp_begin(self) -> int:
|
| 209 |
-
return self.special_tokens["<|0.00|>"]
|
| 210 |
-
|
| 211 |
-
@cached_property
|
| 212 |
-
def language_token(self) -> int:
|
| 213 |
-
"""Returns the token id corresponding to the value of the `language` field"""
|
| 214 |
-
if self.language is None:
|
| 215 |
-
raise ValueError("This tokenizer does not have language token configured")
|
| 216 |
-
|
| 217 |
-
return self.to_language_token(self.language)
|
| 218 |
-
|
| 219 |
-
def to_language_token(self, language):
|
| 220 |
-
if token := self.special_tokens.get(f"<|{language}|>", None):
|
| 221 |
-
return token
|
| 222 |
-
|
| 223 |
-
raise KeyError(f"Language {language} not found in tokenizer.")
|
| 224 |
-
|
| 225 |
-
@cached_property
|
| 226 |
-
def all_language_tokens(self) -> Tuple[int]:
|
| 227 |
-
result = []
|
| 228 |
-
for token, token_id in self.special_tokens.items():
|
| 229 |
-
if token.strip("<|>") in LANGUAGES:
|
| 230 |
-
result.append(token_id)
|
| 231 |
-
return tuple(result)[: self.num_languages]
|
| 232 |
-
|
| 233 |
-
@cached_property
|
| 234 |
-
def all_language_codes(self) -> Tuple[str]:
|
| 235 |
-
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
| 236 |
-
|
| 237 |
-
@cached_property
|
| 238 |
-
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
| 239 |
-
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
| 240 |
-
|
| 241 |
-
@cached_property
|
| 242 |
-
def non_speech_tokens(self) -> Tuple[int]:
|
| 243 |
-
"""
|
| 244 |
-
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
| 245 |
-
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
| 246 |
-
|
| 247 |
-
- ♪♪♪
|
| 248 |
-
- ( SPEAKING FOREIGN LANGUAGE )
|
| 249 |
-
- [DAVID] Hey there,
|
| 250 |
-
|
| 251 |
-
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
| 252 |
-
"""
|
| 253 |
-
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
| 254 |
-
symbols += (
|
| 255 |
-
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
| 256 |
-
)
|
| 257 |
-
|
| 258 |
-
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
| 259 |
-
# In case they're multiple tokens, suppress the first token, which is safe because:
|
| 260 |
-
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
| 261 |
-
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
| 262 |
-
miscellaneous = set("♩♪♫♬♭♮♯")
|
| 263 |
-
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
| 264 |
-
|
| 265 |
-
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
| 266 |
-
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
| 267 |
-
for symbol in symbols + list(miscellaneous):
|
| 268 |
-
for tokens in [
|
| 269 |
-
self.encoding.encode(symbol),
|
| 270 |
-
self.encoding.encode(" " + symbol),
|
| 271 |
-
]:
|
| 272 |
-
if len(tokens) == 1 or symbol in miscellaneous:
|
| 273 |
-
result.add(tokens[0])
|
| 274 |
-
|
| 275 |
-
return tuple(sorted(result))
|
| 276 |
-
|
| 277 |
-
def split_to_word_tokens(self, tokens: List[int]):
|
| 278 |
-
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
| 279 |
-
# These languages don't typically use spaces, so it is difficult to split words
|
| 280 |
-
# without morpheme analysis. Here, we instead split words at any
|
| 281 |
-
# position where the tokens are decoded as valid unicode points
|
| 282 |
-
return self.split_tokens_on_unicode(tokens)
|
| 283 |
-
|
| 284 |
-
return self.split_tokens_on_spaces(tokens)
|
| 285 |
-
|
| 286 |
-
def split_tokens_on_unicode(self, tokens: List[int]):
|
| 287 |
-
decoded_full = self.decode_with_timestamps(tokens)
|
| 288 |
-
replacement_char = "\ufffd"
|
| 289 |
-
|
| 290 |
-
words = []
|
| 291 |
-
word_tokens = []
|
| 292 |
-
current_tokens = []
|
| 293 |
-
unicode_offset = 0
|
| 294 |
-
|
| 295 |
-
for token in tokens:
|
| 296 |
-
current_tokens.append(token)
|
| 297 |
-
decoded = self.decode_with_timestamps(current_tokens)
|
| 298 |
-
|
| 299 |
-
if (
|
| 300 |
-
replacement_char not in decoded
|
| 301 |
-
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
| 302 |
-
== replacement_char
|
| 303 |
-
):
|
| 304 |
-
words.append(decoded)
|
| 305 |
-
word_tokens.append(current_tokens)
|
| 306 |
-
current_tokens = []
|
| 307 |
-
unicode_offset += len(decoded)
|
| 308 |
-
|
| 309 |
-
return words, word_tokens
|
| 310 |
-
|
| 311 |
-
def split_tokens_on_spaces(self, tokens: List[int]):
|
| 312 |
-
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
| 313 |
-
words = []
|
| 314 |
-
word_tokens = []
|
| 315 |
-
|
| 316 |
-
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
| 317 |
-
special = subword_tokens[0] >= self.eot
|
| 318 |
-
with_space = subword.startswith(" ")
|
| 319 |
-
punctuation = subword.strip() in string.punctuation
|
| 320 |
-
if special or with_space or punctuation or len(words) == 0:
|
| 321 |
-
words.append(subword)
|
| 322 |
-
word_tokens.append(subword_tokens)
|
| 323 |
-
else:
|
| 324 |
-
words[-1] = words[-1] + subword
|
| 325 |
-
word_tokens[-1].extend(subword_tokens)
|
| 326 |
-
|
| 327 |
-
return words, word_tokens
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
@lru_cache(maxsize=None)
|
| 331 |
-
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
| 332 |
-
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
| 333 |
-
ranks = {
|
| 334 |
-
base64.b64decode(token): int(rank)
|
| 335 |
-
for token, rank in (line.split() for line in open(vocab_path) if line)
|
| 336 |
-
}
|
| 337 |
-
n_vocab = len(ranks)
|
| 338 |
-
special_tokens = {}
|
| 339 |
-
|
| 340 |
-
specials = [
|
| 341 |
-
"<|endoftext|>",
|
| 342 |
-
"<|startoftranscript|>",
|
| 343 |
-
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
| 344 |
-
"<|translate|>",
|
| 345 |
-
"<|transcribe|>",
|
| 346 |
-
"<|startoflm|>",
|
| 347 |
-
"<|startofprev|>",
|
| 348 |
-
"<|nospeech|>",
|
| 349 |
-
"<|notimestamps|>",
|
| 350 |
-
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
| 351 |
-
]
|
| 352 |
-
|
| 353 |
-
for token in specials:
|
| 354 |
-
special_tokens[token] = n_vocab
|
| 355 |
-
n_vocab += 1
|
| 356 |
-
|
| 357 |
-
return tiktoken.Encoding(
|
| 358 |
-
name=os.path.basename(vocab_path),
|
| 359 |
-
explicit_n_vocab=n_vocab,
|
| 360 |
-
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
| 361 |
-
mergeable_ranks=ranks,
|
| 362 |
-
special_tokens=special_tokens,
|
| 363 |
-
)
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
@lru_cache(maxsize=None)
|
| 367 |
-
def get_tokenizer(
|
| 368 |
-
multilingual: bool,
|
| 369 |
-
*,
|
| 370 |
-
num_languages: int = 99,
|
| 371 |
-
language: Optional[str] = None,
|
| 372 |
-
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
| 373 |
-
) -> Tokenizer:
|
| 374 |
-
if language is not None:
|
| 375 |
-
language = language.lower()
|
| 376 |
-
if language not in LANGUAGES:
|
| 377 |
-
if language in TO_LANGUAGE_CODE:
|
| 378 |
-
language = TO_LANGUAGE_CODE[language]
|
| 379 |
-
else:
|
| 380 |
-
raise ValueError(f"Unsupported language: {language}")
|
| 381 |
-
|
| 382 |
-
if multilingual:
|
| 383 |
-
encoding_name = "multilingual"
|
| 384 |
-
language = language or "en"
|
| 385 |
-
task = task or "transcribe"
|
| 386 |
-
else:
|
| 387 |
-
encoding_name = "gpt2"
|
| 388 |
-
language = None
|
| 389 |
-
task = None
|
| 390 |
-
|
| 391 |
-
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
| 392 |
-
|
| 393 |
-
return Tokenizer(
|
| 394 |
-
encoding=encoding, num_languages=num_languages, language=language, task=task
|
| 395 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|