inoryQwQ commited on
Commit
5d4703e
·
1 Parent(s): c6f1198

Add VAD to asr

Browse files
Files changed (6) hide show
  1. README.md +6 -0
  2. fireredasr/data/asr_feat.py +15 -0
  3. fireredasr_axmodel.py +232 -170
  4. fireredasr_onnx.py +529 -0
  5. test_ax_model.py +45 -76
  6. test_wer.py +115 -113
README.md CHANGED
@@ -19,6 +19,12 @@ license: apache-2.0
19
 
20
  ## 安装依赖
21
 
 
 
 
 
 
 
22
  ### Python
23
 
24
  测试环境为Python 3.12,建议使用[Miniconda](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh
 
19
 
20
  ## 安装依赖
21
 
22
+ ### Audio backend
23
+
24
+ ```
25
+ sudo apt install libsnffile1
26
+ ```
27
+
28
  ### Python
29
 
30
  测试环境为Python 3.12,建议使用[Miniconda](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh
fireredasr/data/asr_feat.py CHANGED
@@ -18,6 +18,7 @@ class ASRFeatExtractor:
18
  durs = []
19
  for wav_path in wav_paths:
20
  sample_rate, wav_np = kaldiio.load_mat(wav_path)
 
21
  dur = wav_np.shape[0] / sample_rate
22
  fbank = self.fbank((sample_rate, wav_np))
23
  if self.cmvn is not None:
@@ -28,6 +29,20 @@ class ASRFeatExtractor:
28
  lengths = torch.tensor([feat.size(0) for feat in feats]).long()
29
  feats_pad = self.pad_feat(feats, 0.0)
30
  return feats_pad, lengths, durs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def pad_feat(self, xs, pad_value):
33
  # type: (List[Tensor], int) -> Tensor
 
18
  durs = []
19
  for wav_path in wav_paths:
20
  sample_rate, wav_np = kaldiio.load_mat(wav_path)
21
+
22
  dur = wav_np.shape[0] / sample_rate
23
  fbank = self.fbank((sample_rate, wav_np))
24
  if self.cmvn is not None:
 
29
  lengths = torch.tensor([feat.size(0) for feat in feats]).long()
30
  feats_pad = self.pad_feat(feats, 0.0)
31
  return feats_pad, lengths, durs
32
+
33
+ def run_chunk(self, wav_np, sample_rate):
34
+ feats = []
35
+
36
+ dur = wav_np.shape[0] / sample_rate
37
+ fbank = self.fbank((sample_rate, wav_np))
38
+ if self.cmvn is not None:
39
+ fbank = self.cmvn(fbank)
40
+ fbank = torch.from_numpy(fbank).float()
41
+ feats.append(fbank)
42
+
43
+ lengths = torch.tensor([feat.size(0) for feat in feats]).long()
44
+ feats_pad = self.pad_feat(feats, 0.0)
45
+ return feats_pad, lengths, dur
46
 
47
  def pad_feat(self, xs, pad_value):
48
  # type: (List[Tensor], int) -> Tensor
fireredasr_axmodel.py CHANGED
@@ -9,9 +9,19 @@ from torch import Tensor
9
  from typing import Tuple, List, Dict
10
  import os
11
  import time
 
 
 
 
 
 
 
 
 
12
 
13
  INF = 1e10
14
 
 
15
  def to_numpy(tensor):
16
  if isinstance(tensor, np.ndarray):
17
  return tensor
@@ -19,12 +29,12 @@ def to_numpy(tensor):
19
  return tensor.detach().cpu().numpy()
20
  else:
21
  return tensor.cpu().numpy()
22
-
23
-
24
  def set_finished_beam_score_to_zero(scores, is_finished):
25
  NB, B = scores.size()
26
  is_finished = is_finished.float()
27
- mask_score = torch.tensor([0.0] + [-INF]*(B-1)).float()
28
  mask_score = mask_score.view(1, B).repeat(NB, 1)
29
  return scores * (1 - is_finished) + mask_score * is_finished
30
 
@@ -36,21 +46,21 @@ def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
36
 
37
  class FireRedASRAxModel:
38
  def __init__(
39
- self,
40
- encoder_path: str,
41
  decoder_loop_path: str,
42
  cmvn_file: str,
43
- dict_file: str,
44
  spm_model_path: str,
45
- providers=['AxEngineExecutionProvider'],
46
  decode_max_len=128,
47
- audio_dur=10
48
  ):
49
  # NOTE: 参考whisper设置的最大的解码长度
50
  # FireRedASR-AED 模型支持的最长语音为 60s
51
  # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
52
  self.decode_max_len = decode_max_len
53
-
54
  self.decoder_hidden_dim = 1280
55
  self.audio_dur = audio_dur
56
  self.max_feat_len = self.calc_feat_len(audio_dur)
@@ -59,47 +69,35 @@ class FireRedASRAxModel:
59
  self.sos_id = 3
60
  self.eos_id = 4
61
  self.pad_id = 2
62
-
63
  self.feature_extractor = ASRFeatExtractor(cmvn_file)
64
  self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
65
-
66
  self.init_encoder(encoder_path, providers)
67
  self.init_decoder_loop(decoder_loop_path, providers)
68
  self.pe = self.init_pe(decoder_loop_path)
69
-
 
 
70
  def init_encoder(self, encoder_path, providers=None):
71
- self.encoder = axe.InferenceSession(
72
- encoder_path,
73
- providers=providers
74
- )
75
 
76
  def init_decoder_loop(self, decoder_path, providers=None):
77
- self.decoder_loop = axe.InferenceSession(
78
- decoder_path,
79
- providers=providers
80
- )
81
 
82
  def init_pe(self, decoder_path):
83
  decoder_path = os.path.dirname(decoder_path)
84
  decoder_path = os.path.join(decoder_path, "pe.npy")
85
-
86
  return np.load(decoder_path)
87
-
88
- def run_encoder(self, input: np.ndarray,
89
- input_length: np.ndarray
90
  ) -> Tuple[Tensor, Tensor, Tensor]:
91
  n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
92
- None,
93
- {
94
- "encoder_input": input,
95
- "encoder_input_lengths": input_length
96
- }
97
- )
98
- return (
99
- n_layer_cross_k,
100
- n_layer_cross_v,
101
- cross_attn_mask
102
  )
 
103
 
104
  def decode_loop_one_token(
105
  self,
@@ -110,9 +108,13 @@ class FireRedASRAxModel:
110
  n_layer_cross_v_cache: np.ndarray,
111
  pe: np.ndarray,
112
  self_attn_mask: np.ndarray,
113
- cross_attn_mask: np.ndarray
114
  ) -> Tuple[Tensor, Tensor, Tensor]:
115
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_loop.run(
 
 
 
 
116
  None,
117
  {
118
  "tokens": tokens,
@@ -123,52 +125,50 @@ class FireRedASRAxModel:
123
  "pe": pe,
124
  "self_attn_mask": self_attn_mask,
125
  "cross_attn_mask": cross_attn_mask,
126
- }
127
  )
128
- return (
129
- logits,
130
- out_n_layer_self_k_cache,
131
- out_n_layer_self_v_cache
132
- )
133
-
134
  def run_decoder(
135
- self,
136
- n_layer_cross_k,
137
- n_layer_cross_v,
138
- cross_attn_mask,
139
- beam_size,
140
- nbest
141
  ):
142
-
143
  num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
144
  encoder_out_length = cross_attn_mask.shape[-1]
145
-
146
  cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
147
- cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
148
- 1, beam_size, 1, 1
149
- ).view(beam_size * batch_size, -1, encoder_out_length)
150
-
 
 
151
  n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
152
  n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
153
- n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
154
- 1, 1, beam_size, 1, 1
155
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
156
- n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
157
- 1, 1, beam_size, 1, 1
158
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
159
-
160
- prediction_tokens = torch.ones(
161
- beam_size * batch_size, 1).fill_(self.sos_id).long()
 
 
 
 
 
162
  tokens = prediction_tokens
163
  offset = torch.zeros(1, dtype=torch.int64)
164
  n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
165
  batch_size, beam_size
166
  )
167
-
168
- scores = torch.tensor([0.0] + [-INF]*(beam_size - 1)).float()
169
  scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
170
  is_finished = torch.zeros_like(scores)
171
-
172
  self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
173
 
174
  for i in range(self.decode_max_len):
@@ -180,95 +180,111 @@ class FireRedASRAxModel:
180
  n_layer_cross_v = to_numpy(n_layer_cross_v)
181
  cross_attn_mask = to_numpy(cross_attn_mask)
182
 
183
- self_attn_mask = np.zeros((batch_size * beam_size, 1, self.decode_max_len), dtype=np.float32)
184
- self_attn_mask[:, :, :self.decode_max_len - offset[0] - 1] = -np.inf
185
-
186
- logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
187
- to_numpy(tokens),
188
- to_numpy(n_layer_self_k_cache),
189
- to_numpy(n_layer_self_v_cache),
190
- to_numpy(n_layer_cross_k),
191
- to_numpy(n_layer_cross_v),
192
- self.pe[offset],
193
- self_attn_mask,
194
- to_numpy(cross_attn_mask)
195
- )
196
-
 
 
 
 
 
 
197
  offset += 1
198
  logits = torch.from_numpy(logits)
199
-
200
  logits = logits.squeeze(1)
201
  t_scores = F.log_softmax(logits, dim=-1)
202
  t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
203
  t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
204
  t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
205
-
206
  scores = scores + t_topB_scores
207
-
208
  scores = scores.view(batch_size, beam_size * beam_size)
