inoryQwQ commited on
Commit
d1ae526
·
1 Parent(s): cfec603

Reformat codes

Browse files
Files changed (9) hide show
  1. SenseVoiceAx.py +164 -92
  2. download_utils.py +7 -3
  3. frontend.py +45 -14
  4. gradio_demo.py +14 -23
  5. main.py +25 -13
  6. print_utils.py +3 -1
  7. server.py +39 -25
  8. test_wer.py +89 -62
  9. tokenizer.py +5 -3
SenseVoiceAx.py CHANGED
@@ -4,7 +4,7 @@ import librosa
4
  from frontend import WavFrontend
5
  import os
6
  import time
7
- from typing import List, Union
8
  from asr_decoder import CTCDecoder
9
  from tokenizer import SentencepiecesTokenizer
10
  from online_fbank import OnlineFbank
@@ -15,93 +15,117 @@ def sequence_mask(lengths, maxlen=None, dtype=np.float32):
15
  # 如果 maxlen 未指定,则取 lengths 中的最大值
16
  if maxlen is None:
17
  maxlen = np.max(lengths)
18
-
19
  # 创建一个从 0 到 maxlen-1 的行向量
20
  row_vector = np.arange(0, maxlen, 1)
21
-
22
  # 将 lengths 转换为列向量
23
  matrix = np.expand_dims(lengths, axis=-1)
24
-
25
  # 比较生成掩码
26
  mask = row_vector < matrix
27
  if mask.shape[-1] < lengths[0]:
28
- mask = np.concatenate([mask, np.zeros((mask.shape[0], lengths[0] - mask.shape[-1]), dtype=np.float32)], axis=-1)
29
-
 
 
 
 
 
 
 
 
30
  # 返回指定数据类型的掩码
31
  return mask.astype(dtype)[None, ...]
32
-
33
 
34
  def unique_consecutive_np(arr):
35
  """
36
  找出数组中连续的唯一值,模拟 torch.unique_consecutive(yseq, dim=-1)
37
-
38
  参数:
39
  arr: 一维numpy数组
40
-
41
  返回:
42
  unique_values: 去除连续重复值后的数组
43
  """
44
  if len(arr) == 0:
45
  return np.array([])
46
-
47
  if len(arr) == 1:
48
  return arr.copy()
49
-
50
  # 找出变化的位置
51
  diff = np.diff(arr)
52
  change_positions = np.where(diff != 0)[0] + 1
53
-
54
  # 添加起始位置
55
  start_positions = np.concatenate(([0], change_positions))
56
-
57
  # 获取唯一值(每个连续段的第一个值)
58
  unique_values = arr[start_positions]
59
-
60
- return unique_values
61
-
62
 
63
- class Tokenizer:
64
- def __init__(self, symbol_path):
65
- self.symbol_tables = {}
66
- with open(symbol_path, 'r') as f:
67
- i = 0
68
- for line in f:
69
- token = line.strip()
70
- self.symbol_tables[token] = i
71
- i += 1
72
-
73
- def tokens2text(self, token):
74
- return self.symbol_tables[token]
75
 
76
 
77
  class SenseVoiceAx:
78
- def __init__(self, model_path,
79
- max_len=256,
80
- beam_size=3,
81
- language="auto",
82
- hot_words=Union[List[str], None],
83
- use_itn=True,
84
- streaming=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  model_path_root = os.path.dirname(model_path)
86
- emb_path = os.path.join(model_path_root, '../embeddings.npy')
87
- cmvn_file = os.path.join(model_path_root, '../am.mvn')
88
- bpe_model = os.path.join(model_path_root, '../chn_jpn_yue_eng_ko_spectok.bpe.model')
 
 
89
  if streaming:
90
- self.position_encoding = np.load(os.path.join(model_path_root, '../pe_streaming.npy'))
 
 
91
  else:
92
- self.position_encoding = np.load(os.path.join(model_path_root, '../pe_nonstream.npy'))
 
 
93
 
94
  self.streaming = streaming
95
  self.tokenizer = SentencepiecesTokenizer(bpemodel=bpe_model)
96
 
97
- self.frontend = WavFrontend(cmvn_file=cmvn_file,
98
- fs=16000,
99
- window="hamming",
100
- n_mels=80,
101
- frame_length=25,
102
- frame_shift=10,
103
- lfr_m=7,
104
- lfr_n=6)
 
 
105
  self.model = axe.InferenceSession(model_path)
106
  self.sample_rate = 16000
107
  self.blank_id = 0
@@ -109,11 +133,32 @@ class SenseVoiceAx:
109
  self.padding = 16
110
  self.input_size = 560
111
 
112
- self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
113
- self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  self.textnorm_dict = {"withitn": 14, "woitn": 15}
115
  self.textnorm_int_dict = {25016: 14, 25017: 15}
116
- self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
 
 
 
 
 
 
117
 
118
  self.load_embeddings(emb_path, language, use_itn)
119
  self.language = language
@@ -135,39 +180,48 @@ class SenseVoiceAx:
135
  self.caches_shape = (max_len, self.input_size)
136
  self.caches = np.zeros(self.caches_shape, dtype=np.float32)
137
  self.zeros = np.zeros((1, self.input_size), dtype=np.float32)
138
- self.neg_mean, self.inv_stddev = self.frontend.cmvn[0, :], self.frontend.cmvn[1, :]
 
 
 
139
 
140
  self.fbank = OnlineFbank(window_type="hamming")
141
- self.masks = sequence_mask(np.array([self.max_len], dtype=np.int32), maxlen=self.max_len, dtype=np.float32)
142
-
 
 
 
143
 
144
  @property
145
  def language_options(self):
146
  return list(self.lid_dict.keys())
147
-
148
  @property
149
  def textnorm_options(self):
150
  return list(self.textnorm_dict.keys())
151
-
152
  def load_embeddings(self, emb_path, language, use_itn):
153
  self.embeddings = np.load(emb_path, allow_pickle=True).item()
154
  self.language_query = self.embeddings[language]
155
- self.textnorm_query = self.embeddings['withitn'] if use_itn else self.embeddings['woitn']
156
- self.event_emo_query = self.embeddings['event_emo']
157
- self.input_query = np.concatenate((self.textnorm_query, self.language_query, self.event_emo_query), axis=1)
 
 
 
 
158
  self.query_num = self.input_query.shape[1]
159
 
160
-
161
  def choose_language(self, language):
162
  self.language_query = self.embeddings[language]
163
- self.input_query = np.concatenate((self.textnorm_query, self.language_query, self.event_emo_query), axis=1)
 
 
164
  self.language = language
165
 
166
-
167
  def load_data(self, filepath: str) -> np.ndarray:
168
  waveform, _ = librosa.load(filepath, sr=self.sample_rate)
169
  return waveform.flatten()
170
-
171
 
172
  @staticmethod
173
  def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
@@ -179,7 +233,6 @@ class SenseVoiceAx:
179
  feats = np.array(feat_res).astype(np.float32)
180
  return feats
181
 
182
-
183
  def preprocess(self, waveform):
184
  feats, feats_len = [], []
185
  for wf in [waveform]:
@@ -191,11 +244,10 @@ class SenseVoiceAx:
191
  feats = self.pad_feats(feats, np.max(feats_len))
192
  feats_len = np.array(feats_len).astype(np.int32)
193
  return feats, feats_len
194
-
195
 
196
  def postprocess(self, ctc_logits, encoder_out_lens):
197
  # 提取数据
198
- x = ctc_logits[0, 4:encoder_out_lens[0], :]
199
 
200
  # 获取最大值索引
201
  yseq = np.argmax(x, axis=-1)
@@ -208,7 +260,6 @@ class SenseVoiceAx:
208
  token_int = yseq[mask].tolist()
209
 
210
  return token_int
211
-
212
 
213
  def infer_waveform(self, waveform: np.ndarray, language="auto"):
214
  if language != self.language:
@@ -224,32 +275,46 @@ class SenseVoiceAx:
224
  asr_res = []
225
  for i in range(slice_num):
226
  if i == 0:
227
- sub_feat = feat[:, i*slice_len:(i+1)*slice_len, :]
228
  else:
229
- sub_feat = feat[:, i*slice_len - self.padding:(i+1)*slice_len - self.padding, :]
 
 
 
 
230
  # concat query
231
  sub_feat = np.concatenate([self.input_query, sub_feat], axis=1)
232
  real_len = sub_feat.shape[1]
233
  if real_len < self.max_len:
234
- sub_feat = np.concatenate([
235
- sub_feat,
236
- np.zeros((1, self.max_len - real_len, sub_feat.shape[-1]), dtype=np.float32)
 
 
 
 
237
  ],
238
- axis=1)
239
-
240
- masks = sequence_mask(np.array([self.max_len], dtype=np.int32), maxlen=real_len, dtype=np.float32)
 
 
 
 
 
241
 
242
  # start = time.time()
243
- outputs = self.model.run(None, {"speech": sub_feat,
244
- "masks": masks,
245
- "position_encoding": self.position_encoding})
 
 
 
 
 
246
  ctc_logits, encoder_out_lens = outputs
247
- # print(f"ctc_logits.shape: {ctc_logits.shape}")
248
- # print(f"Run model take {time.time() - start}s")
249
 
250
- # start = time.time()
251
  token_int = self.postprocess(ctc_logits, encoder_out_lens)
252
- # print(f"Postprocess take {time.time() - start}s")
253
 
254
  if self.tokenizer is not None:
255
  asr_res.append(self.tokenizer.tokens2text(token_int))
@@ -257,9 +322,12 @@ class SenseVoiceAx:
257
  asr_res.append(token_int)
258
 
259
  return asr_res
260
-
261
 
262
- def infer(self, filepath_or_data: Union[np.ndarray, str], language="auto", print_rtf=True):
 
 
 
 
263
  if isinstance(filepath_or_data, str):
264
  waveform = self.load_data(filepath_or_data)
265
  else:
@@ -284,22 +352,21 @@ class SenseVoiceAx:
284
  times_ms.append(step * 60)
285
  return times_ms, self.tokenizer.decode(tokens)
286
 
287
-
288
  def reset(self):
289
  self.cur_idx = -1
290
  self.decoder.reset()
291
  self.fbank = OnlineFbank(window_type="hamming")
292
  self.caches = np.zeros(self.caches_shape)
293
 
294
-
295
  def get_size(self):
296
  effective_size = self.cur_idx + 1 - self.padding
297
  if effective_size <= 0:
298
  return 0
299
  return effective_size % self.chunk_size or self.chunk_size
300
-
301
 
302
  def stream_infer(self, audio, is_last, language="auto"):
 
 
303
  if language != self.language:
304
  self.choose_language(language)
305
 
@@ -321,13 +388,18 @@ class SenseVoiceAx:
321
  continue
322
 
323
  speech = self.caches[None, ...]
324
- outputs = self.model.run(None, {"speech": speech,
325
- "masks": self.masks,
326
- "position_encoding": self.position_encoding})
 
 
 
 
 
327
  ctc_logits, encoder_out_lens = outputs
328
- probs = ctc_logits[0, 4:encoder_out_lens[0]]
329
  probs = torch.from_numpy(probs)
330
-
331
  if cur_size != self.chunk_size:
332
  probs = probs[self.chunk_size - cur_size :]
333
  if not is_last:
 
4
  from frontend import WavFrontend
5
  import os
6
  import time
7
+ from typing import List, Union, Optional
8
  from asr_decoder import CTCDecoder
9
  from tokenizer import SentencepiecesTokenizer
10
  from online_fbank import OnlineFbank
 
15
  # 如果 maxlen 未指定,则取 lengths 中的最大值
16
  if maxlen is None:
17
  maxlen = np.max(lengths)
18
+
19
  # 创建一个从 0 到 maxlen-1 的行向量
20
  row_vector = np.arange(0, maxlen, 1)
21
+
22
  # 将 lengths 转换为列向量
23
  matrix = np.expand_dims(lengths, axis=-1)
24
+
25
  # 比较生成掩码
26
  mask = row_vector < matrix
27
  if mask.shape[-1] < lengths[0]:
28
+ mask = np.concatenate(
29
+ [
30
+ mask,
31
+ np.zeros(
32
+ (mask.shape[0], lengths[0] - mask.shape[-1]), dtype=np.float32
33
+ ),
34
+ ],
35
+ axis=-1,
36
+ )
37
+
38
  # 返回指定数据类型的掩码
39
  return mask.astype(dtype)[None, ...]
40
+
41
 
42
  def unique_consecutive_np(arr):
43
  """
44
  找出数组中连续的唯一值,模拟 torch.unique_consecutive(yseq, dim=-1)
45
+
46
  参数:
47
  arr: 一维numpy数组
48
+
49
  返回:
50
  unique_values: 去除连续重复值后的数组
51
  """
52
  if len(arr) == 0:
53
  return np.array([])
54
+
55
  if len(arr) == 1:
56
  return arr.copy()
57
+
58
  # 找出变化的位置
59
  diff = np.diff(arr)
60
  change_positions = np.where(diff != 0)[0] + 1
61
+
62
  # 添加起始位置
63
  start_positions = np.concatenate(([0], change_positions))
64
+
65
  # 获取唯一值(每个连续段的第一个值)
66
  unique_values = arr[start_positions]
 
 
 
67
 
68
+ return unique_values
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  class SenseVoiceAx:
72
+ """ SenseVoice axmodel runner """
73
+
74
+ def __init__(
75
+ self,
76
+ model_path: str,
77
+ max_len: int = 256,
78
+ beam_size: int = 3,
79
+ language: str = "auto",
80
+ hot_words: Optional[List[str]] = None,
81
+ use_itn: bool = True,
82
+ streaming: bool = False,
83
+ ):
84
+ """
85
+ Initialize SenseVoiceAx
86
+
87
+ Args:
88
+ model_path: Path of axmodel
89
+ max_len: Fixed shape of input of axmodel
90
+ beam_size: Max number of hypos to hold after each decode step
91
+ language: Support auto, zh(Chinese), en(English), yue(Cantonese), ja(Japanese), ko(Korean)
92
+ hot_words: Words that may fail to recognize,
93
+ special words/phrases (aka hotwords) like rare words, personalized information etc.
94
+ use_itn: Allow Invert Text Normalization if True,
95
+ ITN converts ASR model output into its written form to improve text readability,
96
+ For example, the ITN module replaces “one hundred and twenty-three dollars” transcribed by an ASR model with “$123.”
97
+ streaming: Processes audio in small segments or "chunks" sequentially and outputs text on the fly.
98
+ Use stream_infer method if streaming is true otherwise infer.
99
+
100
+ """
101
  model_path_root = os.path.dirname(model_path)
102
+ emb_path = os.path.join(model_path_root, "../embeddings.npy")
103
+ cmvn_file = os.path.join(model_path_root, "../am.mvn")
104
+ bpe_model = os.path.join(
105
+ model_path_root, "../chn_jpn_yue_eng_ko_spectok.bpe.model"
106
+ )
107
  if streaming:
108
+ self.position_encoding = np.load(
109
+ os.path.join(model_path_root, "../pe_streaming.npy")
110
+ )
111
  else:
112
+ self.position_encoding = np.load(
113
+ os.path.join(model_path_root, "../pe_nonstream.npy")
114
+ )
115
 
116
  self.streaming = streaming
117
  self.tokenizer = SentencepiecesTokenizer(bpemodel=bpe_model)
118
 
119
+ self.frontend = WavFrontend(
120
+ cmvn_file=cmvn_file,
121
+ fs=16000,
122
+ window="hamming",
123
+ n_mels=80,
124
+ frame_length=25,
125
+ frame_shift=10,
126
+ lfr_m=7,
127
+ lfr_n=6,
128
+ )
129
  self.model = axe.InferenceSession(model_path)
130
  self.sample_rate = 16000
131
  self.blank_id = 0
 
133
  self.padding = 16
134
  self.input_size = 560
135
 
136
+ self.lid_dict = {
137
+ "auto": 0,
138
+ "zh": 3,
139
+ "en": 4,
140
+ "yue": 7,
141
+ "ja": 11,
142
+ "ko": 12,
143
+ "nospeech": 13,
144
+ }
145
+ self.lid_int_dict = {
146
+ 24884: 3,
147
+ 24885: 4,
148
+ 24888: 7,
149
+ 24892: 11,
150
+ 24896: 12,
151
+ 24992: 13,
152
+ }
153
  self.textnorm_dict = {"withitn": 14, "woitn": 15}
154
  self.textnorm_int_dict = {25016: 14, 25017: 15}
155
+ self.emo_dict = {
156
+ "unk": 25009,
157
+ "happy": 25001,
158
+ "sad": 25002,
159
+ "angry": 25003,
160
+ "neutral": 25004,
161
+ }
162
 
163
  self.load_embeddings(emb_path, language, use_itn)
164
  self.language = language
 
180
  self.caches_shape = (max_len, self.input_size)
181
  self.caches = np.zeros(self.caches_shape, dtype=np.float32)
182
  self.zeros = np.zeros((1, self.input_size), dtype=np.float32)
183
+ self.neg_mean, self.inv_stddev = (
184
+ self.frontend.cmvn[0, :],
185
+ self.frontend.cmvn[1, :],
186
+ )
187
 
188
  self.fbank = OnlineFbank(window_type="hamming")
189
+ self.masks = sequence_mask(
190
+ np.array([self.max_len], dtype=np.int32),
191
+ maxlen=self.max_len,
192
+ dtype=np.float32,
193
+ )
194
 
195
  @property
196
  def language_options(self):
197
  return list(self.lid_dict.keys())
198
+
199
  @property
200
  def textnorm_options(self):
201
  return list(self.textnorm_dict.keys())
202
+
203
  def load_embeddings(self, emb_path, language, use_itn):
204
  self.embeddings = np.load(emb_path, allow_pickle=True).item()
205
  self.language_query = self.embeddings[language]
206
+ self.textnorm_query = (
207
+ self.embeddings["withitn"] if use_itn else self.embeddings["woitn"]
208
+ )
209
+ self.event_emo_query = self.embeddings["event_emo"]
210
+ self.input_query = np.concatenate(
211
+ (self.textnorm_query, self.language_query, self.event_emo_query), axis=1
212
+ )
213
  self.query_num = self.input_query.shape[1]
214
 
 
215
  def choose_language(self, language):
216
  self.language_query = self.embeddings[language]
217
+ self.input_query = np.concatenate(
218
+ (self.textnorm_query, self.language_query, self.event_emo_query), axis=1
219
+ )
220
  self.language = language
221
 
 
222
  def load_data(self, filepath: str) -> np.ndarray:
223
  waveform, _ = librosa.load(filepath, sr=self.sample_rate)
224
  return waveform.flatten()
 
225
 
226
  @staticmethod
227
  def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
 
233
  feats = np.array(feat_res).astype(np.float32)
234
  return feats
235
 
 
236
  def preprocess(self, waveform):
237
  feats, feats_len = [], []
238
  for wf in [waveform]:
 
244
  feats = self.pad_feats(feats, np.max(feats_len))
245
  feats_len = np.array(feats_len).astype(np.int32)
246
  return feats, feats_len
 
247
 
248
  def postprocess(self, ctc_logits, encoder_out_lens):
249
  # 提取数据
250
+ x = ctc_logits[0, 4 : encoder_out_lens[0], :]
251
 
252
  # 获取最大值索引
253
  yseq = np.argmax(x, axis=-1)
 
260
  token_int = yseq[mask].tolist()
261
 
262
  return token_int
 
263
 
264
  def infer_waveform(self, waveform: np.ndarray, language="auto"):
265
  if language != self.language:
 
275
  asr_res = []
276
  for i in range(slice_num):
277
  if i == 0:
278
+ sub_feat = feat[:, i * slice_len : (i + 1) * slice_len, :]
279
  else:
280
+ sub_feat = feat[
281
+ :,
282
+ i * slice_len - self.padding : (i + 1) * slice_len - self.padding,
283
+ :,
284
+ ]
285
  # concat query
286
  sub_feat = np.concatenate([self.input_query, sub_feat], axis=1)
287
  real_len = sub_feat.shape[1]
288
  if real_len < self.max_len:
289
+ sub_feat = np.concatenate(
290
+ [
291
+ sub_feat,
292
+ np.zeros(
293
+ (1, self.max_len - real_len, sub_feat.shape[-1]),
294
+ dtype=np.float32,
295
+ ),
296
  ],
297
+ axis=1,
298
+ )
299
+
300
+ masks = sequence_mask(
301
+ np.array([self.max_len], dtype=np.int32),
302
+ maxlen=real_len,
303
+ dtype=np.float32,
304
+ )
305
 
306
  # start = time.time()
307
+ outputs = self.model.run(
308
+ None,
309
+ {
310
+ "speech": sub_feat,
311
+ "masks": masks,
312
+ "position_encoding": self.position_encoding,
313
+ },
314
+ )
315
  ctc_logits, encoder_out_lens = outputs
 
 
316
 
 
317
  token_int = self.postprocess(ctc_logits, encoder_out_lens)
 
318
 
319
  if self.tokenizer is not None:
320
  asr_res.append(self.tokenizer.tokens2text(token_int))
 
322
  asr_res.append(token_int)
323
 
324
  return asr_res
 
325
 
326
+ def infer(
327
+ self, filepath_or_data: Union[np.ndarray, str], language="auto", print_rtf=False
328
+ ):
329
+ assert not self.streaming, "This method is for non-streaming model"
330
+
331
  if isinstance(filepath_or_data, str):
332
  waveform = self.load_data(filepath_or_data)
333
  else:
 
352
  times_ms.append(step * 60)
353
  return times_ms, self.tokenizer.decode(tokens)
354
 
 
355
  def reset(self):
356
  self.cur_idx = -1
357
  self.decoder.reset()
358
  self.fbank = OnlineFbank(window_type="hamming")
359
  self.caches = np.zeros(self.caches_shape)
360
 
 
361
  def get_size(self):
362
  effective_size = self.cur_idx + 1 - self.padding
363
  if effective_size <= 0:
364
  return 0
365
  return effective_size % self.chunk_size or self.chunk_size
 
366
 
367
  def stream_infer(self, audio, is_last, language="auto"):
368
+ assert self.streaming, "This method is for streaming model"
369
+
370
  if language != self.language:
371
  self.choose_language(language)
372
 
 
388
  continue
389
 
390
  speech = self.caches[None, ...]
391
+ outputs = self.model.run(
392
+ None,
393
+ {
394
+ "speech": speech,
395
+ "masks": self.masks,
396
+ "position_encoding": self.position_encoding,
397
+ },
398
+ )
399
  ctc_logits, encoder_out_lens = outputs
400
+ probs = ctc_logits[0, 4 : encoder_out_lens[0]]
401
  probs = torch.from_numpy(probs)
402
+
403
  if cur_size != self.chunk_size:
404
  probs = probs[self.chunk_size - cur_size :]
405
  if not is_last:
download_utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  # Speed up hf download using mirror url
3
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
4
  from huggingface_hub import snapshot_download
@@ -7,6 +8,7 @@ current_file_path = os.path.dirname(__file__)
7
  REPO_ROOT = "AXERA-TECH"
8
  CACHE_PATH = os.path.join(current_file_path, "models")
9
 
 
10
  def download_model(model_name: str) -> str:
11
  """
12
  Download model from AXERA-TECH's huggingface space.
@@ -23,7 +25,9 @@ def download_model(model_name: str) -> str:
23
  model_path = os.path.join(CACHE_PATH, model_name)
24
  if not os.path.exists(model_path):
25
  print(f"Downloading {model_name}...")
26
- snapshot_download(repo_id=f"{REPO_ROOT}/{model_name}",
27
- local_dir=os.path.join(CACHE_PATH, model_name))
28
-
 
 
29
  return model_path
 
1
  import os
2
+
3
  # Speed up hf download using mirror url
4
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
5
  from huggingface_hub import snapshot_download
 
8
  REPO_ROOT = "AXERA-TECH"
9
  CACHE_PATH = os.path.join(current_file_path, "models")
10
 
11
+
12
  def download_model(model_name: str) -> str:
13
  """
14
  Download model from AXERA-TECH's huggingface space.
 
25
  model_path = os.path.join(CACHE_PATH, model_name)
26
  if not os.path.exists(model_path):
27
  print(f"Downloading {model_name}...")
28
+ snapshot_download(
29
+ repo_id=f"{REPO_ROOT}/{model_name}",
30
+ local_dir=os.path.join(CACHE_PATH, model_name),
31
+ )
32
+
33
  return model_path
frontend.py CHANGED
@@ -96,7 +96,9 @@ class WavFrontend:
96
  T = T + (lfr_m - 1) // 2
97
  for i in range(T_lfr):
98
  if lfr_m <= T - i * lfr_n:
99
- LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
 
 
100
  else:
101
  # process last LFR frame
102
  num_padding = lfr_m - (T - i * lfr_n)
@@ -180,7 +182,9 @@ class WavFrontendOnline(WavFrontend):
180
  splice_idx = T_lfr
181
  for i in range(T_lfr):
182
  if lfr_m <= T - i * lfr_n:
183
- LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
 
 
184
  else: # process last LFR frame
185
  if is_final:
186
  num_padding = lfr_m - (T - i * lfr_n)
@@ -201,8 +205,12 @@ class WavFrontendOnline(WavFrontend):
201
  def compute_frame_num(
202
  sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
203
  ) -> int:
204
- frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
205
- return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
 
 
 
 
206
 
207
  def fbank(
208
  self, input: np.ndarray, input_lengths: np.ndarray
@@ -238,7 +246,9 @@ class WavFrontendOnline(WavFrontend):
238
  )
239
  waveform = waveform * (1 << 15)
240
 
241
- self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
 
 
242
  frames = self.fbank_fn.num_frames_ready
243
  mat = np.empty([frames, self.opts.mel_opts.num_bins])
244
  for i in range(frames):
@@ -291,7 +301,9 @@ class WavFrontendOnline(WavFrontend):
291
  assert (
292
  batch_size == 1
293
  ), "we support to extract feature online only when the batch size is equal to 1 now"
294
- waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
 
 
295
  if feats.shape[0]:
296
  self.waveforms = (
297
  waveforms
@@ -301,7 +313,9 @@ class WavFrontendOnline(WavFrontend):
301
  if not self.lfr_splice_cache:
302
  for i in range(batch_size):
303
  self.lfr_splice_cache.append(
304
- np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
 
 
305
  )
306
 
307
  if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
@@ -313,7 +327,9 @@ class WavFrontendOnline(WavFrontend):
313
  / self.frame_shift_sample_length
314
  + 1
315
  )
316
- minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
 
 
317
  feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
318
  feats, feats_lengths, is_final
319
  )
