inoryQwQ commited on
Commit
c6f1198
·
1 Parent(s): c215883

Update model, no need for decoder_main

Browse files
axmodel/decoder_loop.axmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6b109e1281135a673b613c1fd92f5d12d64e02d1f3da47561c142bbc57295d5d
3
- size 446759232
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c9b3e351557d20846f50d819e18c59d6f10a8adfc40322e5e3034b404b3e038
3
+ size 435136795
axmodel/decoder_loop_u8.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c34f5617f86ad6759bcef16df3b8c2be74660e33b05f1447c52d6c6cf3dcc1e1
3
- size 447207512
 
 
 
 
axmodel/decoder_main.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dc95af79976bd25aa2b13fe62d99ff5e9b03a3d9ce1ea26bfc8b7c7502a4b9b0
3
- size 506408654
 
 
 
 
axmodel/decoder_main_u8.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ebf1de8db552335580fba7e83d2d89e9479518a99bdc7728b04b6975b3eb2b88
3
- size 511355470
 
 
 
 
axmodel/encoder.axmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d1ceef85b578ecffec2e6eaee4dc27987c0e342f109b14a375376935121c5a2c
3
- size 851312087
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cc674ba54cf0e57f3c7dffa3824cd53700e4e7709827893f8708c4958e116c1
3
+ size 851656147
fireredasr_axmodel.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fireredasr.data.asr_feat import ASRFeatExtractor
2
+ from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
3
+
4
+ import axengine as axe
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from torch import Tensor
9
+ from typing import Tuple, List, Dict
10
+ import os
11
+ import time
12
+
13
+ INF = 1e10
14
+
15
+ def to_numpy(tensor):
16
+ if isinstance(tensor, np.ndarray):
17
+ return tensor
18
+ if tensor.requires_grad:
19
+ return tensor.detach().cpu().numpy()
20
+ else:
21
+ return tensor.cpu().numpy()
22
+
23
+
24
+ def set_finished_beam_score_to_zero(scores, is_finished):
25
+ NB, B = scores.size()
26
+ is_finished = is_finished.float()
27
+ mask_score = torch.tensor([0.0] + [-INF]*(B-1)).float()
28
+ mask_score = mask_score.view(1, B).repeat(NB, 1)
29
+ return scores * (1 - is_finished) + mask_score * is_finished
30
+
31
+
32
+ def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
33
+ is_finished = is_finished.long()
34
+ return ys * (1 - is_finished) + eos_id * is_finished
35
+
36
+
37
+ class FireRedASRAxModel:
38
+ def __init__(
39
+ self,
40
+ encoder_path: str,
41
+ decoder_loop_path: str,
42
+ cmvn_file: str,
43
+ dict_file: str,
44
+ spm_model_path: str,
45
+ providers=['AxEngineExecutionProvider'],
46
+ decode_max_len=128,
47
+ audio_dur=10
48
+ ):
49
+ # NOTE: 参考whisper设置的最大的解码长度
50
+ # FireRedASR-AED 模型支持的最长语音为 60s
51
+ # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
52
+ self.decode_max_len = decode_max_len
53
+
54
+ self.decoder_hidden_dim = 1280
55
+ self.audio_dur = audio_dur
56
+ self.max_feat_len = self.calc_feat_len(audio_dur)
57
+ self.num_decoder_blocks = 16
58
+ self.blank_id = 0
59
+ self.sos_id = 3
60
+ self.eos_id = 4
61
+ self.pad_id = 2
62
+
63
+ self.feature_extractor = ASRFeatExtractor(cmvn_file)
64
+ self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
65
+
66
+ self.init_encoder(encoder_path, providers)
67
+ self.init_decoder_loop(decoder_loop_path, providers)
68
+ self.pe = self.init_pe(decoder_loop_path)
69
+
70
+ def init_encoder(self, encoder_path, providers=None):
71
+ self.encoder = axe.InferenceSession(
72
+ encoder_path,
73
+ providers=providers
74
+ )
75
+
76
+ def init_decoder_loop(self, decoder_path, providers=None):
77
+ self.decoder_loop = axe.InferenceSession(
78
+ decoder_path,
79
+ providers=providers
80
+ )
81
+
82
+ def init_pe(self, decoder_path):
83
+ decoder_path = os.path.dirname(decoder_path)
84
+ decoder_path = os.path.join(decoder_path, "pe.npy")
85
+
86
+ return np.load(decoder_path)
87
+
88
+ def run_encoder(self, input: np.ndarray,
89
+ input_length: np.ndarray
90
+ ) -> Tuple[Tensor, Tensor, Tensor]:
91
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
92
+ None,
93
+ {
94
+ "encoder_input": input,
95
+ "encoder_input_lengths": input_length
96
+ }
97
+ )
98
+ return (
99
+ n_layer_cross_k,
100
+ n_layer_cross_v,
101
+ cross_attn_mask
102
+ )
103
+
104
+ def decode_loop_one_token(
105
+ self,
106
+ tokens: np.ndarray,
107
+ n_layer_self_k_cache: np.ndarray,
108
+ n_layer_self_v_cache: np.ndarray,
109
+ n_layer_cross_k_cache: np.ndarray,
110
+ n_layer_cross_v_cache: np.ndarray,
111
+ pe: np.ndarray,
112
+ self_attn_mask: np.ndarray,
113
+ cross_attn_mask: np.ndarray
114
+ ) -> Tuple[Tensor, Tensor, Tensor]:
115
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_loop.run(
116
+ None,
117
+ {
118
+ "tokens": tokens,
119
+ "in_n_layer_self_k_cache": n_layer_self_k_cache,
120
+ "in_n_layer_self_v_cache": n_layer_self_v_cache,
121
+ "n_layer_cross_k": n_layer_cross_k_cache,
122
+ "n_layer_cross_v": n_layer_cross_v_cache,
123
+ "pe": pe,
124
+ "self_attn_mask": self_attn_mask,
125
+ "cross_attn_mask": cross_attn_mask,
126
+ }
127
+ )
128
+ return (
129
+ logits,
130
+ out_n_layer_self_k_cache,
131
+ out_n_layer_self_v_cache
132
+ )
133
+
134
+ def run_decoder(
135
+ self,
136
+ n_layer_cross_k,
137
+ n_layer_cross_v,
138
+ cross_attn_mask,
139
+ beam_size,
140
+ nbest
141
+ ):
142
+
143
+ num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
144
+ encoder_out_length = cross_attn_mask.shape[-1]
145
+
146
+ cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
147
+ cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
148
+ 1, beam_size, 1, 1
149
+ ).view(beam_size * batch_size, -1, encoder_out_length)
150
+
151
+ n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
152
+ n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
153
+ n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
154
+ 1, 1, beam_size, 1, 1
155
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
156
+ n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
157
+ 1, 1, beam_size, 1, 1
158
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
159
+
160
+ prediction_tokens = torch.ones(
161
+ beam_size * batch_size, 1).fill_(self.sos_id).long()
162
+ tokens = prediction_tokens
163
+ offset = torch.zeros(1, dtype=torch.int64)
164
+ n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
165
+ batch_size, beam_size
166
+ )
167
+
168
+ scores = torch.tensor([0.0] + [-INF]*(beam_size - 1)).float()
169
+ scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
170
+ is_finished = torch.zeros_like(scores)
171
+
172
+ self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
173
+
174
+ for i in range(self.decode_max_len):
175
+
176
+ tokens = to_numpy(tokens).astype(np.int32)
177
+ n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
178
+ n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
179
+ n_layer_cross_k = to_numpy(n_layer_cross_k)
180
+ n_layer_cross_v = to_numpy(n_layer_cross_v)
181
+ cross_attn_mask = to_numpy(cross_attn_mask)
182
+
183
+ self_attn_mask = np.zeros((batch_size * beam_size, 1, self.decode_max_len), dtype=np.float32)
184
+ self_attn_mask[:, :, :self.decode_max_len - offset[0] - 1] = -np.inf
185
+
186
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
187
+ to_numpy(tokens),
188
+ to_numpy(n_layer_self_k_cache),
189
+ to_numpy(n_layer_self_v_cache),
190
+ to_numpy(n_layer_cross_k),
191
+ to_numpy(n_layer_cross_v),
192
+ self.pe[offset],
193
+ self_attn_mask,
194
+ to_numpy(cross_attn_mask)
195
+ )
196
+
197
+ offset += 1
198
+ logits = torch.from_numpy(logits)
199
+
200
+ logits = logits.squeeze(1)
201
+ t_scores = F.log_softmax(logits, dim=-1)
202
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
203
+ t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
204
+ t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
205
+
206
+ scores = scores + t_topB_scores
207
+
208
+ scores = scores.view(batch_size, beam_size * beam_size)
209
+ scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
210
+ scores = scores.view(-1, 1)
211
+
212
+ topB_row_number_in_each_B_rows_of_ys = torch.div(
213
+ topB_score_ids, beam_size).view(batch_size * beam_size)
214
+ stride = beam_size * torch.arange(batch_size).view(
215
+ batch_size, 1).repeat(1, beam_size).view(batch_size * beam_size)
216
+ topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
217
+
218
+ prediction_tokens = prediction_tokens[topB_row_number_in_ys]
219
+ t_ys = torch.gather(
220
+ t_topB_ys.view(batch_size, beam_size * beam_size),
221
+ dim=1, index=topB_score_ids
222
+ ).view(beam_size * batch_size, 1)
223
+
224
+ tokens = t_ys
225
+
226
+ prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
227
+
228
+ n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
229
+ n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
230
+
231
+ for i, self_k_cache in enumerate(n_layer_self_k_cache):
232
+ n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
233
+
234
+ for i, self_v_cache in enumerate(n_layer_self_v_cache):
235
+ n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
236
+
237
+ is_finished = t_ys.eq(self.eos_id)
238
+ if is_finished.sum().item() == beam_size * batch_size:
239
+ break
240
+
241
+ scores = scores.view(batch_size, beam_size)
242
+ prediction_valid_token_lengths = torch.sum(
243
+ torch.ne(
244
+ prediction_tokens.view(batch_size, beam_size, -1),
245
+ self.eos_id),
246
+ dim=-1
247
+ ).int()
248
+
249
+ nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
250
+ index = nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
251
+ nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
252
+ nbest_prediction_tokens = nbest_prediction_tokens.view(batch_size, nbest_ids.size(1), -1)
253
+ nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
254
+ batch_size * beam_size)[index.view(-1)].view(batch_size, -1)
255
+ nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
256
+ for i in range(batch_size):
257
+ i_best_hyps: List[Dict[str, torch.Tensor]] = []
258
+ for j, score in enumerate(nbest_scores[i]):
259
+ hyp = {
260
+ "token_ids": nbest_prediction_tokens[i, j, 1:nbest_prediction_valid_token_lengths[i, j]],
261
+ "score": score
262
+ }
263
+ i_best_hyps.append(hyp)
264
+ nbest_hyps.append(i_best_hyps)
265
+
266
+ return nbest_hyps
267
+
268
+ def get_initialized_self_cache(self,
269
+ batch_size,
270
+ beam_size
271
+ ) -> Tuple[Tensor, Tensor]:
272
+ n_layer_self_k_cache = torch.zeros(
273
+ self.num_decoder_blocks,
274
+ batch_size * beam_size,
275
+ self.decode_max_len,
276
+ self.decoder_hidden_dim,
277
+ )
278
+ n_layer_self_v_cache = torch.zeros(
279
+ self.num_decoder_blocks,
280
+ batch_size * beam_size,
281
+ self.decode_max_len,
282
+ self.decoder_hidden_dim,
283
+ )
284
+ return n_layer_self_k_cache, n_layer_self_v_cache
285
+
286
+ def calc_feat_len(self, audio_dur):
287
+ import math
288
+ sample_rate = 16000
289
+ frame_length = 25 * sample_rate / 1000
290
+ frame_shift = 10 * sample_rate / 1000
291
+ length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
292
+ return length
293
+
294
+ def transcribe(self,
295
+ batch_wav_path: List[str],
296
+ beam_size: int = 1,
297
+ nbest: int = 1
298
+ ) -> List[Dict]:
299
+ feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
300
+ # print(f"feats.shape: {feats.shape}")
301
+ if feats.shape[1] < self.max_feat_len:
302
+ feats = np.concatenate([feats, np.zeros((1, self.max_feat_len - feats.shape[1], 80), dtype=np.float32)], axis=1)
303
+ feats = feats[:, :self.max_feat_len, :]
304
+ lengths = torch.minimum(lengths, torch.tensor(self.max_feat_len))
305
+
306
+ feats = to_numpy(feats)
307
+ lengths = to_numpy(lengths).astype(np.int32)
308
+
309
+ start_time = time.time()
310
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
311
+ to_numpy(feats),
312
+ to_numpy(lengths)
313
+ )
314
+ # print(f"run encoder take {(time.time() - start_time) * 1000}ms")
315
+ nbest_hyps = self.run_decoder(n_layer_cross_k,
316
+ n_layer_cross_v,
317
+ cross_attn_mask,
318
+ beam_size,
319
+ nbest,
320
+ )
321
+ transcribe_durations = time.time() - start_time
322
+ results: List[Dict] = []
323
+ for wav, hyp in zip(batch_wav_path, nbest_hyps):
324
+ hyp = hyp[0]
325
+ hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
326
+ score = hyp["score"].item()
327
+ text = self.tokenizer.detokenize(hyp_ids)
328
+ results.append(
329
+ {
330
+ "wav": wav,
331
+ "text": text,
332
+ "score": score
333
+ }
334
+ )
335
+
336
+ return results, wav_durations, transcribe_durations
test_ax_model.py CHANGED
@@ -1,546 +1,30 @@
1
- from fireredasr.data.asr_feat import ASRFeatExtractor
2
- from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
3
-
4
- import axengine as axe
5
- import torch
6
- import torch.nn.functional as F
7
- import numpy as np
8
- from torch import Tensor
9
- from typing import Tuple, List, Dict
10
  import argparse
