inoryQwQ commited on
Commit
d56de90
·
1 Parent(s): 5d4703e

Optimize RTF and cer

Browse files
README.md CHANGED
@@ -6,7 +6,7 @@ license: apache-2.0
6
 
7
  小红书ASR AED-L版本在AX650N上的部署,原项目地址为:[https://github.com/FireRedTeam/FireRedASR](https://github.com/FireRedTeam/FireRedASR)
8
 
9
- 转换后的模型放置在axmodel目录,目前支持中文、英文,最长输入10秒的音频。
10
 
11
  ## 模型转换
12
 
@@ -50,121 +50,11 @@ pip install axengine-0.1.3-py3-none-any.whl
50
  conda activate fireredasr
51
  python test_ax_model.py
52
  ```
 
53
 
54
- 输出结果如下:
55
- ```
56
- [INFO] Available providers: ['AxEngineExecutionProvider']
57
- Namespace(encoder='axmodel/encoder.axmodel', decoder='axmodel/decoder_main.axmodel', cmvn='axmodel/cmvn.ark', dict='axmodel/dict.txt', spm_model='axmodel/train_bpe1000.model', wavlist='wavlist.txt', hypo='hypo_axmodel.txt', beam_size=3, nbest=1, max_len=128)
58
- [WARNING] Selected provider(s): ['AXCLRTExecutionProvider'] is(are) not available.
59
- [INFO] Using provider: AxEngineExecutionProvider
60
- [INFO] Chip type: ChipType.MC50
61
- [INFO] VNPU type: VNPUType.DISABLED
62
- [INFO] Engine version: 2.12.0s
63
- [INFO] Model type: 2 (triple core)
64
- [INFO] Compiler version: 4.2 9555977e
65
- load encoder cost 2.764460325241089 seconds
66
- [WARNING] Selected provider(s): ['AXCLRTExecutionProvider'] is(are) not available.
67
- [INFO] Using provider: AxEngineExecutionProvider
68
- [INFO] Model type: 2 (triple core)
69
- [INFO] Compiler version: 4.2 9555977e
70
- load decoder_main cost 16.36833119392395 seconds
71
- [WARNING] Selected provider(s): ['AXCLRTExecutionProvider'] is(are) not available.
72
- [INFO] Using provider: AxEngineExecutionProvider
73
- [INFO] Model type: 2 (triple core)
74
- [INFO] Compiler version: 4.2 9555977e
75
- load decoder_loop cost 16.194183826446533 seconds
76
- run encoder take 196.9749927520752ms
77
- run decoder_main take 130.2931308746338ms
78
- run decoder_loop take 165.5733585357666ms
79
- run decoder_loop take 109.67779159545898ms
80
- run decoder_loop take 101.15742683410645ms
81
- run decoder_loop take 110.09836196899414ms
82
- run decoder_loop take 100.29029846191406ms
83
- run decoder_loop take 109.33351516723633ms
84
- run decoder_loop take 100.37779808044434ms
85
- run decoder_loop take 109.72428321838379ms
86
- run decoder_loop take 100.42023658752441ms
87
- run decoder_loop take 101.71890258789062ms
88
- run decoder_loop take 100.09407997131348ms
89
- run decoder_loop take 110.25619506835938ms
90
- run decoder_loop take 100.54206848144531ms
91
- run decoder_loop take 101.93896293640137ms
92
- ['wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav']
93
- Durations: 1.8
94
- Transcribe Durations: 2.5527637004852295
95
- (Real time factor) RTF: 1.4182020558251274
96
- wav: wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav
97
- text: 我有的时候说不清楚你们知道吗
98
- score: -0.9156361222267151
99
-
100
- run encoder take 180.2656650543213ms
101
- run decoder_main take 91.42565727233887ms
102
- run decoder_loop take 105.18240928649902ms
103
- run decoder_loop take 100.56614875793457ms
104
- run decoder_loop take 100.9066104888916ms
105
- run decoder_loop take 100.9068489074707ms
106
- run decoder_loop take 102.90265083312988ms
107
- run decoder_loop take 100.50129890441895ms
108
- run decoder_loop take 110.12482643127441ms
109
- run decoder_loop take 100.65031051635742ms
110
- run decoder_loop take 110.09883880615234ms
111
- run decoder_loop take 105.48877716064453ms
112
- run decoder_loop take 100.32439231872559ms
113
- run decoder_loop take 106.08601570129395ms
114
- run decoder_loop take 100.79813003540039ms
115
- run decoder_loop take 100.4643440246582ms
116
- run decoder_loop take 100.30460357666016ms
117
- ['wav/TEST_MEETING_T0000000001_S00000.wav']
118
- Durations: 12.369
119
- Transcribe Durations: 2.464834690093994
120
- (Real time factor) RTF: 0.19927517908432324
121
- wav: wav/TEST_MEETING_T0000000001_S00000.wav
122
- text: 好首先说一下刚才这个
123
- score: -0.5064160823822021
124
-
125
- run encoder take 172.59907722473145ms
126
- run decoder_main take 91.79949760437012ms
127
- run decoder_loop take 105.04364967346191ms
128
- run decoder_loop take 100.62885284423828ms
129
- run decoder_loop take 101.89318656921387ms
130
- run decoder_loop take 100.42643547058105ms
131
- run decoder_loop take 109.7562313079834ms
132
- ['wav/IT0011W0001.wav']
133
- Durations: 1.992
134
- Transcribe Durations: 1.0302071571350098
135
- (Real time factor) RTF: 0.5171722676380571
136
- wav: wav/IT0011W0001.wav
137
- text: 换一首歌
138
- score: -0.016501454636454582
139
-
140
- run encoder take 173.07257652282715ms
141
- run decoder_main take 91.48693084716797ms
142
- run decoder_loop take 105.42607307434082ms
143
- run decoder_loop take 100.10981559753418ms
144
- run decoder_loop take 100.4478931427002ms
145
- run decoder_loop take 100.23713111877441ms
146
- run decoder_loop take 100.10337829589844ms
147
- run decoder_loop take 100.29196739196777ms
148
- run decoder_loop take 101.7463207244873ms
149
- run decoder_loop take 100.8148193359375ms
150
- run decoder_loop take 109.99274253845215ms
151
- run decoder_loop take 105.45015335083008ms
152
- run decoder_loop take 100.59380531311035ms
153
- run decoder_loop take 100.73733329772949ms
154
- run decoder_loop take 100.4335880279541ms
155
- run decoder_loop take 109.68661308288574ms
156
- ['wav/BAC009S0764W0121.wav']
157
- Durations: 4.2039375
158
- Transcribe Durations: 2.3024709224700928
159
- (Real time factor) RTF: 0.5476938994621334
160
- wav: wav/BAC009S0764W0121.wav
161
- text: 甚至出现交易几乎停滞的情况
162
- score: -0.11461181938648224
163
-
164
- total wav durations: 20.364937500000003
165
- total transcribe durations: 8.350276470184326
166
- AVG RTF: 0.4100320204854213
167
 
168
- ```
 
 
169
 
170
- ```hypo_axmodel.txt```包含识别结果
 
6
 
7
  小红书ASR AED-L版本在AX650N上的部署,原项目地址为:[https://github.com/FireRedTeam/FireRedASR](https://github.com/FireRedTeam/FireRedASR)
8
 
9
+ 转换后的模型放置在axmodel目录,目前支持中文、英文,最长输入10秒的音频,超过10秒的音频会用VAD切割后推理。
10
 
11
  ## 模型转换
12
 
 
50
  conda activate fireredasr
51
  python test_ax_model.py
52
  ```
53
+ ```hypo_axmodel.txt```包含识别结果
54
 
55
+ ## 性能表现
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ RTF ~= 0.3
58
+
59
+ CER(on custom dataset): 3.45%
60
 
 
axmodel/decoder_loop.axmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c9b3e351557d20846f50d819e18c59d6f10a8adfc40322e5e3034b404b3e038
3
- size 435136795
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2912496e6804027f2dc77c903f6b2f76678603dd616e662b78e3f226bcaa91a
3
+ size 416269694
fireredasr/data/asr_feat.py CHANGED
@@ -42,7 +42,7 @@ class ASRFeatExtractor:
42
 
43
  lengths = torch.tensor([feat.size(0) for feat in feats]).long()
44
  feats_pad = self.pad_feat(feats, 0.0)
45
- return feats_pad, lengths, dur
46
 
47
  def pad_feat(self, xs, pad_value):
48
  # type: (List[Tensor], int) -> Tensor
 
42
 
43
  lengths = torch.tensor([feat.size(0) for feat in feats]).long()
44
  feats_pad = self.pad_feat(feats, 0.0)
45
+ return feats_pad.numpy(), lengths, dur
46
 
47
  def pad_feat(self, xs, pad_value):
48
  # type: (List[Tensor], int) -> Tensor
fireredasr_axmodel.py CHANGED
@@ -10,6 +10,7 @@ from typing import Tuple, List, Dict
10
  import os
11
  import time
12
  import torchaudio
 
13
 
14
  try:
15
  torchaudio.set_audio_backend("soundfile")
@@ -44,18 +45,30 @@ def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
44
  return ys * (1 - is_finished) + eos_id * is_finished
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class FireRedASRAxModel:
48
- def __init__(
49
- self,
50
- encoder_path: str,
51
- decoder_loop_path: str,
52
- cmvn_file: str,
53
- dict_file: str,
54
- spm_model_path: str,
55
- providers=["AxEngineExecutionProvider"],
56
- decode_max_len=128,
57
- audio_dur=10,
58
- ):
59
  # NOTE: 参考whisper设置的最大的解码长度
60
  # FireRedASR-AED 模型支持的最长语音为 60s
61
  # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
@@ -79,6 +92,21 @@ class FireRedASRAxModel:
79
 
80
  self.vad_model = load_silero_vad()
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def init_encoder(self, encoder_path, providers=None):
83
  self.encoder = axe.InferenceSession(encoder_path, providers=providers)
84
 
@@ -90,7 +118,7 @@ class FireRedASRAxModel:
90
  decoder_path = os.path.join(decoder_path, "pe.npy")
91
 
92
  return np.load(decoder_path)
93
-
94
  def run_encoder(
95
  self, input: np.ndarray, input_length: np.ndarray
96
  ) -> Tuple[Tensor, Tensor, Tensor]:
@@ -98,7 +126,7 @@ class FireRedASRAxModel:
98
  None, {"encoder_input": input, "encoder_input_lengths": input_length}
99
  )
