inoryQwQ commited on
Commit
3399058
·
1 Parent(s): 40049a2

Delete unneceesary files

Browse files
model_convert/model_wrapper.py DELETED
@@ -1,431 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch import Tensor
4
-
5
- from fireredasr.models.module.conformer_encoder import ConformerEncoder
6
- from fireredasr.models.module.transformer_decoder import (
7
- TransformerDecoder,
8
- DecoderLayer,
9
- DecoderMultiHeadAttention,
10
- DecoderScaledDotProductAttention,
11
- PositionalEncoding
12
- )
13
-
14
-
15
- def DecoderScaledDotProductAttentionForward(
16
- self: DecoderScaledDotProductAttention,
17
- q: Tensor,
18
- k: Tensor,
19
- v: Tensor,
20
- mask: Tensor
21
- ):
22
- attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
23
- if mask is not None:
24
- # mask is such as [[[0, 0, 0, 0, ..., -inf, -inf]]]
25
- attn = attn + mask
26
- attn = torch.softmax(attn, dim=-1)
27
- else:
28
- attn = torch.softmax(attn, dim=-1)
29
- output = torch.matmul(attn, v)
30
- return output
31
-
32
- DecoderScaledDotProductAttention.forward = DecoderScaledDotProductAttentionForward
33
-
34
-
35
- """
36
- The purpose of this is to allow the exported onnx model
37
- to only need to pass in the token id of the decoding result
38
- of the previous time step when performing decoding inference at each time step,
39
- rather than the token id of all previous time steps.
40
- """
41
- def PositionalEncodingForward(
42
- self: PositionalEncoding,
43
- offset: Tensor
44
- ):
45
- return self.pe[:, :offset].clone().detach()[:, -1]
46
-
47
- PositionalEncoding.forward = PositionalEncodingForward
48
-
49
-
50
- """
51
- NOTE(Lianghu): Why do that?
52
-
53
- When exporting the onnx model using original padding_position_is_0 funciton,
54
- the dynamic batch does not work properly for the exported onnx model.
55
-
56
- The code in the original padding_position_is_0 function is as follows:
57
- ```py
58
- def padding_position_is_0(...):
59
- N, T = padded_input.size()[:2]
60
- mask = torch.ones((N, T)).to(padded_input.device)
61
- ...
62
- ```
63
-
64
- Because when exporting onnx, N and T are considered constants.
65
- Should be N = padded_input.size(0) and T = padded_input.size(1).
66
- """
67
- def padding_position_is_0(self: ConformerEncoder,
68
- padded_input: Tensor,
69
- input_lengths: Tensor):
70
- N = padded_input.size(0)
71
- T = padded_input.size(1)
72
- seq_range = torch.arange(T, device=padded_input.device).unsqueeze(0) # shape: (1, T)
73
- input_lengths_exp = input_lengths.unsqueeze(1) # shape: (N, 1)
74
- mask = seq_range < input_lengths_exp # shape: (N, T)
75
- mask = mask.unsqueeze(dim=1)
76
- return mask.to(torch.uint8)
77
-
78
-
79
- ConformerEncoder.padding_position_is_0 = padding_position_is_0
80
-
81
- class AudioEncoderTensorCache(nn.Module):
82
- def __init__(self,
83
- encoder: ConformerEncoder,
84
- decoder: TransformerDecoder):
85
- super().__init__()
86
- self.encoder = encoder
87
- self.decoder = decoder
88
-
89
- def forward(self, input: Tensor, input_length: Tensor):
90
- encoder_output, _, encoder_mask = self.encoder(input, input_length)
91
-
92
- n_layer_cross_k_list = []
93
- n_layer_cross_v_list = []
94
-
95
- for layer in self.decoder.layer_stack:
96
- # layer: DecoderLayer
97
- n_layer_cross_k_list.append(layer.cross_attn.w_ks(encoder_output))
98
- n_layer_cross_v_list.append(layer.cross_attn.w_vs(encoder_output))
99
-
100
- encoder_mask = encoder_mask.to(torch.float32)
101
- encoder_mask[encoder_mask == 0] = -torch.inf
102
- encoder_mask[encoder_mask == 1] = 0.0
103
-
104
- return (torch.stack(n_layer_cross_k_list),
105
- torch.stack(n_layer_cross_v_list),
106
- encoder_mask)
107
-
108
-
109
- class DecoderMultiHeadSelfAttention(nn.Module):
110
- def __init__(self, multiHeadSelfAttention: DecoderMultiHeadAttention, loop: bool = False):
111
- super().__init__()
112
- self.multiHeadSelfAttention = multiHeadSelfAttention
113
- self.loop = loop
114
-
115
- def forward(self,
116
- x: Tensor,
117
- k_cache: Tensor,
118
- v_cache: Tensor,
119
- mask: Tensor):
120
- bs = x.size(0)
121
-
122
- # 当前时间步为 t
123
- # k_cache 和 v_cache 是 时间步 [0: t-1] 的 self_attn_k 和 self_attn_v 的缓存
124
- q = self.multiHeadSelfAttention.w_qs(x)
125
- k = self.multiHeadSelfAttention.w_ks(x)
126
- v = self.multiHeadSelfAttention.w_vs(x)
127
-
128
- k_cache[:, -k.shape[1] :, :] = k
129
- v_cache[:, -v.shape[1] :, :] = v
130
- # if self.loop:
131
- # k_cache = torch.cat([k_cache[:, 1:, :], k], 1)
132
- # v_cache = torch.cat([v_cache[:, 1:, :], v], 1)
133
- # else:
134
- # k_cache = k
135
- # v_cache = v
136
-
137
- q = q.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
138
- k = k_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
139
- v = v_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
140
- k = k.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
141
- v = v.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
142
- q = q.transpose(1, 2)
143
- k = k.transpose(1, 2)
144
- v = v.transpose(1, 2)
145
-
146
- if mask is not None:
147
- mask = mask.unsqueeze(1)
148
-
149
- output = self.multiHeadSelfAttention.attention(q, k, v, mask)
150
- output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadSelfAttention.d_model)
151
- output = self.multiHeadSelfAttention.fc(output)
152
- output = self.multiHeadSelfAttention.dropout(output)
153
-
154
- return output, k_cache, v_cache
155
-
156
-
157
- class DecoderMultiHeadSelfAttentionV2(nn.Module):
158
- def __init__(self, multiHeadSelfAttention: DecoderMultiHeadAttention, loop: bool = False):
159
- super().__init__()
160
- self.multiHeadSelfAttention = multiHeadSelfAttention
161
- self.loop = loop
162
-
163
- def forward(self,
164
- x: Tensor,
165
- k_cache: Tensor,
166
- v_cache: Tensor,
167
- mask: Tensor):
168
- bs = x.size(0)
169
-
170
- # 当前时间步为 t
171
- # k_cache 和 v_cache 是 时间步 [0: t-1] 的 self_attn_k 和 self_attn_v 的缓存
172
- q = self.multiHeadSelfAttention.w_qs(x)
173
- k = self.multiHeadSelfAttention.w_ks(x)
174
- v = self.multiHeadSelfAttention.w_vs(x)
175
-
176
- # k_cache[:, -k.shape[1] :, :] = k
177
- # v_cache[:, -v.shape[1] :, :] = v
178
- if self.loop:
179
- k_cache = torch.cat([k_cache[:, 1:, :], k], 1)
180
- v_cache = torch.cat([v_cache[:, 1:, :], v], 1)
181
- else:
182
- k_cache = k
183
- v_cache = v
184
-
185
- q = q.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
186
- k = k_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
187
- v = v_cache.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
188
- k = k.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
189
- v = v.view(bs, -1, self.multiHeadSelfAttention.n_head, self.multiHeadSelfAttention.d_k)
190
- q = q.transpose(1, 2)
191
- k = k.transpose(1, 2)
192
- v = v.transpose(1, 2)
193
-
194
- if mask is not None:
195
- mask = mask.unsqueeze(1)
196
-
197
- output = self.multiHeadSelfAttention.attention(q, k, v, mask)
198
- output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadSelfAttention.d_model)
199
- output = self.multiHeadSelfAttention.fc(output)
200
- output = self.multiHeadSelfAttention.dropout(output)
201
-
202
- return output, k_cache, v_cache
203
-
204
-
205
- class DecoderMultiHeadCrossAttention(nn.Module):
206
- def __init__(self, multiHeadCrossAttention: DecoderMultiHeadAttention):
207
- super().__init__()
208
- self.multiHeadCrossAttention = multiHeadCrossAttention
209
-
210
- def forward(self,
211
- x: Tensor,
212
- k: Tensor,
213
- v: Tensor,
214
- mask: Tensor):
215
- bs = x.size(0)
216
- x = self.multiHeadCrossAttention.w_qs(x)
217
- x = x.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k)
218
- k = k.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k)
219
- v = v.view(bs, -1, self.multiHeadCrossAttention.n_head, self.multiHeadCrossAttention.d_k)
220
-
221
- x = x.transpose(1, 2)
222
- k = k.transpose(1, 2)
223
- v = v.transpose(1, 2)
224
-
225
- if mask is not None:
226
- mask = mask.unsqueeze(1)
227
-
228
- output = self.multiHeadCrossAttention.attention(x, k, v, mask)
229
- output = output.transpose(1, 2).contiguous().view(bs, -1, self.multiHeadCrossAttention.d_model)
230
- output = self.multiHeadCrossAttention.fc(output)
231
- output = self.multiHeadCrossAttention.dropout(output)
232
-
233
- return output
234
-
235
-
236
- class ResidualAttentionBlockTensorCache(nn.Module):
237
- def __init__(self, decoder_layer: DecoderLayer, loop: bool = False):
238
- super().__init__()
239
- self.original_decoder_layer = decoder_layer
240
- self.self_attn = DecoderMultiHeadSelfAttention(decoder_layer.self_attn, loop)
241
- self.cross_attn = DecoderMultiHeadCrossAttention(decoder_layer.cross_attn)
242
-
243
- def forward(self,
244
- x: Tensor,
245
- self_k_cache: Tensor,
246
- self_v_cache: Tensor,
247
- cross_k: Tensor,
248
- cross_v: Tensor,
249
- self_attn_mask: Tensor,
250
- cross_attn_mask: Tensor):
251
- # q.shape (B, 1, dim)
252
- x_self_attn_norm = self.original_decoder_layer.self_attn_norm(x)
253
- self_attn_x, self_k_cache_updated, self_v_cache_updated = self.self_attn(
254
- x_self_attn_norm, self_k_cache, self_v_cache, self_attn_mask)
255
-
256
- x = x + self_attn_x
257
-
258
- residual = x
259
- x_cross_attn_norm = self.original_decoder_layer.cross_attn_norm(x)
260
- x_cross_attn = self.cross_attn(x_cross_attn_norm, cross_k, cross_v, cross_attn_mask)
261
- x = residual + x_cross_attn
262
-
263
- x = x + self.original_decoder_layer.mlp(self.original_decoder_layer.mlp_norm(x))
264
-
265
- return x, self_k_cache_updated, self_v_cache_updated
266
-
267
-
268
- class ResidualAttentionBlockTensorCacheV2(nn.Module):
269
- def __init__(self, decoder_layer: DecoderLayer, loop: bool = False):
270
- super().__init__()
271
- self.original_decoder_layer = decoder_layer
272
- self.self_attn = DecoderMultiHeadSelfAttentionV2(decoder_layer.self_attn, loop)
273
- self.cross_attn = DecoderMultiHeadCrossAttention(decoder_layer.cross_attn)
274
-
275
- def forward(self,
276
- x: Tensor,
277
- self_k_cache: Tensor,
278
- self_v_cache: Tensor,
279
- cross_k: Tensor,
280
- cross_v: Tensor,
281
- self_attn_mask: Tensor,
282
- cross_attn_mask: Tensor):
283
- # q.shape (B, 1, dim)
284
- x_self_attn_norm = self.original_decoder_layer.self_attn_norm(x)
285
- self_attn_x, self_k_cache_updated, self_v_cache_updated = self.self_attn(
286
- x_self_attn_norm, self_k_cache, self_v_cache, self_attn_mask)
287
-
288
- x = x + self_attn_x
289
-
290
- residual = x
291
- x_cross_attn_norm = self.original_decoder_layer.cross_attn_norm(x)
292
- x_cross_attn = self.cross_attn(x_cross_attn_norm, cross_k, cross_v, cross_attn_mask)
293
- x = residual + x_cross_attn
294
-
295
- x = x + self.original_decoder_layer.mlp(self.original_decoder_layer.mlp_norm(x))
296
-
297
- return x, self_k_cache_updated, self_v_cache_updated
298
-
299
-
300
- class TextDecoderTensorCache(nn.Module):
301
- def __init__(self, decoder: TransformerDecoder):
302
- super().__init__()
303
- self.decoder = decoder
304
-
305
- self.blocks = []
306
- for original_layer in self.decoder.layer_stack:
307
- self.blocks.append(
308
- ResidualAttentionBlockTensorCache(original_layer))
309
-
310
- def forward(self,
311
- tokens: Tensor,
312
- n_layer_self_k_cache: Tensor,
313
- n_layer_self_v_cache: Tensor,
314
- n_layer_cross_k: Tensor,
315
- n_layer_cross_v: Tensor,
316
- offset: Tensor,
317
- self_attn_mask: Tensor,
318
- cross_attn_mask: Tensor):
319
- """
320
- TODO(Lianghu): Integrate self_attn_mask into the model inference process
321
- instead of passing it in through an external interface.
322
- """
323
- x = self.decoder.dropout(
324
- self.decoder.tgt_word_emb(tokens) * self.decoder.scale +
325
- self.decoder.positional_encoding(offset + 1)
326
- )
327
-
328
- i = 0
329
- for block in self.blocks:
330
- self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :]
331
- self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :]
332
- x, self_k_cache, self_v_cache = block(
333
- x,
334
- self_k_cache,
335
- self_v_cache,
336
- n_layer_cross_k[i],
337
- n_layer_cross_v[i],
338
- self_attn_mask,
339
- cross_attn_mask
340
- )
341
- n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache
342
- n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache
343
- i += 1
344
-
345
- output = self.decoder.layer_norm_out(x)
346
- logits = self.decoder.tgt_word_prj(output)
347
-
348
- return logits, n_layer_self_k_cache, n_layer_self_v_cache
349
-
350
-
351
- class TextDecoderTensorCacheV2(nn.Module):
352
- def __init__(self, decoder: TransformerDecoder, loop: bool = False):
353
- super().__init__()
354
- self.decoder = decoder
355
- self.loop = loop
356
-
357
- self.blocks = []
358
- for original_layer in self.decoder.layer_stack:
359
- self.blocks.append(
360
- ResidualAttentionBlockTensorCacheV2(original_layer, loop))
361
-
362
- def forward(self,
363
- tokens: Tensor,
364
- n_layer_self_k_cache: Tensor,
365
- n_layer_self_v_cache: Tensor,
366
- n_layer_cross_k: Tensor,
367
- n_layer_cross_v: Tensor,
368
- positional_embedding: Tensor,
369
- self_attn_mask: Tensor,
370
- cross_attn_mask: Tensor):
371
- """
372
- TODO(Lianghu): Integrate self_attn_mask into the model inference process
373
- instead of passing it in through an external interface.
374
- """
375
- x = self.decoder.dropout(
376
- self.decoder.tgt_word_emb(tokens) * self.decoder.scale +
377
- positional_embedding
378
- )
379
- # if self.loop:
380
- # x = self.decoder.dropout(
381
- # self.decoder.tgt_word_emb(tokens) * self.decoder.scale +
382
- # positional_embedding
383
- # )
384
- # else:
385
- # x = self.decoder.dropout(
386
- # self.decoder.tgt_word_emb(tokens) * self.decoder.scale +
387
- # self.decoder.positional_encoding.pe[:, : tokens.shape[-1]]
388
- # )
389
-
390
- i = 0
391
- self_k_cache_out = []
392
- self_v_cache_out = []
393
- for block in self.blocks:
394
- self_k_cache = n_layer_self_k_cache[i, :, :, :]
395
- self_v_cache = n_layer_self_v_cache[i, :, :, :]
396
- if self.loop:
397
- x, self_k_cache, self_v_cache = block(
398
- x,
399
- self_k_cache,
400
- self_v_cache,
401
- n_layer_cross_k[i],
402
- n_layer_cross_v[i],
403
- self_attn_mask,
404
- cross_attn_mask
405
- )
406
- self_k_cache_out.append(self_k_cache.unsqueeze(0))
407
- self_v_cache_out.append(self_v_cache.unsqueeze(0))
408
- else:
409
- n_audio, n_text_ctx, ntext_state = self_k_cache.shape
410
-
411
- x, self_k_cache, self_v_cache = block(
412
- x,
413
- self_k_cache,
414
- self_v_cache,
415
- n_layer_cross_k[i],
416
- n_layer_cross_v[i],
417
- self_attn_mask,
418
- cross_attn_mask
419
- )
420
- self_k_cache_out.append(torch.cat((torch.zeros([n_audio, n_text_ctx - self_k_cache.shape[1], ntext_state]).to(self_k_cache.device), self_k_cache), 1).unsqueeze(0))
421
- self_v_cache_out.append(torch.cat((torch.zeros([n_audio, n_text_ctx - self_v_cache.shape[1], ntext_state]).to(self_v_cache.device), self_v_cache), 1).unsqueeze(0))
422
-
423
- i += 1
424
-
425
- n_layer_self_k_cache = torch.cat(self_k_cache_out, 0)
426
- n_layer_self_v_cache = torch.cat(self_v_cache_out, 0)
427
-
428
- output = self.decoder.layer_norm_out(x)
429
- logits = self.decoder.tgt_word_prj(output)
430
-
431
- return logits, n_layer_self_k_cache, n_layer_self_v_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_convert/to_onnx.py DELETED
@@ -1,525 +0,0 @@
1
- import model_wrapper
2
- from fireredasr.models.fireredasr import FireRedAsrAed
3
-
4
- import torch
5
- import onnx
6
- import onnxruntime
7
- from onnxruntime.quantization import QuantType, quantize_dynamic
8
- import onnxslim
9
- from onnx.external_data_helper import convert_model_to_external_data
10
- import numpy as np
11
- import math
12
- import kaldiio
13
-
14
- import os
15
- import argparse
16
- from typing import Dict, Any
17
-
18
- def to_numpy(tensor):
19
- if tensor.requires_grad:
20
- return tensor.detach().cpu().numpy()
21
- else:
22
- return tensor.cpu().numpy()
23
-
24
-
25
- def load_model(model_path):
26
- package = torch.load(model_path,
27
- map_location=lambda storage,
28
- loc: storage, weights_only=False)
29
- model = FireRedAsrAed.from_args(package["args"])
30
- model.load_state_dict(package["model_state_dict"], strict=True)
31
- return model, package["args"]
32
-
33
-
34
- def read_kaldi_cmvn(kaldi_cmvn_file):
35
- assert os.path.exists(kaldi_cmvn_file)
36
- stats = kaldiio.load_mat(kaldi_cmvn_file)
37
- assert stats.shape[0] == 2
38
- dim = stats.shape[-1] - 1
39
- count = stats[0, dim]
40
- assert count >= 1
41
- floor = 1e-20
42
- means = []
43
- inverse_std_variences = []
44
- for d in range(dim):
45
- mean = stats[0, d] / count
46
- means.append(mean.item())
47
- varience = (stats[1, d] / count) - mean*mean
48
- if varience < floor:
49
- varience = floor
50
- istd = 1.0 / math.sqrt(varience)
51
- inverse_std_variences.append(istd)
52
- return means, inverse_std_variences
53
-
54
-
55
- def add_meta_data(filename: str, meta_data: Dict[str, Any]):
56
- """Add meta data to an ONNX model. It is changed in-place.
57
-
58
- Args:
59
- filename:
60
- Filename of the ONNX model to be changed.
61
- meta_data:
62
- Key-value pairs.
63
- """
64
- model = onnx.load(filename)
65
-
66
- while len(model.metadata_props):
67
- model.metadata_props.pop()
68
-
69
- for key, value in meta_data.items():
70
- meta = model.metadata_props.add()
71
- meta.key = key
72
- meta.value = str(value)
73
-
74
- onnx.save(model, filename)
75
-
76
-
77
- def calc_feat_len(audio_dur):
78
- import math
79
- sample_rate = 16000
80
- frame_length = 25 * sample_rate / 1000
81
- frame_shift = 10 * sample_rate / 1000
82
- length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
83
- return length
84
-
85
-
86
- def export_encoder(fireredasr_model, args, model_args):
87
- encoder = model_wrapper.AudioEncoderTensorCache(
88
- fireredasr_model.encoder,
89
- fireredasr_model.decoder)
90
- encoder.eval()
91
-
92
- # forge encoder input
93
- encoder_input = torch.randn(1, calc_feat_len(10), 80)
94
- encoder_input_lengths = torch.tensor([100], dtype=torch.int64)
95
-
96
- n_layer_cross_k, n_layer_cross_v, cross_attn_mask = encoder(
97
- encoder_input,
98
- encoder_input_lengths
99
- )
100
-
101
- if not os.path.exists(args.encoder):
102
- os.makedirs(args.encoder)
103
- onnx_encoder_file = os.path.join(args.encoder, "encoder.onnx")
104
-
105
- with torch.no_grad():
106
- torch.onnx.export(
107
- encoder,
108
- (encoder_input, encoder_input_lengths),
109
- onnx_encoder_file,
110
- export_params=True,
111
- do_constant_folding=True,
112
- opset_version=16,
113
- verbose=False,
114
- input_names=["encoder_input",
115
- "encoder_input_lengths"],
116
- output_names=["n_layer_cross_k",
117
- "n_layer_cross_v",
118
- "cross_attn_mask"],
119
- # dynamic_axes={
120
- # "encoder_input": {
121
- # 0: "batch_size",
122
- # 1: "input_length"
123
- # },
124
- # "encoder_input_lengths": {
125
- # 0: "batch_size"
126
- # },
127
- # "n_layer_cross_k": {
128
- # 1: "batch_size",
129
- # 2: "length"
130
- # },
131
- # "n_layer_cross_v": {
132
- # 1: "batch_size",
133
- # 2: "length"
134
- # },
135
- # "cross_attn_mask": {
136
- # 0: "batch_size",
137
- # 2: "length"
138
- # }
139
- # },
140
- external_data=True
141
- )
142
-
143
- external_filename = os.path.basename(onnx_encoder_file).split(".onnx")[0]
144
- model = onnx.load(onnx_encoder_file)
145
- convert_model_to_external_data(
146
- model,
147
- all_tensors_to_one_file=True,
148
- location=f"./{external_filename}.data",
149
- size_threshold=0,
150
- convert_attribute=False
151
- )
152
-
153
- onnx.save_model(
154
- model,
155
- onnx_encoder_file,
156
- save_as_external_data=True,
157
- all_tensors_to_one_file=True,
158
- location=f"./{external_filename}.data",
159
- size_threshold=0
160
- )
161
-
162
- onnx.checker.check_model(onnx_encoder_file, True)
163
- ort_session = onnxruntime.InferenceSession(onnx_encoder_file)
164
- onnx_encoder_input = to_numpy(encoder_input)
165
- onxx_encoder_input_lengths = to_numpy(encoder_input_lengths)
166
- ort_inputs = {ort_session.get_inputs()[0].name: onnx_encoder_input,
167
- ort_session.get_inputs()[1].name: onxx_encoder_input_lengths}
168
- ort_outputs = ort_session.run(None, ort_inputs)
169
-
170
- try:
171
- np.testing.assert_allclose(to_numpy(n_layer_cross_k), ort_outputs[0], rtol=1e-03, atol=1e-05)
172
- except AssertionError as e:
173
- print(e)
174
- try:
175
- np.testing.assert_allclose(to_numpy(n_layer_cross_v), ort_outputs[1], rtol=1e-03, atol=1e-05)
176
- except AssertionError as e:
177
- print(e)
178
- try:
179
- np.testing.assert_allclose(to_numpy(cross_attn_mask), ort_outputs[2], rtol=1e-03, atol=1e-05)
180
- except AssertionError as e:
181
- print(e)
182
-
183
- print("export onnx encoder done.")
184
-
185
- # Generate int8 quantization models
186
- # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
187
- print("Generate int8 quantization models")
188
-
189
- if not os.path.exists(args.encoder_int8):
190
- os.mkdir(args.encoder_int8)
191
- onnx_encoder_int8_file = "encoder_int8.onnx"
192
- onnx_encoder_int8_file = os.path.join(args.encoder_int8, onnx_encoder_int8_file)
193
- quantize_dynamic(
194
- model_input=onnx_encoder_file,
195
- model_output=onnx_encoder_int8_file,
196
- op_types_to_quantize=["MatMul"],
197
- weight_type=QuantType.QInt8,
198
- )
199
-
200
- cmvn_mean, cmvn_inv_stddev = read_kaldi_cmvn(args.cmvn)
201
- cmvn_mean = [str(m) for m in cmvn_mean]
202
- cmvn_inv_stddev = [str(istd) for istd in cmvn_inv_stddev]
203
-
204
- encoder_meta_data = {
205
- "model_type": "FireRedAsrAED-L",
206
- "maintainer": "LiangHu",
207
- "feat_dim": model_args.idim,
208
- "feat_type": "fbank",
209
- "num_decoder_layers": model_args.n_layers_dec,
210
- "num_head": model_args.n_head,
211
- "head_dim": model_args.d_model // model_args.n_head,
212
- "max_len": 448,
213
- "sos": model_args.sos_id,
214
- "eos": model_args.eos_id,
215
- "cmvn_mean": ','.join(cmvn_mean),
216
- "cmvn_inv_stddev": ','.join(cmvn_inv_stddev)
217
- }
218
-
219
- # add_meta_data(onnx_encoder_file, encoder_meta_data)
220
- add_meta_data(onnx_encoder_int8_file, encoder_meta_data)
221
-
222
- return n_layer_cross_k, n_layer_cross_v, cross_attn_mask
223
-
224
-
225
- def export_decoder(fireredasr_model, args,
226
- n_layer_cross_k,
227
- n_layer_cross_v,
228
- cross_attn_mask):
229
- beam_size = 3
230
-
231
- decoder = model_wrapper.TextDecoderTensorCache(
232
- fireredasr_model.decoder)
233
- decoder.eval()
234
-
235
- num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
236
- encoder_out_length = cross_attn_mask.size(-1)
237
-
238
- # preparing for batch beam search
239
- cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
240
- 1, beam_size, 1, 1).view(beam_size * batch_size, -1, encoder_out_length)
241
- n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
242
- 1, 1, beam_size, 1, 1
243
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
244
- n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
245
- 1, 1, beam_size, 1, 1
246
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
247
- tokens = torch.ones(beam_size * batch_size, 1).fill_(decoder.decoder.sos_id).long()
248
-
249
- n_layer_self_k_cache = torch.zeros(
250
- (
251
- len(decoder.blocks),
252
- batch_size * beam_size,
253
- 448,
254
- 1280
255
- )
256
- )
257
- n_layer_self_v_cache = torch.zeros(
258
- (
259
- len(decoder.blocks),
260
- batch_size * beam_size,
261
- 448,
262
- 1280
263
- )
264
- )
265
- offset = torch.zeros(1, dtype=torch.int64)
266
- self_attn_mask = torch.empty(batch_size * beam_size,
267
- tokens.shape[-1], tokens.shape[-1]
268
- ).fill_(-np.inf).triu_(1) # fill_(-np.inf)
269
- self_attn_mask = self_attn_mask[:, -1:, :]
270
-
271
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder(
272
- tokens,
273
- n_layer_self_k_cache,
274
- n_layer_self_v_cache,
275
- n_layer_cross_k,
276
- n_layer_cross_v,
277
- offset,
278
- self_attn_mask,
279
- cross_attn_mask
280
- )
281
-
282
- if not os.path.exists(args.decoder):
283
- os.makedirs(args.decoder)
284
- onnx_decoder_file = os.path.join(args.decoder, "decoder.onnx")
285
-
286
- with torch.no_grad():
287
- torch.onnx.export(
288
- decoder,
289
- (tokens,
290
- n_layer_self_k_cache,
291
- n_layer_self_v_cache,
292
- n_layer_cross_k,
293
- n_layer_cross_v,
294
- offset,
295
- self_attn_mask,
296
- cross_attn_mask),
297
- onnx_decoder_file,
298
- export_params=True,
299
- opset_version=13,
300
- verbose=False,
301
- input_names=["tokens",
302
- "in_n_layer_self_k_cache",
303
- "in_n_layer_self_v_cache",
304
- "n_layer_cross_k",
305
- "n_layer_cross_v",
306
- "offset",
307
- "self_attn_mask",
308
- "cross_attn_mask"],
309
- output_names=["logits",
310
- "out_n_layer_self_k_cache",
311
- "out_n_layer_self_v_cache"],
312
- dynamic_axes={
313
- "tokens": {0: "n_audio", 1: "n_tokens"},
314
- "in_n_layer_self_k_cache": {1: "n_audio"},
315
- "in_n_layer_self_v_cache": {1: "n_audio"},
316
- "n_layer_cross_k": {1: "n_audio", 2: "T"},
317
- "n_layer_cross_v": {1: "n_audio", 2: "T"},
318
- "self_attn_mask": {0: "n_audio", 2: "T"},
319
- "cross_attn_mask": {0: "n_audio", 2: "T"},
320
- },
321
- external_data=True
322
- )
323
-
324
- onnx.checker.check_model(onnx_decoder_file)
325
- ort_session = onnxruntime.InferenceSession(onnx_decoder_file)
326
-
327
- onnx_tokens = to_numpy(tokens)
328
- onnx_n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
329
- onnx_n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
330
- onnx_n_layer_cross_k = to_numpy(n_layer_cross_k)
331
- onnx_n_layer_cross_v = to_numpy(n_layer_cross_v)
332
- onnx_offset = to_numpy(offset)
333
- onnx_self_attn_mask = to_numpy(self_attn_mask)
334
- onnx_cross_attn_mask = to_numpy(cross_attn_mask)
335
-
336
- ort_inputs = {ort_session.get_inputs()[0].name: onnx_tokens,
337
- ort_session.get_inputs()[1].name: onnx_n_layer_self_k_cache,
338
- ort_session.get_inputs()[2].name: onnx_n_layer_self_v_cache,
339
- ort_session.get_inputs()[3].name: onnx_n_layer_cross_k,
340
- ort_session.get_inputs()[4].name: onnx_n_layer_cross_v,
341
- ort_session.get_inputs()[5].name: onnx_offset,
342
- ort_session.get_inputs()[6].name: onnx_self_attn_mask,
343
- ort_session.get_inputs()[7].name: onnx_cross_attn_mask}
344
- ort_outputs = ort_session.run(None, ort_inputs)
345
-
346
- try:
347
- np.testing.assert_allclose(to_numpy(logits), ort_outputs[0], rtol=1e-03, atol=1e-05)
348
- except AssertionError as e:
349
- print(e)
350
- try:
351
- np.testing.assert_allclose(to_numpy(out_n_layer_self_k_cache), ort_outputs[1], rtol=1e-03, atol=1e-05)
352
- except AssertionError as e:
353
- print(e)
354
- try:
355
- np.testing.assert_allclose(to_numpy(out_n_layer_self_v_cache), ort_outputs[2], rtol=1e-03, atol=1e-05)
356
- except AssertionError as e:
357
- print(e)
358
-
359
- print("export onnx decoder done.")
360
-
361
- if not os.path.exists(args.decoder_int8):
362
- os.mkdir(args.decoder_int8)
363
- onnx_decoder_int8_file = "decoder_int8.onnx"
364
- onnx_decoder_int8_file = os.path.join(args.decoder_int8, onnx_decoder_int8_file)
365
- quantize_dynamic(
366
- model_input=onnx_decoder_file,
367
- model_output=onnx_decoder_int8_file,
368
- op_types_to_quantize=["MatMul"],
369
- weight_type=QuantType.QInt8,
370
- )
371
-
372
- # decoder main
373
- decoder = model_wrapper.TextDecoderTensorCacheV2(
374
- fireredasr_model.decoder, loop=False)
375
- decoder.eval()
376
-
377
- self_attn_mask = torch.empty(batch_size * beam_size,
378
- tokens.shape[-1], tokens.shape[-1]
379
- ).fill_(-np.inf).triu_(1) # fill_(-np.inf)
380
- self_attn_mask = self_attn_mask[:, -1:, :]
381
-
382
- pe = decoder.decoder.positional_encoding.pe[0]
383
-
384
- onnx_decoder_file = os.path.join(args.decoder, "decoder_main.onnx")
385
-
386
- with torch.no_grad():
387
- torch.onnx.export(
388
- decoder,
389
- (tokens,
390
- n_layer_self_k_cache,
391
- n_layer_self_v_cache,
392
- n_layer_cross_k,
393
- n_layer_cross_v,
394
- pe[0],
395
- self_attn_mask,
396
- cross_attn_mask),
397
- onnx_decoder_file,
398
- export_params=True,
399
- opset_version=13,
400
- verbose=False,
401
- input_names=["tokens",
402
- "in_n_layer_self_k_cache",
403
- "in_n_layer_self_v_cache",
404
- "n_layer_cross_k",
405
- "n_layer_cross_v",
406
- "pe",
407
- "self_attn_mask",
408
- "cross_attn_mask"],
409
- output_names=["logits",
410
- "out_n_layer_self_k_cache",
411
- "out_n_layer_self_v_cache"],
412
- # dynamic_axes={
413
- # "tokens": {0: "n_audio", 1: "n_tokens"},
414
- # "in_n_layer_self_k_cache": {1: "n_audio"},
415
- # "in_n_layer_self_v_cache": {1: "n_audio"},
416
- # "n_layer_cross_k": {1: "n_audio", 2: "T"},
417
- # "n_layer_cross_v": {1: "n_audio", 2: "T"},
418
- # "self_attn_mask": {0: "n_audio", 2: "T"},
419
- # "cross_attn_mask": {0: "n_audio", 2: "T"},
420
- # },
421
- external_data=True
422
- )
423
- print(f"Export decoder_main to {onnx_decoder_file}")
424
-
425
- # decoder loop
426
- decoder = model_wrapper.TextDecoderTensorCacheV2(
427
- fireredasr_model.decoder, loop=True)
428
- decoder.eval()
429
-
430
- pe = decoder.decoder.positional_encoding.pe[0]
431
- pe_file = os.path.join(args.decoder, "pe.npy")
432
- np.save(pe_file, pe.numpy())
433
-
434
- onnx_decoder_file = os.path.join(args.decoder, "decoder_loop.onnx")
435
-
436
- with torch.no_grad():
437
- torch.onnx.export(
438
- decoder,
439
- (tokens,
440
- n_layer_self_k_cache,
441
- n_layer_self_v_cache,
442
- n_layer_cross_k,
443
- n_layer_cross_v,
444
- pe[0],
445
- self_attn_mask,
446
- cross_attn_mask),
447
- onnx_decoder_file,
448
- export_params=True,
449
- opset_version=13,
450
- verbose=False,
451
- input_names=["tokens",
452
- "in_n_layer_self_k_cache",
453
- "in_n_layer_self_v_cache",
454
- "n_layer_cross_k",
455
- "n_layer_cross_v",
456
- "pe",
457
- "self_attn_mask",
458
- "cross_attn_mask"],
459
- output_names=["logits",
460
- "out_n_layer_self_k_cache",
461
- "out_n_layer_self_v_cache"],
462
- # dynamic_axes={
463
- # "tokens": {0: "n_audio", 1: "n_tokens"},
464
- # "in_n_layer_self_k_cache": {1: "n_audio"},
465
- # "in_n_layer_self_v_cache": {1: "n_audio"},
466
- # "n_layer_cross_k": {1: "n_audio", 2: "T"},
467
- # "n_layer_cross_v": {1: "n_audio", 2: "T"},
468
- # "self_attn_mask": {0: "n_audio", 2: "T"},
469
- # "cross_attn_mask": {0: "n_audio", 2: "T"},
470
- # },
471
- external_data=True
472
- )
473
- print(f"Export decoder_loop to {onnx_decoder_file}")
474
-
475
-
476
- def parse_args():
477
- parser = argparse.ArgumentParser(description="export FireRedASR-AED torch model to onnx")
478
- parser.add_argument(
479
- "--model",
480
- type=str,
481
- required=True,
482
- help="Path to FireRedASR-AED torch model"
483
- )
484
- parser.add_argument(
485
- "--encoder",
486
- type=str,
487
- required=True,
488
- help="Dir to the exported onnx encoder"
489
- )
490
- parser.add_argument(
491
- "--decoder",
492
- type=str,
493
- required=True,
494
- help="Dir to the exported onnx decoder"
495
- )
496
- parser.add_argument(
497
- "--encoder_int8",
498
- type=str,
499
- required=True,
500
- help="Dir to the exported onnx encoder after int8 quantization"
501
- )
502
- parser.add_argument(
503
- "--decoder_int8",
504
- type=str,
505
- required=True,
506
- help="Dir to the exported onnx encoder after int8 quantization"
507
- )
508
- parser.add_argument(
509
- "--cmvn",
510
- type=str,
511
- required=True,
512
- help="cmvn.ark file"
513
- )
514
- return parser.parse_args()
515
-
516
-
517
- def main():
518
- args = parse_args()
519
- fireredasr_model, model_args = load_model(args.model)
520
- n_layer_cross_k, n_layer_cross_v, cross_attn_mask = export_encoder(fireredasr_model, args, model_args)
521
- export_decoder(fireredasr_model, args, n_layer_cross_k, n_layer_cross_v, cross_attn_mask)
522
-
523
-
524
- if __name__ == "__main__":
525
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_decoder.py DELETED
@@ -1,640 +0,0 @@
1
- from fireredasr.data.asr_feat import ASRFeatExtractor
2
- from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
3
-
4
- import onnxruntime as ort
5
- # import axengine as axe
6
- import torch
7
- import torch.nn.functional as F
8
- import numpy as np
9
- from torch import Tensor
10
- from typing import Tuple, List, Dict
11
- import argparse
12
- import os
13
- import time
14
- import logging
15
-
16
- logger = logging.getLogger()
17
- logger.setLevel(logging.INFO)
18
- logger_stream_hander = logging.StreamHandler()
19
- logger_stream_hander.setLevel("INFO")
20
- logger.addHandler(logger_stream_hander)
21
-
22
-
23
- INF = 1e10
24
-
25
-
26
- def to_numpy(tensor):
27
- if isinstance(tensor, np.ndarray):
28
- return tensor
29
- if tensor.requires_grad:
30
- return tensor.detach().cpu().numpy()
31
- else:
32
- return tensor.cpu().numpy()
33
-
34
-
35
- def set_finished_beam_score_to_zero(scores, is_finished):
36
- NB, B = scores.size()
37
- is_finished = is_finished.float()
38
- mask_score = torch.tensor([0.0] + [-INF]*(B-1)).float()
39
- mask_score = mask_score.view(1, B).repeat(NB, 1)
40
- return scores * (1 - is_finished) + mask_score * is_finished
41
-
42
-
43
- def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
44
- is_finished = is_finished.long()
45
- return ys * (1 - is_finished) + eos_id * is_finished
46
-
47
-
48
- class FireRedASROnnxModel:
49
- def __init__(
50
- self,
51
- encoder_path: str,
52
- decoder_path: str,
53
- cmvn_file: str,
54
- dict_file: str,
55
- spm_model_path: str,
56
- providers=['CPUExecutionProvider']
57
- ):
58
- session_opts = ort.SessionOptions()
59
- session_opts.inter_op_num_threads = 1
60
- session_opts.intra_op_num_threads = 1
61
- # session_opts.log_severity_level = 1
62
- self.session_opts = session_opts
63
-
64
- # NOTE: 参考whisper设置的最大的解码长度
65
- # FireRedASR-AED 模型支持的最长语音为 60s
66
- # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
67
- self.decode_max_len = 448
68
-
69
- self.decoder_hidden_dim = 1280
70
- self.num_decoder_blocks = 16
71
- self.blank_id = 0
72
- self.sos_id = 3
73
- self.eos_id = 4
74
- self.pad_id = 2
75
-
76
- self.feature_extractor = ASRFeatExtractor(cmvn_file)
77
- self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
78
- self.encoder = None
79
- self.decoder = None
80
-
81
- # self.init_encoder(encoder_path, providers)
82
- # self.init_decoder(decoder_path, providers)
83
- self.init_decoder_main(decoder_path, providers)
84
- self.init_decoder_loop(decoder_path, providers)
85
- self.pe = self.init_pe(decoder_path)
86
-
87
- # def init_encoder(self, encoder_path, providers=None):
88
- # start_time = time.time()
89
- # self.encoder = axe.InferenceSession(
90
- # encoder_path,
91
- # # sess_options=self.session_opts,
92
- # providers=providers
93
- # )
94
- # end_time = time.time()
95
- # logger.info(f"load encoder cost {end_time - start_time} seconds")
96
-
97
- def init_decoder(self, decoder_path, providers=None):
98
- start_time = time.time()
99
- self.decoder = ort.InferenceSession(
100
- decoder_path,
101
- sess_options=self.session_opts,
102
- providers=providers
103
- )
104
- end_time = time.time()
105
- logger.info(f"load decoder cost {end_time - start_time} seconds")
106
-
107
- def init_decoder_main(self, decoder_path, providers=None):
108
- decoder_path = os.path.dirname(decoder_path)
109
- decoder_path = os.path.join(decoder_path, "decoder_main.onnx")
110
- start_time = time.time()
111
- self.decoder_main = ort.InferenceSession(
112
- decoder_path,
113
- sess_options=self.session_opts,
114
- providers=providers
115
- )
116
- end_time = time.time()
117
- logger.info(f"load decoder_main cost {end_time - start_time} seconds")
118
-
119
- input_names = [i.name for i in self.decoder_main.get_inputs()]
120
- print(f"decoder_main.input_names: {input_names}")
121
-
122
- def init_decoder_loop(self, decoder_path, providers=None):
123
- decoder_path = os.path.dirname(decoder_path)
124
- decoder_path = os.path.join(decoder_path, "decoder_loop.onnx")
125
-
126
- start_time = time.time()
127
- self.decoder_loop = ort.InferenceSession(
128
- decoder_path,
129
- sess_options=self.session_opts,
130
- providers=providers
131
- )
132
- end_time = time.time()
133
- logger.info(f"load decoder_loop cost {end_time - start_time} seconds")
134
-
135
- input_names = [i.name for i in self.decoder_loop.get_inputs()]
136
- print(f"decoder_loop.input_names: {input_names}")
137
-
138
- def init_pe(self, decoder_path):
139
- decoder_path = os.path.dirname(decoder_path)
140
- decoder_path = os.path.join(decoder_path, "pe.npy")
141
-
142
- return np.load(decoder_path)
143
-
144
- def run_encoder(self, input: np.ndarray,
145
- input_length: np.ndarray
146
- ) -> Tuple[Tensor, Tensor, Tensor]:
147
- n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
148
- None,
149
- {
150
- "encoder_input": input,
151
- "encoder_input_lengths": input_length.astype(np.int32)
152
- }
153
- )
154
- return (
155
- n_layer_cross_k,
156
- n_layer_cross_v,
157
- cross_attn_mask
158
- )
159
-
160
- def decode_one_token(
161
- self,
162
- tokens: np.ndarray,
163
- n_layer_self_k_cache: np.ndarray,
164
- n_layer_self_v_cache: np.ndarray,
165
- n_layer_cross_k_cache: np.ndarray,
166
- n_layer_cross_v_cache: np.ndarray,
167
- offset: np.ndarray,
168
- self_attn_mask: np.ndarray,
169
- cross_attn_mask: np.ndarray
170
- ) -> Tuple[Tensor, Tensor, Tensor]:
171
- print("decode:")
172
- print(f"tokens.shape: {tokens.shape}")
173
- print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
174
- print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
175
- print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
176
- print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
177
- print(f"offset.shape: {offset.shape}")
178
- print(f"self_attn_mask.shape: {self_attn_mask.shape}")
179
- print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
180
- # print(f"self_attn_mask: {self_attn_mask}")
181
-
182
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
183
- None,
184
- {
185
- self.decoder.get_inputs()[0].name: tokens,
186
- self.decoder.get_inputs()[1].name: n_layer_self_k_cache,
187
- self.decoder.get_inputs()[2].name: n_layer_self_v_cache,
188
- self.decoder.get_inputs()[3].name: n_layer_cross_k_cache,
189
- self.decoder.get_inputs()[4].name: n_layer_cross_v_cache,
190
- self.decoder.get_inputs()[5].name: offset,
191
- self.decoder.get_inputs()[6].name: self_attn_mask,
192
- self.decoder.get_inputs()[7].name: cross_attn_mask,
193
- }
194
- )
195
- return (
196
- logits,
197
- out_n_layer_self_k_cache,
198
- out_n_layer_self_v_cache
199
- )
200
-
201
- def decode_main_one_token(
202
- self,
203
- tokens: np.ndarray,
204
- n_layer_self_k_cache: np.ndarray,
205
- n_layer_self_v_cache: np.ndarray,
206
- n_layer_cross_k_cache: np.ndarray,
207
- n_layer_cross_v_cache: np.ndarray,
208
- pe: np.ndarray,
209
- self_attn_mask: np.ndarray,
210
- cross_attn_mask: np.ndarray
211
- ) -> Tuple[Tensor, Tensor, Tensor]:
212
- # print("decode_main:")
213
- # print(f"tokens.shape: {tokens.shape}")
214
- # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
215
- # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
216
- # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
217
- # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
218
- # print(f"pe.shape: {pe.shape}")
219
- # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
220
- # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
221
-
222
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_main.run(
223
- None,
224
- {
225
- self.decoder_main.get_inputs()[0].name: tokens,
226
- # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
227
- self.decoder_main.get_inputs()[1].name: n_layer_cross_k_cache,
228
- self.decoder_main.get_inputs()[2].name: n_layer_cross_v_cache,
229
- self.decoder_main.get_inputs()[3].name: pe,
230
- self.decoder_main.get_inputs()[4].name: self_attn_mask,
231
- self.decoder_main.get_inputs()[5].name: cross_attn_mask,
232
- # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
233
- }
234
- )
235
- return (
236
- logits,
237
- out_n_layer_self_k_cache,
238
- out_n_layer_self_v_cache
239
- )
240
-
241
- def decode_loop_one_token(
242
- self,
243
- tokens: np.ndarray,
244
- n_layer_self_k_cache: np.ndarray,
245
- n_layer_self_v_cache: np.ndarray,
246
- n_layer_cross_k_cache: np.ndarray,
247
- n_layer_cross_v_cache: np.ndarray,
248
- pe: np.ndarray,
249
- self_attn_mask: np.ndarray,
250
- cross_attn_mask: np.ndarray
251
- ) -> Tuple[Tensor, Tensor, Tensor]:
252
- # print("decode_loop:")
253
- # print(f"tokens.shape: {tokens.shape}")
254
- # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
255
- # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
256
- # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
257
- # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
258
- # print(f"pe.shape: {pe.shape}")
259
- # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
260
- # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
261
-
262
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_loop.run(
263
- None,
264
- {
265
- self.decoder_loop.get_inputs()[0].name: tokens,
266
- self.decoder_loop.get_inputs()[1].name: n_layer_self_k_cache,
267
- self.decoder_loop.get_inputs()[2].name: n_layer_self_v_cache,
268
- self.decoder_loop.get_inputs()[3].name: n_layer_cross_k_cache,
269
- self.decoder_loop.get_inputs()[4].name: n_layer_cross_v_cache,
270
- self.decoder_loop.get_inputs()[5].name: pe,
271
- self.decoder_loop.get_inputs()[6].name: self_attn_mask,
272
- self.decoder_loop.get_inputs()[7].name: cross_attn_mask,
273
- }
274
- )
275
- return (
276
- logits,
277
- out_n_layer_self_k_cache,
278
- out_n_layer_self_v_cache
279
- )
280
-
281
- def run_decoder(
282
- self,
283
- n_layer_cross_k,
284
- n_layer_cross_v,
285
- cross_attn_mask,
286
- beam_size,
287
- nbest
288
- ):
289
-
290
- num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
291
- encoder_out_length = cross_attn_mask.shape[-1]
292
-
293
- cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
294
- cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
295
- 1, beam_size, 1, 1
296
- ).view(beam_size * batch_size, -1, encoder_out_length)
297
-
298
- n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
299
- n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
300
- n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
301
- 1, 1, beam_size, 1, 1
302
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
303
- n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
304
- 1, 1, beam_size, 1, 1
305
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
306
-
307
- prediction_tokens = torch.ones(
308
- beam_size * batch_size, 1).fill_(self.sos_id).long()
309
- tokens = prediction_tokens
310
- offset = torch.zeros(1, dtype=torch.int64)
311
- n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
312
- batch_size, beam_size
313
- )
314
-
315
- scores = torch.tensor([0.0] + [-INF]*(beam_size - 1)).float()
316
- scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
317
- is_finished = torch.zeros_like(scores)
318
-
319
- # self_attn_mask = torch.zeros(
320
- # batch_size * beam_size,
321
- # 1, 1
322
- # )
323
- self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
324
-
325
- results = [self.sos_id]
326
- for i in range(self.decode_max_len):
327
-
328
- # self_attn_mask = torch.empty(
329
- # batch_size * beam_size,
330
- # prediction_tokens.shape[-1], prediction_tokens.shape[-1]
331
- # ).fill_(-np.inf).triu_(1)
332
- # self_attn_mask = self_attn_mask[:, -1:, :]
333
- # self_attn_mask = to_numpy(self_attn_mask)
334
-
335
- # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_one_token(
336
- # to_numpy(tokens),
337
- # to_numpy(n_layer_self_k_cache),
338
- # to_numpy(n_layer_self_v_cache),
339
- # to_numpy(n_layer_cross_k),
340
- # to_numpy(n_layer_cross_v),
341
- # to_numpy(offset),
342
- # to_numpy(self_attn_mask),
343
- # to_numpy(cross_attn_mask)
344
- # )
345
-
346
- tokens = to_numpy(tokens)
347
- n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
348
- n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
349
- n_layer_cross_k = to_numpy(n_layer_cross_k)
350
- n_layer_cross_v = to_numpy(n_layer_cross_v)
351
- cross_attn_mask = to_numpy(cross_attn_mask)
352
-
353
- if i == 0:
354
- logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
355
- to_numpy(tokens),
356
- to_numpy(n_layer_self_k_cache),
357
- to_numpy(n_layer_self_v_cache),
358
- to_numpy(n_layer_cross_k),
359
- to_numpy(n_layer_cross_v),
360
- self.pe[offset],
361
- self_attn_mask,
362
- to_numpy(cross_attn_mask)
363
- )
364
- else:
365
- logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
366
- to_numpy(tokens),
367
- to_numpy(n_layer_self_k_cache),
368
- to_numpy(n_layer_self_v_cache),
369
- to_numpy(n_layer_cross_k),
370
- to_numpy(n_layer_cross_v),
371
- self.pe[offset],
372
- self_attn_mask,
373
- to_numpy(cross_attn_mask)
374
- )
375
-
376
- offset += 1
377
- logits = torch.from_numpy(logits)
378
-
379
- logits = logits.squeeze(1)
380
- t_scores = F.log_softmax(logits, dim=-1)
381
- t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
382
- t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
383
- t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
384
-
385
- scores = scores + t_topB_scores
386
-
387
- scores = scores.view(batch_size, beam_size * beam_size)
388
- scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
389
- scores = scores.view(-1, 1)
390
-
391
- topB_row_number_in_each_B_rows_of_ys = torch.div(
392
- topB_score_ids, beam_size).view(batch_size * beam_size)
393
- stride = beam_size * torch.arange(batch_size).view(
394
- batch_size, 1).repeat(1, beam_size).view(batch_size * beam_size)
395
- topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
396
-
397
- prediction_tokens = prediction_tokens[topB_row_number_in_ys]
398
- t_ys = torch.gather(
399
- t_topB_ys.view(batch_size, beam_size * beam_size),
400
- dim=1, index=topB_score_ids
401
- ).view(beam_size * batch_size, 1)
402
-
403
- tokens = t_ys
404
-
405
- prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
406
-
407
- n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
408
- n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
409
-
410
- for i, self_k_cache in enumerate(n_layer_self_k_cache):
411
- n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
412
-
413
- for i, self_v_cache in enumerate(n_layer_self_v_cache):
414
- n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
415
-
416
- is_finished = t_ys.eq(self.eos_id)
417
- if is_finished.sum().item() == beam_size * batch_size:
418
- break
419
-
420
- scores = scores.view(batch_size, beam_size)
421
- prediction_valid_token_lengths = torch.sum(
422
- torch.ne(
423
- prediction_tokens.view(batch_size, beam_size, -1),
424
- self.eos_id),
425
- dim=-1
426
- ).int()
427
-
428
- nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
429
- index = nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
430
- nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
431
- nbest_prediction_tokens = nbest_prediction_tokens.view(batch_size, nbest_ids.size(1), -1)
432
- nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
433
- batch_size * beam_size)[index.view(-1)].view(batch_size, -1)
434
- nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
435
- for i in range(batch_size):
436
- i_best_hyps: List[Dict[str, torch.Tensor]] = []
437
- for j, score in enumerate(nbest_scores[i]):
438
- hyp = {
439
- "token_ids": nbest_prediction_tokens[i, j, 1:nbest_prediction_valid_token_lengths[i, j]],
440
- "score": score
441
- }
442
- i_best_hyps.append(hyp)
443
- nbest_hyps.append(i_best_hyps)
444
-
445
- return nbest_hyps
446
-
447
- def get_initialized_self_cache(self,
448
- batch_size,
449
- beam_size
450
- ) -> Tuple[Tensor, Tensor]:
451
- n_layer_self_k_cache = torch.zeros(
452
- self.num_decoder_blocks,
453
- batch_size * beam_size,
454
- self.decode_max_len,
455
- self.decoder_hidden_dim,
456
- )
457
- n_layer_self_v_cache = torch.zeros(
458
- self.num_decoder_blocks,
459
- batch_size * beam_size,
460
- self.decode_max_len,
461
- self.decoder_hidden_dim,
462
- )
463
- return n_layer_self_k_cache, n_layer_self_v_cache
464
-
465
- def calc_feat_len(self, audio_dur):
466
- import math
467
- sample_rate = 16000
468
- frame_length = 25 * sample_rate / 1000
469
- frame_shift = 10 * sample_rate / 1000
470
- length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
471
- return length
472
-
473
- def transcribe(self,
474
- batch_wav_path: List[str],
475
- beam_size: int = 1,
476
- nbest: int = 1
477
- ) -> List[Dict]:
478
- feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
479
- print(f"feats.shape: {feats.shape}")
480
- maxlen = self.calc_feat_len(10)
481
- if feats.shape[1] < maxlen:
482
- feats = np.concatenate([feats, np.zeros((1, maxlen - feats.shape[1], 80), dtype=np.float32)], axis=1)
483
- feats = feats[:, :maxlen, :]
484
-
485
- encoder_data_path = os.path.join("encoder_output", os.path.basename(batch_wav_path[0]))
486
- # decoder_data_path = os.path.join("calib_dataset", "decoder", os.path.basename(batch_wav_path[0]))
487
- # os.makedirs(encoder_data_path, exist_ok=True)
488
- # os.makedirs(decoder_data_path, exist_ok=True)
489
-
490
- n_layer_cross_k = np.load(os.path.join(encoder_data_path, "n_layer_cross_k.npy"))
491
- n_layer_cross_v = np.load(os.path.join(encoder_data_path, "n_layer_cross_v.npy"))
492
- cross_attn_mask = np.load(os.path.join(encoder_data_path, "cross_attn_mask.npy"))
493
-
494
- # for name, npy in zip(["encoder_input", "encoder_input_lengths"], [feats, lengths]):
495
- # file_path = os.path.join(encoder_data_path, name + ".npy")
496
- # np.save(file_path, npy)
497
-
498
- start_time = time.time()
499
-
500
- nbest_hyps = self.run_decoder(n_layer_cross_k,
501
- n_layer_cross_v,
502
- cross_attn_mask,
503
- beam_size,
504
- nbest
505
- )
506
- transcribe_durations = time.time() - start_time
507
- results: List[Dict] = []
508
- for wav, hyp in zip(batch_wav_path, nbest_hyps):
509
- hyp = hyp[0]
510
- hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
511
- score = hyp["score"].item()
512
- text = self.tokenizer.detokenize(hyp_ids)
513
- results.append(
514
- {
515
- "wav": wav,
516
- "text": text,
517
- "score": score
518
- }
519
- )
520
-
521
- return results, wav_durations, transcribe_durations
522
-
523
-
524
- def parse_args():
525
- parser = argparse.ArgumentParser(description="FireRedASROnnxModel Test")
526
- parser.add_argument(
527
- "--encoder",
528
- type=str,
529
- default="axmodel/encoder.axmodel",
530
- help="Path to onnx encoder"
531
- )
532
- parser.add_argument(
533
- "--decoder",
534
- type=str,
535
- default="onnx_decoder/decoder_main.onnx",
536
- help="Path to onnx decoder"
537
- )
538
- parser.add_argument(
539
- "--cmvn",
540
- type=str,
541
- default="axmodel/cmvn.ark",
542
- help="Path to cmvn"
543
- )
544
- parser.add_argument(
545
- "--dict",
546
- type=str,
547
- default="axmodel/dict.txt",
548
- help="Path to dict"
549
- )
550
- parser.add_argument(
551
- "--spm_model",
552
- type=str,
553
- default="axmodel/train_bpe1000.model",
554
- help="Path to spm model"
555
- )
556
- parser.add_argument(
557
- "--wavlist",
558
- type=str,
559
- default="wavlist.txt",
560
- help="File to wav path list"
561
- )
562
- parser.add_argument(
563
- "--hypo",
564
- type=str,
565
- default="hypo_encoder.txt",
566
- help="File of hypos"
567
- )
568
- parser.add_argument(
569
- "--beam_size",
570
- type=int,
571
- default=3,
572
- help=""
573
- )
574
- parser.add_argument(
575
- "--nbest",
576
- type=int,
577
- default=1,
578
- help=""
579
- )
580
-
581
- return parser.parse_args()
582
-
583
-
584
- def parse_wavlist(wavlist: str):
585
- wavpaths = []
586
- with open(wavlist) as f:
587
- for line in f:
588
- line = line.strip()
589
- if not os.path.exists(line):
590
- print(f"{line} doesn't exist.")
591
- continue
592
- wavpaths.append(line)
593
-
594
- return wavpaths
595
-
596
-
597
- def main():
598
- args = parse_args()
599
- print(args)
600
-
601
- onnx_model = FireRedASROnnxModel(args.encoder,
602
- args.decoder,
603
- args.cmvn,
604
- args.dict,
605
- args.spm_model)
606
-
607
- wf = open(args.hypo, "wt")
608
- wavlist = parse_wavlist(args.wavlist)
609
-
610
- total_wav_durations = 0
611
- total_transcribe_durations = 0
612
- for wav in wavlist:
613
- batch_wav = [wav]
614
- results, wav_durations, transcribe_durations = onnx_model.transcribe(batch_wav, args.beam_size, args.nbest)
615
-
616
- wav_durations = sum(wav_durations)
617
- total_wav_durations += wav_durations
618
- total_transcribe_durations += transcribe_durations
619
- logger.info(f"{batch_wav}")
620
- logger.info(f"Durations: {wav_durations}")
621
- logger.info(f"Transcribe Durations: {transcribe_durations}")
622
- rtf = transcribe_durations / wav_durations
623
- logger.info(f"(Real time factor) RTF: {rtf}")
624
- for result in results:
625
- logger.info(f"wav: {result['wav']}")
626
- logger.info(f"text: {result['text']}")
627
- logger.info(f"score: {result['score']}")
628
- logger.info("")
629
- wf.write(f"{result['text']} ({result['wav']})\n")
630
-
631
- logger.info(f"total wav durations: {total_wav_durations}")
632
- logger.info(f"total transcribe durations: {total_transcribe_durations}")
633
- avg_ref = total_transcribe_durations / total_wav_durations
634
- logger.info(f"AVG RTF: {avg_ref}")
635
-
636
- wf.close()
637
-
638
-
639
- if __name__ == "__main__":
640
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_encoder.py DELETED
@@ -1,646 +0,0 @@
1
- from fireredasr.data.asr_feat import ASRFeatExtractor
2
- from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
3
-
4
- import onnxruntime as ort
5
- import axengine as axe
6
- import torch
7
- import torch.nn.functional as F
8
- import numpy as np
9
- from torch import Tensor
10
- from typing import Tuple, List, Dict
11
- import argparse
12
- import os
13
- import time
14
- import logging
15
-
16
- logger = logging.getLogger()
17
- logger.setLevel(logging.INFO)
18
- logger_stream_hander = logging.StreamHandler()
19
- logger_stream_hander.setLevel("INFO")
20
- logger.addHandler(logger_stream_hander)
21
-
22
-
23
- INF = 1e10
24
-
25
-
26
- def to_numpy(tensor):
27
- if isinstance(tensor, np.ndarray):
28
- return tensor
29
- if tensor.requires_grad:
30
- return tensor.detach().cpu().numpy()
31
- else:
32
- return tensor.cpu().numpy()
33
-
34
-
35
- def set_finished_beam_score_to_zero(scores, is_finished):
36
- NB, B = scores.size()
37
- is_finished = is_finished.float()
38
- mask_score = torch.tensor([0.0] + [-INF]*(B-1)).float()
39
- mask_score = mask_score.view(1, B).repeat(NB, 1)
40
- return scores * (1 - is_finished) + mask_score * is_finished
41
-
42
-
43
- def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
44
- is_finished = is_finished.long()
45
- return ys * (1 - is_finished) + eos_id * is_finished
46
-
47
-
48
- class FireRedASROnnxModel:
49
- def __init__(
50
- self,
51
- encoder_path: str,
52
- decoder_path: str,
53
- cmvn_file: str,
54
- dict_file: str,
55
- spm_model_path: str,
56
- providers=['AXCLRTExecutionProvider', 'AxEngineExecutionProvider']
57
- ):
58
- session_opts = ort.SessionOptions()
59
- session_opts.inter_op_num_threads = 1
60
- session_opts.intra_op_num_threads = 1
61
- # session_opts.log_severity_level = 1
62
- self.session_opts = session_opts
63
-
64
- # NOTE: 参考whisper设置的最大的解码长度
65
- # FireRedASR-AED 模型支持的最长语音为 60s
66
- # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
67
- self.decode_max_len = 448
68
-
69
- self.decoder_hidden_dim = 1280
70
- self.num_decoder_blocks = 16
71
- self.blank_id = 0
72
- self.sos_id = 3
73
- self.eos_id = 4
74
- self.pad_id = 2
75
-
76
- self.feature_extractor = ASRFeatExtractor(cmvn_file)
77
- self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
78
- self.encoder = None
79
- self.decoder = None
80
-
81
- self.init_encoder(encoder_path, providers)
82
- # self.init_decoder(decoder_path, providers)
83
- # self.init_decoder_main(decoder_path, providers)
84
- # self.init_decoder_loop(decoder_path, providers)
85
- self.pe = self.init_pe(decoder_path)
86
-
87
- def init_encoder(self, encoder_path, providers=None):
88
- start_time = time.time()
89
- self.encoder = axe.InferenceSession(
90
- encoder_path,
91
- # sess_options=self.session_opts,
92
- providers=providers
93
- )
94
- end_time = time.time()
95
- logger.info(f"load encoder cost {end_time - start_time} seconds")
96
-
97
- def init_decoder(self, decoder_path, providers=None):
98
- start_time = time.time()
99
- self.decoder = ort.InferenceSession(
100
- decoder_path,
101
- sess_options=self.session_opts,
102
- providers=['CPUExecutionProvider']
103
- )
104
- end_time = time.time()
105
- logger.info(f"load decoder cost {end_time - start_time} seconds")
106
-
107
- def init_decoder_main(self, decoder_path, providers=None):
108
- decoder_path = os.path.dirname(decoder_path)
109
- decoder_path = os.path.join(decoder_path, "decoder_main.onnx")
110
- start_time = time.time()
111
- self.decoder_main = ort.InferenceSession(
112
- decoder_path,
113
- sess_options=self.session_opts,
114
- providers=['CPUExecutionProvider']
115
- )
116
- end_time = time.time()
117
- logger.info(f"load decoder_main cost {end_time - start_time} seconds")
118
-
119
- input_names = [i.name for i in self.decoder_main.get_inputs()]
120
- print(f"decoder_main.input_names: {input_names}")
121
-
122
- def init_decoder_loop(self, decoder_path, providers=None):
123
- decoder_path = os.path.dirname(decoder_path)
124
- decoder_path = os.path.join(decoder_path, "decoder_loop.onnx")
125
-
126
- start_time = time.time()
127
- self.decoder_loop = ort.InferenceSession(
128
- decoder_path,
129
- sess_options=self.session_opts,
130
- providers=['CPUExecutionProvider']
131
- )
132
- end_time = time.time()
133
- logger.info(f"load decoder_loop cost {end_time - start_time} seconds")
134
-
135
- input_names = [i.name for i in self.decoder_loop.get_inputs()]
136
- print(f"decoder_loop.input_names: {input_names}")
137
-
138
- def init_pe(self, decoder_path):
139
- decoder_path = os.path.join("axmodel", "pe.npy")
140
-
141
- return np.load(decoder_path)
142
-
143
- def run_encoder(self, input: np.ndarray,
144
- input_length: np.ndarray
145
- ) -> Tuple[Tensor, Tensor, Tensor]:
146
- n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
147
- None,
148
- {
149
- "encoder_input": input,
150
- "encoder_input_lengths": input_length.astype(np.int32)
151
- }
152
- )
153
- return (
154
- n_layer_cross_k,
155
- n_layer_cross_v,
156
- cross_attn_mask
157
- )
158
-
159
- def decode_one_token(
160
- self,
161
- tokens: np.ndarray,
162
- n_layer_self_k_cache: np.ndarray,
163
- n_layer_self_v_cache: np.ndarray,
164
- n_layer_cross_k_cache: np.ndarray,
165
- n_layer_cross_v_cache: np.ndarray,
166
- offset: np.ndarray,
167
- self_attn_mask: np.ndarray,
168
- cross_attn_mask: np.ndarray
169
- ) -> Tuple[Tensor, Tensor, Tensor]:
170
- # print("decode:")
171
- # print(f"tokens.shape: {tokens.shape}")
172
- # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
173
- # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
174
- # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
175
- # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
176
- # print(f"offset.shape: {offset.shape}")
177
- # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
178
- # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
179
- # print(f"self_attn_mask: {self_attn_mask}")
180
-
181
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
182
- None,
183
- {
184
- self.decoder.get_inputs()[0].name: tokens,
185
- self.decoder.get_inputs()[1].name: n_layer_self_k_cache,
186
- self.decoder.get_inputs()[2].name: n_layer_self_v_cache,
187
- self.decoder.get_inputs()[3].name: n_layer_cross_k_cache,
188
- self.decoder.get_inputs()[4].name: n_layer_cross_v_cache,
189
- self.decoder.get_inputs()[5].name: offset,
190
- self.decoder.get_inputs()[6].name: self_attn_mask,
191
- self.decoder.get_inputs()[7].name: cross_attn_mask,
192
- }
193
- )
194
- return (
195
- logits,
196
- out_n_layer_self_k_cache,
197
- out_n_layer_self_v_cache
198
- )
199
-
200
- def decode_main_one_token(
201
- self,
202
- tokens: np.ndarray,
203
- n_layer_self_k_cache: np.ndarray,
204
- n_layer_self_v_cache: np.ndarray,
205
- n_layer_cross_k_cache: np.ndarray,
206
- n_layer_cross_v_cache: np.ndarray,
207
- pe: np.ndarray,
208
- self_attn_mask: np.ndarray,
209
- cross_attn_mask: np.ndarray
210
- ) -> Tuple[Tensor, Tensor, Tensor]:
211
- # print("decode_main:")
212
- # print(f"tokens.shape: {tokens.shape}")
213
- # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
214
- # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
215
- # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
216
- # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
217
- # print(f"pe.shape: {pe.shape}")
218
- # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
219
- # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
220
-
221
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_main.run(
222
- None,
223
- {
224
- self.decoder_main.get_inputs()[0].name: tokens,
225
- # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
226
- self.decoder_main.get_inputs()[1].name: n_layer_cross_k_cache,
227
- self.decoder_main.get_inputs()[2].name: n_layer_cross_v_cache,
228
- self.decoder_main.get_inputs()[3].name: pe,
229
- self.decoder_main.get_inputs()[4].name: self_attn_mask,
230
- self.decoder_main.get_inputs()[5].name: cross_attn_mask,
231
- # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
232
- }
233
- )
234
- return (
235
- logits,
236
- out_n_layer_self_k_cache,
237
- out_n_layer_self_v_cache
238
- )
239
-
240
- def decode_loop_one_token(
241
- self,
242
- tokens: np.ndarray,
243
- n_layer_self_k_cache: np.ndarray,
244
- n_layer_self_v_cache: np.ndarray,
245
- n_layer_cross_k_cache: np.ndarray,
246
- n_layer_cross_v_cache: np.ndarray,
247
- pe: np.ndarray,
248
- self_attn_mask: np.ndarray,
249
- cross_attn_mask: np.ndarray
250
- ) -> Tuple[Tensor, Tensor, Tensor]:
251
- # print("decode_loop:")
252
- # print(f"tokens.shape: {tokens.shape}")
253
- # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
254
- # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
255
- # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
256
- # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
257
- # print(f"pe.shape: {pe.shape}")
258
- # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
259
- # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
260
-
261
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_loop.run(
262
- None,
263
- {
264
- self.decoder_loop.get_inputs()[0].name: tokens,
265
- self.decoder_loop.get_inputs()[1].name: n_layer_self_k_cache,
266
- self.decoder_loop.get_inputs()[2].name: n_layer_self_v_cache,
267
- self.decoder_loop.get_inputs()[3].name: n_layer_cross_k_cache,
268
- self.decoder_loop.get_inputs()[4].name: n_layer_cross_v_cache,
269
- self.decoder_loop.get_inputs()[5].name: pe,
270
- self.decoder_loop.get_inputs()[6].name: self_attn_mask,
271
- self.decoder_loop.get_inputs()[7].name: cross_attn_mask,
272
- }
273
- )
274
- return (
275
- logits,
276
- out_n_layer_self_k_cache,
277
- out_n_layer_self_v_cache
278
- )
279
-
280
- def run_decoder(
281
- self,
282
- n_layer_cross_k,
283
- n_layer_cross_v,
284
- cross_attn_mask,
285
- beam_size,
286
- nbest
287
- ):
288
-
289
- num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
290
- encoder_out_length = cross_attn_mask.shape[-1]
291
-
292
- cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
293
- cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
294
- 1, beam_size, 1, 1
295
- ).view(beam_size * batch_size, -1, encoder_out_length)
296
-
297
- n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
298
- n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
299
- n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
300
- 1, 1, beam_size, 1, 1
301
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
302
- n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
303
- 1, 1, beam_size, 1, 1
304
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
305
-
306
- prediction_tokens = torch.ones(
307
- beam_size * batch_size, 1).fill_(self.sos_id).long()
308
- tokens = prediction_tokens
309
- offset = torch.zeros(1, dtype=torch.int64)
310
- n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
311
- batch_size, beam_size
312
- )
313
-
314
- scores = torch.tensor([0.0] + [-INF]*(beam_size - 1)).float()
315
- scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
316
- is_finished = torch.zeros_like(scores)
317
-
318
- # self_attn_mask = torch.zeros(
319
- # batch_size * beam_size,
320
- # 1, 1
321
- # )
322
- self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
323
-
324
- results = [self.sos_id]
325
- for i in range(self.decode_max_len):
326
-
327
- self_attn_mask = torch.empty(
328
- batch_size * beam_size,
329
- prediction_tokens.shape[-1], prediction_tokens.shape[-1]
330
- ).fill_(-np.inf).triu_(1)
331
- self_attn_mask = self_attn_mask[:, -1:, :]
332
- self_attn_mask = to_numpy(self_attn_mask)
333
-
334
- logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_one_token(
335
- to_numpy(tokens),
336
- to_numpy(n_layer_self_k_cache),
337
- to_numpy(n_layer_self_v_cache),
338
- to_numpy(n_layer_cross_k),
339
- to_numpy(n_layer_cross_v),
340
- to_numpy(offset),
341
- to_numpy(self_attn_mask),
342
- to_numpy(cross_attn_mask)
343
- )
344
-
345
- tokens = to_numpy(tokens)
346
- n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
347
- n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
348
- n_layer_cross_k = to_numpy(n_layer_cross_k)
349
- n_layer_cross_v = to_numpy(n_layer_cross_v)
350
- cross_attn_mask = to_numpy(cross_attn_mask)
351
-
352
- # if i == 0:
353
- # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
354
- # to_numpy(tokens),
355
- # to_numpy(n_layer_self_k_cache),
356
- # to_numpy(n_layer_self_v_cache),
357
- # to_numpy(n_layer_cross_k),
358
- # to_numpy(n_layer_cross_v),
359
- # self.pe[offset],
360
- # self_attn_mask,
361
- # to_numpy(cross_attn_mask)
362
- # )
363
- # else:
364
- # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
365
- # to_numpy(tokens),
366
- # to_numpy(n_layer_self_k_cache),
367
- # to_numpy(n_layer_self_v_cache),
368
- # to_numpy(n_layer_cross_k),
369
- # to_numpy(n_layer_cross_v),
370
- # self.pe[offset],
371
- # self_attn_mask,
372
- # to_numpy(cross_attn_mask)
373
- # )
374
-
375
- offset += 1
376
- logits = torch.from_numpy(logits)
377
-
378
- logits = logits.squeeze(1)
379
- t_scores = F.log_softmax(logits, dim=-1)
380
- t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
381
- t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
382
- t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
383
-
384
- scores = scores + t_topB_scores
385
-
386
- scores = scores.view(batch_size, beam_size * beam_size)
387
- scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
388
- scores = scores.view(-1, 1)
389
-
390
- topB_row_number_in_each_B_rows_of_ys = torch.div(
391
- topB_score_ids, beam_size).view(batch_size * beam_size)
392
- stride = beam_size * torch.arange(batch_size).view(
393
- batch_size, 1).repeat(1, beam_size).view(batch_size * beam_size)
394
- topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
395
-
396
- prediction_tokens = prediction_tokens[topB_row_number_in_ys]
397
- t_ys = torch.gather(
398
- t_topB_ys.view(batch_size, beam_size * beam_size),
399
- dim=1, index=topB_score_ids
400
- ).view(beam_size * batch_size, 1)
401
-
402
- tokens = t_ys
403
-
404
- prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
405
-
406
- n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
407
- n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
408
-
409
- for i, self_k_cache in enumerate(n_layer_self_k_cache):
410
- n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
411
-
412
- for i, self_v_cache in enumerate(n_layer_self_v_cache):
413
- n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
414
-
415
- is_finished = t_ys.eq(self.eos_id)
416
- if is_finished.sum().item() == beam_size * batch_size:
417
- break
418
-
419
- scores = scores.view(batch_size, beam_size)
420
- prediction_valid_token_lengths = torch.sum(
421
- torch.ne(
422
- prediction_tokens.view(batch_size, beam_size, -1),
423
- self.eos_id),
424
- dim=-1
425
- ).int()
426
-
427
- nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
428
- index = nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
429
- nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
430
- nbest_prediction_tokens = nbest_prediction_tokens.view(batch_size, nbest_ids.size(1), -1)
431
- nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
432
- batch_size * beam_size)[index.view(-1)].view(batch_size, -1)
433
- nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
434
- for i in range(batch_size):
435
- i_best_hyps: List[Dict[str, torch.Tensor]] = []
436
- for j, score in enumerate(nbest_scores[i]):
437
- hyp = {
438
- "token_ids": nbest_prediction_tokens[i, j, 1:nbest_prediction_valid_token_lengths[i, j]],
439
- "score": score
440
- }
441
- i_best_hyps.append(hyp)
442
- nbest_hyps.append(i_best_hyps)
443
-
444
- return nbest_hyps
445
-
446
- def get_initialized_self_cache(self,
447
- batch_size,
448
- beam_size
449
- ) -> Tuple[Tensor, Tensor]:
450
- n_layer_self_k_cache = torch.zeros(
451
- self.num_decoder_blocks,
452
- batch_size * beam_size,
453
- self.decode_max_len,
454
- self.decoder_hidden_dim,
455
- )
456
- n_layer_self_v_cache = torch.zeros(
457
- self.num_decoder_blocks,
458
- batch_size * beam_size,
459
- self.decode_max_len,
460
- self.decoder_hidden_dim,
461
- )
462
- return n_layer_self_k_cache, n_layer_self_v_cache
463
-
464
- def calc_feat_len(self, audio_dur):
465
- import math
466
- sample_rate = 16000
467
- frame_length = 25 * sample_rate / 1000
468
- frame_shift = 10 * sample_rate / 1000
469
- length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
470
- return length
471
-
472
- def transcribe(self,
473
- batch_wav_path: List[str],
474
- beam_size: int = 1,
475
- nbest: int = 1
476
- ) -> List[Dict]:
477
- feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
478
- print(f"feats.shape: {feats.shape}")
479
- maxlen = self.calc_feat_len(10)
480
- if feats.shape[1] < maxlen:
481
- feats = np.concatenate([feats, np.zeros((1, maxlen - feats.shape[1], 80), dtype=np.float32)], axis=1)
482
- feats = feats[:, :maxlen, :]
483
-
484
- encoder_data_path = os.path.join("encoder_output", os.path.basename(batch_wav_path[0]))
485
- # decoder_data_path = os.path.join("calib_dataset", "decoder", os.path.basename(batch_wav_path[0]))
486
- os.makedirs(encoder_data_path, exist_ok=True)
487
- # os.makedirs(decoder_data_path, exist_ok=True)
488
-
489
- feats = to_numpy(feats)
490
- lengths = to_numpy(lengths)
491
-
492
- # for name, npy in zip(["encoder_input", "encoder_input_lengths"], [feats, lengths]):
493
- # file_path = os.path.join(encoder_data_path, name + ".npy")
494
- # np.save(file_path, npy)
495
-
496
- start_time = time.time()
497
- n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
498
- to_numpy(feats),
499
- to_numpy(lengths)
500
- )
501
-
502
- for name, npy in zip(["n_layer_cross_k", "n_layer_cross_v", "cross_attn_mask"], [n_layer_cross_k, n_layer_cross_v, cross_attn_mask]):
503
- file_path = os.path.join(encoder_data_path, name + ".npy")
504
- np.save(file_path, npy)
505
-
506
- # nbest_hyps = self.run_decoder(n_layer_cross_k,
507
- # n_layer_cross_v,
508
- # cross_attn_mask,
509
- # beam_size,
510
- # nbest
511
- # )
512
- # transcribe_durations = time.time() - start_time
513
- # results: List[Dict] = []
514
- # for wav, hyp in zip(batch_wav_path, nbest_hyps):
515
- # hyp = hyp[0]
516
- # hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
517
- # score = hyp["score"].item()
518
- # text = self.tokenizer.detokenize(hyp_ids)
519
- # results.append(
520
- # {
521
- # "wav": wav,
522
- # "text": text,
523
- # "score": score
524
- # }
525
- # )
526
-
527
- # return results, wav_durations, transcribe_durations
528
-
529
-
530
- def parse_args():
531
- parser = argparse.ArgumentParser(description="FireRedASROnnxModel Test")
532
- parser.add_argument(
533
- "--encoder",
534
- type=str,
535
- default="axmodel/encoder.axmodel",
536
- help="Path to onnx encoder"
537
- )
538
- parser.add_argument(
539
- "--decoder",
540
- type=str,
541
- default="onnx_decoder/decoder.onnx",
542
- help="Path to onnx decoder"
543
- )
544
- parser.add_argument(
545
- "--cmvn",
546
- type=str,
547
- default="axmodel/cmvn.ark",
548
- help="Path to cmvn"
549
- )
550
- parser.add_argument(
551
- "--dict",
552
- type=str,
553
- default="axmodel/dict.txt",
554
- help="Path to dict"
555
- )
556
- parser.add_argument(
557
- "--spm_model",
558
- type=str,
559
- default="axmodel/train_bpe1000.model",
560
- help="Path to spm model"
561
- )
562
- parser.add_argument(
563
- "--wavlist",
564
- type=str,
565
- default="wavlist.txt",
566
- help="File to wav path list"
567
- )
568
- parser.add_argument(
569
- "--hypo",
570
- type=str,
571
- default="hypo_axmodel.txt",
572
- help="File of hypos"
573
- )
574
- parser.add_argument(
575
- "--beam_size",
576
- type=int,
577
- default=3,
578
- help=""
579
- )
580
- parser.add_argument(
581
- "--nbest",
582
- type=int,
583
- default=1,
584
- help=""
585
- )
586
-
587
- return parser.parse_args()
588
-
589
-
590
- def parse_wavlist(wavlist: str):
591
- wavpaths = []
592
- with open(wavlist) as f:
593
- for line in f:
594
- line = line.strip()
595
- if not os.path.exists(line):
596
- print(f"{line} doesn't exist.")
597
- continue
598
- wavpaths.append(line)
599
-
600
- return wavpaths
601
-
602
-
603
- def main():
604
- args = parse_args()
605
- print(args)
606
-
607
- onnx_model = FireRedASROnnxModel(args.encoder,
608
- args.decoder,
609
- args.cmvn,
610
- args.dict,
611
- args.spm_model)
612
-
613
- wf = open(args.hypo, "wt")
614
- wavlist = parse_wavlist(args.wavlist)
615
-
616
- total_wav_durations = 0
617
- total_transcribe_durations = 0
618
- for wav in wavlist:
619
- batch_wav = [wav]
620
- onnx_model.transcribe(batch_wav, args.beam_size, args.nbest)
621
-
622
- # wav_durations = sum(wav_durations)
623
- # total_wav_durations += wav_durations
624
- # total_transcribe_durations += transcribe_durations
625
- # logger.info(f"{batch_wav}")
626
- # logger.info(f"Durations: {wav_durations}")
627
- # logger.info(f"Transcribe Durations: {transcribe_durations}")
628
- # rtf = transcribe_durations / wav_durations
629
- # logger.info(f"(Real time factor) RTF: {rtf}")
630
- # for result in results:
631
- # logger.info(f"wav: {result['wav']}")
632
- # logger.info(f"text: {result['text']}")
633
- # logger.info(f"score: {result['score']}")
634
- # logger.info("")
635
- # wf.write(f"{result['text']} ({result['wav']})\n")
636
-
637
- # logger.info(f"total wav durations: {total_wav_durations}")
638
- # logger.info(f"total transcribe durations: {total_transcribe_durations}")
639
- # avg_ref = total_transcribe_durations / total_wav_durations
640
- # logger.info(f"AVG RTF: {avg_ref}")
641
-
642
- wf.close()
643
-
644
-
645
- if __name__ == "__main__":
646
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_onnx_model.py DELETED
@@ -1,684 +0,0 @@
1
- from fireredasr.data.asr_feat import ASRFeatExtractor
2
- from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
3
-
4
- import onnxruntime as ort
5
- import torch
6
- import torch.nn.functional as F
7
- import numpy as np
8
- from torch import Tensor
9
- from typing import Tuple, List, Dict
10
- import argparse
11
- import os
12
- import time
13
- import logging
14
-
15
- logger = logging.getLogger()
16
- logger.setLevel(logging.INFO)
17
- logger_stream_hander = logging.StreamHandler()
18
- logger_stream_hander.setLevel("INFO")
19
- logger.addHandler(logger_stream_hander)
20
-
21
-
22
- INF = 1e10
23
-
24
-
25
- def to_numpy(tensor):
26
- if isinstance(tensor, np.ndarray):
27
- return tensor
28
- if tensor.requires_grad:
29
- return tensor.detach().cpu().numpy()
30
- else:
31
- return tensor.cpu().numpy()
32
-
33
-
34
- def set_finished_beam_score_to_zero(scores, is_finished):
35
- NB, B = scores.size()
36
- is_finished = is_finished.float()
37
- mask_score = torch.tensor([0.0] + [-INF]*(B-1)).float()
38
- mask_score = mask_score.view(1, B).repeat(NB, 1)
39
- return scores * (1 - is_finished) + mask_score * is_finished
40
-
41
-
42
- def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
43
- is_finished = is_finished.long()
44
- return ys * (1 - is_finished) + eos_id * is_finished
45
-
46
-
47
- class FireRedASROnnxModel:
48
- def __init__(
49
- self,
50
- encoder_path: str,
51
- decoder_path: str,
52
- cmvn_file: str,
53
- dict_file: str,
54
- spm_model_path: str,
55
- providers=["CPUExecutionProvider"]
56
- ):
57
- session_opts = ort.SessionOptions()
58
- session_opts.inter_op_num_threads = 1
59
- session_opts.intra_op_num_threads = 1
60
- # session_opts.log_severity_level = 1
61
- self.session_opts = session_opts
62
-
63
- # NOTE: 参考whisper设置的最大的解码长度
64
- # FireRedASR-AED 模型支持的最长语音为 60s
65
- # ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
66
- self.decode_max_len = 448
67
-
68
- self.decoder_hidden_dim = 1280
69
- self.num_decoder_blocks = 16
70
- self.blank_id = 0
71
- self.sos_id = 3
72
- self.eos_id = 4
73
- self.pad_id = 2
74
-
75
- self.feature_extractor = ASRFeatExtractor(cmvn_file)
76
- self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path)
77
- self.encoder = None
78
- self.decoder = None
79
-
80
- self.init_encoder(encoder_path, providers)
81
- self.init_decoder(decoder_path, providers)
82
- self.init_decoder_main(decoder_path, providers)
83
- self.init_decoder_loop(decoder_path, providers)
84
- self.pe = self.init_pe(decoder_path)
85
-
86
- def init_encoder(self, encoder_path, providers=None):
87
- start_time = time.time()
88
- self.encoder = ort.InferenceSession(
89
- encoder_path,
90
- sess_options=self.session_opts,
91
- providers=providers
92
- )
93
- end_time = time.time()
94
- logger.info(f"load encoder cost {end_time - start_time} seconds")
95
-
96
- def init_decoder(self, decoder_path, providers=None):
97
- start_time = time.time()
98
- self.decoder = ort.InferenceSession(
99
- decoder_path,
100
- sess_options=self.session_opts,
101
- providers=providers
102
- )
103
- end_time = time.time()
104
- logger.info(f"load decoder cost {end_time - start_time} seconds")
105
-
106
- def init_decoder_main(self, decoder_path, providers=None):
107
- decoder_path = os.path.dirname(decoder_path)
108
- decoder_path = os.path.join(decoder_path, "decoder_main.onnx")
109
- start_time = time.time()
110
- self.decoder_main = ort.InferenceSession(
111
- decoder_path,
112
- sess_options=self.session_opts,
113
- providers=providers
114
- )
115
- end_time = time.time()
116
- logger.info(f"load decoder_main cost {end_time - start_time} seconds")
117
-
118
- input_names = [i.name for i in self.decoder_main.get_inputs()]
119
- print(f"decoder_main.input_names: {input_names}")
120
-
121
- def init_decoder_loop(self, decoder_path, providers=None):
122
- decoder_path = os.path.dirname(decoder_path)
123
- decoder_path = os.path.join(decoder_path, "decoder_loop.onnx")
124
-
125
- start_time = time.time()
126
- self.decoder_loop = ort.InferenceSession(
127
- decoder_path,
128
- sess_options=self.session_opts,
129
- providers=providers
130
- )
131
- end_time = time.time()
132
- logger.info(f"load decoder_loop cost {end_time - start_time} seconds")
133
-
134
- input_names = [i.name for i in self.decoder_loop.get_inputs()]
135
- print(f"decoder_loop.input_names: {input_names}")
136
-
137
- def init_pe(self, decoder_path):
138
- decoder_path = os.path.dirname(decoder_path)
139
- decoder_path = os.path.join(decoder_path, "pe.npy")
140
-
141
- return np.load(decoder_path)
142
-
143
- def run_encoder(self, input: np.ndarray,
144
- input_length: np.ndarray
145
- ) -> Tuple[Tensor, Tensor, Tensor]:
146
- n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run(
147
- None,
148
- {
149
- self.encoder.get_inputs()[0].name: input,
150
- self.encoder.get_inputs()[1].name: input_length
151
- }
152
- )
153
- return (
154
- n_layer_cross_k,
155
- n_layer_cross_v,
156
- cross_attn_mask
157
- )
158
-
159
- def decode_one_token(
160
- self,
161
- tokens: np.ndarray,
162
- n_layer_self_k_cache: np.ndarray,
163
- n_layer_self_v_cache: np.ndarray,
164
- n_layer_cross_k_cache: np.ndarray,
165
- n_layer_cross_v_cache: np.ndarray,
166
- offset: np.ndarray,
167
- self_attn_mask: np.ndarray,
168
- cross_attn_mask: np.ndarray
169
- ) -> Tuple[Tensor, Tensor, Tensor]:
170
- # print("decode:")
171
- # print(f"tokens.shape: {tokens.shape}")
172
- # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
173
- # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
174
- # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
175
- # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
176
- # print(f"offset.shape: {offset.shape}")
177
- # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
178
- # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
179
- # print(f"self_attn_mask: {self_attn_mask}")
180
-
181
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
182
- None,
183
- {
184
- self.decoder.get_inputs()[0].name: tokens,
185
- self.decoder.get_inputs()[1].name: n_layer_self_k_cache,
186
- self.decoder.get_inputs()[2].name: n_layer_self_v_cache,
187
- self.decoder.get_inputs()[3].name: n_layer_cross_k_cache,
188
- self.decoder.get_inputs()[4].name: n_layer_cross_v_cache,
189
- self.decoder.get_inputs()[5].name: offset,
190
- self.decoder.get_inputs()[6].name: self_attn_mask,
191
- self.decoder.get_inputs()[7].name: cross_attn_mask,
192
- }
193
- )
194
- return (
195
- logits,
196
- out_n_layer_self_k_cache,
197
- out_n_layer_self_v_cache
198
- )
199
-
200
- def decode_main_one_token(
201
- self,
202
- tokens: np.ndarray,
203
- n_layer_self_k_cache: np.ndarray,
204
- n_layer_self_v_cache: np.ndarray,
205
- n_layer_cross_k_cache: np.ndarray,
206
- n_layer_cross_v_cache: np.ndarray,
207
- pe: np.ndarray,
208
- self_attn_mask: np.ndarray,
209
- cross_attn_mask: np.ndarray
210
- ) -> Tuple[Tensor, Tensor, Tensor]:
211
- # print("decode_main:")
212
- # print(f"tokens.shape: {tokens.shape}")
213
- # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
214
- # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
215
- # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
216
- # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
217
- # print(f"pe.shape: {pe.shape}")
218
- # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
219
- # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
220
-
221
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_main.run(
222
- None,
223
- {
224
- self.decoder_main.get_inputs()[0].name: tokens,
225
- # self.decoder_main.get_inputs()[1].name: n_layer_self_k_cache,
226
- self.decoder_main.get_inputs()[1].name: n_layer_cross_k_cache,
227
- self.decoder_main.get_inputs()[2].name: n_layer_cross_v_cache,
228
- self.decoder_main.get_inputs()[3].name: pe,
229
- self.decoder_main.get_inputs()[4].name: self_attn_mask,
230
- self.decoder_main.get_inputs()[5].name: cross_attn_mask,
231
- # self.decoder_main.get_inputs()[7].name: cross_attn_mask,
232
- }
233
- )
234
- return (
235
- logits,
236
- out_n_layer_self_k_cache,
237
- out_n_layer_self_v_cache
238
- )
239
-
240
- def decode_loop_one_token(
241
- self,
242
- tokens: np.ndarray,
243
- n_layer_self_k_cache: np.ndarray,
244
- n_layer_self_v_cache: np.ndarray,
245
- n_layer_cross_k_cache: np.ndarray,
246
- n_layer_cross_v_cache: np.ndarray,
247
- pe: np.ndarray,
248
- self_attn_mask: np.ndarray,
249
- cross_attn_mask: np.ndarray
250
- ) -> Tuple[Tensor, Tensor, Tensor]:
251
- # print("decode_loop:")
252
- # print(f"tokens.shape: {tokens.shape}")
253
- # print(f"n_layer_self_k_cache.shape: {n_layer_self_k_cache.shape}")
254
- # print(f"n_layer_self_v_cache.shape: {n_layer_self_v_cache.shape}")
255
- # print(f"n_layer_cross_k_cache.shape: {n_layer_cross_k_cache.shape}")
256
- # print(f"n_layer_cross_v_cache.shape: {n_layer_cross_v_cache.shape}")
257
- # print(f"pe.shape: {pe.shape}")
258
- # print(f"self_attn_mask.shape: {self_attn_mask.shape}")
259
- # print(f"cross_attn_mask.shape: {cross_attn_mask.shape}")
260
-
261
- logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder_loop.run(
262
- None,
263
- {
264
- self.decoder_loop.get_inputs()[0].name: tokens,
265
- self.decoder_loop.get_inputs()[1].name: n_layer_self_k_cache,
266
- self.decoder_loop.get_inputs()[2].name: n_layer_self_v_cache,
267
- self.decoder_loop.get_inputs()[3].name: n_layer_cross_k_cache,
268
- self.decoder_loop.get_inputs()[4].name: n_layer_cross_v_cache,
269
- self.decoder_loop.get_inputs()[5].name: pe,
270
- self.decoder_loop.get_inputs()[6].name: self_attn_mask,
271
- self.decoder_loop.get_inputs()[7].name: cross_attn_mask,
272
- }
273
- )
274
- return (
275
- logits,
276
- out_n_layer_self_k_cache,
277
- out_n_layer_self_v_cache
278
- )
279
-
280
- def run_decoder(
281
- self,
282
- n_layer_cross_k,
283
- n_layer_cross_v,
284
- cross_attn_mask,
285
- beam_size,
286
- nbest,
287
- decoder_data_path
288
- ):
289
-
290
- num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
291
- encoder_out_length = cross_attn_mask.shape[-1]
292
-
293
- cross_attn_mask = torch.from_numpy(cross_attn_mask).to(torch.float32)
294
- cross_attn_mask = cross_attn_mask.unsqueeze(1).repeat(
295
- 1, beam_size, 1, 1
296
- ).view(beam_size * batch_size, -1, encoder_out_length)
297
-
298
- n_layer_cross_k = torch.from_numpy(n_layer_cross_k)
299
- n_layer_cross_v = torch.from_numpy(n_layer_cross_v)
300
- n_layer_cross_k = n_layer_cross_k.unsqueeze(2).repeat(
301
- 1, 1, beam_size, 1, 1
302
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
303
- n_layer_cross_v = n_layer_cross_v.unsqueeze(2).repeat(
304
- 1, 1, beam_size, 1, 1
305
- ).view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
306
-
307
- prediction_tokens = torch.ones(
308
- beam_size * batch_size, 1).fill_(self.sos_id).long()
309
- tokens = prediction_tokens
310
- offset = torch.zeros(1, dtype=torch.int64)
311
- n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
312
- batch_size, beam_size
313
- )
314
-
315
- scores = torch.tensor([0.0] + [-INF]*(beam_size - 1)).float()
316
- scores = scores.repeat(batch_size).view(batch_size * beam_size, 1)
317
- is_finished = torch.zeros_like(scores)
318
-
319
- # self_attn_mask = torch.zeros(
320
- # batch_size * beam_size,
321
- # 1, 1
322
- # )
323
- self_attn_mask = np.zeros((batch_size * beam_size, 1, 1), dtype=np.float32)
324
-
325
- results = [self.sos_id]
326
- for i in range(self.decode_max_len):
327
-
328
- # ==== ORIGIN ====
329
- # self_attn_mask = torch.empty(
330
- # batch_size * beam_size,
331
- # prediction_tokens.shape[-1], prediction_tokens.shape[-1]
332
- # ).fill_(-np.inf).triu_(1)
333
- # self_attn_mask = self_attn_mask[:, -1:, :]
334
- # self_attn_mask = to_numpy(self_attn_mask)
335
-
336
- # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_one_token(
337
- # to_numpy(tokens),
338
- # to_numpy(n_layer_self_k_cache),
339
- # to_numpy(n_layer_self_v_cache),
340
- # to_numpy(n_layer_cross_k),
341
- # to_numpy(n_layer_cross_v),
342
- # to_numpy(offset),
343
- # to_numpy(self_attn_mask),
344
- # to_numpy(cross_attn_mask)
345
- # )
346
- # ==== ORIGIN ====
347
-
348
- # tokens = to_numpy(tokens)
349
- # n_layer_self_k_cache = to_numpy(n_layer_self_k_cache)
350
- # n_layer_self_v_cache = to_numpy(n_layer_self_v_cache)
351
- # n_layer_cross_k = to_numpy(n_layer_cross_k)
352
- # n_layer_cross_v = to_numpy(n_layer_cross_v)
353
- # cross_attn_mask = to_numpy(cross_attn_mask)
354
-
355
- # for name, npy in zip(
356
- # ["tokens", "n_layer_self_k_cache", "n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "pe", "self_attn_mask", "cross_attn_mask"],
357
- # [tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, self.pe[offset], self_attn_mask, cross_attn_mask]
358
- # ):
359
- # file_path = os.path.join(decoder_data_path, name)
360
- # os.makedirs(file_path, exist_ok=True)
361
- # np.save(os.path.join(file_path, f"{i}.npy"), npy)
362
-
363
- # if i == 0:
364
- # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_main_one_token(
365
- # to_numpy(tokens),
366
- # to_numpy(n_layer_self_k_cache),
367
- # to_numpy(n_layer_self_v_cache),
368
- # to_numpy(n_layer_cross_k),
369
- # to_numpy(n_layer_cross_v),
370
- # self.pe[0],
371
- # self_attn_mask,
372
- # to_numpy(cross_attn_mask)
373
- # )
374
- # else:
375
- # logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
376
- # to_numpy(tokens),
377
- # to_numpy(n_layer_self_k_cache),
378
- # to_numpy(n_layer_self_v_cache),
379
- # to_numpy(n_layer_cross_k),
380
- # to_numpy(n_layer_cross_v),
381
- # self.pe[offset],
382
- # self_attn_mask,
383
- # to_numpy(cross_attn_mask)
384
- # )
385
- logits, n_layer_self_k_cache, n_layer_self_v_cache = self.decode_loop_one_token(
386
- to_numpy(tokens),
387
- to_numpy(n_layer_self_k_cache),
388
- to_numpy(n_layer_self_v_cache),
389
- to_numpy(n_layer_cross_k),
390
- to_numpy(n_layer_cross_v),
391
- self.pe[offset],
392
- self_attn_mask,
393
- to_numpy(cross_attn_mask)
394
- )
395
-
396
- offset += 1
397
- logits = torch.from_numpy(logits)
398
-
399
- logits = logits.squeeze(1)
400
- t_scores = F.log_softmax(logits, dim=-1)
401
- t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
402
- t_topB_scores = set_finished_beam_score_to_zero(t_topB_scores, is_finished)
403
- t_topB_ys = set_finished_beam_y_to_eos(t_topB_ys, is_finished, self.eos_id)
404
-
405
- scores = scores + t_topB_scores
406
-
407
- scores = scores.view(batch_size, beam_size * beam_size)
408
- scores, topB_score_ids = torch.topk(scores, k=beam_size, dim=1)
409
- scores = scores.view(-1, 1)
410
-
411
- topB_row_number_in_each_B_rows_of_ys = torch.div(
412
- topB_score_ids, beam_size).view(batch_size * beam_size)
413
- stride = beam_size * torch.arange(batch_size).view(
414
- batch_size, 1).repeat(1, beam_size).view(batch_size * beam_size)
415
- topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
416
-
417
- prediction_tokens = prediction_tokens[topB_row_number_in_ys]
418
- t_ys = torch.gather(
419
- t_topB_ys.view(batch_size, beam_size * beam_size),
420
- dim=1, index=topB_score_ids
421
- ).view(beam_size * batch_size, 1)
422
-
423
- tokens = t_ys
424
-
425
- prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
426
-
427
- n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
428
- n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
429
-
430
- for i, self_k_cache in enumerate(n_layer_self_k_cache):
431
- n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
432
-
433
- for i, self_v_cache in enumerate(n_layer_self_v_cache):
434
- n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
435
-
436
- is_finished = t_ys.eq(self.eos_id)
437
- if is_finished.sum().item() == beam_size * batch_size:
438
- break
439
-
440
- scores = scores.view(batch_size, beam_size)
441
- prediction_valid_token_lengths = torch.sum(
442
- torch.ne(
443
- prediction_tokens.view(batch_size, beam_size, -1),
444
- self.eos_id),
445
- dim=-1
446
- ).int()
447
-
448
- nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
449
- index = nbest_ids + beam_size * torch.arange(batch_size).view(batch_size, 1).long()
450
- nbest_prediction_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
451
- nbest_prediction_tokens = nbest_prediction_tokens.view(batch_size, nbest_ids.size(1), -1)
452
- nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
453
- batch_size * beam_size)[index.view(-1)].view(batch_size, -1)
454
- nbest_hyps: List[List[Dict[str, torch.Tensor]]] = []
455
- for i in range(batch_size):
456
- i_best_hyps: List[Dict[str, torch.Tensor]] = []
457
- for j, score in enumerate(nbest_scores[i]):
458
- hyp = {
459
- "token_ids": nbest_prediction_tokens[i, j, 1:nbest_prediction_valid_token_lengths[i, j]],
460
- "score": score
461
- }
462
- i_best_hyps.append(hyp)
463
- nbest_hyps.append(i_best_hyps)
464
-
465
- return nbest_hyps
466
-
467
- def get_initialized_self_cache(self,
468
- batch_size,
469
- beam_size
470
- ) -> Tuple[Tensor, Tensor]:
471
- n_layer_self_k_cache = torch.zeros(
472
- self.num_decoder_blocks,
473
- batch_size * beam_size,
474
- self.decode_max_len,
475
- self.decoder_hidden_dim,
476
- )
477
- n_layer_self_v_cache = torch.zeros(
478
- self.num_decoder_blocks,
479
- batch_size * beam_size,
480
- self.decode_max_len,
481
- self.decoder_hidden_dim,
482
- )
483
- return n_layer_self_k_cache, n_layer_self_v_cache
484
-
485
- def calc_feat_len(self, audio_dur):
486
- import math
487
- sample_rate = 16000
488
- frame_length = 25 * sample_rate / 1000
489
- frame_shift = 10 * sample_rate / 1000
490
- length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
491
- return length
492
-
493
- def transcribe(self,
494
- batch_wav_path: List[str],
495
- beam_size: int = 1,
496
- nbest: int = 1
497
- ) -> List[Dict]:
498
- feats, lengths, wav_durations = self.feature_extractor(batch_wav_path)
499
- print(f"feats.shape: {feats.shape}")
500
- maxlen = self.calc_feat_len(10)
501
- if feats.shape[1] < maxlen:
502
- feats = np.concatenate([feats, np.zeros((1, maxlen - feats.shape[1], 80), dtype=np.float32)], axis=1)
503
- feats = feats[:, :maxlen, :]
504
-
505
- # encoder_data_path = os.path.join("calib_dataset", "encoder", os.path.basename(batch_wav_path[0]))
506
- decoder_data_path = os.path.join("calib_dataset", "decoder", os.path.basename(batch_wav_path[0]))
507
- # os.makedirs(encoder_data_path, exist_ok=True)
508
- # os.makedirs(decoder_data_path, exist_ok=True)
509
-
510
- feats = to_numpy(feats)
511
- lengths = to_numpy(lengths)
512
-
513
- # for name, npy in zip(["encoder_input", "encoder_input_lengths"], [feats, lengths]):
514
- # file_path = os.path.join(encoder_data_path, name + ".npy")
515
- # np.save(file_path, npy)
516
-
517
- start_time = time.time()
518
- n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
519
- to_numpy(feats),
520
- to_numpy(lengths)
521
- )
522
- nbest_hyps = self.run_decoder(n_layer_cross_k,
523
- n_layer_cross_v,
524
- cross_attn_mask,
525
- beam_size,
526
- nbest,
527
- decoder_data_path)
528
- transcribe_durations = time.time() - start_time
529
- results: List[Dict] = []
530
- for wav, hyp in zip(batch_wav_path, nbest_hyps):
531
- hyp = hyp[0]
532
- hyp_ids = [int(id) for id in hyp["token_ids"].cpu()]
533
- score = hyp["score"].item()
534
- text = self.tokenizer.detokenize(hyp_ids)
535
- results.append(
536
- {
537
- "wav": wav,
538
- "text": text,
539
- "score": score
540
- }
541
- )
542
-
543
- return results, wav_durations, transcribe_durations
544
-
545
-
546
- def parse_args():
547
- parser = argparse.ArgumentParser(description="FireRedASROnnxModel Test")
548
- parser.add_argument(
549
- "--encoder",
550
- type=str,
551
- default="onnx_encoder/encoder.onnx",
552
- help="Path to onnx encoder"
553
- )
554
- parser.add_argument(
555
- "--decoder",
556
- type=str,
557
- default="onnx_decoder/decoder.onnx",
558
- help="Path to onnx decoder"
559
- )
560
- parser.add_argument(
561
- "--cmvn",
562
- type=str,
563
- default="axmodel/cmvn.ark",
564
- help="Path to cmvn"
565
- )
566
- parser.add_argument(
567
- "--dict",
568
- type=str,
569
- default="axmodel/dict.txt",
570
- help="Path to dict"
571
- )
572
- parser.add_argument(
573
- "--spm_model",
574
- type=str,
575
- default="axmodel/train_bpe1000.model",
576
- help="Path to spm model"
577
- )
578
- parser.add_argument(
579
- "--wavlist",
580
- type=str,
581
- default="wavlist.txt",
582
- help="File to wav path list"
583
- )
584
- parser.add_argument(
585
- "--hypo",
586
- type=str,
587
- default="hypo_onnx.txt",
588
- help="File of hypos"
589
- )
590
- parser.add_argument(
591
- "--beam_size",
592
- type=int,
593
- default=3,
594
- help=""
595
- )
596
- parser.add_argument(
597
- "--nbest",
598
- type=int,
599
- default=1,
600
- help=""
601
- )
602
- parser.add_argument(
603
- "--provider",
604
- default="CPUExecutionProvider",
605
- choices=['CUDAExecutionProvider', 'CPUExecutionProvider']
606
- )
607
-
608
- return parser.parse_args()
609
-
610
-
611
- def parse_wavlist(wavlist: str):
612
- wavpaths = []
613
- with open(wavlist) as f:
614
- for line in f:
615
- line = line.strip()
616
- if not os.path.exists(line):
617
- print(f"{line} doesn't exist.")
618
- continue
619
- wavpaths.append(line)
620
-
621
- return wavpaths
622
-
623
-
624
- def main():
625
- args = parse_args()
626
- print(args)
627
-
628
- onnx_model = FireRedASROnnxModel(args.encoder,
629
- args.decoder,
630
- args.cmvn,
631
- args.dict,
632
- args.spm_model,
633
- [args.provider])
634
-
635
- wf = open(args.hypo, "wt")
636
- wavlist = parse_wavlist(args.wavlist)
637
-
638
- total_wav_durations = 0
639
- total_transcribe_durations = 0
640
- for wav in wavlist:
641
- batch_wav = [wav]
642
- results, wav_durations, transcribe_durations = onnx_model.transcribe(
643
- batch_wav, args.beam_size, args.nbest)
644
-
645
- wav_durations = sum(wav_durations)
646
- total_wav_durations += wav_durations
647
- total_transcribe_durations += transcribe_durations
648
- logger.info(f"{batch_wav}")
649
- logger.info(f"Durations: {wav_durations}")
650
- logger.info(f"Transcribe Durations: {transcribe_durations}")
651
- rtf = transcribe_durations / wav_durations
652
- logger.info(f"(Real time factor) RTF: {rtf}")
653
- for result in results:
654
- logger.info(f"wav: {result['wav']}")
655
- logger.info(f"text: {result['text']}")
656
- logger.info(f"score: {result['score']}")
657
- logger.info("")
658
- wf.write(f"{result['text']} ({result['wav']})\n")
659
-
660
- logger.info(f"total wav durations: {total_wav_durations}")
661
- logger.info(f"total transcribe durations: {total_transcribe_durations}")
662
- avg_ref = total_transcribe_durations / total_wav_durations
663
- logger.info(f"AVG RTF: {avg_ref}")
664
-
665
- wf.close()
666
-
667
- # import tarfile as tf
668
- # import glob
669
-
670
- # with tf.open("./calib_dataset/encoder_input.tar.gz", "w:gz") as f:
671
- # for npy in glob.glob("./calib_dataset/encoder/*/encoder_input.npy"):
672
- # f.add(npy)
673
-
674
- # with tf.open("./calib_dataset/encoder_input_lengths.tar.gz", "w:gz") as f:
675
- # for npy in glob.glob("./calib_dataset/encoder/*/encoder_input_lengths.npy"):
676
- # f.add(npy)
677
-
678
- # for decoder_input in ["tokens", "n_layer_self_k_cache", "n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "pe", "self_attn_mask", "cross_attn_mask"]:
679
- # with tf.open(f"./calib_dataset/{decoder_input}.tar.gz", "w:gz") as f:
680
- # for npy in glob.glob(f"./calib_dataset/decoder/*/{decoder_input}"):
681
- # f.add(npy)
682
-
683
- if __name__ == "__main__":
684
- main()