11
  import os
12
  import time
13
  import logging
14
 
 
 
15
  logger = logging.getLogger()
16
  logger.setLevel(logging.INFO)
17
  logger_stream_hander = logging.StreamHandler()
18
  logger_stream_hander.setLevel("INFO")
19
  logger.addHandler(logger_stream_hander)
20
 
21
-
22
- INF = 1e10
23
-
24
-
25
- def to_numpy(tensor):
26
- if isinstance(tensor, np.ndarray):
27
- return tensor
28
- if tensor.requires_grad:
29
- return tensor.detach().cpu().numpy()
30
- else:
31
- return tensor.cpu().numpy()
32
-
33
-
34
- def set_finished_beam_score_to_zero(scores, is_finished):
35
- NB, B = scores.size()
36
- is_finished = is_finished.float()
37
- mask_score = torch.tensor([0.0] + [-INF]*(B-1)).float()
38
- mask_score = mask_score.view(1, B).repeat(NB, 1)
39
- return scores * (1 - is_finished) + mask_score * is_finished
40
-
41
-
42
- def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
43
- is_finished = is_finished.long()
44
- return ys * (1 - is_finished) + eos_id * is_finished
45
-
46
-
47
- class FireRedASROnnxModel:
48
- def __init__(
49
- self,
50
- encoder_path: str,
51
- decoder_path: str,
52
- cmvn_file: str,
53
- dict_file: str,
54
- spm_model_path: str,
55
- providers=['AXCLRTExecutionProvider', 'AxEngineExecutionProvider'],
56
- decode_max_len=128
57
- ):
58
- # NOTE: 参考whisper设置的最大的解码长度
59
- # FireRedASR-AED 模型支持的最长语音为 60s
60
- # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
61
- self.decode_max_len = decode_max_len
62
-
63
- self.decoder_hidden_dim = 1280
64
- self.num_decoder_blocks = 16
65
- self.blank_id = 0
66
- self.sos_id = 3
67
- self.eos_id = 4
68
- self.pad_id = 2
69
-
70
- self.feature_extractor = ASRFeatExtractor(cmvn_file)
71
- self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
72
- self.encoder = None
73
- self.decoder = None
74
-
75
- self.init_encoder(encoder_path, providers)
76
- self.init_decoder_main(decoder_path, providers)
77
- self.init_decoder_loop(decoder_path, providers)
78
- self.pe = self.init_pe(decoder_path)
79
-
80
- def init_encoder(self, encoder_path, providers=None):
81
- start_time = time.time()
82
- self.encoder = axe.InferenceSession(
83
- encoder_path,
84
- # sess_options=self.session_opts,
85
- providers=providers
86
- )
87
- end_time = time.time()
88
- logger.info(f"load encoder cost {end_time - start_time} seconds")
89
-
90
- def init_decoder_main(self, decoder_path, providers=None):
91
- decoder_path = os.path.dirname(decoder_path)
92
- decoder_path = os.path.join(decoder_path, "decoder_main.axmodel")
93
- start_time = time.time()
94
- self.decoder_main = axe.InferenceSession(
95
- decoder_path,
96
- # sess_options=self.session_opts,
97
- providers=providers
98
- )
99
- end_time = time.time()
100
- logger.info(f"load decoder_main cost {end_time - start_time} seconds")
101
-
102
- # input_names = [i.name for i in self.decoder_main.get_inputs()]
103
- # print(f"decoder_main.input_names: {input_names}")
104
-
105
- def init_decoder_loop(self, decoder_path, providers=None):
106
- decoder_path = os.path.dirname(decoder_path)
107
- decoder_path = os.path.join(decoder_path, "decoder_loop.axmodel")
108
-
109
- start_time = time.time()
110
- self.decoder_loop = axe.InferenceSession(
111
- decoder_path,
112
- # sess_options=self.session_opts,
113
- providers=providers
114
- )
115
- end_time = time.time()
116
- logger.info(f"load decoder_loop cost {end_time - start_time} seconds")
117
-
118
- # input_names = [i.name for i in self.decoder_loop.get_inputs()]
119
- # print(f"decoder_loop.input_names: {input_names}")
120
-
121
- def init_pe(self, decoder_path):
122
- decoder_path = os.path.dirname(decoder_path)
123
- decoder_path = os.path.join(decoder_path, "pe.npy")
124
-
125
- return np.load(decoder_path)
126
-
127
- def run_encoder(self, input: np.ndarray,
128
- input_length: np.ndarray
129
- ) -> Tuple[Tensor, Tensor, Tensor]:
130
- n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
131
- None,
132
- {
133
- "encoder_input": input,
134
- "encoder_input_lengths": input_length
135
- }
136
- )
137
- # n_layer_cross_k, n_layer_cross_v, cross_attn_mask = \
138
- # outputs["n_layer_cross_k"], outputs["n_layer_cross_v"], outputs["cross_attn_mask"]
139
- return (
140
- n_layer_cross_k,
141
- n_layer_cross_v,
142
- cross_attn_mask
143
- )
144
-
145
- def decode_one_token(
146
- self,
147
- tokens: np.ndarray,
148
- n_layer_self_k_cache: np.ndarray,
149
- n_layer_self_v_cache: np.ndarray,
150
- n_layer_cross_k_cache: np.ndarray,
151
- n_layer_cross_v_cache: np.ndarray,
152
- offset: np.ndarray,
153
- self_attn_mask: np.ndarray,
154
- cross_attn_mask: np.ndarray
155
- ) -> Tuple[Tensor, Tensor, Tensor]:
156
- print("decode:")
157
- print(f"tokens.shape: {tokens.shape}")
158
- print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
159
- print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
160
- print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
161
- print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
162
- print(f"offset.shape: {offset.shape}")
163
- print(f"self_attn_mask.shape: {self_attn_mask.shape}")
164
- print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
165
- # print(f"self_attn_mask: {self_attn_mask}")
166
-
167
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
168
- None,
169
- {
170
- self.decoder.get_inputs()[0].name: tokens,
171
- self.decoder.get_inputs()[1].name: n_layer_self_k_cache,
172
- self.decoder.get_inputs()[2].name: n_layer_self_v_cache,
173
- self.decoder.get_inputs()[3].name: n_layer_cross_k_cache,
174
- self.decoder.get_inputs()[4].name: n_layer_cross_v_cache,
175
- self.decoder.get_inputs()[5].name: offset,
176
- self.decoder.get_inputs()[6].name: self_attn_mask,
177
- self.decoder.get_inputs()[7].name: cross_attn_mask,
178
- }
179
- )
180
- return (
181
- logits,
182
- out_n_layer_self_k_cache,
183
- out_n_layer_self_v_cache
184
- )
185
-
186
- def decode_main_one_token(
187
- self,
188
- tokens: np.ndarray,
189
- n_layer_self_k_cache: np.ndarray,
190
- n_layer_self_v_cache: np.ndarray,
191
- n_layer_cross_k_cache: np.ndarray,
192
- n_layer_cross_v_cache: np.ndarray,
193
- pe: np.ndarray,
194
- self_attn_mask: np.ndarray,
195
- cross_attn_mask: np.ndarray
196
- ) -> Tuple[Tensor, Tensor, Tensor]:
197
- # print("decode_main:")
198
- # print(f"tokens.shape: {tokens.shape}")
199
- # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
200
- # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
201
- # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
202
- # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
203
- # print(f"pe.shape: {pe.shape}")
204
- # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
205
- # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
206
-
207
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_main.run(
208
- None,
209
- {
210
- "tokens": tokens,
211
- # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
212
- "n_layer_cross_k": n_layer_cross_k_cache,
213
- "n_layer_cross_v": n_layer_cross_v_cache,
214
- # "pe": pe,
215
- # "self_attn_mask": self_attn_mask,
216
- "cross_attn_mask": cross_attn_mask,
217
- # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
218
- }
219
- )
220
- # logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = \
221
- # outputs["logits"], outputs["out_n_layer_self_k_cache"], outputs["out_n_layer_self_v_cache"]
222
- return (
223
- logits,
224
- out_n_layer_self_k_cache,
225
- out_n_layer_self_v_cache
226
- )
227
-
228
- def decode_loop_one_token(
229
- self,
230
- tokens: np.ndarray,
231
- n_layer_self_k_cache: np.ndarray,
232
- n_layer_self_v_cache: np.ndarray,
233
- n_layer_cross_k_cache: np.ndarray,
234
- n_layer_cross_v_cache: np.ndarray,
235
- pe: np.ndarray,
236
- self_attn_mask: np.ndarray,
237
- cross_attn_mask: np.ndarray
238
- ) -> Tuple[Tensor, Tensor, Tensor]:
239
- # print("decode_loop:")
240
- # print(f"tokens.shape: {tokens.shape}")
241
- # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
242
- # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
243
- # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
244
- # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
245
- # print(f"pe.shape: {pe.shape}")
246
- # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
247
- # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
248
-
249
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_loop.run(
250
- None,
251
- {
252
- "tokens": tokens,
253
- "in_n_layer_self_k_cache": n_layer_self_k_cache,
254
- "in_n_layer_self_v_cache": n_layer_self_v_cache,
255
- "n_layer_cross_k": n_layer_cross_k_cache,
256
- "n_layer_cross_v": n_layer_cross_v_cache,
257
- "pe": pe,
258
- "self_attn_mask": self_attn_mask,
259
- "cross_attn_mask": cross_attn_mask,
260
- }
261
- )
262
- # logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = \
263
- # outputs["logits"], outputs["out_n_layer_self_k_cache"], outputs["out_n_layer_self_v_cache"]
264
- return (
265
- logits,
266
- out_n_layer_self_k_cache,
267
- out_n_layer_self_v_cache
268
- )
269
-
270
- def run_decoder(
271
- self,
272
- n_layer_cross_k,
273
- n_layer_cross_v,
274
- cross_attn_mask,
275
- beam_size,
276
- nbest
277
- ):
278
-
279
- num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
280
- encoder_out_length = cross_attn_mask.shape[-1]
281
-
282
- cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
283
- cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
284
- 1, beam_size, 1, 1
285
- ).view(beam_size * batch_size, -1, encoder_out_length)
286
-
287
- n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
288
- n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
289
- n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
290
- 1, 1, beam_size, 1, 1
291
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
292
- n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
293
- 1, 1, beam_size, 1, 1
294
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
295
-
296
- prediction_tokens = torch.ones(
297
- beam_size * batch_size, 1).fill_(self.sos_id).long()
298
- tokens = prediction_tokens
299
- offset = torch.zeros(1, dtype=torch.int64)
300
- n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
301
- batch_size, beam_size
302
- )
303
-
304
- scores = torch.tensor([0.0] + [-INF]*(beam_size - 1)).float()
305
- scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
306
- is_finished = torch.zeros_like(scores)
307
-
308
- # self_attn_mask = torch.zeros(
309
- # batch_size * beam_size,
310
- # 1, 1
311
- # )
312
- self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
313
-
314
- results = [self.sos_id]
315
- for i in range(self.decode_max_len):
316
-
317
- # self_attn_mask = torch.empty(
318
- # batch_size * beam_size,
319
- # prediction_tokens.shape[-1], prediction_tokens.shape[-1]
320
- # ).fill_(-np.inf).triu_(1)
321
- # self_attn_mask = self_attn_mask[:, -1:, :]
322
- # self_attn_mask = to_numpy(self_attn_mask)
323
-
324
- # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_one_token(
325
- # to_numpy(tokens),
326
- # to_numpy(n_layer_self_k_cache),
327
- # to_numpy(n_layer_self_v_cache),
328
- # to_numpy(n_layer_cross_k),
329
- # to_numpy(n_layer_cross_v),
330
- # to_numpy(offset),
331
- # to_numpy(self_attn_mask),
332
- # to_numpy(cross_attn_mask)
333
- # )
334
-
335
- tokens = to_numpy(tokens).astype(np.int32)
336
- n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
337
- n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
338
- n_layer_cross_k = to_numpy(n_layer_cross_k)
339
- n_layer_cross_v = to_numpy(n_layer_cross_v)
340
- cross_attn_mask = to_numpy(cross_attn_mask)
341
-
342
- self_attn_mask = np.zeros((batch_size * beam_size, 1, self.decode_max_len), dtype=np.float32)
343
- self_attn_mask[:, :, :self.decode_max_len - offset[0] - 1] = -np.inf
344
-
345
- # for name, npy in zip(
346
- # ["tokens", "n_layer_self_k_cache", "n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "pe", "self_attn_mask", "cross_attn_mask"],
347
- # [tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, self.pe[offset], self_attn_mask, cross_attn_mask]
348
- # ):
349
- # file_path = os.path.join(decoder_data_path, name)
350
- # os.makedirs(file_path, exist_ok=True)
351
- # np.save(os.path.join(file_path, f"{i}.npy"), npy)
352
-
353
- if i == 0:
354
- start_time = time.time()
355
- logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
356
- to_numpy(tokens),
357
- to_numpy(n_layer_self_k_cache),
358
- to_numpy(n_layer_self_v_cache),
359
- to_numpy(n_layer_cross_k),
360
- to_numpy(n_layer_cross_v),
361
- self.pe[offset],
362
- self_attn_mask,
363
- to_numpy(cross_attn_mask)
364
- )
365
- print(f"run decoder_main take {(time.time() - start_time) * 1000}ms")
366
- else:
367
- start_time = time.time()
368
- logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
369
- to_numpy(tokens),
370
- to_numpy(n_layer_self_k_cache),
371
- to_numpy(n_layer_self_v_cache),
372
- to_numpy(n_layer_cross_k),
373
- to_numpy(n_layer_cross_v),
374
- self.pe[offset],
375
- self_attn_mask,
376
- to_numpy(cross_attn_mask)
377
- )
378
- print(f"run decoder_loop take {(time.time() - start_time) * 1000}ms")
379
-
380
- offset += 1
381
- logits = torch.from_numpy(logits)
382
-
383
- logits = logits.squeeze(1)
384
- t_scores = F.log_softmax(logits, dim=-1)
385
- t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
386
- t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
387
- t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
388
-
389
- scores = scores + t_topB_scores
390
-
391
- scores = scores.view(batch_size, beam_size * beam_size)
392
- scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
393
- scores = scores.view(-1, 1)
394
-
395
- topB_row_number_in_each_B_rows_of_ys = torch.div(
396
- topB_score_ids, beam_size).view(batch_size * beam_size)
397
- stride = beam_size * torch.arange(batch_size).view(
398
- batch_size, 1).repeat(1, beam_size).view(batch_size * beam_size)
399
- topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
400
-
401
- prediction_tokens = prediction_tokens[topB_row_number_in_ys]
402
- t_ys = torch.gather(
403
- t_topB_ys.view(batch_size, beam_size * beam_size),
404
- dim=1, index=topB_score_ids
405
- ).view(beam_size * batch_size, 1)
406
-
407
- tokens = t_ys
408
-
409
- prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
410
-
411
- n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
412
- n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
413
-
414
- for i, self_k_cache in enumerate(n_layer_self_k_cache):
415
- n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
416
-
417
- for i, self_v_cache in enumerate(n_layer_self_v_cache):
418
- n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
419
-
420
- is_finished = t_ys.eq(self.eos_id)
421
- if is_finished.sum().item() == beam_size * batch_size:
422
- break
423
-
424
- scores = scores.view(batch_size, beam_size)
425
- prediction_valid_token_lengths = torch.sum(
426
- torch.ne(
427
- prediction_tokens.view(batch_size, beam_size, -1),
428
- self.eos_id),
429
- dim=-1
430
- ).int()
431
-
432
- nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
433
- index = nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
434
- nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
435
- nbest_prediction_tokens = nbest_prediction_tokens.view(batch_size, nbest_ids.size(1), -1)
436
- nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
437
- batch_size * beam_size)[index.view(-1)].view(batch_size, -1)
438
- nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
439
- for i in range(batch_size):
440
- i_best_hyps: List[Dict[str, torch.Tensor]] = []
441
- for j, score in enumerate(nbest_scores[i]):
442
- hyp = {
443
- "token_ids": nbest_prediction_tokens[i, j, 1:nbest_prediction_valid_token_lengths[i, j]],
444
- "score": score
445
- }
446
- i_best_hyps.append(hyp)
447
- nbest_hyps.append(i_best_hyps)
448
-
449
- return nbest_hyps
450
-
451
- def get_initialized_self_cache(self,
452
- batch_size,
453
- beam_size
454
- ) -> Tuple[Tensor, Tensor]:
455
- n_layer_self_k_cache = torch.zeros(
456
- self.num_decoder_blocks,
457
- batch_size * beam_size,
458
- self.decode_max_len,
459
- self.decoder_hidden_dim,
460
- )
461
- n_layer_self_v_cache = torch.zeros(
462
- self.num_decoder_blocks,
463
- batch_size * beam_size,
464
- self.decode_max_len,
465
- self.decoder_hidden_dim,
466
- )
467
- return n_layer_self_k_cache, n_layer_self_v_cache
468
-
469
- def calc_feat_len(self, audio_dur):
470
- import math
471
- sample_rate = 16000
472
- frame_length = 25 * sample_rate / 1000
473
- frame_shift = 10 * sample_rate / 1000
474
- length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
475
- return length
476
-
477
- def transcribe(self,
478
- batch_wav_path: List[str],
479
- beam_size: int = 1,
480
- nbest: int = 1
481
- ) -> List[Dict]:
482
- feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
483
- # print(f"feats.shape: {feats.shape}")
484
- maxlen = self.calc_feat_len(10)
485
- if feats.shape[1] < maxlen:
486
- feats = np.concatenate([feats, np.zeros((1, maxlen - feats.shape[1], 80), dtype=np.float32)], axis=1)
487
- feats = feats[:, :maxlen, :]
488
-
489
- # encoder_data_path = os.path.join("calib_dataset", "encoder", os.path.basename(batch_wav_path[0]))
490
- # decoder_data_path = os.path.join("calib_dataset", "decoder", os.path.basename(batch_wav_path[0]))
491
- # os.makedirs(encoder_data_path, exist_ok=True)
492
- # os.makedirs(decoder_data_path, exist_ok=True)
493
-
494
- feats = to_numpy(feats)
495
- lengths = to_numpy(lengths).astype(np.int32)
496
-
497
- # for name, npy in zip(["encoder_input", "encoder_input_lengths"], [feats, lengths]):
498
- # file_path = os.path.join(encoder_data_path, name + ".npy")
499
- # np.save(file_path, npy)
500
-
501
- start_time = time.time()
502
- n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
503
- to_numpy(feats),
504
- to_numpy(lengths)
505
- )
506
- print(f"run encoder take {(time.time() - start_time) * 1000}ms")
507
- nbest_hyps = self.run_decoder(n_layer_cross_k,
508
- n_layer_cross_v,
509
- cross_attn_mask,
510
- beam_size,
511
- nbest,
512
- )
513
- transcribe_durations = time.time() - start_time
514
- results: List[Dict] = []
515
- for wav, hyp in zip(batch_wav_path, nbest_hyps):
516
- hyp = hyp[0]
517
- hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
518
- score = hyp["score"].item()
519
- text = self.tokenizer.detokenize(hyp_ids)
520
- results.append(
521
- {
522
- "wav": wav,
523
- "text": text,
524
- "score": score
525
- }
526
- )
527
-
528
- return results, wav_durations, transcribe_durations
529
-
530
 