100
  return (n_layer_cross_k, n_layer_cross_v, cross_attn_mask)
101
-
102
  def decode_loop_one_token(
103
  self,
104
  tokens: np.ndarray,
@@ -128,271 +156,485 @@ class FireRedASRAxModel:
128
  },
129
  )
130
  return (logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache)
131
-
132
- def run_decoder(
133
- self, n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
134
- ):
135
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
137
  encoder_out_length = cross_attn_mask.shape[-1]
138
 
139
- cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
140
- cross_attn_mask = (
141
- cross_attn_mask.unsqueeze(1)
142
- .repeat(1, beam_size, 1, 1)
143
- .view(beam_size * batch_size, -1, encoder_out_length)
144
- )
145
-
146
- n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
147
- n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
148
- n_layer_cross_k = (
149
- n_layer_cross_k.unsqueeze(2)
150
- .repeat(1, 1, beam_size, 1, 1)
151
- .view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
152
- )
153
- n_layer_cross_v = (
154
- n_layer_cross_v.unsqueeze(2)
155
- .repeat(1, 1, beam_size, 1, 1)
156
- .view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
157
- )
158
-
159
- prediction_tokens = (
160
- torch.ones(beam_size * batch_size, 1).fill_(self.sos_id).long()
161
- )
162
- tokens = prediction_tokens
163
- offset = torch.zeros(1, dtype=torch.int64)
164
- n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
165
  batch_size, beam_size