209
  scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
210
  scores = scores.view(-1, 1)
211
-
212
  topB_row_number_in_each_B_rows_of_ys = torch.div(
213
- topB_score_ids, beam_size).view(batch_size * beam_size)
214
- stride = beam_size * torch.arange(batch_size).view(
215
- batch_size, 1).repeat(1, beam_size).view(batch_size * beam_size)
216
- topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
217
-
 
 
 
 
218
  prediction_tokens = prediction_tokens[topB_row_number_in_ys]
219
  t_ys = torch.gather(
220
- t_topB_ys.view(batch_size, beam_size * beam_size),
221
- dim=1, index=topB_score_ids
 
222
  ).view(beam_size * batch_size, 1)
223
-
224
  tokens = t_ys
225
-
226
  prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
227
-
228
  n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
229
  n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
230
-
231
  for i, self_k_cache in enumerate(n_layer_self_k_cache):
232
  n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
233
-
234
  for i, self_v_cache in enumerate(n_layer_self_v_cache):
235
  n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
236
-
237
  is_finished = t_ys.eq(self.eos_id)
238
  if is_finished.sum().item() == beam_size * batch_size:
239
  break
240
-
241
  scores = scores.view(batch_size, beam_size)
242
  prediction_valid_token_lengths = torch.sum(
243
- torch.ne(
244
- prediction_tokens.view(batch_size, beam_size, -1),
245
- self.eos_id),
246
- dim=-1
247
  ).int()
248
-
249
  nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
250
- index = nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
251
- nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
252
- nbest_prediction_tokens = nbest_prediction_tokens.view(batch_size, nbest_ids.size(1), -1)
 
 
 
 
 
 
253
  nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
254
- batch_size * beam_size)[index.view(-1)].view(batch_size, -1)
255
- nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
256
- for i in range(batch_size):
257
- i_best_hyps: List[Dict[str, torch.Tensor]] = []
258
- for j, score in enumerate(nbest_scores[i]):
259
- hyp = {
260
- "token_ids": nbest_prediction_tokens[i, j, 1:nbest_prediction_valid_token_lengths[i, j]],
261
- "score": score
262
- }
263
- i_best_hyps.append(hyp)
264
- nbest_hyps.append(i_best_hyps)
265
-
266
- return nbest_hyps
267
-
268
- def get_initialized_self_cache(self,
269
- batch_size,
270
- beam_size
271
- ) -> Tuple[Tensor, Tensor]:
 
272
  n_layer_self_k_cache = torch.zeros(
273
  self.num_decoder_blocks,
274
  batch_size * beam_size,
@@ -282,55 +298,101 @@ class FireRedASRAxModel:
282
  self.decoder_hidden_dim,
283
  )
284
  return n_layer_self_k_cache, n_layer_self_v_cache
285
-
286
  def calc_feat_len(self, audio_dur):
287
  import math
288
- sample_rate = 16000
 
289
  frame_length = 25 * sample_rate / 1000
290
  frame_shift = 10 * sample_rate / 1000
291
  length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
292
  return length
293
-
294
- def transcribe(self,
295
- batch_wav_path: List[str],
296
- beam_size: int = 1,
297
- nbest: int = 1
298
- ) -> List[Dict]:
299
- feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
300
- # print(f"feats.shape: {feats.shape}")
301
- if feats.shape[1] < self.max_feat_len:
302
- feats = np.concatenate([feats, np.zeros((1, self.max_feat_len - feats.shape[1], 80), dtype=np.float32)], axis=1)
303
- feats = feats[:, :self.max_feat_len, :]
304
- lengths = torch.minimum(lengths, torch.tensor(self.max_feat_len))
305
-
306
- feats = to_numpy(feats)
307
- lengths = to_numpy(lengths).astype(np.int32)
308
-
309
- start_time = time.time()
310
- n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
311
- to_numpy(feats),
312
- to_numpy(lengths)
313
- )
314
- # print(f"run encoder take {(time.time() - start_time) * 1000}ms")
315
- nbest_hyps = self.run_decoder(n_layer_cross_k,
316
- n_layer_cross_v,
317
- cross_attn_mask,
318
- beam_size,
319
- nbest,
320
- )
321
- transcribe_durations = time.time() - start_time
322
- results: List[Dict] = []
323
- for wav, hyp in zip(batch_wav_path, nbest_hyps):
324
- hyp = hyp[0]
325
- hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
326
- score = hyp["score"].item()
327
- text = self.tokenizer.detokenize(hyp_ids)
328
- results.append(
329
- {
330
- "wav": wav,
331
- "text": text,
332
- "score": score
333
- }
 
 
 
 
 
 
 
 
 
 
 
 
334
  )
335
-
336
- return results, wav_durations, transcribe_durations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from typing import Tuple, List, Dict
10
  import os
11
  import time
12
+ import torchaudio
13
+
14
+ try:
15
+ torchaudio.set_audio_backend("soundfile")
16
+ except Exception as e:
17
+ print("Please run apt install libsnffile1 first")
18
+ raise e
19
+
20
+ from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
21
 
22
  INF = 1e10
23
 
24
+
25
  def to_numpy(tensor):
26
  if isinstance(tensor, np.ndarray):
27
  return tensor
 
29
  return tensor.detach().cpu().numpy()
30
  else:
31
  return tensor.cpu().numpy()
32
+
33
+
34
  def set_finished_beam_score_to_zero(scores, is_finished):
35
  NB, B = scores.size()
36
  is_finished = is_finished.float()
37
+ mask_score = torch.tensor([0.0] + [-INF] * (B - 1)).float()
38
  mask_score = mask_score.view(1, B).repeat(NB, 1)
39
  return scores * (1 - is_finished) + mask_score * is_finished
40
 
 
46
 
47
  class FireRedASRAxModel:
48
  def __init__(
49
+ self,
50
+ encoder_path: str,
51
  decoder_loop_path: str,
52
  cmvn_file: str,
53
+ dict_file: str,
54
  spm_model_path: str,
55
+ providers=["AxEngineExecutionProvider"],
56
  decode_max_len=128,
57
+ audio_dur=10,
58
  ):
59
  # NOTE: 参考whisper设置的最大的解码长度
60
  # FireRedASR-AED 模型支持的最长语音为 60s
61
  # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
62
  self.decode_max_len = decode_max_len
63
+ self.sample_rate = 16000
64
  self.decoder_hidden_dim = 1280
65
  self.audio_dur = audio_dur
66
  self.max_feat_len = self.calc_feat_len(audio_dur)
 
69
  self.sos_id = 3
70
  self.eos_id = 4
71
  self.pad_id = 2
72
+
73
  self.feature_extractor = ASRFeatExtractor(cmvn_file)
74
  self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
75
+
76
  self.init_encoder(encoder_path, providers)
77
  self.init_decoder_loop(decoder_loop_path, providers)
78
  self.pe = self.init_pe(decoder_loop_path)
79
+
80
+ self.vad_model = load_silero_vad()
81
+
82
  def init_encoder(self, encoder_path, providers=None):
83
+ self.encoder = axe.InferenceSession(encoder_path, providers=providers)
 
 
 
84
 
85
  def init_decoder_loop(self, decoder_path, providers=None):
86
+ self.decoder_loop = axe.InferenceSession(decoder_path, providers=providers)
 
 
 
87
 
88
  def init_pe(self, decoder_path):
89
  decoder_path = os.path.dirname(decoder_path)
90
  decoder_path = os.path.join(decoder_path, "pe.npy")
91
+
92
  return np.load(decoder_path)
93
+
94
+ def run_encoder(
95
+ self, input: np.ndarray, input_length: np.ndarray
96
  ) -> Tuple[Tensor, Tensor, Tensor]:
97
  n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
98
+ None, {"encoder_input": input, "encoder_input_lengths": input_length}
 
 
 
 
 
 
 
 
 
99
  )
100
+ return (n_layer_cross_k, n_layer_cross_v, cross_attn_mask)
101
 
102
  def decode_loop_one_token(
103
  self,
 
108
  n_layer_cross_v_cache: np.ndarray,
109
  pe: np.ndarray,
110
  self_attn_mask: np.ndarray,
111
+ cross_attn_mask: np.ndarray,
112
  ) -> Tuple[Tensor, Tensor, Tensor]:
113
+ (
114
+ logits,
115
+ out_n_layer_self_k_cache,
116
+ out_n_layer_self_v_cache,
117
+ ) = self.decoder_loop.run(
118
  None,
119
  {
120
  "tokens": tokens,
 
125
  "pe": pe,
126
  "self_attn_mask": self_attn_mask,
127
  "cross_attn_mask": cross_attn_mask,
128
+ },
129
  )
130
+ return (logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache)
131
+
 
 
 
 
132
  def run_decoder(
133
+ self, n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
 
 
 
 
 
134
  ):
135
+
136
  num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
137
  encoder_out_length = cross_attn_mask.shape[-1]
138
+
139
  cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
140
+ cross_attn_mask = (
141
+ cross_attn_mask.unsqueeze(1)
142
+ .repeat(1, beam_size, 1, 1)
143
+ .view(beam_size * batch_size, -1, encoder_out_length)
144
+ )
145
+
146
  n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
147
  n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
148
+ n_layer_cross_k = (
149
+ n_layer_cross_k.unsqueeze(2)
150
+ .repeat(1, 1, beam_size, 1, 1)
151
+ .view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
152
+ )
153
+ n_layer_cross_v = (
154
+ n_layer_cross_v.unsqueeze(2)
155
+ .repeat(1, 1, beam_size, 1, 1)
156
+ .view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
157
+ )
158
+
159
+ prediction_tokens = (
160
+ torch.ones(beam_size * batch_size, 1).fill_(self.sos_id).long()
161
+ )
162
  tokens = prediction_tokens
