inoryQwQ commited on
Commit
e138696
·
1 Parent(s): e5940ab

Fix python

Browse files
python/SenseVoiceAx.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import axengine as axe
2
+ import numpy as np
3
+ import librosa
4
+ from frontend import WavFrontend
5
+ import time
6
+ from typing import List, Union, Optional, Tuple
7
+ import torch
8
+
9
+
10
+ def unique_consecutive(arr):
11
+ """
12
+ 找出数组中连续的唯一值,模拟 torch.unique_consecutive(yseq, dim=-1)
13
+
14
+ 参数:
15
+ arr: 一维numpy数组
16
+
17
+ 返回:
18
+ unique_values: 去除连续重复值后的数组
19
+ """
20
+ if len(arr) == 0:
21
+ return np.array([])
22
+
23
+ if len(arr) == 1:
24
+ return arr.copy()
25
+
26
+ # 找出变化的位置
27
+ diff = np.diff(arr)
28
+ change_positions = np.where(diff != 0)[0] + 1
29
+
30
+ # 添加起始位置
31
+ start_positions = np.concatenate(([0], change_positions))
32
+
33
+ # 获取唯一值(每个连续段的第一个值)
34
+ unique_values = arr[start_positions]
35
+
36
+ return unique_values
37
+
38
+
39
+ class SenseVoiceAx:
40
+ """SenseVoice axmodel runner"""
41
+
42
+ def __init__(
43
+ self,
44
+ model_path: str,
45
+ cmvn_file: str,
46
+ token_file: str,
47
+ bpe_model: str = None,
48
+ max_seq_len: int = 256,
49
+ beam_size: int = 3,
50
+ hot_words: Optional[List[str]] = None,
51
+ streaming: bool = False,
52
+ providers=["AxEngineExecutionProvider"],
53
+ ):
54
+ """
55
+ Initialize SenseVoiceAx
56
+
57
+ Args:
58
+ model_path: Path of axmodel
59
+ max_len: Fixed shape of input of axmodel
60
+ beam_size: Max number of hypos to hold after each decode step
61
+ language: Support auto, zh(Chinese), en(English), yue(Cantonese), ja(Japanese), ko(Korean)
62
+ hot_words: Words that may fail to recognize,
63
+ special words/phrases (aka hotwords) like rare words, personalized information etc.
64
+ use_itn: Allow Invert Text Normalization if True,
65
+ ITN converts ASR model output into its written form to improve text readability,
66
+ For example, the ITN module replaces “one hundred and twenty-three dollars” transcribed by an ASR model with “$123.”
67
+ streaming: Processes audio in small segments or "chunks" sequentially and outputs text on the fly.
68
+ Use stream_infer method if streaming is true otherwise infer.
69
+
70
+ """
71
+
72
+ self.streaming = streaming
73
+
74
+ self.frontend = WavFrontend(
75
+ cmvn_file=cmvn_file,
76
+ fs=16000,
77
+ window="hamming",
78
+ n_mels=80,
79
+ frame_length=25,
80
+ frame_shift=10,
81
+ lfr_m=7,
82
+ lfr_n=6,
83
+ )
84
+
85
+ self.model = axe.InferenceSession(model_path, providers=providers)
86
+ self.sample_rate = 16000
87
+ self.blank_id = 0
88
+ self.max_seq_len = max_seq_len
89
+ self.padding = 16
90
+ self.input_size = 560
91
+ self.query_num = 4
92
+ self.tokens = self.load_tokens(token_file)
93
+
94
+ self.lid_dict = {
95
+ "auto": 0,
96
+ "zh": 3,
97
+ "en": 4,
98
+ "yue": 7,
99
+ "ja": 11,
100
+ "ko": 12,
101
+ "nospeech": 13,
102
+ }
103
+
104
+ if streaming:
105
+ from asr_decoder import CTCDecoder
106
+ from online_fbank import OnlineFbank
107
+
108
+ # decoder
109
+ if beam_size > 1 and hot_words is not None:
110
+ self.beam_size = beam_size
111
+ symbol_table = {}
112
+ for i in range(len(self.tokens)):
113
+ symbol_table[self.tokens[i]] = i
114
+ self.decoder = CTCDecoder(hot_words, symbol_table, bpe_model)
115
+ else:
116
+ self.beam_size = 1
117
+ self.decoder = CTCDecoder()
118
+
119
+ self.cur_idx = -1
120
+ self.chunk_size = max_seq_len - self.padding
121
+ self.caches_shape = (max_seq_len, self.input_size)
122
+ self.caches = np.zeros(self.caches_shape, dtype=np.float32)
123
+ self.zeros = np.zeros((1, self.input_size), dtype=np.float32)
124
+ self.neg_mean, self.inv_stddev = (
125
+ self.frontend.cmvn[0, :],
126
+ self.frontend.cmvn[1, :],
127
+ )
128
+
129
+ self.fbank = OnlineFbank(window_type="hamming")
130
+ self.stream_mask = self.sequence_mask(
131
+ max_seq_len + self.query_num, max_seq_len + self.query_num
132
+ )
133
+
134
+ def load_tokens(self, token_file):
135
+ tokens = []
136
+ with open(token_file, "r") as f:
137
+ for line in f:
138
+ tokens.append(line[:-1])
139
+ return tokens
140
+
141
+ @property
142
+ def language_options(self):
143
+ return list(self.lid_dict.keys())
144
+
145
+ def sequence_mask(self, max_seq_len, actual_seq_len):
146
+ mask = np.zeros((1, 1, max_seq_len), dtype=np.int32)
147
+ mask[:, :, :actual_seq_len] = 1
148
+ return mask
149
+
150
+ def load_data(self, filepath: str) -> np.ndarray:
151
+ waveform, _ = librosa.load(filepath, sr=self.sample_rate)
152
+ return waveform.flatten()
153
+
154
+ @staticmethod
155
+ def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
156
+ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
157
+ pad_width = ((0, max_feat_len - cur_len), (0, 0))
158
+ return np.pad(feat, pad_width, "constant", constant_values=0)
159
+
160
+ feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
161
+ feats = np.array(feat_res).astype(np.float32)
162
+ return feats
163
+
164
+ def preprocess(self, waveform):
165
+ feats, feats_len = [], []
166
+ for wf in [waveform]:
167
+ speech, _ = self.frontend.fbank(wf)
168
+ feat, feat_len = self.frontend.lfr_cmvn(speech)
169
+ feats.append(feat)
170
+ feats_len.append(feat_len)
171
+
172
+ feats = self.pad_feats(feats, np.max(feats_len))
173
+ feats_len = np.array(feats_len).astype(np.int32)
174
+ return feats, feats_len
175
+
176
+ def postprocess(self, ctc_logits, encoder_out_lens):
177
+ # 提取数据
178
+ x = ctc_logits[0, 4 : encoder_out_lens[0], :]
179
+
180
+ # 获取最大值索引
181
+ yseq = np.argmax(x, axis=-1)
182
+
183
+ # 去除连续重复元素
184
+ yseq = unique_consecutive(yseq)
185
+
186
+ # 创建掩码并过滤 blank_id
187
+ mask = yseq != self.blank_id
188
+ token_int = yseq[mask].tolist()
189
+
190
+ return token_int
191
+
192
+ def infer_waveform(self, waveform: np.ndarray, language="auto"):
193
+ feat, feat_len = self.preprocess(waveform)
194
+
195
+ slice_len = self.max_seq_len
196
+ slice_num = int(np.ceil(feat.shape[1] / slice_len))
197
+
198
+ language_token = self.lid_dict[language]
199
+ language_token = np.array([language_token], dtype=np.int32)
200
+
201
+ asr_res = []
202
+ for i in range(slice_num):
203
+ if i == 0:
204
+ sub_feat = feat[:, i * slice_len : (i + 1) * slice_len, :]
205
+ else:
206
+ sub_feat = feat[
207
+ :,
208
+ i * slice_len - self.padding : (i + 1) * slice_len - self.padding,
209
+ :,
210
+ ]
211
+
212
+ real_len = sub_feat.shape[1]
213
+ if real_len < self.max_seq_len:
214
+ sub_feat = np.concatenate(
215
+ [
216
+ sub_feat,
217
+ np.zeros(
218
+ (1, self.max_seq_len - real_len, sub_feat.shape[-1]),
219
+ dtype=np.float32,
220
+ ),
221
+ ],
222
+ axis=1,
223
+ )
224
+
225
+ mask = self.sequence_mask(self.max_seq_len + self.query_num, real_len)
226
+
227
+ # start = time.time()
228
+ outputs = self.model.run(
229
+ None,
230
+ {
231
+ "speech": sub_feat,
232
+ "mask": mask,
233
+ "language": language_token,
234
+ },
235
+ )
236
+ ctc_logits, encoder_out_lens = outputs
237
+
238
+ token_int = self.postprocess(ctc_logits, encoder_out_lens)
239
+
240
+ asr_res.extend(token_int)
241
+
242
+ text = "".join([self.tokens[i] for i in asr_res])
243
+ return text
244
+
245
+ def infer(
246
+ self, filepath_or_data: Union[Tuple[np.ndarray, int], str], language="auto", print_rtf=False
247
+ ):
248
+ assert not self.streaming, "This method is for non-streaming model"
249
+
250
+ if isinstance(filepath_or_data, str):
251
+ waveform = self.load_data(filepath_or_data)
252
+ else:
253
+ waveform, sr = filepath_or_data
254
+ if sr != self.sample_rate:
255
+ waveform = librosa.resample(waveform, orig_sr=sr, target_sr=self.sample_rate, res_type="soxr_hq")
256
+
257
+ total_time = waveform.shape[-1] / self.sample_rate
258
+
259
+ start = time.time()
260
+ asr_res = self.infer_waveform(waveform, language)
261
+ latency = time.time() - start
262
+
263
+ if print_rtf:
264
+ rtf = latency / total_time
265
+ print(f"RTF: {rtf} Latency: {latency}s Total length: {total_time}s")
266
+ return asr_res
267
+
268
+ def decode(self, times, tokens):
269
+ times_ms = []
270
+ for step, token in zip(times, tokens):
271
+ if len(self.tokens[token].strip()) == 0:
272
+ continue
273
+ times_ms.append(step * 60)
274
+ return times_ms, "".join([self.tokens[i] for i in tokens])
275
+
276
+ def reset(self):
277
+ from online_fbank import OnlineFbank
278
+ self.cur_idx = -1
279
+ self.decoder.reset()
280
+ self.fbank = OnlineFbank(window_type="hamming")
281
+ self.caches = np.zeros(self.caches_shape)
282
+
283
+ def get_size(self):
284
+ effective_size = self.cur_idx + 1 - self.padding
285
+ if effective_size <= 0:
286
+ return 0
287
+ return effective_size % self.chunk_size or self.chunk_size
288
+
289
+ def stream_infer(self, audio, is_last, language="auto"):
290
+ assert self.streaming, "This method is for streaming model"
291
+
292
+ language_token = self.lid_dict[language]
293
+ language_token = np.array([language_token], dtype=np.int32)
294
+
295
+ self.fbank.accept_waveform(audio, is_last)
296
+ features = self.fbank.get_lfr_frames(
297
+ neg_mean=self.neg_mean, inv_stddev=self.inv_stddev
298
+ )
299
+
300
+ if is_last and len(features) == 0:
301
+ features = self.zeros
302
+
303
+ for idx, feature in enumerate(features):
304
+ is_last = is_last and idx == features.shape[0] - 1
305
+ self.caches = np.roll(self.caches, -1, axis=0)
306
+ self.caches[-1, :] = feature
307
+ self.cur_idx += 1
308
+ cur_size = self.get_size()
309
+ if cur_size != self.chunk_size and not is_last:
310
+ continue
311
+
312
+ speech = self.caches[None, ...]
313
+ outputs = self.model.run(
314
+ None,
315
+ {
316
+ "speech": speech,
317
+ "mask": self.stream_mask,
318
+ "language": language_token,
319
+ },
320
+ )
321
+ ctc_logits, encoder_out_lens = outputs
322
+ probs = ctc_logits[0, 4 : encoder_out_lens[0]]
323
+ probs = torch.from_numpy(probs)
324
+
325
+ if cur_size != self.chunk_size:
326
+ probs = probs[self.chunk_size - cur_size :]
327
+ if not is_last:
328
+ probs = probs[: self.chunk_size]
329
+ if self.beam_size > 1:
330
+ res = self.decoder.ctc_prefix_beam_search(
331
+ probs, beam_size=self.beam_size, is_last=is_last
332
+ )
333
+ times_ms, text = self.decode(res["times"][0], res["tokens"][0])
334
+ else:
335
+ res = self.decoder.ctc_greedy_search(probs, is_last=is_last)
336
+ times_ms, text = self.decode(res["times"], res["tokens"])
337
+ yield {"timestamps": times_ms, "text": text}
python/cert.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIUdmv2KOIO+jdiFDg8lLn0tla5sY8wDQYJKoZIhvcNAQEL
3
+ BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
4
+ GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNTA3MDcwMzI2NDVaFw0yNjA3
5
+ MDcwMzI2NDVaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
6
+ HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggIiMA0GCSqGSIb3DQEB
7
+ AQUAA4ICDwAwggIKAoICAQCkrYAr0M7mLR0hN5tQMNeRLENsJtA7QEGlK5aJXgSs
8
+ BafXw0TOIeE0xgf4GMAx05oKKfMEZE453+VrTVUuttMA9kPli4I1+efxlQdSRv+W
9
+ F84QiUCjg9bg74GaJNX8h9rzr+9Zl94Hak/OeY1yV/5x+DG63XvGXyBPmXUm2Z9l
10
+ TZRCni18+R4PaQ6MM56OzSGmYWmlkyGw3nKiv6lb/CQFHQU1fmJxg4bggMnRkHtP
11
+ Cth++Y9lXwT+1U3CP1xDMmxLiTSX2z7/FjQ9e6d/HdhXbS98ipEQ1OT1CJIIPude
12
+ R9dMdaXAydCAof+jPkxmRU1EI9ssK+GEqx948/R+QN5cgZCLDo54b3fMpbJdsFlD
13
+ 498nTY1cmnkJVb6iUiqNoysqAPDrhfQE7hb59t4RyJock/utqg4n+X+QWKo1B/lA
14
+ gi9UZoAw7NLauzs5sPeLLX9qy+1b6hhoCOeBLOdOe6H+xW9aE0yAPJy2cM1UhGmA
15
+ OgcgXMzB8vI9zPTSmBdXJiVMdGjj2ALIVa+TiKS7mbGjzEVxCuxpR+g469c/9Puh
16
+ syGCo196/j/iw5GSimOpfUlSovY4TnFxATwq1S2XBr2b4tXihxzi8cvdJ8duemLb
17
+ Hvv4aEozzIh+CuoglEiuJ8BI6N4gttDLAYPiNiEld8DVKnD3eikU2HDh3JlrNOVB
18
+ 5QIDAQABo1MwUTAdBgNVHQ4EFgQUVM2by23+rTw4XzMrHNVxjCDpu3EwHwYDVR0j
19
+ BBgwFoAUVM2by23+rTw4XzMrHNVxjCDpu3EwDwYDVR0TAQH/BAUwAwEB/zANBgkq
20
+ hkiG9w0BAQsFAAOCAgEANg6bSEeWlUbupJSnNOWX00jRI6RkDtxv3O/qWb6q/nhT
21
+ 71zGgrdfRk2+fbrFwApDMy5VlDpqwgo76LZSrDuODZwPqc57asHsglVs/2v1h+BW
22
+ tjvQ7zn/VOp+KU7/S3oMumONvdaI0OgPEqnMlQH1hvtlyQpR25SiDDlgzOD/OfDe
23
+ 9jMJ4BCSlOXuSi/q4E1jdZRY+ja6eFlVT7elfOyS7S1kI1akqyX7TpQ5GXw6XIhn
24
+ fr7cQujq4hRpwVjPX03qS2JKna5VE1+qvLUo4xLFF20HAtpg1yS8sRXT4YISXBYk
25
+ 8AeQy470AlCCEm45hW4FxNbu820KkvfY4TqDUj4GyZWq4X5NUtscekwJYnYD8gaA
26
+ Aeyd2SiyicpHYg/tWwzyObXDBDaLDmaXq33PKBinDwiwrAA1BD+T8PqEbHP9Iv0l
27
+ SlRlSiMLiBHul5j1JgWSOFT2bXdNBd8JLmJK9tVBeQP7UpCaUaB/vXtxOS33ASAu
28
+ ReswIYQM7ZFF6hItizi6NjA24trJxdmwStrh062N6Is9nfM/Osvqc4Ms87+jZ5b/
29
+ KLHVi8ZmewhRS0UUMTwFU3RuS5Rj6mOP/5xusAr+EUGqqmX0oOgxfJYqHXIvXxHB
30
+ kP0qRpYkRNoQ4Rauu4B8dl4nYbFspPv/+dVPvJQ1g9hFMj5XM46BSkTYfLaEC6o=
31
+ -----END CERTIFICATE-----
python/download_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Speed up hf download using mirror url
4
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
5
+ from huggingface_hub import snapshot_download
6
+
7
+ current_file_path = os.path.dirname(__file__)
8
+ REPO_ROOT = "AXERA-TECH"
9
+ CACHE_PATH = os.path.join(current_file_path, "models")
10
+
11
+
12
+ def download_model(model_name: str) -> str:
13
+ """
14
+ Download model from AXERA-TECH's huggingface space.
15
+
16
+ model_name: str
17
+ Available model names could be checked on https://huggingface.co/AXERA-TECH.
18
+
19
+ Returns:
20
+ str: Path to model_name
21
+
22
+ """
23
+ os.makedirs(CACHE_PATH, exist_ok=True)
24
+
25
+ model_path = os.path.join(CACHE_PATH, model_name)
26
+ if not os.path.exists(model_path):
27
+ print(f"Downloading {model_name}...")
28
+ snapshot_download(
29
+ repo_id=f"{REPO_ROOT}/{model_name}",
30
+ local_dir=os.path.join(CACHE_PATH, model_name),
31
+ )
32
+
33
+ return model_path
python/frontend.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
4
+ import copy
5
+
6
+ import numpy as np
7
+ import kaldi_native_fbank as knf
8
+
9
+
10
+ class WavFrontend:
11
+ """Conventional frontend structure for ASR."""
12
+
13
+ def __init__(
14
+ self,
15
+ cmvn_file: str = None,
16
+ fs: int = 16000,
17
+ window: str = "hamming",
18
+ n_mels: int = 80,
19
+ frame_length: int = 25,
20
+ frame_shift: int = 10,
21
+ lfr_m: int = 1,
22
+ lfr_n: int = 1,
23
+ dither: float = 1.0,
24
+ **kwargs,
25
+ ) -> None:
26
+
27
+ opts = knf.FbankOptions()
28
+ opts.frame_opts.samp_freq = fs
29
+ opts.frame_opts.dither = dither
30
+ opts.frame_opts.window_type = window
31
+ opts.frame_opts.frame_shift_ms = float(frame_shift)
32
+ opts.frame_opts.frame_length_ms = float(frame_length)
33
+ opts.mel_opts.num_bins = n_mels
34
+ opts.energy_floor = 0
35
+ opts.frame_opts.snip_edges = True
36
+ opts.mel_opts.debug_mel = False
37
+ self.opts = opts
38
+
39
+ self.lfr_m = lfr_m
40
+ self.lfr_n = lfr_n
41
+ self.cmvn_file = cmvn_file
42
+
43
+ if self.cmvn_file:
44
+ self.cmvn = self.load_cmvn()
45
+ self.fbank_fn = None
46
+ self.fbank_beg_idx = 0
47
+ self.reset_status()
48
+
49
+ def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
50
+ waveform = waveform * (1 << 15)
51
+ self.fbank_fn = knf.OnlineFbank(self.opts)
52
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
53
+ frames = self.fbank_fn.num_frames_ready
54
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
55
+ for i in range(frames):
56
+ mat[i, :] = self.fbank_fn.get_frame(i)
57
+ feat = mat.astype(np.float32)
58
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
59
+ return feat, feat_len
60
+
61
+ def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
62
+ waveform = waveform * (1 << 15)
63
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
64
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
65
+ frames = self.fbank_fn.num_frames_ready
66
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
67
+ for i in range(self.fbank_beg_idx, frames):
68
+ mat[i, :] = self.fbank_fn.get_frame(i)
69
+ # self.fbank_beg_idx += (frames-self.fbank_beg_idx)
70
+ feat = mat.astype(np.float32)
71
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
72
+ return feat, feat_len
73
+
74
+ def reset_status(self):
75
+ self.fbank_fn = knf.OnlineFbank(self.opts)
76
+ self.fbank_beg_idx = 0
77
+
78
+ def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
79
+ if self.lfr_m != 1 or self.lfr_n != 1:
80
+ feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
81
+
82
+ if self.cmvn_file:
83
+ feat = self.apply_cmvn(feat)
84
+
85
+ feat_len = np.array(feat.shape[0]).astype(np.int32)
86
+ return feat, feat_len
87
+
88
+ @staticmethod
89
+ def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
90
+ LFR_inputs = []
91
+
92
+ T = inputs.shape[0]
93
+ T_lfr = int(np.ceil(T / lfr_n))
94
+ left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
95
+ inputs = np.vstack((left_padding, inputs))
96
+ T = T + (lfr_m - 1) // 2
97
+ for i in range(T_lfr):
98
+ if lfr_m <= T - i * lfr_n:
99
+ LFR_inputs.append(
100
+ (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
101
+ )
102
+ else:
103
+ # process last LFR frame
104
+ num_padding = lfr_m - (T - i * lfr_n)
105
+ frame = inputs[i * lfr_n :].reshape(-1)
106
+ for _ in range(num_padding):
107
+ frame = np.hstack((frame, inputs[-1]))
108
+
109
+ LFR_inputs.append(frame)
110
+ LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
111
+ return LFR_outputs
112
+
113
+ def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
114
+ """
115
+ Apply CMVN with mvn data
116
+ """
117
+ frame, dim = inputs.shape
118
+ means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
119
+ vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
120
+ inputs = (inputs + means) * vars
121
+ return inputs
122
+
123
+ def load_cmvn(
124
+ self,
125
+ ) -> np.ndarray:
126
+ with open(self.cmvn_file, "r", encoding="utf-8") as f:
127
+ lines = f.readlines()
128
+
129
+ means_list = []
130
+ vars_list = []
131
+ for i in range(len(lines)):
132
+ line_item = lines[i].split()
133
+ if line_item[0] == "<AddShift>":
134
+ line_item = lines[i + 1].split()
135
+ if line_item[0] == "<LearnRateCoef>":
136
+ add_shift_line = line_item[3 : (len(line_item) - 1)]
137
+ means_list = list(add_shift_line)
138
+ continue
139
+ elif line_item[0] == "<Rescale>":
140
+ line_item = lines[i + 1].split()
141
+ if line_item[0] == "<LearnRateCoef>":
142
+ rescale_line = line_item[3 : (len(line_item) - 1)]
143
+ vars_list = list(rescale_line)
144
+ continue
145
+
146
+ means = np.array(means_list).astype(np.float64)
147
+ vars = np.array(vars_list).astype(np.float64)
148
+ cmvn = np.array([means, vars])
149
+ return cmvn
150
+
151
+
152
+ class WavFrontendOnline(WavFrontend):
153
+ def __init__(self, **kwargs):
154
+ super().__init__(**kwargs)
155
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
156
+ # add variables
157
+ self.frame_sample_length = int(
158
+ self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
159
+ )
160
+ self.frame_shift_sample_length = int(
161
+ self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
162
+ )
163
+ self.waveform = None
164
+ self.reserve_waveforms = None
165
+ self.input_cache = None
166
+ self.lfr_splice_cache = []
167
+
168
+ @staticmethod
169
+ # inputs has catted the cache
170
+ def apply_lfr(
171
+ inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
172
+ ) -> Tuple[np.ndarray, np.ndarray, int]:
173
+ """
174
+ Apply lfr with data
175
+ """
176
+
177
+ LFR_inputs = []
178
+ T = inputs.shape[0] # include the right context
179
+ T_lfr = int(
180
+ np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
181
+ ) # minus the right context: (lfr_m - 1) // 2
182
+ splice_idx = T_lfr
183
+ for i in range(T_lfr):
184
+ if lfr_m <= T - i * lfr_n:
185
+ LFR_inputs.append(
186
+ (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
187
+ )
188
+ else: # process last LFR frame
189
+ if is_final:
190
+ num_padding = lfr_m - (T - i * lfr_n)
191
+ frame = (inputs[i * lfr_n :]).reshape(-1)
192
+ for _ in range(num_padding):
193
+ frame = np.hstack((frame, inputs[-1]))
194
+ LFR_inputs.append(frame)
195
+ else:
196
+ # update splice_idx and break the circle
197
+ splice_idx = i
198
+ break
199
+ splice_idx = min(T - 1, splice_idx * lfr_n)
200
+ lfr_splice_cache = inputs[splice_idx:, :]
201
+ LFR_outputs = np.vstack(LFR_inputs)
202
+ return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
203
+
204
+ @staticmethod
205
+ def compute_frame_num(
206
+ sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
207
+ ) -> int:
208
+ frame_num = int(
209
+ (sample_length - frame_sample_length) / frame_shift_sample_length + 1
210
+ )
211
+ return (
212
+ frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
213
+ )
214
+
215
+ def fbank(
216
+ self, input: np.ndarray, input_lengths: np.ndarray
217
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
218
+ self.fbank_fn = knf.OnlineFbank(self.opts)
219
+ batch_size = input.shape[0]
220
+ if self.input_cache is None:
221
+ self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
222
+ input = np.concatenate((self.input_cache, input), axis=1)
223
+ frame_num = self.compute_frame_num(
224
+ input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
225
+ )
226
+ # update self.in_cache
227
+ self.input_cache = input[
228
+ :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
229
+ ]
230
+ waveforms = np.empty(0, dtype=np.float32)
231
+ feats_pad = np.empty(0, dtype=np.float32)
232
+ feats_lens = np.empty(0, dtype=np.int32)
233
+ if frame_num:
234
+ waveforms = []
235
+ feats = []
236
+ feats_lens = []
237
+ for i in range(batch_size):
238
+ waveform = input[i]
239
+ waveforms.append(
240
+ waveform[
241
+ : (
242
+ (frame_num - 1) * self.frame_shift_sample_length
243
+ + self.frame_sample_length
244
+ )
245
+ ]
246
+ )
247
+ waveform = waveform * (1 << 15)
248
+
249
+ self.fbank_fn.accept_waveform(
250
+ self.opts.frame_opts.samp_freq, waveform.tolist()
251
+ )
252
+ frames = self.fbank_fn.num_frames_ready
253
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
254
+ for i in range(frames):
255
+ mat[i, :] = self.fbank_fn.get_frame(i)
256
+ feat = mat.astype(np.float32)
257
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
258
+ feats.append(feat)
259
+ feats_lens.append(feat_len)
260
+
261
+ waveforms = np.stack(waveforms)
262
+ feats_lens = np.array(feats_lens)
263
+ feats_pad = np.array(feats)
264
+ self.fbanks = feats_pad
265
+ self.fbanks_lens = copy.deepcopy(feats_lens)
266
+ return waveforms, feats_pad, feats_lens
267
+
268
+ def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
269
+ return self.fbanks, self.fbanks_lens
270
+
271
+ def lfr_cmvn(
272
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
273
+ ) -> Tuple[np.ndarray, np.ndarray, List[int]]:
274
+ batch_size = input.shape[0]
275
+ feats = []
276
+ feats_lens = []
277
+ lfr_splice_frame_idxs = []
278
+ for i in range(batch_size):
279
+ mat = input[i, : input_lengths[i], :]
280
+ lfr_splice_frame_idx = -1
281
+ if self.lfr_m != 1 or self.lfr_n != 1:
282
+ # update self.lfr_splice_cache in self.apply_lfr
283
+ mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
284
+ mat, self.lfr_m, self.lfr_n, is_final
285
+ )
286
+ if self.cmvn_file is not None:
287
+ mat = self.apply_cmvn(mat)
288
+ feat_length = mat.shape[0]
289
+ feats.append(mat)
290
+ feats_lens.append(feat_length)
291
+ lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
292
+
293
+ feats_lens = np.array(feats_lens)
294
+ feats_pad = np.array(feats)
295
+ return feats_pad, feats_lens, lfr_splice_frame_idxs
296
+
297
+ def extract_fbank(
298
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
299
+ ) -> Tuple[np.ndarray, np.ndarray]:
300
+ batch_size = input.shape[0]
301
+ assert (
302
+ batch_size == 1
303
+ ), "we support to extract feature online only when the batch size is equal to 1 now"
304
+ waveforms, feats, feats_lengths = self.fbank(
305
+ input, input_lengths
306
+ ) # input shape: B T D
307
+ if feats.shape[0]:
308
+ self.waveforms = (
309
+ waveforms
310
+ if self.reserve_waveforms is None
311
+ else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
312
+ )
313
+ if not self.lfr_splice_cache:
314
+ for i in range(batch_size):
315
+ self.lfr_splice_cache.append(
316
+ np.expand_dims(feats[i][0, :], axis=0).repeat(
317
+ (self.lfr_m - 1) // 2, axis=0
318
+ )
319
+ )
320
+
321
+ if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
322
+ lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
323
+ feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
324
+ feats_lengths += lfr_splice_cache_np[0].shape[0]
325
+ frame_from_waveforms = int(
326
+ (self.waveforms.shape[1] - self.frame_sample_length)
327
+ / self.frame_shift_sample_length
328
+ + 1
329
+ )
330
+ minus_frame = (
331
+ (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
332
+ )
333
+ feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
334
+ feats, feats_lengths, is_final
335
+ )
336
+ if self.lfr_m == 1:
337
+ self.reserve_waveforms = None
338
+ else:
339
+ reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
340
+ # print('reserve_frame_idx: ' + str(reserve_frame_idx))
341
+ # print('frame_frame: ' + str(frame_from_waveforms))
342
+ self.reserve_waveforms = self.waveforms[
343
+ :,
344
+ reserve_frame_idx
345
+ * self.frame_shift_sample_length : frame_from_waveforms
346
+ * self.frame_shift_sample_length,
347
+ ]
348
+ sample_length = (
349
+ frame_from_waveforms - 1
350
+ ) * self.frame_shift_sample_length + self.frame_sample_length
351
+ self.waveforms = self.waveforms[:, :sample_length]
352
+ else:
353
+ # update self.reserve_waveforms and self.lfr_splice_cache
354
+ self.reserve_waveforms = self.waveforms[
355
+ :, : -(self.frame_sample_length - self.frame_shift_sample_length)
356
+ ]
357
+ for i in range(batch_size):
358
+ self.lfr_splice_cache[i] = np.concatenate(
359
+ (self.lfr_splice_cache[i], feats[i]), axis=0
360
+ )
361
+ return np.empty(0, dtype=np.float32), feats_lengths
362
+ else:
363
+ if is_final:
364
+ self.waveforms = (
365
+ waveforms
366
+ if self.reserve_waveforms is None
367
+ else self.reserve_waveforms
368
+ )
369
+ feats = np.stack(self.lfr_splice_cache)
370
+ feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
371
+ feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
372
+ if is_final:
373
+ self.cache_reset()
374
+ return feats, feats_lengths
375
+
376
+ def get_waveforms(self):
377
+ return self.waveforms
378
+
379
+ def cache_reset(self):
380
+ self.fbank_fn = knf.OnlineFbank(self.opts)
381
+ self.reserve_waveforms = None
382
+ self.input_cache = None
383
+ self.lfr_splice_cache = []
384
+
385
+
386
+ def load_bytes(input):
387
+ middle_data = np.frombuffer(input, dtype=np.int16)
388
+ middle_data = np.asarray(middle_data)
389
+ if middle_data.dtype.kind not in "iu":
390
+ raise TypeError("'middle_data' must be an array of integers")
391
+ dtype = np.dtype("float32")
392
+ if dtype.kind != "f":
393
+ raise TypeError("'dtype' must be a floating point type")
394
+
395
+ i = np.iinfo(middle_data.dtype)
396
+ abs_max = 2 ** (i.bits - 1)
397
+ offset = i.min + abs_max
398
+ array = np.frombuffer(
399
+ (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32
400
+ )
401
+ return array
402
+
403
+
404
+ class SinusoidalPositionEncoderOnline:
405
+ """Streaming Positional encoding."""
406
+
407
+ def encode(
408
+ self,
409
+ positions: np.ndarray = None,
410
+ depth: int = None,
411
+ dtype: np.dtype = np.float32,
412
+ ):
413
+ batch_size = positions.shape[0]
414
+ positions = positions.astype(dtype)
415
+ log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (
416
+ depth / 2 - 1
417
+ )
418
+ inv_timescales = np.exp(
419
+ np.arange(depth / 2).astype(dtype) * (-log_timescale_increment)
420
+ )
421
+ inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
422
+ scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(
423
+ inv_timescales, [1, 1, -1]
424
+ )
425
+ encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
426
+ return encoding.astype(dtype)
427
+
428
+ def forward(self, x, start_idx=0):
429
+ batch_size, timesteps, input_dim = x.shape
430
+ positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
431
+ position_encoding = self.encode(positions, input_dim, x.dtype)
432
+
433
+ return x + position_encoding[:, start_idx : start_idx + timesteps]
434
+
435
+
436
+ def test():
437
+ path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
438
+ import librosa
439
+
440
+ cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
441
+ config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
442
+ from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
443
+
444
+ config = read_yaml(config_file)
445
+ waveform, _ = librosa.load(path, sr=None)
446
+ frontend = WavFrontend(
447
+ cmvn_file=cmvn_file,
448
+ **config["frontend_conf"],
449
+ )
450
+ speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
451
+ feat, feat_len = frontend.lfr_cmvn(
452
+ speech
453
+ ) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
454
+
455
+ frontend.reset_status() # clear cache
456
+ return feat, feat_len
457
+
458
+
459
+ if __name__ == "__main__":
460
+ test()
python/gradio_demo.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from SenseVoiceAx import SenseVoiceAx
4
+ import numpy as np
5
+
6
+ model_root = "../sensevoice_ax650"
7
+ max_seq_len = 256
8
+ model_path = os.path.join(model_root, "sensevoice.axmodel")
9
+
10
+ assert os.path.exists(model_path), f"model {model_path} not exist"
11
+
12
+ cmvn_file = os.path.join(model_root, "am.mvn")
13
+ bpe_model = os.path.join(model_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
14
+ token_file = os.path.join(model_root, "tokens.txt")
15
+
16
+ model = SenseVoiceAx(
17
+ model_path,
18
+ cmvn_file,
19
+ token_file,
20
+ bpe_model,
21
+ max_seq_len=max_seq_len,
22
+ beam_size=3,
23
+ hot_words=None,
24
+ streaming=False,
25
+ )
26
+
27
+ # 你实现的语言转文本函数
28
+ def speech_to_text(audio_input, lang):
29
+ """
30
+ audio_path: A tuple of (sample rate in Hz, audio data as numpy array).
31
+ lang: 语言类型 "auto", "zh", "en", "yue", "ja", "ko"
32
+ """
33
+ if not audio_input:
34
+ return "无音频"
35
+
36
+ sr, audio_data = audio_input
37
+ if audio_data.dtype != np.float32:
38
+ audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
39
+
40
+ asr_res = model.infer((audio_data, sr), lang, print_rtf=False)
41
+ return asr_res
42
+
43
+
44
+ def main():
45
+ with gr.Blocks() as demo:
46
+ with gr.Row():
47
+ output_text = gr.Textbox(label="识别结果", lines=5)
48
+
49
+ with gr.Row():
50
+ audio_input = gr.Audio(
51
+ sources=["microphone", "upload"], type="numpy", label="录制或上传音频", format="wav"
52
+ )
53
+ lang_dropdown = gr.Dropdown(
54
+ choices=["auto", "zh", "en", "yue", "ja", "ko"],
55
+ value="auto",
56
+ label="选择音频语言",
57
+ )
58
+
59
+ audio_input.change(
60
+ fn=speech_to_text, inputs=[audio_input, lang_dropdown], outputs=output_text
61
+ )
62
+
63
+ demo.launch(
64
+ server_name="0.0.0.0",
65
+ ssl_certfile="./cert.pem",
66
+ ssl_keyfile="./key.pem",
67
+ ssl_verify=False,
68
+ )
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
python/key.pem ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN PRIVATE KEY-----
2
+ MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCkrYAr0M7mLR0h
3
+ N5tQMNeRLENsJtA7QEGlK5aJXgSsBafXw0TOIeE0xgf4GMAx05oKKfMEZE453+Vr
4
+ TVUuttMA9kPli4I1+efxlQdSRv+WF84QiUCjg9bg74GaJNX8h9rzr+9Zl94Hak/O
5
+ eY1yV/5x+DG63XvGXyBPmXUm2Z9lTZRCni18+R4PaQ6MM56OzSGmYWmlkyGw3nKi
6
+ v6lb/CQFHQU1fmJxg4bggMnRkHtPCth++Y9lXwT+1U3CP1xDMmxLiTSX2z7/FjQ9
7
+ e6d/HdhXbS98ipEQ1OT1CJIIPudeR9dMdaXAydCAof+jPkxmRU1EI9ssK+GEqx94
8
+ 8/R+QN5cgZCLDo54b3fMpbJdsFlD498nTY1cmnkJVb6iUiqNoysqAPDrhfQE7hb5
9
+ 9t4RyJock/utqg4n+X+QWKo1B/lAgi9UZoAw7NLauzs5sPeLLX9qy+1b6hhoCOeB
10
+ LOdOe6H+xW9aE0yAPJy2cM1UhGmAOgcgXMzB8vI9zPTSmBdXJiVMdGjj2ALIVa+T
11
+ iKS7mbGjzEVxCuxpR+g469c/9PuhsyGCo196/j/iw5GSimOpfUlSovY4TnFxATwq
12
+ 1S2XBr2b4tXihxzi8cvdJ8duemLbHvv4aEozzIh+CuoglEiuJ8BI6N4gttDLAYPi
13
+ NiEld8DVKnD3eikU2HDh3JlrNOVB5QIDAQABAoICAAwvSZu0WbgT9XhRkNHi/fL1
14
+ bKWyIi0y03NLttns1XlUT8zPW3t0a/ac19ZxH7jFbeaQ9Qoe+99yDsZyPzpzd522
15
+ Gw7/KWrMq29SIMEN6iJqb3+vZX4pJqXtGCYA0hPbNNsGv7XdiVBQ0Efi8ZGDgPBg
16
+ 4MPGrekJ0oO2mJb/z4341V6v19t2jqqtkiXTOfOvVO071EvWh6MlH86lUibcELQ6
17
+ J0+ueCKVt0326Y1H3KqGub8nawdL+7wj/0VqAm3Ma3vpUoq77meEdpK9YAbap/kh
18
+ ZeaYACTX2SW89RqGwQArs6VU0ny3J2YlK8ZMjphF9Md1GbtPzIzQS/CwqFIL2F1Z
19
+ ojsBdUj2V2ZFdL9ZQZYbnI68xU7H+RolOqBdmB1U/M19kBVtOz+Hrzu1jnSVn+YC
20
+ dL3W/bmgk9xrGGUOK9oIiH3NaLarK1wsbUpaatQxmNNVMfIz06Z4tAbgO8KGqFhb
21
+ f92MdmnLFPjYLg93NZ7wOWBr+S25FZFd58aYwOm4D+pnbYPoq7x7eYZwm9+gSIY7
22
+ y9k4JPFlNhmAecgAhPMg9RaVzJN3qXfDHb2pT7rcJ5DE4GaJotUDv/dL9OnzGStP
23
+ QfYMERObEcaQIC14z3q/JOph0hn7gzPtYChJFxdijCTE+jf+9JKctgAiMmysLbpv
24
+ 7zpXcyXdWiRpI69N4pZ/AoIBAQDmsCXKn20H+1OJ8tYBmKcFm/ZGcHek9mUDhPgu
25
+ Bfgag+sk5tInG05gCtvlrAkRikJUuj88URjxqoIjCCmwqFXT3WOrwAx0z4IRY0Gv
26
+ vLnmslbGl34GiLvKodvsIKb20NXsCLM8ScqpVuH0GTezu6J8qHwlKqLSFuWk3BHN
27
+ hb3mcCNiw2hmvZVY1fNTpqsLRxpFW2sblYtlNc6nbCDwXHMdQEsay0CTt56SmUWK
28
+ 7dxtW217nS5rClHUm95iyIInfBqUT1KEmvcBnqnnyB3mRGt4wN6wqhyQOSjv2OQ1
29
+ z2UfKuSHtMJmnNkduuEXeOJfv4HG/p/cituGDZcTzHdNk3T3AoIBAQC2vyxBzERp
30
+ 3oMBfd3JgQpC9uEHmmxvtHRHskXnH+Fz85WznTBGTm0/RFRd2/LESZC5ODZVRAaG
31
+ GFlBpzXD5dDxHDTIzSSW4Kt/WPYg5UvkKDpxpqvTv7LFV4m7/tjgHGoirJKsZDHi
32
+ a9X1ER3ZE9GEg1ebCVIacdvL3EvzcoFYC/ZF7JN2SWfbbFdEWOrGYs1du0rJHiOX
33
+ CWEu1nVcrq+2U7IEqxLD4Ns3talxmDr3AZ4ATHqAcPNp61vAggjO78u9RSmwAPd3
34
+ NWlHk/S2Gi7ti4mQQHMmdiog4PvTanOuvybjLDh6bK895Wo8qHHoEyMZ3xkVEhQG
35
+ HGv8HmWmenUDAoIBAQC1USgzDYHOLz1nBOYuVQSaRQ6aKNXxY/TbgkzrJ6ftd1iA
36
+ JahyMmU02fQinkh2b9xY6ha/2uInOKSW0liqUHU9VBp+KTHhMiSCdChx732SlQPd
37
+ jb7xddFcoEHSY4u4HUa3AdOXBEz1MqPgj12XuFgrcOY69DsLtBGFta+MgZ1UHTnC
38
+ 6+IINuTG8UsSqcJw188PSp5yDOWGhHdMYpG1OoUELb+abLzyHfXWNgBSBUkm7yCr
39
+ c0zDt1XALU7rB7w9Oq9NeNdcAM06ibHzyvetQIPUYovmAZ73wOWrNyeQH9XUXItJ
40
+ GstdidShKHy5TTtolIZ1mTafSsjmoZHobuIqqEbbAoIBABO0BPeLKI0pmoJcqb8C
41
+ FLMnnxeMxMg+cpMQW40R2OMBjlBxUDUkW48ItPfxsPkM3Xe64dDLptBqa6UyfA+F
42
+ BcQZQG+t/pXt30+5rb/aORZ+Z969E6We841nZMhKL+Pp7F+Ur7O6kc5Rxh3IHKm9
43
+ A0gAST/D/4Auan5OYDn9TIjLsV/UpAmK3JHB2p7Z32ZIXNAQU33frAKq1jmQkdLO
44
+ Ws+TsoviTgGkir407fH7cdAT8o8hr8uNYhE3eQsGeiClphfgDyCU2hmWPqWjBC1m
45
+ IU0nUEunR0MMVnp5B23B+nsKzQyNRgGdGj/YLl4f4zgcaBpv/WpSKqqGAfaK6HbM
46
+ mTUCggEAP3G/WCFE99G0fHctcwG/5vIAhbxgs7MMZ2gs9WCm3p4IbkVIsqUpx0C8
47
+ WDQ3e3sTqfIZy2Ixr7SFfD3/W+J+2D+UyJEYigbjTj3U6BHD21lFocC393YoTcOp
48
+ xruTY3XrbJF4y1y5kShngBZXfjsuYfKwy+xhGaES+FS1bQo+L+EkwBp0VWZVd+FN
49
+ zf7//zzhllVcPpQzihnax4LCaOUFKVKm3ajalPsV8gozSJrRsEAy3SezRWNK/gos
50
+ SgY5H+4jlO8l8LvswmLRxCpyCG1nLidPMrQAzrcWbDLDkG+FSultDrIJ6oOCTtmz
51
+ oy92VpLjCy74/qSmIijKrd+A3Uawrw==
52
+ -----END PRIVATE KEY-----
python/main.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from SenseVoiceAx import SenseVoiceAx
4
+ import librosa
5
+ import time
6
+
7
+
8
+ def get_args():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument(
11
+ "--input", "-i", required=True, type=str, help="Input audio file"
12
+ )
13
+ parser.add_argument(
14
+ "--language",
15
+ "-l",
16
+ required=False,
17
+ type=str,
18
+ default="auto",
19
+ choices=["auto", "zh", "en", "yue", "ja", "ko"],
20
+ )
21
+ parser.add_argument("--streaming", action="store_true")
22
+ return parser.parse_args()
23
+
24
+
25
+ def main():
26
+ args = get_args()
27
+ print(vars(args))
28
+
29
+ input_audio = args.input
30
+ language = args.language
31
+ model_root = "../sensevoice_ax650"
32
+ if not args.streaming:
33
+ max_seq_len = 256
34
+ model_path = os.path.join(model_root, "sensevoice.axmodel")
35
+ else:
36
+ max_seq_len = 26
37
+ model_path = os.path.join(model_root, "streaming_sensevoice.axmodel")
38
+
39
+ assert os.path.exists(model_path), f"model {model_path} not exist"
40
+
41
+ cmvn_file = os.path.join(model_root, "am.mvn")
42
+ bpe_model = os.path.join(model_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
43
+ token_file = os.path.join(model_root, "tokens.txt")
44
+
45
+ model = SenseVoiceAx(
46
+ model_path,
47
+ cmvn_file,
48
+ token_file,
49
+ bpe_model,
50
+ max_seq_len=max_seq_len,
51
+ beam_size=3,
52
+ hot_words=None,
53
+ streaming=args.streaming,
54
+ )
55
+
56
+ if not args.streaming:
57
+ asr_res = model.infer(input_audio, language, print_rtf=True)
58
+ print("ASR result: " + asr_res)
59
+ else:
60
+ samples, sr = librosa.load(input_audio, sr=16000)
61
+ samples = (samples * 32768).tolist()
62
+ duration = len(samples) / 16000
63
+
64
+ start = time.time()
65
+ step = int(0.1 * sr)
66
+ for i in range(0, len(samples), step):
67
+ is_last = i + step >= len(samples)
68
+ for res in model.stream_infer(samples[i : i + step], is_last, language):
69
+ print(res)
70
+
71
+ end = time.time()
72
+ cost_time = end - start
73
+
74
+ print(f"RTF: {cost_time / duration}")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
python/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ numpy<2
3
+ kaldi-native-fbank
4
+ librosa==0.9.1
5
+ fastapi
6
+ gradio==5.47.1
7
+ online-fbank
8
+ asr_decoder
9
+ resampy
10
+ soxr
python/server.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from fastapi import FastAPI, HTTPException, Body
3
+ from fastapi.responses import JSONResponse
4
+ from typing import List, Optional
5
+ import logging
6
+ from SenseVoiceAx import SenseVoiceAx
7
+ import os
8
+ import librosa
9
+
10
+ # 初始化日志
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API")
15
+
16
+ # 全局变量存储模型
17
+ asr_model = None
18
+
19
+
20
+ @app.on_event("startup")
21
+ async def load_model():
22
+ """
23
+ 服务启动时加载ASR模型
24
+ """
25
+ global asr_model
26
+ logger.info("Loading ASR model...")
27
+
28
+ try:
29
+ # 模型加载
30
+ language = "auto"
31
+ model_root = "../sensevoice_ax650"
32
+ max_seq_len = 256
33
+ model_path = os.path.join(model_root, "sensevoice.axmodel")
34
+
35
+ assert os.path.exists(model_path), f"model {model_path} not exist"
36
+
37
+ cmvn_file = os.path.join(model_root, "am.mvn")
38
+ bpe_model = os.path.join(model_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
39
+ token_file = os.path.join(model_root, "tokens.txt")
40
+
41
+ asr_model = SenseVoiceAx(
42
+ model_path,
43
+ cmvn_file,
44
+ token_file,
45
+ bpe_model,
46
+ max_seq_len=max_seq_len,
47
+ beam_size=3,
48
+ hot_words=None,
49
+ streaming=False,
50
+ )
51
+
52
+ print(f"language: {language}")
53
+ print(f"model_path: {model_path}")
54
+
55
+ logger.info("ASR model loaded successfully")
56
+ except Exception as e:
57
+ logger.error(f"Failed to load ASR model: {str(e)}")
58
+ raise
59
+
60
+
61
+ def validate_audio_data(audio_data: List[float]) -> np.ndarray:
62
+ """
63
+ 验证并转换音频数据为numpy数组
64
+
65
+ 参数:
66
+ - audio_data: 浮点数列表表示的音频数据
67
+
68
+ 返回:
69
+ - 验证后的numpy数组
70
+ """
71
+ try:
72
+ # 转换为numpy数组
73
+ np_array = np.array(audio_data, dtype=np.float32)
74
+
75
+ # 验证数据有效性
76
+ if np_array.ndim != 1:
77
+ raise ValueError("Audio data must be 1-dimensional")
78
+
79
+ if len(np_array) == 0:
80
+ raise ValueError("Audio data cannot be empty")
81
+
82
+ return np_array
83
+ except Exception as e:
84
+ raise ValueError(f"Invalid audio data: {str(e)}")
85
+
86
+
87
+ @app.get("/get_language", summary="Get current language")
88
+ async def get_language():
89
+ return JSONResponse(content={"language": asr_model.language})
90
+
91
+
92
+ @app.get(
93
+ "/get_language_options",
94
+ summary="Get possible language options, possible options include [auto, zh, en, yue, ja, ko]",
95
+ )
96
+ async def get_language_options():
97
+ return JSONResponse(content={"language_options": asr_model.language_options})
98
+
99
+
100
+ @app.post("/asr", summary="Recognize speech from numpy audio data")
101
+ async def recognize_speech(
102
+ audio_data: List[float] = Body(
103
+ ..., embed=True, description="Audio data as list of floats"
104
+ ),
105
+ sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
106
+ language: Optional[str] = Body("auto", description="Language"),
107
+ ):
108
+ """
109
+ 接收numpy数组格式的音频数据并返回识别结果
110
+
111
+ 参数:
112
+ - audio_data: 浮点数列表表示的音频数据
113
+ - sample_rate: 音频采样率(默认16000Hz)
114
+
115
+ 返回:
116
+ - JSON包含识别文本
117
+ """
118
+ try:
119
+ # 检查模型是否已加载
120
+ if asr_model is None:
121
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
122
+
123
+ logger.info(f"Received audio data with length: {len(audio_data)}")
124
+
125
+ # 验证并转换数据
126
+ np_audio = validate_audio_data(audio_data)
127
+
128
+ # 调用模型进行识别
129
+ result = asr_model.infer_waveform((np_audio, sample_rate), language)
130
+
131
+ return JSONResponse(content={"text": result})
132
+
133
+ except ValueError as e:
134
+ logger.error(f"Validation error: {str(e)}")
135
+ raise HTTPException(status_code=400, detail=str(e))
136
+ except Exception as e:
137
+ logger.error(f"Recognition error: {str(e)}")
138
+ raise HTTPException(status_code=500, detail=str(e))
139
+
140
+
141
+ if __name__ == "__main__":
142
+ import uvicorn
143
+
144
+ uvicorn.run(app, host="0.0.0.0", port=8000)
python/test_wer.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from SenseVoiceAx import SenseVoiceAx
4
+ from download_utils import download_model
5
+ import logging
6
+ import re
7
+
8
+
9
+ def setup_logging():
10
+ """配置日志系统,同时输出到控制台和文件"""
11
+ # 获取脚本所在目录
12
+ script_dir = os.path.dirname(os.path.abspath(__file__))
13
+ log_file = os.path.join(script_dir, "test_wer.log")
14
+
15
+ # 配置日志格式
16
+ log_format = "%(asctime)s - %(levelname)s - %(message)s"
17
+ date_format = "%Y-%m-%d %H:%M:%S"
18
+
19
+ # 创建logger
20
+ logger = logging.getLogger()
21
+ logger.setLevel(logging.INFO)
22
+
23
+ # 清除现有的handler
24
+ for handler in logger.handlers[:]:
25
+ logger.removeHandler(handler)
26
+
27
+ # 创建文件handler
28
+ file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
29
+ file_handler.setLevel(logging.INFO)
30
+ file_formatter = logging.Formatter(log_format, date_format)
31
+ file_handler.setFormatter(file_formatter)
32
+
33
+ # 创建控制台handler
34
+ console_handler = logging.StreamHandler()
35
+ console_handler.setLevel(logging.INFO)
36
+ console_formatter = logging.Formatter(log_format, date_format)
37
+ console_handler.setFormatter(console_formatter)
38
+
39
+ # 添加handler到logger
40
+ logger.addHandler(file_handler)
41
+ logger.addHandler(console_handler)
42
+
43
+ return logger
44
+
45
+
46
+ class AIShellDataset:
47
+ def __init__(self, gt_path: str):
48
+ """
49
+ 初始化数据集
50
+
51
+ Args:
52
+ json_path: voice.json文件的路径
53
+ """
54
+ self.gt_path = gt_path
55
+ self.dataset_dir = os.path.dirname(gt_path)
56
+ self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
57
+
58
+ # 检查必要文件和文件夹是否存在
59
+ assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
60
+ assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
61
+
62
+ # 加载数据
63
+ self.data = []
64
+ with open(gt_path, "r", encoding="utf-8") as f:
65
+ for line in f:
66
+ line = line.strip()
67
+ audio_path, gt = line.split(" ")
68
+ audio_path = os.path.join(self.voice_dir, audio_path + ".wav")
69
+ self.data.append({"audio_path": audio_path, "gt": gt})
70
+
71
+ # 使用logging而不是print
72
+ logger = logging.getLogger()
73
+ logger.info(f"加载了 {len(self.data)} 条数据")
74
+
75
+ def __iter__(self):
76
+ """返回迭代器"""
77
+ self.index = 0
78
+ return self
79
+
80
+ def __next__(self):
81
+ """返回下一个数据项"""
82
+ if self.index >= len(self.data):
83
+ raise StopIteration
84
+
85
+ item = self.data[self.index]
86
+ audio_path = item["audio_path"]
87
+ ground_truth = item["gt"]
88
+
89
+ self.index += 1
90
+ return audio_path, ground_truth
91
+
92
+ def __len__(self):
93
+ """返回数据集大小"""
94
+ return len(self.data)
95
+
96
+
97
+ class CommonVoiceDataset:
98
+ """Common Voice数据集解析器"""
99
+
100
+ def __init__(self, tsv_path: str):
101
+ """
102
+ 初始化数据集
103
+
104
+ Args:
105
+ json_path: voice.json文件的路径
106
+ """
107
+ self.tsv_path = tsv_path
108
+ self.dataset_dir = os.path.dirname(tsv_path)
109
+ self.voice_dir = os.path.join(self.dataset_dir, "clips")
110
+
111
+ # 检查必要文件和文件夹是否存在
112
+ assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
113
+ assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
114
+
115
+ # 加载JSON数据
116
+ self.data = []
117
+ with open(tsv_path, "r", encoding="utf-8") as f:
118
+ f.readline()
119
+ for line in f:
120
+ line = line.strip()
121
+ splits = line.split("\t")
122
+ audio_path = splits[1]
123
+ gt = splits[3]
124
+ audio_path = os.path.join(self.voice_dir, audio_path)
125
+ self.data.append({"audio_path": audio_path, "gt": gt})
126
+
127
+ # 使用logging而不是print
128
+ logger = logging.getLogger()
129
+ logger.info(f"加载了 {len(self.data)} 条数据")
130
+
131
+ def __iter__(self):
132
+ """返回迭代器"""
133
+ self.index = 0
134
+ return self
135
+
136
+ def __next__(self):
137
+ """返回下一个数据项"""
138
+ if self.index >= len(self.data):
139
+ raise StopIteration
140
+
141
+ item = self.data[self.index]
142
+ audio_path = item["audio_path"]
143
+ ground_truth = item["gt"]
144
+
145
+ self.index += 1
146
+ return audio_path, ground_truth
147
+
148
+ def __len__(self):
149
+ """返回数据集大小"""
150
+ return len(self.data)
151
+
152
+
153
+ def get_args():
154
+ parser = argparse.ArgumentParser()
155
+ parser.add_argument(
156
+ "--dataset",
157
+ "-d",
158
+ type=str,
159
+ required=True,
160
+ choices=["aishell", "common_voice"],
161
+ help="Test dataset",
162
+ )
163
+ parser.add_argument(
164
+ "--gt_path",
165
+ "-g",
166
+ type=str,
167
+ required=True,
168
+ help="Test dataset ground truth file",
169
+ )
170
+ parser.add_argument(
171
+ "--language",
172
+ "-l",
173
+ required=False,
174
+ type=str,
175
+ default="auto",
176
+ choices=["auto", "zh", "en", "yue", "ja", "ko"],
177
+ )
178
+ parser.add_argument(
179
+ "--max_num", type=int, default=-1, required=False, help="Maximum test data num"
180
+ )
181
+ return parser.parse_args()
182
+
183
+
184
+ def min_distance(word1: str, word2: str) -> int:
185
+
186
+ row = len(word1) + 1
187
+ column = len(word2) + 1
188
+
189
+ cache = [[0] * column for i in range(row)]
190
+
191
+ for i in range(row):
192
+ for j in range(column):
193
+
194
+ if i == 0 and j == 0:
195
+ cache[i][j] = 0
196
+ elif i == 0 and j != 0:
197
+ cache[i][j] = j
198
+ elif j == 0 and i != 0:
199
+ cache[i][j] = i
200
+ else:
201
+ if word1[i - 1] == word2[j - 1]:
202
+ cache[i][j] = cache[i - 1][j - 1]
203
+ else:
204
+ replace = cache[i - 1][j - 1] + 1
205
+ insert = cache[i][j - 1] + 1
206
+ remove = cache[i - 1][j] + 1
207
+
208
+ cache[i][j] = min(replace, insert, remove)
209
+
210
+ return cache[row - 1][column - 1]
211
+
212
+
213
+ def remove_punctuation(text):
214
+ # 定义正则表达式模式,匹配所有标点符号
215
+ # 这个模式包括常见的标点符号和中文标点
216
+ pattern = r"[^\w\s]|_"
217
+
218
+ # 使用sub方法将所有匹配的标点符号替换为空字符串
219
+ cleaned_text = re.sub(pattern, "", text)
220
+
221
+ return cleaned_text
222
+
223
+
224
+ def main():
225
+ logger = setup_logging()
226
+ args = get_args()
227
+
228
+ language = args.language
229
+ max_num = args.max_num
230
+
231
+ dataset_type = args.dataset.lower()
232
+ if dataset_type == "aishell":
233
+ dataset = AIShellDataset(args.gt_path)
234
+ elif dataset_type == "common_voice":
235
+ dataset = CommonVoiceDataset(args.gt_path)
236
+ else:
237
+ raise ValueError(f"Unknown dataset type {dataset_type}")
238
+
239
+ model_root = "../sensevoice_ax650"
240
+ max_seq_len = 256
241
+ model_path = os.path.join(model_root, "sensevoice.axmodel")
242
+
243
+ assert os.path.exists(model_path), f"model {model_path} not exist"
244
+
245
+ cmvn_file = os.path.join(model_root, "am.mvn")
246
+ bpe_model = os.path.join(model_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")
247
+ token_file = os.path.join(model_root, "tokens.txt")
248
+
249
+ model = SenseVoiceAx(
250
+ model_path,
251
+ cmvn_file,
252
+ token_file,
253
+ bpe_model,
254
+ max_seq_len=max_seq_len,
255
+ beam_size=3,
256
+ hot_words=None,
257
+ streaming=False,
258
+ )
259
+
260
+ logger.info(f"dataset: {args.dataset}")
261
+ logger.info(f"language: {language}")
262
+ logger.info(f"model_path: {model_path}")
263
+
264
+ # Iterate over dataset
265
+ hyp = []
266
+ references = []
267
+ all_character_error_num = 0
268
+ all_character_num = 0
269
+ max_data_num = max_num if max_num > 0 else len(dataset)
270
+ for n, (audio_path, reference) in enumerate(dataset):
271
+ reference = remove_punctuation(reference).lower()
272
+
273
+ asr_res = model.infer(audio_path, language, print_rtf=False)
274
+ hypothesis = remove_punctuation(asr_res).lower()
275
+
276
+ character_error_num = min_distance(reference, hypothesis)
277
+ character_num = len(reference)
278
+ character_error_rate = character_error_num / character_num * 100
279
+
280
+ all_character_error_num += character_error_num
281
+ all_character_num += character_num
282
+
283
+ hyp.append(hypothesis)
284
+ references.append(reference)
285
+
286
+ line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
287
+ logger.info(line_content)
288
+
289
+ if n + 1 >= max_data_num:
290
+ break
291
+
292
+ total_character_error_rate = all_character_error_num / all_character_num * 100
293
+
294
+ logger.info(f"Total WER: {total_character_error_rate}%")
295
+
296
+
297
+ if __name__ == "__main__":
298
+ main()