166
  )
167
-
168
- scores = torch.tensor([0.0] + [-INF] * (beam_size - 1)).float()
169
- scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
170
- is_finished = torch.zeros_like(scores)
171
-
172
- self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
173
-
174
- for i in range(self.decode_max_len):
175
-
176
- tokens = to_numpy(tokens).astype(np.int32)
177
- n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
178
- n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
179
- n_layer_cross_k = to_numpy(n_layer_cross_k)
180
- n_layer_cross_v = to_numpy(n_layer_cross_v)
181
- cross_attn_mask = to_numpy(cross_attn_mask)
182
-
183
- self_attn_mask = np.zeros(
184
- (batch_size * beam_size, 1, self.decode_max_len), dtype=np.float32
 
 
 
 
 
185
  )
186
- self_attn_mask[:, :, : self.decode_max_len - offset[0] - 1] = -np.inf
187
-
188
- (
189
- logits,
190
- n_layer_self_k_cache,
191
- n_layer_self_v_cache,
192
- ) = self.decode_loop_one_token(
193
- to_numpy(tokens),
194
- to_numpy(n_layer_self_k_cache),
195
- to_numpy(n_layer_self_v_cache),
196
- to_numpy(n_layer_cross_k),
197
- to_numpy(n_layer_cross_v),
198
- self.pe[offset],
199
- self_attn_mask,
200
- to_numpy(cross_attn_mask),
201
  )
202
-
203
- offset += 1
204
- logits = torch.from_numpy(logits)
205
-
206
- logits = logits.squeeze(1)
207
  t_scores = F.log_softmax(logits, dim=-1)
208
- t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
209
- t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
210
- t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
211
-
212
- scores = scores + t_topB_scores
213
-
214
- scores = scores.view(batch_size, beam_size * beam_size)
215
- scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
216
- scores = scores.view(-1, 1)
217
-
218
- topB_row_number_in_each_B_rows_of_ys = torch.div(
219
- topB_score_ids, beam_size
220
- ).view(batch_size * beam_size)
221
- stride = beam_size * torch.arange(batch_size).view(batch_size, 1).repeat(
222
- 1, beam_size
223
- ).view(batch_size * beam_size)
224
- topB_row_number_in_ys = (
225
- topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
226
  )
227
-
228
- prediction_tokens = prediction_tokens[topB_row_number_in_ys]
229
- t_ys = torch.gather(
230
- t_topB_ys.view(batch_size, beam_size * beam_size),
231
- dim=1,
232
- index=topB_score_ids,
233
- ).view(beam_size * batch_size, 1)
234
-
235
- tokens = t_ys
236
-
237
- prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
238
-
239
- n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
240
- n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
241
-
242
- for i, self_k_cache in enumerate(n_layer_self_k_cache):
243
- n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
244
-
245
- for i, self_v_cache in enumerate(n_layer_self_v_cache):
246
- n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
247
-
248
- is_finished = t_ys.eq(self.eos_id)
249
- if is_finished.sum().item() == beam_size * batch_size:
250
  break
251
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  scores = scores.view(batch_size, beam_size)
253
- prediction_valid_token_lengths = torch.sum(
254
  torch.ne(prediction_tokens.view(batch_size, beam_size, -1), self.eos_id),
255
- dim=-1,
256
  ).int()
257
-
258
  nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
259
- index = (
260
- nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
261
- )
262
- nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[
263
- index.view(-1)
264
- ]
265
- nbest_prediction_tokens = nbest_prediction_tokens.view(
266
- batch_size, nbest_ids.size(1), -1
267
- )
268
- nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
269
- batch_size * beam_size
270
- )[index.view(-1)].view(batch_size, -1)
271
-
272
- # batch_size is always 1
273
- i_best_hyps: List[Dict[str, torch.Tensor]] = []
274
  for j, score in enumerate(nbest_scores[0]):
275
  hyp = {
276
- "token_ids": nbest_prediction_tokens[
277
- 0, j, 1 : nbest_prediction_valid_token_lengths[0, j]
278
- ],
279
  "score": score,
280
  }