163
  offset = torch.zeros(1, dtype=torch.int64)
164
  n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
165
  batch_size, beam_size
166
  )
167
+
168
+ scores = torch.tensor([0.0] + [-INF] * (beam_size - 1)).float()
169
  scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
170
  is_finished = torch.zeros_like(scores)
171
+
172
  self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
173
 
174
  for i in range(self.decode_max_len):
 
180
  n_layer_cross_v = to_numpy(n_layer_cross_v)
181
  cross_attn_mask = to_numpy(cross_attn_mask)
182
 
183
+ self_attn_mask = np.zeros(
184
+ (batch_size * beam_size, 1, self.decode_max_len), dtype=np.float32
185
+ )
186
+ self_attn_mask[:, :, : self.decode_max_len - offset[0] - 1] = -np.inf
187
+
188
+ (
189
+ logits,
190
+ n_layer_self_k_cache,
191
+ n_layer_self_v_cache,
192
+ ) = self.decode_loop_one_token(
193
+ to_numpy(tokens),
194
+ to_numpy(n_layer_self_k_cache),
195
+ to_numpy(n_layer_self_v_cache),
196
+ to_numpy(n_layer_cross_k),
197
+ to_numpy(n_layer_cross_v),
198
+ self.pe[offset],
199
+ self_attn_mask,
200
+ to_numpy(cross_attn_mask),
201
+ )
202
+
203
  offset += 1
204
  logits = torch.from_numpy(logits)
205
+
206
  logits = logits.squeeze(1)
207
  t_scores = F.log_softmax(logits, dim=-1)
208
  t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
209
  t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
210
  t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
211
+
212
  scores = scores + t_topB_scores
213
+
214
  scores = scores.view(batch_size, beam_size * beam_size)
215
  scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
216
  scores = scores.view(-1, 1)
217
+
218
  topB_row_number_in_each_B_rows_of_ys = torch.div(
219
+ topB_score_ids, beam_size
220
+ ).view(batch_size * beam_size)
221
+ stride = beam_size * torch.arange(batch_size).view(batch_size, 1).repeat(
222
+ 1, beam_size
223
+ ).view(batch_size * beam_size)
224
+ topB_row_number_in_ys = (
225
+ topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
226
+ )
227
+
228
  prediction_tokens = prediction_tokens[topB_row_number_in_ys]
229
  t_ys = torch.gather(
230
+ t_topB_ys.view(batch_size, beam_size * beam_size),
231
+ dim=1,
232
+ index=topB_score_ids,
233
  ).view(beam_size * batch_size, 1)
234
+
235
  tokens = t_ys
236
+
237
  prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
238
+
239
  n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
240
  n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
241
+
242
  for i, self_k_cache in enumerate(n_layer_self_k_cache):
243
  n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
244
+
245
  for i, self_v_cache in enumerate(n_layer_self_v_cache):
246
  n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
247
+
248
  is_finished = t_ys.eq(self.eos_id)
249
  if is_finished.sum().item() == beam_size * batch_size:
250
  break
251
+
252
  scores = scores.view(batch_size, beam_size)
253
  prediction_valid_token_lengths = torch.sum(
254
+ torch.ne(prediction_tokens.view(batch_size, beam_size, -1), self.eos_id),
255
+ dim=-1,
 
 
256
  ).int()
257
+
258
  nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
259
+ index = (
260
+ nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
261
+ )
262
+ nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[
263
+ index.view(-1)
264
+ ]
265
+ nbest_prediction_tokens = nbest_prediction_tokens.view(
266
+ batch_size, nbest_ids.size(1), -1
267
+ )
268
  nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
269
+ batch_size * beam_size
270
+ )[index.view(-1)].view(batch_size, -1)
271
+
272
+ # batch_size is always 1
273
+ i_best_hyps: List[Dict[str, torch.Tensor]] = []
274
+ for j, score in enumerate(nbest_scores[0]):
275
+ hyp = {
276
+ "token_ids": nbest_prediction_tokens[
277
+ 0, j, 1 : nbest_prediction_valid_token_lengths[0, j]
278
+ ],
279
+ "score": score,
280
+ }
281
+ i_best_hyps.append(hyp)
282
+
283
+ return i_best_hyps
284
+
285
+ def get_initialized_self_cache(
286
+ self, batch_size, beam_size
287
+ ) -> Tuple[Tensor, Tensor]:
288
  n_layer_self_k_cache = torch.zeros(
289
  self.num_decoder_blocks,
290
  batch_size * beam_size,
 
298
  self.decoder_hidden_dim,
299
  )
300
  return n_layer_self_k_cache, n_layer_self_v_cache
301
+
302
  def calc_feat_len(self, audio_dur):
303
  import math
304
+
305
+ sample_rate = self.sample_rate
306
  frame_length = 25 * sample_rate / 1000
307
  frame_shift = 10 * sample_rate / 1000
308
  length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
309
  return length
310
+
311
+ def collect_chunks(self, wav, speech_timestamps, audio_dur, sample_rate):
312
+ max_chunk_samples = int(audio_dur * sample_rate)
313
+ chunks = []
314
+ for ts in speech_timestamps:
315
+ start, end = ts["start"], ts["end"]
316
+ cur_chunk = wav[start:end]
317
+ if (
318
+ len(chunks) > 0
319
+ and chunks[-1].shape[0] + cur_chunk.shape[0] < max_chunk_samples
320
+ ):
321
+ chunks[-1] = torch.concat([chunks[-1], cur_chunk], dim=0)
322
+ else:
323
+ if cur_chunk.shape[0] > max_chunk_samples:
324
+ # greedy split if one chunk is too big
325
+ chunks.append(cur_chunk[:max_chunk_samples])
326
+ chunks.append(cur_chunk[max_chunk_samples:])
327
+ else:
328
+ chunks.append(cur_chunk)
329
+ return chunks
330
+
331
+ def transcribe(
332
+ self, batch_wav_path: List[str], beam_size: int = 1, nbest: int = 1
333
+ ) -> List[Dict]:
334
+
335
+ # Run vad, greedy split audio to fit audio_dur
336
+ try:
337
+ wav = read_audio(batch_wav_path[0], sampling_rate=self.sample_rate)
338
+ except Exception as e:
339
+ print("Please run apt install libsnffile1 first")
340
+ raise e
341
+
342
+ max_chunk_samples = int(self.sample_rate * self.audio_dur)
343
+ if wav.shape[0] < max_chunk_samples:
344
+ chunks = [wav]
345
+ else:
346
+ speech_timestamps = get_speech_timestamps(
347
+ wav,
348
+ self.vad_model,
349
+ return_seconds=False, # Return speech timestamps in seconds (default is samples)
350
+ )
351
+ chunks = self.collect_chunks(
352
+ wav, speech_timestamps, self.audio_dur, self.sample_rate
353
+ )
354
+ # print(f"Split to {len(chunks)} chunks")
355
+
356
+ transcribe_durations = 0
357
+ wav_durations = []
358
+ tokens = []
359
+ for chunk in chunks:
360
+ chunk = (chunk.clamp(-1, 1) * 32768).to(torch.int16)
361
+ feats, lengths, wav_duration = self.feature_extractor.run_chunk(
362
+ chunk, self.sample_rate
363
  )