531
  def parse_args():
532
- parser = argparse.ArgumentParser(description="FireRedASROnnxModel Test")
533
  parser.add_argument(
534
  "--encoder",
535
  type=str,
536
  default="axmodel/encoder.axmodel",
537
- help="Path to onnx encoder"
538
  )
539
  parser.add_argument(
540
- "--decoder",
541
  type=str,
542
- default="axmodel/decoder_main.axmodel",
543
- help="Path to onnx decoder"
544
  )
545
  parser.add_argument(
546
  "--cmvn",
@@ -585,10 +69,16 @@ def parse_args():
585
  help=""
586
  )
587
  parser.add_argument(
588
- "--max_len",
589
  type=int,
590
  default=128,
591
- help=""
 
 
 
 
 
 
592
  )
593
 
594
  return parser.parse_args()
@@ -611,12 +101,14 @@ def main():
611
  args = parse_args()
612
  print(args)
613
 
614
- onnx_model = FireRedASROnnxModel(args.encoder,
615
- args.decoder,
 
616
  args.cmvn,
617
  args.dict,
618
  args.spm_model,
619
- decode_max_len=args.max_len
 
620
  )
621
 
622
  wf = open(args.hypo, "wt")
@@ -626,7 +118,7 @@ def main():
626
  total_transcribe_durations = 0