281
- i_best_hyps.append(hyp)
282
-
283
- return i_best_hyps
284
-
285
- def get_initialized_self_cache(
286
- self, batch_size, beam_size
287
- ) -> Tuple[Tensor, Tensor]:
288
- n_layer_self_k_cache = torch.zeros(
289
- self.num_decoder_blocks,
290
- batch_size * beam_size,
291
- self.decode_max_len,
292
- self.decoder_hidden_dim,
293
- )
294
- n_layer_self_v_cache = torch.zeros(
295
- self.num_decoder_blocks,
296
- batch_size * beam_size,
297
- self.decode_max_len,
298
- self.decoder_hidden_dim,
299
- )
300
- return n_layer_self_k_cache, n_layer_self_v_cache
301
-
302
- def calc_feat_len(self, audio_dur):
303
- import math
304
-
305
- sample_rate = self.sample_rate
306
- frame_length = 25 * sample_rate / 1000
307
- frame_shift = 10 * sample_rate / 1000
308
- length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
309
- return length
310
-
311
- def collect_chunks(self, wav, speech_timestamps, audio_dur, sample_rate):
312
- max_chunk_samples = int(audio_dur * sample_rate)
313
- chunks = []
314
- for ts in speech_timestamps:
315
- start, end = ts["start"], ts["end"]
316
- cur_chunk = wav[start:end]
317
- if (
318
- len(chunks) > 0
319
- and chunks[-1].shape[0] + cur_chunk.shape[0] < max_chunk_samples
320
- ):
321
- chunks[-1] = torch.concat([chunks[-1], cur_chunk], dim=0)
322
- else:
323
- if cur_chunk.shape[0] > max_chunk_samples:
324
- # greedy split if one chunk is too big
325
- chunks.append(cur_chunk[:max_chunk_samples])
326
- chunks.append(cur_chunk[max_chunk_samples:])
327
- else:
328
- chunks.append(cur_chunk)
329
- return chunks
330
-
331
- def transcribe(
332
- self, batch_wav_path: List[str], beam_size: int = 1, nbest: int = 1
333
  ) -> List[Dict]:
334
-
335
- # Run vad, greedy split audio to fit audio_dur
336
- try:
337
- wav = read_audio(batch_wav_path[0], sampling_rate=self.sample_rate)
338
- except Exception as e:
339
- print("Please run apt install libsnffile1 first")
340
- raise e
341
-
342
- max_chunk_samples = int(self.sample_rate * self.audio_dur)
343
- if wav.shape[0] < max_chunk_samples:
344
- chunks = [wav]
345
- else:
346
- speech_timestamps = get_speech_timestamps(
347
- wav,
348
- self.vad_model,
349
- return_seconds=False, # Return speech timestamps in seconds (default is samples)
350
- )
351
- chunks = self.collect_chunks(
352
- wav, speech_timestamps, self.audio_dur, self.sample_rate
353
- )
354
- # print(f"Split to {len(chunks)} chunks")
355
-
356
- transcribe_durations = 0
357
- wav_durations = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  tokens = []
 
 
 
359
  for chunk in chunks:
360
- chunk = (chunk.clamp(-1, 1) * 32768).to(torch.int16)
361
- feats, lengths, wav_duration = self.feature_extractor.run_chunk(
362
- chunk, self.sample_rate
363
- )
364
-
365
  wav_durations.append(wav_duration)
366
-
367
- if feats.shape[1] < self.max_feat_len:
368
- feats = np.concatenate(
369
- [
370
- feats,
371
- np.zeros(
372
- (1, self.max_feat_len - feats.shape[1], 80),
373
- dtype=np.float32,
374
- ),
375
- ],
376
- axis=1,
377
- )
378
- feats = feats[:, : self.max_feat_len, :]
379
- lengths = torch.minimum(lengths, torch.tensor(self.max_feat_len))
380
-
381
- feats = to_numpy(feats)
382
- lengths = to_numpy(lengths).astype(np.int32)
383
-
384
  start_time = time.time()
385
  n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
386
- to_numpy(feats), to_numpy(lengths)
387
  )
388
- # print(f"run encoder take {(time.time() - start_time) * 1000}ms")
389
- nbest_hyps = self.run_decoder(
390
  n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
391
  )
392
- tokens.extend([int(id) for id in nbest_hyps[0]["token_ids"].cpu()])
393
-
394
- transcribe_durations += time.time() - start_time
395
-
396
  text = self.tokenizer.detokenize(tokens)
397
-
398
- return {"text": text}, wav_durations, transcribe_durations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import os
11
  import time
12
  import torchaudio
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
 
15
  try:
16
  torchaudio.set_audio_backend("soundfile")
 
45
  return ys * (1 - is_finished) + eos_id * is_finished
46
 
47
 
48
+ def expand_for_beam_search(n_layer_cross_k, beam_size):
49
+ """方法1: 使用expand_dims + tile + reshape (最快)"""
50
+ num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
51
+
52
+ # 在第2维插入新维度
53
+ expanded = np.expand_dims(n_layer_cross_k, axis=2)
54
+ # 使用tile替代repeat,性能更好
55
+ tiled = np.tile(expanded, (1, 1, beam_size, 1, 1))
56
+ # 重塑形状
57
+ reshaped = tiled.reshape(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
58
+
59
+ return reshaped
60
+
61
+
62
  class FireRedASRAxModel:
63
+ def __init__(self,
64
+ encoder_path: str,
65
+ decoder_loop_path: str,
66
+ cmvn_file: str,
67
+ dict_file: str,
68
+ spm_model_path: str,
69
+ providers=["AxEngineExecutionProvider"],
70
+ decode_max_len=128,
71
+ audio_dur=10):
 
 
72
  # NOTE: 参考whisper设置的最大的解码长度
73
  # FireRedASR-AED 模型支持的最长语音为 60s
74
  # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
 
92
 
93
  self.vad_model = load_silero_vad()
94
 
95
+ # 预分配内存
96
+ self._preallocated_memory()
97
+ # 启用CUDA如果可用
98
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
99
+ # print(f"Using device: {self.device}")
100
+
101
+ def calc_feat_len(self, audio_dur):
102
+ import math
103
+
104
+ sample_rate = self.sample_rate
105
+ frame_length = 25 * sample_rate / 1000
106
+ frame_shift = 10 * sample_rate / 1000
107
+ length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
108
+ return length
109
+
110
  def init_encoder(self, encoder_path, providers=None):
111
  self.encoder = axe.InferenceSession(encoder_path, providers=providers)
112
 
 
118
  decoder_path = os.path.join(decoder_path, "pe.npy")
119
 
120
  return np.load(decoder_path)
121
+
122
  def run_encoder(
123
  self, input: np.ndarray, input_length: np.ndarray
124
  ) -> Tuple[Tensor, Tensor, Tensor]:
 
126
  None, {"encoder_input": input, "encoder_input_lengths": input_length}
127
  )