364
+
365
+ wav_durations.append(wav_duration)
366
+
367
+ if feats.shape[1] < self.max_feat_len:
368
+ feats = np.concatenate(
369
+ [
370
+ feats,
371
+ np.zeros(
372
+ (1, self.max_feat_len - feats.shape[1], 80),
373
+ dtype=np.float32,
374
+ ),
375
+ ],
376
+ axis=1,
377
+ )
378
+ feats = feats[:, : self.max_feat_len, :]
379
+ lengths = torch.minimum(lengths, torch.tensor(self.max_feat_len))
380
+
381
+ feats = to_numpy(feats)
382
+ lengths = to_numpy(lengths).astype(np.int32)
383
+
384
+ start_time = time.time()
385
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
386
+ to_numpy(feats), to_numpy(lengths)
387
+ )
388
+ # print(f"run encoder take {(time.time() - start_time) * 1000}ms")
389
+ nbest_hyps = self.run_decoder(
390
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
391
+ )
392
+ tokens.extend([int(id) for id in nbest_hyps[0]["token_ids"].cpu()])
393
+
394
+ transcribe_durations += time.time() - start_time
395
+
396
+ text = self.tokenizer.detokenize(tokens)
397
+
398
+ return {"text": text}, wav_durations, transcribe_durations
fireredasr_onnx.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fireredasr.data.asr_feat import ASRFeatExtractor
2
+ from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
3
+
4
+ import onnxruntime as ort
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from torch import Tensor
9
+ from typing import Tuple, List, Dict
10
+ import argparse
11
+ import os
12
+ import time
13
+ import logging
14
+
15
+ logger = logging.getLogger()
16
+ logger.setLevel(logging.INFO)
17
+ logger_stream_hander = logging.StreamHandler()
18
+ logger_stream_hander.setLevel("INFO")
19
+ logger.addHandler(logger_stream_hander)
20
+
21
+
22
+ INF = 1e10
23
+
24
+
25
+ def to_numpy(tensor):
26
+ if isinstance(tensor, np.ndarray):
27
+ return tensor
28
+ if tensor.requires_grad:
29
+ return tensor.detach().cpu().numpy()
30
+ else:
31
+ return tensor.cpu().numpy()
32
+
33
+
34
+ def set_finished_beam_score_to_zero(scores, is_finished):
35
+ NB, B = scores.size()
36
+ is_finished = is_finished.float()
37
+ mask_score = torch.tensor([0.0] + [-INF] * (B - 1)).float()
38
+ mask_score = mask_score.view(1, B).repeat(NB, 1)
39
+ return scores * (1 - is_finished) + mask_score * is_finished
40
+
41
+
42
+ def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
43
+ is_finished = is_finished.long()
44
+ return ys * (1 - is_finished) + eos_id * is_finished
45
+
46
+
47
+ class FireRedASROnnxModel:
48
+ def __init__(
49
+ self,
50
+ encoder_path: str,
51
+ decoder_path: str,
52
+ cmvn_file: str,
53
+ dict_file: str,
54
+ spm_model_path: str,
55
+ providers=["CUDAExecutionProvider"],
56
+ decode_max_len=128,
57
+ audio_dur=10,
58
+ ):
59
+ session_opts = ort.SessionOptions()
60
+ session_opts.inter_op_num_threads = 1
61
+ session_opts.intra_op_num_threads = 1
62
+ # session_opts.log_severity_level = 1
63
+ self.session_opts = session_opts
64
+
65
+ # NOTE: 参考whisper设置的最大的解码长度
66
+ # FireRedASR-AED 模型支持的最长语音为 60s
67
+ # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
68
+ self.decode_max_len = decode_max_len
69
+
70
+ self.decoder_hidden_dim = 1280
71
+ self.num_decoder_blocks = 16
72
+ self.blank_id = 0
73
+ self.sos_id = 3
74
+ self.eos_id = 4
75
+ self.pad_id = 2
76
+
77
+ self.feature_extractor = ASRFeatExtractor(cmvn_file)
78
+ self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
79
+ self.encoder = None
80
+ self.decoder = None
81
+ self.audio_dur = audio_dur
82
+
83
+ self.init_encoder(encoder_path, providers)
84
+ self.init_decoder_main(decoder_path, providers)
85
+ self.init_decoder_loop(decoder_path, providers)
86
+ self.pe = self.init_pe(decoder_path)
87
+
88
+ def init_encoder(self, encoder_path, providers=None):
89
+ start_time = time.time()
90
+ self.encoder = ort.InferenceSession(
91
+ encoder_path, sess_options=self.session_opts, providers=providers
92
+ )
93
+ end_time = time.time()
94
+ logger.info(f"load encoder cost {end_time - start_time} seconds")
95
+
96
+ def init_decoder(self, decoder_path, providers=None):
97
+ start_time = time.time()
98
+ self.decoder = ort.InferenceSession(
99
+ decoder_path, sess_options=self.session_opts, providers=providers
100
+ )
101
+ end_time = time.time()
102
+ logger.info(f"load decoder cost {end_time - start_time} seconds")
103
+
104
+ def init_decoder_main(self, decoder_path, providers=None):
105
+ decoder_path = os.path.dirname(decoder_path)
106
+ decoder_path = os.path.join(decoder_path, "decoder_main.onnx")
107
+ start_time = time.time()
108
+ self.decoder_main = ort.InferenceSession(
109
+ decoder_path, sess_options=self.session_opts, providers=providers
110
+ )
111
+ end_time = time.time()
112
+ logger.info(f"load decoder_main cost {end_time - start_time} seconds")
113
+
114
+ input_names = [i.name for i in self.decoder_main.get_inputs()]
115
+ print(f"decoder_main.input_names: {input_names}")
116
+
117
+ def init_decoder_loop(self, decoder_path, providers=None):
118
+ decoder_path = os.path.dirname(decoder_path)
119
+ decoder_path = os.path.join(decoder_path, "decoder_loop.onnx")
120
+
121
+ start_time = time.time()
122
+ self.decoder_loop = ort.InferenceSession(
123
+ decoder_path, sess_options=self.session_opts, providers=providers
124
+ )
125
+ end_time = time.time()
126
+ logger.info(f"load decoder_loop cost {end_time - start_time} seconds")
127
+
128
+ input_names = [i.name for i in self.decoder_loop.get_inputs()]
129
+ print(f"decoder_loop.input_names: {input_names}")
130
+
131
+ def init_pe(self, decoder_path):
132
+ decoder_path = os.path.dirname(decoder_path)
133
+ decoder_path = os.path.join(decoder_path, "pe.npy")
134
+
135
+ return np.load(decoder_path)
136
+
137
+ def run_encoder(
138
+ self, input: np.ndarray, input_length: np.ndarray
139
+ ) -> Tuple[Tensor, Tensor, Tensor]:
140
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
141
+ None,
142
+ {
143
+ self.encoder.get_inputs()[0].name: input,
144
+ self.encoder.get_inputs()[1].name: input_length,
145
+ },
146
+ )
147
+ return (n_layer_cross_k, n_layer_cross_v, cross_attn_mask)
148
+
149
+ def decode_one_token(
150
+ self,
151
+ tokens: np.ndarray,
152
+ n_layer_self_k_cache: np.ndarray,
153
+ n_layer_self_v_cache: np.ndarray,
154
+ n_layer_cross_k_cache: np.ndarray,
155
+ n_layer_cross_v_cache: np.ndarray,
156
+ offset: np.ndarray,
157
+ self_attn_mask: np.ndarray,
158
+ cross_attn_mask: np.ndarray,
159
+ ) -> Tuple[Tensor, Tensor, Tensor]:
160
+ # print("decode:")
161
+ # print(f"tokens.shape: {tokens.shape}")
162
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
163
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
164
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
165
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
166
+ # print(f"offset.shape: {offset.shape}")
167
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
168
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
169
+ # print(f"self_attn_mask: {self_attn_mask}")
170
+
171
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
172
+ None,
173
+ {
174
+ self.decoder.get_inputs()[0].name: tokens,
175
+ self.decoder.get_inputs()[1].name: n_layer_self_k_cache,
176
+ self.decoder.get_inputs()[2].name: n_layer_self_v_cache,
177
+ self.decoder.get_inputs()[3].name: n_layer_cross_k_cache,
178
+ self.decoder.get_inputs()[4].name: n_layer_cross_v_cache,
179
+ self.decoder.get_inputs()[5].name: offset,
180
+ self.decoder.get_inputs()[6].name: self_attn_mask,
181
+ self.decoder.get_inputs()[7].name: cross_attn_mask,
182
+ },
183
+ )
184
+ return (logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache)
185
+
186
+ def decode_main_one_token(
187
+ self,
188
+ tokens: np.ndarray,
189
+ n_layer_self_k_cache: np.ndarray,
190
+ n_layer_self_v_cache: np.ndarray,
191
+ n_layer_cross_k_cache: np.ndarray,
192
+ n_layer_cross_v_cache: np.ndarray,
193
+ pe: np.ndarray,
194
+ self_attn_mask: np.ndarray,
195
+ cross_attn_mask: np.ndarray,
196
+ ) -> Tuple[Tensor, Tensor, Tensor]:
197
+ # print("decode_main:")
198
+ # print(f"tokens.shape: {tokens.shape}")
199
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
200
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
201
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
202
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
203
+ # print(f"pe.shape: {pe.shape}")
204
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
205
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
206
+
207
+ (
208
+ logits,
209
+ out_n_layer_self_k_cache,
210
+ out_n_layer_self_v_cache,
211
+ ) = self.decoder_main.run(
212
+ None,
213
+ {
214
+ self.decoder_main.get_inputs()[0].name: tokens,
215
+ # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
216
+ self.decoder_main.get_inputs()[1].name: n_layer_cross_k_cache,
217
+ self.decoder_main.get_inputs()[2].name: n_layer_cross_v_cache,
218
+ # self.decoder_main.get_inputs()[3].name: pe,
219
+ # self.decoder_main.get_inputs()[4].name: self_attn_mask,
220
+ self.decoder_main.get_inputs()[3].name: cross_attn_mask,
221
+ # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
222
+ },
223
+ )
224
+ return (logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache)
225
+
226
+ def decode_loop_one_token(
227
+ self,
228
+ tokens: np.ndarray,
229
+ n_layer_self_k_cache: np.ndarray,
230
+ n_layer_self_v_cache: np.ndarray,
231
+ n_layer_cross_k_cache: np.ndarray,
232
+ n_layer_cross_v_cache: np.ndarray,
233
+ pe: np.ndarray,
234
+ self_attn_mask: np.ndarray,
235
+ cross_attn_mask: np.ndarray,
236
+ ) -> Tuple[Tensor, Tensor, Tensor]:
237
+ # print("decode_loop:")
238
+ # print(f"tokens.shape: {tokens.shape}")
239
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
240
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
241
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
242
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
243
+ # print(f"pe.shape: {pe.shape}")
244
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
245
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
246
+
247
+ (
248
+ logits,
249
+ out_n_layer_self_k_cache,
250
+ out_n_layer_self_v_cache,
251
+ ) = self.decoder_loop.run(
252
+ None,
253
+ {
254
+ self.decoder_loop.get_inputs()[0].name: tokens,
255
+ self.decoder_loop.get_inputs()[1].name: n_layer_self_k_cache,
256
+ self.decoder_loop.get_inputs()[2].name: n_layer_self_v_cache,
257
+ self.decoder_loop.get_inputs()[3].name: n_layer_cross_k_cache,
258
+ self.decoder_loop.get_inputs()[4].name: n_layer_cross_v_cache,
259
+ self.decoder_loop.get_inputs()[5].name: pe,
260
+ self.decoder_loop.get_inputs()[6].name: self_attn_mask,
261
+ self.decoder_loop.get_inputs()[7].name: cross_attn_mask,
262
+ },
263
+ )
264
+ return (logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache)
265
+
266
+ def run_decoder(
267
+ self, n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
268
+ ):
269
+
270
+ num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
271
+ encoder_out_length = cross_attn_mask.shape[-1]
272
+
273
+ cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
274
+ cross_attn_mask = (
275
+ cross_attn_mask.unsqueeze(1)
276
+ .repeat(1, beam_size, 1, 1)
277
+ .view(beam_size * batch_size, -1, encoder_out_length)
278
+ )
279
+
280
+ n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
281
+ n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
282
+ n_layer_cross_k = (
283
+ n_layer_cross_k.unsqueeze(2)
284
+ .repeat(1, 1, beam_size, 1, 1)
285
+ .view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
286
+ )
287
+ n_layer_cross_v = (
288
+ n_layer_cross_v.unsqueeze(2)
289
+ .repeat(1, 1, beam_size, 1, 1)
290
+ .view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
291
+ )
292
+
293
+ prediction_tokens = (
294
+ torch.ones(beam_size * batch_size, 1).fill_(self.sos_id).long()
295
+ )
296
+ tokens = prediction_tokens
297
+ offset = torch.zeros(1, dtype=torch.int64)
298
+ n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
299
+ batch_size, beam_size
300
+ )
301
+
302
+ scores = torch.tensor([0.0] + [-INF] * (beam_size - 1)).float()
303
+ scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
304
+ is_finished = torch.zeros_like(scores)
305
+
306
+ # self_attn_mask = torch.zeros(
307
+ # batch_size * beam_size,
308
+ # 1, 1
309
+ # )
310
+
311
+ results = [self.sos_id]
312
+ for i in range(self.decode_max_len):
313
+
314
+ # ==== ORIGIN ====
315
+ # self_attn_mask = torch.empty(
316
+ # batch_size * beam_size,
317
+ # prediction_tokens.shape[-1], prediction_tokens.shape[-1]
318
+ # ).fill_(-np.inf).triu_(1)
319
+ # self_attn_mask = self_attn_mask[:, -1:, :]
320
+ # self_attn_mask = to_numpy(self_attn_mask)
321
+
322
+ # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_one_token(
323
+ # to_numpy(tokens),
324
+ # to_numpy(n_layer_self_k_cache),
325
+ # to_numpy(n_layer_self_v_cache),
326
+ # to_numpy(n_layer_cross_k),
327
+ # to_numpy(n_layer_cross_v),
328
+ # to_numpy(offset),
329
+ # to_numpy(self_attn_mask),
330
+ # to_numpy(cross_attn_mask)
331
+ # )
332
+ # ==== ORIGIN ====
333
+
334
+ tokens = to_numpy(tokens)
335
+ n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
336
+ n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
337
+ n_layer_cross_k = to_numpy(n_layer_cross_k)
338
+ n_layer_cross_v = to_numpy(n_layer_cross_v)
339
+ cross_attn_mask = to_numpy(cross_attn_mask)
340
+
341
+ self_attn_mask = np.zeros(
342
+ (batch_size * beam_size, 1, self.decode_max_len), dtype=np.float32
343
+ )
344
+ self_attn_mask[:, :, : self.decode_max_len - offset[0] - 1] = -np.inf
345
+
346
+ if i == 0:
347
+ (
348
+ logits,
349
+ n_layer_self_k_cache,
350
+ n_layer_self_v_cache,
351
+ ) = self.decode_main_one_token(
352
+ to_numpy(tokens),
353
+ to_numpy(n_layer_self_k_cache),
354
+ to_numpy(n_layer_self_v_cache),
355
+ to_numpy(n_layer_cross_k),
356
+ to_numpy(n_layer_cross_v),
357
+ self.pe[0],
358
+ self_attn_mask,
359
+ to_numpy(cross_attn_mask),
360
+ )
361
+ else:
362
+ (
363
+ logits,
364
+ n_layer_self_k_cache,
365
+ n_layer_self_v_cache,
366
+ ) = self.decode_loop_one_token(
367
+ to_numpy(tokens),
368
+ to_numpy(n_layer_self_k_cache),
369
+ to_numpy(n_layer_self_v_cache),
370
+ to_numpy(n_layer_cross_k),
371
+ to_numpy(n_layer_cross_v),
372
+ self.pe[offset],
373
+ self_attn_mask,
374
+ to_numpy(cross_attn_mask),
375
+ )
376
+
377
+ # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
378
+ # to_numpy(tokens),
379
+ # to_numpy(n_layer_self_k_cache),
380
+ # to_numpy(n_layer_self_v_cache),
381
+ # to_numpy(n_layer_cross_k),
382
+ # to_numpy(n_layer_cross_v),
383
+ # self.pe[offset],
384
+ # self_attn_mask,
385
+ # to_numpy(cross_attn_mask)
386
+ # )
387
+
388
+ offset += 1
389
+ logits = torch.from_numpy(logits)
390
+
391
+ logits = logits.squeeze(1)
392
+ t_scores = F.log_softmax(logits, dim=-1)
393
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
394
+ t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
395
+ t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
396
+
397
+ scores = scores + t_topB_scores
398
+
399
+ scores = scores.view(batch_size, beam_size * beam_size)
400
+ scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
401
+ scores = scores.view(-1, 1)
402
+
403
+ topB_row_number_in_each_B_rows_of_ys = torch.div(
404
+ topB_score_ids, beam_size
405
+ ).view(batch_size * beam_size)
406
+ stride = beam_size * torch.arange(batch_size).view(batch_size, 1).repeat(
407
+ 1, beam_size
408
+ ).view(batch_size * beam_size)
409
+ topB_row_number_in_ys = (
410
+ topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
411
+ )
412
+
413
+ prediction_tokens = prediction_tokens[topB_row_number_in_ys]
414
+ t_ys = torch.gather(
415
+ t_topB_ys.view(batch_size, beam_size * beam_size),
416
+ dim=1,
417
+ index=topB_score_ids,
418
+ ).view(beam_size * batch_size, 1)
419
+
420
+ tokens = t_ys
421
+
422
+ prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
423
+
424
+ n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
425
+ n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
426
+
427
+ for i, self_k_cache in enumerate(n_layer_self_k_cache):
428
+ n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
429
+
430
+ for i, self_v_cache in enumerate(n_layer_self_v_cache):
431
+ n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
432
+
433
+ is_finished = t_ys.eq(self.eos_id)
434
+ if is_finished.sum().item() == beam_size * batch_size:
435
+ break
436
+
437
+ scores = scores.view(batch_size, beam_size)
438
+ prediction_valid_token_lengths = torch.sum(
439
+ torch.ne(prediction_tokens.view(batch_size, beam_size, -1), self.eos_id),
440
+ dim=-1,
441
+ ).int()
442
+
443
+ nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
444
+ index = (
445
+ nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
446
+ )
447
+ nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[
448
+ index.view(-1)
449
+ ]
450
+ nbest_prediction_tokens = nbest_prediction_tokens.view(
451
+ batch_size, nbest_ids.size(1), -1
452
+ )
453
+ nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
454
+ batch_size * beam_size
455
+ )[index.view(-1)].view(batch_size, -1)
456
+ nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
457
+ for i in range(batch_size):
458
+ i_best_hyps: List[Dict[str, torch.Tensor]] = []
459
+ for j, score in enumerate(nbest_scores[i]):
460
+ hyp = {
461
+ "token_ids": nbest_prediction_tokens[
462
+ i, j, 1 : nbest_prediction_valid_token_lengths[i, j]
463
+ ],
464
+ "score": score,
465
+ }
466
+ i_best_hyps.append(hyp)
467
+ nbest_hyps.append(i_best_hyps)
468
+
469
+ return nbest_hyps
470
+
471
+ def get_initialized_self_cache(
472
+ self, batch_size, beam_size
473
+ ) -> Tuple[Tensor, Tensor]:
474
+ n_layer_self_k_cache = torch.zeros(
475
+ self.num_decoder_blocks,
476
+ batch_size * beam_size,
477
+ self.decode_max_len,
478
+ self.decoder_hidden_dim,
479
+ )
480
+ n_layer_self_v_cache = torch.zeros(
481
+ self.num_decoder_blocks,
482
+ batch_size * beam_size,
483
+ self.decode_max_len,
484
+ self.decoder_hidden_dim,
485
+ )
486
+ return n_layer_self_k_cache, n_layer_self_v_cache
487
+
488
+ def calc_feat_len(self, audio_dur):
489
+ import math
490
+
491
+ sample_rate = 16000
492
+ frame_length = 25 * sample_rate / 1000
493
+ frame_shift = 10 * sample_rate / 1000
494
+ length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
495
+ return length
496
+
497
+ def transcribe(
498
+ self, batch_wav_path: List[str], beam_size: int = 1, nbest: int = 1
499
+ ) -> List[Dict]:
500
+ feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
501
+ maxlen = self.calc_feat_len(self.audio_dur)
502
+ if feats.shape[1] < maxlen:
503
+ feats = np.concatenate(
504
+ [feats, np.zeros((1, maxlen - feats.shape[1], 80), dtype=np.float32)],
505
+ axis=1,
506
+ )
507
+ feats = feats[:, :maxlen, :]
508
+ lengths = torch.minimum(lengths, torch.tensor(maxlen))
509
+
510
+ feats = to_numpy(feats)
511
+ lengths = to_numpy(lengths)
512
+
513
+ start_time = time.time()
514
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
515
+ to_numpy(feats), to_numpy(lengths)
516
+ )
517
+ nbest_hyps = self.run_decoder(
518
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
519
+ )
520
+ transcribe_durations = time.time() - start_time
521
+ results: List[Dict] = []
522
+ for wav, hyp in zip(batch_wav_path, nbest_hyps):
523
+ hyp = hyp[0]
524
+ hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
525
+ score = hyp["score"].item()
526
+ text = self.tokenizer.detokenize(hyp_ids)
527
+ results.append({"wav": wav, "text": text, "score": score})
528
+
529
+ return results, wav_durations, transcribe_durations
test_ax_model.py CHANGED
@@ -11,79 +11,47 @@ logger_stream_hander = logging.StreamHandler()
11
  logger_stream_hander.setLevel("INFO")
