inoryQwQ commited on
Commit
798e40d
·
1 Parent(s): 33455e8

Update cpp bins, python scripts, add English readme

Browse files
.gitattributes CHANGED
@@ -51,3 +51,7 @@ ax620e/install/whisper filter=lfs diff=lfs merge=lfs -text
51
  *axcl_aarch64/whisper filter=lfs diff=lfs merge=lfs -text
52
  *ax650/install/whisper filter=lfs diff=lfs merge=lfs -text
53
  *ax620e/install/whisper filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
51
  *axcl_aarch64/whisper filter=lfs diff=lfs merge=lfs -text
52
  *ax650/install/whisper filter=lfs diff=lfs merge=lfs -text
53
  *ax620e/install/whisper filter=lfs diff=lfs merge=lfs -text
54
+ cpp/ax630c/lib/libax_whisper.so filter=lfs diff=lfs merge=lfs -text
55
+ cpp/ax650/lib/libax_whisper.so filter=lfs diff=lfs merge=lfs -text
56
+ cpp/ax650/whisper_svr filter=lfs diff=lfs merge=lfs -text
57
+ cpp/ax630c/whisper_svr filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md CHANGED
@@ -5,6 +5,10 @@ pipeline_tag: automatic-speech-recognition
5
 
6
  # Whisper
7
 
 
 
 
 
8
  OpenAI Whisper on Axera
9
 
10
  - 目前支持 C++ 和 Python 两种语言
@@ -13,6 +17,12 @@ OpenAI Whisper on Axera
13
 