627
  for wav in wavlist:
628
  batch_wav = [wav]
629
- results, wav_durations, transcribe_durations = onnx_model.transcribe(
630
  batch_wav, args.beam_size, args.nbest)
631
 
632
  wav_durations = sum(wav_durations)
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  import os
3
  import time
4
  import logging
5
 
6
+ from fireredasr_axmodel import FireRedASRAxModel
7
+
8
  logger = logging.getLogger()
9
  logger.setLevel(logging.INFO)
10
  logger_stream_hander = logging.StreamHandler()
11
  logger_stream_hander.setLevel("INFO")
12
  logger.addHandler(logger_stream_hander)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def parse_args():
16
+ parser = argparse.ArgumentParser(description="FireRedASRAxModel Test")
17
  parser.add_argument(
18
  "--encoder",
19
  type=str,
20
  default="axmodel/encoder.axmodel",
21
+ help="Path to axmodel encoder"
22
  )
23
  parser.add_argument(
24
+ "--decoder_loop",
25
  type=str,
26
+ default="axmodel/decoder_loop.axmodel",
27
+ help="Path to axmodel decoder loop"
28
  )
29
  parser.add_argument(
30
  "--cmvn",
 
69
  help=""
70
  )
71
  parser.add_argument(
72
+ "--decode_max_len",
73
  type=int,
74
  default=128,
75
+ help="max token len"
76
+ )
77
+ parser.add_argument(
78
+ "--max_dur",
79
+ type=int,
80
+ default=10,
81
+ help="max audio len"
82
  )