12
  logger.addHandler(logger_stream_hander)
13
 
14
-
15
  def parse_args():
16
  parser = argparse.ArgumentParser(description="FireRedASRAxModel Test")
17
  parser.add_argument(
18
- "--encoder",
19
- type=str,
20
  default="axmodel/encoder.axmodel",
21
- help="Path to axmodel encoder"
22
  )
23
  parser.add_argument(
24
- "--decoder_loop",
25
- type=str,
26
  default="axmodel/decoder_loop.axmodel",
27
- help="Path to axmodel decoder loop"
28
  )
29
  parser.add_argument(
30
- "--cmvn",
31
- type=str,
32
- default="axmodel/cmvn.ark",
33
- help="Path to cmvn"
34
  )
35
  parser.add_argument(
36
- "--dict",
37
- type=str,
38
- default="axmodel/dict.txt",
39
- help="Path to dict"
40
  )
41
  parser.add_argument(
42
  "--spm_model",
43
  type=str,
44
  default="axmodel/train_bpe1000.model",
45
- help="Path to spm model"
46
- )
47
- parser.add_argument(
48
- "--wavlist",
49
- type=str,
50
- default="wavlist.txt",
51
- help="File to wav path list"
52
- )
53
- parser.add_argument(
54
- "--hypo",
55
- type=str,
56
- default="hypo_axmodel.txt",
57
- help="File of hypos"
58
- )
59
- parser.add_argument(
60
- "--beam_size",
61
- type=int,
62
- default=3,
63
- help=""
64
- )
65
- parser.add_argument(
66
- "--nbest",
67
- type=int,
68
- default=1,
69
- help=""
70
  )