128
  return (n_layer_cross_k, n_layer_cross_v, cross_attn_mask)
129
+
130
  def decode_loop_one_token(
131
  self,
132
  tokens: np.ndarray,
 
156
  },
157
  )
158
  return (logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache)
159
+
160
+ def _preallocated_memory(self):
161
+ """预分配常用内存空间"""
162
+ # 预计算self_attn_mask模板
163
+ self.self_attn_mask_templates = {}
164
+ for offset in range(self.decode_max_len):
165
+ mask = np.zeros((1, 1, self.decode_max_len), dtype=np.float32)
166
+ mask[:, :, :self.decode_max_len - offset - 1] = -np.inf
167
+ self.self_attn_mask_templates[offset] = mask
168
+
169
+ # 预分配beam search的scores模板
170
+ self.beam_scores_template = torch.tensor(
171
+ [0.0] + [-INF] * (self.decode_max_len - 1)
172
+ ).float()
173
+
174
+ def transcribe(
175
+ self,
176
+ batch_wav_path: List[str],
177
+ beam_size: int = 1,
178
+ nbest: int = 1,
179
+ use_parallel: bool = False
180
+ ) -> List[Dict]:
181
+ """优化后的转录方法"""
182
+
183
+ # 1. 优化VAD和分块处理
184
+ chunks = self._optimized_vad_split(batch_wav_path[0])
185
+
186
+ if use_parallel and len(chunks) > 1:
187
+ return self._parallel_transcribe(chunks, beam_size, nbest)
188
+ else:
189
+ return self._sequential_transcribe(chunks, beam_size, nbest)
190
+
191
+ def _optimized_vad_split(self, wav_path: str) -> List[torch.Tensor]:
192
+ """优化的VAD分块处理"""
193
+ import torchaudio
194
+
195
+ # 直接读取为numpy数组,避免torchaudio开销
196
+ try:
197
+ wav, sr = torchaudio.load(wav_path)
198
+ if sr != self.sample_rate:
199
+ wav = torchaudio.functional.resample(wav, sr, self.sample_rate)
200
+ except:
201
+ # 使用silero_vad的read_audio作为备选
202
+ from silero_vad import read_audio
203
+ wav = read_audio(wav_path, sampling_rate=self.sample_rate)
204
+ wav = wav.unsqueeze(0)
205
+
206
+ wav = wav.squeeze(0)
207
+
208
+ # 快速VAD:如果音频较短,直接返回
209
+ max_chunk_samples = int(self.sample_rate * self.audio_dur)
210
+ if wav.shape[0] < max_chunk_samples:
211
+ return [wav]
212
+
213
+ # 使用优化的VAD参数
214
+ speech_timestamps = get_speech_timestamps(
215
+ wav,
216
+ self.vad_model,
217
+ threshold=0.5, # 提高阈值,减少静音检测
218
+ min_speech_duration_ms=250, # 最小语音段
219
+ min_silence_duration_ms=100, # 最小静音段
220
+ return_seconds=False,
221
+ )
222
+
223
+ # 优化的分块合并算法
224
+ return self._optimized_collect_chunks(wav, speech_timestamps)
225
+
226
+ def _optimized_collect_chunks(
227
+ self,
228
+ wav: torch.Tensor,
229
+ speech_timestamps: List[Dict]
230
+ ) -> List[torch.Tensor]:
231
+ """优化的分块合并算法"""
232
+ max_chunk_samples = int(self.sample_rate * self.audio_dur)
233
+ chunks = []
234
+ current_chunk = []
235
+ current_length = 0
236
+
237
+ for ts in speech_timestamps:
238
+ start, end = ts["start"], ts["end"]
239
+ chunk_length = end - start
240
+
241
+ if current_length + chunk_length <= max_chunk_samples:
242
+ current_chunk.append((start, end))
243
+ current_length += chunk_length
244
+ else:
245
+ if current_chunk:
246
+ # 合并当前chunk
247
+ merged = torch.cat([wav[s:e] for s, e in current_chunk])
248
+ chunks.append(merged)
249
+
250
+ if chunk_length > max_chunk_samples:
251
+ # 大chunk分割
252
+ num_splits = (chunk_length + max_chunk_samples - 1) // max_chunk_samples
253
+ for i in range(num_splits):
254
+ s = start + i * max_chunk_samples
255
+ e = min(start + (i + 1) * max_chunk_samples, end)
256
+ chunks.append(wav[s:e])
257
+ current_chunk = []
258
+ current_length = 0
259
+ else:
260
+ current_chunk = [(start, end)]
261
+ current_length = chunk_length
262
+
263
+ # 处理最后一个chunk
264
+ if current_chunk:
265
+ merged = torch.cat([wav[s:e] for s, e in current_chunk])
266
+ chunks.append(merged)
267
+
268
+ return chunks
269
+
270
+ def _optimized_decode_loop(
271
+ self,
272
+ n_layer_cross_k: np.ndarray,
273
+ n_layer_cross_v: np.ndarray,
274
+ cross_attn_mask: np.ndarray,
275
+ beam_size: int,
276
+ nbest: int
277
+ ) -> List[Dict]:
278
+ """优化的解码循环"""
279
+
280
  num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
281
  encoder_out_length = cross_attn_mask.shape[-1]
282
 