14
  - 如需自行转换请参考[模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
15
 
 
 
 
 
 
 
16
  ## 支持平台
17
 
18
  - [x] AX650N
@@ -20,6 +30,21 @@ OpenAI Whisper on Axera
20
 
21
  ## 模型转换
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  [模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
24
 
25
  ## 上板部署
@@ -87,60 +112,22 @@ python3 main.py --model_type small --model_path ../models-ax650 --wav ../demo.wa
87
  输出结果
88
 
89
  ```
90
- root@ax650:/mnt/qtang/whisper.axera/python# python3 main.py --wav ../demo.wav --model_type small --model_path ../models/ --language zh
91
  [INFO] Available providers: ['AxEngineExecutionProvider']
92
- wav: ../demo.wav
93
- model_type: small
94
- model_path: ../models/
95
- language: zh
96
  [INFO] Using provider: AxEngineExecutionProvider
97
  [INFO] Chip type: ChipType.MC50
98
  [INFO] VNPU type: VNPUType.DISABLED
99
- [INFO] Engine version: 2.10.1s
100
  [INFO] Model type: 2 (triple core)
101
- [INFO] Compiler version: 3.2-patch1 117f5fd4
102
  [INFO] Using provider: AxEngineExecutionProvider
103
  [INFO] Model type: 2 (triple core)
104
- [INFO] Compiler version: 3.2-patch1 117f5fd4
105
- [INFO] Using provider: AxEngineExecutionProvider
106
- [INFO] Model type: 2 (triple core)
107
- [INFO] Compiler version: 3.2-patch1 117f5fd4
108
- Load models take 2322.563409805298ms
109
- Preprocess wav take 6971.68493270874ms
110
- Run encoder take 211.52877807617188ms
111
- Run decoder_main take 79.00094985961914ms
112
- First token: 17556
113
- Run decoder_loop take 101.91774368286133ms
114
- Iter 0 Token: 20844
115
- Run decoder_loop take 60.30416488647461ms
116
- Iter 1 Token: 7781
117
- Run decoder_loop take 60.22000312805176ms
118
- Iter 2 Token: 20204
119
- Run decoder_loop take 60.23716926574707ms
120
- Iter 3 Token: 28455
121
- Run decoder_loop take 60.214996337890625ms
122
- Iter 4 Token: 31962
123
- Run decoder_loop take 60.17565727233887ms
124
- Iter 5 Token: 6336
125
- Run decoder_loop take 60.94002723693848ms
126
- Iter 6 Token: 254
127
- Run decoder_loop take 60.71639060974121ms
128
- Iter 7 Token: 2930
129
- Run decoder_loop take 60.225725173950195ms
130
- Iter 8 Token: 236
131
- Run decoder_loop take 60.167789459228516ms
132
- Iter 9 Token: 36135
133
- Run decoder_loop take 60.29987335205078ms
134
- Iter 10 Token: 15868
135
- Run decoder_loop take 61.163902282714844ms
136
- Iter 11 Token: 252
137
- Run decoder_loop take 60.273170471191406ms
138
- Iter 12 Token: 1546
139
- Run decoder_loop take 60.23144721984863ms
140
- Iter 13 Token: 46514
141
- Run decoder_loop take 60.31966209411621ms
142
- Iter 14 Token: 50257
143
- Result: 甚至出现交易几乎停滞的情况
144
  ```
145
 
146
  运行参数说明:
@@ -152,6 +139,21 @@ Result: 甚至出现交易几乎停滞的情况
152
  | --language/-l | 识别语言 | zh |
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  <h3 id="CPP">CPP</h3>
156
 
157
  #### 运行
@@ -160,50 +162,35 @@ Result: 甚至出现交易几乎停滞的情况
160
 
161
  ```
162
  cd cpp
163
- ./whisper -w ../demo.wav
164
  ```
165
 
166
 
167
 
168
  ```
169
  cd cpp
170
- ./whisper --model_type small --model_path ../models -w ../demo.wav
171
  ```
172
 
173
  输出结果
174
 
175
  ```
176
- root@ax650:/mnt/qtang/whisper.axera/cpp# ./install/whisper --wav ../demo.wav --model_type small --model_path ../models/ --language zh
177
- wav_file: ../demo.wav
178
- model_path: ../models/
179
- model_type: small
180
  language: zh
181
- Encoder run take 188.30 ms
182
- First token: 17556 take 81.88ms
183
- Next Token: 20844 take 29.64ms
184
- Next Token: 7781 take 29.70ms
185
- Next Token: 20204 take 29.64ms
186
- Next Token: 28455 take 29.65ms
187
- Next Token: 31962 take 29.61ms
188
- Next Token: 6336 take 29.67ms
189
- Next Token: 254 take 29.63ms
190
- Next Token: 2930 take 29.61ms
191
- Next Token: 236 take 29.56ms
192
- Next Token: 36135 take 29.64ms
193
- Next Token: 15868 take 29.71ms
194
- Next Token: 252 take 29.51ms
195
- Next Token: 1546 take 29.63ms
196
- Next Token: 46514 take 29.51ms
197
- Next Token: 50257 take 29.69ms
198
- All take 801.13 ms
199
- Result: 甚至出现交易几乎停滞的情况
200
  ```
201
 
202
  ### 服务端
203
 
204
  ```
205
- cd cpp
206
- ./whisper_srv --model_type tiny --model_path ../models-ax650 --language zh --port 8080
207
  ```
208
 
209
  ### 客户端
 
5
 
6
  # Whisper
7
 
8
+ <div align="center">
9
+ <a href="README_EN.md">English</a> | <a href="README.md">中文</a>
10
+ </div>
11
+
12
  OpenAI Whisper on Axera
13
 
14
  - 目前支持 C++ 和 Python 两种语言
 
17
 
18
  - 如需自行转换请参考[模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
19
 
20
+
21
+ ## Update
22
+
23
+ - 2026/01/14: 更简单的模型结构,现在只需要encoder和decoder,去掉原来的decoder_main和decoder_loop;支持来自HuggingFace的模型导出
24
+
25
+
26
  ## 支持平台
27
 
28
  - [x] AX650N
 
30
 
31
  ## 模型转换
32
 
33
+ 目前支持的模型规模:
34
+ - tiny
35
+ - base
36
+ - small
37
+ - medium
38
+ - turbo
39
+
40
+
41
+ 目前测试过的语言:
42
+ - English
43
+ - Chinese
44
+ - Japanese
45
+ - Korean
46
+ - Malaysian
47
+
48
  [模型转换](https://github.com/ml-inory/whisper.axera/blob/main/model_convert/README.md)
49
 
50
  ## 上板部署
 
112
  输出结果
113
 
114
  ```
115
+ (whisper) root@ax650:/mnt/data/Github/whisper.axera/python# python whisper_cli.py -t tiny -w ../demo.wav
116
  [INFO] Available providers: ['AxEngineExecutionProvider']
117
+ {'wav': '../demo.wav', 'model_type': 'tiny', 'model_path': '../models-ax650', 'language': 'zh', 'task': 'transcribe'}
 
 
 
118
  [INFO] Using provider: AxEngineExecutionProvider
119
  [INFO] Chip type: ChipType.MC50
120
  [INFO] VNPU type: VNPUType.DISABLED
121
+ [INFO] Engine version: 2.12.0s
122
  [INFO] Model type: 2 (triple core)
123
+ [INFO] Compiler version: 5.0 76f70fdc
124
  [INFO] Using provider: AxEngineExecutionProvider
125
  [INFO] Model type: 2 (triple core)
126
+ [INFO] Compiler version: 5.0 76f70fdc
127
+ ASR result:
128
+ 擅职出现交易几乎停止的情况
129
+ RTF: 0.11406774537746188
130
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  ```
132
 
133
  运行参数说明:
 
139
  | --language/-l | 识别语言 | zh |
140
 
141
 
142
+ ##### 服务端
143
+
144
+ ```
145
+ (whisper) root@ax650:/mnt/data/Github/whisper.axera/python# python whisper_svr.py
146
+ [INFO] Available providers: ['AxEngineExecutionProvider']
147
+ Server started at http://0.0.0.0:8000
148
+
149
+ ```
150
+
151
+ 测试服务端
152
+ ```
153
+ python test_svr.py
154
+ ```
155
+
156
+
157
  <h3 id="CPP">CPP</h3>
158
 
159
  #### 运行
 
162
 
163
  ```
164
  cd cpp
165
+ ./whisper_cli -w ../demo.wav -t tiny
166
  ```
167
 
168
 
169
 
170
  ```
171
  cd cpp
172
+ ./whisper_cli --model_type small -w ../demo.wav
173
  ```
174
 
175
  输出结果
176
 
177
  ```
178
+ (whisper) root@ax650:/mnt/data/HF/Whisper/cpp/ax650# ./whisper_cli -w ../../demo.wav -t tiny
179
+ wav_file: ../../demo.wav
180
+ model_path: ../../models-ax650
181
+ model_type: tiny
182
  language: zh
183
+ Init whisper success, take 0.3540seconds
184
+ Result: 甚至出现交易几乎停止的情况
185
+ RTF: 0.0968
186
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  ```
188
 
189
  ### 服务端
190
 
191
  ```
192
+ cd cpp/ax650
193
+ ./whisper_srv --model_type tiny --language zh --port 8080
194
  ```
195
 
196
  ### 客户端
README_EN.md ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # whisper.axera
2
+
3
+ <div align="center">
4
+ <a href="README_EN.md">English</a> | <a href="README.md">中文</a>
5
+ </div>
6
+
7
+ OpenAI Whisper on Axera Platform
8
+
9
+ ## Overview
10
+
11
+ This project provides an optimized implementation of OpenAI's Whisper speech recognition model for Axera AI processors (AX650N/AX630C). It supports both C++ and Python interfaces for efficient on-device speech-to-text conversion.
12
+
13
+ ## Features
14
+
15
+ - **Dual Language Support**: Both C++ and Python APIs available
16
+ - **Multiple Model Sizes**: Support for tiny, base, small, and turbo model variants
17
+ - **Multi-language Recognition**: Tested with English, Chinese, Japanese, and Korean
18
+ - **Optimized Performance**: Specially optimized for Axera NPU acceleration
19
+ - **Easy Deployment**: Pre-built packages and cross-compilation support
20
+
21
+ ## Update
22
+
23
+ - 2026/01/14: We provide cleaner model architecture now.(With encoder and decoder instead of decoder_main and decoder_loop). Support exporting models from huggingface.
24
+
25
+ ## Supported Platforms
26
+
27
+ - ✅ AX650N
28
+ - ✅ AX630C
29
+
30
+ ## Pre-trained Models
31
+
32
+ Download pre-compiled models from:
33
+ - [Baidu Cloud](https://pan.baidu.com/s/1tOHVMZCin0A68T5HmKRJyg?pwd=axyz)
34
+ - [Huggingface](https://huggingface.co/AXERA-TECH/Whisper)
35
+
36
+ For custom model conversion, please refer to [Model Conversion Guide](./model_convert/README_EN.md).
37
+
38
+ ## Model Conversion
39
+
40
+ Currently supported model scales:
41
+ - tiny
42
+ - base
43
+ - small
44
+ - medium
45
+ - turbo
46
+
47
+ Tested languages:
48
+ - English
49
+ - Chinese
50
+ - Japanese
51
+ - Korean
52
+ - Malaysian
53
+
54
+ For other languages or custom model sizes, please refer to the [Model Conversion Guide](./model_convert/README_EN.md).
55
+
56
+ ## Deployment on Target Devices
57
+
58
+ ### Prerequisites
59
+ - AX650N/AX630C devices with Ubuntu 22.04 pre-installed
60
+ - Internet connection for `apt install` and `pip install`
61
+ - Verified hardware platforms:
62
+ - [MaixIV M4nDock (AX650N)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
63
+ - [M.2 Accelerator Card (AX650N)](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
64
+ - [Axera Pi 2 (AX630C)](https://axera-pi-2-docs-cn.readthedocs.io/zh-cn/latest/index.html)
65
+ - [Module-LLM (AX630C)](https://docs.m5stack.com/zh_CN/module/Module-LLM)
66
+ - [LLM630 Compute Kit (AX630C)](https://docs.m5stack.com/zh_CN/core/LLM630%20Compute%20Kit)
67
+
68
+ ## Programming Language Support
69
+
70
+ ### Python
71
+
72
+ Tested with Python 3.12. We recommend using [Miniconda](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh) for environment management.
73
+
74
+ #### Installation
75
+
76
+ ```bash
77
+ cd python
78
+ pip3 install -r requirements.txt
79
+ ```
80
+
81
+ #### pyaxenigne
82
+
83
+ Install NPU Python API from: https://github.com/AXERA-TECH/pyaxengine
84
+
85
+ #### Usage
86
+
87
+ ##### Command Line Interface
88
+
89
+ ```
90
+ cd python
91
+ (whisper) root@ax650:/mnt/data/HF/Whisper/python# python whisper_cli.py -w ../demo.wav -t tiny
92
+ [INFO] Available providers: ['AxEngineExecutionProvider']
93
+ {'wav': '../demo.wav', 'model_type': 'tiny', 'model_path': '../models-ax650', 'language': 'zh', 'task': 'transcribe'}
94
+ [INFO] Using provider: AxEngineExecutionProvider
95
+ [INFO] Chip type: ChipType.MC50
96
+ [INFO] VNPU type: VNPUType.DISABLED
97
+ [INFO] Engine version: 2.12.0s
98
+ [INFO] Model type: 2 (triple core)
99
+ [INFO] Compiler version: 5.0 76f70fdc
100
+ [INFO] Using provider: AxEngineExecutionProvider
101
+ [INFO] Model type: 2 (triple core)
102
+ [INFO] Compiler version: 5.0 76f70fdc
103
+ ASR result:
104
+ 擅职出现交易几乎停止的情况
105
+ RTF: 0.10313174677896837
106
+
107
+ ```
108
+
109
+ Command line arguments:
110
+ | Argument | Description | Default |
111
+ | --- | --- | --- |
112
+ | --wav | Input audio file | - |
113
+ | --model_type/-t | Model type: tiny/base/small | - |
114
+ | --model_path/-p | Model directory | ../models |
115
+ | --language/-l | Recognition language | zh |
116
+
117
+
118
+ ##### Server Mode
119
+
120
+ ```
121
+ (whisper) root@ax650:/mnt/data/HF/Whisper/python# python whisper_svr.py
122
+ [INFO] Available providers: ['AxEngineExecutionProvider']
123
+ Server started at http://0.0.0.0:8000
124
+
125
+ ```
126
+
127
+ Test the server:
128
+ ```
129
+ python test_svr.py
130
+ ```
131
+
132
+
133
+ <h3 id="CPP">CPP</h3>
134
+
135
+ #### Usage on Target Device
136
+ ```
137
+ cd cpp/ax650
138
+ ./whisper_cli -w ../demo.wav -t tiny
139
+ ```
140
+
141
+
142
+
143
+ ```
144
+ cd cpp/ax650
145
+ ./whisper_cli --model_type small -w ../demo.wav
146
+ ```
147
+
148
+ Example Output:
149
+
150
+ ```
151
+ (whisper) root@ax650:/mnt/data/HF/Whisper/cpp/ax650# ./whisper_cli -w ../../demo.wav -t tiny
152
+ wav_file: ../../demo.wav
153
+ model_path: ../../models-ax650
154
+ model_type: tiny
155
+ language: zh
156
+ Init whisper success, take 0.3540seconds
157
+ Result: 甚至出现交易几乎停止的情况
158
+ RTF: 0.0968
159
+
160
+ ```
161
+
162
+ ### Server Mode
163
+
164
+ ```
165
+ cd cpp/ax650
166
+ (whisper) root@ax650:/mnt/data/HF/Whisper/cpp/ax650# ./whisper_svr -t tiny
167
+ port: 8080
168
+ model_path: ../../models-ax650
169
+ model_type: tiny
170
+ language: zh
171
+ [I][ main][ 60]: Initializing server...
172
+ [I][ main][ 65]: Init server success
173
+ [I][ start][ 32]: Start server at port 8080, POST binary stream to IP:8080/asr
174
+
175
+ ```
176
+
177
+ ### Client test using curl:
178
+
179
+ ```
180
+ ffmpeg -i demo.wav -f f32le -c:a pcm_f32le - 2>/dev/null | \
181
+ curl -X POST 10.126.33.192:8080/asr \
182
+ -H "Content-Type: application/octet-stream" \
183
+ --data-binary @-
184
+ ```
185
+
186
+ ## Performance Benchmarks
187
+
188
+ ### Latency
189
+
190
+ RTF: Real-Time Factor
191
+
192
+ CPP:
193
+
194
+ | Models | AX650N | AX630C |
195
+ | ------------- | ------ | ------ |
196
+ | Whisper-Tiny | 0.08 | |
197
+ | Whisper-Base | 0.11 | 0.35 |
198
+ | Whisper-Small | 0.24 | |
199
+ | Whisper-Turbo | 0.48 | |
200
+
201
+ Python:
202
+
203
+ | Models | AX650N | AX630C |
204
+ | ------------- | ------ | ------ |
205
+ | Whisper-Tiny | 0.12 | |
206
+ | Whisper-Base | 0.16 | 0.35 |
207
+ | Whisper-Small | 0.50 | |
208
+ | Whisper-Turbo | 0.60 | |
209
+
210
+ ### Word Error Rate(Test on AIShell dataset)
211
+
212
+ | Models | AX650N | AX630C |
213
+ | ------------- | ------ | ------ |
214
+ | Whisper-Tiny | 0.24 | |
215
+ | Whisper-Base | 0.18 | |
216
+ | Whisper-Small | 0.11 | |
217
+ | Whisper-Turbo | 0.06 | |
218
+
219
+ To reproduce WER test results:
220
+
221
+ Download dataset:
222
+ ```
223
+ cd model_convert
224
+ bash download_dataset.sh
225
+ ```
226
+
227
+ Run test script:
228
+ ```
229
+ cd python
230
+ conda activate whisper
231
+ python test_wer.py -d aishell --gt_path ../model_convert/datasets/ground_truth.txt --model_type tiny
232
+
233
+ ```
234
+
235
+ ### MEM Usage
236
+
237
+ * CMM Stands for Physical memory used by Axera modules like VDEC(Video decoder), VENC(Video encoder), NPU, etc.
238
+
239
+ Python:
240
+
241
+ | Models | CMM(MB)| OS(MB) |
242
+ | ------------- | ------ | ------ |
243
+ | Whisper-Tiny | 332 | 512 |
244
+ | Whisper-Base | 533 | 644 |
245
+ | Whisper-Small | 1106 | 906 |
246
+ | Whisper-Turbo | 2065 | 2084 |
247
+
248
+ C++:
249
+
250
+ | Models | CMM(MB)| OS(MB) |
251
+ | ------------- | ------ | ------ |
252
+ | Whisper-Tiny | 332 | 31 |
253
+ | Whisper-Base | 533 | 54 |
254
+ | Whisper-Small | 1106 | 146 |
255
+ | Whisper-Turbo | 2065 | 86 |
256
+
257
+
258
+ ## Technical Discussion
259
+
260
+ - Github issues
261
+ - Tencent QQ Group: 139953715
cpp/ax630c/TSCharacters.ocd2 ADDED
Binary file (46.1 kB). View file
 
cpp/ax630c/TSPhrases.ocd2 ADDED
Binary file (9.78 kB). View file
 
cpp/ax630c/include/ax_whisper_api.h ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @file ax_whisper_api.h
3
+ * @brief AX Whisper API header - C-compatible interface for Whisper ASR system
4
+ * @note This header provides a C interface to the Whisper speech recognition system
5
+ */
6
+
7
+ #ifndef _AX_WHISPER_API_H_
8
+ #define _AX_WHISPER_API_H_
9
+
10
+ #ifdef __cplusplus
11
+ extern "C" {
12
+ #endif
13
+
14
+ #define AX_WHISPER_API __attribute__((visibility("default")))
15
+
16
+ /**
17
+ * @brief Opaque handle type for Whisper ASR context
18
+ *
19
+ * This handle encapsulates all internal state of the Whisper ASR system.
20
+ * The actual implementation is hidden from C callers to maintain ABI stability.
21
+ */
22
+ typedef void* AX_WHISPER_HANDLE;
23
+
24
+ /**
25
+ * @brief Initialize the Whisper ASR system with specific configuration
26
+ *
27
+ * Creates and initializes a new Whisper ASR context with the specified
28
+ * model type, model path, and language. This function loads the appropriate
29
+ * models, configures the recognizer, and prepares it for speech recognition.
30
+ *
31
+ * @param model_type Type of Whisper model to use (e.g., "tiny", "base", "small", "medium", "large")
32
+ * or custom model identifier
33
+ * @param model_path Directory path where model files are stored
34
+ * Model files are expected to be in the format:
35
+ * - {model_path}/{model_type}/{model_type}-encoder.axmodel
36
+ * - {model_path}/{model_type}/{model_type}-decoder.axmodel
37
+ * - {model_path}/{model_type}/{model_type}-tokens.txt
38
+ * - {model_path}/{model_type}/{model_type}_config.json
39
+ * @param language Language code for recognition (e.g., "en", "zh", "ja", "ko")
40
+ * Use "auto" for automatic language detection if supported
41
+ *
42
+ * @return AX_WHISPER_HANDLE Opaque handle to the initialized Whisper context,
43
+ * or NULL if initialization fails
44
+ *
45
+ * @note The caller is responsible for calling AX_WHISPER_Uninit() to free
46
+ * resources when the handle is no longer needed.
47
+ * @note If language is not supported by the model, the function may fall back
48
+ * to a default language or return NULL.
49
+ * @example
50
+ * // Initialize English recognition with base model
51
+ * AX_WHISPER_HANDLE handle = AX_WHISPER_Init("base", "../models-ax650", "en");
52
+ *
53
+ */
54
+ AX_WHISPER_API AX_WHISPER_HANDLE AX_WHISPER_Init(const char* model_type, const char* model_path, const char* language);
55
+
56
+ /**
57
+ * @brief Deinitialize and release Whisper ASR resources
58
+ *
59
+ * Cleans up all resources associated with the Whisper context, including
60
+ * unloading models, freeing memory, and releasing hardware resources.
61
+ *
62
+ * @param handle Whisper context handle obtained from AX_WHISPER_Init()
63
+ *
64
+ * @warning After calling this function, the handle becomes invalid and
65
+ * should not be used in any subsequent API calls.
66
+ */
67
+ AX_WHISPER_API void AX_WHISPER_Uninit(AX_WHISPER_HANDLE handle);
68
+
69
+ /**
70
+ * @brief Perform speech recognition and return dynamically allocated string
71
+ *
72
+ * @param handle Whisper context handle
73
+ * @param wav_file Path to the input 16k pcmf32 WAV audio file
74
+ * @param result Pointer to receive the allocated result string
75
+ *
76
+ * @return int Status code (0 = success, <0 = error)
77
+ *
78
+ * @note The returned string is allocated with malloc() and must be freed
79
+ * by the caller using free() when no longer needed.
80
+ */
81
+ AX_WHISPER_API int AX_WHISPER_RunFile(AX_WHISPER_HANDLE handle,
82
+ const char* wav_file,
83
+ char** result);
84
+
85
+ /**
86
+ * @brief Perform speech recognition and return dynamically allocated string
87
+ *
88
+ * @param handle Whisper context handle
89
+ * @param pcm_data 16k Mono PCM f32 data, range from -1.0 to 1.0
90
+ * @param num_samples Sample num of PCM data
91
+ * @param result Pointer to receive the allocated result string
92
+ *
93
+ * @return int Status code (0 = success, <0 = error)
94
+ *
95
+ * @note The returned string is allocated with malloc() and must be freed
96
+ * by the caller using free() when no longer needed.
97
+ */
98
+ AX_WHISPER_API int AX_WHISPER_RunPCM(AX_WHISPER_HANDLE handle,
99
+ float* pcm_data,
100
+ int num_samples,
101
+ char** result);
102
+
103
+ #ifdef __cplusplus
104
+ }
105
+ #endif
106
+
107
+ #endif // _AX_WHISPER_API_H_
cpp/ax630c/lib/cmake/ax_whisper/ax_whisper-config-release.cmake ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #----------------------------------------------------------------
2
+ # Generated CMake target import file for configuration "Release".
3
+ #----------------------------------------------------------------
4
+
5
+ # Commands may need to know the format version.
6
+ set(CMAKE_IMPORT_FILE_VERSION 1)
7
+
8
+ # Import target "ax::ax_whisper" for configuration "Release"
9
+ set_property(TARGET ax::ax_whisper APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
10
+ set_target_properties(ax::ax_whisper PROPERTIES
11
+ IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libax_whisper.so"
12
+ IMPORTED_SONAME_RELEASE "libax_whisper.so"
13
+ )
14
+
15
+ list(APPEND _IMPORT_CHECK_TARGETS ax::ax_whisper )
16
+ list(APPEND _IMPORT_CHECK_FILES_FOR_ax::ax_whisper "${_IMPORT_PREFIX}/lib/libax_whisper.so" )
17
+
18
+ # Commands beyond this point should not need to know the version.
19
+ set(CMAKE_IMPORT_FILE_VERSION)
cpp/ax630c/lib/cmake/ax_whisper/ax_whisper-config.cmake ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by CMake
2
+
3
+ if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.6)
4
+ message(FATAL_ERROR "CMake >= 2.6.0 required")
5
+ endif()
6
+ cmake_policy(PUSH)
7
+ cmake_policy(VERSION 2.6...3.20)
8
+ #----------------------------------------------------------------
9
+ # Generated CMake target import file.
10
+ #----------------------------------------------------------------
11
+
12
+ # Commands may need to know the format version.
13
+ set(CMAKE_IMPORT_FILE_VERSION 1)
14
+
15
+ # Protect against multiple inclusion, which would fail when already imported targets are added once more.
16
+ set(_targetsDefined)
17
+ set(_targetsNotDefined)
18
+ set(_expectedTargets)
19
+ foreach(_expectedTarget ax::ax_whisper)
20
+ list(APPEND _expectedTargets ${_expectedTarget})
21
+ if(NOT TARGET ${_expectedTarget})
22
+ list(APPEND _targetsNotDefined ${_expectedTarget})
23
+ endif()
24
+ if(TARGET ${_expectedTarget})
25
+ list(APPEND _targetsDefined ${_expectedTarget})
26
+ endif()
27
+ endforeach()
28
+ if("${_targetsDefined}" STREQUAL "${_expectedTargets}")
29
+ unset(_targetsDefined)
30
+ unset(_targetsNotDefined)
31
+ unset(_expectedTargets)
32
+ set(CMAKE_IMPORT_FILE_VERSION)
33
+ cmake_policy(POP)
34
+ return()
35
+ endif()
36
+ if(NOT "${_targetsDefined}" STREQUAL "")
37
+ message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_targetsDefined}\nTargets not yet defined: ${_targetsNotDefined}\n")
38
+ endif()
39
+ unset(_targetsDefined)
40
+ unset(_targetsNotDefined)
41
+ unset(_expectedTargets)
42
+
43
+
44
+ # Compute the installation prefix relative to this file.
45
+ get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
46
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
47
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
48
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
49
+ if(_IMPORT_PREFIX STREQUAL "/")
50
+ set(_IMPORT_PREFIX "")
51
+ endif()
52
+
53
+ # Create imported target ax::ax_whisper
54
+ add_library(ax::ax_whisper SHARED IMPORTED)
55
+
56
+ set_target_properties(ax::ax_whisper PROPERTIES
57
+ INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
58
+ )
59
+
60
+ # Load information for each installed configuration.
61
+ get_filename_component(_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
62
+ file(GLOB CONFIG_FILES "${_DIR}/ax_whisper-config-*.cmake")
63
+ foreach(f ${CONFIG_FILES})
64
+ include(${f})
65
+ endforeach()
66
+
67
+ # Cleanup temporary variables.
68
+ set(_IMPORT_PREFIX)
69
+
70
+ # Loop over all imported files and verify that they actually exist
71
+ foreach(target ${_IMPORT_CHECK_TARGETS} )
72
+ foreach(file ${_IMPORT_CHECK_FILES_FOR_${target}} )
73
+ if(NOT EXISTS "${file}" )
74
+ message(FATAL_ERROR "The imported target \"${target}\" references the file
75
+ \"${file}\"
76
+ but this file does not exist. Possible reasons include:
77
+ * The file was deleted, renamed, or moved to another location.
78
+ * An install or uninstall procedure did not complete successfully.
79
+ * The installation package was faulty and contained
80
+ \"${CMAKE_CURRENT_LIST_FILE}\"
81
+ but not all the files it references.
82
+ ")
83
+ endif()
84
+ endforeach()
85
+ unset(_IMPORT_CHECK_FILES_FOR_${target})
86
+ endforeach()
87
+ unset(_IMPORT_CHECK_TARGETS)
88
+
89
+ # This file does not depend on other imported targets which have
90
+ # been exported from the same project but in a separate export set.
91
+
92
+ # Commands beyond this point should not need to know the version.
93
+ set(CMAKE_IMPORT_FILE_VERSION)
94
+ cmake_policy(POP)
cpp/{TSCharacters.ocd2 → ax630c/lib/libax_whisper.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:85291e0173e972bbca58c848fb90b3bb41c79674cb61a75645e01bd884ad5927
3
- size 46126
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:469ffe5522d67f92b9b7d5390dc17740cbef3cd3b8b484542a8e1de44c11ad5a
3
+ size 623784
cpp/ax630c/t2s.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Traditional Chinese to Simplified Chinese",
3
+ "segmentation": {
4
+ "type": "mmseg",
5
+ "dict": {
6
+ "type": "ocd2",
7
+ "file": "TSPhrases.ocd2"
8
+ }
9
+ },
10
+ "conversion_chain": [{
11
+ "dict": {
12
+ "type": "group",
13
+ "dicts": [{
14
+ "type": "ocd2",
15
+ "file": "TSPhrases.ocd2"
16
+ }, {
17
+ "type": "ocd2",
18
+ "file": "TSCharacters.ocd2"
19
+ }]
20
+ }
21
+ }]
22
+ }
cpp/ax630c/whisper_cli ADDED
Binary file (93.8 kB). View file
 
cpp/{TSPhrases.ocd2 → ax630c/whisper_svr} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:eea69e525e01b8475a1b1ad45f78f25e5aa78986305f185ef6f85e11f5325387
3
- size 9782
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2af4115e31d25a1c85f39666aba3918403449ca64e60d11f5af7e49c7456cd7
3
+ size 531688
cpp/ax650/TSCharacters.ocd2 ADDED
Binary file (46.1 kB). View file
 
cpp/ax650/TSPhrases.ocd2 ADDED
Binary file (9.78 kB). View file
 
cpp/ax650/include/ax_whisper_api.h ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @file ax_whisper_api.h
3
+ * @brief AX Whisper API header - C-compatible interface for Whisper ASR system
4
+ * @note This header provides a C interface to the Whisper speech recognition system
5
+ */
6
+
7
+ #ifndef _AX_WHISPER_API_H_
8
+ #define _AX_WHISPER_API_H_
9
+
10
+ #ifdef __cplusplus
11
+ extern "C" {
12
+ #endif
13
+
14
+ #define AX_WHISPER_API __attribute__((visibility("default")))
15
+
16
+ /**
17
+ * @brief Opaque handle type for Whisper ASR context
18
+ *
19
+ * This handle encapsulates all internal state of the Whisper ASR system.
20
+ * The actual implementation is hidden from C callers to maintain ABI stability.
21
+ */
22
+ typedef void* AX_WHISPER_HANDLE;
23
+
24
+ /**
25
+ * @brief Initialize the Whisper ASR system with specific configuration
26
+ *
27
+ * Creates and initializes a new Whisper ASR context with the specified
28
+ * model type, model path, and language. This function loads the appropriate
29
+ * models, configures the recognizer, and prepares it for speech recognition.
30
+ *
31
+ * @param model_type Type of Whisper model to use (e.g., "tiny", "base", "small", "medium", "large")
32
+ * or custom model identifier
33
+ * @param model_path Directory path where model files are stored
34
+ * Model files are expected to be in the format:
35
+ * - {model_path}/{model_type}/{model_type}-encoder.axmodel
36
+ * - {model_path}/{model_type}/{model_type}-decoder.axmodel
37
+ * - {model_path}/{model_type}/{model_type}-tokens.txt
38
+ * - {model_path}/{model_type}/{model_type}_config.json
39
+ * @param language Language code for recognition (e.g., "en", "zh", "ja", "ko")
40
+ * Use "auto" for automatic language detection if supported
41
+ *
42
+ * @return AX_WHISPER_HANDLE Opaque handle to the initialized Whisper context,
43
+ * or NULL if initialization fails
44
+ *
45
+ * @note The caller is responsible for calling AX_WHISPER_Uninit() to free
46
+ * resources when the handle is no longer needed.
47
+ * @note If language is not supported by the model, the function may fall back
48
+ * to a default language or return NULL.
49
+ * @example
50
+ * // Initialize English recognition with base model
51
+ * AX_WHISPER_HANDLE handle = AX_WHISPER_Init("base", "../models-ax650", "en");
52
+ *
53
+ */
54
+ AX_WHISPER_API AX_WHISPER_HANDLE AX_WHISPER_Init(const char* model_type, const char* model_path, const char* language);
55
+
56
+ /**
57
+ * @brief Deinitialize and release Whisper ASR resources
58
+ *
59
+ * Cleans up all resources associated with the Whisper context, including
60
+ * unloading models, freeing memory, and releasing hardware resources.
61
+ *
62
+ * @param handle Whisper context handle obtained from AX_WHISPER_Init()
63
+ *
64
+ * @warning After calling this function, the handle becomes invalid and
65
+ * should not be used in any subsequent API calls.
66
+ */
67
+ AX_WHISPER_API void AX_WHISPER_Uninit(AX_WHISPER_HANDLE handle);
68
+
69
+ /**
70
+ * @brief Perform speech recognition and return dynamically allocated string
71
+ *
72
+ * @param handle Whisper context handle
73
+ * @param wav_file Path to the input 16k pcmf32 WAV audio file
74
+ * @param result Pointer to receive the allocated result string
75
+ *
76
+ * @return int Status code (0 = success, <0 = error)
77
+ *
78
+ * @note The returned string is allocated with malloc() and must be freed
79
+ * by the caller using free() when no longer needed.
80
+ */
81
+ AX_WHISPER_API int AX_WHISPER_RunFile(AX_WHISPER_HANDLE handle,
82
+ const char* wav_file,
83
+ char** result);
84
+
85
+ /**
86
+ * @brief Perform speech recognition and return dynamically allocated string
87
+ *
88
+ * @param handle Whisper context handle
89
+ * @param pcm_data 16k Mono PCM f32 data, range from -1.0 to 1.0
90
+ * @param num_samples Sample num of PCM data
91
+ * @param result Pointer to receive the allocated result string
92
+ *
93
+ * @return int Status code (0 = success, <0 = error)
94
+ *
95
+ * @note The returned string is allocated with malloc() and must be freed
96
+ * by the caller using free() when no longer needed.
97
+ */
98
+ AX_WHISPER_API int AX_WHISPER_RunPCM(AX_WHISPER_HANDLE handle,
99
+ float* pcm_data,
100
+ int num_samples,
101
+ char** result);
102
+
103
+ #ifdef __cplusplus
104
+ }
105
+ #endif
106
+
107
+ #endif // _AX_WHISPER_API_H_
cpp/ax650/lib/cmake/ax_whisper/ax_whisper-config-release.cmake ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #----------------------------------------------------------------
2
+ # Generated CMake target import file for configuration "Release".
3
+ #----------------------------------------------------------------
4
+
5
+ # Commands may need to know the format version.
6
+ set(CMAKE_IMPORT_FILE_VERSION 1)
7
+
8
+ # Import target "ax::ax_whisper" for configuration "Release"
9
+ set_property(TARGET ax::ax_whisper APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
10
+ set_target_properties(ax::ax_whisper PROPERTIES
11
+ IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libax_whisper.so"
12
+ IMPORTED_SONAME_RELEASE "libax_whisper.so"
13
+ )
14
+
15
+ list(APPEND _IMPORT_CHECK_TARGETS ax::ax_whisper )
16
+ list(APPEND _IMPORT_CHECK_FILES_FOR_ax::ax_whisper "${_IMPORT_PREFIX}/lib/libax_whisper.so" )
17
+
18
+ # Commands beyond this point should not need to know the version.
19
+ set(CMAKE_IMPORT_FILE_VERSION)
cpp/ax650/lib/cmake/ax_whisper/ax_whisper-config.cmake ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by CMake
2
+
3
+ if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.6)
4
+ message(FATAL_ERROR "CMake >= 2.6.0 required")
5
+ endif()
6
+ cmake_policy(PUSH)
7
+ cmake_policy(VERSION 2.6...3.20)
8
+ #----------------------------------------------------------------
9
+ # Generated CMake target import file.
10
+ #----------------------------------------------------------------
11
+
12
+ # Commands may need to know the format version.
13
+ set(CMAKE_IMPORT_FILE_VERSION 1)
14
+
15
+ # Protect against multiple inclusion, which would fail when already imported targets are added once more.
16
+ set(_targetsDefined)
17
+ set(_targetsNotDefined)
18
+ set(_expectedTargets)
19
+ foreach(_expectedTarget ax::ax_whisper)
20
+ list(APPEND _expectedTargets ${_expectedTarget})
21
+ if(NOT TARGET ${_expectedTarget})
22
+ list(APPEND _targetsNotDefined ${_expectedTarget})
23
+ endif()
24
+ if(TARGET ${_expectedTarget})
25
+ list(APPEND _targetsDefined ${_expectedTarget})
26
+ endif()
27
+ endforeach()
28
+ if("${_targetsDefined}" STREQUAL "${_expectedTargets}")
29
+ unset(_targetsDefined)
30
+ unset(_targetsNotDefined)
31
+ unset(_expectedTargets)
32
+ set(CMAKE_IMPORT_FILE_VERSION)
33
+ cmake_policy(POP)
34
+ return()
35
+ endif()
36
+ if(NOT "${_targetsDefined}" STREQUAL "")
37
+ message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_targetsDefined}\nTargets not yet defined: ${_targetsNotDefined}\n")
38
+ endif()
39
+ unset(_targetsDefined)
40
+ unset(_targetsNotDefined)
41
+ unset(_expectedTargets)
42
+
43
+
44
+ # Compute the installation prefix relative to this file.
45
+ get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
46
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
47
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
48
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
49
+ if(_IMPORT_PREFIX STREQUAL "/")
50
+ set(_IMPORT_PREFIX "")
51
+ endif()
52
+
53
+ # Create imported target ax::ax_whisper
54
+ add_library(ax::ax_whisper SHARED IMPORTED)
55
+
56
+ set_target_properties(ax::ax_whisper PROPERTIES
57
+ INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include"
58
+ )
59
+
60
+ # Load information for each installed configuration.
61
+ get_filename_component(_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
62
+ file(GLOB CONFIG_FILES "${_DIR}/ax_whisper-config-*.cmake")
63
+ foreach(f ${CONFIG_FILES})
64
+ include(${f})
65
+ endforeach()
66
+
67
+ # Cleanup temporary variables.
68
+ set(_IMPORT_PREFIX)
69
+
70
+ # Loop over all imported files and verify that they actually exist
71
+ foreach(target ${_IMPORT_CHECK_TARGETS} )
72
+ foreach(file ${_IMPORT_CHECK_FILES_FOR_${target}} )
73
+ if(NOT EXISTS "${file}" )
74
+ message(FATAL_ERROR "The imported target \"${target}\" references the file
75
+ \"${file}\"
76
+ but this file does not exist. Possible reasons include:
77
+ * The file was deleted, renamed, or moved to another location.
78
+ * An install or uninstall procedure did not complete successfully.
79
+ * The installation package was faulty and contained
80
+ \"${CMAKE_CURRENT_LIST_FILE}\"
81
+ but not all the files it references.
82
+ ")
83
+ endif()
84
+ endforeach()
85
+ unset(_IMPORT_CHECK_FILES_FOR_${target})
86
+ endforeach()
87
+ unset(_IMPORT_CHECK_TARGETS)
88
+
89
+ # This file does not depend on other imported targets which have
90
+ # been exported from the same project but in a separate export set.
91
+
92
+ # Commands beyond this point should not need to know the version.
93
+ set(CMAKE_IMPORT_FILE_VERSION)
94
+ cmake_policy(POP)
cpp/{t2s.json → ax650/lib/libax_whisper.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b818534194f27c2d95f01001edb0a5ec49b9050119892cb30a0504bb202cc07c
3
- size 406
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2aafbad2ea23d93c226fdb3d7d22ccbad44813638c78222787c8de9f85ec358
3
+ size 624072
cpp/ax650/t2s.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Traditional Chinese to Simplified Chinese",
3
+ "segmentation": {
4
+ "type": "mmseg",
5
+ "dict": {
6
+ "type": "ocd2",
7
+ "file": "TSPhrases.ocd2"
8
+ }
9
+ },
10
+ "conversion_chain": [{
11
+ "dict": {
12
+ "type": "group",
13
+ "dicts": [{
14
+ "type": "ocd2",
15
+ "file": "TSPhrases.ocd2"
16
+ }, {
17
+ "type": "ocd2",
18
+ "file": "TSCharacters.ocd2"
19
+ }]
20
+ }
21
+ }]
22
+ }
cpp/ax650/whisper_cli ADDED
Binary file (93.8 kB). View file
 
cpp/{whisper_aarch64 → ax650/whisper_svr} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:70a165bd3b25a07c17fd45ed84dd04231f20dff7614ce4576ac090a51cc64513
3
- size 584440
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:400a719b950cf61863179204d509c4f785b75184af2415bca9ed04a0198bf363
3
+ size 531688
cpp/whisper_axcl_aarch64 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:47ad92295eecfcd230f420acbb5ecd86211c53735aa216d935b22e17d740cd09
3
- size 1251248
 
 
 
 
cpp/whisper_axcl_x86 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f59c233fb5c7d935422b650b5cd968e4b40654d368c671e6551c8ea7dcc78cee
3
- size 1212056
 
 
 
 
cpp/whisper_srv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:27f1a07763df0f8fc2a94cf32551daae4082111668ea5e938b93bdd6a76e3d29
3
- size 1006728
 
 
 
 
python/assets/multilingual.tiktoken DELETED
The diff for this file is too large to render. See raw diff
 
python/languages.py DELETED
@@ -1,102 +0,0 @@
1
- WHISPER_LANGUAGES = {
2
- "en": "english",
3
- "zh": "chinese",
4
- "de": "german",
5
- "es": "spanish",
6
- "ru": "russian",
7
- "ko": "korean",
8
- "fr": "french",
9
- "ja": "japanese",
10
- "pt": "portuguese",
11
- "tr": "turkish",
12
- "pl": "polish",
13
- "ca": "catalan",
14
- "nl": "dutch",
15
- "ar": "arabic",
16
- "sv": "swedish",
17
- "it": "italian",
18
- "id": "indonesian",
19
- "hi": "hindi",
20
- "fi": "finnish",
21
- "vi": "vietnamese",
22
- "he": "hebrew",
23
- "uk": "ukrainian",
24
- "el": "greek",
25
- "ms": "malay",
26
- "cs": "czech",
27
- "ro": "romanian",
28
- "da": "danish",
29
- "hu": "hungarian",
30
- "ta": "tamil",
31
- "no": "norwegian",
32
- "th": "thai",
33
- "ur": "urdu",
34
- "hr": "croatian",
35
- "bg": "bulgarian",
36
- "lt": "lithuanian",
37
- "la": "latin",
38
- "mi": "maori",
39
- "ml": "malayalam",
40
- "cy": "welsh",
41
- "sk": "slovak",
42
- "te": "telugu",
43
- "fa": "persian",
44
- "lv": "latvian",
45
- "bn": "bengali",
46
- "sr": "serbian",
47
- "az": "azerbaijani",
48
- "sl": "slovenian",
49
- "kn": "kannada",
50
- "et": "estonian",
51
- "mk": "macedonian",
52
- "br": "breton",
53
- "eu": "basque",
54
- "is": "icelandic",
55
- "hy": "armenian",
56
- "ne": "nepali",
57
- "mn": "mongolian",
58
- "bs": "bosnian",
59
- "kk": "kazakh",
60
- "sq": "albanian",
61
- "sw": "swahili",
62
- "gl": "galician",
63
- "mr": "marathi",
64
- "pa": "punjabi",
65
- "si": "sinhala",
66
- "km": "khmer",
67
- "sn": "shona",
68
- "yo": "yoruba",
69
- "so": "somali",
70
- "af": "afrikaans",
71
- "oc": "occitan",
72
- "ka": "georgian",
73
- "be": "belarusian",
74
- "tg": "tajik",
75
- "sd": "sindhi",
76
- "gu": "gujarati",
77
- "am": "amharic",
78
- "yi": "yiddish",
79
- "lo": "lao",
80
- "uz": "uzbek",
81
- "fo": "faroese",
82
- "ht": "haitian creole",
83
- "ps": "pashto",
84
- "tk": "turkmen",
85
- "nn": "nynorsk",
86
- "mt": "maltese",
87
- "sa": "sanskrit",
88
- "lb": "luxembourgish",
89
- "my": "myanmar",
90
- "bo": "tibetan",
91
- "tl": "tagalog",
92
- "mg": "malagasy",
93
- "as": "assamese",
94
- "tt": "tatar",
95
- "haw": "hawaiian",
96
- "ln": "lingala",
97
- "ha": "hausa",
98
- "ba": "bashkir",
99
- "jw": "javanese",
100
- "su": "sundanese",
101
- "yue": "cantonese",
102
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/main.py DELETED
@@ -1,74 +0,0 @@
1
- import argparse
2
- import os
3
- from whisper import Whisper
4
- import time
5
-
6
-
7
- def get_args():
8
- parser = argparse.ArgumentParser(
9
- prog="whisper", description="Run Whisper on input audio file"
10
- )
11
- parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
12
- parser.add_argument(
13
- "--model_type",
14
- "-t",
15
- type=str,
16
- choices=["tiny", "base", "small", "large", "large-v3", "turbo"],
17
- required=True,
18
- help="model type, only support tiny, base and small currently",
19
- )
20
- parser.add_argument(
21
- "--model_path",
22
- "-p",
23
- type=str,
24
- required=False,
25
- default="../models-ax650",
26
- help="model path for *.axmodel, tokens.txt, positional_embedding.bin",
27
- )
28
- parser.add_argument(
29
- "--language",
30
- "-l",
31
- type=str,
32
- required=False,
33
- default="zh",
34
- help="Target language, support en, zh, ja, and others. See languages.py for more options.",
35
- )
36
- parser.add_argument(
37
- "--task",
38
- type=str,
39
- required=False,
40
- choices=["translate", "transcribe"],
41
- default="transcribe",
42
- )
43
- parser.add_argument(
44
- "--print_rtf", action="store_true", help="Print Real-Time Factor"
45
- )
46
- return parser.parse_args()
47
-
48
-
49
- def main():
50
- args = get_args()
51
- print(vars(args))
52
-
53
- # Check wav existence
54
- wav_path = args.wav
55
- assert os.path.exists(wav_path), f"{wav_path} NOT exist"
56
-
57
- model = Whisper(args.model_type, args.model_path, args.language, args.task)
58
-
59
- print("ASR result:")
60
- start = time.time()
61
- print(model.run(wav_path))
62
- end = time.time()
63
-
64
- if args.print_rtf:
65
- import librosa
66
-
67
- samples, sr = librosa.load(wav_path, sr=16000)
68
- duration = len(samples) / sr
69
- process_time = end - start
70
- print(f"RTF: {process_time / duration}")
71
-
72
-
73
- if __name__ == "__main__":
74
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/test_svr.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+
4
+ def transcribe_audio(
5
+ server_url: str,
6
+ wav_path: str,
7
+ model_type: str = "tiny",
8
+ model_path: str = "../models-ax650",
9
+ language: str = "zh",
10
+ task: str = "transcribe",
11
+ ):
12
+ url = f"{server_url.rstrip('/')}/asr"
13
+
14
+ files = {
15
+ "wav": open(wav_path, "rb"),
16
+ }
17
+
18
+ data = {
19
+ "model_type": model_type,
20
+ "model_path": model_path,
21
+ "language": language,
22
+ "task": task,
23
+ }
24
+
25
+ print(f"Sending request to: {url}")
26
+
27
+ response = requests.post(url, files=files, data=data)
28
+ if response.status_code != 200:
29
+ print("❌ Error:", response.text)
30
+ return None
31
+
32
+ result = response.json()
33
+ print("服务器返回结果:")
34
+ print(result)
35
+
36
+ return result
37
+
38
+
39
+ if __name__ == "__main__":
40
+ # 你的服务器地址
41
+ SERVER = "http://127.0.0.1:8000"
42
+
43
+ # 本地 wav 文件路径
44
+ WAV = "../demo.wav"
45
+
46
+ transcribe_audio(SERVER, WAV)
python/test_wer.py CHANGED
@@ -2,14 +2,19 @@ import argparse
2
  import os
3
  import logging
4
  import re
5
- from whisper import Whisper
 
 
 
 
 
6
 
7
 
8
- def setup_logging():
9
  """配置日志系统,同时输出到控制台和文件"""
10
  # 获取脚本所在目录
11
  script_dir = os.path.dirname(os.path.abspath(__file__))
12
- log_file = os.path.join(script_dir, "test_wer.log")
13
 
14
  # 配置日志格式
15
  log_format = "%(asctime)s - %(levelname)s - %(message)s"
@@ -24,7 +29,7 @@ def setup_logging():
24
  logger.removeHandler(handler)
25
 
26
  # 创建文件handler
27
- file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
28
  file_handler.setLevel(logging.INFO)
29
  file_formatter = logging.Formatter(log_format, date_format)
30
  file_handler.setFormatter(file_formatter)
@@ -42,6 +47,61 @@ def setup_logging():
42
  return logger
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  class AIShellDataset:
46
  def __init__(self, gt_path: str):
47
  """
@@ -149,6 +209,56 @@ class CommonVoiceDataset:
149
  return len(self.data)
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def get_args():
153
  parser = argparse.ArgumentParser(prog="whisper", description="Test WER on dataset")
154
  parser.add_argument(
@@ -156,7 +266,7 @@ def get_args():
156
  "-d",
157
  type=str,
158
  required=True,
159
- choices=["aishell", "common_voice"],
160
  help="Test dataset",
161
  )
162
  parser.add_argument(
@@ -173,7 +283,7 @@ def get_args():
173
  "--model_type",
174
  "-t",
175
  type=str,
176
- choices=["tiny", "base", "small", "large", "large-v3", "turbo"],
177
  required=True,
178
  help="model type, only support tiny, base and small currently",
179
  )
@@ -182,8 +292,11 @@ def get_args():
182
  "-p",
183
  type=str,
184
  required=False,
185
- default="../models/models-ax650",
186
- help="model path for *.axmodel, tokens.txt, positional_embedding.bin",
 
 
 
187
  )
188
  parser.add_argument(
189
  "--language",
@@ -193,17 +306,16 @@ def get_args():
193
  default="zh",
194
  help="Target language, support en, zh, ja, and others. See languages.py for more options.",
195
  )
 
 
 
 
196
  return parser.parse_args()
197
 
198
 
199
  def print_args(args):
200
  logger = logging.getLogger()
201
- logger.info(f"dataset: {args.dataset}")
202
- logger.info(f"gt_path: {args.gt_path}")
203
- logger.info(f"max_num: {args.max_num}")
204
- logger.info(f"model_type: {args.model_type}")
205
- logger.info(f"model_path: {args.model_path}")
206
- logger.info(f"language: {args.language}")
207
 
208
 
209
  def min_distance(word1: str, word2: str) -> int:
@@ -247,10 +359,10 @@ def remove_punctuation(text):
247
 
248
 
249
  def main():
250
- # 设置日志系统
251
- logger = setup_logging()
252
-
253
  args = get_args()
 
 
 
254
  print_args(args)
255
 
256
  dataset_type = args.dataset.lower()
@@ -258,26 +370,88 @@ def main():
258
  dataset = AIShellDataset(args.gt_path)
259
  elif dataset_type == "common_voice":
260
  dataset = CommonVoiceDataset(args.gt_path)
 
 
261
  else:
262
  raise ValueError(f"Unknown dataset type {dataset_type}")
263
 
264
  max_num = args.max_num
265
 
266
  # Load model
267
- model = Whisper(args.model_type, args.model_path, args.language, "transcribe")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  # Iterate over dataset
270
  references = []
271
  hyp = []
272
  all_character_error_num = 0
273
  all_character_num = 0
274
- wer_file = open("wer.txt", "w")
275
  max_data_num = max_num if max_num > 0 else len(dataset)
276
  for n, (audio_path, reference) in enumerate(dataset):
277
- hypothesis = model.run(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
- hypothesis = remove_punctuation(hypothesis)
280
- reference = remove_punctuation(reference)
281
 
282
  character_error_num = min_distance(reference, hypothesis)
283
  character_num = len(reference)
@@ -290,7 +464,6 @@ def main():
290
  references.append(reference)
291
 
292
  line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
293
- wer_file.write(line_content + "\n")
294
  logger.info(line_content)
295
 
296
  if n + 1 >= max_data_num:
@@ -299,8 +472,6 @@ def main():
299
  total_character_error_rate = all_character_error_num / all_character_num * 100
300
 
301
  logger.info(f"Total WER: {total_character_error_rate}%")
302
- wer_file.write(f"Total WER: {total_character_error_rate}%")
303
- wer_file.close()
304
 
305
 
306
  if __name__ == "__main__":
 
2
  import os
3
  import logging
4
  import re
5
+ import pandas as pd
6
+ from typing import Tuple
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import zhconv
10
+ import librosa
11
 
12
 
13
+ def setup_logging(filename):
14
  """配置日志系统,同时输出到控制台和文件"""
15
  # 获取脚本所在目录
16
  script_dir = os.path.dirname(os.path.abspath(__file__))
17
+ log_file = os.path.join(script_dir, f"{filename}.log")
18
 
19
  # 配置日志格式
20
  log_format = "%(asctime)s - %(levelname)s - %(message)s"
 
29
  logger.removeHandler(handler)
30
 
31
  # 创建文件handler
32
+ file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
33
  file_handler.setLevel(logging.INFO)
34
  file_formatter = logging.Formatter(log_format, date_format)
35
  file_handler.setFormatter(file_formatter)
 
47
  return logger
48
 
49
 
50
+ def load_audio(filename: str) -> Tuple[np.ndarray, int]:
51
+ data, sample_rate = sf.read(
52
+ filename,
53
+ always_2d=True,
54
+ dtype="float32",
55
+ )
56
+ data = data[:, 0] # use only the first channel
57
+ if sample_rate != 16000:
58
+ wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
59
+ sample_rate = 16000
60
+ samples = np.ascontiguousarray(data)
61
+ return samples, sample_rate
62
+
63
+
64
+ def compute_feat(filename: str, n_mels: int = 80):
65
+ audio, sample_rate = load_audio(filename)
66
+ if sample_rate != 16000:
67
+ audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
68
+ sample_rate = 16000
69
+
70
+ mel = librosa.feature.melspectrogram(
71
+ y=audio,
72
+ sr=sample_rate,
73
+ n_fft=480,
74
+ hop_length=160,
75
+ window="hann",
76
+ center=True,
77
+ pad_mode="reflect",
78
+ power=2.0,
79
+ n_mels=n_mels,
80
+ )
81
+
82
+ log_spec = np.log10(np.maximum(mel, 1e-10))
83
+ log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
84
+ mel = (log_spec + 4.0) / 4.0
85
+
86
+ target = 3000
87
+ if mel.shape[1] > target:
88
+ # -50 so that there are some zero tail paddings.
89
+ mel = mel[:, :target]
90
+ mel[:, -50:] = 0
91
+
92
+ # We don't need to pad it to 30 seconds now!
93
+ if mel.shape[1] < target:
94
+ mel = np.concatenate(
95
+ (
96
+ mel,
97
+ np.zeros((n_mels, target - mel.shape[1]), dtype=np.float32),
98
+ ),
99
+ axis=-1,
100
+ )
101
+
102
+ return mel[np.newaxis, ...]
103
+
104
+
105
  class AIShellDataset:
106
  def __init__(self, gt_path: str):
107
  """
 
209
  return len(self.data)
210
 
211
 
212
+ class CustomDataset:
213
+ """自定义数据集解析器"""
214
+
215
+ def __init__(self, label_path: str):
216
+ """
217
+ 初始化数据集
218
+ """
219
+
220
+ self.label_path = label_path
221
+ self.dataset_dir = os.path.dirname(label_path)
222
+
223
+ # 检查必要文件和文件夹是否存在
224
+ assert os.path.exists(label_path), f"{label_path}文件不存在: {label_path}"
225
+
226
+ # 加载csv
227
+ self.data = []
228
+ df = pd.read_csv(label_path, sep="\t")
229
+ for i, row in df.iterrows():
230
+ audio_path = os.path.join(
231
+ self.dataset_dir, row["SPEAKER_ID"], row["UTTRANS_ID"]
232
+ )
233
+ gt = row["TRANSCRIPTION"]
234
+ self.data.append({"audio_path": audio_path, "gt": gt})
235
+
236
+ # 使用logging而不是print
237
+ logger = logging.getLogger()
238
+ logger.info(f"加载了 {len(self.data)} 条数据")
239
+
240
+ def __iter__(self):
241
+ """返回迭代器"""
242
+ self.index = 0
243
+ return self
244
+
245
+ def __next__(self):
246
+ """返回下一个数据项"""
247
+ if self.index >= len(self.data):
248
+ raise StopIteration
249
+
250
+ item = self.data[self.index]
251
+ audio_path = item["audio_path"]
252
+ ground_truth = item["gt"]
253
+
254
+ self.index += 1
255
+ return audio_path, ground_truth
256
+
257
+ def __len__(self):
258
+ """返回数据集大小"""
259
+ return len(self.data)
260
+
261
+
262
  def get_args():
263
  parser = argparse.ArgumentParser(prog="whisper", description="Test WER on dataset")
264
  parser.add_argument(
 
266
  "-d",
267
  type=str,
268
  required=True,
269
+ choices=["aishell", "common_voice", "custom"],
270
  help="Test dataset",
271
  )
272
  parser.add_argument(
 
283
  "--model_type",
284
  "-t",
285
  type=str,
286
+ choices=["tiny", "base", "small", "medium", "large", "large-v3", "turbo"],
287
  required=True,
288
  help="model type, only support tiny, base and small currently",
289
  )
 
292
  "-p",
293
  type=str,
294
  required=False,
295
+ default="../models-ax650",
296
+ help="model path for *.axmodel, tokens.txt",
297
+ )
298
+ parser.add_argument(
299
+ "--repo_id", type=str, default=None, help="repo id from huggingface"
300
  )
301
  parser.add_argument(
302
  "--language",
 
306
  default="zh",
307
  help="Target language, support en, zh, ja, and others. See languages.py for more options.",
308
  )
309
+ parser.add_argument(
310
+ "--backend", type=str, default="ax", choices=["ax", "torch", "onnx"]
311
+ )
312
+ parser.add_argument("--log_name", type=str, default="test_wer")
313
  return parser.parse_args()
314
 
315
 
316
  def print_args(args):
317
  logger = logging.getLogger()
318
+ logger.info(vars(args))
 
 
 
 
 
319
 
320
 
321
  def min_distance(word1: str, word2: str) -> int:
 
359
 
360
 
361
  def main():
 
 
 
362
  args = get_args()
363
+
364
+ # 设置日志系统
365
+ logger = setup_logging(args.log_name)
366
  print_args(args)
367
 
368
  dataset_type = args.dataset.lower()
 
370
  dataset = AIShellDataset(args.gt_path)
371
  elif dataset_type == "common_voice":
372
  dataset = CommonVoiceDataset(args.gt_path)
373
+ elif dataset_type == "custom":
374
+ dataset = CustomDataset(args.gt_path)
375
  else:
376
  raise ValueError(f"Unknown dataset type {dataset_type}")
377
 
378
  max_num = args.max_num
379
 
380
  # Load model
381
+ use_hf_model = False
382
+ tokenizer = None
383
+ task = "transcribe"
384
+
385
+ if args.backend == "ax":
386
+ from whisper_ax import Whisper
387
+
388
+ model = Whisper(args.model_type, args.model_path, args.language, task)
389
+ elif args.backend == "torch":
390
+ if args.repo_id is not None:
391
+ use_hf_model = True
392
+
393
+ from transformers import WhisperForConditionalGeneration
394
+ import torch
395
+
396
+ model = WhisperForConditionalGeneration.from_pretrained(
397
+ args.repo_id,
398
+ dtype=torch.float32,
399
+ ).cpu()
400
+ else:
401
+ import whisper
402
+
403
+ model = whisper.load_model(args.model_type).cpu()
404
+
405
+ tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True)
406
+ elif args.backend == "onnx":
407
+ import onnxruntime as ort
408
+ from ..model_convert.generate_data import OnnxModel
409
+
410
+ encoder_path = os.path.join(
411
+ args.model_path, f"{args.model_type}/{args.model_type}-encoder.onnx"
412
+ )
413
+ decoder_path = os.path.join(
414
+ args.model_path, f"{args.model_type}/{args.model_type}-decoder.onnx"
415
+ )
416
+ model = OnnxModel(encoder_path, decoder_path)
417
 
418
  # Iterate over dataset
419
  references = []
420
  hyp = []
421
  all_character_error_num = 0
422
  all_character_num = 0
 
423
  max_data_num = max_num if max_num > 0 else len(dataset)
424
  for n, (audio_path, reference) in enumerate(dataset):
425
+ if args.backend == "ax":
426
+ hypothesis = model.run(audio_path)
427
+ elif args.backend == "torch":
428
+ if use_hf_model:
429
+ with torch.no_grad():
430
+ feature = compute_feat(audio_path, model.config.num_mel_bins)
431
+ r = model.generate(
432
+ torch.from_numpy(feature),
433
+ output_scores=True,
434
+ return_dict_in_generate=True,
435
+ return_timestamps=False,
436
+ language=args.language,
437
+ task="transcribe",
438
+ )
439
+
440
+ tokens = r["sequences"][0][4:-1]
441
+ hypothesis = "".join(tokenizer.decode(tokens)).strip()
442
+ else:
443
+ result = model.transcribe(
444
+ audio_path, fp16=False, language=args.language
445
+ )
446
+ hypothesis = result["text"]
447
+ if args.language == "zh":
448
+ hypothesis = zhconv.convert(hypothesis, "zh-hans")
449
+
450
+ elif args.backend == "onnx":
451
+ hypothesis = model.run(audio_path, args.language, task)
452
 
453
+ hypothesis = remove_punctuation(hypothesis).lower()
454
+ reference = remove_punctuation(reference).lower()
455
 
456
  character_error_num = min_distance(reference, hypothesis)
457
  character_num = len(reference)
 
464
  references.append(reference)
465
 
466
  line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
 
467
  logger.info(line_content)
468
 
469
  if n + 1 >= max_data_num:
 
472
  total_character_error_rate = all_character_error_num / all_character_num * 100
473
 
474
  logger.info(f"Total WER: {total_character_error_rate}%")
 
 
475
 
476
 
477
  if __name__ == "__main__":
python/{whisper.py → whisper_ax.py} RENAMED
@@ -2,11 +2,11 @@ import axengine as axe
2
  import numpy as np
3
  import librosa
4
  import os
5
- from typing import Union
6
- from whisper_tokenizer import *
7
  import json
8
- from dataclasses import dataclass
9
  import zhconv
 
10
 
11
 
12
  @dataclass
@@ -34,11 +34,9 @@ class WhisperConfig:
34
 
35
  class Whisper:
36
  def __init__(self, model_type: str, model_path: str, language: str, task: str):
37
- assert task in ["translate", "transcribe"]
38
-
39
  self.language = language
40
  self.task = task
41
- self.encoder, self.decoder, self.tokenizer, model_config = self.load_model(
42
  model_type, model_path, language, task
43
  )
44
  self.config = self.load_config(model_config)
@@ -73,16 +71,20 @@ class Whisper:
73
  model_config["all_language_codes"] = [
74
  i for i in model_config["all_language_codes"].split(",")
75
  ]
76
- tokenizer = get_tokenizer(
77
- model_config["is_multilingual"],
78
- num_languages=len(model_config["all_language_codes"]),
79
- language=language,
80
- task=task,
81
- )
82
 
83
  self.id2token = self.load_tokens(required_files[3])
 
 
 
 
 
 
 
 
 
 
84
 
85
- return encoder, decoder, tokenizer, model_config
86
 
87
  def load_config(self, model_config):
88
  config = WhisperConfig
@@ -109,6 +111,7 @@ class Whisper:
109
  task_token = (
110
  config.transcribe if self.task == "transcribe" else config.translate
111
  )
 
112
  config.sot_sequence = np.array(
113
  [config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32
114
  )
@@ -124,9 +127,14 @@ class Whisper:
124
  return tokens
125
 
126
  def load_audio(self, audio: str):
127
- data, sample_rate = librosa.load(audio, sr=self.config.sample_rate)
128
- samples = np.ascontiguousarray(data)
129
- return samples, sample_rate
 
 
 
 
 
130
 
131
  def compute_feature(self, audio: np.ndarray):
132
  mel = librosa.feature.melspectrogram(
@@ -189,19 +197,27 @@ class Whisper:
189
  return out
190
 
191
  def get_self_cache(self) -> List[np.ndarray]:
192
- self_cache = []
193
  batch_size = 1
194
- for i in range(self.config.n_text_layer):
195
- k = np.zeros(
196
- (batch_size, self.config.n_text_ctx, self.config.n_text_state),
197
- dtype=np.float32,
198
- )
199
- v = np.zeros(
200
- (batch_size, self.config.n_text_ctx, self.config.n_text_state),
201
- dtype=np.float32,
202
- )
203
- self_cache.extend([k, v])
204
- return self_cache
 
 
 
 
 
 
 
 
 
205
 
206
  def causal_mask_1d(self, n: int, L: int):
207
  """
@@ -214,47 +230,46 @@ class Whisper:
214
  mask[:n] = 0
215
  return mask
216
 
217
- def run(self, audio: Union[str, np.ndarray]) -> str:
218
- if isinstance(audio, str):
219
- audio, sample_rate = self.load_audio(audio)
220
-
221
- mel = self.compute_feature(audio)
222
-
223
- cross_kv = self.run_encoder(mel)
224
 
225
- self_kv = self.get_self_cache()
226
 
227
  offset = np.array([0], dtype=np.int32)
228
  for t in self.config.sot_sequence:
229
  token = np.array([[t]], dtype=np.int32) # sot
230
  mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
231
 
232
- out = self.run_decoder([token] + self_kv + cross_kv + [offset, mask])
 
 
233
 
234
- for i in range(1, len(out)):
235
- self_kv[i - 1][:, offset.item() : offset.item() + 1, :] = out[i]
236
 
237
  offset += 1
238
 
239
- idx = out[0][0, 0].argmax()
240
 
241
  eot = self.config.eot
242
 
243
  ans = []
244
 
245
- while idx != eot and offset.item() < 100:
246
  ans.append(idx)
247
  token = np.array([[idx]], dtype=np.int32)
248
 
249
  mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
250
 
251
- out = self.run_decoder([token] + self_kv + cross_kv + [offset, mask])
 
 
252
 
253
- for i in range(1, len(out)):
254
- self_kv[i - 1][:, offset.item() : offset.item() + 1, :] = out[i]
255
 
256
  offset += 1
257
- idx = out[0][0, 0].argmax()
258
 
259
  # print(ans)
260
 
@@ -273,3 +288,19 @@ class Whisper:
273
  return text
274
 
275
  return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import librosa
4
  import os
5
+ from typing import Union, List
 
6
  import json
7
+ from dataclasses import dataclass, field
8
  import zhconv
9
+ import base64
10
 
11
 
12
  @dataclass
 
34
 
35
  class Whisper:
36
  def __init__(self, model_type: str, model_path: str, language: str, task: str):
 
 
37
  self.language = language
38
  self.task = task
39
+ self.encoder, self.decoder, model_config = self.load_model(
40
  model_type, model_path, language, task
41
  )
42
  self.config = self.load_config(model_config)
 
71
  model_config["all_language_codes"] = [
72
  i for i in model_config["all_language_codes"].split(",")
73
  ]
 
 
 
 
 
 
74
 
75
  self.id2token = self.load_tokens(required_files[3])
76
+ self.lang2token = {
77
+ k: v
78
+ for k, v in zip(
79
+ model_config["all_language_codes"], model_config["all_language_tokens"]
80
+ )
81
+ }
82
+ self.task2token = {
83
+ "transcribe": model_config["transcribe"],
84
+ "translate": model_config["translate"],
85
+ }
86
 
87
+ return encoder, decoder, model_config
88
 
89
  def load_config(self, model_config):
90
  config = WhisperConfig
 
111
  task_token = (
112
  config.transcribe if self.task == "transcribe" else config.translate
113
  )
114
+
115
  config.sot_sequence = np.array(
116
  [config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32
117
  )
 
127
  return tokens
128
 
129
  def load_audio(self, audio: str):
130
+ samples, sample_rate = librosa.load(audio, sr=self.config.sample_rate)
131
+ if sample_rate != self.config.sample_rate:
132
+ samples = librosa.resample(
133
+ samples, orig_sr=sample_rate, target_sr=self.config.sample_rate
134
+ )
135
+
136
+ samples = np.ascontiguousarray(samples)
137
+ return samples, self.config.sample_rate
138
 
139
  def compute_feature(self, audio: np.ndarray):
140
  mel = librosa.feature.melspectrogram(
 
197
  return out
198
 
199
  def get_self_cache(self) -> List[np.ndarray]:
 
200
  batch_size = 1
201
+
202
+ self_k = np.zeros(
203
+ (
204
+ self.config.n_text_layer,
205
+ batch_size,
206
+ self.config.n_text_ctx,
207
+ self.config.n_text_state,
208
+ ),
209
+ dtype=np.float32,
210
+ )
211
+ self_v = np.zeros(
212
+ (
213
+ self.config.n_text_layer,
214
+ batch_size,
215
+ self.config.n_text_ctx,
216
+ self.config.n_text_state,
217
+ ),
218
+ dtype=np.float32,
219
+ )
220
+ return self_k, self_v
221
 
222
  def causal_mask_1d(self, n: int, L: int):
223
  """
 
230
  mask[:n] = 0
231
  return mask
232
 
233
+ def run_mel(self, mel):
234
+ cross_k, cross_v = self.run_encoder(mel)
 
 
 
 
 
235
 
236
+ self_k, self_v = self.get_self_cache()
237
 
238
  offset = np.array([0], dtype=np.int32)
239
  for t in self.config.sot_sequence:
240
  token = np.array([[t]], dtype=np.int32) # sot
241
  mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
242
 
243
+ logits, this_self_k, this_self_v = self.run_decoder(
244
+ [token] + [self_k, self_v] + [cross_k, cross_v] + [offset, mask]
245
+ )
246
 
247
+ self_k[:, :, offset.item() : offset.item() + 1, :] = this_self_k
248
+ self_v[:, :, offset.item() : offset.item() + 1, :] = this_self_v
249
 
250
  offset += 1
251
 
252
+ idx = logits[0, 0].argmax()
253
 
254
  eot = self.config.eot
255
 
256
  ans = []
257
 
258
+ while idx != eot and offset.item() < self.config.n_text_ctx:
259
  ans.append(idx)
260
  token = np.array([[idx]], dtype=np.int32)
261
 
262
  mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
263
 
264
+ logits, this_self_k, this_self_v = self.run_decoder(
265
+ [token] + [self_k, self_v] + [cross_k, cross_v] + [offset, mask]
266
+ )
267
 
268
+ self_k[:, :, offset.item() : offset.item() + 1, :] = this_self_k
269
+ self_v[:, :, offset.item() : offset.item() + 1, :] = this_self_v
270
 
271
  offset += 1
272
+ idx = logits[0, 0].argmax()
273
 
274
  # print(ans)
275
 
 
288
  return text
289
 
290
  return text
291
+
292
+ def run(
293
+ self, audio: Union[str, np.ndarray], language: str = None, task: str = None
294
+ ) -> str:
295
+ if isinstance(audio, str):
296
+ audio, sample_rate = self.load_audio(audio)
297
+
298
+ mel = self.compute_feature(audio)
299
+
300
+ if language is not None and self.language != language:
301
+ self.config.sot_sequence[1] = self.lang2token(language)
302
+
303
+ if task is not None and self.task != task:
304
+ self.config.sot_sequence[2] = self.task2token(task)
305
+
306
+ return self.run_mel(mel)
python/whisper_cli.py CHANGED
@@ -1,46 +1,70 @@
1
- import requests
 
 
 
2
 
3
 
4
- def transcribe_audio(
5
- server_url: str,
6
- wav_path: str,
7
- model_type: str = "tiny",
8
- model_path: str = "../models/models-ax650",
9
- language: str = "zh",
10
- task: str = "transcribe",
11
- ):
12
- url = f"{server_url.rstrip('/')}/asr"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- files = {
15
- "wav": open(wav_path, "rb"),
16
- }
17
 
18
- data = {
19
- "model_type": model_type,
20
- "model_path": model_path,
21
- "language": language,
22
- "task": task,
23
- }
24
 
25
- print(f"Sending request to: {url}")
 
 
26
 
27
- response = requests.post(url, files=files, data=data)
28
- if response.status_code != 200:
29
- print("❌ Error:", response.text)
30
- return None
31
 
32
- result = response.json()
33
- print("服务器返回结果:")
34
- print(result)
 
35
 
36
- return result
37
 
 
 
 
 
38
 
39
- if __name__ == "__main__":
40
- # 你的服务器地址
41
- SERVER = "http://127.0.0.1:8000"
42
-
43
- # 本地 wav 文件路径
44
- WAV = "../demo.wav"
45
 
46
- transcribe_audio(SERVER, WAV)
 
 
1
+ import argparse
2
+ import os
3
+ from whisper_ax import Whisper
4
+ import time
5
 
6
 
7
+ def get_args():
8
+ parser = argparse.ArgumentParser(
9
+ prog="whisper", description="Run Whisper on input audio file"
10
+ )
11
+ parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
12
+ parser.add_argument(
13
+ "--model_type",
14
+ "-t",
15
+ type=str,
16
+ choices=["tiny", "base", "small", "large", "large-v3", "turbo"],
17
+ required=True,
18
+ help="model type, only support tiny, base and small currently",
19
+ )
20
+ parser.add_argument(
21
+ "--model_path",
22
+ "-p",
23
+ type=str,
24
+ required=False,
25
+ default="../models-ax650",
26
+ help="model path for *.axmodel, tokens.txt",
27
+ )
28
+ parser.add_argument(
29
+ "--language",
30
+ "-l",
31
+ type=str,
32
+ required=False,
33
+ default="zh",
34
+ help="Target language, support en, zh, ja, and others. See languages.py for more options.",
35
+ )
36
+ parser.add_argument(
37
+ "--task",
38
+ type=str,
39
+ required=False,
40
+ choices=["translate", "transcribe"],
41
+ default="transcribe",
42
+ )
43
+ return parser.parse_args()
44
 
 
 
 
45
 
46
+ def main():
47
+ args = get_args()
48
+ print(vars(args))
 
 
 
49
 
50
+ # Check wav existence
51
+ wav_path = args.wav
52
+ assert os.path.exists(wav_path), f"{wav_path} NOT exist"
53
 
54
+ model = Whisper(args.model_type, args.model_path, args.language, args.task)
 
 
 
55
 
56
+ print("ASR result:")
57
+ start = time.time()
58
+ print(model.run(wav_path))
59
+ end = time.time()
60
 
61
+ import librosa
62
 
63
+ samples, sr = librosa.load(wav_path, sr=16000)
64
+ duration = len(samples) / sr
65
+ process_time = end - start
66
+ print(f"RTF: {process_time / duration}")
67
 
 
 
 
 
 
 
68
 
69
+ if __name__ == "__main__":
70
+ main()
python/whisper_svr.py CHANGED
@@ -5,7 +5,7 @@ import tempfile
5
  from http.server import BaseHTTPRequestHandler, HTTPServer
6
  from urllib.parse import parse_qs
7
 
8
- from whisper import Whisper
9
  import cgi
10
 
11
 
 
5
  from http.server import BaseHTTPRequestHandler, HTTPServer
6
  from urllib.parse import parse_qs
7
 
8
+ from whisper_ax import Whisper
9
  import cgi
10
 
11
 
python/whisper_tokenizer.py DELETED
@@ -1,395 +0,0 @@
1
- import base64
2
- import os
3
- import string
4
- from dataclasses import dataclass, field
5
- from functools import cached_property, lru_cache
6
- from typing import Dict, List, Optional, Tuple
7
-
8
- import tiktoken
9
-
10
- LANGUAGES = {
11
- "en": "english",
12
- "zh": "chinese",
13
- "de": "german",
14
- "es": "spanish",
15
- "ru": "russian",
16
- "ko": "korean",
17
- "fr": "french",
18
- "ja": "japanese",
19
- "pt": "portuguese",
20
- "tr": "turkish",
21
- "pl": "polish",
22
- "ca": "catalan",
23
- "nl": "dutch",
24
- "ar": "arabic",
25
- "sv": "swedish",
26
- "it": "italian",
27
- "id": "indonesian",
28
- "hi": "hindi",
29
- "fi": "finnish",
30
- "vi": "vietnamese",
31
- "he": "hebrew",
32
- "uk": "ukrainian",
33
- "el": "greek",
34
- "ms": "malay",
35
- "cs": "czech",
36
- "ro": "romanian",
37
- "da": "danish",
38
- "hu": "hungarian",
39
- "ta": "tamil",
40
- "no": "norwegian",
41
- "th": "thai",
42
- "ur": "urdu",
43
- "hr": "croatian",
44
- "bg": "bulgarian",
45
- "lt": "lithuanian",
46
- "la": "latin",
47
- "mi": "maori",
48
- "ml": "malayalam",
49
- "cy": "welsh",
50
- "sk": "slovak",
51
- "te": "telugu",
52
- "fa": "persian",
53
- "lv": "latvian",
54
- "bn": "bengali",
55
- "sr": "serbian",
56
- "az": "azerbaijani",
57
- "sl": "slovenian",
58
- "kn": "kannada",
59
- "et": "estonian",
60
- "mk": "macedonian",
61
- "br": "breton",
62
- "eu": "basque",
63
- "is": "icelandic",
64
- "hy": "armenian",
65
- "ne": "nepali",
66
- "mn": "mongolian",
67
- "bs": "bosnian",
68
- "kk": "kazakh",
69
- "sq": "albanian",
70
- "sw": "swahili",
71
- "gl": "galician",
72
- "mr": "marathi",
73
- "pa": "punjabi",
74
- "si": "sinhala",
75
- "km": "khmer",
76
- "sn": "shona",
77
- "yo": "yoruba",
78
- "so": "somali",
79
- "af": "afrikaans",
80
- "oc": "occitan",
81
- "ka": "georgian",
82
- "be": "belarusian",
83
- "tg": "tajik",
84
- "sd": "sindhi",
85
- "gu": "gujarati",
86
- "am": "amharic",
87
- "yi": "yiddish",
88
- "lo": "lao",
89
- "uz": "uzbek",
90
- "fo": "faroese",
91
- "ht": "haitian creole",
92
- "ps": "pashto",
93
- "tk": "turkmen",
94
- "nn": "nynorsk",
95
- "mt": "maltese",
96
- "sa": "sanskrit",
97
- "lb": "luxembourgish",
98
- "my": "myanmar",
99
- "bo": "tibetan",
100
- "tl": "tagalog",
101
- "mg": "malagasy",
102
- "as": "assamese",
103
- "tt": "tatar",
104
- "haw": "hawaiian",
105
- "ln": "lingala",
106
- "ha": "hausa",
107
- "ba": "bashkir",
108
- "jw": "javanese",
109
- "su": "sundanese",
110
- "yue": "cantonese",
111
- }
112
-
113
- # language code lookup by name, with a few language aliases
114
- TO_LANGUAGE_CODE = {
115
- **{language: code for code, language in LANGUAGES.items()},
116
- "burmese": "my",
117
- "valencian": "ca",
118
- "flemish": "nl",
119
- "haitian": "ht",
120
- "letzeburgesch": "lb",
121
- "pushto": "ps",
122
- "panjabi": "pa",
123
- "moldavian": "ro",
124
- "moldovan": "ro",
125
- "sinhalese": "si",
126
- "castilian": "es",
127
- "mandarin": "zh",
128
- }
129
-
130
-
131
- @dataclass
132
- class Tokenizer:
133
- """A thin wrapper around `tiktoken` providing quick access to special tokens"""
134
-
135
- encoding: tiktoken.Encoding
136
- num_languages: int
137
- language: Optional[str] = None
138
- task: Optional[str] = None
139
- sot_sequence: Tuple[int] = ()
140
- special_tokens: Dict[str, int] = field(default_factory=dict)
141
-
142
- def __post_init__(self):
143
- for special in self.encoding.special_tokens_set:
144
- special_token = self.encoding.encode_single_token(special)
145
- self.special_tokens[special] = special_token
146
-
147
- sot: int = self.special_tokens["<|startoftranscript|>"]
148
- translate: int = self.special_tokens["<|translate|>"]
149
- transcribe: int = self.special_tokens["<|transcribe|>"]
150
-
151
- langs = tuple(LANGUAGES.keys())[: self.num_languages]
152
- sot_sequence = [sot]
153
- if self.language is not None:
154
- sot_sequence.append(sot + 1 + langs.index(self.language))
155
- if self.task is not None:
156
- task_token: int = transcribe if self.task == "transcribe" else translate
157
- sot_sequence.append(task_token)
158
-
159
- self.sot_sequence = tuple(sot_sequence)
160
-
161
- def encode(self, text, **kwargs):
162
- return self.encoding.encode(text, **kwargs)
163
-
164
- def decode(self, token_ids: List[int], **kwargs) -> str:
165
- token_ids = [t for t in token_ids if t < self.timestamp_begin]
166
- return self.encoding.decode(token_ids, **kwargs)
167
-
168
- def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
169
- """
170
- Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
171
- This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
172
- """
173
- return self.encoding.decode(token_ids, **kwargs)
174
-
175
- @cached_property
176
- def eot(self) -> int:
177
- return self.encoding.eot_token
178
-
179
- @cached_property
180
- def transcribe(self) -> int:
181
- return self.special_tokens["<|transcribe|>"]
182
-
183
- @cached_property
184
- def translate(self) -> int:
185
- return self.special_tokens["<|translate|>"]
186
-
187
- @cached_property
188
- def sot(self) -> int:
189
- return self.special_tokens["<|startoftranscript|>"]
190
-
191
- @cached_property
192
- def sot_lm(self) -> int:
193
- return self.special_tokens["<|startoflm|>"]
194
-
195
- @cached_property
196
- def sot_prev(self) -> int:
197
- return self.special_tokens["<|startofprev|>"]
198
-
199
- @cached_property
200
- def no_speech(self) -> int:
201
- return self.special_tokens["<|nospeech|>"]
202
-
203
- @cached_property
204
- def no_timestamps(self) -> int:
205
- return self.special_tokens["<|notimestamps|>"]
206
-
207
- @cached_property
208
- def timestamp_begin(self) -> int:
209
- return self.special_tokens["<|0.00|>"]
210
-
211
- @cached_property
212
- def language_token(self) -> int:
213
- """Returns the token id corresponding to the value of the `language` field"""
214
- if self.language is None:
215
- raise ValueError("This tokenizer does not have language token configured")
216
-
217
- return self.to_language_token(self.language)
218
-
219
- def to_language_token(self, language):
220
- if token := self.special_tokens.get(f"<|{language}|>", None):
221
- return token
222
-
223
- raise KeyError(f"Language {language} not found in tokenizer.")
224
-
225
- @cached_property
226
- def all_language_tokens(self) -> Tuple[int]:
227
- result = []
228
- for token, token_id in self.special_tokens.items():
229
- if token.strip("<|>") in LANGUAGES:
230
- result.append(token_id)
231
- return tuple(result)[: self.num_languages]
232
-
233
- @cached_property
234
- def all_language_codes(self) -> Tuple[str]:
235
- return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
236
-
237
- @cached_property
238
- def sot_sequence_including_notimestamps(self) -> Tuple[int]:
239
- return tuple(list(self.sot_sequence) + [self.no_timestamps])
240
-
241
- @cached_property
242
- def non_speech_tokens(self) -> Tuple[int]:
243
- """
244
- Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
245
- annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
246
-
247
- - ♪♪♪
248
- - ( SPEAKING FOREIGN LANGUAGE )
249
- - [DAVID] Hey there,
250
-
251
- keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
252
- """
253
- symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
254
- symbols += (
255
- "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
256
- )
257
-
258
- # symbols that may be a single token or multiple tokens depending on the tokenizer.
259
- # In case they're multiple tokens, suppress the first token, which is safe because:
260
- # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
261
- # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
262
- miscellaneous = set("♩♪♫♬♭♮♯")
263
- assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
264
-
265
- # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
266
- result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
267
- for symbol in symbols + list(miscellaneous):
268
- for tokens in [
269
- self.encoding.encode(symbol),
270
- self.encoding.encode(" " + symbol),
271
- ]:
272
- if len(tokens) == 1 or symbol in miscellaneous:
273
- result.add(tokens[0])
274
-
275
- return tuple(sorted(result))
276
-
277
- def split_to_word_tokens(self, tokens: List[int]):
278
- if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
279
- # These languages don't typically use spaces, so it is difficult to split words
280
- # without morpheme analysis. Here, we instead split words at any
281
- # position where the tokens are decoded as valid unicode points
282
- return self.split_tokens_on_unicode(tokens)
283
-
284
- return self.split_tokens_on_spaces(tokens)
285
-
286
- def split_tokens_on_unicode(self, tokens: List[int]):
287
- decoded_full = self.decode_with_timestamps(tokens)
288
- replacement_char = "\ufffd"
289
-
290
- words = []
291
- word_tokens = []
292
- current_tokens = []
293
- unicode_offset = 0
294
-
295
- for token in tokens:
296
- current_tokens.append(token)
297
- decoded = self.decode_with_timestamps(current_tokens)
298
-
299
- if (
300
- replacement_char not in decoded
301
- or decoded_full[unicode_offset + decoded.index(replacement_char)]
302
- == replacement_char
303
- ):
304
- words.append(decoded)
305
- word_tokens.append(current_tokens)
306
- current_tokens = []
307
- unicode_offset += len(decoded)
308
-
309
- return words, word_tokens
310
-
311
- def split_tokens_on_spaces(self, tokens: List[int]):
312
- subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
313
- words = []
314
- word_tokens = []
315
-
316
- for subword, subword_tokens in zip(subwords, subword_tokens_list):
317
- special = subword_tokens[0] >= self.eot
318
- with_space = subword.startswith(" ")
319
- punctuation = subword.strip() in string.punctuation
320
- if special or with_space or punctuation or len(words) == 0:
321
- words.append(subword)
322
- word_tokens.append(subword_tokens)
323
- else:
324
- words[-1] = words[-1] + subword
325
- word_tokens[-1].extend(subword_tokens)
326
-
327
- return words, word_tokens
328
-
329
-
330
- @lru_cache(maxsize=None)
331
- def get_encoding(name: str = "gpt2", num_languages: int = 99):
332
- vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
333
- ranks = {
334
- base64.b64decode(token): int(rank)
335
- for token, rank in (line.split() for line in open(vocab_path) if line)
336
- }
337
- n_vocab = len(ranks)
338
- special_tokens = {}
339
-
340
- specials = [
341
- "<|endoftext|>",
342
- "<|startoftranscript|>",
343
- *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
344
- "<|translate|>",
345
- "<|transcribe|>",
346
- "<|startoflm|>",
347
- "<|startofprev|>",
348
- "<|nospeech|>",
349
- "<|notimestamps|>",
350
- *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
351
- ]
352
-
353
- for token in specials:
354
- special_tokens[token] = n_vocab
355
- n_vocab += 1
356
-
357
- return tiktoken.Encoding(
358
- name=os.path.basename(vocab_path),
359
- explicit_n_vocab=n_vocab,
360
- pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
361
- mergeable_ranks=ranks,
362
- special_tokens=special_tokens,
363
- )
364
-
365
-
366
- @lru_cache(maxsize=None)
367
- def get_tokenizer(
368
- multilingual: bool,
369
- *,
370
- num_languages: int = 99,
371
- language: Optional[str] = None,
372
- task: Optional[str] = None, # Literal["transcribe", "translate", None]
373
- ) -> Tokenizer:
374
- if language is not None:
375
- language = language.lower()
376
- if language not in LANGUAGES:
377
- if language in TO_LANGUAGE_CODE:
378
- language = TO_LANGUAGE_CODE[language]
379
- else:
380
- raise ValueError(f"Unsupported language: {language}")
381
-
382
- if multilingual:
383
- encoding_name = "multilingual"
384
- language = language or "en"
385
- task = task or "transcribe"
386
- else:
387
- encoding_name = "gpt2"
388
- language = None
389
- task = None
390
-
391
- encoding = get_encoding(name=encoding_name, num_languages=num_languages)
392
-
393
- return Tokenizer(
394
- encoding=encoding, num_languages=num_languages, language=language, task=task
395
- )