lxowalle commited on
Commit
ee4406b
·
0 Parent(s):

* support for maixcam2

Browse files
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sensevoice.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ *.axmodel filter=lfs diff=lfs merge=lfs -text
38
+ *.wav filter=lfs diff=lfs merge=lfs -text
39
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ .gradio
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 祈Inory
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ pipeline_tag: automatic-speech-recognition
6
+ ---
7
+ # sensevoice.axera
8
+ FunASR SenseVoice on Axera, official repo: https://github.com/FunAudioLLM/SenseVoice
9
+
10
+ ## TODO
11
+
12
+ - [x] 支持AX630C
13
+ - [ ] 支持C++
14
+ - [x] 支持FastAPI
15
+
16
+ ## 功能
17
+ - 语音识别
18
+ - 自动识别语言(支持中文、英文、粤语、日语、韩语)
19
+ - 情感识别
20
+ - 自动标点
21
+ - 支持流式识别
22
+
23
+ ## 支持平台
24
+
25
+ - [x] AX650N
26
+ - [x] AX630C
27
+
28
+ ## 环境安装
29
+ ```
30
+ pip3 install -r requirements.txt
31
+ ```
32
+ 如果空间不足可以使用 --prefix 指定别的安装路径
33
+
34
+
35
+ ## 使用
36
+ ```
37
+ # 首次运行会自动从huggingface上下载模型, 保存到models中
38
+ python3 main.py -i 输入音频文件
39
+ ```
40
+ 运行参数说明:
41
+ | 参数名称 | 说明 | 默认值 |
42
+ | --- | --- | --- |
43
+ | --input/-i | 输入音频文件 | |
44
+ | --language/-l | 识别语言,支持auto, zh, en, yue, ja, ko | auto |
45
+ | --streaming | 流式识别 | |
46
+
47
+
48
+ ### 示例:
49
+ example下有测试音频
50
+
51
+ 如 粤语测试
52
+ ```
53
+ python3 main.py -i example/yue.mp3
54
+ ```
55
+ 输出
56
+ ```
57
+ RTF: 0.03026517820946964 Latency: 0.15689468383789062s Total length: 5.184s
58
+ ['呢几个字。', '都表达唔到,我想讲嘅意。', '思。']
59
+ ```
60
+
61
+ 流式识别
62
+
63
+ ```
64
+ python3 main.py -i example/zh.mp3 --streaming
65
+ ```
66
+ 输出
67
+ ```
68
+ {'timestamps': [540], 'text': '开'}
69
+ {'timestamps': [540, 780, 1080], 'text': '开放时'}
70
+ {'timestamps': [540, 780, 1080, 1260, 1740], 'text': '开放时间早'}
71
+ {'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340], 'text': '开放时间早上9'}
72
+ {'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640], 'text': '开放时间早上9点'}
73
+ {'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060], 'text': '开放时间早上9点至'}
74
+ {'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060, 3780, 4020], 'text': '开放时间早上9点至下午'}
75
+ {'timestamps': [540, 780, 1080, 1260, 1740, 1920, 2340, 2640, 3060, 3780, 4020, 4440, 4620], 'text': '开放时间早上9点至下午五点'}
76
+ RTF: 0.03678379235444246
77
+
78
+ ```
79
+
80
+ ## 准确率
81
+
82
+ 使用WER(Word-Error-Rate)作为评价标准
83
+
84
+ **WER = 0.0389**
85
+
86
+ ### 复现测试结果
87
+
88
+ ```
89
+ ./download_datasets.sh
90
+ python test_wer.py -d datasets -l zh
91
+ ```
92
+
93
+ ## 技术讨论
94
+
95
+ - Github issues
96
+ - QQ 群: 139953715
SenseVoiceAx.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import axengine as axe
2
+ import numpy as np
3
+ import librosa
4
+ from frontend import WavFrontend
5
+ import os
6
+ import time
7
+ from typing import List, Union, Optional
8
+ from asr_decoder import CTCDecoder
9
+ from tokenizer import SentencepiecesTokenizer
10
+ from online_fbank import OnlineFbank
11
+ import torch
12
+
13
+
14
+ def sequence_mask(lengths, maxlen=None, dtype=np.float32):
15
+ # 如果 maxlen 未指定,则取 lengths 中的最大值
16
+ if maxlen is None:
17
+ maxlen = np.max(lengths)
18
+
19
+ # 创建一个从 0 到 maxlen-1 的行向量
20
+ row_vector = np.arange(0, maxlen, 1)
21
+
22
+ # 将 lengths 转换为列向量
23
+ matrix = np.expand_dims(lengths, axis=-1)
24
+
25
+ # 比较生成掩码
26
+ mask = row_vector < matrix
27
+ if mask.shape[-1] < lengths[0]:
28
+ mask = np.concatenate(
29
+ [
30
+ mask,
31
+ np.zeros(
32
+ (mask.shape[0], lengths[0] - mask.shape[-1]), dtype=np.float32
33
+ ),
34
+ ],
35
+ axis=-1,
36
+ )
37
+
38
+ # 返回指定数据类型的掩码
39
+ return mask.astype(dtype)[None, ...]
40
+
41
+
42
+ def unique_consecutive_np(arr):
43
+ """
44
+ 找出数组中连续的唯一值,模拟 torch.unique_consecutive(yseq, dim=-1)
45
+
46
+ 参数:
47
+ arr: 一维numpy数组
48
+
49
+ 返回:
50
+ unique_values: 去除连续重复值后的数组
51
+ """
52
+ if len(arr) == 0:
53
+ return np.array([])
54
+
55
+ if len(arr) == 1:
56
+ return arr.copy()
57
+
58
+ # 找出变化的位置
59
+ diff = np.diff(arr)
60
+ change_positions = np.where(diff != 0)[0] + 1
61
+
62
+ # 添加起始位置
63
+ start_positions = np.concatenate(([0], change_positions))
64
+
65
+ # 获取唯一值(每个连续段的第一个值)
66
+ unique_values = arr[start_positions]
67
+
68
+ return unique_values
69
+
70
+
71
+ class SenseVoiceAx:
72
+ """SenseVoice axmodel runner"""
73
+
74
+ def __init__(
75
+ self,
76
+ model_path: str,
77
+ max_len: int = 256,
78
+ beam_size: int = 3,
79
+ language: str = "auto",
80
+ hot_words: Optional[List[str]] = None,
81
+ use_itn: bool = True,
82
+ streaming: bool = False,
83
+ ):
84
+ """
85
+ Initialize SenseVoiceAx
86
+
87
+ Args:
88
+ model_path: Path of axmodel
89
+ max_len: Fixed shape of input of axmodel
90
+ beam_size: Max number of hypos to hold after each decode step
91
+ language: Support auto, zh(Chinese), en(English), yue(Cantonese), ja(Japanese), ko(Korean)
92
+ hot_words: Words that may fail to recognize,
93
+ special words/phrases (aka hotwords) like rare words, personalized information etc.
94
+ use_itn: Allow Invert Text Normalization if True,
95
+ ITN converts ASR model output into its written form to improve text readability,
96
+ For example, the ITN module replaces “one hundred and twenty-three dollars” transcribed by an ASR model with “$123.”
97
+ streaming: Processes audio in small segments or "chunks" sequentially and outputs text on the fly.
98
+ Use stream_infer method if streaming is true otherwise infer.
99
+
100
+ """
101
+ model_path_root = os.path.dirname(model_path)
102
+ emb_path = os.path.join(model_path_root, "../embeddings.npy")
103
+ cmvn_file = os.path.join(model_path_root, "../am.mvn")
104
+ bpe_model = os.path.join(
105
+ model_path_root, "../chn_jpn_yue_eng_ko_spectok.bpe.model"
106
+ )
107
+ if streaming:
108
+ self.position_encoding = np.load(
109
+ os.path.join(model_path_root, "../pe_streaming.npy")
110
+ )
111
+ else:
112
+ self.position_encoding = np.load(
113
+ os.path.join(model_path_root, "../pe_nonstream.npy")
114
+ )
115
+
116
+ self.streaming = streaming
117
+ self.tokenizer = SentencepiecesTokenizer(bpemodel=bpe_model)
118
+
119
+ self.frontend = WavFrontend(
120
+ cmvn_file=cmvn_file,
121
+ fs=16000,
122
+ window="hamming",
123
+ n_mels=80,
124
+ frame_length=25,
125
+ frame_shift=10,
126
+ lfr_m=7,
127
+ lfr_n=6,
128
+ )
129
+ self.model = axe.InferenceSession(model_path)
130
+ self.sample_rate = 16000
131
+ self.blank_id = 0
132
+ self.max_len = max_len
133
+ self.padding = 16
134
+ self.input_size = 560
135
+
136
+ self.lid_dict = {
137
+ "auto": 0,
138
+ "zh": 3,
139
+ "en": 4,
140
+ "yue": 7,
141
+ "ja": 11,
142
+ "ko": 12,
143
+ "nospeech": 13,
144
+ }
145
+ self.lid_int_dict = {
146
+ 24884: 3,
147
+ 24885: 4,
148
+ 24888: 7,
149
+ 24892: 11,
150
+ 24896: 12,
151
+ 24992: 13,
152
+ }
153
+ self.textnorm_dict = {"withitn": 14, "woitn": 15}
154
+ self.textnorm_int_dict = {25016: 14, 25017: 15}
155
+ self.emo_dict = {
156
+ "unk": 25009,
157
+ "happy": 25001,
158
+ "sad": 25002,
159
+ "angry": 25003,
160
+ "neutral": 25004,
161
+ }
162
+
163
+ self.load_embeddings(emb_path, language, use_itn)
164
+ self.language = language
165
+
166
+ # decoder
167
+ if beam_size > 1 and hot_words is not None:
168
+ self.beam_size = beam_size
169
+ symbol_table = {}
170
+ for i in range(self.tokenizer.get_vocab_size()):
171
+ symbol_table[self.tokenizer.decode(i)] = i
172
+ self.decoder = CTCDecoder(hot_words, symbol_table, bpe_model)
173
+ else:
174
+ self.beam_size = 1
175
+ self.decoder = CTCDecoder()
176
+
177
+ if streaming:
178
+ self.cur_idx = -1
179
+ self.chunk_size = max_len - self.padding
180
+ self.caches_shape = (max_len, self.input_size)
181
+ self.caches = np.zeros(self.caches_shape, dtype=np.float32)
182
+ self.zeros = np.zeros((1, self.input_size), dtype=np.float32)
183
+ self.neg_mean, self.inv_stddev = (
184
+ self.frontend.cmvn[0, :],
185
+ self.frontend.cmvn[1, :],
186
+ )
187
+
188
+ self.fbank = OnlineFbank(window_type="hamming")
189
+ self.masks = sequence_mask(
190
+ np.array([self.max_len], dtype=np.int32),
191
+ maxlen=self.max_len,
192
+ dtype=np.float32,
193
+ )
194
+
195
+ @property
196
+ def language_options(self):
197
+ return list(self.lid_dict.keys())
198
+
199
+ @property
200
+ def textnorm_options(self):
201
+ return list(self.textnorm_dict.keys())
202
+
203
+ def load_embeddings(self, emb_path, language, use_itn):
204
+ self.embeddings = np.load(emb_path, allow_pickle=True).item()
205
+ self.language_query = self.embeddings[language]
206
+ self.textnorm_query = (
207
+ self.embeddings["withitn"] if use_itn else self.embeddings["woitn"]
208
+ )
209
+ self.event_emo_query = self.embeddings["event_emo"]
210
+ self.input_query = np.concatenate(
211
+ (self.textnorm_query, self.language_query, self.event_emo_query), axis=1
212
+ )
213
+ self.query_num = self.input_query.shape[1]
214
+
215
+ def choose_language(self, language):
216
+ self.language_query = self.embeddings[language]
217
+ self.input_query = np.concatenate(
218
+ (self.textnorm_query, self.language_query, self.event_emo_query), axis=1
219
+ )
220
+ self.language = language
221
+
222
+ def load_data(self, filepath: str) -> np.ndarray:
223
+ waveform, _ = librosa.load(filepath, sr=self.sample_rate)
224
+ return waveform.flatten()
225
+
226
+ @staticmethod
227
+ def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
228
+ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
229
+ pad_width = ((0, max_feat_len - cur_len), (0, 0))
230
+ return np.pad(feat, pad_width, "constant", constant_values=0)
231
+
232
+ feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
233
+ feats = np.array(feat_res).astype(np.float32)
234
+ return feats
235
+
236
+ def preprocess(self, waveform):
237
+ feats, feats_len = [], []
238
+ for wf in [waveform]:
239
+ speech, _ = self.frontend.fbank(wf)
240
+ feat, feat_len = self.frontend.lfr_cmvn(speech)
241
+ feats.append(feat)
242
+ feats_len.append(feat_len)
243
+
244
+ feats = self.pad_feats(feats, np.max(feats_len))
245
+ feats_len = np.array(feats_len).astype(np.int32)
246
+ return feats, feats_len
247
+
248
+ def postprocess(self, ctc_logits, encoder_out_lens):
249
+ # 提取数据
250
+ x = ctc_logits[0, 4 : encoder_out_lens[0], :]
251
+
252
+ # 获取最大值索引
253
+ yseq = np.argmax(x, axis=-1)
254
+
255
+ # 去除连续重复元素
256
+ yseq = unique_consecutive_np(yseq)
257
+
258
+ # 创建掩码并过滤 blank_id
259
+ mask = yseq != self.blank_id
260
+ token_int = yseq[mask].tolist()
261
+
262
+ return token_int
263
+
264
+ def infer_waveform(self, waveform: np.ndarray, language="auto"):
265
+ if language != self.language:
266
+ self.choose_language(language)
267
+
268
+ # start = time.time()
269
+ feat, feat_len = self.preprocess(waveform)
270
+ # print(f"Preprocess take {time.time() - start}s")
271
+
272
+ slice_len = self.max_len - self.query_num
273
+ slice_num = int(np.ceil(feat.shape[1] / slice_len))
274
+
275
+ asr_res = []
276
+ for i in range(slice_num):
277
+ if i == 0:
278
+ sub_feat = feat[:, i * slice_len : (i + 1) * slice_len, :]
279
+ else:
280
+ sub_feat = feat[
281
+ :,
282
+ i * slice_len - self.padding : (i + 1) * slice_len - self.padding,
283
+ :,
284
+ ]
285
+ # concat query
286
+ sub_feat = np.concatenate([self.input_query, sub_feat], axis=1)
287
+ real_len = sub_feat.shape[1]
288
+ if real_len < self.max_len:
289
+ sub_feat = np.concatenate(
290
+ [
291
+ sub_feat,
292
+ np.zeros(
293
+ (1, self.max_len - real_len, sub_feat.shape[-1]),
294
+ dtype=np.float32,
295
+ ),
296
+ ],
297
+ axis=1,
298
+ )
299
+
300
+ masks = sequence_mask(
301
+ np.array([self.max_len], dtype=np.int32),
302
+ maxlen=real_len,
303
+ dtype=np.float32,
304
+ )
305
+
306
+ # start = time.time()
307
+ outputs = self.model.run(
308
+ None,
309
+ {
310
+ "speech": sub_feat,
311
+ "masks": masks,
312
+ "position_encoding": self.position_encoding,
313
+ },
314
+ )
315
+ ctc_logits, encoder_out_lens = outputs
316
+
317
+ token_int = self.postprocess(ctc_logits, encoder_out_lens)
318
+
319
+ if self.tokenizer is not None:
320
+ asr_res.append(self.tokenizer.tokens2text(token_int))
321
+ else:
322
+ asr_res.append(token_int)
323
+
324
+ return asr_res
325
+
326
+ def infer(
327
+ self, filepath_or_data: Union[np.ndarray, str], language="auto", print_rtf=False
328
+ ):
329
+ assert not self.streaming, "This method is for non-streaming model"
330
+
331
+ if isinstance(filepath_or_data, str):
332
+ waveform = self.load_data(filepath_or_data)
333
+ else:
334
+ waveform = filepath_or_data
335
+
336
+ total_time = waveform.shape[-1] / self.sample_rate
337
+
338
+ start = time.time()
339
+ asr_res = self.infer_waveform(waveform, language)
340
+ latency = time.time() - start
341
+
342
+ if print_rtf:
343
+ rtf = latency / total_time
344
+ print(f"RTF: {rtf} Latency: {latency}s Total length: {total_time}s")
345
+ return "".join(asr_res)
346
+
347
+ def decode(self, times, tokens):
348
+ times_ms = []
349
+ for step, token in zip(times, tokens):
350
+ if len(self.tokenizer.decode(token).strip()) == 0:
351
+ continue
352
+ times_ms.append(step * 60)
353
+ return times_ms, self.tokenizer.decode(tokens)
354
+
355
+ def reset(self):
356
+ self.cur_idx = -1
357
+ self.decoder.reset()
358
+ self.fbank = OnlineFbank(window_type="hamming")
359
+ self.caches = np.zeros(self.caches_shape)
360
+
361
+ def get_size(self):
362
+ effective_size = self.cur_idx + 1 - self.padding
363
+ if effective_size <= 0:
364
+ return 0
365
+ return effective_size % self.chunk_size or self.chunk_size
366
+
367
+ def stream_infer(self, audio, is_last, language="auto"):
368
+ assert self.streaming, "This method is for streaming model"
369
+
370
+ if language != self.language:
371
+ self.choose_language(language)
372
+
373
+ self.fbank.accept_waveform(audio, is_last)
374
+ features = self.fbank.get_lfr_frames(
375
+ neg_mean=self.neg_mean, inv_stddev=self.inv_stddev
376
+ )
377
+
378
+ if is_last and len(features) == 0:
379
+ features = self.zeros
380
+
381
+ for idx, feature in enumerate(features):
382
+ is_last = is_last and idx == features.shape[0] - 1
383
+ self.caches = np.roll(self.caches, -1, axis=0)
384
+ self.caches[-1, :] = feature
385
+ self.cur_idx += 1
386
+ cur_size = self.get_size()
387
+ if cur_size != self.chunk_size and not is_last:
388
+ continue
389
+
390
+ speech = self.caches[None, ...]
391
+ outputs = self.model.run(
392
+ None,
393
+ {
394
+ "speech": speech,
395
+ "masks": self.masks,
396
+ "position_encoding": self.position_encoding,
397
+ },
398
+ )
399
+ ctc_logits, encoder_out_lens = outputs
400
+ probs = ctc_logits[0, 4 : encoder_out_lens[0]]
401
+ probs = torch.from_numpy(probs)
402
+
403
+ if cur_size != self.chunk_size:
404
+ probs = probs[self.chunk_size - cur_size :]
405
+ if not is_last:
406
+ probs = probs[: self.chunk_size]
407
+ if self.beam_size > 1:
408
+ res = self.decoder.ctc_prefix_beam_search(
409
+ probs, beam_size=self.beam_size, is_last=is_last
410
+ )
411
+ times_ms, text = self.decode(res["times"][0], res["tokens"][0])
412
+ else:
413
+ res = self.decoder.ctc_greedy_search(probs, is_last=is_last)
414
+ times_ms, text = self.decode(res["times"], res["tokens"])
415
+ yield {"timestamps": times_ms, "text": text}
am.mvn ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <Nnet>
2
+ <Splice> 560 560
3
+ [ 0 ]
4
+ <AddShift> 560 560
5
+ <LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
6
+ <Rescale> 560 560
7
+ <LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
8
+ </Nnet>
chn_jpn_yue_eng_ko_spectok.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa87f86064c3730d799ddf7af3c04659151102cba548bce325cf06ba4da4e6a8
3
+ size 377341
client.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests, json, os
2
+ import librosa
3
+
4
+ class SensevoiceClient:
5
+ def __init__(self, model = "", url="http://0.0.0.0:12347", lauguage="auto", stream=False):
6
+ self.model = model
7
+ self.url = url
8
+ self.stream = stream
9
+ self.launguage = lauguage
10
+ def _check_service(self):
11
+ try:
12
+ response = requests.get(self.url + '/status')
13
+ if response.status_code == 200:
14
+ return True
15
+ except:
16
+ return False
17
+
18
+ def _start_service(self):
19
+ import time
20
+ if not self._check_service():
21
+ os.system("systemctl start sensevoice.service")
22
+
23
+ while not self._check_service():
24
+ print("Waiting for service to start...")
25
+ time.sleep(1)
26
+
27
+ return True
28
+
29
+ def _stop_service(self):
30
+ os.system("systemctl stop sensevoice.service")
31
+
32
+ def _get_status(self):
33
+ try:
34
+ response = requests.get(self.url + '/status')
35
+ if response.status_code == 200:
36
+ res = json.loads(response.text)
37
+ return res["status"]
38
+ except:
39
+ return "not loaded"
40
+
41
+ def _start_model(self):
42
+ try:
43
+ data = {
44
+ "model_path": self.model,
45
+ "sample_rate": 16000,
46
+ "language": self.launguage,
47
+ "stream": self.stream
48
+ }
49
+ response = requests.post(self.url + '/start_model', json=data)
50
+ if response.status_code == 200:
51
+ res = json.loads(response.text)
52
+ return True if res["status"] == 'loaded' else False
53
+ except Exception as e:
54
+ return False
55
+
56
+ def _stop_model(self):
57
+ try:
58
+ response = requests.post(self.url + '/_stop_model')
59
+ if response.status_code == 200:
60
+ res = json.loads(response.text)
61
+ return True if res["status"] == 'not loaded' else False
62
+ except Exception as e:
63
+ return False
64
+
65
+ def start(self):
66
+ if self._start_service():
67
+ print("Service started successfully.")
68
+ else:
69
+ print("Failed to start service.")
70
+ return False
71
+
72
+ if self._start_model():
73
+ print("Model started successfully.")
74
+ else:
75
+ print("Failed to start model.")
76
+ return False
77
+ return True
78
+
79
+ def stop_model(self):
80
+ self._stop_model()
81
+
82
+ def stop(self):
83
+ self._stop_model()
84
+ self._stop_service()
85
+
86
+ def get_wave_form(self, path):
87
+ waveform, _ = librosa.load(path, sr=16000)
88
+ return waveform
89
+
90
+ def refer(self, filepath):
91
+ if self.stream:
92
+ print("Streaming mode, use refer_stream() instead.")
93
+ return ""
94
+ waveform = self.get_wave_form(filepath)
95
+ data = {
96
+ "audio_data": waveform.tolist(),
97
+ "sample_rate": 16000,
98
+ "launguage": "auto"
99
+ }
100
+ try:
101
+ response = requests.post(self.url + '/asr', json=data)
102
+ if response.status_code == 200:
103
+ res = json.loads(response.text)
104
+ return res.get("text", "")
105
+ else:
106
+ print(f"Requests failed: {response.status_code}")
107
+ return ""
108
+ except Exception as e:
109
+ print("Requests failed:", e)
110
+ return ""
111
+
112
+ def refer_stream(self, filepath):
113
+ if not self.stream:
114
+ print("Streaming mode, use refer() instead.")
115
+ return ""
116
+ waveform = self.get_wave_form(filepath)
117
+ data = {
118
+ "audio_data": waveform.tolist(),
119
+ "sample_rate": 16000,
120
+ "launguage": "auto",
121
+ "step": 0.1,
122
+ }
123
+ print('start post')
124
+ try:
125
+ response = requests.post(self.url + '/asr_stream', json=data, stream=True)
126
+ for line in response.iter_lines():
127
+ if line:
128
+ chunk = json.loads(line)
129
+ yield chunk.get("text", "")
130
+ except Exception as e:
131
+ print("Requests failed:", e)
132
+ return ""
133
+
134
+ stream = True
135
+ client = SensevoiceClient(model="/root/models/sensevoice-maixcam2/model.mud", stream=stream)
136
+ if client.start() is False:
137
+ print("Failed to start service or model.")
138
+ exit()
139
+ if not stream:
140
+ print('start refer')
141
+ text = client.refer("example/zh.mp3")
142
+ print(text)
143
+ else:
144
+ print('start refer stream')
145
+ for text in client.refer_stream("example/zh.mp3"):
146
+ print(text)
config.json ADDED
File without changes
download_dataset.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ wget https://github.com/ml-inory/whisper.axera/releases/download/v1.0/datasets.zip
2
+ unzip datasets.zip -d ./
download_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Speed up hf download using mirror url
4
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
5
+ from huggingface_hub import snapshot_download
6
+
7
+ current_file_path = os.path.dirname(__file__)
8
+ REPO_ROOT = "AXERA-TECH"
9
+ CACHE_PATH = os.path.join(current_file_path, "models")
10
+
11
+
12
+ def download_model(model_name: str) -> str:
13
+ """
14
+ Download model from AXERA-TECH's huggingface space.
15
+
16
+ model_name: str
17
+ Available model names could be checked on https://huggingface.co/AXERA-TECH.
18
+
19
+ Returns:
20
+ str: Path to model_name
21
+
22
+ """
23
+ os.makedirs(CACHE_PATH, exist_ok=True)
24
+
25
+ model_path = os.path.join(CACHE_PATH, model_name)
26
+ if not os.path.exists(model_path):
27
+ print(f"Downloading {model_name}...")
28
+ snapshot_download(
29
+ repo_id=f"{REPO_ROOT}/{model_name}",
30
+ local_dir=os.path.join(CACHE_PATH, model_name),
31
+ )
32
+
33
+ return model_path
embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a453244ab037744531b97bcb8574c8442301dac11f6406fdab208dddb83b93e
3
+ size 25523
example/en.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f10378336a4e584f3f63799e62f99d5add3c2a401b51d3abe7d3a3a82f255ada
3
+ size 57441
example/ja.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:496dbc43b289e1d0d0cb916df9737450bca56acd8aaca046a7a2472363b1be53
3
+ size 57837
example/ko.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8612f62db8319a6cb4ab4b1d2039bfc32f174f89611889ddafdeb5c0a6070b5f
3
+ size 27909
example/yue.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5098eebc13530a66e4eac1f30d3246e65c9cfc4e096665f9d395aca8eff0d181
3
+ size 31246
example/zh.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e64de19e4ff9a02e682955c9112f32d2317cfdbb5bc2f3504664044c993f195
3
+ size 44973
frontend.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
4
+ import copy
5
+
6
+ import numpy as np
7
+ import kaldi_native_fbank as knf
8
+
9
+
10
+ class WavFrontend:
11
+ """Conventional frontend structure for ASR."""
12
+
13
+ def __init__(
14
+ self,
15
+ cmvn_file: str = None,
16
+ fs: int = 16000,
17
+ window: str = "hamming",
18
+ n_mels: int = 80,
19
+ frame_length: int = 25,
20
+ frame_shift: int = 10,
21
+ lfr_m: int = 1,
22
+ lfr_n: int = 1,
23
+ dither: float = 1.0,
24
+ **kwargs,
25
+ ) -> None:
26
+
27
+ opts = knf.FbankOptions()
28
+ opts.frame_opts.samp_freq = fs
29
+ opts.frame_opts.dither = dither
30
+ opts.frame_opts.window_type = window
31
+ opts.frame_opts.frame_shift_ms = float(frame_shift)
32
+ opts.frame_opts.frame_length_ms = float(frame_length)
33
+ opts.mel_opts.num_bins = n_mels
34
+ opts.energy_floor = 0
35
+ opts.frame_opts.snip_edges = True
36
+ opts.mel_opts.debug_mel = False
37
+ self.opts = opts
38
+
39
+ self.lfr_m = lfr_m
40
+ self.lfr_n = lfr_n
41
+ self.cmvn_file = cmvn_file
42
+
43
+ if self.cmvn_file:
44
+ self.cmvn = self.load_cmvn()
45
+ self.fbank_fn = None
46
+ self.fbank_beg_idx = 0
47
+ self.reset_status()
48
+
49
+ def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
50
+ waveform = waveform * (1 << 15)
51
+ self.fbank_fn = knf.OnlineFbank(self.opts)
52
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
53
+ frames = self.fbank_fn.num_frames_ready
54
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
55
+ for i in range(frames):
56
+ mat[i, :] = self.fbank_fn.get_frame(i)
57
+ feat = mat.astype(np.float32)
58
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
59
+ return feat, feat_len
60
+
61
+ def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
62
+ waveform = waveform * (1 << 15)
63
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
64
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
65
+ frames = self.fbank_fn.num_frames_ready
66
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
67
+ for i in range(self.fbank_beg_idx, frames):
68
+ mat[i, :] = self.fbank_fn.get_frame(i)
69
+ # self.fbank_beg_idx += (frames-self.fbank_beg_idx)
70
+ feat = mat.astype(np.float32)
71
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
72
+ return feat, feat_len
73
+
74
+ def reset_status(self):
75
+ self.fbank_fn = knf.OnlineFbank(self.opts)
76
+ self.fbank_beg_idx = 0
77
+
78
+ def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
79
+ if self.lfr_m != 1 or self.lfr_n != 1:
80
+ feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
81
+
82
+ if self.cmvn_file:
83
+ feat = self.apply_cmvn(feat)
84
+
85
+ feat_len = np.array(feat.shape[0]).astype(np.int32)
86
+ return feat, feat_len
87
+
88
+ @staticmethod
89
+ def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
90
+ LFR_inputs = []
91
+
92
+ T = inputs.shape[0]
93
+ T_lfr = int(np.ceil(T / lfr_n))
94
+ left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
95
+ inputs = np.vstack((left_padding, inputs))
96
+ T = T + (lfr_m - 1) // 2
97
+ for i in range(T_lfr):
98
+ if lfr_m <= T - i * lfr_n:
99
+ LFR_inputs.append(
100
+ (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
101
+ )
102
+ else:
103
+ # process last LFR frame
104
+ num_padding = lfr_m - (T - i * lfr_n)
105
+ frame = inputs[i * lfr_n :].reshape(-1)
106
+ for _ in range(num_padding):
107
+ frame = np.hstack((frame, inputs[-1]))
108
+
109
+ LFR_inputs.append(frame)
110
+ LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
111
+ return LFR_outputs
112
+
113
+ def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
114
+ """
115
+ Apply CMVN with mvn data
116
+ """
117
+ frame, dim = inputs.shape
118
+ means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
119
+ vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
120
+ inputs = (inputs + means) * vars
121
+ return inputs
122
+
123
+ def load_cmvn(
124
+ self,
125
+ ) -> np.ndarray:
126
+ with open(self.cmvn_file, "r", encoding="utf-8") as f:
127
+ lines = f.readlines()
128
+
129
+ means_list = []
130
+ vars_list = []
131
+ for i in range(len(lines)):
132
+ line_item = lines[i].split()
133
+ if line_item[0] == "<AddShift>":
134
+ line_item = lines[i + 1].split()
135
+ if line_item[0] == "<LearnRateCoef>":
136
+ add_shift_line = line_item[3 : (len(line_item) - 1)]
137
+ means_list = list(add_shift_line)
138
+ continue
139
+ elif line_item[0] == "<Rescale>":
140
+ line_item = lines[i + 1].split()
141
+ if line_item[0] == "<LearnRateCoef>":
142
+ rescale_line = line_item[3 : (len(line_item) - 1)]
143
+ vars_list = list(rescale_line)
144
+ continue
145
+
146
+ means = np.array(means_list).astype(np.float64)
147
+ vars = np.array(vars_list).astype(np.float64)
148
+ cmvn = np.array([means, vars])
149
+ return cmvn
150
+
151
+
152
+ class WavFrontendOnline(WavFrontend):
153
+ def __init__(self, **kwargs):
154
+ super().__init__(**kwargs)
155
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
156
+ # add variables
157
+ self.frame_sample_length = int(
158
+ self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
159
+ )
160
+ self.frame_shift_sample_length = int(
161
+ self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
162
+ )
163
+ self.waveform = None
164
+ self.reserve_waveforms = None
165
+ self.input_cache = None
166
+ self.lfr_splice_cache = []
167
+
168
+ @staticmethod
169
+ # inputs has catted the cache
170
+ def apply_lfr(
171
+ inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
172
+ ) -> Tuple[np.ndarray, np.ndarray, int]:
173
+ """
174
+ Apply lfr with data
175
+ """
176
+
177
+ LFR_inputs = []
178
+ T = inputs.shape[0] # include the right context
179
+ T_lfr = int(
180
+ np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
181
+ ) # minus the right context: (lfr_m - 1) // 2
182
+ splice_idx = T_lfr
183
+ for i in range(T_lfr):
184
+ if lfr_m <= T - i * lfr_n:
185
+ LFR_inputs.append(
186
+ (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
187
+ )
188
+ else: # process last LFR frame
189
+ if is_final:
190
+ num_padding = lfr_m - (T - i * lfr_n)
191
+ frame = (inputs[i * lfr_n :]).reshape(-1)
192
+ for _ in range(num_padding):
193
+ frame = np.hstack((frame, inputs[-1]))
194
+ LFR_inputs.append(frame)
195
+ else:
196
+ # update splice_idx and break the circle
197
+ splice_idx = i
198
+ break
199
+ splice_idx = min(T - 1, splice_idx * lfr_n)
200
+ lfr_splice_cache = inputs[splice_idx:, :]
201
+ LFR_outputs = np.vstack(LFR_inputs)
202
+ return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
203
+
204
+ @staticmethod
205
+ def compute_frame_num(
206
+ sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
207
+ ) -> int:
208
+ frame_num = int(
209
+ (sample_length - frame_sample_length) / frame_shift_sample_length + 1
210
+ )
211
+ return (
212
+ frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
213
+ )
214
+
215
+ def fbank(
216
+ self, input: np.ndarray, input_lengths: np.ndarray
217
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
218
+ self.fbank_fn = knf.OnlineFbank(self.opts)
219
+ batch_size = input.shape[0]
220
+ if self.input_cache is None:
221
+ self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
222
+ input = np.concatenate((self.input_cache, input), axis=1)
223
+ frame_num = self.compute_frame_num(
224
+ input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
225
+ )
226
+ # update self.in_cache
227
+ self.input_cache = input[
228
+ :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
229
+ ]
230
+ waveforms = np.empty(0, dtype=np.float32)
231
+ feats_pad = np.empty(0, dtype=np.float32)
232
+ feats_lens = np.empty(0, dtype=np.int32)
233
+ if frame_num:
234
+ waveforms = []
235
+ feats = []
236
+ feats_lens = []
237
+ for i in range(batch_size):
238
+ waveform = input[i]
239
+ waveforms.append(
240
+ waveform[
241
+ : (
242
+ (frame_num - 1) * self.frame_shift_sample_length
243
+ + self.frame_sample_length
244
+ )
245
+ ]
246
+ )
247
+ waveform = waveform * (1 << 15)
248
+
249
+ self.fbank_fn.accept_waveform(
250
+ self.opts.frame_opts.samp_freq, waveform.tolist()
251
+ )
252
+ frames = self.fbank_fn.num_frames_ready
253
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
254
+ for i in range(frames):
255
+ mat[i, :] = self.fbank_fn.get_frame(i)
256
+ feat = mat.astype(np.float32)
257
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
258
+ feats.append(feat)
259
+ feats_lens.append(feat_len)
260
+
261
+ waveforms = np.stack(waveforms)
262
+ feats_lens = np.array(feats_lens)
263
+ feats_pad = np.array(feats)
264
+ self.fbanks = feats_pad
265
+ self.fbanks_lens = copy.deepcopy(feats_lens)
266
+ return waveforms, feats_pad, feats_lens
267
+
268
+ def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
269
+ return self.fbanks, self.fbanks_lens
270
+
271
+ def lfr_cmvn(
272
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
273
+ ) -> Tuple[np.ndarray, np.ndarray, List[int]]:
274
+ batch_size = input.shape[0]
275
+ feats = []
276
+ feats_lens = []
277
+ lfr_splice_frame_idxs = []
278
+ for i in range(batch_size):
279
+ mat = input[i, : input_lengths[i], :]
280
+ lfr_splice_frame_idx = -1
281
+ if self.lfr_m != 1 or self.lfr_n != 1:
282
+ # update self.lfr_splice_cache in self.apply_lfr
283
+ mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
284
+ mat, self.lfr_m, self.lfr_n, is_final
285
+ )
286
+ if self.cmvn_file is not None:
287
+ mat = self.apply_cmvn(mat)
288
+ feat_length = mat.shape[0]
289
+ feats.append(mat)
290
+ feats_lens.append(feat_length)
291
+ lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
292
+
293
+ feats_lens = np.array(feats_lens)
294
+ feats_pad = np.array(feats)
295
+ return feats_pad, feats_lens, lfr_splice_frame_idxs
296
+
297
+ def extract_fbank(
298
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
299
+ ) -> Tuple[np.ndarray, np.ndarray]:
300
+ batch_size = input.shape[0]
301
+ assert (
302
+ batch_size == 1
303
+ ), "we support to extract feature online only when the batch size is equal to 1 now"
304
+ waveforms, feats, feats_lengths = self.fbank(
305
+ input, input_lengths
306
+ ) # input shape: B T D
307
+ if feats.shape[0]:
308
+ self.waveforms = (
309
+ waveforms
310
+ if self.reserve_waveforms is None
311
+ else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
312
+ )
313
+ if not self.lfr_splice_cache:
314
+ for i in range(batch_size):
315
+ self.lfr_splice_cache.append(
316
+ np.expand_dims(feats[i][0, :], axis=0).repeat(
317
+ (self.lfr_m - 1) // 2, axis=0
318
+ )
319
+ )
320
+
321
+ if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
322
+ lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
323
+ feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
324
+ feats_lengths += lfr_splice_cache_np[0].shape[0]
325
+ frame_from_waveforms = int(
326
+ (self.waveforms.shape[1] - self.frame_sample_length)
327
+ / self.frame_shift_sample_length
328
+ + 1
329
+ )
330
+ minus_frame = (
331
+ (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
332
+ )
333
+ feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
334
+ feats, feats_lengths, is_final
335
+ )
336
+ if self.lfr_m == 1:
337
+ self.reserve_waveforms = None
338
+ else:
339
+ reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
340
+ # print('reserve_frame_idx: ' + str(reserve_frame_idx))
341
+ # print('frame_frame: ' + str(frame_from_waveforms))
342
+ self.reserve_waveforms = self.waveforms[
343
+ :,
344
+ reserve_frame_idx
345
+ * self.frame_shift_sample_length : frame_from_waveforms
346
+ * self.frame_shift_sample_length,
347
+ ]
348
+ sample_length = (
349
+ frame_from_waveforms - 1
350
+ ) * self.frame_shift_sample_length + self.frame_sample_length
351
+ self.waveforms = self.waveforms[:, :sample_length]
352
+ else:
353
+ # update self.reserve_waveforms and self.lfr_splice_cache
354
+ self.reserve_waveforms = self.waveforms[
355
+ :, : -(self.frame_sample_length - self.frame_shift_sample_length)
356
+ ]
357
+ for i in range(batch_size):
358
+ self.lfr_splice_cache[i] = np.concatenate(
359
+ (self.lfr_splice_cache[i], feats[i]), axis=0
360
+ )
361
+ return np.empty(0, dtype=np.float32), feats_lengths
362
+ else:
363
+ if is_final:
364
+ self.waveforms = (
365
+ waveforms
366
+ if self.reserve_waveforms is None
367
+ else self.reserve_waveforms
368
+ )
369
+ feats = np.stack(self.lfr_splice_cache)
370
+ feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
371
+ feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
372
+ if is_final:
373
+ self.cache_reset()
374
+ return feats, feats_lengths
375
+
376
+ def get_waveforms(self):
377
+ return self.waveforms
378
+
379
+ def cache_reset(self):
380
+ self.fbank_fn = knf.OnlineFbank(self.opts)
381
+ self.reserve_waveforms = None
382
+ self.input_cache = None
383
+ self.lfr_splice_cache = []
384
+
385
+
386
+ def load_bytes(input):
387
+ middle_data = np.frombuffer(input, dtype=np.int16)
388
+ middle_data = np.asarray(middle_data)
389
+ if middle_data.dtype.kind not in "iu":
390
+ raise TypeError("'middle_data' must be an array of integers")
391
+ dtype = np.dtype("float32")
392
+ if dtype.kind != "f":
393
+ raise TypeError("'dtype' must be a floating point type")
394
+
395
+ i = np.iinfo(middle_data.dtype)
396
+ abs_max = 2 ** (i.bits - 1)
397
+ offset = i.min + abs_max
398
+ array = np.frombuffer(
399
+ (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32
400
+ )
401
+ return array
402
+
403
+
404
+ class SinusoidalPositionEncoderOnline:
405
+ """Streaming Positional encoding."""
406
+
407
+ def encode(
408
+ self,
409
+ positions: np.ndarray = None,
410
+ depth: int = None,
411
+ dtype: np.dtype = np.float32,
412
+ ):
413
+ batch_size = positions.shape[0]
414
+ positions = positions.astype(dtype)
415
+ log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (
416
+ depth / 2 - 1
417
+ )
418
+ inv_timescales = np.exp(
419
+ np.arange(depth / 2).astype(dtype) * (-log_timescale_increment)
420
+ )
421
+ inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
422
+ scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(
423
+ inv_timescales, [1, 1, -1]
424
+ )
425
+ encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
426
+ return encoding.astype(dtype)
427
+
428
+ def forward(self, x, start_idx=0):
429
+ batch_size, timesteps, input_dim = x.shape
430
+ positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
431
+ position_encoding = self.encode(positions, input_dim, x.dtype)
432
+
433
+ return x + position_encoding[:, start_idx : start_idx + timesteps]
434
+
435
+
436
+ def test():
437
+ path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
438
+ import librosa
439
+
440
+ cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
441
+ config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
442
+ from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
443
+
444
+ config = read_yaml(config_file)
445
+ waveform, _ = librosa.load(path, sr=None)
446
+ frontend = WavFrontend(
447
+ cmvn_file=cmvn_file,
448
+ **config["frontend_conf"],
449
+ )
450
+ speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
451
+ feat, feat_len = frontend.lfr_cmvn(
452
+ speech
453
+ ) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
454
+
455
+ frontend.reset_status() # clear cache
456
+ return feat, feat_len
457
+
458
+
459
+ if __name__ == "__main__":
460
+ test()
gradio_demo.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from SenseVoiceAx import SenseVoiceAx
4
+ from print_utils import rich_transcription_postprocess
5
+
6
+ max_len = 256
7
+
8
+ model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
9
+
10
+ assert os.path.exists(model_path), f"model {model_path} not exist"
11
+
12
+ pipeline = SenseVoiceAx(
13
+ model_path,
14
+ max_len=max_len,
15
+ beam_size=3,
16
+ language="auto",
17
+ hot_words=None,
18
+ use_itn=True,
19
+ streaming=False,
20
+ )
21
+
22
+
23
+ def speech_to_text(audio_path, lang):
24
+ """
25
+ audio_path: 音频文件路径
26
+ lang: 语言类型 "auto", "zh", "en", "yue", "ja", "ko"
27
+ """
28
+ if not audio_path:
29
+ return "无音频"
30
+
31
+ pipeline.choose_language(language=lang)
32
+ asr_res = pipeline.infer(audio_path, print_rtf=False)
33
+
34
+ return asr_res
35
+
36
+
37
+ def main():
38
+ with gr.Blocks() as demo:
39
+ with gr.Row():
40
+ output_text = gr.Textbox(label="识别结果", lines=5)
41
+
42
+ with gr.Row():
43
+ audio_input = gr.Audio(
44
+ sources=["upload"], type="filepath", label="录制或上传音频", format="mp3"
45
+ )
46
+ lang_dropdown = gr.Dropdown(
47
+ choices=["auto", "zh", "en", "yue", "ja", "ko"],
48
+ value="auto",
49
+ label="选择音频语言",
50
+ )
51
+
52
+ audio_input.change(
53
+ fn=speech_to_text, inputs=[audio_input, lang_dropdown], outputs=output_text
54
+ )
55
+
56
+ demo.launch(
57
+ server_name="0.0.0.0",
58
+ )
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
main.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from SenseVoiceAx import SenseVoiceAx
4
+ import librosa
5
+ import numpy as np
6
+ import time
7
+
8
+
9
+ def get_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument(
12
+ "--input", "-i", required=True, type=str, help="Input audio file"
13
+ )
14
+ parser.add_argument(
15
+ "--language",
16
+ "-l",
17
+ required=False,
18
+ type=str,
19
+ default="auto",
20
+ choices=["auto", "zh", "en", "yue", "ja", "ko"],
21
+ )
22
+ parser.add_argument("--streaming", action="store_true")
23
+ return parser.parse_args()
24
+
25
+
26
+ def main():
27
+ args = get_args()
28
+
29
+ input_audio = args.input
30
+ language = args.language
31
+ use_itn = True # 标点符号预测
32
+ if not args.streaming:
33
+ max_len = 256
34
+ model_path = os.path.join("sensevoice_ax630c", "sensevoice.axmodel")
35
+ else:
36
+ max_len = 26
37
+ model_path = os.path.join("sensevoice_ax630c", "streaming_sensevoice.axmodel")
38
+
39
+ assert os.path.exists(model_path), f"model {model_path} not exist"
40
+
41
+ print(f"input_audio: {input_audio}")
42
+ print(f"language: {language}")
43
+ print(f"use_itn: {use_itn}")
44
+ print(f"model_path: {model_path}")
45
+ print(f"streaming: {args.streaming}")
46
+
47
+ pipeline = SenseVoiceAx(
48
+ model_path,
49
+ max_len=max_len,
50
+ beam_size=3,
51
+ language="auto",
52
+ hot_words=None,
53
+ use_itn=True,
54
+ streaming=args.streaming,
55
+ )
56
+
57
+ if not args.streaming:
58
+ asr_res = pipeline.infer(input_audio, print_rtf=True)
59
+ print("ASR result: " + asr_res)
60
+ else:
61
+ samples, sr = librosa.load(input_audio, sr=16000)
62
+ samples = (samples * 32768).tolist()
63
+ duration = len(samples) / 16000
64
+
65
+ start = time.time()
66
+ step = int(0.1 * sr)
67
+ for i in range(0, len(samples), step):
68
+ is_last = i + step >= len(samples)
69
+ for res in pipeline.stream_infer(samples[i : i + step], is_last):
70
+ print(res)
71
+
72
+ end = time.time()
73
+ cost_time = end - start
74
+
75
+ print(f"RTF: {cost_time / duration}")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
model.mud ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [basic]
2
+ type = axmodel
3
+ model_npu = sensevoice_ax630c/sensevoice.axmodel
4
+ model_vnpu =
5
+
6
+ [extra]
7
+ model_type = sensevoice
8
+ input_cache = true
9
+ output_cache = true
10
+ beam_size = 3
11
+ language = auto
12
+ hot_words = None,
13
+ use_itn = True
14
+ stream_model = sensevoice_ax630c/streaming_sensevoice.axmodel
15
+ server_url = http://127.0.0.1:12345
pe_nonstream.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f1c9c550bd62fa164a959517f52d46a28591812fafdf002df0df2bd998f44b5
3
+ size 573568
pe_streaming.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54fec2fe2670168d36678c5857e65c459c634e6b6d6df928b7d415399ce2c291
3
+ size 58368
print_utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ emo_dict = {
2
+ "<|HAPPY|>": "😊",
3
+ "<|SAD|>": "😔",
4
+ "<|ANGRY|>": "😡",
5
+ "<|NEUTRAL|>": "",
6
+ "<|FEARFUL|>": "😰",
7
+ "<|DISGUSTED|>": "🤢",
8
+ "<|SURPRISED|>": "😮",
9
+ }
10
+
11
+ event_dict = {
12
+ "<|BGM|>": "🎼",
13
+ "<|Speech|>": "",
14
+ "<|Applause|>": "👏",
15
+ "<|Laughter|>": "😀",
16
+ "<|Cry|>": "😭",
17
+ "<|Sneeze|>": "🤧",
18
+ "<|Breath|>": "",
19
+ "<|Cough|>": "🤧",
20
+ }
21
+
22
+ lang_dict = {
23
+ "<|zh|>": "<|lang|>",
24
+ "<|en|>": "<|lang|>",
25
+ "<|yue|>": "<|lang|>",
26
+ "<|ja|>": "<|lang|>",
27
+ "<|ko|>": "<|lang|>",
28
+ "<|nospeech|>": "<|lang|>",
29
+ }
30
+
31
+ emoji_dict = {
32
+ "<|nospeech|><|Event_UNK|>": "❓",
33
+ "<|zh|>": "",
34
+ "<|en|>": "",
35
+ "<|yue|>": "",
36
+ "<|ja|>": "",
37
+ "<|ko|>": "",
38
+ "<|nospeech|>": "",
39
+ "<|HAPPY|>": "😊",
40
+ "<|SAD|>": "😔",
41
+ "<|ANGRY|>": "😡",
42
+ "<|NEUTRAL|>": "",
43
+ "<|BGM|>": "🎼",
44
+ "<|Speech|>": "",
45
+ "<|Applause|>": "👏",
46
+ "<|Laughter|>": "😀",
47
+ "<|FEARFUL|>": "😰",
48
+ "<|DISGUSTED|>": "🤢",
49
+ "<|SURPRISED|>": "😮",
50
+ "<|Cry|>": "😭",
51
+ "<|EMO_UNKNOWN|>": "",
52
+ "<|Sneeze|>": "🤧",
53
+ "<|Breath|>": "",
54
+ "<|Cough|>": "😷",
55
+ "<|Sing|>": "",
56
+ "<|Speech_Noise|>": "",
57
+ "<|withitn|>": "",
58
+ "<|woitn|>": "",
59
+ "<|GBG|>": "",
60
+ "<|Event_UNK|>": "",
61
+ }
62
+
63
+ emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
64
+ event_set = {
65
+ "🎼",
66
+ "👏",
67
+ "😀",
68
+ "😭",
69
+ "🤧",
70
+ "😷",
71
+ }
72
+
73
+
74
+ def format_str_v2(s):
75
+ sptk_dict = {}
76
+ for sptk in emoji_dict:
77
+ sptk_dict[sptk] = s.count(sptk)
78
+ s = s.replace(sptk, "")
79
+ emo = "<|NEUTRAL|>"
80
+ for e in emo_dict:
81
+ if sptk_dict[e] > sptk_dict[emo]:
82
+ emo = e
83
+ for e in event_dict:
84
+ if sptk_dict[e] > 0:
85
+ s = event_dict[e] + s
86
+ s = s + emo_dict[emo]
87
+
88
+ for emoji in emo_set.union(event_set):
89
+ s = s.replace(" " + emoji, emoji)
90
+ s = s.replace(emoji + " ", emoji)
91
+ return s.strip()
92
+
93
+
94
+ def rich_transcription_postprocess(s):
95
+ def get_emo(s):
96
+ return s[-1] if s[-1] in emo_set else None
97
+
98
+ def get_event(s):
99
+ return s[0] if s[0] in event_set else None
100
+
101
+ s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
102
+ for lang in lang_dict:
103
+ s = s.replace(lang, "<|lang|>")
104
+ s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
105
+ new_s = " " + s_list[0]
106
+ cur_ent_event = get_event(new_s)
107
+ for i in range(1, len(s_list)):
108
+ if len(s_list[i]) == 0:
109
+ continue
110
+ if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
111
+ s_list[i] = s_list[i][1:]
112
+ # else:
113
+ cur_ent_event = get_event(s_list[i])
114
+ if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
115
+ new_s = new_s[:-1]
116
+ new_s += s_list[i].strip().lstrip()
117
+ new_s = new_s.replace("The.", " ")
118
+ return new_s.strip()
119
+
120
+
121
+ def rich_print_asr_res(asr_res, will_print=True, remove_punc=False):
122
+ res = "".join([rich_transcription_postprocess(i) for i in asr_res])
123
+
124
+ if remove_punc:
125
+ res = res.replace(",", "")
126
+ res = res.replace("。", "")
127
+
128
+ if will_print:
129
+ print(res)
130
+
131
+ return res
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ numpy<2
3
+ kaldi-native-fbank
4
+ librosa==0.9.1
5
+ sentencepiece
6
+ fastapi
7
+ gradio
8
+ emoji
9
+ asr-decoder
10
+ online-fbank
11
+ torch
sensevoice_ax630c/sensevoice.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67d290cf7cebf45db5f37b2e93b8bdfff44dc35110bb29d84204a5f9eae9fd4d
3
+ size 256550253
sensevoice_ax630c/streaming_sensevoice.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba1ddd60841297903bfdae059ad88092d0fd1c543e1d80d7f64199d4e27b8263
3
+ size 249023211
server.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from fastapi import FastAPI, HTTPException, Body
3
+ from fastapi.responses import JSONResponse, StreamingResponse
4
+ from typing import List, Optional
5
+ import logging
6
+ import json
7
+ import configparser
8
+ from SenseVoiceAx import SenseVoiceAx
9
+ import os
10
+ import librosa
11
+
12
+ # 初始化日志
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API")
17
+
18
+ # 全局变量存储模型
19
+ asr_model = None
20
+ asr_model_is_loaded = False
21
+ mud_configs = None
22
+
23
+ def parse_config_file_to_json(file_path):
24
+ """从文件读取配置并解析为JSON"""
25
+ if not os.path.exists(file_path):
26
+ raise FileNotFoundError(f"配置文件不存在: {file_path}")
27
+
28
+ config = configparser.ConfigParser()
29
+ config.read(file_path, encoding='utf-8')
30
+
31
+ result = {}
32
+ for section in config.sections():
33
+ result[section] = {}
34
+ for key, value in config[section].items():
35
+ # 简单类型转换
36
+ value = value.strip()
37
+
38
+ if value.lower() == 'true':
39
+ result[section][key] = True
40
+ elif value.lower() == 'false':
41
+ result[section][key] = False
42
+ elif value.lower() == 'none' or value == '':
43
+ result[section][key] = None
44
+ elif value.isdigit():
45
+ result[section][key] = int(value)
46
+ else:
47
+ result[section][key] = value
48
+
49
+ return result
50
+
51
+ @app.on_event("startup")
52
+ async def load_model():
53
+ pass
54
+
55
+ def validate_audio_data(audio_data: List[float]) -> np.ndarray:
56
+ """
57
+ 验证并转换音频数据为numpy数组
58
+
59
+ 参数:
60
+ - audio_data: 浮点数列表表示的音频数据
61
+
62
+ 返回:
63
+ - 验证后的numpy数组
64
+ """
65
+ try:
66
+ # 转换为numpy数组
67
+ np_array = np.array(audio_data, dtype=np.float32)
68
+
69
+ # 验证数据有效性
70
+ if np_array.ndim != 1:
71
+ raise ValueError("Audio data must be 1-dimensional")
72
+
73
+ if len(np_array) == 0:
74
+ raise ValueError("Audio data cannot be empty")
75
+
76
+ return np_array
77
+ except Exception as e:
78
+ raise ValueError(f"Invalid audio data: {str(e)}")
79
+
80
+
81
+ @app.get("/get_language", summary="Get current language")
82
+ async def get_language():
83
+ return JSONResponse(content={"language": asr_model.language})
84
+
85
+
86
+ @app.get(
87
+ "/get_language_options",
88
+ summary="Get possible language options, possible options include [auto, zh, en, yue, ja, ko]",
89
+ )
90
+ async def get_language_options():
91
+ return JSONResponse(content={"language_options": asr_model.language_options})
92
+
93
+ @app.get("/status", summary="Get ASR model status")
94
+ async def get_status():
95
+ global asr_model_is_loaded
96
+ return JSONResponse(content={"status": "loaded" if asr_model_is_loaded else "not loaded"})
97
+
98
+ @app.post("/start_model", summary="Load model")
99
+ async def start_model(
100
+ model_path: str = Body(
101
+ "sensevoice_ax630c/sensevoice.axmodel",
102
+ description="Path to the model file",
103
+ ),
104
+ language: str = Body("auto", description="Language"),
105
+ stream: bool = Body(False, description="streaming or not"),
106
+ ):
107
+ """
108
+ 服务启动时加载ASR模型
109
+ """
110
+ global asr_model
111
+ global asr_model_is_loaded
112
+ logger.info("Loading ASR model...")
113
+
114
+ if asr_model_is_loaded:
115
+ return JSONResponse(content={"status": "loaded"})
116
+
117
+ try:
118
+ mud_configs = parse_config_file_to_json(model_path)
119
+ axmodel_path = mud_configs.get("basic", {}).get("model_npu", None)
120
+ streaming_axmodel_path = mud_configs.get("extra", {}).get("stream_model", None)
121
+ model_dir_path = os.path.dirname(model_path)
122
+ if stream:
123
+ if streaming_axmodel_path is None:
124
+ logger.error(f"Not found model:{streaming_axmodel_path}")
125
+ raise HTTPException(status_code=400, detail=f"Not found model:{streaming_axmodel_path}")
126
+ model_path = os.path.join(model_dir_path, streaming_axmodel_path)
127
+ else:
128
+ if axmodel_path is None:
129
+ logger.error(f"Not found model:{axmodel_path}")
130
+ raise HTTPException(status_code=400, detail=f"Not found model:{axmodel_path}")
131
+ model_path = os.path.join(model_dir_path, axmodel_path)
132
+
133
+ # 模型加载
134
+ use_itn = mud_configs.get("extra", {}).get("use_itn", True) # 逆文本规范
135
+ beam_size = mud_configs.get("extra", {}).get("beam_size", 3)
136
+ hot_words = mud_configs.get("extra", {}).get("hot_words", None)
137
+ use_itn = mud_configs.get("extra", {}).get("use_itn", True)
138
+ streaming = stream
139
+ max_len = 26 if streaming else 256
140
+
141
+ print(f'model path: {model_path}')
142
+ print(f'max_len: {max_len}')
143
+ print(f'beam_size: {beam_size}')
144
+ print(f"language: {language}")
145
+ print(f'hot_words: {hot_words}')
146
+ print(f"use_itn: {use_itn}")
147
+ print(f'streaming: {streaming}')
148
+
149
+ if not os.path.exists(model_path):
150
+ raise HTTPException(status_code=400, detail=f"model {model_path} not exist")
151
+
152
+ asr_model = SenseVoiceAx(
153
+ model_path,
154
+ max_len=max_len,
155
+ beam_size=beam_size,
156
+ language=language,
157
+ hot_words=hot_words,
158
+ use_itn=use_itn,
159
+ streaming=streaming,
160
+ )
161
+
162
+ logger.info("ASR model loaded successfully")
163
+ except Exception as e:
164
+ logger.error(f"Failed to load ASR model: {str(e)}")
165
+ raise
166
+
167
+ return JSONResponse(content={"status": "loaded"})
168
+
169
+ @app.post("/stop_model", summary="Load model")
170
+ async def stop_model(
171
+ ):
172
+ global asr_model
173
+ global asr_model_is_loaded
174
+ del asr_model
175
+ asr_model = None
176
+ asr_model_is_loaded = False
177
+
178
+ @app.post("/asr", summary="Recognize speech from numpy audio data")
179
+ async def recognize_speech(
180
+ audio_data: List[float] = Body(
181
+ ..., embed=True, description="Audio data as list of floats"
182
+ ),
183
+ sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
184
+ language: Optional[str] = Body("auto", description="Language"),
185
+ ):
186
+ """
187
+ 接收numpy数组格式的音频数据并返回识别结果
188
+
189
+ 参数:
190
+ - audio_data: 浮点数列表表示的音频数据
191
+ - sample_rate: 音频采样率(默认16000Hz)
192
+
193
+ 返回:
194
+ - JSON包含识别文本
195
+ """
196
+ try:
197
+ # 检查模型是否已加载
198
+ if asr_model is None:
199
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
200
+
201
+ logger.info(f"Received audio data with length: {len(audio_data)}")
202
+
203
+ # 验证并转换数据
204
+ np_audio = validate_audio_data(audio_data)
205
+ if sample_rate != asr_model.sample_rate:
206
+ np_audio = librosa.resample(np_audio, sample_rate, asr_model.sample_rate)
207
+
208
+ # 调用模型进行识别
209
+ result = asr_model.infer_waveform(np_audio, language)
210
+
211
+ return JSONResponse(content={"text": result})
212
+
213
+ except ValueError as e:
214
+ logger.error(f"Validation error: {str(e)}")
215
+ raise HTTPException(status_code=400, detail=str(e))
216
+ except Exception as e:
217
+ logger.error(f"Recognition error: {str(e)}")
218
+ raise HTTPException(status_code=500, detail=str(e))
219
+
220
+ @app.post("/asr_stream", summary="Recognize speech from numpy audio data")
221
+ async def recognize_speech_stream(
222
+ audio_data: List[float] = Body(
223
+ ..., embed=True, description="Audio data as list of floats"
224
+ ),
225
+ sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
226
+ language: Optional[str] = Body("auto", description="Language"),
227
+ step: Optional[float] = Body(0.1, description="step in seconds"),
228
+ ):
229
+ """
230
+ 接收numpy数组格式的音频数据并返回识别结果
231
+
232
+ 参数:
233
+ - audio_data: 浮点数列表表示的音频数据
234
+ - sample_rate: 音频采样率(默认16000Hz)
235
+
236
+ 返回:
237
+ - JSON包含识别文本
238
+ """
239
+ try:
240
+ # 检查模型是否已加载
241
+ if asr_model is None:
242
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
243
+
244
+ logger.info(f"Received audio data with length: {len(audio_data)}")
245
+
246
+ # 验证并转换数据
247
+ np_audio = validate_audio_data(audio_data)
248
+ if sample_rate != asr_model.sample_rate:
249
+ np_audio = librosa.resample(np_audio, sample_rate, asr_model.sample_rate)
250
+ # 调用模型进行识别
251
+ def stream_infer(np_audio, step):
252
+ samples = (np_audio * 32768).tolist()
253
+
254
+ step = int(step * 16000)
255
+ for i in range(0, len(samples), step):
256
+ is_last = i + step >= len(samples)
257
+ for res in asr_model.stream_infer(samples[i : i + step], is_last, language):
258
+ yield json.dumps(res) + "\n"
259
+ return StreamingResponse(stream_infer(np_audio, step), media_type="application/json")
260
+ except ValueError as e:
261
+ logger.error(f"Validation error: {str(e)}")
262
+ raise HTTPException(status_code=400, detail=str(e))
263
+ except Exception as e:
264
+ logger.error(f"Recognition error: {str(e)}")
265
+ raise HTTPException(status_code=500, detail=str(e))
266
+
267
+ if __name__ == "__main__":
268
+ import uvicorn
269
+
270
+ uvicorn.run(app, host="0.0.0.0", port=12347)
test_wer.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import argparse
3
+ from SenseVoiceAx import SenseVoiceAx
4
+ from tokenizer import SentencepiecesTokenizer
5
+ from print_utils import rich_transcription_postprocess, rich_print_asr_res
6
+ from download_utils import download_model
7
+ import logging
8
+ import re
9
+ import emoji
10
+
11
+
12
+ def setup_logging():
13
+ """配置日志系统,同时输出到控制台和文件"""
14
+ # 获取脚本所在目录
15
+ script_dir = os.path.dirname(os.path.abspath(__file__))
16
+ log_file = os.path.join(script_dir, "test_wer.log")
17
+
18
+ # 配置日志格式
19
+ log_format = "%(asctime)s - %(levelname)s - %(message)s"
20
+ date_format = "%Y-%m-%d %H:%M:%S"
21
+
22
+ # 创建logger
23
+ logger = logging.getLogger()
24
+ logger.setLevel(logging.INFO)
25
+
26
+ # 清除现有的handler
27
+ for handler in logger.handlers[:]:
28
+ logger.removeHandler(handler)
29
+
30
+ # 创建文件handler
31
+ file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
32
+ file_handler.setLevel(logging.INFO)
33
+ file_formatter = logging.Formatter(log_format, date_format)
34
+ file_handler.setFormatter(file_formatter)
35
+
36
+ # 创建控制台handler
37
+ console_handler = logging.StreamHandler()
38
+ console_handler.setLevel(logging.INFO)
39
+ console_formatter = logging.Formatter(log_format, date_format)
40
+ console_handler.setFormatter(console_formatter)
41
+
42
+ # 添加handler到logger
43
+ logger.addHandler(file_handler)
44
+ logger.addHandler(console_handler)
45
+
46
+ return logger
47
+
48
+
49
+ class AIShellDataset:
50
+ def __init__(self, gt_path: str):
51
+ """
52
+ 初始化数据集
53
+
54
+ Args:
55
+ json_path: voice.json文件的路径
56
+ """
57
+ self.gt_path = gt_path
58
+ self.dataset_dir = os.path.dirname(gt_path)
59
+ self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
60
+
61
+ # 检查必要文件和文件夹是否存在
62
+ assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
63
+ assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
64
+
65
+ # 加载数据
66
+ self.data = []
67
+ with open(gt_path, "r", encoding="utf-8") as f:
68
+ for line in f:
69
+ line = line.strip()
70
+ audio_path, gt = line.split(" ")
71
+ audio_path = os.path.join(self.voice_dir, audio_path + ".wav")
72
+ self.data.append({"audio_path": audio_path, "gt": gt})
73
+
74
+ # 使用logging而不是print
75
+ logger = logging.getLogger()
76
+ logger.info(f"加载了 {len(self.data)} 条数据")
77
+
78
+ def __iter__(self):
79
+ """返回迭代器"""
80
+ self.index = 0
81
+ return self
82
+
83
+ def __next__(self):
84
+ """返回下一个数据项"""
85
+ if self.index >= len(self.data):
86
+ raise StopIteration
87
+
88
+ item = self.data[self.index]
89
+ audio_path = item["audio_path"]
90
+ ground_truth = item["gt"]
91
+
92
+ self.index += 1
93
+ return audio_path, ground_truth
94
+
95
+ def __len__(self):
96
+ """返回数据集大小"""
97
+ return len(self.data)
98
+
99
+
100
+ class CommonVoiceDataset:
101
+ """Common Voice数据集解析器"""
102
+
103
+ def __init__(self, tsv_path: str):
104
+ """
105
+ 初始化数据集
106
+
107
+ Args:
108
+ json_path: voice.json文件的路径
109
+ """
110
+ self.tsv_path = tsv_path
111
+ self.dataset_dir = os.path.dirname(tsv_path)
112
+ self.voice_dir = os.path.join(self.dataset_dir, "clips")
113
+
114
+ # 检查必要文件和文件夹是否存在
115
+ assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
116
+ assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
117
+
118
+ # 加载JSON数据
119
+ self.data = []
120
+ with open(tsv_path, "r", encoding="utf-8") as f:
121
+ f.readline()
122
+ for line in f:
123
+ line = line.strip()
124
+ splits = line.split("\t")
125
+ audio_path = splits[1]
126
+ gt = splits[3]
127
+ audio_path = os.path.join(self.voice_dir, audio_path)
128
+ self.data.append({"audio_path": audio_path, "gt": gt})
129
+
130
+ # 使用logging而不是print
131
+ logger = logging.getLogger()
132
+ logger.info(f"加载了 {len(self.data)} 条数据")
133
+
134
+ def __iter__(self):
135
+ """返回迭代器"""
136
+ self.index = 0
137
+ return self
138
+
139
+ def __next__(self):
140
+ """返回下一个数据项"""
141
+ if self.index >= len(self.data):
142
+ raise StopIteration
143
+
144
+ item = self.data[self.index]
145
+ audio_path = item["audio_path"]
146
+ ground_truth = item["gt"]
147
+
148
+ self.index += 1
149
+ return audio_path, ground_truth
150
+
151
+ def __len__(self):
152
+ """返回数据集大小"""
153
+ return len(self.data)
154
+
155
+
156
+ def get_args():
157
+ parser = argparse.ArgumentParser()
158
+ parser.add_argument(
159
+ "--dataset",
160
+ "-d",
161
+ type=str,
162
+ required=True,
163
+ choices=["aishell", "common_voice"],
164
+ help="Test dataset",
165
+ )
166
+ parser.add_argument(
167
+ "--gt_path",
168
+ "-g",
169
+ type=str,
170
+ required=True,
171
+ help="Test dataset ground truth file",
172
+ )
173
+ parser.add_argument(
174
+ "--language",
175
+ "-l",
176
+ required=False,
177
+ type=str,
178
+ default="auto",
179
+ choices=["auto", "zh", "en", "yue", "ja", "ko"],
180
+ )
181
+ parser.add_argument(
182
+ "--max_num", type=int, default=-1, required=False, help="Maximum test data num"
183
+ )
184
+ return parser.parse_args()
185
+
186
+
187
+ def min_distance(word1: str, word2: str) -> int:
188
+
189
+ row = len(word1) + 1
190
+ column = len(word2) + 1
191
+
192
+ cache = [[0] * column for i in range(row)]
193
+
194
+ for i in range(row):
195
+ for j in range(column):
196
+
197
+ if i == 0 and j == 0:
198
+ cache[i][j] = 0
199
+ elif i == 0 and j != 0:
200
+ cache[i][j] = j
201
+ elif j == 0 and i != 0:
202
+ cache[i][j] = i
203
+ else:
204
+ if word1[i - 1] == word2[j - 1]:
205
+ cache[i][j] = cache[i - 1][j - 1]
206
+ else:
207
+ replace = cache[i - 1][j - 1] + 1
208
+ insert = cache[i][j - 1] + 1
209
+ remove = cache[i - 1][j] + 1
210
+
211
+ cache[i][j] = min(replace, insert, remove)
212
+
213
+ return cache[row - 1][column - 1]
214
+
215
+
216
+ def remove_punctuation(text):
217
+ # 定义正则表达式模式,匹配所有标点符号
218
+ # 这个模式包括常见的标点符号和中文标点
219
+ pattern = r"[^\w\s]|_"
220
+
221
+ # 使用sub方法将所有匹配的标点符号替换为空字符串
222
+ cleaned_text = re.sub(pattern, "", text)
223
+
224
+ return cleaned_text
225
+
226
+
227
+ def main():
228
+ logger = setup_logging()
229
+ args = get_args()
230
+
231
+ language = args.language
232
+ use_itn = False # 标点符号预测
233
+ max_num = args.max_num
234
+
235
+ dataset_type = args.dataset.lower()
236
+ if dataset_type == "aishell":
237
+ dataset = AIShellDataset(args.gt_path)
238
+ elif dataset_type == "common_voice":
239
+ dataset = CommonVoiceDataset(args.gt_path)
240
+ else:
241
+ raise ValueError(f"Unknown dataset type {dataset_type}")
242
+
243
+ # model_path_root = download_model("SenseVoice")
244
+ model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
245
+ bpemodel = "chn_jpn_yue_eng_ko_spectok.bpe.model"
246
+
247
+ assert os.path.exists(model_path), f"model {model_path} not exist"
248
+
249
+ logger.info(f"dataset: {args.dataset}")
250
+ logger.info(f"language: {language}")
251
+ logger.info(f"use_itn: {use_itn}")
252
+ logger.info(f"model_path: {model_path}")
253
+
254
+ tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
255
+ pipeline = SenseVoiceAx(
256
+ model_path, language=language, use_itn=use_itn, tokenizer=tokenizer, max_len=256
257
+ )
258
+
259
+ # Iterate over dataset
260
+ hyp = []
261
+ references = []
262
+ all_character_error_num = 0
263
+ all_character_num = 0
264
+ max_data_num = max_num if max_num > 0 else len(dataset)
265
+ for n, (audio_path, reference) in enumerate(dataset):
266
+ reference = remove_punctuation(reference).lower()
267
+
268
+ asr_res = pipeline.infer(audio_path, print_rtf=False)
269
+ hypothesis = rich_print_asr_res(
270
+ asr_res, will_print=False, remove_punc=True
271
+ ).lower()
272
+ hypothesis = emoji.replace_emoji(hypothesis, replace="")
273
+
274
+ character_error_num = min_distance(reference, hypothesis)
275
+ character_num = len(reference)
276
+ character_error_rate = character_error_num / character_num * 100
277
+
278
+ all_character_error_num += character_error_num
279
+ all_character_num += character_num
280
+
281
+ hyp.append(hypothesis)
282
+ references.append(reference)
283
+
284
+ line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
285
+ logger.info(line_content)
286
+
287
+ if n + 1 >= max_data_num:
288
+ break
289
+
290
+ total_character_error_rate = all_character_error_num / all_character_num * 100
291
+
292
+ logger.info(f"Total WER: {total_character_error_rate}%")
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()
tokenizer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
5
+
6
+ import json
7
+ from abc import abstractmethod
8
+ from abc import ABC
9
+ import numpy as np
10
+
11
+
12
+ class BaseTokenizer(ABC):
13
+ def __init__(
14
+ self,
15
+ token_list: Union[Path, str, Iterable[str]] = None,
16
+ unk_symbol: str = "<unk>",
17
+ **kwargs,
18
+ ):
19
+
20
+ if token_list is not None:
21
+ if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
22
+ token_list = Path(token_list)
23
+ self.token_list_repr = str(token_list)
24
+ self.token_list: List[str] = []
25
+
26
+ with token_list.open("r", encoding="utf-8") as f:
27
+ for idx, line in enumerate(f):
28
+ line = line.rstrip()
29
+ self.token_list.append(line)
30
+ elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
31
+ token_list = Path(token_list)
32
+ self.token_list_repr = str(token_list)
33
+ self.token_list: List[str] = []
34
+
35
+ with open(token_list, "r", encoding="utf-8") as f:
36
+ self.token_list = json.load(f)
37
+
38
+ else:
39
+ self.token_list: List[str] = list(token_list)
40
+ self.token_list_repr = ""
41
+ for i, t in enumerate(self.token_list):
42
+ if i == 3:
43
+ break
44
+ self.token_list_repr += f"{t}, "
45
+ self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
46
+
47
+ self.token2id: Dict[str, int] = {}
48
+ for i, t in enumerate(self.token_list):
49
+ if t in self.token2id:
50
+ raise RuntimeError(f'Symbol "{t}" is duplicated')
51
+ self.token2id[t] = i
52
+
53
+ self.unk_symbol = unk_symbol
54
+ if self.unk_symbol not in self.token2id:
55
+ raise RuntimeError(
56
+ f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
57
+ )
58
+ self.unk_id = self.token2id[self.unk_symbol]
59
+
60
+ def encode(self, text, **kwargs):
61
+ tokens = self.text2tokens(text)
62
+ text_ints = self.tokens2ids(tokens)
63
+
64
+ return text_ints
65
+
66
+ def decode(self, text_ints):
67
+ token = self.ids2tokens(text_ints)
68
+ text = self.tokens2text(token)
69
+ return text
70
+
71
+ def get_num_vocabulary_size(self) -> int:
72
+ return len(self.token_list)
73
+
74
+ def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
75
+ if isinstance(integers, np.ndarray) and integers.ndim != 1:
76
+ raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
77
+ return [self.token_list[i] for i in integers]
78
+
79
+ def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
80
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
81
+
82
+ @abstractmethod
83
+ def text2tokens(self, line: str) -> List[str]:
84
+ raise NotImplementedError
85
+
86
+ @abstractmethod
87
+ def tokens2text(self, tokens: Iterable[str]) -> str:
88
+ raise NotImplementedError
89
+
90
+
91
+ class SentencepiecesTokenizer(BaseTokenizer):
92
+ def __init__(self, bpemodel: Union[Path, str], **kwargs):
93
+ super().__init__(**kwargs)
94
+ self.bpemodel = str(bpemodel)
95
+ # NOTE(kamo):
96
+ # Don't build SentencePieceProcessor in __init__()
97
+ # because it's not picklable and it may cause following error,
98
+ # "TypeError: can't pickle SwigPyObject objects",
99
+ # when giving it as argument of "multiprocessing.Process()".
100
+ self.sp = None
101
+ self._build_sentence_piece_processor()
102
+
103
+ def __repr__(self):
104
+ return f'{self.__class__.__name__}(model="{self.bpemodel}")'
105
+
106
+ def _build_sentence_piece_processor(self):
107
+ # Build SentencePieceProcessor lazily.
108
+ if self.sp is None:
109
+ self.sp = spm.SentencePieceProcessor()
110
+ self.sp.load(self.bpemodel)
111
+
112
+ def text2tokens(self, line: str) -> List[str]:
113
+ self._build_sentence_piece_processor()
114
+ return self.sp.EncodeAsPieces(line)
115
+
116
+ def tokens2text(self, tokens: Iterable[str]) -> str:
117
+ self._build_sentence_piece_processor()
118
+ return self.sp.DecodePieces(list(tokens))
119
+
120
+ def encode(self, line: str, **kwargs) -> List[int]:
121
+ self._build_sentence_piece_processor()
122
+ return self.sp.EncodeAsIds(line)
123
+
124
+ def decode(self, line: List[int], **kwargs):
125
+ self._build_sentence_piece_processor()
126
+ return self.sp.DecodeIds(line)
127
+
128
+ def get_vocab_size(self):
129
+ return self.sp.GetPieceSize()
130
+
131
+ def ids2tokens(self, *args, **kwargs):
132
+ return self.decode(*args, **kwargs)
133
+
134
+ def tokens2ids(self, *args, **kwargs):
135
+ return self.encode(*args, **kwargs)