283
+ n_layer_cross_k = expand_for_beam_search(n_layer_cross_k, beam_size)
284
+ n_layer_cross_v = expand_for_beam_search(n_layer_cross_v, beam_size)
285
+
286
+ batch_size, Ti, encoder_out_length = cross_attn_mask.shape
287
+
288
+ # 在第1维插入新维度
289
+ expanded = np.expand_dims(cross_attn_mask, axis=1)
290
+ # 使用tile替代repeat,性能更好
291
+ tiled = np.tile(expanded, (1, beam_size, 1, 1))
292
+ # 重塑形状
293
+ cross_attn_mask = tiled.reshape(beam_size * batch_size, Ti, encoder_out_length)
294
+
295
+ # 优化的cache初始化
296
+ n_layer_self_k_cache, n_layer_self_v_cache = self._optimized_init_self_cache(
 
 
 
 
 
 
 
 
 
 
 
 
297
  batch_size, beam_size
298
  )
299
+
300
+ # 预分配tokens和scores
301
+ tokens = torch.full(
302
+ (beam_size * batch_size, 1),
303
+ self.sos_id,
304
+ dtype=torch.int32, device=self.device
305
+ )
306
+ scores = self.beam_scores_template[:beam_size].repeat(batch_size).view(
307
+ batch_size * beam_size, 1
308
+ ).to(self.device)
309
+ is_finished = torch.zeros_like(scores, dtype=torch.bool, device=self.device)
310
+
311
+ # 预分配prediction_tokens
312
+ prediction_tokens = tokens.clone()
313
+
314
+ pe_np = self.pe
315
+
316
+ for offset in range(self.decode_max_len):
317
+ # 使用预计算的mask模板
318
+ self_attn_mask = np.repeat(
319
+ self.self_attn_mask_templates[offset],
320
+ beam_size * batch_size,
321
+ axis=0
322
  )
323
+
324
+ # 直接使用numpy数组,避免转换
325
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = (
326
+ self.decode_loop_one_token(
327
+ tokens.cpu().numpy().astype(np.int32),
328
+ n_layer_self_k_cache,
329
+ n_layer_self_v_cache,
330
+ n_layer_cross_k,
331
+ n_layer_cross_v,
332
+ pe_np[offset],
333
+ self_attn_mask,
334
+ cross_attn_mask
335
+ )
 
 
336
  )
337
+
338
+ logits = torch.from_numpy(logits).to(self.device).squeeze(1)
 
 
 
339
  t_scores = F.log_softmax(logits, dim=-1)
340
+
341
+ # 优化的beam search
342
+ tokens, scores, prediction_tokens, n_layer_self_k_cache, n_layer_self_v_cache, is_finished = (
343
+ self._optimized_beam_search(
344
+ t_scores, tokens, scores, prediction_tokens,
345
+ n_layer_self_k_cache, n_layer_self_v_cache,
346
+ is_finished, beam_size, batch_size
347
+ )
 
 
 
 
 
 
 
 
 
 
348
  )