83
 
84
  return parser.parse_args()
 
101
  args = parse_args()
102
  print(args)
103
 
104
+ model = FireRedASRAxModel(args.encoder,
105
+ args.decoder_main,
106
+ args.decoder_loop,
107
  args.cmvn,
108
  args.dict,
109
  args.spm_model,
110
+ decode_max_len=args.decode_max_len,
111
+ audio_dur=args.max_dur
112
  )
113
 
114
  wf = open(args.hypo, "wt")
 
118
  total_transcribe_durations = 0
119
  for wav in wavlist:
120
  batch_wav = [wav]
121
+ results, wav_durations, transcribe_durations = model.transcribe(
122
  batch_wav, args.beam_size, args.nbest)
123
 
124
  wav_durations = sum(wav_durations)
test_wer.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import logging
4
+ import re
5
+ from fireredasr_axmodel import FireRedASRAxModel
6
+
7
+
8
+ def setup_logging():
9
+ """配置日志系统,同时输出到控制台和文件"""
10
+ # 获取脚本所在目录
11
+ script_dir = os.path.dirname(os.path.abspath(__file__))
12
+ log_file = os.path.join(script_dir, "test_wer.log")
13
+
14
+ # 配置日志格式
15
+ log_format = '%(asctime)s - %(levelname)s - %(message)s'
16
+ date_format = '%Y-%m-%d %H:%M:%S'
17
+
18
+ # 创建logger
19
+ logger = logging.getLogger()
20
+ logger.setLevel(logging.INFO)
21
+
22
+ # 清除现有的handler
23
+ for handler in logger.handlers[:]:
24
+ logger.removeHandler(handler)
25
+
26
+ # 创建文件handler
27
+ file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
28
+ file_handler.setLevel(logging.INFO)
29
+ file_formatter = logging.Formatter(log_format, date_format)
30
+ file_handler.setFormatter(file_formatter)
31
+
32
+ # 创建控制台handler
33
+ console_handler = logging.StreamHandler()
34
+ console_handler.setLevel(logging.INFO)
35
+ console_formatter = logging.Formatter(log_format, date_format)
36
+ console_handler.setFormatter(console_formatter)
37
+
38
+ # 添加handler到logger
39
+ logger.addHandler(file_handler)
40
+ logger.addHandler(console_handler)
41
+
42
+ return logger
43
+
44
+
45
+ class AIShellDataset:
46
+ def __init__(self, gt_path: str, voice_dir='wav'):
47
+ """
48
+ 初始化数据集
49
+
50
+ Args:
51
+ json_path: voice.json文件的路径
52
+ """
53
+ self.gt_path = gt_path
54
+ self.dataset_dir = os.path.dirname(gt_path)
55
+ self.voice_dir = os.path.join(self.dataset_dir, voice_dir)
56
+
57
+ # 检查必要文件和文件夹是否存在
58
+ assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
59
+ assert os.path.exists(self.voice_dir), f"文件夹不存在: {self.voice_dir}"
60
+
61
+ # 加载数据
62
+ self.data = []
63
+ with open(gt_path, 'r', encoding='utf-8') as f:
64
+ for line in f:
65
+ line = line.strip()
66
+ audio_path, gt = line.split(" ")
67
+ audio_path = os.path.join(self.voice_dir, audio_path + ".wav")
68
+ self.data.append({"audio_path": audio_path, "gt": gt})
69
+
70
+ # 使用logging而不是print
71
+ logger = logging.getLogger()
72
+ logger.info(f"加载了 {len(self.data)} 条数据")
73
+
74
+ def __iter__(self):
75
+ """返回迭代器"""
76
+ self.index = 0
77
+ return self
78
+
79
+ def __next__(self):
80
+ """返回下一个数据项"""
81
+ if self.index >= len(self.data):
82
+ raise StopIteration
83
+
84
+ item = self.data[self.index]
85
+ audio_path = item["audio_path"]
86
+ ground_truth = item["gt"]
87
+
88
+ self.index += 1
89
+ return audio_path, ground_truth
90
+
91
+ def __len__(self):
92
+ """返回数据集大小"""
93
+ return len(self.data)
94
+
95
+
96
+ class CommonVoiceDataset:
97
+ """Common Voice数据集解析器"""
98
+
99
+ def __init__(self, tsv_path: str):
100
+ """
101
+ 初始化数据集
102
+
103
+ Args:
104
+ json_path: voice.json文件的路径
105
+ """
106
+ self.tsv_path = tsv_path
107
+ self.dataset_dir = os.path.dirname(tsv_path)
108
+ self.voice_dir = os.path.join(self.dataset_dir, "clips")
109
+
110
+ # 检查必要文件和文件夹是否存在
111
+ assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
112
+ assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
113
+
114
+ # 加载JSON数据
115
+ self.data = []
116
+ with open(tsv_path, 'r', encoding='utf-8') as f:
117
+ f.readline()
118
+ for line in f:
119
+ line = line.strip()
120
+ splits = line.split("\t")
121
+ audio_path = splits[1]
122
+ gt = splits[2]
123
+ audio_path = os.path.join(self.voice_dir, audio_path)
124
+ self.data.append({"audio_path": audio_path, "gt": gt})
125
+
126
+ # 使用logging而不是print
127
+ logger = logging.getLogger()
128
+ logger.info(f"加载了 {len(self.data)} 条数据")
129
+
130
+ def __iter__(self):
131
+ """返回迭代器"""
132
+ self.index = 0
133
+ return self
134
+
135
+ def __next__(self):
136
+ """返回下一个数据项"""
137
+ if self.index >= len(self.data):
138
+ raise StopIteration
139
+
140
+ item = self.data[self.index]
141
+ audio_path = item["audio_path"]
142
+ ground_truth = item["gt"]
143
+
144
+ self.index += 1
145
+ return audio_path, ground_truth
146
+
147
+ def __len__(self):
148
+ """返回数据集大小"""
149
+ return len(self.data)
150
+
151
+ def get_args():
152
+ parser = argparse.ArgumentParser(
153
+ prog="whisper",
154
+ description="Test WER on dataset"
155
+ )
156
+ parser.add_argument("--dataset", "-d", type=str, required=True, choices=["aishell", "common_voice"], help="Test dataset")
157
+ parser.add_argument("--gt_path", "-g", type=str, required=True, help="Test dataset ground truth file")
158
+ parser.add_argument("--max_num", type=int, default=-1, required=False, help="Maximum test data num")
159
+ parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
160
+ parser.add_argument(
161
+ "--encoder",
162
+ type=str,
163
+ default="axmodel/encoder.axmodel",
164
+ help="Path to onnx encoder"
165
+ )
166
+ parser.add_argument(
167
+ "--decoder_main",
168
+ type=str,
169
+ default="axmodel/decoder_main.axmodel",
170
+ help="Path to axmodel decoder main"
171
+ )
172
+ parser.add_argument(
173
+ "--decoder_loop",
174
+ type=str,
175
+ default="axmodel/decoder_loop.axmodel",
176
+ help="Path to axmodel decoder loop"
177
+ )
178
+ parser.add_argument(
179
+ "--cmvn",
180
+ type=str,
181
+ default="axmodel/cmvn.ark",
182
+ help="Path to cmvn"
183
+ )
184
+ parser.add_argument(
185
+ "--dict",
186
+ type=str,
187
+ default="axmodel/dict.txt",
188
+ help="Path to dict"
189
+ )
190
+ parser.add_argument(
191
+ "--spm_model",
192
+ type=str,
193
+ default="axmodel/train_bpe1000.model",
194
+ help="Path to spm model"
195
+ )
196
+ parser.add_argument(
197
+ "--wavlist",
198
+ type=str,
199
+ default="wavlist.txt",
200
+ help="File to wav path list"
201
+ )
202
+ parser.add_argument(
203
+ "--hypo",
204
+ type=str,
205
+ default="hypo_axmodel.txt",
206
+ help="File of hypos"
207
+ )
208
+ parser.add_argument(
209
+ "--beam_size",
210
+ type=int,
211
+ default=3,
212
+ help=""
213
+ )
214
+ parser.add_argument(
215
+ "--nbest",
216
+ type=int,
217
+ default=1,
218
+ help=""
219
+ )
220
+ parser.add_argument(
221
+ "--max_len",
222
+ type=int,
223
+ default=128,
224
+ help=""
225
+ )
226
+ return parser.parse_args()
227
+
228
+
229
+ def print_args(args):
230
+ logger = logging.getLogger()
231
+ logger.info(f"dataset: {args.dataset}")
232
+ logger.info(f"gt_path: {args.gt_path}")
233
+ logger.info(f"max_num: {args.max_num}")
234
+ logger.info(f"language: {args.language}")
235
+
236
+
237
+ def min_distance(word1: str, word2: str) -> int:
238
+
239
+ row = len(word1) + 1
240
+ column = len(word2) + 1
241
+
242
+ cache = [ [0]*column for i in range(row) ]
243
+
244
+ for i in range(row):
245
+ for j in range(column):
246
+
247
+ if i ==0 and j ==0:
248
+ cache[i][j] = 0
249
+ elif i == 0 and j!=0:
250
+ cache[i][j] = j
251
+ elif j == 0 and i!=0:
252
+ cache[i][j] = i
253
+ else:
254
+ if word1[i-1] == word2[j-1]:
255
+ cache[i][j] = cache[i-1][j-1]
256
+ else:
257
+ replace = cache[i-1][j-1] + 1
258
+ insert = cache[i][j-1] + 1
259
+ remove = cache[i-1][j] + 1
260
+
261
+ cache[i][j] = min(replace, insert, remove)
262
+
263
+ return cache[row-1][column-1]
264
+
265
+
266
+ def remove_punctuation(text):
267
+ # 定义正则表达式模式,匹配所有标点符号
268
+ # 这个模式包括常见的标点符号和中文标点
269
+ pattern = r'[^\w\s]|_'
270
+
271
+ # 使用sub方法将所有匹配的标点符号替换为空字符串
272
+ cleaned_text = re.sub(pattern, '', text)
273
+
274
+ return cleaned_text
275
+
276
+
277
+ def main():
278
+ # 设置日志系统
279
+ logger = setup_logging()
280
+
281
+ args = get_args()
282
+ print_args(args)
283
+
284
+ dataset_type = args.dataset.lower()
285
+ if dataset_type == "aishell":
286
+ dataset = AIShellDataset(args.gt_path)
287
+ elif dataset_type == "common_voice":
288
+ dataset = CommonVoiceDataset(args.gt_path)
289
+ else:
290
+ raise ValueError(f"Unknown dataset type {dataset_type}")
291
+
292
+ max_num = args.max_num
293
+
294
+ # Load model
295
+ model = FireRedASRAxModel(args.encoder,
296
+ args.decoder_main,
297
+ args.decoder_loop,
298
+ args.cmvn,
299
+ args.dict,
300
+ args.spm_model,
301
+ decode_max_len=args.max_len,
302
+ audio_dur=10
303
+ )
304
+
305
+
306
+ # Iterate over dataset
307
+ references = []
308
+ hyp = []
309
+ all_character_error_num = 0
310
+ all_character_num = 0
311
+ wer_file = open("wer.txt", "w")
312
+ max_data_num = max_num if max_num > 0 else len(dataset)
313
+ for n, (audio_path, reference) in enumerate(dataset):
314
+ batch_uttid = [os.path.splitext(os.path.basename(audio_path))[0]]
315
+ batch_wav = [audio_path]
316
+ results, _, _ = model.transcribe(
317
+ batch_wav, args.beam_size, args.nbest)
318
+
319
+ hypothesis = results[0]['text']
320
+
321
+ hypothesis = remove_punctuation(hypothesis)
322
+ reference = remove_punctuation(reference)
323
+
324
+ character_error_num = min_distance(reference, hypothesis)
325
+ character_num = len(reference)
326
+ character_error_rate = character_error_num / character_num * 100
327
+
328
+ all_character_error_num += character_error_num
329
+ all_character_num += character_num
330
+
331
+ hyp.append(hypothesis)
332
+ references.append(reference)
333
+
334
+ line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
335
+ wer_file.write(line_content + "\n")
336
+ logger.info(line_content)
337
+
338
+ if n + 1 >= max_data_num:
339
+ break
340
+
341
+ total_character_error_rate = all_character_error_num / all_character_num * 100
342
+
343
+ logger.info(f"Total WER: {total_character_error_rate}%")
344
+ wer_file.write(f"Total WER: {total_character_error_rate}%")
345
+ wer_file.close()
346
+
347
+ if __name__ == "__main__":
348
+ main()