71
  parser.add_argument(
72
- "--decode_max_len",
73
- type=int,
74
- default=128,
75
- help="max token len"
76
  )
77
  parser.add_argument(
78
- "--max_dur",
79
- type=int,
80
- default=10,
81
- help="max audio len"
82
  )
83
-
 
 
 
 
84
  return parser.parse_args()
85
-
86
-
87
  def parse_wavlist(wavlist: str):
88
  wavpaths = []
89
  with open(wavlist) as f:
@@ -93,24 +61,24 @@ def parse_wavlist(wavlist: str):
93
  print(f"{line} doesn't exist.")
94
  continue
95
  wavpaths.append(line)
96
-
97
  return wavpaths
98
-
99
 
100
  def main():
101
  args = parse_args()
102
  print(args)
103
-
104
- model = FireRedASRAxModel(args.encoder,
105
- args.decoder_main,
106
- args.decoder_loop,
107
- args.cmvn,
108
- args.dict,
109
- args.spm_model,
110
- decode_max_len=args.decode_max_len,
111
- audio_dur=args.max_dur
112
- )
113
-
114
  wf = open(args.hypo, "wt")
115
  wavlist = parse_wavlist(args.wavlist)
116
 
@@ -118,9 +86,10 @@ def main():
118
  total_transcribe_durations = 0
119
  for wav in wavlist:
120
  batch_wav = [wav]
121
- results, wav_durations, transcribe_durations = model.transcribe(
122
- batch_wav, args.beam_size, args.nbest)
123
-
 
124
  wav_durations = sum(wav_durations)
125
  total_wav_durations += wav_durations
126
  total_transcribe_durations += transcribe_durations
@@ -129,19 +98,19 @@ def main():
129
  logger.info(f"Transcribe Durations: {transcribe_durations}")
130
  rtf = transcribe_durations / wav_durations
131
  logger.info(f"(Real time factor) RTF: {rtf}")
132
- for result in results:
133
- logger.info(f"wav: {result['wav']}")
134
- logger.info(f"text: {result['text']}")
135
- logger.info(f"score: {result['score']}")
136
- logger.info("")
137
- wf.write(f"{result['text']} ({result['wav']})\n")
138
-
139
  logger.info(f"total wav durations: {total_wav_durations}")
140
  logger.info(f"total transcribe durations: {total_transcribe_durations}")
141
  avg_ref = total_transcribe_durations / total_wav_durations
142
  logger.info(f"AVG RTF: {avg_ref}")
143
-
144
  wf.close()
145
 
 
146
  if __name__ == "__main__":
147
- main()
 
11
  logger_stream_hander.setLevel("INFO")
12
  logger.addHandler(logger_stream_hander)
13
 
14
+
15
  def parse_args():
16
  parser = argparse.ArgumentParser(description="FireRedASRAxModel Test")
17
  parser.add_argument(
18
+ "--encoder",
19
+ type=str,
20
  default="axmodel/encoder.axmodel",
21
+ help="Path to axmodel encoder",
22
  )
23
  parser.add_argument(
24
+ "--decoder_loop",
25
+ type=str,
26
  default="axmodel/decoder_loop.axmodel",
27
+ help="Path to axmodel decoder loop",
28
  )
29
  parser.add_argument(
30
+ "--cmvn", type=str, default="axmodel/cmvn.ark", help="Path to cmvn"
 
 
 
31
  )
32
  parser.add_argument(
33
+ "--dict", type=str, default="axmodel/dict.txt", help="Path to dict"
 
 
 
34
  )