349
+
350
+ if is_finished.all():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  break
352
+
353
+ # return self._extract_results(scores, prediction_tokens, batch_size, beam_size, nbest)
354
+ return self.extract_results_numpy_vectorized(scores.numpy(), prediction_tokens.numpy(), batch_size, beam_size, nbest)
355
+
356
+
357
+ def _optimized_beam_search(
358
+ self,
359
+ t_scores: torch.Tensor,
360
+ tokens: torch.Tensor,
361
+ scores: torch.Tensor,
362
+ prediction_tokens: torch.Tensor,
363
+ n_layer_self_k_cache: torch.Tensor,
364
+ n_layer_self_v_cache: torch.Tensor,
365
+ is_finished: torch.Tensor,
366
+ beam_size: int,
367
+ batch_size: int
368
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
369
+ """优化的beam search步骤"""
370
+
371
+ # 使用torch的in-place操作
372
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
373
+
374
+ # 处理已完成的beam
375
+ if is_finished.any():
376
+ # 原地操作,避免创建新tensor
377
+ t_topB_scores.masked_fill_(is_finished, 0.0)
378
+ t_topB_scores[:, 1:].masked_fill_(is_finished, -INF)
379
+ t_topB_ys.masked_fill_(is_finished, self.eos_id)
380
+
381
+ # 更新scores
382
+ scores = scores + t_topB_scores
383
+
384
+ # 优化的topk选择
385
+ scores_2d = scores.view(batch_size, beam_size * beam_size)
386
+ top_scores, top_ids = torch.topk(scores_2d, k=beam_size, dim=1)
387
+ scores = top_scores.view(-1, 1)
388
+
389
+ # 计算索引
390
+ topB_row_number_in_each_B_rows_of_ys = torch.div(top_ids, beam_size, rounding_mode='floor')
391
+ stride = beam_size * torch.arange(batch_size, device=self.device).view(batch_size, 1)
392
+ topB_row_number_in_ys = (topB_row_number_in_each_B_rows_of_ys + stride).view(-1)
393
+
394
+ # 更新tokens和prediction_tokens
395
+ tokens = torch.gather(
396
+ t_topB_ys.view(batch_size, beam_size * beam_size),
397
+ dim=1,
398
+ index=top_ids,
399
+ ).view(beam_size * batch_size, 1)
400
+
401
+ prediction_tokens = torch.cat([
402
+ prediction_tokens[topB_row_number_in_ys],
403
+ tokens
404
+ ], dim=1)
405
+
406
+ # 更新cache(原地操作)
407
+ for i in range(n_layer_self_k_cache.shape[0]):
408
+ n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
409
+ n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
410
+
411
+ # 更新完成状态
412
+ is_finished = tokens.eq(self.eos_id)
413
+
414
+ return tokens, scores, prediction_tokens, n_layer_self_k_cache, n_layer_self_v_cache, is_finished
415
+
416
+ def _optimized_init_self_cache(
417
+ self, batch_size: int, beam_size: int
418
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
419
+ """优化的self cache初始化"""
420
+ shape = (
421
+ self.num_decoder_blocks,
422
+ batch_size * beam_size,
423
+ self.decode_max_len,
424
+ self.decoder_hidden_dim
425
+ )
426
+ n_layer_self_k_cache = np.zeros(shape, dtype=np.float32)
427
+ n_layer_self_v_cache = np.zeros(shape, dtype=np.float32)
428
+ return n_layer_self_k_cache, n_layer_self_v_cache
429
+
430
+ def _extract_results(
431
+ self,
432
+ scores: torch.Tensor,
433
+ prediction_tokens: torch.Tensor,
434
+ batch_size: int,
435
+ beam_size: int,
436
+ nbest: int
437
+ ) -> List[Dict]:
438
+ """提取结果"""
439
  scores = scores.view(batch_size, beam_size)
440
+ valid_lengths = torch.sum(
441
  torch.ne(prediction_tokens.view(batch_size, beam_size, -1), self.eos_id),
442
+ dim=-1
443
  ).int()
444
+
445
  nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
446
+ index = nbest_ids + beam_size * torch.arange(batch_size, device=self.device).unsqueeze(1)
447
+
448
+ nbest_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
449
+ nbest_tokens = nbest_tokens.view(batch_size, nbest_ids.size(1), -1)
450
+
451
+ results = []
 
 
 
 
 
 
 
 
 
452
  for j, score in enumerate(nbest_scores[0]):
453
  hyp = {
454
+ "token_ids": nbest_tokens[0, j, 1:valid_lengths[0, nbest_ids[0, j]]],
 
 
455
  "score": score,
456
  }
457
+ results.append(hyp)
458
+
459
+ return results
460
+
461
+
462
+ def extract_results_numpy_vectorized(
463
+ self,
464
+ scores: np.ndarray,
465
+ prediction_tokens: np.ndarray,
466
+ batch_size: int,
467
+ beam_size: int,
468
+ nbest: int,
469
+ eos_id: int = 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  ) -> List[Dict]:
471
+ """向量化版本的NumPy实现"""
472
+
473
+ # 1. 重塑和计算有效长度
474
+ scores_2d = scores.reshape(batch_size, beam_size)
475
+ tokens_3d = prediction_tokens.reshape(batch_size, beam_size, -1)
476
+
477
+ # 计算有效长度(不包括eos_id)
478
+ valid_lengths = np.sum(tokens_3d != eos_id, axis=-1).astype(np.int32)
479
+
480
+ # 2. 使用argpartition进行部分排序(比argsort更快)
481
+ # 获取最大的nbest个元素的索引
482
+ # 使用argpartition: O(n) vs argsort: O(n log n)
483
+ partitioned_indices = np.argpartition(-scores_2d, nbest-1, axis=1)[:, :nbest]
484
+
485
+ # 对每个batch内的topk进行排序
486
+ nbest_scores = np.take_along_axis(scores_2d, partitioned_indices, axis=1)
487
+ sorted_order = np.argsort(-nbest_scores, axis=1)
488
+
489
+ # 应用排序
490
+ nbest_ids = np.take_along_axis(partitioned_indices, sorted_order, axis=1)
491
+ nbest_scores = np.take_along_axis(nbest_scores, sorted_order, axis=1)
492
+
493
+ # 3. 计算全局索引
494
+ batch_indices = np.arange(batch_size)[:, np.newaxis]
495
+ global_indices = nbest_ids + beam_size * batch_indices
496
+ flat_global_indices = global_indices.reshape(-1)
497
+
498
+ # 4. 提取tokens
499
+ flat_tokens = prediction_tokens.reshape(-1, prediction_tokens.shape[-1])
500
+ nbest_tokens = flat_tokens[flat_global_indices]
501
+ nbest_tokens = nbest_tokens.reshape(batch_size, nbest, -1)
502
+
503
+ # 5. 提取对应的有效长度
504
+ nbest_valid_lengths = np.take_along_axis(valid_lengths, nbest_ids, axis=1)
505
+
506
+ # 6. 构建结果
507
+ results = []
508
+
509
+ for b in range(batch_size):
510
+ batch_results = []
511
+ for j in range(nbest):
512
+ valid_len = nbest_valid_lengths[b, j]
513
+
514
+ # 提取token_ids(跳过<sos>)
515
+ token_ids = nbest_tokens[b, j, 1:valid_len]
516
+
517
+ hyp = {
518
+ "token_ids": token_ids.tolist(),
519
+ "score": float(nbest_scores[b, j]),
520
+ }
521
+ batch_results.append(hyp)
522
+
523
+ # 如果是批量处理,可以按batch返回
524
+ # 这里假设batch_size=1,直接返回第一个batch的结果
525
+ if b == 0:
526
+ results = batch_results
527
+
528
+ return results
529
+
530
+
531
+ def _sequential_transcribe(
532
+ self,
533
+ chunks: List[torch.Tensor],
534
+ beam_size: int,
535
+ nbest: int
536
+ ) -> Dict:
537
+ """顺序转录(单线程)"""
538
  tokens = []
539
+ wav_durations = []
540
+ transcribe_duration = 0
541
+
542
  for chunk in chunks:
543
+ # 优化的特征提取
544
+ feats, lengths, wav_duration = self._optimized_feature_extraction(chunk)
 
 
 
545
  wav_durations.append(wav_duration)
546
+
547
+ # 运行encoder和decoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  start_time = time.time()
549
  n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
550
+ feats, lengths.numpy().astype(np.int32)
551
  )
552
+
553
+ nbest_hyps = self._optimized_decode_loop(
554
  n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
555
  )
556
+
557
+ tokens.extend([int(id) for id in nbest_hyps[0]["token_ids"]])
558
+ transcribe_duration += time.time() - start_time
559
+
560
  text = self.tokenizer.detokenize(tokens)
561
+ return {"text": text}, wav_durations, transcribe_duration
562
+
563
+ def _parallel_transcribe(
564
+ self,
565
+ chunks: List[torch.Tensor],
566
+ beam_size: int,
567
+ nbest: int
568
+ ) -> Dict:
569
+ """并行转录(多线程)"""
570
+ import threading
571
+
572
+ results = []
573
+ lock = threading.Lock()
574
+
575
+ def process_chunk(chunk_idx, chunk):
576
+ try:
577
+ # 特征提取
578
+ feats, lengths, wav_duration = self._optimized_feature_extraction(chunk)
579
+
580
+ # encoder
581
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
582
+ feats, lengths.astype(np.int32)
583
+ )
584
+
585
+ # decoder
586
+ nbest_hyps = self._optimized_decode_loop(
587
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
588
+ )
589
+
590
+ with lock:
591
+ results.append({
592
+ 'chunk_idx': chunk_idx,
593
+ 'tokens': [int(id) for id in nbest_hyps[0]["token_ids"].cpu()],
594
+ 'duration': wav_duration
595
+ })
596
+ except Exception as e:
597
+ print(f"Error processing chunk {chunk_idx}: {e}")
598
+
599
+ # 使用ThreadPoolExecutor并行处理
600
+ with ThreadPoolExecutor(max_workers=min(4, len(chunks))) as executor:
601
+ futures = []
602
+ for i, chunk in enumerate(chunks):
603
+ future = executor.submit(process_chunk, i, chunk)
604
+ futures.append(future)
605
+
606
+ # 等待所有任务完成
607
+ for future in as_completed(futures):
608
+ future.result()
609
+
610
+ # 合并结果
611
+ results.sort(key=lambda x: x['chunk_idx'])
612
+ tokens = []
613
+ wav_durations = []
614
+
615
+ for result in results:
616
+ tokens.extend(result['tokens'])
617
+ wav_durations.append(result['duration'])
618
+
619
+ text = self.tokenizer.detokenize(tokens)
620
+ return {"text": text}, wav_durations, 0 # 并行处理时间不好统计
621
+
622
+ def _optimized_feature_extraction(
623
+ self,
624
+ chunk: torch.Tensor
625
+ ) -> Tuple[np.ndarray, np.ndarray, float]:
626
+ """优化的特征提取"""
627
+ chunk = (chunk.clamp(-1, 1) * 32768).to(torch.int16)
628
+ feats, lengths, wav_duration = self.feature_extractor.run_chunk(
629
+ chunk, self.sample_rate
630
+ )
631
+
632
+ # 原地padding,避免创建新数组
633
+ if feats.shape[1] < self.max_feat_len:
634
+ pad_width = ((0, 0), (0, self.max_feat_len - feats.shape[1]), (0, 0))
635
+ feats = np.pad(feats, pad_width, mode='constant', constant_values=0)
636
+
637
+ feats = feats[:, :self.max_feat_len, :]
638
+ lengths = np.minimum(lengths, self.max_feat_len)
639
+
640
+ return feats, lengths, wav_duration
test_ax_model.py CHANGED
@@ -44,7 +44,7 @@ def parse_args():
44
  parser.add_argument(
45
  "--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
46
  )
