yangrongzhao commited on
Commit
f21b604
·
1 Parent(s): 1d01163

Add model convert

Browse files
model_convert/model_wrapper.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+
5
+ from fireredasr.models.module.conformer_encoder import ConformerEncoder
6
+ from fireredasr.models.module.transformer_decoder import (
7
+ TransformerDecoder,
8
+ DecoderLayer,
9
+ DecoderMultiHeadAttention,
10
+ DecoderScaledDotProductAttention,
11
+ PositionalEncoding
12
+ )
13
+
14
+
15
+ def DecoderScaledDotProductAttentionForward(
16
+ self: DecoderScaledDotProductAttention,
17
+ q: Tensor,
18
+ k: Tensor,
19
+ v: Tensor,
20
+ mask: Tensor
21
+ ):
22
+ attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
23
+ if mask is not None:
24
+ # mask is such as [[[0, 0, 0, 0, ..., -inf, -inf]]]
25
+ attn = attn + mask
26
+ attn = torch.softmax(attn, dim=-1)
27
+ else:
28
+ attn = torch.softmax(attn, dim=-1)
29
+ output = torch.matmul(attn, v)
30
+ return output
31
+
32
+ DecoderScaledDotProductAttention.forward = DecoderScaledDotProductAttentionForward
33
+
34
+
35
+ """
36
+ The purpose of this is to allow the exported onnx model
37
+ to only need to pass in the token id of the decoding result
38
+ of the previous time step when performing decoding inference at each time step,
39
+ rather than the token id of all previous time steps.
40
+ """
41
+ def PositionalEncodingForward(
42
+ self: PositionalEncoding,
43
+ offset: Tensor
44
+ ):
45
+ return self.pe[:, :offset].clone().detach()[:, -1]
46
+
47
+ PositionalEncoding.forward = PositionalEncodingForward
48
+
49
+
50
+ """
51
+ NOTE(Lianghu): Why do that?
52
+
53
+ When exporting the onnx model using original padding_position_is_0 funciton,
54
+ the dynamic batch does not work properly for the exported onnx model.
55
+
56
+ The code in the original padding_position_is_0 function is as follows:
57
+ ```py
58
+ def padding_position_is_0(...):
59
+ N, T = padded_input.size()[:2]
60
+ mask = torch.ones((N, T)).to(padded_input.device)
61
+ ...
62
+ ```
63
+
64
+ Because when exporting onnx, N and T are considered constants.
65
+ Should be N = padded_input.size(0) and T = padded_input.size(1).
66
+ """
67
+ def padding_position_is_0(self: ConformerEncoder,
68
+ padded_input: Tensor,
69
+ input_lengths: Tensor):
70
+ N = padded_input.size(0)
71
+ T = padded_input.size(1)
72
+ seq_range = torch.arange(T, device=padded_input.device).unsqueeze(0) # shape: (1, T)
73
+ input_lengths_exp = input_lengths.unsqueeze(1) # shape: (N, 1)
74
+ mask = seq_range < input_lengths_exp # shape: (N, T)
75
+ mask = mask.unsqueeze(dim=1)
76
+ return mask.to(torch.uint8)
77
+
78
+
79
+ ConformerEncoder.padding_position_is_0 = padding_position_is_0
80
+
81
+ class AudioEncoderTensorCache(nn.Module):
82
+ def __init__(self,
83
+ encoder: ConformerEncoder,
84
+ decoder: TransformerDecoder):
85
+ super().__init__()
86
+ self.encoder = encoder
87
+ self.decoder = decoder
88
+
89
+ def forward(self, input: Tensor, input_length: Tensor):
90
+ encoder_output, _, encoder_mask = self.encoder(input, input_length)
91
+
92
+ n_layer_cross_k_list = []
93
+ n_layer_cross_v_list = []
94
+
95
+ for layer in self.decoder.layer_stack:
96
+ # layer: DecoderLayer
97
+ n_layer_cross_k_list.append(layer.cross_attn.w_ks(encoder_output))
98
+ n_layer_cross_v_list.append(layer.cross_attn.w_vs(encoder_output))
99
+
100
+ encoder_mask = encoder_mask.to(torch.float32)
101
+ encoder_mask[encoder_mask == 0] = -torch.inf
102
+ encoder_mask[encoder_mask == 1] = 0.0
103
+
104
+ return (torch.stack(n_layer_cross_k_list),
105
+ torch.stack(n_layer_cross_v_list),
106
+ encoder_mask)
107
+
108
+
109
+ class DecoderMultiHeadSelfAttention(nn.Module):
110
+ def __init__(self, multiHeadSelfAttention: DecoderMultiHeadAttention, loop: bool = False):
111
+ super().__init__()
112
+ self.multiHeadSelfAttention = multiHeadSelfAttention
113
+ self.loop = loop
114
+
115
+ def forward(self,
116
+ x: Tensor,
117
+ k_cache: Tensor,
118
+ v_cache: Tensor,
119
+ mask: Tensor):
120
+ bs = x.size(0)
121
+
122
+ # 当前时间步为 t
123
+ # k_cache 和 v_cache 是 时间步 [0: t-1] 的 self_attn_k 和 self_attn_v 的缓存
124
+ q = self.multiHeadSelfAttention.w_qs(x)
125
+ k = self.multiHeadSelfAttention.w_ks(x)
126
+ v = self.multiHeadSelfAttention.w_vs(x)
127
+
128
+ k_cache[:, -k.shape[1] :, :] = k
129
+ v_cache[:, -v.shape[1] :, :] = v
130
+ # if self.loop:
131
+ # k_cache = torch.cat([k_cache[:, 1:, :], k], 1)
132
+ # v_cache = torch.cat([v_cache[:, 1:, :], v], 1)
133
+ # else:
134
+ # k_cache = k
135
+ # v_cache = v
136
+
137
+ q = q.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
138
+ k = k_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
139
+ v = v_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
140
+ k = k.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
141
+ v = v.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
142
+ q = q.transpose(1, 2)
143
+ k = k.transpose(1, 2)
144
+ v = v.transpose(1, 2)
145
+
146
+ if mask is not None:
147
+ mask = mask.unsqueeze(1)
148
+
149
+ output = self.multiHeadSelfAttention.attention(q, k, v, mask)
150
+ output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadSelfAttention.d_model)
151
+ output = self.multiHeadSelfAttention.fc(output)
152
+ output = self.multiHeadSelfAttention.dropout(output)
153
+
154
+ return output, k_cache, v_cache
155
+
156
+
157
+ class DecoderMultiHeadSelfAttentionV2(nn.Module):
158
+ def __init__(self, multiHeadSelfAttention: DecoderMultiHeadAttention, loop: bool = False):
159
+ super().__init__()
160
+ self.multiHeadSelfAttention = multiHeadSelfAttention
161
+ self.loop = loop
162
+
163
+ def forward(self,
164
+ x: Tensor,
165
+ k_cache: Tensor,
166
+ v_cache: Tensor,
167
+ mask: Tensor):
168
+ bs = x.size(0)
169
+
170
+ # 当前时间步为 t
171
+ # k_cache 和 v_cache 是 时间步 [0: t-1] 的 self_attn_k 和 self_attn_v 的缓存
172
+ q = self.multiHeadSelfAttention.w_qs(x)
173
+ k = self.multiHeadSelfAttention.w_ks(x)
174
+ v = self.multiHeadSelfAttention.w_vs(x)
175
+
176
+ # k_cache[:, -k.shape[1] :, :] = k
177
+ # v_cache[:, -v.shape[1] :, :] = v
178
+ if self.loop:
179
+ k_cache = torch.cat([k_cache[:, 1:, :], k], 1)
180
+ v_cache = torch.cat([v_cache[:, 1:, :], v], 1)
181
+ else:
182
+ k_cache = k
183
+ v_cache = v
184
+
185
+ q = q.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
186
+ k = k_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
187
+ v = v_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
188
+ k = k.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
189
+ v = v.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
190
+ q = q.transpose(1, 2)
191
+ k = k.transpose(1, 2)
192
+ v = v.transpose(1, 2)
193
+
194
+ if mask is not None:
195
+ mask = mask.unsqueeze(1)
196
+
197
+ output = self.multiHeadSelfAttention.attention(q, k, v, mask)
198
+ output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadSelfAttention.d_model)
199
+ output = self.multiHeadSelfAttention.fc(output)
200
+ output = self.multiHeadSelfAttention.dropout(output)
201
+
202
+ return output, k_cache, v_cache
203
+
204
+
205
+ class DecoderMultiHeadCrossAttention(nn.Module):
206
+ def __init__(self, multiHeadCrossAttention: DecoderMultiHeadAttention):
207
+ super().__init__()
208
+ self.multiHeadCrossAttention = multiHeadCrossAttention
209
+
210
+ def forward(self,
211
+ x: Tensor,
212
+ k: Tensor,
213
+ v: Tensor,
214
+ mask: Tensor):
215
+ bs = x.size(0)
216
+ x = self.multiHeadCrossAttention.w_qs(x)
217
+ x = x.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k)
218
+ k = k.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k)
219
+ v = v.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k)
220
+
221
+ x = x.transpose(1, 2)
222
+ k = k.transpose(1, 2)
223
+ v = v.transpose(1, 2)
224
+
225
+ if mask is not None:
226
+ mask = mask.unsqueeze(1)
227
+
228
+ output = self.multiHeadCrossAttention.attention(x, k, v, mask)
229
+ output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadCrossAttention.d_model)
230
+ output = self.multiHeadCrossAttention.fc(output)
231
+ output = self.multiHeadCrossAttention.dropout(output)
232
+
233
+ return output
234
+
235
+
236
+ class ResidualAttentionBlockTensorCache(nn.Module):
237
+ def __init__(self, decoder_layer: DecoderLayer, loop: bool = False):
238
+ super().__init__()
239
+ self.original_decoder_layer = decoder_layer
240
+ self.self_attn = DecoderMultiHeadSelfAttention(decoder_layer.self_attn, loop)
241
+ self.cross_attn = DecoderMultiHeadCrossAttention(decoder_layer.cross_attn)
242
+
243
+ def forward(self,
244
+ x: Tensor,
245
+ self_k_cache: Tensor,
246
+ self_v_cache: Tensor,
247
+ cross_k: Tensor,
248
+ cross_v: Tensor,
249
+ self_attn_mask: Tensor,
250
+ cross_attn_mask: Tensor):
251
+ # q.shape (B, 1, dim)
252
+ x_self_attn_norm = self.original_decoder_layer.self_attn_norm(x)
253
+ self_attn_x, self_k_cache_updated, self_v_cache_updated = self.self_attn(
254
+ x_self_attn_norm, self_k_cache, self_v_cache, self_attn_mask)
255
+
256
+ x = x + self_attn_x
257
+
258
+ residual = x
259
+ x_cross_attn_norm = self.original_decoder_layer.cross_attn_norm(x)
260
+ x_cross_attn = self.cross_attn(x_cross_attn_norm, cross_k, cross_v, cross_attn_mask)
261
+ x = residual + x_cross_attn
262
+
263
+ x = x + self.original_decoder_layer.mlp(self.original_decoder_layer.mlp_norm(x))
264
+
265
+ return x, self_k_cache_updated, self_v_cache_updated
266
+
267
+
268
+ class ResidualAttentionBlockTensorCacheV2(nn.Module):
269
+ def __init__(self, decoder_layer: DecoderLayer, loop: bool = False):
270
+ super().__init__()
271
+ self.original_decoder_layer = decoder_layer
272
+ self.self_attn = DecoderMultiHeadSelfAttentionV2(decoder_layer.self_attn, loop)
273
+ self.cross_attn = DecoderMultiHeadCrossAttention(decoder_layer.cross_attn)
274
+
275
+ def forward(self,
276
+ x: Tensor,
277
+ self_k_cache: Tensor,
278
+ self_v_cache: Tensor,
279
+ cross_k: Tensor,
280
+ cross_v: Tensor,
281
+ self_attn_mask: Tensor,
282
+ cross_attn_mask: Tensor):
283
+ # q.shape (B, 1, dim)
284
+ x_self_attn_norm = self.original_decoder_layer.self_attn_norm(x)
285
+ self_attn_x, self_k_cache_updated, self_v_cache_updated = self.self_attn(
286
+ x_self_attn_norm, self_k_cache, self_v_cache, self_attn_mask)
287
+
288
+ x = x + self_attn_x
289
+
290
+ residual = x
291
+ x_cross_attn_norm = self.original_decoder_layer.cross_attn_norm(x)
292
+ x_cross_attn = self.cross_attn(x_cross_attn_norm, cross_k, cross_v, cross_attn_mask)
293
+ x = residual + x_cross_attn
294
+
295
+ x = x + self.original_decoder_layer.mlp(self.original_decoder_layer.mlp_norm(x))
296
+
297
+ return x, self_k_cache_updated, self_v_cache_updated
298
+
299
+
300
+ class TextDecoderTensorCache(nn.Module):
301
+ def __init__(self, decoder: TransformerDecoder):
302
+ super().__init__()
303
+ self.decoder = decoder
304
+
305
+ self.blocks = []
306
+ for original_layer in self.decoder.layer_stack:
307
+ self.blocks.append(
308
+ ResidualAttentionBlockTensorCache(original_layer))
309
+
310
+ def forward(self,
311
+ tokens: Tensor,
312
+ n_layer_self_k_cache: Tensor,
313
+ n_layer_self_v_cache: Tensor,
314
+ n_layer_cross_k: Tensor,
315
+ n_layer_cross_v: Tensor,
316
+ offset: Tensor,
317
+ self_attn_mask: Tensor,
318
+ cross_attn_mask: Tensor):
319
+ """
320
+ TODO(Lianghu): Integrate self_attn_mask into the model inference process
321
+ instead of passing it in through an external interface.
322
+ """
323
+ x = self.decoder.dropout(
324
+ self.decoder.tgt_word_emb(tokens) * self.decoder.scale +
325
+ self.decoder.positional_encoding(offset + 1)
326
+ )
327
+
328
+ i = 0
329
+ for block in self.blocks:
330
+ self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :]
331
+ self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :]
332
+ x, self_k_cache, self_v_cache = block(
333
+ x,
334
+ self_k_cache,
335
+ self_v_cache,
336
+ n_layer_cross_k[i],
337
+ n_layer_cross_v[i],
338
+ self_attn_mask,
339
+ cross_attn_mask
340
+ )
341
+ n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache
342
+ n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache
343
+ i += 1
344
+
345
+ output = self.decoder.layer_norm_out(x)
346
+ logits = self.decoder.tgt_word_prj(output)
347
+
348
+ return logits, n_layer_self_k_cache, n_layer_self_v_cache
349
+
350
+
351
+ class TextDecoderTensorCacheV2(nn.Module):
352
+ def __init__(self, decoder: TransformerDecoder, loop: bool = False):
353
+ super().__init__()
354
+ self.decoder = decoder
355
+ self.loop = loop
356
+
357
+ self.blocks = []
358
+ for original_layer in self.decoder.layer_stack:
359
+ self.blocks.append(
360
+ ResidualAttentionBlockTensorCacheV2(original_layer, loop))
361
+
362
+ def forward(self,
363
+ tokens: Tensor,
364
+ n_layer_self_k_cache: Tensor,
365
+ n_layer_self_v_cache: Tensor,
366
+ n_layer_cross_k: Tensor,
367
+ n_layer_cross_v: Tensor,
368
+ positional_embedding: Tensor,
369
+ self_attn_mask: Tensor,
370
+ cross_attn_mask: Tensor):
371
+ """
372
+ TODO(Lianghu): Integrate self_attn_mask into the model inference process
373
+ instead of passing it in through an external interface.
374
+ """
375
+ x = self.decoder.dropout(
376
+ self.decoder.tgt_word_emb(tokens) * self.decoder.scale +
377
+ positional_embedding
378
+ )
379
+ # if self.loop:
380
+ # x = self.decoder.dropout(
381
+ # self.decoder.tgt_word_emb(tokens) * self.decoder.scale +
382
+ # positional_embedding
383
+ # )
384
+ # else:
385
+ # x = self.decoder.dropout(
386
+ # self.decoder.tgt_word_emb(tokens) * self.decoder.scale +
387
+ # self.decoder.positional_encoding.pe[:, : tokens.shape[-1]]
388
+ # )
389
+
390
+ i = 0
391
+ self_k_cache_out = []
392
+ self_v_cache_out = []
393
+ for block in self.blocks:
394
+ self_k_cache = n_layer_self_k_cache[i, :, :, :]
395
+ self_v_cache = n_layer_self_v_cache[i, :, :, :]
396
+ if self.loop:
397
+ x, self_k_cache, self_v_cache = block(
398
+ x,
399
+ self_k_cache,
400
+ self_v_cache,
401
+ n_layer_cross_k[i],
402
+ n_layer_cross_v[i],
403
+ self_attn_mask,
404
+ cross_attn_mask
405
+ )
406
+ self_k_cache_out.append(self_k_cache.unsqueeze(0))
407
+ self_v_cache_out.append(self_v_cache.unsqueeze(0))
408
+ else:
409
+ n_audio, n_text_ctx, ntext_state = self_k_cache.shape
410
+
411
+ x, self_k_cache, self_v_cache = block(
412
+ x,
413
+ self_k_cache,
414
+ self_v_cache,
415
+ n_layer_cross_k[i],
416
+ n_layer_cross_v[i],
417
+ self_attn_mask,
418
+ cross_attn_mask
419
+ )
420
+ self_k_cache_out.append(torch.cat((torch.zeros([n_audio, n_text_ctx - self_k_cache.shape[1], ntext_state]).to(self_k_cache.device), self_k_cache), 1).unsqueeze(0))
421
+ self_v_cache_out.append(torch.cat((torch.zeros([n_audio, n_text_ctx - self_v_cache.shape[1], ntext_state]).to(self_v_cache.device), self_v_cache), 1).unsqueeze(0))
422
+
423
+ i += 1
424
+
425
+ n_layer_self_k_cache = torch.cat(self_k_cache_out, 0)
426
+ n_layer_self_v_cache = torch.cat(self_v_cache_out, 0)
427
+
428
+ output = self.decoder.layer_norm_out(x)
429
+ logits = self.decoder.tgt_word_prj(output)
430
+
431
+ return logits, n_layer_self_k_cache, n_layer_self_v_cache
model_convert/to_onnx.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import model_wrapper
2
+ from fireredasr.models.fireredasr import FireRedAsrAed
3
+
4
+ import torch
5
+ import onnx
6
+ import onnxruntime
7
+ from onnxruntime.quantization import QuantType, quantize_dynamic
8
+ import onnxslim
9
+ from onnx.external_data_helper import convert_model_to_external_data
10
+ import numpy as np
11
+ import math
12
+ import kaldiio
13
+
14
+ import os
15
+ import argparse
16
+ from typing import Dict, Any
17
+
18
+ def to_numpy(tensor):
19
+ if tensor.requires_grad:
20
+ return tensor.detach().cpu().numpy()
21
+ else:
22
+ return tensor.cpu().numpy()
23
+
24
+
25
+ def load_model(model_path):
26
+ package = torch.load(model_path,
27
+ map_location=lambda storage,
28
+ loc: storage, weights_only=False)
29
+ model = FireRedAsrAed.from_args(package["args"])
30
+ model.load_state_dict(package["model_state_dict"], strict=True)
31
+ return model, package["args"]
32
+
33
+
34
+ def read_kaldi_cmvn(kaldi_cmvn_file):
35
+ assert os.path.exists(kaldi_cmvn_file)
36
+ stats = kaldiio.load_mat(kaldi_cmvn_file)
37
+ assert stats.shape[0] == 2
38
+ dim = stats.shape[-1] - 1
39
+ count = stats[0, dim]
40
+ assert count >= 1
41
+ floor = 1e-20
42
+ means = []
43
+ inverse_std_variences = []
44
+ for d in range(dim):
45
+ mean = stats[0, d] / count
46
+ means.append(mean.item())
47
+ varience = (stats[1, d] / count) - mean*mean
48
+ if varience < floor:
49
+ varience = floor
50
+ istd = 1.0 / math.sqrt(varience)
51
+ inverse_std_variences.append(istd)
52
+ return means, inverse_std_variences
53
+
54
+
55
+ def add_meta_data(filename: str, meta_data: Dict[str, Any]):
56
+ """Add meta data to an ONNX model. It is changed in-place.
57
+
58
+ Args:
59
+ filename:
60
+ Filename of the ONNX model to be changed.
61
+ meta_data:
62
+ Key-value pairs.
63
+ """
64
+ model = onnx.load(filename)
65
+
66
+ while len(model.metadata_props):
67
+ model.metadata_props.pop()
68
+
69
+ for key, value in meta_data.items():
70
+ meta = model.metadata_props.add()
71
+ meta.key = key
72
+ meta.value = str(value)
73
+
74
+ onnx.save(model, filename)
75
+
76
+
77
+ def calc_feat_len(audio_dur):
78
+ import math
79
+ sample_rate = 16000
80
+ frame_length = 25 * sample_rate / 1000
81
+ frame_shift = 10 * sample_rate / 1000
82
+ length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
83
+ return length
84
+
85
+
86
+ def export_encoder(fireredasr_model, args, model_args):
87
+ encoder = model_wrapper.AudioEncoderTensorCache(
88
+ fireredasr_model.encoder,
89
+ fireredasr_model.decoder)
90
+ encoder.eval()
91
+
92
+ # forge encoder input
93
+ encoder_input = torch.randn(1, calc_feat_len(10), 80)
94
+ encoder_input_lengths = torch.tensor([100], dtype=torch.int64)
95
+
96
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = encoder(
97
+ encoder_input,
98
+ encoder_input_lengths
99
+ )
100
+
101
+ if not os.path.exists(args.encoder):
102
+ os.makedirs(args.encoder)
103
+ onnx_encoder_file = os.path.join(args.encoder, "encoder.onnx")
104
+
105
+ with torch.no_grad():
106
+ torch.onnx.export(
107
+ encoder,
108
+ (encoder_input, encoder_input_lengths),
109
+ onnx_encoder_file,
110
+ export_params=True,
111
+ do_constant_folding=True,
112
+ opset_version=16,
113
+ verbose=False,
114
+ input_names=["encoder_input",
115
+ "encoder_input_lengths"],
116
+ output_names=["n_layer_cross_k",
117
+ "n_layer_cross_v",
118
+ "cross_attn_mask"],
119
+ # dynamic_axes={
120
+ # "encoder_input": {
121
+ # 0: "batch_size",
122
+ # 1: "input_length"
123
+ # },
124
+ # "encoder_input_lengths": {
125
+ # 0: "batch_size"
126
+ # },
127
+ # "n_layer_cross_k": {
128
+ # 1: "batch_size",
129
+ # 2: "length"
130
+ # },
131
+ # "n_layer_cross_v": {
132
+ # 1: "batch_size",
133
+ # 2: "length"
134
+ # },
135
+ # "cross_attn_mask": {
136
+ # 0: "batch_size",
137
+ # 2: "length"
138
+ # }
139
+ # },
140
+ external_data=True
141
+ )
142
+
143
+ external_filename = os.path.basename(onnx_encoder_file).split(".onnx")[0]
144
+ model = onnx.load(onnx_encoder_file)
145
+ convert_model_to_external_data(
146
+ model,
147
+ all_tensors_to_one_file=True,
148
+ location=f"./{external_filename}.data",
149
+ size_threshold=0,
150
+ convert_attribute=False
151
+ )
152
+
153
+ onnx.save_model(
154
+ model,
155
+ onnx_encoder_file,
156
+ save_as_external_data=True,
157
+ all_tensors_to_one_file=True,
158
+ location=f"./{external_filename}.data",
159
+ size_threshold=0
160
+ )
161
+
162
+ onnx.checker.check_model(onnx_encoder_file, True)
163
+ ort_session = onnxruntime.InferenceSession(onnx_encoder_file)
164
+ onnx_encoder_input = to_numpy(encoder_input)
165
+ onxx_encoder_input_lengths = to_numpy(encoder_input_lengths)
166
+ ort_inputs = {ort_session.get_inputs()[0].name: onnx_encoder_input,
167
+ ort_session.get_inputs()[1].name: onxx_encoder_input_lengths}
168
+ ort_outputs = ort_session.run(None, ort_inputs)
169
+
170
+ try:
171
+ np.testing.assert_allclose(to_numpy(n_layer_cross_k), ort_outputs[0], rtol=1e-03, atol=1e-05)
172
+ except AssertionError as e:
173
+ print(e)
174
+ try:
175
+ np.testing.assert_allclose(to_numpy(n_layer_cross_v), ort_outputs[1], rtol=1e-03, atol=1e-05)
176
+ except AssertionError as e:
177
+ print(e)
178
+ try:
179
+ np.testing.assert_allclose(to_numpy(cross_attn_mask), ort_outputs[2], rtol=1e-03, atol=1e-05)
180
+ except AssertionError as e:
181
+ print(e)
182
+
183
+ print("export onnx encoder done.")
184
+
185
+ # Generate int8 quantization models
186
+ # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
187
+ print("Generate int8 quantization models")
188
+
189
+ if not os.path.exists(args.encoder_int8):
190
+ os.mkdir(args.encoder_int8)
191
+ onnx_encoder_int8_file = "encoder_int8.onnx"
192
+ onnx_encoder_int8_file = os.path.join(args.encoder_int8, onnx_encoder_int8_file)
193
+ quantize_dynamic(
194
+ model_input=onnx_encoder_file,
195
+ model_output=onnx_encoder_int8_file,
196
+ op_types_to_quantize=["MatMul"],
197
+ weight_type=QuantType.QInt8,
198
+ )
199
+
200
+ cmvn_mean, cmvn_inv_stddev = read_kaldi_cmvn(args.cmvn)
201
+ cmvn_mean = [str(m) for m in cmvn_mean]
202
+ cmvn_inv_stddev = [str(istd) for istd in cmvn_inv_stddev]
203
+
204
+ encoder_meta_data = {
205
+ "model_type": "FireRedAsrAED-L",
206
+ "maintainer": "LiangHu",
207
+ "feat_dim": model_args.idim,
208
+ "feat_type": "fbank",
209
+ "num_decoder_layers": model_args.n_layers_dec,
210
+ "num_head": model_args.n_head,
211
+ "head_dim": model_args.d_model // model_args.n_head,
212
+ "max_len": 448,
213
+ "sos": model_args.sos_id,
214
+ "eos": model_args.eos_id,
215
+ "cmvn_mean": ','.join(cmvn_mean),
216
+ "cmvn_inv_stddev": ','.join(cmvn_inv_stddev)
217
+ }
218
+
219
+ # add_meta_data(onnx_encoder_file, encoder_meta_data)
220
+ add_meta_data(onnx_encoder_int8_file, encoder_meta_data)
221
+
222
+ return n_layer_cross_k, n_layer_cross_v, cross_attn_mask
223
+
224
+
225
+ def export_decoder(fireredasr_model, args,
226
+ n_layer_cross_k,
227
+ n_layer_cross_v,
228
+ cross_attn_mask):
229
+ beam_size = 3
230
+
231
+ decoder = model_wrapper.TextDecoderTensorCache(
232
+ fireredasr_model.decoder)
233
+ decoder.eval()
234
+
235
+ num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
236
+ encoder_out_length = cross_attn_mask.size(-1)
237
+
238
+ # preparing for batch beam search
239
+ cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
240
+ 1, beam_size, 1, 1).view(beam_size * batch_size, -1, encoder_out_length)
241
+ n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
242
+ 1, 1, beam_size, 1, 1
243
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
244
+ n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
245
+ 1, 1, beam_size, 1, 1
246
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
247
+ tokens = torch.ones(beam_size * batch_size, 1).fill_(decoder.decoder.sos_id).long()
248
+
249
+ n_layer_self_k_cache = torch.zeros(
250
+ (
251
+ len(decoder.blocks),
252
+ batch_size * beam_size,
253
+ 448,
254
+ 1280
255
+ )
256
+ )
257
+ n_layer_self_v_cache = torch.zeros(
258
+ (
259
+ len(decoder.blocks),
260
+ batch_size * beam_size,
261
+ 448,
262
+ 1280
263
+ )
264
+ )
265
+ offset = torch.zeros(1, dtype=torch.int64)
266
+ self_attn_mask = torch.empty(batch_size * beam_size,
267
+ tokens.shape[-1], tokens.shape[-1]
268
+ ).fill_(-np.inf).triu_(1) # fill_(-np.inf)
269
+ self_attn_mask = self_attn_mask[:, -1:, :]
270
+
271
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder(
272
+ tokens,
273
+ n_layer_self_k_cache,
274
+ n_layer_self_v_cache,
275
+ n_layer_cross_k,
276
+ n_layer_cross_v,
277
+ offset,
278
+ self_attn_mask,
279
+ cross_attn_mask
280
+ )
281
+
282
+ if not os.path.exists(args.decoder):
283
+ os.makedirs(args.decoder)
284
+ onnx_decoder_file = os.path.join(args.decoder, "decoder.onnx")
285
+
286
+ with torch.no_grad():
287
+ torch.onnx.export(
288
+ decoder,
289
+ (tokens,
290
+ n_layer_self_k_cache,
291
+ n_layer_self_v_cache,
292
+ n_layer_cross_k,
293
+ n_layer_cross_v,
294
+ offset,
295
+ self_attn_mask,
296
+ cross_attn_mask),
297
+ onnx_decoder_file,
298
+ export_params=True,
299
+ opset_version=13,
300
+ verbose=False,
301
+ input_names=["tokens",
302
+ "in_n_layer_self_k_cache",
303
+ "in_n_layer_self_v_cache",
304
+ "n_layer_cross_k",
305
+ "n_layer_cross_v",
306
+ "offset",
307
+ "self_attn_mask",
308
+ "cross_attn_mask"],
309
+ output_names=["logits",
310
+ "out_n_layer_self_k_cache",
311
+ "out_n_layer_self_v_cache"],
312
+ dynamic_axes={
313
+ "tokens": {0: "n_audio", 1: "n_tokens"},
314
+ "in_n_layer_self_k_cache": {1: "n_audio"},
315
+ "in_n_layer_self_v_cache": {1: "n_audio"},
316
+ "n_layer_cross_k": {1: "n_audio", 2: "T"},
317
+ "n_layer_cross_v": {1: "n_audio", 2: "T"},
318
+ "self_attn_mask": {0: "n_audio", 2: "T"},
319
+ "cross_attn_mask": {0: "n_audio", 2: "T"},
320
+ },
321
+ external_data=True
322
+ )
323
+
324
+ onnx.checker.check_model(onnx_decoder_file)
325
+ ort_session = onnxruntime.InferenceSession(onnx_decoder_file)
326
+
327
+ onnx_tokens = to_numpy(tokens)
328
+ onnx_n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
329
+ onnx_n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
330
+ onnx_n_layer_cross_k = to_numpy(n_layer_cross_k)
331
+ onnx_n_layer_cross_v = to_numpy(n_layer_cross_v)
332
+ onnx_offset = to_numpy(offset)
333
+ onnx_self_attn_mask = to_numpy(self_attn_mask)
334
+ onnx_cross_attn_mask = to_numpy(cross_attn_mask)
335
+
336
+ ort_inputs = {ort_session.get_inputs()[0].name: onnx_tokens,
337
+ ort_session.get_inputs()[1].name: onnx_n_layer_self_k_cache,
338
+ ort_session.get_inputs()[2].name: onnx_n_layer_self_v_cache,
339
+ ort_session.get_inputs()[3].name: onnx_n_layer_cross_k,
340
+ ort_session.get_inputs()[4].name: onnx_n_layer_cross_v,
341
+ ort_session.get_inputs()[5].name: onnx_offset,
342
+ ort_session.get_inputs()[6].name: onnx_self_attn_mask,
343
+ ort_session.get_inputs()[7].name: onnx_cross_attn_mask}
344
+ ort_outputs = ort_session.run(None, ort_inputs)
345
+
346
+ try:
347
+ np.testing.assert_allclose(to_numpy(logits), ort_outputs[0], rtol=1e-03, atol=1e-05)
348
+ except AssertionError as e:
349
+ print(e)
350
+ try:
351
+ np.testing.assert_allclose(to_numpy(out_n_layer_self_k_cache), ort_outputs[1], rtol=1e-03, atol=1e-05)
352
+ except AssertionError as e:
353
+ print(e)
354
+ try:
355
+ np.testing.assert_allclose(to_numpy(out_n_layer_self_v_cache), ort_outputs[2], rtol=1e-03, atol=1e-05)
356
+ except AssertionError as e:
357
+ print(e)
358
+
359
+ print("export onnx decoder done.")
360
+
361
+ if not os.path.exists(args.decoder_int8):
362
+ os.mkdir(args.decoder_int8)
363
+ onnx_decoder_int8_file = "decoder_int8.onnx"
364
+ onnx_decoder_int8_file = os.path.join(args.decoder_int8, onnx_decoder_int8_file)
365
+ quantize_dynamic(
366
+ model_input=onnx_decoder_file,
367
+ model_output=onnx_decoder_int8_file,
368
+ op_types_to_quantize=["MatMul"],
369
+ weight_type=QuantType.QInt8,
370
+ )
371
+
372
+ # decoder main
373
+ decoder = model_wrapper.TextDecoderTensorCacheV2(
374
+ fireredasr_model.decoder, loop=False)
375
+ decoder.eval()
376
+
377
+ self_attn_mask = torch.empty(batch_size * beam_size,
378
+ tokens.shape[-1], tokens.shape[-1]
379
+ ).fill_(-np.inf).triu_(1) # fill_(-np.inf)
380
+ self_attn_mask = self_attn_mask[:, -1:, :]
381
+
382
+ pe = decoder.decoder.positional_encoding.pe[0]
383
+
384
+ onnx_decoder_file = os.path.join(args.decoder, "decoder_main.onnx")
385
+
386
+ with torch.no_grad():
387
+ torch.onnx.export(
388
+ decoder,
389
+ (tokens,
390
+ n_layer_self_k_cache,
391
+ n_layer_self_v_cache,
392
+ n_layer_cross_k,
393
+ n_layer_cross_v,
394
+ pe[0],
395
+ self_attn_mask,
396
+ cross_attn_mask),
397
+ onnx_decoder_file,
398
+ export_params=True,
399
+ opset_version=13,
400
+ verbose=False,
401
+ input_names=["tokens",
402
+ "in_n_layer_self_k_cache",
403
+ "in_n_layer_self_v_cache",
404
+ "n_layer_cross_k",
405
+ "n_layer_cross_v",
406
+ "pe",
407
+ "self_attn_mask",
408
+ "cross_attn_mask"],
409
+ output_names=["logits",
410
+ "out_n_layer_self_k_cache",
411
+ "out_n_layer_self_v_cache"],
412
+ # dynamic_axes={
413
+ # "tokens": {0: "n_audio", 1: "n_tokens"},
414
+ # "in_n_layer_self_k_cache": {1: "n_audio"},
415
+ # "in_n_layer_self_v_cache": {1: "n_audio"},
416
+ # "n_layer_cross_k": {1: "n_audio", 2: "T"},
417
+ # "n_layer_cross_v": {1: "n_audio", 2: "T"},
418
+ # "self_attn_mask": {0: "n_audio", 2: "T"},
419
+ # "cross_attn_mask": {0: "n_audio", 2: "T"},
420
+ # },
421
+ external_data=True
422
+ )
423
+ print(f"Export decoder_main to {onnx_decoder_file}")
424
+
425
+ # decoder loop
426
+ decoder = model_wrapper.TextDecoderTensorCacheV2(
427
+ fireredasr_model.decoder, loop=True)
428
+ decoder.eval()
429
+
430
+ pe = decoder.decoder.positional_encoding.pe[0]
431
+ pe_file = os.path.join(args.decoder, "pe.npy")
432
+ np.save(pe_file, pe.numpy())
433
+
434
+ onnx_decoder_file = os.path.join(args.decoder, "decoder_loop.onnx")
435
+
436
+ with torch.no_grad():
437
+ torch.onnx.export(
438
+ decoder,
439
+ (tokens,
440
+ n_layer_self_k_cache,
441
+ n_layer_self_v_cache,
442
+ n_layer_cross_k,
443
+ n_layer_cross_v,
444
+ pe[0],
445
+ self_attn_mask,
446
+ cross_attn_mask),
447
+ onnx_decoder_file,
448
+ export_params=True,
449
+ opset_version=13,
450
+ verbose=False,
451
+ input_names=["tokens",
452
+ "in_n_layer_self_k_cache",
453
+ "in_n_layer_self_v_cache",
454
+ "n_layer_cross_k",
455
+ "n_layer_cross_v",
456
+ "pe",
457
+ "self_attn_mask",
458
+ "cross_attn_mask"],
459
+ output_names=["logits",
460
+ "out_n_layer_self_k_cache",
461
+ "out_n_layer_self_v_cache"],
462
+ # dynamic_axes={
463
+ # "tokens": {0: "n_audio", 1: "n_tokens"},
464
+ # "in_n_layer_self_k_cache": {1: "n_audio"},
465
+ # "in_n_layer_self_v_cache": {1: "n_audio"},
466
+ # "n_layer_cross_k": {1: "n_audio", 2: "T"},
467
+ # "n_layer_cross_v": {1: "n_audio", 2: "T"},
468
+ # "self_attn_mask": {0: "n_audio", 2: "T"},
469
+ # "cross_attn_mask": {0: "n_audio", 2: "T"},
470
+ # },
471
+ external_data=True
472
+ )
473
+ print(f"Export decoder_loop to {onnx_decoder_file}")
474
+
475
+
476
+ def parse_args():
477
+ parser = argparse.ArgumentParser(description="export FireRedASR-AED torch model to onnx")
478
+ parser.add_argument(
479
+ "--model",
480
+ type=str,
481
+ required=True,
482
+ help="Path to FireRedASR-AED torch model"
483
+ )
484
+ parser.add_argument(
485
+ "--encoder",
486
+ type=str,
487
+ required=True,
488
+ help="Dir to the exported onnx encoder"
489
+ )
490
+ parser.add_argument(
491
+ "--decoder",
492
+ type=str,
493
+ required=True,
494
+ help="Dir to the exported onnx decoder"
495
+ )
496
+ parser.add_argument(
497
+ "--encoder_int8",
498
+ type=str,
499
+ required=True,
500
+ help="Dir to the exported onnx encoder after int8 quantization"
501
+ )
502
+ parser.add_argument(
503
+ "--decoder_int8",
504
+ type=str,
505
+ required=True,
506
+ help="Dir to the exported onnx encoder after int8 quantization"
507
+ )
508
+ parser.add_argument(
509
+ "--cmvn",
510
+ type=str,
511
+ required=True,
512
+ help="cmvn.ark file"
513
+ )
514
+ return parser.parse_args()
515
+
516
+
517
+ def main():
518
+ args = parse_args()
519
+ fireredasr_model, model_args = load_model(args.model)
520
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = export_encoder(fireredasr_model, args, model_args)
521
+ export_decoder(fireredasr_model, args, n_layer_cross_k, n_layer_cross_v, cross_attn_mask)
522
+
523
+
524
+ if __name__ == "__main__":
525
+ main()
test_ax_model.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fireredasr.data.asr_feat import ASRFeatExtractor
2
+ from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
3
+
4
+ import onnxruntime as ort
5
+ import axengine as axe
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from torch import Tensor
10
+ from typing import Tuple, List, Dict
11
+ import argparse
12
+ import os
13
+ import time
14
+ import logging
15
+
16
+ logger = logging.getLogger()
17
+ logger.setLevel(logging.INFO)
18
+ logger_stream_hander = logging.StreamHandler()
19
+ logger_stream_hander.setLevel("INFO")
20
+ logger.addHandler(logger_stream_hander)
21
+
22
+
23
+ INF = 1e10
24
+
25
+
26
+ def to_numpy(tensor):
27
+ if isinstance(tensor, np.ndarray):
28
+ return tensor
29
+ if tensor.requires_grad:
30
+ return tensor.detach().cpu().numpy()
31
+ else:
32
+ return tensor.cpu().numpy()
33
+
34
+
35
+ def set_finished_beam_score_to_zero(scores, is_finished):
36
+ NB, B = scores.size()
37
+ is_finished = is_finished.float()
38
+ mask_score = torch.tensor([0.0] + [-INF]*(B-1)).float()
39
+ mask_score = mask_score.view(1, B).repeat(NB, 1)
40
+ return scores * (1 - is_finished) + mask_score * is_finished
41
+
42
+
43
+ def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
44
+ is_finished = is_finished.long()
45
+ return ys * (1 - is_finished) + eos_id * is_finished
46
+
47
+
48
+ class FireRedASROnnxModel:
49
+ def __init__(
50
+ self,
51
+ encoder_path: str,
52
+ decoder_path: str,
53
+ cmvn_file: str,
54
+ dict_file: str,
55
+ spm_model_path: str,
56
+ providers=['AXCLRTExecutionProvider', 'AxEngineExecutionProvider']
57
+ ):
58
+ session_opts = ort.SessionOptions()
59
+ session_opts.inter_op_num_threads = 1
60
+ session_opts.intra_op_num_threads = 1
61
+ # session_opts.log_severity_level = 1
62
+ self.session_opts = session_opts
63
+
64
+ # NOTE: 参考whisper设置的最大的解码长度
65
+ # FireRedASR-AED 模型支持的最长语音为 60s
66
+ # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
67
+ self.decode_max_len = 448
68
+
69
+ self.decoder_hidden_dim = 1280
70
+ self.num_decoder_blocks = 16
71
+ self.blank_id = 0
72
+ self.sos_id = 3
73
+ self.eos_id = 4
74
+ self.pad_id = 2
75
+
76
+ self.feature_extractor = ASRFeatExtractor(cmvn_file)
77
+ self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
78
+ self.encoder = None
79
+ self.decoder = None
80
+
81
+ self.init_encoder(encoder_path, providers)
82
+ # self.init_decoder(decoder_path, providers)
83
+ self.init_decoder_main(decoder_path, providers)
84
+ self.init_decoder_loop(decoder_path, providers)
85
+ self.pe = self.init_pe(decoder_path)
86
+
87
+ def init_encoder(self, encoder_path, providers=None):
88
+ start_time = time.time()
89
+ self.encoder = axe.InferenceSession(
90
+ encoder_path,
91
+ # sess_options=self.session_opts,
92
+ providers=providers
93
+ )
94
+ end_time = time.time()
95
+ logger.info(f"load encoder cost {end_time - start_time} seconds")
96
+
97
+ def init_decoder(self, decoder_path, providers=None):
98
+ start_time = time.time()
99
+ self.decoder = ort.InferenceSession(
100
+ decoder_path,
101
+ # sess_options=self.session_opts,
102
+ providers=providers
103
+ )
104
+ end_time = time.time()
105
+ logger.info(f"load decoder cost {end_time - start_time} seconds")
106
+
107
+ def init_decoder_main(self, decoder_path, providers=None):
108
+ decoder_path = os.path.dirname(decoder_path)
109
+ decoder_path = os.path.join(decoder_path, "decoder_main.axmodel")
110
+ start_time = time.time()
111
+ self.decoder_main = axe.InferenceSession(
112
+ decoder_path,
113
+ # sess_options=self.session_opts,
114
+ providers=providers
115
+ )
116
+ end_time = time.time()
117
+ logger.info(f"load decoder_main cost {end_time - start_time} seconds")
118
+
119
+ # input_names = [i.name for i in self.decoder_main.get_inputs()]
120
+ # print(f"decoder_main.input_names: {input_names}")
121
+
122
+ def init_decoder_loop(self, decoder_path, providers=None):
123
+ decoder_path = os.path.dirname(decoder_path)
124
+ decoder_path = os.path.join(decoder_path, "decoder_loop.axmodel")
125
+
126
+ start_time = time.time()
127
+ self.decoder_loop = axe.InferenceSession(
128
+ decoder_path,
129
+ # sess_options=self.session_opts,
130
+ providers=providers
131
+ )
132
+ end_time = time.time()
133
+ logger.info(f"load decoder_loop cost {end_time - start_time} seconds")
134
+
135
+ # input_names = [i.name for i in self.decoder_loop.get_inputs()]
136
+ # print(f"decoder_loop.input_names: {input_names}")
137
+
138
+ def init_pe(self, decoder_path):
139
+ decoder_path = os.path.dirname(decoder_path)
140
+ decoder_path = os.path.join(decoder_path, "pe.npy")
141
+
142
+ return np.load(decoder_path)
143
+
144
+ def run_encoder(self, input: np.ndarray,
145
+ input_length: np.ndarray
146
+ ) -> Tuple[Tensor, Tensor, Tensor]:
147
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
148
+ None,
149
+ {
150
+ "encoder_input": input,
151
+ "encoder_input_lengths": input_length
152
+ }
153
+ )
154
+ # n_layer_cross_k, n_layer_cross_v, cross_attn_mask = \
155
+ # outputs["n_layer_cross_k"], outputs["n_layer_cross_v"], outputs["cross_attn_mask"]
156
+ return (
157
+ n_layer_cross_k,
158
+ n_layer_cross_v,
159
+ cross_attn_mask
160
+ )
161
+
162
+ def decode_one_token(
163
+ self,
164
+ tokens: np.ndarray,
165
+ n_layer_self_k_cache: np.ndarray,
166
+ n_layer_self_v_cache: np.ndarray,
167
+ n_layer_cross_k_cache: np.ndarray,
168
+ n_layer_cross_v_cache: np.ndarray,
169
+ offset: np.ndarray,
170
+ self_attn_mask: np.ndarray,
171
+ cross_attn_mask: np.ndarray
172
+ ) -> Tuple[Tensor, Tensor, Tensor]:
173
+ print("decode:")
174
+ print(f"tokens.shape: {tokens.shape}")
175
+ print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
176
+ print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
177
+ print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
178
+ print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
179
+ print(f"offset.shape: {offset.shape}")
180
+ print(f"self_attn_mask.shape: {self_attn_mask.shape}")
181
+ print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
182
+ # print(f"self_attn_mask: {self_attn_mask}")
183
+
184
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
185
+ None,
186
+ {
187
+ self.decoder.get_inputs()[0].name: tokens,
188
+ self.decoder.get_inputs()[1].name: n_layer_self_k_cache,
189
+ self.decoder.get_inputs()[2].name: n_layer_self_v_cache,
190
+ self.decoder.get_inputs()[3].name: n_layer_cross_k_cache,
191
+ self.decoder.get_inputs()[4].name: n_layer_cross_v_cache,
192
+ self.decoder.get_inputs()[5].name: offset,
193
+ self.decoder.get_inputs()[6].name: self_attn_mask,
194
+ self.decoder.get_inputs()[7].name: cross_attn_mask,
195
+ }
196
+ )
197
+ return (
198
+ logits,
199
+ out_n_layer_self_k_cache,
200
+ out_n_layer_self_v_cache
201
+ )
202
+
203
+ def decode_main_one_token(
204
+ self,
205
+ tokens: np.ndarray,
206
+ n_layer_self_k_cache: np.ndarray,
207
+ n_layer_self_v_cache: np.ndarray,
208
+ n_layer_cross_k_cache: np.ndarray,
209
+ n_layer_cross_v_cache: np.ndarray,
210
+ pe: np.ndarray,
211
+ self_attn_mask: np.ndarray,
212
+ cross_attn_mask: np.ndarray
213
+ ) -> Tuple[Tensor, Tensor, Tensor]:
214
+ # print("decode_main:")
215
+ # print(f"tokens.shape: {tokens.shape}")
216
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
217
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
218
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
219
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
220
+ # print(f"pe.shape: {pe.shape}")
221
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
222
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
223
+
224
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_main.run(
225
+ None,
226
+ {
227
+ "tokens": tokens,
228
+ # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
229
+ "n_layer_cross_k": n_layer_cross_k_cache,
230
+ "n_layer_cross_v": n_layer_cross_v_cache,
231
+ "pe": pe,
232
+ "self_attn_mask": self_attn_mask,
233
+ "cross_attn_mask": cross_attn_mask,
234
+ # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
235
+ }
236
+ )
237
+ # logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = \
238
+ # outputs["logits"], outputs["out_n_layer_self_k_cache"], outputs["out_n_layer_self_v_cache"]
239
+ return (
240
+ logits,
241
+ out_n_layer_self_k_cache,
242
+ out_n_layer_self_v_cache
243
+ )
244
+
245
+ def decode_loop_one_token(
246
+ self,
247
+ tokens: np.ndarray,
248
+ n_layer_self_k_cache: np.ndarray,
249
+ n_layer_self_v_cache: np.ndarray,
250
+ n_layer_cross_k_cache: np.ndarray,
251
+ n_layer_cross_v_cache: np.ndarray,
252
+ pe: np.ndarray,
253
+ self_attn_mask: np.ndarray,
254
+ cross_attn_mask: np.ndarray
255
+ ) -> Tuple[Tensor, Tensor, Tensor]:
256
+ # print("decode_loop:")
257
+ # print(f"tokens.shape: {tokens.shape}")
258
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
259
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
260
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
261
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
262
+ # print(f"pe.shape: {pe.shape}")
263
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
264
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
265
+
266
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_loop.run(
267
+ None,
268
+ {
269
+ "tokens": tokens,
270
+ "in_n_layer_self_k_cache": n_layer_self_k_cache,
271
+ "in_n_layer_self_v_cache": n_layer_self_v_cache,
272
+ "n_layer_cross_k": n_layer_cross_k_cache,
273
+ "n_layer_cross_v": n_layer_cross_v_cache,
274
+ "pe": pe,
275
+ "self_attn_mask": self_attn_mask,
276
+ "cross_attn_mask": cross_attn_mask,
277
+ }
278
+ )
279
+ # logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = \
280
+ # outputs["logits"], outputs["out_n_layer_self_k_cache"], outputs["out_n_layer_self_v_cache"]
281
+ return (
282
+ logits,
283
+ out_n_layer_self_k_cache,
284
+ out_n_layer_self_v_cache
285
+ )
286
+
287
+ def run_decoder(
288
+ self,
289
+ n_layer_cross_k,
290
+ n_layer_cross_v,
291
+ cross_attn_mask,
292
+ beam_size,
293
+ nbest
294
+ ):
295
+
296
+ num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
297
+ encoder_out_length = cross_attn_mask.shape[-1]
298
+
299
+ cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
300
+ cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
301
+ 1, beam_size, 1, 1
302
+ ).view(beam_size * batch_size, -1, encoder_out_length)
303
+
304
+ n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
305
+ n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
306
+ n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
307
+ 1, 1, beam_size, 1, 1
308
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
309
+ n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
310
+ 1, 1, beam_size, 1, 1
311
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
312
+
313
+ prediction_tokens = torch.ones(
314
+ beam_size * batch_size, 1).fill_(self.sos_id).long()
315
+ tokens = prediction_tokens
316
+ offset = torch.zeros(1, dtype=torch.int64)
317
+ n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
318
+ batch_size, beam_size
319
+ )
320
+
321
+ scores = torch.tensor([0.0] + [-INF]*(beam_size - 1)).float()
322
+ scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
323
+ is_finished = torch.zeros_like(scores)
324
+
325
+ # self_attn_mask = torch.zeros(
326
+ # batch_size * beam_size,
327
+ # 1, 1
328
+ # )
329
+ self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
330
+
331
+ results = [self.sos_id]
332
+ for i in range(self.decode_max_len):
333
+
334
+ # self_attn_mask = torch.empty(
335
+ # batch_size * beam_size,
336
+ # prediction_tokens.shape[-1], prediction_tokens.shape[-1]
337
+ # ).fill_(-np.inf).triu_(1)
338
+ # self_attn_mask = self_attn_mask[:, -1:, :]
339
+ # self_attn_mask = to_numpy(self_attn_mask)
340
+
341
+ # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_one_token(
342
+ # to_numpy(tokens),
343
+ # to_numpy(n_layer_self_k_cache),
344
+ # to_numpy(n_layer_self_v_cache),
345
+ # to_numpy(n_layer_cross_k),
346
+ # to_numpy(n_layer_cross_v),
347
+ # to_numpy(offset),
348
+ # to_numpy(self_attn_mask),
349
+ # to_numpy(cross_attn_mask)
350
+ # )
351
+
352
+ tokens = to_numpy(tokens).astype(np.int32)
353
+ n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
354
+ n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
355
+ n_layer_cross_k = to_numpy(n_layer_cross_k)
356
+ n_layer_cross_v = to_numpy(n_layer_cross_v)
357
+ cross_attn_mask = to_numpy(cross_attn_mask)
358
+
359
+ # for name, npy in zip(
360
+ # ["tokens", "n_layer_self_k_cache", "n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "pe", "self_attn_mask", "cross_attn_mask"],
361
+ # [tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, self.pe[offset], self_attn_mask, cross_attn_mask]
362
+ # ):
363
+ # file_path = os.path.join(decoder_data_path, name)
364
+ # os.makedirs(file_path, exist_ok=True)
365
+ # np.save(os.path.join(file_path, f"{i}.npy"), npy)
366
+
367
+ if i == 0:
368
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
369
+ to_numpy(tokens),
370
+ to_numpy(n_layer_self_k_cache),
371
+ to_numpy(n_layer_self_v_cache),
372
+ to_numpy(n_layer_cross_k),
373
+ to_numpy(n_layer_cross_v),
374
+ self.pe[offset],
375
+ self_attn_mask,
376
+ to_numpy(cross_attn_mask)
377
+ )
378
+ else:
379
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
380
+ to_numpy(tokens),
381
+ to_numpy(n_layer_self_k_cache),
382
+ to_numpy(n_layer_self_v_cache),
383
+ to_numpy(n_layer_cross_k),
384
+ to_numpy(n_layer_cross_v),
385
+ self.pe[offset],
386
+ self_attn_mask,
387
+ to_numpy(cross_attn_mask)
388
+ )
389
+
390
+ offset += 1
391
+ logits = torch.from_numpy(logits)
392
+
393
+ logits = logits.squeeze(1)
394
+ t_scores = F.log_softmax(logits, dim=-1)
395
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
396
+ t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
397
+ t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
398
+
399
+ scores = scores + t_topB_scores
400
+
401
+ scores = scores.view(batch_size, beam_size * beam_size)
402
+ scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
403
+ scores = scores.view(-1, 1)
404
+
405
+ topB_row_number_in_each_B_rows_of_ys = torch.div(
406
+ topB_score_ids, beam_size).view(batch_size * beam_size)
407
+ stride = beam_size * torch.arange(batch_size).view(
408
+ batch_size, 1).repeat(1, beam_size).view(batch_size * beam_size)
409
+ topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
410
+
411
+ prediction_tokens = prediction_tokens[topB_row_number_in_ys]
412
+ t_ys = torch.gather(
413
+ t_topB_ys.view(batch_size, beam_size * beam_size),
414
+ dim=1, index=topB_score_ids
415
+ ).view(beam_size * batch_size, 1)
416
+
417
+ tokens = t_ys
418
+
419
+ prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
420
+
421
+ n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
422
+ n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
423
+
424
+ for i, self_k_cache in enumerate(n_layer_self_k_cache):
425
+ n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
426
+
427
+ for i, self_v_cache in enumerate(n_layer_self_v_cache):
428
+ n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
429
+
430
+ is_finished = t_ys.eq(self.eos_id)
431
+ if is_finished.sum().item() == beam_size * batch_size:
432
+ break
433
+
434
+ scores = scores.view(batch_size, beam_size)
435
+ prediction_valid_token_lengths = torch.sum(
436
+ torch.ne(
437
+ prediction_tokens.view(batch_size, beam_size, -1),
438
+ self.eos_id),
439
+ dim=-1
440
+ ).int()
441
+
442
+ nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
443
+ index = nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
444
+ nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
445
+ nbest_prediction_tokens = nbest_prediction_tokens.view(batch_size, nbest_ids.size(1), -1)
446
+ nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
447
+ batch_size * beam_size)[index.view(-1)].view(batch_size, -1)
448
+ nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
449
+ for i in range(batch_size):
450
+ i_best_hyps: List[Dict[str, torch.Tensor]] = []
451
+ for j, score in enumerate(nbest_scores[i]):
452
+ hyp = {
453
+ "token_ids": nbest_prediction_tokens[i, j, 1:nbest_prediction_valid_token_lengths[i, j]],
454
+ "score": score
455
+ }
456
+ i_best_hyps.append(hyp)
457
+ nbest_hyps.append(i_best_hyps)
458
+
459
+ return nbest_hyps
460
+
461
+ def get_initialized_self_cache(self,
462
+ batch_size,
463
+ beam_size
464
+ ) -> Tuple[Tensor, Tensor]:
465
+ n_layer_self_k_cache = torch.zeros(
466
+ self.num_decoder_blocks,
467
+ batch_size * beam_size,
468
+ self.decode_max_len,
469
+ self.decoder_hidden_dim,
470
+ )
471
+ n_layer_self_v_cache = torch.zeros(
472
+ self.num_decoder_blocks,
473
+ batch_size * beam_size,
474
+ self.decode_max_len,
475
+ self.decoder_hidden_dim,
476
+ )
477
+ return n_layer_self_k_cache, n_layer_self_v_cache
478
+
479
+ def calc_feat_len(self, audio_dur):
480
+ import math
481
+ sample_rate = 16000
482
+ frame_length = 25 * sample_rate / 1000
483
+ frame_shift = 10 * sample_rate / 1000
484
+ length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
485
+ return length
486
+
487
+ def transcribe(self,
488
+ batch_wav_path: List[str],
489
+ beam_size: int = 1,
490
+ nbest: int = 1
491
+ ) -> List[Dict]:
492
+ feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
493
+ print(f"feats.shape: {feats.shape}")
494
+ maxlen = self.calc_feat_len(10)
495
+ if feats.shape[1] < maxlen:
496
+ feats = np.concatenate([feats, np.zeros((1, maxlen - feats.shape[1], 80), dtype=np.float32)], axis=1)
497
+ feats = feats[:, :maxlen, :]
498
+
499
+ # encoder_data_path = os.path.join("calib_dataset", "encoder", os.path.basename(batch_wav_path[0]))
500
+ # decoder_data_path = os.path.join("calib_dataset", "decoder", os.path.basename(batch_wav_path[0]))
501
+ # os.makedirs(encoder_data_path, exist_ok=True)
502
+ # os.makedirs(decoder_data_path, exist_ok=True)
503
+
504
+ feats = to_numpy(feats)
505
+ lengths = to_numpy(lengths).astype(np.int32)
506
+
507
+ # for name, npy in zip(["encoder_input", "encoder_input_lengths"], [feats, lengths]):
508
+ # file_path = os.path.join(encoder_data_path, name + ".npy")
509
+ # np.save(file_path, npy)
510
+
511
+ start_time = time.time()
512
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
513
+ to_numpy(feats),
514
+ to_numpy(lengths)
515
+ )
516
+ nbest_hyps = self.run_decoder(n_layer_cross_k,
517
+ n_layer_cross_v,
518
+ cross_attn_mask,
519
+ beam_size,
520
+ nbest,
521
+ )
522
+ transcribe_durations = time.time() - start_time
523
+ results: List[Dict] = []
524
+ for wav, hyp in zip(batch_wav_path, nbest_hyps):
525
+ hyp = hyp[0]
526
+ hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
527
+ score = hyp["score"].item()
528
+ text = self.tokenizer.detokenize(hyp_ids)
529
+ results.append(
530
+ {
531
+ "wav": wav,
532
+ "text": text,
533
+ "score": score
534
+ }
535
+ )
536
+
537
+ return results, wav_durations, transcribe_durations
538
+
539
+
540
+ def parse_args():
541
+ parser = argparse.ArgumentParser(description="FireRedASROnnxModel Test")
542
+ parser.add_argument(
543
+ "--encoder",
544
+ type=str,
545
+ default="axmodel/encoder.axmodel",
546
+ help="Path to onnx encoder"
547
+ )
548
+ parser.add_argument(
549
+ "--decoder",
550
+ type=str,
551
+ default="axmodel/decoder_main.axmodel",
552
+ help="Path to onnx decoder"
553
+ )
554
+ parser.add_argument(
555
+ "--cmvn",
556
+ type=str,
557
+ default="axmodel/cmvn.ark",
558
+ help="Path to cmvn"
559
+ )
560
+ parser.add_argument(
561
+ "--dict",
562
+ type=str,
563
+ default="axmodel/dict.txt",
564
+ help="Path to dict"
565
+ )
566
+ parser.add_argument(
567
+ "--spm_model",
568
+ type=str,
569
+ default="axmodel/train_bpe1000.model",
570
+ help="Path to spm model"
571
+ )
572
+ parser.add_argument(
573
+ "--wavlist",
574
+ type=str,
575
+ default="wavlist.txt",
576
+ help="File to wav path list"
577
+ )
578
+ parser.add_argument(
579
+ "--hypo",
580
+ type=str,
581
+ default="hypo_axmodel.txt",
582
+ help="File of hypos"
583
+ )
584
+ parser.add_argument(
585
+ "--beam_size",
586
+ type=int,
587
+ default=3,
588
+ help=""
589
+ )
590
+ parser.add_argument(
591
+ "--nbest",
592
+ type=int,
593
+ default=1,
594
+ help=""
595
+ )
596
+
597
+ return parser.parse_args()
598
+
599
+
600
+ def parse_wavlist(wavlist: str):
601
+ wavpaths = []
602
+ with open(wavlist) as f:
603
+ for line in f:
604
+ line = line.strip()
605
+ if not os.path.exists(line):
606
+ print(f"{line} doesn't exist.")
607
+ continue
608
+ wavpaths.append(line)
609
+
610
+ return wavpaths
611
+
612
+
613
+ def main():
614
+ args = parse_args()
615
+ print(args)
616
+
617
+ onnx_model = FireRedASROnnxModel(args.encoder,
618
+ args.decoder,
619
+ args.cmvn,
620
+ args.dict,
621
+ args.spm_model,
622
+ )
623
+
624
+ wf = open(args.hypo, "wt")
625
+ wavlist = parse_wavlist(args.wavlist)
626
+
627
+ total_wav_durations = 0
628
+ total_transcribe_durations = 0
629
+ for wav in wavlist:
630
+ batch_wav = [wav]
631
+ results, wav_durations, transcribe_durations = onnx_model.transcribe(
632
+ batch_wav, args.beam_size, args.nbest)
633
+
634
+ wav_durations = sum(wav_durations)
635
+ total_wav_durations += wav_durations
636
+ total_transcribe_durations += transcribe_durations
637
+ logger.info(f"{batch_wav}")
638
+ logger.info(f"Durations: {wav_durations}")
639
+ logger.info(f"Transcribe Durations: {transcribe_durations}")
640
+ rtf = transcribe_durations / wav_durations
641
+ logger.info(f"(Real time factor) RTF: {rtf}")
642
+ for result in results:
643
+ logger.info(f"wav: {result['wav']}")
644
+ logger.info(f"text: {result['text']}")
645
+ logger.info(f"score: {result['score']}")
646
+ logger.info("")
647
+ wf.write(f"{result['text']} ({result['wav']})\n")
648
+
649
+ logger.info(f"total wav durations: {total_wav_durations}")
650
+ logger.info(f"total transcribe durations: {total_transcribe_durations}")
651
+ avg_ref = total_transcribe_durations / total_wav_durations
652
+ logger.info(f"AVG RTF: {avg_ref}")
653
+
654
+ wf.close()
655
+
656
+ if __name__ == "__main__":
657
+ main()
test_decoder.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fireredasr.data.asr_feat import ASRFeatExtractor
2
+ from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
3
+
4
+ import onnxruntime as ort
5
+ # import axengine as axe
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from torch import Tensor
10
+ from typing import Tuple, List, Dict
11
+ import argparse
12
+ import os
13
+ import time
14
+ import logging
15
+
16
+ logger = logging.getLogger()
17
+ logger.setLevel(logging.INFO)
18
+ logger_stream_hander = logging.StreamHandler()
19
+ logger_stream_hander.setLevel("INFO")
20
+ logger.addHandler(logger_stream_hander)
21
+
22
+
23
+ INF = 1e10
24
+
25
+
26
+ def to_numpy(tensor):
27
+ if isinstance(tensor, np.ndarray):
28
+ return tensor
29
+ if tensor.requires_grad:
30
+ return tensor.detach().cpu().numpy()
31
+ else:
32
+ return tensor.cpu().numpy()
33
+
34
+
35
+ def set_finished_beam_score_to_zero(scores, is_finished):
36
+ NB, B = scores.size()
37
+ is_finished = is_finished.float()
38
+ mask_score = torch.tensor([0.0] + [-INF]*(B-1)).float()
39
+ mask_score = mask_score.view(1, B).repeat(NB, 1)
40
+ return scores * (1 - is_finished) + mask_score * is_finished
41
+
42
+
43
+ def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
44
+ is_finished = is_finished.long()
45
+ return ys * (1 - is_finished) + eos_id * is_finished
46
+
47
+
48
+ class FireRedASROnnxModel:
49
+ def __init__(
50
+ self,
51
+ encoder_path: str,
52
+ decoder_path: str,
53
+ cmvn_file: str,
54
+ dict_file: str,
55
+ spm_model_path: str,
56
+ providers=['CPUExecutionProvider']
57
+ ):
58
+ session_opts = ort.SessionOptions()
59
+ session_opts.inter_op_num_threads = 1
60
+ session_opts.intra_op_num_threads = 1
61
+ # session_opts.log_severity_level = 1
62
+ self.session_opts = session_opts
63
+
64
+ # NOTE: 参考whisper设置的最大的解码长度
65
+ # FireRedASR-AED 模型支持的最长语音为 60s
66
+ # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
67
+ self.decode_max_len = 448
68
+
69
+ self.decoder_hidden_dim = 1280
70
+ self.num_decoder_blocks = 16
71
+ self.blank_id = 0
72
+ self.sos_id = 3
73
+ self.eos_id = 4
74
+ self.pad_id = 2
75
+
76
+ self.feature_extractor = ASRFeatExtractor(cmvn_file)
77
+ self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
78
+ self.encoder = None
79
+ self.decoder = None
80
+
81
+ # self.init_encoder(encoder_path, providers)
82
+ # self.init_decoder(decoder_path, providers)
83
+ self.init_decoder_main(decoder_path, providers)
84
+ self.init_decoder_loop(decoder_path, providers)
85
+ self.pe = self.init_pe(decoder_path)
86
+
87
+ # def init_encoder(self, encoder_path, providers=None):
88
+ # start_time = time.time()
89
+ # self.encoder = axe.InferenceSession(
90
+ # encoder_path,
91
+ # # sess_options=self.session_opts,
92
+ # providers=providers
93
+ # )
94
+ # end_time = time.time()
95
+ # logger.info(f"load encoder cost {end_time - start_time} seconds")
96
+
97
+ def init_decoder(self, decoder_path, providers=None):
98
+ start_time = time.time()
99
+ self.decoder = ort.InferenceSession(
100
+ decoder_path,
101
+ sess_options=self.session_opts,
102
+ providers=providers
103
+ )
104
+ end_time = time.time()
105
+ logger.info(f"load decoder cost {end_time - start_time} seconds")
106
+
107
+ def init_decoder_main(self, decoder_path, providers=None):
108
+ decoder_path = os.path.dirname(decoder_path)
109
+ decoder_path = os.path.join(decoder_path, "decoder_main.onnx")
110
+ start_time = time.time()
111
+ self.decoder_main = ort.InferenceSession(
112
+ decoder_path,
113
+ sess_options=self.session_opts,
114
+ providers=providers
115
+ )
116
+ end_time = time.time()
117
+ logger.info(f"load decoder_main cost {end_time - start_time} seconds")
118
+
119
+ input_names = [i.name for i in self.decoder_main.get_inputs()]
120
+ print(f"decoder_main.input_names: {input_names}")
121
+
122
+ def init_decoder_loop(self, decoder_path, providers=None):
123
+ decoder_path = os.path.dirname(decoder_path)
124
+ decoder_path = os.path.join(decoder_path, "decoder_loop.onnx")
125
+
126
+ start_time = time.time()
127
+ self.decoder_loop = ort.InferenceSession(
128
+ decoder_path,
129
+ sess_options=self.session_opts,
130
+ providers=providers
131
+ )
132
+ end_time = time.time()
133
+ logger.info(f"load decoder_loop cost {end_time - start_time} seconds")
134
+
135
+ input_names = [i.name for i in self.decoder_loop.get_inputs()]
136
+ print(f"decoder_loop.input_names: {input_names}")
137
+
138
+ def init_pe(self, decoder_path):
139
+ decoder_path = os.path.dirname(decoder_path)
140
+ decoder_path = os.path.join(decoder_path, "pe.npy")
141
+
142
+ return np.load(decoder_path)
143
+
144
+ def run_encoder(self, input: np.ndarray,
145
+ input_length: np.ndarray
146
+ ) -> Tuple[Tensor, Tensor, Tensor]:
147
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
148
+ None,
149
+ {
150
+ "encoder_input": input,
151
+ "encoder_input_lengths": input_length.astype(np.int32)
152
+ }
153
+ )
154
+ return (
155
+ n_layer_cross_k,
156
+ n_layer_cross_v,
157
+ cross_attn_mask
158
+ )
159
+
160
+ def decode_one_token(
161
+ self,
162
+ tokens: np.ndarray,
163
+ n_layer_self_k_cache: np.ndarray,
164
+ n_layer_self_v_cache: np.ndarray,
165
+ n_layer_cross_k_cache: np.ndarray,
166
+ n_layer_cross_v_cache: np.ndarray,
167
+ offset: np.ndarray,
168
+ self_attn_mask: np.ndarray,
169
+ cross_attn_mask: np.ndarray
170
+ ) -> Tuple[Tensor, Tensor, Tensor]:
171
+ print("decode:")
172
+ print(f"tokens.shape: {tokens.shape}")
173
+ print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
174
+ print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
175
+ print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
176
+ print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
177
+ print(f"offset.shape: {offset.shape}")
178
+ print(f"self_attn_mask.shape: {self_attn_mask.shape}")
179
+ print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
180
+ # print(f"self_attn_mask: {self_attn_mask}")
181
+
182
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
183
+ None,
184
+ {
185
+ self.decoder.get_inputs()[0].name: tokens,
186
+ self.decoder.get_inputs()[1].name: n_layer_self_k_cache,
187
+ self.decoder.get_inputs()[2].name: n_layer_self_v_cache,
188
+ self.decoder.get_inputs()[3].name: n_layer_cross_k_cache,
189
+ self.decoder.get_inputs()[4].name: n_layer_cross_v_cache,
190
+ self.decoder.get_inputs()[5].name: offset,
191
+ self.decoder.get_inputs()[6].name: self_attn_mask,
192
+ self.decoder.get_inputs()[7].name: cross_attn_mask,
193
+ }
194
+ )
195
+ return (
196
+ logits,
197
+ out_n_layer_self_k_cache,
198
+ out_n_layer_self_v_cache
199
+ )
200
+
201
+ def decode_main_one_token(
202
+ self,
203
+ tokens: np.ndarray,
204
+ n_layer_self_k_cache: np.ndarray,
205
+ n_layer_self_v_cache: np.ndarray,
206
+ n_layer_cross_k_cache: np.ndarray,
207
+ n_layer_cross_v_cache: np.ndarray,
208
+ pe: np.ndarray,
209
+ self_attn_mask: np.ndarray,
210
+ cross_attn_mask: np.ndarray
211
+ ) -> Tuple[Tensor, Tensor, Tensor]:
212
+ # print("decode_main:")
213
+ # print(f"tokens.shape: {tokens.shape}")
214
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
215
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
216
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
217
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
218
+ # print(f"pe.shape: {pe.shape}")
219
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
220
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
221
+
222
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_main.run(
223
+ None,
224
+ {
225
+ self.decoder_main.get_inputs()[0].name: tokens,
226
+ # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
227
+ self.decoder_main.get_inputs()[1].name: n_layer_cross_k_cache,
228
+ self.decoder_main.get_inputs()[2].name: n_layer_cross_v_cache,
229
+ self.decoder_main.get_inputs()[3].name: pe,
230
+ self.decoder_main.get_inputs()[4].name: self_attn_mask,
231
+ self.decoder_main.get_inputs()[5].name: cross_attn_mask,
232
+ # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
233
+ }
234
+ )
235
+ return (
236
+ logits,
237
+ out_n_layer_self_k_cache,
238
+ out_n_layer_self_v_cache
239
+ )
240
+
241
+ def decode_loop_one_token(
242
+ self,
243
+ tokens: np.ndarray,
244
+ n_layer_self_k_cache: np.ndarray,
245
+ n_layer_self_v_cache: np.ndarray,
246
+ n_layer_cross_k_cache: np.ndarray,
247
+ n_layer_cross_v_cache: np.ndarray,
248
+ pe: np.ndarray,
249
+ self_attn_mask: np.ndarray,
250
+ cross_attn_mask: np.ndarray
251
+ ) -> Tuple[Tensor, Tensor, Tensor]:
252
+ # print("decode_loop:")
253
+ # print(f"tokens.shape: {tokens.shape}")
254
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
255
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
256
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
257
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
258
+ # print(f"pe.shape: {pe.shape}")
259
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
260
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
261
+
262
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_loop.run(
263
+ None,
264
+ {
265
+ self.decoder_loop.get_inputs()[0].name: tokens,
266
+ self.decoder_loop.get_inputs()[1].name: n_layer_self_k_cache,
267
+ self.decoder_loop.get_inputs()[2].name: n_layer_self_v_cache,
268
+ self.decoder_loop.get_inputs()[3].name: n_layer_cross_k_cache,
269
+ self.decoder_loop.get_inputs()[4].name: n_layer_cross_v_cache,
270
+ self.decoder_loop.get_inputs()[5].name: pe,
271
+ self.decoder_loop.get_inputs()[6].name: self_attn_mask,
272
+ self.decoder_loop.get_inputs()[7].name: cross_attn_mask,
273
+ }
274
+ )
275
+ return (
276
+ logits,
277
+ out_n_layer_self_k_cache,
278
+ out_n_layer_self_v_cache
279
+ )
280
+
281
+ def run_decoder(
282
+ self,
283
+ n_layer_cross_k,
284
+ n_layer_cross_v,
285
+ cross_attn_mask,
286
+ beam_size,
287
+ nbest
288
+ ):
289
+
290
+ num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
291
+ encoder_out_length = cross_attn_mask.shape[-1]
292
+
293
+ cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
294
+ cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
295
+ 1, beam_size, 1, 1
296
+ ).view(beam_size * batch_size, -1, encoder_out_length)
297
+
298
+ n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
299
+ n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
300
+ n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
301
+ 1, 1, beam_size, 1, 1
302
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
303
+ n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
304
+ 1, 1, beam_size, 1, 1
305
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
306
+
307
+ prediction_tokens = torch.ones(
308
+ beam_size * batch_size, 1).fill_(self.sos_id).long()
309
+ tokens = prediction_tokens
310
+ offset = torch.zeros(1, dtype=torch.int64)
311
+ n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
312
+ batch_size, beam_size
313
+ )
314
+
315
+ scores = torch.tensor([0.0] + [-INF]*(beam_size - 1)).float()
316
+ scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
317
+ is_finished = torch.zeros_like(scores)
318
+
319
+ # self_attn_mask = torch.zeros(
320
+ # batch_size * beam_size,
321
+ # 1, 1
322
+ # )
323
+ self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
324
+
325
+ results = [self.sos_id]
326
+ for i in range(self.decode_max_len):
327
+
328
+ # self_attn_mask = torch.empty(
329
+ # batch_size * beam_size,
330
+ # prediction_tokens.shape[-1], prediction_tokens.shape[-1]
331
+ # ).fill_(-np.inf).triu_(1)
332
+ # self_attn_mask = self_attn_mask[:, -1:, :]
333
+ # self_attn_mask = to_numpy(self_attn_mask)
334
+
335
+ # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_one_token(
336
+ # to_numpy(tokens),
337
+ # to_numpy(n_layer_self_k_cache),
338
+ # to_numpy(n_layer_self_v_cache),
339
+ # to_numpy(n_layer_cross_k),
340
+ # to_numpy(n_layer_cross_v),
341
+ # to_numpy(offset),
342
+ # to_numpy(self_attn_mask),
343
+ # to_numpy(cross_attn_mask)
344
+ # )
345
+
346
+ tokens = to_numpy(tokens)
347
+ n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
348
+ n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
349
+ n_layer_cross_k = to_numpy(n_layer_cross_k)
350
+ n_layer_cross_v = to_numpy(n_layer_cross_v)
351
+ cross_attn_mask = to_numpy(cross_attn_mask)
352
+
353
+ if i == 0:
354
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
355
+ to_numpy(tokens),
356
+ to_numpy(n_layer_self_k_cache),
357
+ to_numpy(n_layer_self_v_cache),
358
+ to_numpy(n_layer_cross_k),
359
+ to_numpy(n_layer_cross_v),
360
+ self.pe[offset],
361
+ self_attn_mask,
362
+ to_numpy(cross_attn_mask)
363
+ )
364
+ else:
365
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
366
+ to_numpy(tokens),
367
+ to_numpy(n_layer_self_k_cache),
368
+ to_numpy(n_layer_self_v_cache),
369
+ to_numpy(n_layer_cross_k),
370
+ to_numpy(n_layer_cross_v),
371
+ self.pe[offset],
372
+ self_attn_mask,
373
+ to_numpy(cross_attn_mask)
374
+ )
375
+
376
+ offset += 1
377
+ logits = torch.from_numpy(logits)
378
+
379
+ logits = logits.squeeze(1)
380
+ t_scores = F.log_softmax(logits, dim=-1)
381
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
382
+ t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
383
+ t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
384
+
385
+ scores = scores + t_topB_scores
386
+
387
+ scores = scores.view(batch_size, beam_size * beam_size)
388
+ scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
389
+ scores = scores.view(-1, 1)
390
+
391
+ topB_row_number_in_each_B_rows_of_ys = torch.div(
392
+ topB_score_ids, beam_size).view(batch_size * beam_size)
393
+ stride = beam_size * torch.arange(batch_size).view(
394
+ batch_size, 1).repeat(1, beam_size).view(batch_size * beam_size)
395
+ topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
396
+
397
+ prediction_tokens = prediction_tokens[topB_row_number_in_ys]
398
+ t_ys = torch.gather(
399
+ t_topB_ys.view(batch_size, beam_size * beam_size),
400
+ dim=1, index=topB_score_ids
401
+ ).view(beam_size * batch_size, 1)
402
+
403
+ tokens = t_ys
404
+
405
+ prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
406
+
407
+ n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
408
+ n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
409
+
410
+ for i, self_k_cache in enumerate(n_layer_self_k_cache):
411
+ n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
412
+
413
+ for i, self_v_cache in enumerate(n_layer_self_v_cache):
414
+ n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
415
+
416
+ is_finished = t_ys.eq(self.eos_id)
417
+ if is_finished.sum().item() == beam_size * batch_size:
418
+ break
419
+
420
+ scores = scores.view(batch_size, beam_size)
421
+ prediction_valid_token_lengths = torch.sum(
422
+ torch.ne(
423
+ prediction_tokens.view(batch_size, beam_size, -1),
424
+ self.eos_id),
425
+ dim=-1
426
+ ).int()
427
+
428
+ nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
429
+ index = nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
430
+ nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
431
+ nbest_prediction_tokens = nbest_prediction_tokens.view(batch_size, nbest_ids.size(1), -1)
432
+ nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
433
+ batch_size * beam_size)[index.view(-1)].view(batch_size, -1)
434
+ nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
435
+ for i in range(batch_size):
436
+ i_best_hyps: List[Dict[str, torch.Tensor]] = []
437
+ for j, score in enumerate(nbest_scores[i]):
438
+ hyp = {
439
+ "token_ids": nbest_prediction_tokens[i, j, 1:nbest_prediction_valid_token_lengths[i, j]],
440
+ "score": score
441
+ }
442
+ i_best_hyps.append(hyp)
443
+ nbest_hyps.append(i_best_hyps)
444
+
445
+ return nbest_hyps
446
+
447
+ def get_initialized_self_cache(self,
448
+ batch_size,
449
+ beam_size
450
+ ) -> Tuple[Tensor, Tensor]:
451
+ n_layer_self_k_cache = torch.zeros(
452
+ self.num_decoder_blocks,
453
+ batch_size * beam_size,
454
+ self.decode_max_len,
455
+ self.decoder_hidden_dim,
456
+ )
457
+ n_layer_self_v_cache = torch.zeros(
458
+ self.num_decoder_blocks,
459
+ batch_size * beam_size,
460
+ self.decode_max_len,
461
+ self.decoder_hidden_dim,
462
+ )
463
+ return n_layer_self_k_cache, n_layer_self_v_cache
464
+
465
+ def calc_feat_len(self, audio_dur):
466
+ import math
467
+ sample_rate = 16000
468
+ frame_length = 25 * sample_rate / 1000
469
+ frame_shift = 10 * sample_rate / 1000
470
+ length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
471
+ return length
472
+
473
+ def transcribe(self,
474
+ batch_wav_path: List[str],
475
+ beam_size: int = 1,
476
+ nbest: int = 1
477
+ ) -> List[Dict]:
478
+ feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
479
+ print(f"feats.shape: {feats.shape}")
480
+ maxlen = self.calc_feat_len(10)
481
+ if feats.shape[1] < maxlen:
482
+ feats = np.concatenate([feats, np.zeros((1, maxlen - feats.shape[1], 80), dtype=np.float32)], axis=1)
483
+ feats = feats[:, :maxlen, :]
484
+
485
+ encoder_data_path = os.path.join("encoder_output", os.path.basename(batch_wav_path[0]))
486
+ # decoder_data_path = os.path.join("calib_dataset", "decoder", os.path.basename(batch_wav_path[0]))
487
+ # os.makedirs(encoder_data_path, exist_ok=True)
488
+ # os.makedirs(decoder_data_path, exist_ok=True)
489
+
490
+ n_layer_cross_k = np.load(os.path.join(encoder_data_path, "n_layer_cross_k.npy"))
491
+ n_layer_cross_v = np.load(os.path.join(encoder_data_path, "n_layer_cross_v.npy"))
492
+ cross_attn_mask = np.load(os.path.join(encoder_data_path, "cross_attn_mask.npy"))
493
+
494
+ # for name, npy in zip(["encoder_input", "encoder_input_lengths"], [feats, lengths]):
495
+ # file_path = os.path.join(encoder_data_path, name + ".npy")
496
+ # np.save(file_path, npy)
497
+
498
+ start_time = time.time()
499
+
500
+ nbest_hyps = self.run_decoder(n_layer_cross_k,
501
+ n_layer_cross_v,
502
+ cross_attn_mask,
503
+ beam_size,
504
+ nbest
505
+ )
506
+ transcribe_durations = time.time() - start_time
507
+ results: List[Dict] = []
508
+ for wav, hyp in zip(batch_wav_path, nbest_hyps):
509
+ hyp = hyp[0]
510
+ hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
511
+ score = hyp["score"].item()
512
+ text = self.tokenizer.detokenize(hyp_ids)
513
+ results.append(
514
+ {
515
+ "wav": wav,
516
+ "text": text,
517
+ "score": score
518
+ }
519
+ )
520
+
521
+ return results, wav_durations, transcribe_durations
522
+
523
+
524
+ def parse_args():
525
+ parser = argparse.ArgumentParser(description="FireRedASROnnxModel Test")
526
+ parser.add_argument(
527
+ "--encoder",
528
+ type=str,
529
+ default="axmodel/encoder.axmodel",
530
+ help="Path to onnx encoder"
531
+ )
532
+ parser.add_argument(
533
+ "--decoder",
534
+ type=str,
535
+ default="onnx_decoder/decoder_main.onnx",
536
+ help="Path to onnx decoder"
537
+ )
538
+ parser.add_argument(
539
+ "--cmvn",
540
+ type=str,
541
+ default="axmodel/cmvn.ark",
542
+ help="Path to cmvn"
543
+ )
544
+ parser.add_argument(
545
+ "--dict",
546
+ type=str,
547
+ default="axmodel/dict.txt",
548
+ help="Path to dict"
549
+ )
550
+ parser.add_argument(
551
+ "--spm_model",
552
+ type=str,
553
+ default="axmodel/train_bpe1000.model",
554
+ help="Path to spm model"
555
+ )
556
+ parser.add_argument(
557
+ "--wavlist",
558
+ type=str,
559
+ default="wavlist.txt",
560
+ help="File to wav path list"
561
+ )
562
+ parser.add_argument(
563
+ "--hypo",
564
+ type=str,
565
+ default="hypo_encoder.txt",
566
+ help="File of hypos"
567
+ )
568
+ parser.add_argument(
569
+ "--beam_size",
570
+ type=int,
571
+ default=3,
572
+ help=""
573
+ )
574
+ parser.add_argument(
575
+ "--nbest",
576
+ type=int,
577
+ default=1,
578
+ help=""
579
+ )
580
+
581
+ return parser.parse_args()
582
+
583
+
584
+ def parse_wavlist(wavlist: str):
585
+ wavpaths = []
586
+ with open(wavlist) as f:
587
+ for line in f:
588
+ line = line.strip()
589
+ if not os.path.exists(line):
590
+ print(f"{line} doesn't exist.")
591
+ continue
592
+ wavpaths.append(line)
593
+
594
+ return wavpaths
595
+
596
+
597
+ def main():
598
+ args = parse_args()
599
+ print(args)
600
+
601
+ onnx_model = FireRedASROnnxModel(args.encoder,
602
+ args.decoder,
603
+ args.cmvn,
604
+ args.dict,
605
+ args.spm_model)
606
+
607
+ wf = open(args.hypo, "wt")
608
+ wavlist = parse_wavlist(args.wavlist)
609
+
610
+ total_wav_durations = 0
611
+ total_transcribe_durations = 0
612
+ for wav in wavlist:
613
+ batch_wav = [wav]
614
+ results, wav_durations, transcribe_durations = onnx_model.transcribe(batch_wav, args.beam_size, args.nbest)
615
+
616
+ wav_durations = sum(wav_durations)
617
+ total_wav_durations += wav_durations
618
+ total_transcribe_durations += transcribe_durations
619
+ logger.info(f"{batch_wav}")
620
+ logger.info(f"Durations: {wav_durations}")
621
+ logger.info(f"Transcribe Durations: {transcribe_durations}")
622
+ rtf = transcribe_durations / wav_durations
623
+ logger.info(f"(Real time factor) RTF: {rtf}")
624
+ for result in results:
625
+ logger.info(f"wav: {result['wav']}")
626
+ logger.info(f"text: {result['text']}")
627
+ logger.info(f"score: {result['score']}")
628
+ logger.info("")
629
+ wf.write(f"{result['text']} ({result['wav']})\n")
630
+
631
+ logger.info(f"total wav durations: {total_wav_durations}")
632
+ logger.info(f"total transcribe durations: {total_transcribe_durations}")
633
+ avg_ref = total_transcribe_durations / total_wav_durations
634
+ logger.info(f"AVG RTF: {avg_ref}")
635
+
636
+ wf.close()
637
+
638
+
639
+ if __name__ == "__main__":
640
+ main()
test_encoder.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fireredasr.data.asr_feat import ASRFeatExtractor
2
+ from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
3
+
4
+ import onnxruntime as ort
5
+ import axengine as axe
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from torch import Tensor
10
+ from typing import Tuple, List, Dict
11
+ import argparse
12
+ import os
13
+ import time
14
+ import logging
15
+
16
+ logger = logging.getLogger()
17
+ logger.setLevel(logging.INFO)
18
+ logger_stream_hander = logging.StreamHandler()
19
+ logger_stream_hander.setLevel("INFO")
20
+ logger.addHandler(logger_stream_hander)
21
+
22
+
23
+ INF = 1e10
24
+
25
+
26
+ def to_numpy(tensor):
27
+ if isinstance(tensor, np.ndarray):
28
+ return tensor
29
+ if tensor.requires_grad:
30
+ return tensor.detach().cpu().numpy()
31
+ else:
32
+ return tensor.cpu().numpy()
33
+
34
+
35
+ def set_finished_beam_score_to_zero(scores, is_finished):
36
+ NB, B = scores.size()
37
+ is_finished = is_finished.float()
38
+ mask_score = torch.tensor([0.0] + [-INF]*(B-1)).float()
39
+ mask_score = mask_score.view(1, B).repeat(NB, 1)
40
+ return scores * (1 - is_finished) + mask_score * is_finished
41
+
42
+
43
+ def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
44
+ is_finished = is_finished.long()
45
+ return ys * (1 - is_finished) + eos_id * is_finished
46
+
47
+
48
+ class FireRedASROnnxModel:
49
+ def __init__(
50
+ self,
51
+ encoder_path: str,
52
+ decoder_path: str,
53
+ cmvn_file: str,
54
+ dict_file: str,
55
+ spm_model_path: str,
56
+ providers=['AXCLRTExecutionProvider', 'AxEngineExecutionProvider']
57
+ ):
58
+ session_opts = ort.SessionOptions()
59
+ session_opts.inter_op_num_threads = 1
60
+ session_opts.intra_op_num_threads = 1
61
+ # session_opts.log_severity_level = 1
62
+ self.session_opts = session_opts
63
+
64
+ # NOTE: 参考whisper设置的最大的解码长度
65
+ # FireRedASR-AED 模型支持的最长语音为 60s
66
+ # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
67
+ self.decode_max_len = 448
68
+
69
+ self.decoder_hidden_dim = 1280
70
+ self.num_decoder_blocks = 16
71
+ self.blank_id = 0
72
+ self.sos_id = 3
73
+ self.eos_id = 4
74
+ self.pad_id = 2
75
+
76
+ self.feature_extractor = ASRFeatExtractor(cmvn_file)
77
+ self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
78
+ self.encoder = None
79
+ self.decoder = None
80
+
81
+ self.init_encoder(encoder_path, providers)
82
+ # self.init_decoder(decoder_path, providers)
83
+ # self.init_decoder_main(decoder_path, providers)
84
+ # self.init_decoder_loop(decoder_path, providers)
85
+ self.pe = self.init_pe(decoder_path)
86
+
87
+ def init_encoder(self, encoder_path, providers=None):
88
+ start_time = time.time()
89
+ self.encoder = axe.InferenceSession(
90
+ encoder_path,
91
+ # sess_options=self.session_opts,
92
+ providers=providers
93
+ )
94
+ end_time = time.time()
95
+ logger.info(f"load encoder cost {end_time - start_time} seconds")
96
+
97
+ def init_decoder(self, decoder_path, providers=None):
98
+ start_time = time.time()
99
+ self.decoder = ort.InferenceSession(
100
+ decoder_path,
101
+ sess_options=self.session_opts,
102
+ providers=['CPUExecutionProvider']
103
+ )
104
+ end_time = time.time()
105
+ logger.info(f"load decoder cost {end_time - start_time} seconds")
106
+
107
+ def init_decoder_main(self, decoder_path, providers=None):
108
+ decoder_path = os.path.dirname(decoder_path)
109
+ decoder_path = os.path.join(decoder_path, "decoder_main.onnx")
110
+ start_time = time.time()
111
+ self.decoder_main = ort.InferenceSession(
112
+ decoder_path,
113
+ sess_options=self.session_opts,
114
+ providers=['CPUExecutionProvider']
115
+ )
116
+ end_time = time.time()
117
+ logger.info(f"load decoder_main cost {end_time - start_time} seconds")
118
+
119
+ input_names = [i.name for i in self.decoder_main.get_inputs()]
120
+ print(f"decoder_main.input_names: {input_names}")
121
+
122
+ def init_decoder_loop(self, decoder_path, providers=None):
123
+ decoder_path = os.path.dirname(decoder_path)
124
+ decoder_path = os.path.join(decoder_path, "decoder_loop.onnx")
125
+
126
+ start_time = time.time()
127
+ self.decoder_loop = ort.InferenceSession(
128
+ decoder_path,
129
+ sess_options=self.session_opts,
130
+ providers=['CPUExecutionProvider']
131
+ )
132
+ end_time = time.time()
133
+ logger.info(f"load decoder_loop cost {end_time - start_time} seconds")
134
+
135
+ input_names = [i.name for i in self.decoder_loop.get_inputs()]
136
+ print(f"decoder_loop.input_names: {input_names}")
137
+
138
+ def init_pe(self, decoder_path):
139
+ decoder_path = os.path.join("axmodel", "pe.npy")
140
+
141
+ return np.load(decoder_path)
142
+
143
+ def run_encoder(self, input: np.ndarray,
144
+ input_length: np.ndarray
145
+ ) -> Tuple[Tensor, Tensor, Tensor]:
146
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
147
+ None,
148
+ {
149
+ "encoder_input": input,
150
+ "encoder_input_lengths": input_length.astype(np.int32)
151
+ }
152
+ )
153
+ return (
154
+ n_layer_cross_k,
155
+ n_layer_cross_v,
156
+ cross_attn_mask
157
+ )
158
+
159
+ def decode_one_token(
160
+ self,
161
+ tokens: np.ndarray,
162
+ n_layer_self_k_cache: np.ndarray,
163
+ n_layer_self_v_cache: np.ndarray,
164
+ n_layer_cross_k_cache: np.ndarray,
165
+ n_layer_cross_v_cache: np.ndarray,
166
+ offset: np.ndarray,
167
+ self_attn_mask: np.ndarray,
168
+ cross_attn_mask: np.ndarray
169
+ ) -> Tuple[Tensor, Tensor, Tensor]:
170
+ # print("decode:")
171
+ # print(f"tokens.shape: {tokens.shape}")
172
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
173
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
174
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
175
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
176
+ # print(f"offset.shape: {offset.shape}")
177
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
178
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
179
+ # print(f"self_attn_mask: {self_attn_mask}")
180
+
181
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
182
+ None,
183
+ {
184
+ self.decoder.get_inputs()[0].name: tokens,
185
+ self.decoder.get_inputs()[1].name: n_layer_self_k_cache,
186
+ self.decoder.get_inputs()[2].name: n_layer_self_v_cache,
187
+ self.decoder.get_inputs()[3].name: n_layer_cross_k_cache,
188
+ self.decoder.get_inputs()[4].name: n_layer_cross_v_cache,
189
+ self.decoder.get_inputs()[5].name: offset,
190
+ self.decoder.get_inputs()[6].name: self_attn_mask,
191
+ self.decoder.get_inputs()[7].name: cross_attn_mask,
192
+ }
193
+ )
194
+ return (
195
+ logits,
196
+ out_n_layer_self_k_cache,
197
+ out_n_layer_self_v_cache
198
+ )
199
+
200
+ def decode_main_one_token(
201
+ self,
202
+ tokens: np.ndarray,
203
+ n_layer_self_k_cache: np.ndarray,
204
+ n_layer_self_v_cache: np.ndarray,
205
+ n_layer_cross_k_cache: np.ndarray,
206
+ n_layer_cross_v_cache: np.ndarray,
207
+ pe: np.ndarray,
208
+ self_attn_mask: np.ndarray,
209
+ cross_attn_mask: np.ndarray
210
+ ) -> Tuple[Tensor, Tensor, Tensor]:
211
+ # print("decode_main:")
212
+ # print(f"tokens.shape: {tokens.shape}")
213
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
214
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
215
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
216
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
217
+ # print(f"pe.shape: {pe.shape}")
218
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
219
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
220
+
221
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_main.run(
222
+ None,
223
+ {
224
+ self.decoder_main.get_inputs()[0].name: tokens,
225
+ # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
226
+ self.decoder_main.get_inputs()[1].name: n_layer_cross_k_cache,
227
+ self.decoder_main.get_inputs()[2].name: n_layer_cross_v_cache,
228
+ self.decoder_main.get_inputs()[3].name: pe,
229
+ self.decoder_main.get_inputs()[4].name: self_attn_mask,
230
+ self.decoder_main.get_inputs()[5].name: cross_attn_mask,
231
+ # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
232
+ }
233
+ )
234
+ return (
235
+ logits,
236
+ out_n_layer_self_k_cache,
237
+ out_n_layer_self_v_cache
238
+ )
239
+
240
+ def decode_loop_one_token(
241
+ self,
242
+ tokens: np.ndarray,
243
+ n_layer_self_k_cache: np.ndarray,
244
+ n_layer_self_v_cache: np.ndarray,
245
+ n_layer_cross_k_cache: np.ndarray,
246
+ n_layer_cross_v_cache: np.ndarray,
247
+ pe: np.ndarray,
248
+ self_attn_mask: np.ndarray,
249
+ cross_attn_mask: np.ndarray
250
+ ) -> Tuple[Tensor, Tensor, Tensor]:
251
+ # print("decode_loop:")
252
+ # print(f"tokens.shape: {tokens.shape}")
253
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
254
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
255
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
256
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
257
+ # print(f"pe.shape: {pe.shape}")
258
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
259
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
260
+
261
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_loop.run(
262
+ None,
263
+ {
264
+ self.decoder_loop.get_inputs()[0].name: tokens,
265
+ self.decoder_loop.get_inputs()[1].name: n_layer_self_k_cache,
266
+ self.decoder_loop.get_inputs()[2].name: n_layer_self_v_cache,
267
+ self.decoder_loop.get_inputs()[3].name: n_layer_cross_k_cache,
268
+ self.decoder_loop.get_inputs()[4].name: n_layer_cross_v_cache,
269
+ self.decoder_loop.get_inputs()[5].name: pe,
270
+ self.decoder_loop.get_inputs()[6].name: self_attn_mask,
271
+ self.decoder_loop.get_inputs()[7].name: cross_attn_mask,
272
+ }
273
+ )
274
+ return (
275
+ logits,
276
+ out_n_layer_self_k_cache,
277
+ out_n_layer_self_v_cache
278
+ )
279
+
280
+ def run_decoder(
281
+ self,
282
+ n_layer_cross_k,
283
+ n_layer_cross_v,
284
+ cross_attn_mask,
285
+ beam_size,
286
+ nbest
287
+ ):
288
+
289
+ num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
290
+ encoder_out_length = cross_attn_mask.shape[-1]
291
+
292
+ cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
293
+ cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
294
+ 1, beam_size, 1, 1
295
+ ).view(beam_size * batch_size, -1, encoder_out_length)
296
+
297
+ n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
298
+ n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
299
+ n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
300
+ 1, 1, beam_size, 1, 1
301
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
302
+ n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
303
+ 1, 1, beam_size, 1, 1
304
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
305
+
306
+ prediction_tokens = torch.ones(
307
+ beam_size * batch_size, 1).fill_(self.sos_id).long()
308
+ tokens = prediction_tokens
309
+ offset = torch.zeros(1, dtype=torch.int64)
310
+ n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
311
+ batch_size, beam_size
312
+ )
313
+
314
+ scores = torch.tensor([0.0] + [-INF]*(beam_size - 1)).float()
315
+ scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
316
+ is_finished = torch.zeros_like(scores)
317
+
318
+ # self_attn_mask = torch.zeros(
319
+ # batch_size * beam_size,
320
+ # 1, 1
321
+ # )
322
+ self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
323
+
324
+ results = [self.sos_id]
325
+ for i in range(self.decode_max_len):
326
+
327
+ self_attn_mask = torch.empty(
328
+ batch_size * beam_size,
329
+ prediction_tokens.shape[-1], prediction_tokens.shape[-1]
330
+ ).fill_(-np.inf).triu_(1)
331
+ self_attn_mask = self_attn_mask[:, -1:, :]
332
+ self_attn_mask = to_numpy(self_attn_mask)
333
+
334
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_one_token(
335
+ to_numpy(tokens),
336
+ to_numpy(n_layer_self_k_cache),
337
+ to_numpy(n_layer_self_v_cache),
338
+ to_numpy(n_layer_cross_k),
339
+ to_numpy(n_layer_cross_v),
340
+ to_numpy(offset),
341
+ to_numpy(self_attn_mask),
342
+ to_numpy(cross_attn_mask)
343
+ )
344
+
345
+ tokens = to_numpy(tokens)
346
+ n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
347
+ n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
348
+ n_layer_cross_k = to_numpy(n_layer_cross_k)
349
+ n_layer_cross_v = to_numpy(n_layer_cross_v)
350
+ cross_attn_mask = to_numpy(cross_attn_mask)
351
+
352
+ # if i == 0:
353
+ # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
354
+ # to_numpy(tokens),
355
+ # to_numpy(n_layer_self_k_cache),
356
+ # to_numpy(n_layer_self_v_cache),
357
+ # to_numpy(n_layer_cross_k),
358
+ # to_numpy(n_layer_cross_v),
359
+ # self.pe[offset],
360
+ # self_attn_mask,
361
+ # to_numpy(cross_attn_mask)
362
+ # )
363
+ # else:
364
+ # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
365
+ # to_numpy(tokens),
366
+ # to_numpy(n_layer_self_k_cache),
367
+ # to_numpy(n_layer_self_v_cache),
368
+ # to_numpy(n_layer_cross_k),
369
+ # to_numpy(n_layer_cross_v),
370
+ # self.pe[offset],
371
+ # self_attn_mask,
372
+ # to_numpy(cross_attn_mask)
373
+ # )
374
+
375
+ offset += 1
376
+ logits = torch.from_numpy(logits)
377
+
378
+ logits = logits.squeeze(1)
379
+ t_scores = F.log_softmax(logits, dim=-1)
380
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
381
+ t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
382
+ t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
383
+
384
+ scores = scores + t_topB_scores
385
+
386
+ scores = scores.view(batch_size, beam_size * beam_size)
387
+ scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
388
+ scores = scores.view(-1, 1)
389
+
390
+ topB_row_number_in_each_B_rows_of_ys = torch.div(
391
+ topB_score_ids, beam_size).view(batch_size * beam_size)
392
+ stride = beam_size * torch.arange(batch_size).view(
393
+ batch_size, 1).repeat(1, beam_size).view(batch_size * beam_size)
394
+ topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
395
+
396
+ prediction_tokens = prediction_tokens[topB_row_number_in_ys]
397
+ t_ys = torch.gather(
398
+ t_topB_ys.view(batch_size, beam_size * beam_size),
399
+ dim=1, index=topB_score_ids
400
+ ).view(beam_size * batch_size, 1)
401
+
402
+ tokens = t_ys
403
+
404
+ prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
405
+
406
+ n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
407
+ n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
408
+
409
+ for i, self_k_cache in enumerate(n_layer_self_k_cache):
410
+ n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
411
+
412
+ for i, self_v_cache in enumerate(n_layer_self_v_cache):
413
+ n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
414
+
415
+ is_finished = t_ys.eq(self.eos_id)
416
+ if is_finished.sum().item() == beam_size * batch_size:
417
+ break
418
+
419
+ scores = scores.view(batch_size, beam_size)
420
+ prediction_valid_token_lengths = torch.sum(
421
+ torch.ne(
422
+ prediction_tokens.view(batch_size, beam_size, -1),
423
+ self.eos_id),
424
+ dim=-1
425
+ ).int()
426
+
427
+ nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
428
+ index = nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
429
+ nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
430
+ nbest_prediction_tokens = nbest_prediction_tokens.view(batch_size, nbest_ids.size(1), -1)
431
+ nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
432
+ batch_size * beam_size)[index.view(-1)].view(batch_size, -1)
433
+ nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
434
+ for i in range(batch_size):
435
+ i_best_hyps: List[Dict[str, torch.Tensor]] = []
436
+ for j, score in enumerate(nbest_scores[i]):
437
+ hyp = {
438
+ "token_ids": nbest_prediction_tokens[i, j, 1:nbest_prediction_valid_token_lengths[i, j]],
439
+ "score": score
440
+ }
441
+ i_best_hyps.append(hyp)
442
+ nbest_hyps.append(i_best_hyps)
443
+
444
+ return nbest_hyps
445
+
446
+ def get_initialized_self_cache(self,
447
+ batch_size,
448
+ beam_size
449
+ ) -> Tuple[Tensor, Tensor]:
450
+ n_layer_self_k_cache = torch.zeros(
451
+ self.num_decoder_blocks,
452
+ batch_size * beam_size,
453
+ self.decode_max_len,
454
+ self.decoder_hidden_dim,
455
+ )
456
+ n_layer_self_v_cache = torch.zeros(
457
+ self.num_decoder_blocks,
458
+ batch_size * beam_size,
459
+ self.decode_max_len,
460
+ self.decoder_hidden_dim,
461
+ )
462
+ return n_layer_self_k_cache, n_layer_self_v_cache
463
+
464
+ def calc_feat_len(self, audio_dur):
465
+ import math
466
+ sample_rate = 16000
467
+ frame_length = 25 * sample_rate / 1000
468
+ frame_shift = 10 * sample_rate / 1000
469
+ length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
470
+ return length
471
+
472
+ def transcribe(self,
473
+ batch_wav_path: List[str],
474
+ beam_size: int = 1,
475
+ nbest: int = 1
476
+ ) -> List[Dict]:
477
+ feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
478
+ print(f"feats.shape: {feats.shape}")
479
+ maxlen = self.calc_feat_len(10)
480
+ if feats.shape[1] < maxlen:
481
+ feats = np.concatenate([feats, np.zeros((1, maxlen - feats.shape[1], 80), dtype=np.float32)], axis=1)
482
+ feats = feats[:, :maxlen, :]
483
+
484
+ encoder_data_path = os.path.join("encoder_output", os.path.basename(batch_wav_path[0]))
485
+ # decoder_data_path = os.path.join("calib_dataset", "decoder", os.path.basename(batch_wav_path[0]))
486
+ os.makedirs(encoder_data_path, exist_ok=True)
487
+ # os.makedirs(decoder_data_path, exist_ok=True)
488
+
489
+ feats = to_numpy(feats)
490
+ lengths = to_numpy(lengths)
491
+
492
+ # for name, npy in zip(["encoder_input", "encoder_input_lengths"], [feats, lengths]):
493
+ # file_path = os.path.join(encoder_data_path, name + ".npy")
494
+ # np.save(file_path, npy)
495
+
496
+ start_time = time.time()
497
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
498
+ to_numpy(feats),
499
+ to_numpy(lengths)
500
+ )
501
+
502
+ for name, npy in zip(["n_layer_cross_k", "n_layer_cross_v", "cross_attn_mask"], [n_layer_cross_k, n_layer_cross_v, cross_attn_mask]):
503
+ file_path = os.path.join(encoder_data_path, name + ".npy")
504
+ np.save(file_path, npy)
505
+
506
+ # nbest_hyps = self.run_decoder(n_layer_cross_k,
507
+ # n_layer_cross_v,
508
+ # cross_attn_mask,
509
+ # beam_size,
510
+ # nbest
511
+ # )
512
+ # transcribe_durations = time.time() - start_time
513
+ # results: List[Dict] = []
514
+ # for wav, hyp in zip(batch_wav_path, nbest_hyps):
515
+ # hyp = hyp[0]
516
+ # hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
517
+ # score = hyp["score"].item()
518
+ # text = self.tokenizer.detokenize(hyp_ids)
519
+ # results.append(
520
+ # {
521
+ # "wav": wav,
522
+ # "text": text,
523
+ # "score": score
524
+ # }
525
+ # )
526
+
527
+ # return results, wav_durations, transcribe_durations
528
+
529
+
530
+ def parse_args():
531
+ parser = argparse.ArgumentParser(description="FireRedASROnnxModel Test")
532
+ parser.add_argument(
533
+ "--encoder",
534
+ type=str,
535
+ default="axmodel/encoder.axmodel",
536
+ help="Path to onnx encoder"
537
+ )
538
+ parser.add_argument(
539
+ "--decoder",
540
+ type=str,
541
+ default="onnx_decoder/decoder.onnx",
542
+ help="Path to onnx decoder"
543
+ )
544
+ parser.add_argument(
545
+ "--cmvn",
546
+ type=str,
547
+ default="axmodel/cmvn.ark",
548
+ help="Path to cmvn"
549
+ )
550
+ parser.add_argument(
551
+ "--dict",
552
+ type=str,
553
+ default="axmodel/dict.txt",
554
+ help="Path to dict"
555
+ )
556
+ parser.add_argument(
557
+ "--spm_model",
558
+ type=str,
559
+ default="axmodel/train_bpe1000.model",
560
+ help="Path to spm model"
561
+ )
562
+ parser.add_argument(
563
+ "--wavlist",
564
+ type=str,
565
+ default="wavlist.txt",
566
+ help="File to wav path list"
567
+ )
568
+ parser.add_argument(
569
+ "--hypo",
570
+ type=str,
571
+ default="hypo_axmodel.txt",
572
+ help="File of hypos"
573
+ )
574
+ parser.add_argument(
575
+ "--beam_size",
576
+ type=int,
577
+ default=3,
578
+ help=""
579
+ )
580
+ parser.add_argument(
581
+ "--nbest",
582
+ type=int,
583
+ default=1,
584
+ help=""
585
+ )
586
+
587
+ return parser.parse_args()
588
+
589
+
590
+ def parse_wavlist(wavlist: str):
591
+ wavpaths = []
592
+ with open(wavlist) as f:
593
+ for line in f:
594
+ line = line.strip()
595
+ if not os.path.exists(line):
596
+ print(f"{line} doesn't exist.")
597
+ continue
598
+ wavpaths.append(line)
599
+
600
+ return wavpaths
601
+
602
+
603
+ def main():
604
+ args = parse_args()
605
+ print(args)
606
+
607
+ onnx_model = FireRedASROnnxModel(args.encoder,
608
+ args.decoder,
609
+ args.cmvn,
610
+ args.dict,
611
+ args.spm_model)
612
+
613
+ wf = open(args.hypo, "wt")
614
+ wavlist = parse_wavlist(args.wavlist)
615
+
616
+ total_wav_durations = 0
617
+ total_transcribe_durations = 0
618
+ for wav in wavlist:
619
+ batch_wav = [wav]
620
+ onnx_model.transcribe(batch_wav, args.beam_size, args.nbest)
621
+
622
+ # wav_durations = sum(wav_durations)
623
+ # total_wav_durations += wav_durations
624
+ # total_transcribe_durations += transcribe_durations
625
+ # logger.info(f"{batch_wav}")
626
+ # logger.info(f"Durations: {wav_durations}")
627
+ # logger.info(f"Transcribe Durations: {transcribe_durations}")
628
+ # rtf = transcribe_durations / wav_durations
629
+ # logger.info(f"(Real time factor) RTF: {rtf}")
630
+ # for result in results:
631
+ # logger.info(f"wav: {result['wav']}")
632
+ # logger.info(f"text: {result['text']}")
633
+ # logger.info(f"score: {result['score']}")
634
+ # logger.info("")
635
+ # wf.write(f"{result['text']} ({result['wav']})\n")
636
+
637
+ # logger.info(f"total wav durations: {total_wav_durations}")
638
+ # logger.info(f"total transcribe durations: {total_transcribe_durations}")
639
+ # avg_ref = total_transcribe_durations / total_wav_durations
640
+ # logger.info(f"AVG RTF: {avg_ref}")
641
+
642
+ wf.close()
643
+
644
+
645
+ if __name__ == "__main__":
646
+ main()
test_onnx_model.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fireredasr.data.asr_feat import ASRFeatExtractor
2
+ from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
3
+
4
+ import onnxruntime as ort
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from torch import Tensor
9
+ from typing import Tuple, List, Dict
10
+ import argparse
11
+ import os
12
+ import time
13
+ import logging
14
+
15
+ logger = logging.getLogger()
16
+ logger.setLevel(logging.INFO)
17
+ logger_stream_hander = logging.StreamHandler()
18
+ logger_stream_hander.setLevel("INFO")
19
+ logger.addHandler(logger_stream_hander)
20
+
21
+
22
+ INF = 1e10
23
+
24
+
25
+ def to_numpy(tensor):
26
+ if isinstance(tensor, np.ndarray):
27
+ return tensor
28
+ if tensor.requires_grad:
29
+ return tensor.detach().cpu().numpy()
30
+ else:
31
+ return tensor.cpu().numpy()
32
+
33
+
34
+ def set_finished_beam_score_to_zero(scores, is_finished):
35
+ NB, B = scores.size()
36
+ is_finished = is_finished.float()
37
+ mask_score = torch.tensor([0.0] + [-INF]*(B-1)).float()
38
+ mask_score = mask_score.view(1, B).repeat(NB, 1)
39
+ return scores * (1 - is_finished) + mask_score * is_finished
40
+
41
+
42
+ def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
43
+ is_finished = is_finished.long()
44
+ return ys * (1 - is_finished) + eos_id * is_finished
45
+
46
+
47
+ class FireRedASROnnxModel:
48
+ def __init__(
49
+ self,
50
+ encoder_path: str,
51
+ decoder_path: str,
52
+ cmvn_file: str,
53
+ dict_file: str,
54
+ spm_model_path: str,
55
+ providers=["CPUExecutionProvider"]
56
+ ):
57
+ session_opts = ort.SessionOptions()
58
+ session_opts.inter_op_num_threads = 1
59
+ session_opts.intra_op_num_threads = 1
60
+ # session_opts.log_severity_level = 1
61
+ self.session_opts = session_opts
62
+
63
+ # NOTE: 参考whisper设置的最大的解码长度
64
+ # FireRedASR-AED 模型支持的最长语音为 60s
65
+ # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
66
+ self.decode_max_len = 448
67
+
68
+ self.decoder_hidden_dim = 1280
69
+ self.num_decoder_blocks = 16
70
+ self.blank_id = 0
71
+ self.sos_id = 3
72
+ self.eos_id = 4
73
+ self.pad_id = 2
74
+
75
+ self.feature_extractor = ASRFeatExtractor(cmvn_file)
76
+ self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
77
+ self.encoder = None
78
+ self.decoder = None
79
+
80
+ self.init_encoder(encoder_path, providers)
81
+ self.init_decoder(decoder_path, providers)
82
+ self.init_decoder_main(decoder_path, providers)
83
+ self.init_decoder_loop(decoder_path, providers)
84
+ self.pe = self.init_pe(decoder_path)
85
+
86
+ def init_encoder(self, encoder_path, providers=None):
87
+ start_time = time.time()
88
+ self.encoder = ort.InferenceSession(
89
+ encoder_path,
90
+ sess_options=self.session_opts,
91
+ providers=providers
92
+ )
93
+ end_time = time.time()
94
+ logger.info(f"load encoder cost {end_time - start_time} seconds")
95
+
96
+ def init_decoder(self, decoder_path, providers=None):
97
+ start_time = time.time()
98
+ self.decoder = ort.InferenceSession(
99
+ decoder_path,
100
+ sess_options=self.session_opts,
101
+ providers=providers
102
+ )
103
+ end_time = time.time()
104
+ logger.info(f"load decoder cost {end_time - start_time} seconds")
105
+
106
+ def init_decoder_main(self, decoder_path, providers=None):
107
+ decoder_path = os.path.dirname(decoder_path)
108
+ decoder_path = os.path.join(decoder_path, "decoder_main.onnx")
109
+ start_time = time.time()
110
+ self.decoder_main = ort.InferenceSession(
111
+ decoder_path,
112
+ sess_options=self.session_opts,
113
+ providers=providers
114
+ )
115
+ end_time = time.time()
116
+ logger.info(f"load decoder_main cost {end_time - start_time} seconds")
117
+
118
+ input_names = [i.name for i in self.decoder_main.get_inputs()]
119
+ print(f"decoder_main.input_names: {input_names}")
120
+
121
+ def init_decoder_loop(self, decoder_path, providers=None):
122
+ decoder_path = os.path.dirname(decoder_path)
123
+ decoder_path = os.path.join(decoder_path, "decoder_loop.onnx")
124
+
125
+ start_time = time.time()
126
+ self.decoder_loop = ort.InferenceSession(
127
+ decoder_path,
128
+ sess_options=self.session_opts,
129
+ providers=providers
130
+ )
131
+ end_time = time.time()
132
+ logger.info(f"load decoder_loop cost {end_time - start_time} seconds")
133
+
134
+ input_names = [i.name for i in self.decoder_loop.get_inputs()]
135
+ print(f"decoder_loop.input_names: {input_names}")
136
+
137
+ def init_pe(self, decoder_path):
138
+ decoder_path = os.path.dirname(decoder_path)
139
+ decoder_path = os.path.join(decoder_path, "pe.npy")
140
+
141
+ return np.load(decoder_path)
142
+
143
+ def run_encoder(self, input: np.ndarray,
144
+ input_length: np.ndarray
145
+ ) -> Tuple[Tensor, Tensor, Tensor]:
146
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
147
+ None,
148
+ {
149
+ self.encoder.get_inputs()[0].name: input,
150
+ self.encoder.get_inputs()[1].name: input_length
151
+ }
152
+ )
153
+ return (
154
+ n_layer_cross_k,
155
+ n_layer_cross_v,
156
+ cross_attn_mask
157
+ )
158
+
159
+ def decode_one_token(
160
+ self,
161
+ tokens: np.ndarray,
162
+ n_layer_self_k_cache: np.ndarray,
163
+ n_layer_self_v_cache: np.ndarray,
164
+ n_layer_cross_k_cache: np.ndarray,
165
+ n_layer_cross_v_cache: np.ndarray,
166
+ offset: np.ndarray,
167
+ self_attn_mask: np.ndarray,
168
+ cross_attn_mask: np.ndarray
169
+ ) -> Tuple[Tensor, Tensor, Tensor]:
170
+ # print("decode:")
171
+ # print(f"tokens.shape: {tokens.shape}")
172
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
173
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
174
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
175
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
176
+ # print(f"offset.shape: {offset.shape}")
177
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
178
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
179
+ # print(f"self_attn_mask: {self_attn_mask}")
180
+
181
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
182
+ None,
183
+ {
184
+ self.decoder.get_inputs()[0].name: tokens,
185
+ self.decoder.get_inputs()[1].name: n_layer_self_k_cache,
186
+ self.decoder.get_inputs()[2].name: n_layer_self_v_cache,
187
+ self.decoder.get_inputs()[3].name: n_layer_cross_k_cache,
188
+ self.decoder.get_inputs()[4].name: n_layer_cross_v_cache,
189
+ self.decoder.get_inputs()[5].name: offset,
190
+ self.decoder.get_inputs()[6].name: self_attn_mask,
191
+ self.decoder.get_inputs()[7].name: cross_attn_mask,
192
+ }
193
+ )
194
+ return (
195
+ logits,
196
+ out_n_layer_self_k_cache,
197
+ out_n_layer_self_v_cache
198
+ )
199
+
200
+ def decode_main_one_token(
201
+ self,
202
+ tokens: np.ndarray,
203
+ n_layer_self_k_cache: np.ndarray,
204
+ n_layer_self_v_cache: np.ndarray,
205
+ n_layer_cross_k_cache: np.ndarray,
206
+ n_layer_cross_v_cache: np.ndarray,
207
+ pe: np.ndarray,
208
+ self_attn_mask: np.ndarray,
209
+ cross_attn_mask: np.ndarray
210
+ ) -> Tuple[Tensor, Tensor, Tensor]:
211
+ # print("decode_main:")
212
+ # print(f"tokens.shape: {tokens.shape}")
213
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
214
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
215
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
216
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
217
+ # print(f"pe.shape: {pe.shape}")
218
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
219
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
220
+
221
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_main.run(
222
+ None,
223
+ {
224
+ self.decoder_main.get_inputs()[0].name: tokens,
225
+ # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
226
+ self.decoder_main.get_inputs()[1].name: n_layer_cross_k_cache,
227
+ self.decoder_main.get_inputs()[2].name: n_layer_cross_v_cache,
228
+ self.decoder_main.get_inputs()[3].name: pe,
229
+ self.decoder_main.get_inputs()[4].name: self_attn_mask,
230
+ self.decoder_main.get_inputs()[5].name: cross_attn_mask,
231
+ # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
232
+ }
233
+ )
234
+ return (
235
+ logits,
236
+ out_n_layer_self_k_cache,
237
+ out_n_layer_self_v_cache
238
+ )
239
+
240
+ def decode_loop_one_token(
241
+ self,
242
+ tokens: np.ndarray,
243
+ n_layer_self_k_cache: np.ndarray,
244
+ n_layer_self_v_cache: np.ndarray,
245
+ n_layer_cross_k_cache: np.ndarray,
246
+ n_layer_cross_v_cache: np.ndarray,
247
+ pe: np.ndarray,
248
+ self_attn_mask: np.ndarray,
249
+ cross_attn_mask: np.ndarray
250
+ ) -> Tuple[Tensor, Tensor, Tensor]:
251
+ # print("decode_loop:")
252
+ # print(f"tokens.shape: {tokens.shape}")
253
+ # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
254
+ # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
255
+ # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
256
+ # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
257
+ # print(f"pe.shape: {pe.shape}")
258
+ # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
259
+ # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
260
+
261
+ logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_loop.run(
262
+ None,
263
+ {
264
+ self.decoder_loop.get_inputs()[0].name: tokens,
265
+ self.decoder_loop.get_inputs()[1].name: n_layer_self_k_cache,
266
+ self.decoder_loop.get_inputs()[2].name: n_layer_self_v_cache,
267
+ self.decoder_loop.get_inputs()[3].name: n_layer_cross_k_cache,
268
+ self.decoder_loop.get_inputs()[4].name: n_layer_cross_v_cache,
269
+ self.decoder_loop.get_inputs()[5].name: pe,
270
+ self.decoder_loop.get_inputs()[6].name: self_attn_mask,
271
+ self.decoder_loop.get_inputs()[7].name: cross_attn_mask,
272
+ }
273
+ )
274
+ return (
275
+ logits,
276
+ out_n_layer_self_k_cache,
277
+ out_n_layer_self_v_cache
278
+ )
279
+
280
+ def run_decoder(
281
+ self,
282
+ n_layer_cross_k,
283
+ n_layer_cross_v,
284
+ cross_attn_mask,
285
+ beam_size,
286
+ nbest,
287
+ decoder_data_path
288
+ ):
289
+
290
+ num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
291
+ encoder_out_length = cross_attn_mask.shape[-1]
292
+
293
+ cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
294
+ cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
295
+ 1, beam_size, 1, 1
296
+ ).view(beam_size * batch_size, -1, encoder_out_length)
297
+
298
+ n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
299
+ n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
300
+ n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
301
+ 1, 1, beam_size, 1, 1
302
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
303
+ n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
304
+ 1, 1, beam_size, 1, 1
305
+ ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
306
+
307
+ prediction_tokens = torch.ones(
308
+ beam_size * batch_size, 1).fill_(self.sos_id).long()
309
+ tokens = prediction_tokens
310
+ offset = torch.zeros(1, dtype=torch.int64)
311
+ n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
312
+ batch_size, beam_size
313
+ )
314
+
315
+ scores = torch.tensor([0.0] + [-INF]*(beam_size - 1)).float()
316
+ scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
317
+ is_finished = torch.zeros_like(scores)
318
+
319
+ # self_attn_mask = torch.zeros(
320
+ # batch_size * beam_size,
321
+ # 1, 1
322
+ # )
323
+ self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
324
+
325
+ results = [self.sos_id]
326
+ for i in range(self.decode_max_len):
327
+
328
+ # ==== ORIGIN ====
329
+ # self_attn_mask = torch.empty(
330
+ # batch_size * beam_size,
331
+ # prediction_tokens.shape[-1], prediction_tokens.shape[-1]
332
+ # ).fill_(-np.inf).triu_(1)
333
+ # self_attn_mask = self_attn_mask[:, -1:, :]
334
+ # self_attn_mask = to_numpy(self_attn_mask)
335
+
336
+ # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_one_token(
337
+ # to_numpy(tokens),
338
+ # to_numpy(n_layer_self_k_cache),
339
+ # to_numpy(n_layer_self_v_cache),
340
+ # to_numpy(n_layer_cross_k),
341
+ # to_numpy(n_layer_cross_v),
342
+ # to_numpy(offset),
343
+ # to_numpy(self_attn_mask),
344
+ # to_numpy(cross_attn_mask)
345
+ # )
346
+ # ==== ORIGIN ====
347
+
348
+ # tokens = to_numpy(tokens)
349
+ # n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
350
+ # n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
351
+ # n_layer_cross_k = to_numpy(n_layer_cross_k)
352
+ # n_layer_cross_v = to_numpy(n_layer_cross_v)
353
+ # cross_attn_mask = to_numpy(cross_attn_mask)
354
+
355
+ # for name, npy in zip(
356
+ # ["tokens", "n_layer_self_k_cache", "n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "pe", "self_attn_mask", "cross_attn_mask"],
357
+ # [tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, self.pe[offset], self_attn_mask, cross_attn_mask]
358
+ # ):
359
+ # file_path = os.path.join(decoder_data_path, name)
360
+ # os.makedirs(file_path, exist_ok=True)
361
+ # np.save(os.path.join(file_path, f"{i}.npy"), npy)
362
+
363
+ # if i == 0:
364
+ # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
365
+ # to_numpy(tokens),
366
+ # to_numpy(n_layer_self_k_cache),
367
+ # to_numpy(n_layer_self_v_cache),
368
+ # to_numpy(n_layer_cross_k),
369
+ # to_numpy(n_layer_cross_v),
370
+ # self.pe[0],
371
+ # self_attn_mask,
372
+ # to_numpy(cross_attn_mask)
373
+ # )
374
+ # else:
375
+ # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
376
+ # to_numpy(tokens),
377
+ # to_numpy(n_layer_self_k_cache),
378
+ # to_numpy(n_layer_self_v_cache),
379
+ # to_numpy(n_layer_cross_k),
380
+ # to_numpy(n_layer_cross_v),
381
+ # self.pe[offset],
382
+ # self_attn_mask,
383
+ # to_numpy(cross_attn_mask)
384
+ # )
385
+ logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
386
+ to_numpy(tokens),
387
+ to_numpy(n_layer_self_k_cache),
388
+ to_numpy(n_layer_self_v_cache),
389
+ to_numpy(n_layer_cross_k),
390
+ to_numpy(n_layer_cross_v),
391
+ self.pe[offset],
392
+ self_attn_mask,
393
+ to_numpy(cross_attn_mask)
394
+ )
395
+
396
+ offset += 1
397
+ logits = torch.from_numpy(logits)
398
+
399
+ logits = logits.squeeze(1)
400
+ t_scores = F.log_softmax(logits, dim=-1)
401
+ t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
402
+ t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
403
+ t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
404
+
405
+ scores = scores + t_topB_scores
406
+
407
+ scores = scores.view(batch_size, beam_size * beam_size)
408
+ scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
409
+ scores = scores.view(-1, 1)
410
+
411
+ topB_row_number_in_each_B_rows_of_ys = torch.div(
412
+ topB_score_ids, beam_size).view(batch_size * beam_size)
413
+ stride = beam_size * torch.arange(batch_size).view(
414
+ batch_size, 1).repeat(1, beam_size).view(batch_size * beam_size)
415
+ topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
416
+
417
+ prediction_tokens = prediction_tokens[topB_row_number_in_ys]
418
+ t_ys = torch.gather(
419
+ t_topB_ys.view(batch_size, beam_size * beam_size),
420
+ dim=1, index=topB_score_ids
421
+ ).view(beam_size * batch_size, 1)
422
+
423
+ tokens = t_ys
424
+
425
+ prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
426
+
427
+ n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
428
+ n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
429
+
430
+ for i, self_k_cache in enumerate(n_layer_self_k_cache):
431
+ n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
432
+
433
+ for i, self_v_cache in enumerate(n_layer_self_v_cache):
434
+ n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
435
+
436
+ is_finished = t_ys.eq(self.eos_id)
437
+ if is_finished.sum().item() == beam_size * batch_size:
438
+ break
439
+
440
+ scores = scores.view(batch_size, beam_size)
441
+ prediction_valid_token_lengths = torch.sum(
442
+ torch.ne(
443
+ prediction_tokens.view(batch_size, beam_size, -1),
444
+ self.eos_id),
445
+ dim=-1
446
+ ).int()
447
+
448
+ nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
449
+ index = nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
450
+ nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
451
+ nbest_prediction_tokens = nbest_prediction_tokens.view(batch_size, nbest_ids.size(1), -1)
452
+ nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
453
+ batch_size * beam_size)[index.view(-1)].view(batch_size, -1)
454
+ nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
455
+ for i in range(batch_size):
456
+ i_best_hyps: List[Dict[str, torch.Tensor]] = []
457
+ for j, score in enumerate(nbest_scores[i]):
458
+ hyp = {
459
+ "token_ids": nbest_prediction_tokens[i, j, 1:nbest_prediction_valid_token_lengths[i, j]],
460
+ "score": score
461
+ }
462
+ i_best_hyps.append(hyp)
463
+ nbest_hyps.append(i_best_hyps)
464
+
465
+ return nbest_hyps
466
+
467
+ def get_initialized_self_cache(self,
468
+ batch_size,
469
+ beam_size
470
+ ) -> Tuple[Tensor, Tensor]:
471
+ n_layer_self_k_cache = torch.zeros(
472
+ self.num_decoder_blocks,
473
+ batch_size * beam_size,
474
+ self.decode_max_len,
475
+ self.decoder_hidden_dim,
476
+ )
477
+ n_layer_self_v_cache = torch.zeros(
478
+ self.num_decoder_blocks,
479
+ batch_size * beam_size,
480
+ self.decode_max_len,
481
+ self.decoder_hidden_dim,
482
+ )
483
+ return n_layer_self_k_cache, n_layer_self_v_cache
484
+
485
+ def calc_feat_len(self, audio_dur):
486
+ import math
487
+ sample_rate = 16000
488
+ frame_length = 25 * sample_rate / 1000
489
+ frame_shift = 10 * sample_rate / 1000
490
+ length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
491
+ return length
492
+
493
+ def transcribe(self,
494
+ batch_wav_path: List[str],
495
+ beam_size: int = 1,
496
+ nbest: int = 1
497
+ ) -> List[Dict]:
498
+ feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
499
+ print(f"feats.shape: {feats.shape}")
500
+ maxlen = self.calc_feat_len(10)
501
+ if feats.shape[1] < maxlen:
502
+ feats = np.concatenate([feats, np.zeros((1, maxlen - feats.shape[1], 80), dtype=np.float32)], axis=1)
503
+ feats = feats[:, :maxlen, :]
504
+
505
+ # encoder_data_path = os.path.join("calib_dataset", "encoder", os.path.basename(batch_wav_path[0]))
506
+ decoder_data_path = os.path.join("calib_dataset", "decoder", os.path.basename(batch_wav_path[0]))
507
+ # os.makedirs(encoder_data_path, exist_ok=True)
508
+ # os.makedirs(decoder_data_path, exist_ok=True)
509
+
510
+ feats = to_numpy(feats)
511
+ lengths = to_numpy(lengths)
512
+
513
+ # for name, npy in zip(["encoder_input", "encoder_input_lengths"], [feats, lengths]):
514
+ # file_path = os.path.join(encoder_data_path, name + ".npy")
515
+ # np.save(file_path, npy)
516
+
517
+ start_time = time.time()
518
+ n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
519
+ to_numpy(feats),
520
+ to_numpy(lengths)
521
+ )
522
+ nbest_hyps = self.run_decoder(n_layer_cross_k,
523
+ n_layer_cross_v,
524
+ cross_attn_mask,
525
+ beam_size,
526
+ nbest,
527
+ decoder_data_path)
528
+ transcribe_durations = time.time() - start_time
529
+ results: List[Dict] = []
530
+ for wav, hyp in zip(batch_wav_path, nbest_hyps):
531
+ hyp = hyp[0]
532
+ hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
533
+ score = hyp["score"].item()
534
+ text = self.tokenizer.detokenize(hyp_ids)
535
+ results.append(
536
+ {
537
+ "wav": wav,
538
+ "text": text,
539
+ "score": score
540
+ }
541
+ )
542
+
543
+ return results, wav_durations, transcribe_durations
544
+
545
+
546
+ def parse_args():
547
+ parser = argparse.ArgumentParser(description="FireRedASROnnxModel Test")
548
+ parser.add_argument(
549
+ "--encoder",
550
+ type=str,
551
+ default="onnx_encoder/encoder.onnx",
552
+ help="Path to onnx encoder"
553
+ )
554
+ parser.add_argument(
555
+ "--decoder",
556
+ type=str,
557
+ default="onnx_decoder/decoder.onnx",
558
+ help="Path to onnx decoder"
559
+ )
560
+ parser.add_argument(
561
+ "--cmvn",
562
+ type=str,
563
+ default="axmodel/cmvn.ark",
564
+ help="Path to cmvn"
565
+ )
566
+ parser.add_argument(
567
+ "--dict",
568
+ type=str,
569
+ default="axmodel/dict.txt",
570
+ help="Path to dict"
571
+ )
572
+ parser.add_argument(
573
+ "--spm_model",
574
+ type=str,
575
+ default="axmodel/train_bpe1000.model",
576
+ help="Path to spm model"
577
+ )
578
+ parser.add_argument(
579
+ "--wavlist",
580
+ type=str,
581
+ default="wavlist.txt",
582
+ help="File to wav path list"
583
+ )
584
+ parser.add_argument(
585
+ "--hypo",
586
+ type=str,
587
+ default="hypo_onnx.txt",
588
+ help="File of hypos"
589
+ )
590
+ parser.add_argument(
591
+ "--beam_size",
592
+ type=int,
593
+ default=3,
594
+ help=""
595
+ )
596
+ parser.add_argument(
597
+ "--nbest",
598
+ type=int,
599
+ default=1,
600
+ help=""
601
+ )
602
+ parser.add_argument(
603
+ "--provider",
604
+ default="CPUExecutionProvider",
605
+ choices=['CUDAExecutionProvider', 'CPUExecutionProvider']
606
+ )
607
+
608
+ return parser.parse_args()
609
+
610
+
611
+ def parse_wavlist(wavlist: str):
612
+ wavpaths = []
613
+ with open(wavlist) as f:
614
+ for line in f:
615
+ line = line.strip()
616
+ if not os.path.exists(line):
617
+ print(f"{line} doesn't exist.")
618
+ continue
619
+ wavpaths.append(line)
620
+
621
+ return wavpaths
622
+
623
+
624
+ def main():
625
+ args = parse_args()
626
+ print(args)
627
+
628
+ onnx_model = FireRedASROnnxModel(args.encoder,
629
+ args.decoder,
630
+ args.cmvn,
631
+ args.dict,
632
+ args.spm_model,
633
+ [args.provider])
634
+
635
+ wf = open(args.hypo, "wt")
636
+ wavlist = parse_wavlist(args.wavlist)
637
+
638
+ total_wav_durations = 0
639
+ total_transcribe_durations = 0
640
+ for wav in wavlist:
641
+ batch_wav = [wav]
642
+ results, wav_durations, transcribe_durations = onnx_model.transcribe(
643
+ batch_wav, args.beam_size, args.nbest)
644
+
645
+ wav_durations = sum(wav_durations)
646
+ total_wav_durations += wav_durations
647
+ total_transcribe_durations += transcribe_durations
648
+ logger.info(f"{batch_wav}")
649
+ logger.info(f"Durations: {wav_durations}")
650
+ logger.info(f"Transcribe Durations: {transcribe_durations}")
651
+ rtf = transcribe_durations / wav_durations
652
+ logger.info(f"(Real time factor) RTF: {rtf}")
653
+ for result in results:
654
+ logger.info(f"wav: {result['wav']}")
655
+ logger.info(f"text: {result['text']}")
656
+ logger.info(f"score: {result['score']}")
657
+ logger.info("")
658
+ wf.write(f"{result['text']} ({result['wav']})\n")
659
+
660
+ logger.info(f"total wav durations: {total_wav_durations}")
661
+ logger.info(f"total transcribe durations: {total_transcribe_durations}")
662
+ avg_ref = total_transcribe_durations / total_wav_durations
663
+ logger.info(f"AVG RTF: {avg_ref}")
664
+
665
+ wf.close()
666
+
667
+ # import tarfile as tf
668
+ # import glob
669
+
670
+ # with tf.open("./calib_dataset/encoder_input.tar.gz", "w:gz") as f:
671
+ # for npy in glob.glob("./calib_dataset/encoder/*/encoder_input.npy"):
672
+ # f.add(npy)
673
+
674
+ # with tf.open("./calib_dataset/encoder_input_lengths.tar.gz", "w:gz") as f:
675
+ # for npy in glob.glob("./calib_dataset/encoder/*/encoder_input_lengths.npy"):
676
+ # f.add(npy)
677
+
678
+ # for decoder_input in ["tokens", "n_layer_self_k_cache", "n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "pe", "self_attn_mask", "cross_attn_mask"]:
679
+ # with tf.open(f"./calib_dataset/{decoder_input}.tar.gz", "w:gz") as f:
680
+ # for npy in glob.glob(f"./calib_dataset/decoder/*/{decoder_input}"):
681
+ # f.add(npy)
682
+
683
+ if __name__ == "__main__":
684
+ main()
wavlist.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav
2
+ wav/TEST_MEETING_T0000000001_S00000.wav
3
+ wav/IT0011W0001.wav
4
+ wav/BAC009S0764W0121.wav