35
  parser.add_argument(
36
  "--spm_model",
37
  type=str,
38
  default="axmodel/train_bpe1000.model",
39
+ help="Path to spm model",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
  parser.add_argument(
42
+ "--wavlist", type=str, default="wavlist.txt", help="File to wav path list"
 
 
 
43
  )
44
  parser.add_argument(
45
+ "--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
 
 
 
46
  )
47
+ parser.add_argument("--beam_size", type=int, default=3, help="")
48
+ parser.add_argument("--nbest", type=int, default=1, help="")
49
+ parser.add_argument("--decode_max_len", type=int, default=128, help="max token len")
50
+ parser.add_argument("--max_dur", type=int, default=10, help="max audio len")
51
+
52
  return parser.parse_args()
53
+
54
+
55
  def parse_wavlist(wavlist: str):
56
  wavpaths = []
57
  with open(wavlist) as f:
 
61
  print(f"{line} doesn't exist.")
62
  continue
63
  wavpaths.append(line)
64
+
65
  return wavpaths
66
+
67
 
68
  def main():
69
  args = parse_args()
70
  print(args)
71
+
72
+ model = FireRedASRAxModel(
73
+ args.encoder,
74
+ args.decoder_loop,
75
+ args.cmvn,
76
+ args.dict,
77
+ args.spm_model,
78
+ decode_max_len=args.decode_max_len,
79
+ audio_dur=args.max_dur,
80
+ )
81
+
82
  wf = open(args.hypo, "wt")
83
  wavlist = parse_wavlist(args.wavlist)
84
 
 
86
  total_transcribe_durations = 0
87
  for wav in wavlist:
88
  batch_wav = [wav]
89
+ result, wav_durations, transcribe_durations = model.transcribe(
90
+ batch_wav, args.beam_size, args.nbest
91
+ )
92
+
93
  wav_durations = sum(wav_durations)
94
  total_wav_durations += wav_durations
95
  total_transcribe_durations += transcribe_durations
 
98
  logger.info(f"Transcribe Durations: {transcribe_durations}")
99
  rtf = transcribe_durations / wav_durations
100
  logger.info(f"(Real time factor) RTF: {rtf}")
101
+
102
+ text = result["text"]
103
+ logger.info(f"text: {text}")
104
+ logger.info("")
105
+ wf.write(f"{text}\n")
106
+
 
107
  logger.info(f"total wav durations: {total_wav_durations}")
108
  logger.info(f"total transcribe durations: {total_transcribe_durations}")
109
  avg_ref = total_transcribe_durations / total_wav_durations
110
  logger.info(f"AVG RTF: {avg_ref}")
111
+
112
  wf.close()
113
 
114
+
115
  if __name__ == "__main__":
116
+ main()
test_wer.py CHANGED
@@ -10,57 +10,57 @@ def setup_logging():
10
  # 获取脚本所在目录
11
  script_dir = os.path.dirname(os.path.abspath(__file__))
12
  log_file = os.path.join(script_dir, "test_wer.log")
13
-
14
  # 配置日志格式
15
- log_format = '%(asctime)s - %(levelname)s - %(message)s'
16
- date_format = '%Y-%m-%d %H:%M:%S'
17
-
18
  # 创建logger
19
  logger = logging.getLogger()
20
  logger.setLevel(logging.INFO)
21
-
22
  # 清除现有的handler
23
  for handler in logger.handlers[:]:
24
  logger.removeHandler(handler)
25
-
26
  # 创建文件handler
27
- file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
28
  file_handler.setLevel(logging.INFO)
29
  file_formatter = logging.Formatter(log_format, date_format)
30
  file_handler.setFormatter(file_formatter)
31
-
32
  # 创建控制台handler
33
  console_handler = logging.StreamHandler()
34
  console_handler.setLevel(logging.INFO)
35
  console_formatter = logging.Formatter(log_format, date_format)
36
  console_handler.setFormatter(console_formatter)
37
-
38
  # 添加handler到logger
39
  logger.addHandler(file_handler)
40
  logger.addHandler(console_handler)
41
-
42
  return logger
43
 
44
 
45
  class AIShellDataset:
46
- def __init__(self, gt_path: str, voice_dir='wav'):
47
  """
48
  初始化数据集
49
-
50
  Args:
51
  json_path: voice.json文件的路径
52
  """
53
  self.gt_path = gt_path
54
  self.dataset_dir = os.path.dirname(gt_path)
55
  self.voice_dir = os.path.join(self.dataset_dir, voice_dir)
56
-
57
  # 检查必要文件和文件夹是否存在
58
  assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
59
  assert os.path.exists(self.voice_dir), f"文件夹不存在: {self.voice_dir}"
60
-
61
  # 加载数据
62
  self.data = []
63
- with open(gt_path, 'r', encoding='utf-8') as f:
64
  for line in f:
65
  line = line.strip()
66
  audio_path, gt = line.split(" ")
@@ -70,50 +70,50 @@ class AIShellDataset:
70
  # 使用logging而不是print
71
  logger = logging.getLogger()
72
  logger.info(f"加载了 {len(self.data)} 条数据")
73
-
74
  def __iter__(self):
75
  """返回迭代器"""
76
  self.index = 0
77
  return self
78
-
79
  def __next__(self):
80
  """返回下一个数据项"""
81
  if self.index >= len(self.data):
82
  raise StopIteration
83
-
84
  item = self.data[self.index]
85
  audio_path = item["audio_path"]
86
  ground_truth = item["gt"]
87
-
88
  self.index += 1
89
  return audio_path, ground_truth
90
-
91
  def __len__(self):
92
  """返回数据集大小"""
93
  return len(self.data)
94
-
95
 
96
  class CommonVoiceDataset:
97
  """Common Voice数据集解析器"""
98
-
99
  def __init__(self, tsv_path: str):
100
  """
101
  初始化数据集
102
-
103
  Args:
104
  json_path: voice.json文件的路径
105
  """
106
  self.tsv_path = tsv_path
107
  self.dataset_dir = os.path.dirname(tsv_path)
108
  self.voice_dir = os.path.join(self.dataset_dir, "clips")
109
-
110
  # 检查必要文件和文件夹是否存在
111
  assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
112
  assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
113
-
114
  # 加载JSON数据
115
  self.data = []
116
- with open(tsv_path, 'r', encoding='utf-8') as f:
117
  f.readline()
118
  for line in f:
119
  line = line.strip()
@@ -122,107 +122,100 @@ class CommonVoiceDataset:
122
  gt = splits[2]
123
  audio_path = os.path.join(self.voice_dir, audio_path)
124
  self.data.append({"audio_path": audio_path, "gt": gt})
125
-
126
  # 使用logging而不是print
127
  logger = logging.getLogger()
128
  logger.info(f"加载了 {len(self.data)} 条数据")
129
-
130
  def __iter__(self):
131
  """返回迭代器"""
132
  self.index = 0
133
  return self
134
-
135
  def __next__(self):
136
  """返回下一个数据项"""
137
  if self.index >= len(self.data):
138
  raise StopIteration
139
-
140
  item = self.data[self.index]
141
  audio_path = item["audio_path"]
142
  ground_truth = item["gt"]
143
-
144
  self.index += 1
145
  return audio_path, ground_truth
146
-
147
  def __len__(self):
148
  """返回数据集大小"""
149
  return len(self.data)
150
 
 
151
  def get_args():
152
- parser = argparse.ArgumentParser(
153
- prog="whisper",
154
- description="Test WER on dataset"
155
- )
156
- parser.add_argument("--dataset", "-d", type=str, required=True, choices=["aishell", "common_voice"], help="Test dataset")
157
- parser.add_argument("--gt_path", "-g", type=str, required=True, help="Test dataset ground truth file")
158
- parser.add_argument("--max_num", type=int, default=-1, required=False, help="Maximum test data num")
159
- parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
160
  parser.add_argument(
161
- "--encoder",
162
- type=str,
163
- default="axmodel/encoder.axmodel",
164
- help="Path to onnx encoder"
 
 
165
  )
166
  parser.add_argument(
167
- "--decoder_main",
168
- type=str,
169
- default="axmodel/decoder_main.axmodel",
170
- help="Path to axmodel decoder main"
 
171
  )
172
  parser.add_argument(
173
- "--decoder_loop",
174
- type=str,
175
- default="axmodel/decoder_loop.axmodel",
176
- help="Path to axmodel decoder loop"
177
  )
178
  parser.add_argument(
179
- "--cmvn",
 
180
  type=str,
181
- default="axmodel/cmvn.ark",
182
- help="Path to cmvn"
 
183
  )
184
  parser.add_argument(
185
- "--dict",
186
  type=str,
187
- default="axmodel/dict.txt",
188
- help="Path to dict"
189
  )
190
  parser.add_argument(
191
- "--spm_model",
192
  type=str,
193
- default="axmodel/train_bpe1000.model",
194
- help="Path to spm model"
195
  )
196
  parser.add_argument(
197
- "--wavlist",
198
  type=str,
199
- default="wavlist.txt",
200
- help="File to wav path list"
201
  )
202
  parser.add_argument(
203
- "--hypo",
204
- type=str,
205
- default="hypo_axmodel.txt",
206
- help="File of hypos"
207
  )
208
  parser.add_argument(
209
- "--beam_size",
210
- type=int,
211
- default=3,
212
- help=""
213
  )
214
  parser.add_argument(
215
- "--nbest",
216
- type=int,
217
- default=1,
218
- help=""
 
 
 
219
  )
220
  parser.add_argument(
221
- "--max_len",
222
- type=int,
223
- default=128,
224
- help=""
225
  )
 
 
 
226
  return parser.parse_args()
227
 
228
 
@@ -235,42 +228,42 @@ def print_args(args):
235
 
236
 
237
  def min_distance(word1: str, word2: str) -> int:
238
-
239
  row = len(word1) + 1
240
  column = len(word2) + 1
241
-
242
- cache = [ [0]*column for i in range(row) ]
243
-
244
  for i in range(row):
245
  for j in range(column):
246
-
247
- if i ==0 and j ==0:
248
  cache[i][j] = 0
249
- elif i == 0 and j!=0:
250
  cache[i][j] = j
251
- elif j == 0 and i!=0:
252
  cache[i][j] = i
253
  else:
254
- if word1[i-1] == word2[j-1]:
255
- cache[i][j] = cache[i-1][j-1]
256
  else:
257
- replace = cache[i-1][j-1] + 1
258
- insert = cache[i][j-1] + 1
259
- remove = cache[i-1][j] + 1
260
-
261
  cache[i][j] = min(replace, insert, remove)
262
-
263
- return cache[row-1][column-1]
264
 
265
 
266
  def remove_punctuation(text):
267
  # 定义正则表达式���式,匹配所有标点符号
268
  # 这个模式包括常见的标点符号和中文标点
269
- pattern = r'[^\w\s]|_'
270
-
271
  # 使用sub方法将所有匹配的标点符号替换为空字符串
272
- cleaned_text = re.sub(pattern, '', text)
273
-
274
  return cleaned_text
275
 
276
 
@@ -292,16 +285,25 @@ def main():
292
  max_num = args.max_num
293
 
294
  # Load model
295
- model = FireRedASRAxModel(args.encoder,
296
- args.decoder_main,
297
- args.decoder_loop,
298
- args.cmvn,
299
- args.dict,
300
- args.spm_model,
301
- decode_max_len=args.max_len,
302
- audio_dur=10
303
  )
304
-
 
 
 
 
 
 
 
 
 
305
 
306
  # Iterate over dataset
307
  references = []
@@ -313,10 +315,9 @@ def main():
313
  for n, (audio_path, reference) in enumerate(dataset):
314
  batch_uttid = [os.path.splitext(os.path.basename(audio_path))[0]]
315
  batch_wav = [audio_path]
316
- results, _, _ = model.transcribe(
317
- batch_wav, args.beam_size, args.nbest)
318
 
319
- hypothesis = results[0]['text']
320
 
321
  hypothesis = remove_punctuation(hypothesis)
322
  reference = remove_punctuation(reference)
@@ -330,7 +331,7 @@ def main():
330
 
331
  hyp.append(hypothesis)
332
  references.append(reference)
333
-
334
  line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
335
  wer_file.write(line_content + "\n")
336
  logger.info(line_content)
@@ -344,5 +345,6 @@ def main():
344
  wer_file.write(f"Total WER: {total_character_error_rate}%")
345
  wer_file.close()
346
 
 
347
  if __name__ == "__main__":
348
  main()
 
10
  # 获取脚本所在目录
11
  script_dir = os.path.dirname(os.path.abspath(__file__))
12
  log_file = os.path.join(script_dir, "test_wer.log")
13
+
14
  # 配置日志格式
15
+ log_format = "%(asctime)s - %(levelname)s - %(message)s"
16
+ date_format = "%Y-%m-%d %H:%M:%S"
17
+
18
  # 创建logger
19
  logger = logging.getLogger()
20
  logger.setLevel(logging.INFO)
21
+
22
  # 清除现有的handler
23
  for handler in logger.handlers[:]:
24
  logger.removeHandler(handler)
25
+
26
  # 创建文件handler
27
+ file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
28
  file_handler.setLevel(logging.INFO)
29
  file_formatter = logging.Formatter(log_format, date_format)
30
  file_handler.setFormatter(file_formatter)
31
+
32
  # 创建控制台handler
33
  console_handler = logging.StreamHandler()
34
  console_handler.setLevel(logging.INFO)
35
  console_formatter = logging.Formatter(log_format, date_format)
36
  console_handler.setFormatter(console_formatter)
37
+
38
  # 添加handler到logger
39
  logger.addHandler(file_handler)
40
  logger.addHandler(console_handler)
41
+
42
  return logger
43
 
44
 
45
  class AIShellDataset:
46
+ def __init__(self, gt_path: str, voice_dir="wav"):
47
  """
48
  初始化数据集
49
+
50
  Args:
51
  json_path: voice.json文件的路径
52
  """
53
  self.gt_path = gt_path
54
  self.dataset_dir = os.path.dirname(gt_path)
55
  self.voice_dir = os.path.join(self.dataset_dir, voice_dir)
56
+
57
  # 检查必要文件和文件夹是否存在
58
  assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
59
  assert os.path.exists(self.voice_dir), f"文件夹不存在: {self.voice_dir}"
60
+
61
  # 加载数据
62
  self.data = []
63
+ with open(gt_path, "r", encoding="utf-8") as f:
64
  for line in f:
65
  line = line.strip()
66
  audio_path, gt = line.split(" ")
 
70
  # 使用logging而不是print
71
  logger = logging.getLogger()
72
  logger.info(f"加载了 {len(self.data)} 条数据")
73
+
74
  def __iter__(self):
75
  """返回迭代器"""
76
  self.index = 0
77
  return self
78
+
79
  def __next__(self):
80
  """返回下一个数据项"""
81
  if self.index >= len(self.data):
82
  raise StopIteration
83
+
84
  item = self.data[self.index]
85
  audio_path = item["audio_path"]
86
  ground_truth = item["gt"]
87
+
88
  self.index += 1
89
  return audio_path, ground_truth
90
+
91
  def __len__(self):
92
  """返回数据集大小"""
93
  return len(self.data)
94
+
95
 
96
  class CommonVoiceDataset:
97
  """Common Voice数据集解析器"""
98
+
99
  def __init__(self, tsv_path: str):
100
  """
101
  初始化数据集
102
+
103
  Args:
104
  json_path: voice.json文件的路径
105
  """
106
  self.tsv_path = tsv_path
107
  self.dataset_dir = os.path.dirname(tsv_path)
108
  self.voice_dir = os.path.join(self.dataset_dir, "clips")
109
+
110
  # 检查必要文件和文件夹是否存在
111
  assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
112
  assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
113
+
114
  # 加载JSON数据
115
  self.data = []
116
+ with open(tsv_path, "r", encoding="utf-8") as f:
117
  f.readline()
118
  for line in f:
119
  line = line.strip()
 
122
  gt = splits[2]
123
  audio_path = os.path.join(self.voice_dir, audio_path)
124
  self.data.append({"audio_path": audio_path, "gt": gt})
125
+
126
  # 使用logging而不是print
127
  logger = logging.getLogger()
128
  logger.info(f"加载了 {len(self.data)} 条数据")
129
+
130
  def __iter__(self):
131
  """返回迭代器"""
132
  self.index = 0
133
  return self
134
+
135
  def __next__(self):
136
  """返回下一个数据项"""
137
  if self.index >= len(self.data):
138
  raise StopIteration
139
+
140
  item = self.data[self.index]
141
  audio_path = item["audio_path"]
142
  ground_truth = item["gt"]
143
+
144
  self.index += 1
145
  return audio_path, ground_truth
146
+
147
  def __len__(self):
148
  """返回数据集大小"""
149
  return len(self.data)
150
 
151
+
152
  def get_args():
153
+ parser = argparse.ArgumentParser(prog="whisper", description="Test WER on dataset")
 
 
 
 
 
 
 
154
  parser.add_argument(
155
+ "--dataset",
156
+ "-d",
157
+ type=str,
158
+ required=True,
159
+ choices=["aishell", "common_voice"],
160
+ help="Test dataset",
161
  )
162
  parser.add_argument(
163
+ "--gt_path",
164
+ "-g",
165
+ type=str,
166
+ required=True,
167
+ help="Test dataset ground truth file",
168
  )
169
  parser.add_argument(
170
+ "--max_num", type=int, default=-1, required=False, help="Maximum test data num"
 
 
 
171
  )
172
  parser.add_argument(
173
+ "--language",
174
+ "-l",
175
  type=str,
176
+ required=False,
177
+ default="zh",
178
+ help="Target language, support en, zh, ja, and others. See languages.py for more options.",
179
  )
180
  parser.add_argument(
181
+ "--encoder",
182
  type=str,
183
+ default="axmodel/encoder.axmodel",
184
+ help="Path to onnx encoder",
185
  )
186
  parser.add_argument(
187
+ "--decoder_main",
188
  type=str,
189
+ default="axmodel/decoder_main.axmodel",
190
+ help="Path to axmodel decoder main",
191
  )
192
  parser.add_argument(
193
+ "--decoder_loop",
194
  type=str,
195
+ default="axmodel/decoder_loop.axmodel",
196
+ help="Path to axmodel decoder loop",
197
  )
198
  parser.add_argument(
199
+ "--cmvn", type=str, default="axmodel/cmvn.ark", help="Path to cmvn"
 
 
 
200
  )
201
  parser.add_argument(
202
+ "--dict", type=str, default="axmodel/dict.txt", help="Path to dict"
 
 
 
203
  )
204
  parser.add_argument(
205
+ "--spm_model",
206
+ type=str,
207
+ default="axmodel/train_bpe1000.model",
208
+ help="Path to spm model",
209
+ )
210
+ parser.add_argument(
211
+ "--wavlist", type=str, default="wavlist.txt", help="File to wav path list"
212
  )
213
  parser.add_argument(
214
+ "--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
 
 
 
215
  )
216
+ parser.add_argument("--beam_size", type=int, default=3, help="")
217
+ parser.add_argument("--nbest", type=int, default=1, help="")
218
+ parser.add_argument("--max_len", type=int, default=128, help="")
219
  return parser.parse_args()
220
 
221
 
 
228
 
229
 
230
  def min_distance(word1: str, word2: str) -> int:
231
+
232
  row = len(word1) + 1
233
  column = len(word2) + 1
234
+
235
+ cache = [[0] * column for i in range(row)]
236
+
237
  for i in range(row):
238
  for j in range(column):
239
+
240
+ if i == 0 and j == 0:
241
  cache[i][j] = 0
242
+ elif i == 0 and j != 0:
243
  cache[i][j] = j
244
+ elif j == 0 and i != 0:
245
  cache[i][j] = i
246
  else:
247
+ if word1[i - 1] == word2[j - 1]:
248
+ cache[i][j] = cache[i - 1][j - 1]
249
  else:
250
+ replace = cache[i - 1][j - 1] + 1
251
+ insert = cache[i][j - 1] + 1
252
+ remove = cache[i - 1][j] + 1
253
+
254
  cache[i][j] = min(replace, insert, remove)
255
+
256
+ return cache[row - 1][column - 1]
257
 
258
 
259
  def remove_punctuation(text):
260
  # 定义正则表达式���式,匹配所有标点符号
261
  # 这个模式包括常见的标点符号和中文标点
262
+ pattern = r"[^\w\s]|_"
263
+
264
  # 使用sub方法将所有匹配的标点符号替换为空字符串
265
+ cleaned_text = re.sub(pattern, "", text)
266
+
267
  return cleaned_text
268
 
269
 
 
285
  max_num = args.max_num
286
 
287
  # Load model
288
+ model = FireRedASRAxModel(
289
+ args.encoder,
290
+ args.decoder_loop,
291
+ args.cmvn,
292
+ args.dict,
293
+ args.spm_model,
294
+ decode_max_len=args.max_len,
295
+ audio_dur=10,
296
  )
297
+ # model = FireRedASROnnxModel(
298
+ # args.encoder,
299
+ # args.decoder,
300
+ # args.cmvn,
301
+ # args.dict,
302
+ # args.spm_model,
303
+ # decode_max_len=args.max_len,
304
+ # audio_dur=10
305
+ # )
306
+ # model = FireRedAsr.from_pretrained("aed", "model_convert/pretrained_models/FireRedASR-AED-L")
307
 
308
  # Iterate over dataset
309
  references = []
 
315
  for n, (audio_path, reference) in enumerate(dataset):
316
  batch_uttid = [os.path.splitext(os.path.basename(audio_path))[0]]
317
  batch_wav = [audio_path]
318
+ results, _, _ = model.transcribe(batch_wav, args.beam_size, args.nbest)
 
319
 
320
+ hypothesis = results["text"]
321
 
322
  hypothesis = remove_punctuation(hypothesis)
323
  reference = remove_punctuation(reference)
 
331
 
332
  hyp.append(hypothesis)
333
  references.append(reference)
334
+
335
  line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
336
  wer_file.write(line_content + "\n")
337
  logger.info(line_content)
 
345
  wer_file.write(f"Total WER: {total_character_error_rate}%")
346
  wer_file.close()
347
 
348
+
349
  if __name__ == "__main__":
350
  main()