47
- parser.add_argument("--beam_size", type=int, default=3, help="")
48
  parser.add_argument("--nbest", type=int, default=1, help="")
49
  parser.add_argument("--decode_max_len", type=int, default=128, help="max token len")
50
  parser.add_argument("--max_dur", type=int, default=10, help="max audio len")
 
44
  parser.add_argument(
45
  "--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
46
  )
47
+ parser.add_argument("--beam_size", type=int, default=1, help="")
48
  parser.add_argument("--nbest", type=int, default=1, help="")
49
  parser.add_argument("--decode_max_len", type=int, default=128, help="max token len")
50
  parser.add_argument("--max_dur", type=int, default=10, help="max audio len")
test_wer.py CHANGED
@@ -183,12 +183,6 @@ def get_args():
183
  default="axmodel/encoder.axmodel",
184
  help="Path to onnx encoder",
185
  )
186
- parser.add_argument(
187
- "--decoder_main",
188
- type=str,
189
- default="axmodel/decoder_main.axmodel",
190
- help="Path to axmodel decoder main",
191
- )
192
  parser.add_argument(
193
  "--decoder_loop",
194
  type=str,
@@ -213,7 +207,7 @@ def get_args():
213
  parser.add_argument(
214
  "--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
215
  )
216
- parser.add_argument("--beam_size", type=int, default=3, help="")
217
  parser.add_argument("--nbest", type=int, default=1, help="")
218
  parser.add_argument("--max_len", type=int, default=128, help="")
219
  return parser.parse_args()
 
183
  default="axmodel/encoder.axmodel",
184
  help="Path to onnx encoder",
185
  )
 
 
 
 
 
 
186
  parser.add_argument(
187
  "--decoder_loop",
188
  type=str,
 
207
  parser.add_argument(
208
  "--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
209
  )
210
+ parser.add_argument("--beam_size", type=int, default=1, help="")
211
  parser.add_argument("--nbest", type=int, default=1, help="")
212
  parser.add_argument("--max_len", type=int, default=128, help="")
213
  return parser.parse_args()