@@ -346,7 +362,9 @@ class WavFrontendOnline(WavFrontend):
346
  else:
347
  if is_final:
348
  self.waveforms = (
349
- waveforms if self.reserve_waveforms is None else self.reserve_waveforms
 
 
350
  )
351
  feats = np.stack(self.lfr_splice_cache)
352
  feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
@@ -377,20 +395,33 @@ def load_bytes(input):
377
  i = np.iinfo(middle_data.dtype)
378
  abs_max = 2 ** (i.bits - 1)
379
  offset = i.min + abs_max
380
- array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
 
 
381
  return array
382
 
383
 
384
  class SinusoidalPositionEncoderOnline:
385
  """Streaming Positional encoding."""
386
 
387
- def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
 
 
 
 
 
388
  batch_size = positions.shape[0]
389
  positions = positions.astype(dtype)
390
- log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
391
- inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
 
 
 
 
392
  inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
393
- scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
 
 
394
  encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
395
  return encoding.astype(dtype)
396
 
 
96
  T = T + (lfr_m - 1) // 2
97
  for i in range(T_lfr):
98
  if lfr_m <= T - i * lfr_n:
99
+ LFR_inputs.append(
100
+ (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
101
+ )
102
  else:
103
  # process last LFR frame
104
  num_padding = lfr_m - (T - i * lfr_n)
 
182
  splice_idx = T_lfr
183
  for i in range(T_lfr):
184
  if lfr_m <= T - i * lfr_n:
185
+ LFR_inputs.append(
186
+ (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
187
+ )
188
  else: # process last LFR frame
189
  if is_final:
190
  num_padding = lfr_m - (T - i * lfr_n)
 
205
  def compute_frame_num(
206
  sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
207
  ) -> int:
208
+ frame_num = int(
209
+ (sample_length - frame_sample_length) / frame_shift_sample_length + 1
210
+ )
211
+ return (
212
+ frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
213
+ )
214
 
215
  def fbank(
216
  self, input: np.ndarray, input_lengths: np.ndarray
 
246
  )
247
  waveform = waveform * (1 << 15)
248
 
249
+ self.fbank_fn.accept_waveform(
250
+ self.opts.frame_opts.samp_freq, waveform.tolist()
251
+ )
252
  frames = self.fbank_fn.num_frames_ready
253
  mat = np.empty([frames, self.opts.mel_opts.num_bins])
254
  for i in range(frames):
 
301
  assert (
302
  batch_size == 1
303
  ), "we support to extract feature online only when the batch size is equal to 1 now"
304
+ waveforms, feats, feats_lengths = self.fbank(
305
+ input, input_lengths
306
+ ) # input shape: B T D
307
  if feats.shape[0]:
308
  self.waveforms = (
309
  waveforms
 
313
  if not self.lfr_splice_cache:
314
  for i in range(batch_size):
315
  self.lfr_splice_cache.append(
316
+ np.expand_dims(feats[i][0, :], axis=0).repeat(
317
+ (self.lfr_m - 1) // 2, axis=0
318
+ )
319
  )
320
 
321
  if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
 
327
  / self.frame_shift_sample_length
328
  + 1
329
  )
330
+ minus_frame = (
331
+ (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
332
+ )
333
  feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
334
  feats, feats_lengths, is_final
335
  )
 
362
  else:
363
  if is_final:
364
  self.waveforms = (
365
+ waveforms
366
+ if self.reserve_waveforms is None
367
+ else self.reserve_waveforms
368
  )
369
  feats = np.stack(self.lfr_splice_cache)
370
  feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
 
395
  i = np.iinfo(middle_data.dtype)
396
  abs_max = 2 ** (i.bits - 1)
397
  offset = i.min + abs_max
398
+ array = np.frombuffer(
399
+ (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32
400
+ )
401
  return array
402
 
403
 
404
  class SinusoidalPositionEncoderOnline:
405
  """Streaming Positional encoding."""
406
 
407
+ def encode(
408
+ self,
409
+ positions: np.ndarray = None,
410
+ depth: int = None,
411
+ dtype: np.dtype = np.float32,
412
+ ):
413
  batch_size = positions.shape[0]
414
  positions = positions.astype(dtype)
415
+ log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (
416
+ depth / 2 - 1
417
+ )
418
+ inv_timescales = np.exp(
419
+ np.arange(depth / 2).astype(dtype) * (-log_timescale_increment)
420
+ )
421
  inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
422
+ scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(
423
+ inv_timescales, [1, 1, -1]
424
+ )
425
  encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
426
  return encoding.astype(dtype)
427
 
gradio_demo.py CHANGED
@@ -5,7 +5,7 @@ from tokenizer import SentencepiecesTokenizer
5
  from print_utils import rich_transcription_postprocess
6
  from download_utils import download_model
7
 
8
- use_itn = True # 标点符号预测
9
  max_len = 256
10
 
11
  model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
@@ -14,11 +14,10 @@ bpemodel = "chn_jpn_yue_eng_ko_spectok.bpe.model"
14
  assert os.path.exists(model_path), f"model {model_path} not exist"
15
 
16
  tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
17
- pipeline = SenseVoiceAx(model_path,
18
- max_len=max_len,
19
- language="auto",
20
- use_itn=use_itn,
21
- tokenizer=tokenizer)
22
 
23
  def speech_to_text(audio_path, lang):
24
  """
@@ -27,7 +26,7 @@ def speech_to_text(audio_path, lang):
27
  """
28
  if not audio_path:
29
  return "无音频"
30
-
31
  pipeline.choose_language(language=lang)
32
  asr_res = pipeline.infer(audio_path, print_rtf=True)
33
  res = " ".join([rich_transcription_postprocess(i) for i in asr_res])
@@ -38,34 +37,26 @@ def speech_to_text(audio_path, lang):
38
  def main():
39
  with gr.Blocks() as demo:
40
  with gr.Row():
41
- output_text = gr.Textbox(
42
- label="识别结果",
43
- lines=5
44
- )
45
-
46
 
47
  with gr.Row():
48
  audio_input = gr.Audio(
49
- sources=["upload"],
50
- type="filepath",
51
- label="录制或上传音频",
52
- format="mp3"
53
  )
54
  lang_dropdown = gr.Dropdown(
55
  choices=["auto", "zh", "en", "yue", "ja", "ko"],
56
  value="auto",
57
- label="选择音频语言"
58
  )
59
 
60
  audio_input.change(
61
- fn=speech_to_text,
62
- inputs=[audio_input, lang_dropdown],
63
- outputs=output_text
64
  )
65
 
66
  demo.launch(
67
- server_name="0.0.0.0",
68
- )
 
69
 
70
  if __name__ == "__main__":
71
- main()
 
5
  from print_utils import rich_transcription_postprocess
6
  from download_utils import download_model
7
 
8
+ use_itn = True # 标点符号预测
9
  max_len = 256
10
 
11
  model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
 
14
  assert os.path.exists(model_path), f"model {model_path} not exist"
15
 
16
  tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
17
+ pipeline = SenseVoiceAx(
18
+ model_path, max_len=max_len, language="auto", use_itn=use_itn, tokenizer=tokenizer
19
+ )
20
+
 
21
 
22
  def speech_to_text(audio_path, lang):
23
  """
 
26
  """
27
  if not audio_path:
28
  return "无音频"
29
+
30
  pipeline.choose_language(language=lang)
31
  asr_res = pipeline.infer(audio_path, print_rtf=True)
32
  res = " ".join([rich_transcription_postprocess(i) for i in asr_res])
 
37
  def main():
38
  with gr.Blocks() as demo:
39
  with gr.Row():
40
+ output_text = gr.Textbox(label="识别结果", lines=5)
 
 
 
 
41
 
42
  with gr.Row():
43
  audio_input = gr.Audio(
44
+ sources=["upload"], type="filepath", label="录制或上传音频", format="mp3"
 
 
 
45
  )
46
  lang_dropdown = gr.Dropdown(
47
  choices=["auto", "zh", "en", "yue", "ja", "ko"],
48
  value="auto",
49
+ label="选择音频语言",
50
  )
51
 
52
  audio_input.change(
53
+ fn=speech_to_text, inputs=[audio_input, lang_dropdown], outputs=output_text
 
 
54
  )
55
 
56
  demo.launch(
57
+ server_name="0.0.0.0",
58
+ )
59
+
60
 
61
  if __name__ == "__main__":
62
+ main()
main.py CHANGED
@@ -8,8 +8,17 @@ import time
8
 
9
  def get_args():
10
  parser = argparse.ArgumentParser()
11
- parser.add_argument("--input", "-i", required=True, type=str, help="Input audio file")
12
- parser.add_argument("--language", "-l", required=False, type=str, default="auto", choices=["auto", "zh", "en", "yue", "ja", "ko"])
 
 
 
 
 
 
 
 
 
13
  parser.add_argument("--streaming", action="store_true")
14
  return parser.parse_args()
15
 
@@ -19,7 +28,7 @@ def main():
19
 
20
  input_audio = args.input
21
  language = args.language
22
- use_itn = True # 标点符号预测
23
  if not args.streaming:
24
  max_len = 256
25
  model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
@@ -35,14 +44,16 @@ def main():
35
  print(f"model_path: {model_path}")
36
  print(f"streaming: {args.streaming}")
37
 
38
- pipeline = SenseVoiceAx(model_path,
39
- max_len=max_len,
40
- beam_size=3,
41
- language="auto",
42
- hot_words=None,
43
- use_itn=True,
44
- streaming=args.streaming)
45
-
 
 
46
  if not args.streaming:
47
  asr_res = pipeline.infer(input_audio, print_rtf=True)
48
  print("ASR result: " + asr_res)
@@ -57,11 +68,12 @@ def main():
57
  is_last = i + step >= len(samples)
58
  for res in pipeline.stream_infer(samples[i : i + step], is_last):
59
  print(res)
60
-
61
  end = time.time()
62
  cost_time = end - start
63
 
64
  print(f"RTF: {cost_time / duration}")
65
 
 
66
  if __name__ == "__main__":
67
- main()
 
8
 
9
  def get_args():
10
  parser = argparse.ArgumentParser()
11
+ parser.add_argument(
12
+ "--input", "-i", required=True, type=str, help="Input audio file"
13
+ )
14
+ parser.add_argument(
15
+ "--language",
16
+ "-l",
17
+ required=False,
18
+ type=str,
19
+ default="auto",
20
+ choices=["auto", "zh", "en", "yue", "ja", "ko"],
21
+ )
22
  parser.add_argument("--streaming", action="store_true")
23
  return parser.parse_args()
24
 
 
28
 
29
  input_audio = args.input
30
  language = args.language
31
+ use_itn = True # 标点符号预测
32
  if not args.streaming:
33
  max_len = 256
34
  model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
 
44
  print(f"model_path: {model_path}")
45
  print(f"streaming: {args.streaming}")
46
 
47
+ pipeline = SenseVoiceAx(
48
+ model_path,
49
+ max_len=max_len,
50
+ beam_size=3,
51
+ language="auto",
52
+ hot_words=None,
53
+ use_itn=True,
54
+ streaming=args.streaming,
55
+ )
56
+
57
  if not args.streaming:
58
  asr_res = pipeline.infer(input_audio, print_rtf=True)
59
  print("ASR result: " + asr_res)
 
68
  is_last = i + step >= len(samples)
69
  for res in pipeline.stream_infer(samples[i : i + step], is_last):
70
  print(res)
71
+
72
  end = time.time()
73
  cost_time = end - start
74
 
75
  print(f"RTF: {cost_time / duration}")
76
 
77
+
78
  if __name__ == "__main__":
79
+ main()
print_utils.py CHANGED
@@ -90,6 +90,7 @@ def format_str_v2(s):
90
  s = s.replace(emoji + " ", emoji)
91
  return s.strip()
92
 
 
93
  def rich_transcription_postprocess(s):
94
  def get_emo(s):
95
  return s[-1] if s[-1] in emo_set else None
@@ -116,6 +117,7 @@ def rich_transcription_postprocess(s):
116
  new_s = new_s.replace("The.", " ")
117
  return new_s.strip()
118
 
 
119
  def rich_print_asr_res(asr_res, will_print=True, remove_punc=False):
120
  res = "".join([rich_transcription_postprocess(i) for i in asr_res])
121
 
@@ -126,4 +128,4 @@ def rich_print_asr_res(asr_res, will_print=True, remove_punc=False):
126
  if will_print:
127
  print(res)
128
 
129
- return res
 
90
  s = s.replace(emoji + " ", emoji)
91
  return s.strip()
92
 
93
+
94
  def rich_transcription_postprocess(s):
95
  def get_emo(s):
96
  return s[-1] if s[-1] in emo_set else None
 
117
  new_s = new_s.replace("The.", " ")
118
  return new_s.strip()
119
 
120
+
121
  def rich_print_asr_res(asr_res, will_print=True, remove_punc=False):
122
  res = "".join([rich_transcription_postprocess(i) for i in asr_res])
123
 
 
128
  if will_print:
129
  print(res)
130
 
131
+ return res
server.py CHANGED
@@ -20,6 +20,7 @@ app = FastAPI(title="ASR Server", description="Automatic Speech Recognition API"
20
  # 全局变量存储模型
21
  asr_model = None
22
 
 
23
  @app.on_event("startup")
24
  async def load_model():
25
  """
@@ -27,11 +28,11 @@ async def load_model():
27
  """
28
  global asr_model
29
  logger.info("Loading ASR model...")
30
-
31
  try:
32
  # 模型加载
33
  language = "auto"
34
- use_itn = True # 标点符号预测
35
  max_len = 256
36
 
37
  model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
@@ -44,63 +45,74 @@ async def load_model():
44
  print(f"model_path: {model_path}")
45
 
46
  tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
47
- asr_model = SenseVoiceAx(model_path,
48
- max_len=max_len,
49
- language=language,
50
- use_itn=use_itn,
51
- tokenizer=tokenizer)
52
-
 
 
53
  logger.info("ASR model loaded successfully")
54
  except Exception as e:
55
  logger.error(f"Failed to load ASR model: {str(e)}")
56
  raise
57
 
 
58
  def validate_audio_data(audio_data: List[float]) -> np.ndarray:
59
  """
60
  验证并转换音频数据为numpy数组
61
-
62
  参数:
63
  - audio_data: 浮点数列表表示的音频数据
64
-
65
  返回:
66
  - 验证后的numpy数组
67
  """
68
  try:
69
  # 转换为numpy数组
70
  np_array = np.array(audio_data, dtype=np.float32)
71
-
72
  # 验证数据有效性
73
  if np_array.ndim != 1:
74
  raise ValueError("Audio data must be 1-dimensional")
75
-
76
  if len(np_array) == 0:
77
  raise ValueError("Audio data cannot be empty")
78
-
79
  return np_array
80
  except Exception as e:
81
  raise ValueError(f"Invalid audio data: {str(e)}")
82
-
 
83
  @app.get("/get_language", summary="Get current language")
84
  async def get_language():
85
  return JSONResponse(content={"language": asr_model.language})
86
 
87
- @app.get("/get_language_options", summary="Get possible language options, possible options include [auto, zh, en, yue, ja, ko]")
 
 
 
 
88
  async def get_language_options():
89
  return JSONResponse(content={"language_options": asr_model.language_options})
90
 
 
91
  @app.post("/asr", summary="Recognize speech from numpy audio data")
92
  async def recognize_speech(
93
- audio_data: List[float] = Body(..., embed=True, description="Audio data as list of floats"),
 
 
94
  sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
95
- language: Optional[str] = Body("auto", description="Language")
96
  ):
97
  """
98
  接收numpy数组格式的音频数据并返回识别结果
99
-
100
  参数:
101
  - audio_data: 浮点数列表表示的音频数据
102
  - sample_rate: 音频采样率(默认16000Hz)
103
-
104
  返回:
105
  - JSON包含识别文本
106
  """
@@ -108,19 +120,19 @@ async def recognize_speech(
108
  # 检查模型是否已加载
109
  if asr_model is None:
110
  raise HTTPException(status_code=503, detail="ASR model not loaded")
111
-
112
  logger.info(f"Received audio data with length: {len(audio_data)}")
113
-
114
  # 验证并转换数据
115
  np_audio = validate_audio_data(audio_data)
116
  if sample_rate != asr_model.sample_rate:
117
  np_audio = librosa.resample(np_audio, sample_rate, asr_model.sample_rate)
118
-
119
  # 调用模型进行识别
120
  result = asr_model.infer_waveform(np_audio, language)
121
-
122
  return JSONResponse(content={"text": result})
123
-
124
  except ValueError as e:
125
  logger.error(f"Validation error: {str(e)}")
126
  raise HTTPException(status_code=400, detail=str(e))
@@ -128,6 +140,8 @@ async def recognize_speech(
128
  logger.error(f"Recognition error: {str(e)}")
129
  raise HTTPException(status_code=500, detail=str(e))
130
 
 
131
  if __name__ == "__main__":
132
  import uvicorn
133
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
20
  # 全局变量存储模型
21
  asr_model = None
22
 
23
+
24
  @app.on_event("startup")
25
  async def load_model():
26
  """
 
28
  """
29
  global asr_model
30
  logger.info("Loading ASR model...")
31
+
32
  try:
33
  # 模型加载
34
  language = "auto"
35
+ use_itn = True # 标点符号预测
36
  max_len = 256
37
 
38
  model_path = os.path.join("sensevoice_ax650", "sensevoice.axmodel")
 
45
  print(f"model_path: {model_path}")
46
 
47
  tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
48
+ asr_model = SenseVoiceAx(
49
+ model_path,
50
+ max_len=max_len,
51
+ language=language,
52
+ use_itn=use_itn,
53
+ tokenizer=tokenizer,
54
+ )
55
+
56
  logger.info("ASR model loaded successfully")
57
  except Exception as e:
58
  logger.error(f"Failed to load ASR model: {str(e)}")
59
  raise
60
 
61
+
62
  def validate_audio_data(audio_data: List[float]) -> np.ndarray:
63
  """
64
  验证并转换音频数据为numpy数组
65
+
66
  参数:
67
  - audio_data: 浮点数列表表示的音频数据
68
+
69
  返回:
70
  - 验证后的numpy数组
71
  """
72
  try:
73
  # 转换为numpy数组
74
  np_array = np.array(audio_data, dtype=np.float32)
75
+
76
  # 验证数据有效性
77
  if np_array.ndim != 1:
78
  raise ValueError("Audio data must be 1-dimensional")
79
+
80
  if len(np_array) == 0:
81
  raise ValueError("Audio data cannot be empty")
82
+
83
  return np_array
84
  except Exception as e:
85
  raise ValueError(f"Invalid audio data: {str(e)}")
86
+
87
+
88
  @app.get("/get_language", summary="Get current language")
89
  async def get_language():
90
  return JSONResponse(content={"language": asr_model.language})
91
 
92
+
93
+ @app.get(
94
+ "/get_language_options",
95
+ summary="Get possible language options, possible options include [auto, zh, en, yue, ja, ko]",
96
+ )
97
  async def get_language_options():
98
  return JSONResponse(content={"language_options": asr_model.language_options})
99
 
100
+
101
  @app.post("/asr", summary="Recognize speech from numpy audio data")
102
  async def recognize_speech(
103
+ audio_data: List[float] = Body(
104
+ ..., embed=True, description="Audio data as list of floats"
105
+ ),
106
  sample_rate: Optional[int] = Body(16000, description="Audio sample rate in Hz"),
107
+ language: Optional[str] = Body("auto", description="Language"),
108
  ):
109
  """
110
  接收numpy数组格式的音频数据并返回识别结果
111
+
112
  参数:
113
  - audio_data: 浮点数列表表示的音频数据
114
  - sample_rate: 音频采样率(默认16000Hz)
115
+
116
  返回:
117
  - JSON包含识别文本
118
  """
 
120
  # 检查模型是否已加载
121
  if asr_model is None:
122
  raise HTTPException(status_code=503, detail="ASR model not loaded")
123
+
124
  logger.info(f"Received audio data with length: {len(audio_data)}")
125
+
126
  # 验证并转换数据
127
  np_audio = validate_audio_data(audio_data)
128
  if sample_rate != asr_model.sample_rate:
129
  np_audio = librosa.resample(np_audio, sample_rate, asr_model.sample_rate)
130
+
131
  # 调用模型进行识别
132
  result = asr_model.infer_waveform(np_audio, language)
133
+
134
  return JSONResponse(content={"text": result})
135
+
136
  except ValueError as e:
137
  logger.error(f"Validation error: {str(e)}")
138
  raise HTTPException(status_code=400, detail=str(e))
 
140
  logger.error(f"Recognition error: {str(e)}")
141
  raise HTTPException(status_code=500, detail=str(e))
142
 
143
+
144
  if __name__ == "__main__":
145
  import uvicorn
146
+
147
+ uvicorn.run(app, host="0.0.0.0", port=8000)
test_wer.py CHANGED
@@ -14,35 +14,35 @@ def setup_logging():
14
  # 获取脚本所在目录
15
  script_dir = os.path.dirname(os.path.abspath(__file__))
16
  log_file = os.path.join(script_dir, "test_wer.log")
17
-
18
  # 配置日志格式
19
- log_format = '%(asctime)s - %(levelname)s - %(message)s'
20
- date_format = '%Y-%m-%d %H:%M:%S'
21
-
22
  # 创建logger
23
  logger = logging.getLogger()
24
  logger.setLevel(logging.INFO)
25
-
26
  # 清除现有的handler
27
  for handler in logger.handlers[:]:
28
  logger.removeHandler(handler)
29
-
30
  # 创建文件handler
31
- file_handler = logging.FileHandler(log_file, mode='w', encoding='utf-8')
32
  file_handler.setLevel(logging.INFO)
33
  file_formatter = logging.Formatter(log_format, date_format)
34
  file_handler.setFormatter(file_formatter)
35
-
36
  # 创建控制台handler
37
  console_handler = logging.StreamHandler()
38
  console_handler.setLevel(logging.INFO)
39
  console_formatter = logging.Formatter(log_format, date_format)
40
  console_handler.setFormatter(console_formatter)
41
-
42
  # 添加handler到logger
43
  logger.addHandler(file_handler)
44
  logger.addHandler(console_handler)
45
-
46
  return logger
47
 
48
 
@@ -50,21 +50,21 @@ class AIShellDataset:
50
  def __init__(self, gt_path: str):
51
  """
52
  初始化数据集
53
-
54
  Args:
55
  json_path: voice.json文件的路径
56
  """
57
  self.gt_path = gt_path
58
  self.dataset_dir = os.path.dirname(gt_path)
59
  self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
60
-
61
  # 检查必要文件和文件夹是否存在
62
  assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
63
  assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
64
-
65
  # 加载数据
66
  self.data = []
67
- with open(gt_path, 'r', encoding='utf-8') as f:
68
  for line in f:
69
  line = line.strip()
70
  audio_path, gt = line.split(" ")
@@ -74,50 +74,50 @@ class AIShellDataset:
74
  # 使用logging而不是print
75
  logger = logging.getLogger()
76
  logger.info(f"加载了 {len(self.data)} 条数据")
77
-
78
  def __iter__(self):
79
  """返回迭代器"""
80
  self.index = 0
81
  return self
82
-
83
  def __next__(self):
84
  """返回下一个数据项"""
85
  if self.index >= len(self.data):
86
  raise StopIteration
87
-
88
  item = self.data[self.index]
89
  audio_path = item["audio_path"]
90
  ground_truth = item["gt"]
91
-
92
  self.index += 1
93
  return audio_path, ground_truth
94
-
95
  def __len__(self):
96
  """返回数据集大小"""
97
  return len(self.data)
98
-
99
 
100
  class CommonVoiceDataset:
101
  """Common Voice数据集解析器"""
102
-
103
  def __init__(self, tsv_path: str):
104
  """
105
  初始化数据集
106
-
107
  Args:
108
  json_path: voice.json文件的路径
109
  """
110
  self.tsv_path = tsv_path
111
  self.dataset_dir = os.path.dirname(tsv_path)
112
  self.voice_dir = os.path.join(self.dataset_dir, "clips")
113
-
114
  # 检查必要文件和文件夹是否存在
115
  assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
116
  assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
117
-
118
  # 加载JSON数据
119
  self.data = []
120
- with open(tsv_path, 'r', encoding='utf-8') as f:
121
  f.readline()
122
  for line in f:
123
  line = line.strip()
@@ -126,79 +126,101 @@ class CommonVoiceDataset:
126
  gt = splits[3]
127
  audio_path = os.path.join(self.voice_dir, audio_path)
128
  self.data.append({"audio_path": audio_path, "gt": gt})
129
-
130
  # 使用logging而不是print
131
  logger = logging.getLogger()
132
  logger.info(f"加载了 {len(self.data)} 条数据")
133
-
134
  def __iter__(self):
135
  """返回迭代器"""
136
  self.index = 0
137
  return self
138
-
139
  def __next__(self):
140
  """返回下一个数据项"""
141
  if self.index >= len(self.data):
142
  raise StopIteration
143
-
144
  item = self.data[self.index]
145
  audio_path = item["audio_path"]
146
  ground_truth = item["gt"]
147
-
148
  self.index += 1
149
  return audio_path, ground_truth
150
-
151
  def __len__(self):
152
  """返回数据集大小"""
153
  return len(self.data)
154
-
155
 
156
  def get_args():
157
  parser = argparse.ArgumentParser()
158
- parser.add_argument("--dataset", "-d", type=str, required=True, choices=["aishell", "common_voice"], help="Test dataset")
159
- parser.add_argument("--gt_path", "-g", type=str, required=True, help="Test dataset ground truth file")
160
- parser.add_argument("--language", "-l", required=False, type=str, default="auto", choices=["auto", "zh", "en", "yue", "ja", "ko"])
161
- parser.add_argument("--max_num", type=int, default=-1, required=False, help="Maximum test data num")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  return parser.parse_args()
163
 
164
 
165
  def min_distance(word1: str, word2: str) -> int:
166
-
167
  row = len(word1) + 1
168
  column = len(word2) + 1
169
-
170
- cache = [ [0]*column for i in range(row) ]
171
-
172
  for i in range(row):
173
  for j in range(column):
174
-
175
- if i ==0 and j ==0:
176
  cache[i][j] = 0
177
- elif i == 0 and j!=0:
178
  cache[i][j] = j
179
- elif j == 0 and i!=0:
180
  cache[i][j] = i
181
  else:
182
- if word1[i-1] == word2[j-1]:
183
- cache[i][j] = cache[i-1][j-1]
184
  else:
185
- replace = cache[i-1][j-1] + 1
186
- insert = cache[i][j-1] + 1
187
- remove = cache[i-1][j] + 1
188
-
189
  cache[i][j] = min(replace, insert, remove)
190
-
191
- return cache[row-1][column-1]
192
 
193
 
194
  def remove_punctuation(text):
195
  # 定义正则表达式模式,匹配所有标点符号
196
  # 这个模式包括常见的标点符号和中文标点
197
- pattern = r'[^\w\s]|_'
198
-
199
  # 使用sub方法将所有匹配的标点符号替换为空字符串
200
- cleaned_text = re.sub(pattern, '', text)
201
-
202
  return cleaned_text
203
 
204
 
@@ -207,7 +229,7 @@ def main():
207
  args = get_args()
208
 
209
  language = args.language
210
- use_itn = False # 标点符号预测
211
  max_num = args.max_num
212
 
213
  dataset_type = args.dataset.lower()
@@ -230,7 +252,9 @@ def main():
230
  logger.info(f"model_path: {model_path}")
231
 
232
  tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
233
- pipeline = SenseVoiceAx(model_path, language=language, use_itn=use_itn, tokenizer=tokenizer, max_len=256)
 
 
234
 
235
  # Iterate over dataset
236
  hyp = []
@@ -242,8 +266,10 @@ def main():
242
  reference = remove_punctuation(reference).lower()
243
 
244
  asr_res = pipeline.infer(audio_path, print_rtf=False)
245
- hypothesis = rich_print_asr_res(asr_res, will_print=False, remove_punc=True).lower()
246
- hypothesis = emoji.replace_emoji(hypothesis, replace='')
 
 
247
 
248
  character_error_num = min_distance(reference, hypothesis)
249
  character_num = len(reference)
@@ -254,7 +280,7 @@ def main():
254
 
255
  hyp.append(hypothesis)
256
  references.append(reference)
257
-
258
  line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
259
  logger.info(line_content)
260
 
@@ -265,5 +291,6 @@ def main():
265
 
266
  logger.info(f"Total WER: {total_character_error_rate}%")
267
 
 
268
  if __name__ == "__main__":
269
- main()
 
14
  # 获取脚本所在目录
15
  script_dir = os.path.dirname(os.path.abspath(__file__))
16
  log_file = os.path.join(script_dir, "test_wer.log")
17
+
18
  # 配置日志格式
19
+ log_format = "%(asctime)s - %(levelname)s - %(message)s"
20
+ date_format = "%Y-%m-%d %H:%M:%S"
21
+
22
  # 创建logger
23
  logger = logging.getLogger()
24
  logger.setLevel(logging.INFO)
25
+
26
  # 清除现有的handler
27
  for handler in logger.handlers[:]:
28
  logger.removeHandler(handler)
29
+
30
  # 创建文件handler
31
+ file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
32
  file_handler.setLevel(logging.INFO)
33
  file_formatter = logging.Formatter(log_format, date_format)
34
  file_handler.setFormatter(file_formatter)
35
+
36
  # 创建控制台handler
37
  console_handler = logging.StreamHandler()
38
  console_handler.setLevel(logging.INFO)
39
  console_formatter = logging.Formatter(log_format, date_format)
40
  console_handler.setFormatter(console_formatter)
41
+
42
  # 添加handler到logger
43
  logger.addHandler(file_handler)
44
  logger.addHandler(console_handler)
45
+
46
  return logger
47
 
48
 
 
50
  def __init__(self, gt_path: str):
51
  """
52
  初始化数据集
53
+
54
  Args:
55
  json_path: voice.json文件的路径
56
  """
57
  self.gt_path = gt_path
58
  self.dataset_dir = os.path.dirname(gt_path)
59
  self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
60
+
61
  # 检查必要文件和文件夹是否存在
62
  assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
63
  assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
64
+
65
  # 加载数据
66
  self.data = []
67
+ with open(gt_path, "r", encoding="utf-8") as f:
68
  for line in f:
69
  line = line.strip()
70
  audio_path, gt = line.split(" ")
 
74
  # 使用logging而不是print
75
  logger = logging.getLogger()
76
  logger.info(f"加载了 {len(self.data)} 条数据")
77
+
78
  def __iter__(self):
79
  """返回迭代器"""
80
  self.index = 0
81
  return self
82
+
83
  def __next__(self):
84
  """返回下一个数据项"""
85
  if self.index >= len(self.data):
86
  raise StopIteration
87
+
88
  item = self.data[self.index]
89
  audio_path = item["audio_path"]
90
  ground_truth = item["gt"]
91
+
92
  self.index += 1
93
  return audio_path, ground_truth
94
+
95
  def __len__(self):
96
  """返回数据集大小"""
97
  return len(self.data)
98
+
99
 
100
  class CommonVoiceDataset:
101
  """Common Voice数据集解析器"""
102
+
103
  def __init__(self, tsv_path: str):
104
  """
105
  初始化数据集
106
+
107
  Args:
108
  json_path: voice.json文件的路径
109
  """
110
  self.tsv_path = tsv_path
111
  self.dataset_dir = os.path.dirname(tsv_path)
112
  self.voice_dir = os.path.join(self.dataset_dir, "clips")
113
+
114
  # 检查必要文件和文件夹是否存在
115
  assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
116
  assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
117
+
118
  # 加载JSON数据
119
  self.data = []
120
+ with open(tsv_path, "r", encoding="utf-8") as f:
121
  f.readline()
122
  for line in f:
123
  line = line.strip()
 
126
  gt = splits[3]
127
  audio_path = os.path.join(self.voice_dir, audio_path)
128
  self.data.append({"audio_path": audio_path, "gt": gt})
129
+
130
  # 使用logging而不是print
131
  logger = logging.getLogger()
132
  logger.info(f"加载了 {len(self.data)} 条数据")
133
+
134
  def __iter__(self):
135
  """返回迭代器"""
136
  self.index = 0
137
  return self
138
+
139
  def __next__(self):
140
  """返回下一个数据项"""
141
  if self.index >= len(self.data):
142
  raise StopIteration
143
+
144
  item = self.data[self.index]
145
  audio_path = item["audio_path"]
146
  ground_truth = item["gt"]
147
+
148
  self.index += 1
149
  return audio_path, ground_truth
150
+
151
  def __len__(self):
152
  """返回数据集大小"""
153
  return len(self.data)
154
+
155
 
156
  def get_args():
157
  parser = argparse.ArgumentParser()
158
+ parser.add_argument(
159
+ "--dataset",
160
+ "-d",
161
+ type=str,
162
+ required=True,
163
+ choices=["aishell", "common_voice"],
164
+ help="Test dataset",
165
+ )
166
+ parser.add_argument(
167
+ "--gt_path",
168
+ "-g",
169
+ type=str,
170
+ required=True,
171
+ help="Test dataset ground truth file",
172
+ )
173
+ parser.add_argument(
174
+ "--language",
175
+ "-l",
176
+ required=False,
177
+ type=str,
178
+ default="auto",
179
+ choices=["auto", "zh", "en", "yue", "ja", "ko"],
180
+ )
181
+ parser.add_argument(
182
+ "--max_num", type=int, default=-1, required=False, help="Maximum test data num"
183
+ )
184
  return parser.parse_args()
185
 
186
 
187
  def min_distance(word1: str, word2: str) -> int:
188
+
189
  row = len(word1) + 1
190
  column = len(word2) + 1
191
+
192
+ cache = [[0] * column for i in range(row)]
193
+
194
  for i in range(row):
195
  for j in range(column):
196
+
197
+ if i == 0 and j == 0:
198
  cache[i][j] = 0
199
+ elif i == 0 and j != 0:
200
  cache[i][j] = j
201
+ elif j == 0 and i != 0:
202
  cache[i][j] = i
203
  else:
204
+ if word1[i - 1] == word2[j - 1]:
205
+ cache[i][j] = cache[i - 1][j - 1]
206
  else:
207
+ replace = cache[i - 1][j - 1] + 1
208
+ insert = cache[i][j - 1] + 1
209
+ remove = cache[i - 1][j] + 1
210
+
211
  cache[i][j] = min(replace, insert, remove)
212
+
213
+ return cache[row - 1][column - 1]
214
 
215
 
216
  def remove_punctuation(text):
217
  # 定义正则表达式模式,匹配所有标点符号
218
  # 这个模式包括常见的标点符号和中文标点
219
+ pattern = r"[^\w\s]|_"
220
+
221
  # 使用sub方法将所有匹配的标点符号替换为空字符串
222
+ cleaned_text = re.sub(pattern, "", text)
223
+
224
  return cleaned_text
225
 
226
 
 
229
  args = get_args()
230
 
231
  language = args.language
232
+ use_itn = False # 标点符号预测
233
  max_num = args.max_num
234
 
235
  dataset_type = args.dataset.lower()
 
252
  logger.info(f"model_path: {model_path}")
253
 
254
  tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
255
+ pipeline = SenseVoiceAx(
256
+ model_path, language=language, use_itn=use_itn, tokenizer=tokenizer, max_len=256
257
+ )
258
 
259
  # Iterate over dataset
260
  hyp = []
 
266
  reference = remove_punctuation(reference).lower()
267
 
268
  asr_res = pipeline.infer(audio_path, print_rtf=False)
269
+ hypothesis = rich_print_asr_res(
270
+ asr_res, will_print=False, remove_punc=True
271
+ ).lower()
272
+ hypothesis = emoji.replace_emoji(hypothesis, replace="")
273
 
274
  character_error_num = min_distance(reference, hypothesis)
275
  character_num = len(reference)
 
280
 
281
  hyp.append(hypothesis)
282
  references.append(reference)
283
+
284
  line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
285
  logger.info(line_content)
286
 
 
291
 
292
  logger.info(f"Total WER: {total_character_error_rate}%")
293
 
294
+
295
  if __name__ == "__main__":
296
+ main()
tokenizer.py CHANGED
@@ -52,7 +52,9 @@ class BaseTokenizer(ABC):
52
 
53
  self.unk_symbol = unk_symbol
54
  if self.unk_symbol not in self.token2id:
55
- raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list")
 
 
56
  self.unk_id = self.token2id[self.unk_symbol]
57
 
58
  def encode(self, text, **kwargs):
@@ -84,7 +86,7 @@ class BaseTokenizer(ABC):
84
  @abstractmethod
85
  def tokens2text(self, tokens: Iterable[str]) -> str:
86
  raise NotImplementedError
87
-
88
 
89
  class SentencepiecesTokenizer(BaseTokenizer):
90
  def __init__(self, bpemodel: Union[Path, str], **kwargs):
@@ -130,4 +132,4 @@ class SentencepiecesTokenizer(BaseTokenizer):
130
  return self.decode(*args, **kwargs)
131
 
132
  def tokens2ids(self, *args, **kwargs):
133
- return self.encode(*args, **kwargs)
 
52
 
53
  self.unk_symbol = unk_symbol
54
  if self.unk_symbol not in self.token2id:
55
+ raise RuntimeError(
56
+ f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
57
+ )
58
  self.unk_id = self.token2id[self.unk_symbol]
59
 
60
  def encode(self, text, **kwargs):
 
86
  @abstractmethod
87
  def tokens2text(self, tokens: Iterable[str]) -> str:
88
  raise NotImplementedError
89
+
90
 
91
  class SentencepiecesTokenizer(BaseTokenizer):
92
  def __init__(self, bpemodel: Union[Path, str], **kwargs):
 
132
  return self.decode(*args, **kwargs)
133
 
134
  def tokens2ids(self, *args, **kwargs):
135
+ return self.encode(*args, **kwargs)