Optimize RTF and cer
Browse files- README.md +6 -116
- axmodel/decoder_loop.axmodel +2 -2
- fireredasr/data/asr_feat.py +1 -1
- fireredasr_axmodel.py +495 -253
- test_ax_model.py +1 -1
- test_wer.py +1 -7
README.md
CHANGED
|
@@ -6,7 +6,7 @@ license: apache-2.0
|
|
| 6 |
|
| 7 |
小红书ASR AED-L版本在AX650N上的部署,原项目地址为:[https://github.com/FireRedTeam/FireRedASR](https://github.com/FireRedTeam/FireRedASR)
|
| 8 |
|
| 9 |
-
转换后的模型放置在axmodel目录,目前支持中文、英文,最长输入10
|
| 10 |
|
| 11 |
## 模型转换
|
| 12 |
|
|
@@ -50,121 +50,11 @@ pip install axengine-0.1.3-py3-none-any.whl
|
|
| 50 |
conda activate fireredasr
|
| 51 |
python test_ax_model.py
|
| 52 |
```
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
```
|
| 56 |
-
[INFO] Available providers: ['AxEngineExecutionProvider']
|
| 57 |
-
Namespace(encoder='axmodel/encoder.axmodel', decoder='axmodel/decoder_main.axmodel', cmvn='axmodel/cmvn.ark', dict='axmodel/dict.txt', spm_model='axmodel/train_bpe1000.model', wavlist='wavlist.txt', hypo='hypo_axmodel.txt', beam_size=3, nbest=1, max_len=128)
|
| 58 |
-
[WARNING] Selected provider(s): ['AXCLRTExecutionProvider'] is(are) not available.
|
| 59 |
-
[INFO] Using provider: AxEngineExecutionProvider
|
| 60 |
-
[INFO] Chip type: ChipType.MC50
|
| 61 |
-
[INFO] VNPU type: VNPUType.DISABLED
|
| 62 |
-
[INFO] Engine version: 2.12.0s
|
| 63 |
-
[INFO] Model type: 2 (triple core)
|
| 64 |
-
[INFO] Compiler version: 4.2 9555977e
|
| 65 |
-
load encoder cost 2.764460325241089 seconds
|
| 66 |
-
[WARNING] Selected provider(s): ['AXCLRTExecutionProvider'] is(are) not available.
|
| 67 |
-
[INFO] Using provider: AxEngineExecutionProvider
|
| 68 |
-
[INFO] Model type: 2 (triple core)
|
| 69 |
-
[INFO] Compiler version: 4.2 9555977e
|
| 70 |
-
load decoder_main cost 16.36833119392395 seconds
|
| 71 |
-
[WARNING] Selected provider(s): ['AXCLRTExecutionProvider'] is(are) not available.
|
| 72 |
-
[INFO] Using provider: AxEngineExecutionProvider
|
| 73 |
-
[INFO] Model type: 2 (triple core)
|
| 74 |
-
[INFO] Compiler version: 4.2 9555977e
|
| 75 |
-
load decoder_loop cost 16.194183826446533 seconds
|
| 76 |
-
run encoder take 196.9749927520752ms
|
| 77 |
-
run decoder_main take 130.2931308746338ms
|
| 78 |
-
run decoder_loop take 165.5733585357666ms
|
| 79 |
-
run decoder_loop take 109.67779159545898ms
|
| 80 |
-
run decoder_loop take 101.15742683410645ms
|
| 81 |
-
run decoder_loop take 110.09836196899414ms
|
| 82 |
-
run decoder_loop take 100.29029846191406ms
|
| 83 |
-
run decoder_loop take 109.33351516723633ms
|
| 84 |
-
run decoder_loop take 100.37779808044434ms
|
| 85 |
-
run decoder_loop take 109.72428321838379ms
|
| 86 |
-
run decoder_loop take 100.42023658752441ms
|
| 87 |
-
run decoder_loop take 101.71890258789062ms
|
| 88 |
-
run decoder_loop take 100.09407997131348ms
|
| 89 |
-
run decoder_loop take 110.25619506835938ms
|
| 90 |
-
run decoder_loop take 100.54206848144531ms
|
| 91 |
-
run decoder_loop take 101.93896293640137ms
|
| 92 |
-
['wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav']
|
| 93 |
-
Durations: 1.8
|
| 94 |
-
Transcribe Durations: 2.5527637004852295
|
| 95 |
-
(Real time factor) RTF: 1.4182020558251274
|
| 96 |
-
wav: wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav
|
| 97 |
-
text: 我有的时候说不清楚你们知道吗
|
| 98 |
-
score: -0.9156361222267151
|
| 99 |
-
|
| 100 |
-
run encoder take 180.2656650543213ms
|
| 101 |
-
run decoder_main take 91.42565727233887ms
|
| 102 |
-
run decoder_loop take 105.18240928649902ms
|
| 103 |
-
run decoder_loop take 100.56614875793457ms
|
| 104 |
-
run decoder_loop take 100.9066104888916ms
|
| 105 |
-
run decoder_loop take 100.9068489074707ms
|
| 106 |
-
run decoder_loop take 102.90265083312988ms
|
| 107 |
-
run decoder_loop take 100.50129890441895ms
|
| 108 |
-
run decoder_loop take 110.12482643127441ms
|
| 109 |
-
run decoder_loop take 100.65031051635742ms
|
| 110 |
-
run decoder_loop take 110.09883880615234ms
|
| 111 |
-
run decoder_loop take 105.48877716064453ms
|
| 112 |
-
run decoder_loop take 100.32439231872559ms
|
| 113 |
-
run decoder_loop take 106.08601570129395ms
|
| 114 |
-
run decoder_loop take 100.79813003540039ms
|
| 115 |
-
run decoder_loop take 100.4643440246582ms
|
| 116 |
-
run decoder_loop take 100.30460357666016ms
|
| 117 |
-
['wav/TEST_MEETING_T0000000001_S00000.wav']
|
| 118 |
-
Durations: 12.369
|
| 119 |
-
Transcribe Durations: 2.464834690093994
|
| 120 |
-
(Real time factor) RTF: 0.19927517908432324
|
| 121 |
-
wav: wav/TEST_MEETING_T0000000001_S00000.wav
|
| 122 |
-
text: 好首先说一下刚才这个
|
| 123 |
-
score: -0.5064160823822021
|
| 124 |
-
|
| 125 |
-
run encoder take 172.59907722473145ms
|
| 126 |
-
run decoder_main take 91.79949760437012ms
|
| 127 |
-
run decoder_loop take 105.04364967346191ms
|
| 128 |
-
run decoder_loop take 100.62885284423828ms
|
| 129 |
-
run decoder_loop take 101.89318656921387ms
|
| 130 |
-
run decoder_loop take 100.42643547058105ms
|
| 131 |
-
run decoder_loop take 109.7562313079834ms
|
| 132 |
-
['wav/IT0011W0001.wav']
|
| 133 |
-
Durations: 1.992
|
| 134 |
-
Transcribe Durations: 1.0302071571350098
|
| 135 |
-
(Real time factor) RTF: 0.5171722676380571
|
| 136 |
-
wav: wav/IT0011W0001.wav
|
| 137 |
-
text: 换一首歌
|
| 138 |
-
score: -0.016501454636454582
|
| 139 |
-
|
| 140 |
-
run encoder take 173.07257652282715ms
|
| 141 |
-
run decoder_main take 91.48693084716797ms
|
| 142 |
-
run decoder_loop take 105.42607307434082ms
|
| 143 |
-
run decoder_loop take 100.10981559753418ms
|
| 144 |
-
run decoder_loop take 100.4478931427002ms
|
| 145 |
-
run decoder_loop take 100.23713111877441ms
|
| 146 |
-
run decoder_loop take 100.10337829589844ms
|
| 147 |
-
run decoder_loop take 100.29196739196777ms
|
| 148 |
-
run decoder_loop take 101.7463207244873ms
|
| 149 |
-
run decoder_loop take 100.8148193359375ms
|
| 150 |
-
run decoder_loop take 109.99274253845215ms
|
| 151 |
-
run decoder_loop take 105.45015335083008ms
|
| 152 |
-
run decoder_loop take 100.59380531311035ms
|
| 153 |
-
run decoder_loop take 100.73733329772949ms
|
| 154 |
-
run decoder_loop take 100.4335880279541ms
|
| 155 |
-
run decoder_loop take 109.68661308288574ms
|
| 156 |
-
['wav/BAC009S0764W0121.wav']
|
| 157 |
-
Durations: 4.2039375
|
| 158 |
-
Transcribe Durations: 2.3024709224700928
|
| 159 |
-
(Real time factor) RTF: 0.5476938994621334
|
| 160 |
-
wav: wav/BAC009S0764W0121.wav
|
| 161 |
-
text: 甚至出现交易几乎停滞的情况
|
| 162 |
-
score: -0.11461181938648224
|
| 163 |
-
|
| 164 |
-
total wav durations: 20.364937500000003
|
| 165 |
-
total transcribe durations: 8.350276470184326
|
| 166 |
-
AVG RTF: 0.4100320204854213
|
| 167 |
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
|
| 170 |
-
```hypo_axmodel.txt```包含识别结果
|
|
|
|
| 6 |
|
| 7 |
小红书ASR AED-L版本在AX650N上的部署,原项目地址为:[https://github.com/FireRedTeam/FireRedASR](https://github.com/FireRedTeam/FireRedASR)
|
| 8 |
|
| 9 |
+
转换后的模型放置在axmodel目录,目前支持中文、英文,最长输入10秒的音频,超过10秒的音频会用VAD切割后推理。
|
| 10 |
|
| 11 |
## 模型转换
|
| 12 |
|
|
|
|
| 50 |
conda activate fireredasr
|
| 51 |
python test_ax_model.py
|
| 52 |
```
|
| 53 |
+
```hypo_axmodel.txt```包含识别结果
|
| 54 |
|
| 55 |
+
## 性能表现
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
RTF ~= 0.3
|
| 58 |
+
|
| 59 |
+
CER(on custom dataset): 3.45%
|
| 60 |
|
|
|
axmodel/decoder_loop.axmodel
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2912496e6804027f2dc77c903f6b2f76678603dd616e662b78e3f226bcaa91a
|
| 3 |
+
size 416269694
|
fireredasr/data/asr_feat.py
CHANGED
|
@@ -42,7 +42,7 @@ class ASRFeatExtractor:
|
|
| 42 |
|
| 43 |
lengths = torch.tensor([feat.size(0) for feat in feats]).long()
|
| 44 |
feats_pad = self.pad_feat(feats, 0.0)
|
| 45 |
-
return feats_pad, lengths, dur
|
| 46 |
|
| 47 |
def pad_feat(self, xs, pad_value):
|
| 48 |
# type: (List[Tensor], int) -> Tensor
|
|
|
|
| 42 |
|
| 43 |
lengths = torch.tensor([feat.size(0) for feat in feats]).long()
|
| 44 |
feats_pad = self.pad_feat(feats, 0.0)
|
| 45 |
+
return feats_pad.numpy(), lengths, dur
|
| 46 |
|
| 47 |
def pad_feat(self, xs, pad_value):
|
| 48 |
# type: (List[Tensor], int) -> Tensor
|
fireredasr_axmodel.py
CHANGED
|
@@ -10,6 +10,7 @@ from typing import Tuple, List, Dict
|
|
| 10 |
import os
|
| 11 |
import time
|
| 12 |
import torchaudio
|
|
|
|
| 13 |
|
| 14 |
try:
|
| 15 |
torchaudio.set_audio_backend("soundfile")
|
|
@@ -44,18 +45,30 @@ def set_finished_beam_y_to_eos(ys, is_finished, eos_id):
|
|
| 44 |
return ys * (1 - is_finished) + eos_id * is_finished
|
| 45 |
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
class FireRedASRAxModel:
|
| 48 |
-
def __init__(
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
audio_dur=10,
|
| 58 |
-
):
|
| 59 |
# NOTE: 参考whisper设置的最大的解码长度
|
| 60 |
# FireRedASR-AED 模型支持的最长语音为 60s
|
| 61 |
# ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
|
|
@@ -79,6 +92,21 @@ class FireRedASRAxModel:
|
|
| 79 |
|
| 80 |
self.vad_model = load_silero_vad()
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def init_encoder(self, encoder_path, providers=None):
|
| 83 |
self.encoder = axe.InferenceSession(encoder_path, providers=providers)
|
| 84 |
|
|
@@ -90,7 +118,7 @@ class FireRedASRAxModel:
|
|
| 90 |
decoder_path = os.path.join(decoder_path, "pe.npy")
|
| 91 |
|
| 92 |
return np.load(decoder_path)
|
| 93 |
-
|
| 94 |
def run_encoder(
|
| 95 |
self, input: np.ndarray, input_length: np.ndarray
|
| 96 |
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
@@ -98,7 +126,7 @@ class FireRedASRAxModel:
|
|
| 98 |
None, {"encoder_input": input, "encoder_input_lengths": input_length}
|
| 99 |
)
|
| 100 |
return (n_layer_cross_k, n_layer_cross_v, cross_attn_mask)
|
| 101 |
-
|
| 102 |
def decode_loop_one_token(
|
| 103 |
self,
|
| 104 |
tokens: np.ndarray,
|
|
@@ -128,271 +156,485 @@ class FireRedASRAxModel:
|
|
| 128 |
},
|
| 129 |
)
|
| 130 |
return (logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache)
|
| 131 |
-
|
| 132 |
-
def
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
|
| 137 |
encoder_out_length = cross_attn_mask.shape[-1]
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
n_layer_cross_v = (
|
| 154 |
-
n_layer_cross_v.unsqueeze(2)
|
| 155 |
-
.repeat(1, 1, beam_size, 1, 1)
|
| 156 |
-
.view(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
prediction_tokens = (
|
| 160 |
-
torch.ones(beam_size * batch_size, 1).fill_(self.sos_id).long()
|
| 161 |
-
)
|
| 162 |
-
tokens = prediction_tokens
|
| 163 |
-
offset = torch.zeros(1, dtype=torch.int64)
|
| 164 |
-
n_layer_self_k_cache, n_layer_self_v_cache = self.get_initialized_self_cache(
|
| 165 |
batch_size, beam_size
|
| 166 |
)
|
| 167 |
-
|
| 168 |
-
scores
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
(
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
self_attn_mask,
|
| 200 |
-
to_numpy(cross_attn_mask),
|
| 201 |
)
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
logits = torch.from_numpy(logits)
|
| 205 |
-
|
| 206 |
-
logits = logits.squeeze(1)
|
| 207 |
t_scores = F.log_softmax(logits, dim=-1)
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
scores = scores.view(-1, 1)
|
| 217 |
-
|
| 218 |
-
topB_row_number_in_each_B_rows_of_ys = torch.div(
|
| 219 |
-
topB_score_ids, beam_size
|
| 220 |
-
).view(batch_size * beam_size)
|
| 221 |
-
stride = beam_size * torch.arange(batch_size).view(batch_size, 1).repeat(
|
| 222 |
-
1, beam_size
|
| 223 |
-
).view(batch_size * beam_size)
|
| 224 |
-
topB_row_number_in_ys = (
|
| 225 |
-
topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
|
| 226 |
)
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
t_ys = torch.gather(
|
| 230 |
-
t_topB_ys.view(batch_size, beam_size * beam_size),
|
| 231 |
-
dim=1,
|
| 232 |
-
index=topB_score_ids,
|
| 233 |
-
).view(beam_size * batch_size, 1)
|
| 234 |
-
|
| 235 |
-
tokens = t_ys
|
| 236 |
-
|
| 237 |
-
prediction_tokens = torch.cat((prediction_tokens, t_ys), dim=1)
|
| 238 |
-
|
| 239 |
-
n_layer_self_k_cache = torch.from_numpy(n_layer_self_k_cache)
|
| 240 |
-
n_layer_self_v_cache = torch.from_numpy(n_layer_self_v_cache)
|
| 241 |
-
|
| 242 |
-
for i, self_k_cache in enumerate(n_layer_self_k_cache):
|
| 243 |
-
n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
|
| 244 |
-
|
| 245 |
-
for i, self_v_cache in enumerate(n_layer_self_v_cache):
|
| 246 |
-
n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
|
| 247 |
-
|
| 248 |
-
is_finished = t_ys.eq(self.eos_id)
|
| 249 |
-
if is_finished.sum().item() == beam_size * batch_size:
|
| 250 |
break
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
scores = scores.view(batch_size, beam_size)
|
| 253 |
-
|
| 254 |
torch.ne(prediction_tokens.view(batch_size, beam_size, -1), self.eos_id),
|
| 255 |
-
dim=-1
|
| 256 |
).int()
|
| 257 |
-
|
| 258 |
nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
|
| 259 |
-
index = (
|
| 260 |
-
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
]
|
| 265 |
-
nbest_prediction_tokens = nbest_prediction_tokens.view(
|
| 266 |
-
batch_size, nbest_ids.size(1), -1
|
| 267 |
-
)
|
| 268 |
-
nbest_prediction_valid_token_lengths = prediction_valid_token_lengths.view(
|
| 269 |
-
batch_size * beam_size
|
| 270 |
-
)[index.view(-1)].view(batch_size, -1)
|
| 271 |
-
|
| 272 |
-
# batch_size is always 1
|
| 273 |
-
i_best_hyps: List[Dict[str, torch.Tensor]] = []
|
| 274 |
for j, score in enumerate(nbest_scores[0]):
|
| 275 |
hyp = {
|
| 276 |
-
"token_ids":
|
| 277 |
-
0, j, 1 : nbest_prediction_valid_token_lengths[0, j]
|
| 278 |
-
],
|
| 279 |
"score": score,
|
| 280 |
}
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
return
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
n_layer_self_v_cache = torch.zeros(
|
| 295 |
-
self.num_decoder_blocks,
|
| 296 |
-
batch_size * beam_size,
|
| 297 |
-
self.decode_max_len,
|
| 298 |
-
self.decoder_hidden_dim,
|
| 299 |
-
)
|
| 300 |
-
return n_layer_self_k_cache, n_layer_self_v_cache
|
| 301 |
-
|
| 302 |
-
def calc_feat_len(self, audio_dur):
|
| 303 |
-
import math
|
| 304 |
-
|
| 305 |
-
sample_rate = self.sample_rate
|
| 306 |
-
frame_length = 25 * sample_rate / 1000
|
| 307 |
-
frame_shift = 10 * sample_rate / 1000
|
| 308 |
-
length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
|
| 309 |
-
return length
|
| 310 |
-
|
| 311 |
-
def collect_chunks(self, wav, speech_timestamps, audio_dur, sample_rate):
|
| 312 |
-
max_chunk_samples = int(audio_dur * sample_rate)
|
| 313 |
-
chunks = []
|
| 314 |
-
for ts in speech_timestamps:
|
| 315 |
-
start, end = ts["start"], ts["end"]
|
| 316 |
-
cur_chunk = wav[start:end]
|
| 317 |
-
if (
|
| 318 |
-
len(chunks) > 0
|
| 319 |
-
and chunks[-1].shape[0] + cur_chunk.shape[0] < max_chunk_samples
|
| 320 |
-
):
|
| 321 |
-
chunks[-1] = torch.concat([chunks[-1], cur_chunk], dim=0)
|
| 322 |
-
else:
|
| 323 |
-
if cur_chunk.shape[0] > max_chunk_samples:
|
| 324 |
-
# greedy split if one chunk is too big
|
| 325 |
-
chunks.append(cur_chunk[:max_chunk_samples])
|
| 326 |
-
chunks.append(cur_chunk[max_chunk_samples:])
|
| 327 |
-
else:
|
| 328 |
-
chunks.append(cur_chunk)
|
| 329 |
-
return chunks
|
| 330 |
-
|
| 331 |
-
def transcribe(
|
| 332 |
-
self, batch_wav_path: List[str], beam_size: int = 1, nbest: int = 1
|
| 333 |
) -> List[Dict]:
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
tokens = []
|
|
|
|
|
|
|
|
|
|
| 359 |
for chunk in chunks:
|
| 360 |
-
|
| 361 |
-
feats, lengths, wav_duration = self.
|
| 362 |
-
chunk, self.sample_rate
|
| 363 |
-
)
|
| 364 |
-
|
| 365 |
wav_durations.append(wav_duration)
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
feats = np.concatenate(
|
| 369 |
-
[
|
| 370 |
-
feats,
|
| 371 |
-
np.zeros(
|
| 372 |
-
(1, self.max_feat_len - feats.shape[1], 80),
|
| 373 |
-
dtype=np.float32,
|
| 374 |
-
),
|
| 375 |
-
],
|
| 376 |
-
axis=1,
|
| 377 |
-
)
|
| 378 |
-
feats = feats[:, : self.max_feat_len, :]
|
| 379 |
-
lengths = torch.minimum(lengths, torch.tensor(self.max_feat_len))
|
| 380 |
-
|
| 381 |
-
feats = to_numpy(feats)
|
| 382 |
-
lengths = to_numpy(lengths).astype(np.int32)
|
| 383 |
-
|
| 384 |
start_time = time.time()
|
| 385 |
n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
|
| 386 |
-
|
| 387 |
)
|
| 388 |
-
|
| 389 |
-
nbest_hyps = self.
|
| 390 |
n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
|
| 391 |
)
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
text = self.tokenizer.detokenize(tokens)
|
| 397 |
-
|
| 398 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import os
|
| 11 |
import time
|
| 12 |
import torchaudio
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 14 |
|
| 15 |
try:
|
| 16 |
torchaudio.set_audio_backend("soundfile")
|
|
|
|
| 45 |
return ys * (1 - is_finished) + eos_id * is_finished
|
| 46 |
|
| 47 |
|
| 48 |
+
def expand_for_beam_search(n_layer_cross_k, beam_size):
|
| 49 |
+
"""方法1: 使用expand_dims + tile + reshape (最快)"""
|
| 50 |
+
num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
|
| 51 |
+
|
| 52 |
+
# 在第2维插入新维度
|
| 53 |
+
expanded = np.expand_dims(n_layer_cross_k, axis=2)
|
| 54 |
+
# 使用tile替代repeat,性能更好
|
| 55 |
+
tiled = np.tile(expanded, (1, 1, beam_size, 1, 1))
|
| 56 |
+
# 重塑形状
|
| 57 |
+
reshaped = tiled.reshape(num_layer, beam_size * batch_size, Ti, encoder_out_dim)
|
| 58 |
+
|
| 59 |
+
return reshaped
|
| 60 |
+
|
| 61 |
+
|
| 62 |
class FireRedASRAxModel:
|
| 63 |
+
def __init__(self,
|
| 64 |
+
encoder_path: str,
|
| 65 |
+
decoder_loop_path: str,
|
| 66 |
+
cmvn_file: str,
|
| 67 |
+
dict_file: str,
|
| 68 |
+
spm_model_path: str,
|
| 69 |
+
providers=["AxEngineExecutionProvider"],
|
| 70 |
+
decode_max_len=128,
|
| 71 |
+
audio_dur=10):
|
|
|
|
|
|
|
| 72 |
# NOTE: 参考whisper设置的最大的解码长度
|
| 73 |
# FireRedASR-AED 模型支持的最长语音为 60s
|
| 74 |
# ref: https://github.com/FireRedTeam/FireRedASR?tab=readme-ov-file#input-length-limitations
|
|
|
|
| 92 |
|
| 93 |
self.vad_model = load_silero_vad()
|
| 94 |
|
| 95 |
+
# 预分配内存
|
| 96 |
+
self._preallocated_memory()
|
| 97 |
+
# 启用CUDA如果可用
|
| 98 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 99 |
+
# print(f"Using device: {self.device}")
|
| 100 |
+
|
| 101 |
+
def calc_feat_len(self, audio_dur):
|
| 102 |
+
import math
|
| 103 |
+
|
| 104 |
+
sample_rate = self.sample_rate
|
| 105 |
+
frame_length = 25 * sample_rate / 1000
|
| 106 |
+
frame_shift = 10 * sample_rate / 1000
|
| 107 |
+
length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1
|
| 108 |
+
return length
|
| 109 |
+
|
| 110 |
def init_encoder(self, encoder_path, providers=None):
|
| 111 |
self.encoder = axe.InferenceSession(encoder_path, providers=providers)
|
| 112 |
|
|
|
|
| 118 |
decoder_path = os.path.join(decoder_path, "pe.npy")
|
| 119 |
|
| 120 |
return np.load(decoder_path)
|
| 121 |
+
|
| 122 |
def run_encoder(
|
| 123 |
self, input: np.ndarray, input_length: np.ndarray
|
| 124 |
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
|
|
| 126 |
None, {"encoder_input": input, "encoder_input_lengths": input_length}
|
| 127 |
)
|
| 128 |
return (n_layer_cross_k, n_layer_cross_v, cross_attn_mask)
|
| 129 |
+
|
| 130 |
def decode_loop_one_token(
|
| 131 |
self,
|
| 132 |
tokens: np.ndarray,
|
|
|
|
| 156 |
},
|
| 157 |
)
|
| 158 |
return (logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache)
|
| 159 |
+
|
| 160 |
+
def _preallocated_memory(self):
|
| 161 |
+
"""预分配常用内存空间"""
|
| 162 |
+
# 预计算self_attn_mask模板
|
| 163 |
+
self.self_attn_mask_templates = {}
|
| 164 |
+
for offset in range(self.decode_max_len):
|
| 165 |
+
mask = np.zeros((1, 1, self.decode_max_len), dtype=np.float32)
|
| 166 |
+
mask[:, :, :self.decode_max_len - offset - 1] = -np.inf
|
| 167 |
+
self.self_attn_mask_templates[offset] = mask
|
| 168 |
+
|
| 169 |
+
# 预分配beam search的scores模板
|
| 170 |
+
self.beam_scores_template = torch.tensor(
|
| 171 |
+
[0.0] + [-INF] * (self.decode_max_len - 1)
|
| 172 |
+
).float()
|
| 173 |
+
|
| 174 |
+
def transcribe(
|
| 175 |
+
self,
|
| 176 |
+
batch_wav_path: List[str],
|
| 177 |
+
beam_size: int = 1,
|
| 178 |
+
nbest: int = 1,
|
| 179 |
+
use_parallel: bool = False
|
| 180 |
+
) -> List[Dict]:
|
| 181 |
+
"""优化后的转录方法"""
|
| 182 |
+
|
| 183 |
+
# 1. 优化VAD和分块处理
|
| 184 |
+
chunks = self._optimized_vad_split(batch_wav_path[0])
|
| 185 |
+
|
| 186 |
+
if use_parallel and len(chunks) > 1:
|
| 187 |
+
return self._parallel_transcribe(chunks, beam_size, nbest)
|
| 188 |
+
else:
|
| 189 |
+
return self._sequential_transcribe(chunks, beam_size, nbest)
|
| 190 |
+
|
| 191 |
+
def _optimized_vad_split(self, wav_path: str) -> List[torch.Tensor]:
|
| 192 |
+
"""优化的VAD分块处理"""
|
| 193 |
+
import torchaudio
|
| 194 |
+
|
| 195 |
+
# 直接读取为numpy数组,避免torchaudio开销
|
| 196 |
+
try:
|
| 197 |
+
wav, sr = torchaudio.load(wav_path)
|
| 198 |
+
if sr != self.sample_rate:
|
| 199 |
+
wav = torchaudio.functional.resample(wav, sr, self.sample_rate)
|
| 200 |
+
except:
|
| 201 |
+
# 使用silero_vad的read_audio作为备选
|
| 202 |
+
from silero_vad import read_audio
|
| 203 |
+
wav = read_audio(wav_path, sampling_rate=self.sample_rate)
|
| 204 |
+
wav = wav.unsqueeze(0)
|
| 205 |
+
|
| 206 |
+
wav = wav.squeeze(0)
|
| 207 |
+
|
| 208 |
+
# 快速VAD:如果音频较短,直接返回
|
| 209 |
+
max_chunk_samples = int(self.sample_rate * self.audio_dur)
|
| 210 |
+
if wav.shape[0] < max_chunk_samples:
|
| 211 |
+
return [wav]
|
| 212 |
+
|
| 213 |
+
# 使用优化的VAD参数
|
| 214 |
+
speech_timestamps = get_speech_timestamps(
|
| 215 |
+
wav,
|
| 216 |
+
self.vad_model,
|
| 217 |
+
threshold=0.5, # 提高阈值,减少静音检测
|
| 218 |
+
min_speech_duration_ms=250, # 最小语音段
|
| 219 |
+
min_silence_duration_ms=100, # 最小静音段
|
| 220 |
+
return_seconds=False,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# 优化的分块合并算法
|
| 224 |
+
return self._optimized_collect_chunks(wav, speech_timestamps)
|
| 225 |
+
|
| 226 |
+
def _optimized_collect_chunks(
|
| 227 |
+
self,
|
| 228 |
+
wav: torch.Tensor,
|
| 229 |
+
speech_timestamps: List[Dict]
|
| 230 |
+
) -> List[torch.Tensor]:
|
| 231 |
+
"""优化的分块合并算法"""
|
| 232 |
+
max_chunk_samples = int(self.sample_rate * self.audio_dur)
|
| 233 |
+
chunks = []
|
| 234 |
+
current_chunk = []
|
| 235 |
+
current_length = 0
|
| 236 |
+
|
| 237 |
+
for ts in speech_timestamps:
|
| 238 |
+
start, end = ts["start"], ts["end"]
|
| 239 |
+
chunk_length = end - start
|
| 240 |
+
|
| 241 |
+
if current_length + chunk_length <= max_chunk_samples:
|
| 242 |
+
current_chunk.append((start, end))
|
| 243 |
+
current_length += chunk_length
|
| 244 |
+
else:
|
| 245 |
+
if current_chunk:
|
| 246 |
+
# 合并当前chunk
|
| 247 |
+
merged = torch.cat([wav[s:e] for s, e in current_chunk])
|
| 248 |
+
chunks.append(merged)
|
| 249 |
+
|
| 250 |
+
if chunk_length > max_chunk_samples:
|
| 251 |
+
# 大chunk分割
|
| 252 |
+
num_splits = (chunk_length + max_chunk_samples - 1) // max_chunk_samples
|
| 253 |
+
for i in range(num_splits):
|
| 254 |
+
s = start + i * max_chunk_samples
|
| 255 |
+
e = min(start + (i + 1) * max_chunk_samples, end)
|
| 256 |
+
chunks.append(wav[s:e])
|
| 257 |
+
current_chunk = []
|
| 258 |
+
current_length = 0
|
| 259 |
+
else:
|
| 260 |
+
current_chunk = [(start, end)]
|
| 261 |
+
current_length = chunk_length
|
| 262 |
+
|
| 263 |
+
# 处理最后一个chunk
|
| 264 |
+
if current_chunk:
|
| 265 |
+
merged = torch.cat([wav[s:e] for s, e in current_chunk])
|
| 266 |
+
chunks.append(merged)
|
| 267 |
+
|
| 268 |
+
return chunks
|
| 269 |
+
|
| 270 |
+
def _optimized_decode_loop(
|
| 271 |
+
self,
|
| 272 |
+
n_layer_cross_k: np.ndarray,
|
| 273 |
+
n_layer_cross_v: np.ndarray,
|
| 274 |
+
cross_attn_mask: np.ndarray,
|
| 275 |
+
beam_size: int,
|
| 276 |
+
nbest: int
|
| 277 |
+
) -> List[Dict]:
|
| 278 |
+
"""优化的解码循环"""
|
| 279 |
+
|
| 280 |
num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape
|
| 281 |
encoder_out_length = cross_attn_mask.shape[-1]
|
| 282 |
|
| 283 |
+
n_layer_cross_k = expand_for_beam_search(n_layer_cross_k, beam_size)
|
| 284 |
+
n_layer_cross_v = expand_for_beam_search(n_layer_cross_v, beam_size)
|
| 285 |
+
|
| 286 |
+
batch_size, Ti, encoder_out_length = cross_attn_mask.shape
|
| 287 |
+
|
| 288 |
+
# 在第1维插入新维度
|
| 289 |
+
expanded = np.expand_dims(cross_attn_mask, axis=1)
|
| 290 |
+
# 使用tile替代repeat,性能更好
|
| 291 |
+
tiled = np.tile(expanded, (1, beam_size, 1, 1))
|
| 292 |
+
# 重塑形状
|
| 293 |
+
cross_attn_mask = tiled.reshape(beam_size * batch_size, Ti, encoder_out_length)
|
| 294 |
+
|
| 295 |
+
# 优化的cache初始化
|
| 296 |
+
n_layer_self_k_cache, n_layer_self_v_cache = self._optimized_init_self_cache(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
batch_size, beam_size
|
| 298 |
)
|
| 299 |
+
|
| 300 |
+
# 预分配tokens和scores
|
| 301 |
+
tokens = torch.full(
|
| 302 |
+
(beam_size * batch_size, 1),
|
| 303 |
+
self.sos_id,
|
| 304 |
+
dtype=torch.int32, device=self.device
|
| 305 |
+
)
|
| 306 |
+
scores = self.beam_scores_template[:beam_size].repeat(batch_size).view(
|
| 307 |
+
batch_size * beam_size, 1
|
| 308 |
+
).to(self.device)
|
| 309 |
+
is_finished = torch.zeros_like(scores, dtype=torch.bool, device=self.device)
|
| 310 |
+
|
| 311 |
+
# 预分配prediction_tokens
|
| 312 |
+
prediction_tokens = tokens.clone()
|
| 313 |
+
|
| 314 |
+
pe_np = self.pe
|
| 315 |
+
|
| 316 |
+
for offset in range(self.decode_max_len):
|
| 317 |
+
# 使用预计算的mask模板
|
| 318 |
+
self_attn_mask = np.repeat(
|
| 319 |
+
self.self_attn_mask_templates[offset],
|
| 320 |
+
beam_size * batch_size,
|
| 321 |
+
axis=0
|
| 322 |
)
|
| 323 |
+
|
| 324 |
+
# 直接使用numpy数组,避免转换
|
| 325 |
+
logits, n_layer_self_k_cache, n_layer_self_v_cache = (
|
| 326 |
+
self.decode_loop_one_token(
|
| 327 |
+
tokens.cpu().numpy().astype(np.int32),
|
| 328 |
+
n_layer_self_k_cache,
|
| 329 |
+
n_layer_self_v_cache,
|
| 330 |
+
n_layer_cross_k,
|
| 331 |
+
n_layer_cross_v,
|
| 332 |
+
pe_np[offset],
|
| 333 |
+
self_attn_mask,
|
| 334 |
+
cross_attn_mask
|
| 335 |
+
)
|
|
|
|
|
|
|
| 336 |
)
|
| 337 |
+
|
| 338 |
+
logits = torch.from_numpy(logits).to(self.device).squeeze(1)
|
|
|
|
|
|
|
|
|
|
| 339 |
t_scores = F.log_softmax(logits, dim=-1)
|
| 340 |
+
|
| 341 |
+
# 优化的beam search
|
| 342 |
+
tokens, scores, prediction_tokens, n_layer_self_k_cache, n_layer_self_v_cache, is_finished = (
|
| 343 |
+
self._optimized_beam_search(
|
| 344 |
+
t_scores, tokens, scores, prediction_tokens,
|
| 345 |
+
n_layer_self_k_cache, n_layer_self_v_cache,
|
| 346 |
+
is_finished, beam_size, batch_size
|
| 347 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
)
|
| 349 |
+
|
| 350 |
+
if is_finished.all():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
break
|
| 352 |
+
|
| 353 |
+
# return self._extract_results(scores, prediction_tokens, batch_size, beam_size, nbest)
|
| 354 |
+
return self.extract_results_numpy_vectorized(scores.numpy(), prediction_tokens.numpy(), batch_size, beam_size, nbest)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def _optimized_beam_search(
|
| 358 |
+
self,
|
| 359 |
+
t_scores: torch.Tensor,
|
| 360 |
+
tokens: torch.Tensor,
|
| 361 |
+
scores: torch.Tensor,
|
| 362 |
+
prediction_tokens: torch.Tensor,
|
| 363 |
+
n_layer_self_k_cache: torch.Tensor,
|
| 364 |
+
n_layer_self_v_cache: torch.Tensor,
|
| 365 |
+
is_finished: torch.Tensor,
|
| 366 |
+
beam_size: int,
|
| 367 |
+
batch_size: int
|
| 368 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 369 |
+
"""优化的beam search步骤"""
|
| 370 |
+
|
| 371 |
+
# 使用torch的in-place操作
|
| 372 |
+
t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1)
|
| 373 |
+
|
| 374 |
+
# 处理已完成的beam
|
| 375 |
+
if is_finished.any():
|
| 376 |
+
# 原地操作,避免创建新tensor
|
| 377 |
+
t_topB_scores.masked_fill_(is_finished, 0.0)
|
| 378 |
+
t_topB_scores[:, 1:].masked_fill_(is_finished, -INF)
|
| 379 |
+
t_topB_ys.masked_fill_(is_finished, self.eos_id)
|
| 380 |
+
|
| 381 |
+
# 更新scores
|
| 382 |
+
scores = scores + t_topB_scores
|
| 383 |
+
|
| 384 |
+
# 优化的topk选择
|
| 385 |
+
scores_2d = scores.view(batch_size, beam_size * beam_size)
|
| 386 |
+
top_scores, top_ids = torch.topk(scores_2d, k=beam_size, dim=1)
|
| 387 |
+
scores = top_scores.view(-1, 1)
|
| 388 |
+
|
| 389 |
+
# 计算索引
|
| 390 |
+
topB_row_number_in_each_B_rows_of_ys = torch.div(top_ids, beam_size, rounding_mode='floor')
|
| 391 |
+
stride = beam_size * torch.arange(batch_size, device=self.device).view(batch_size, 1)
|
| 392 |
+
topB_row_number_in_ys = (topB_row_number_in_each_B_rows_of_ys + stride).view(-1)
|
| 393 |
+
|
| 394 |
+
# 更新tokens和prediction_tokens
|
| 395 |
+
tokens = torch.gather(
|
| 396 |
+
t_topB_ys.view(batch_size, beam_size * beam_size),
|
| 397 |
+
dim=1,
|
| 398 |
+
index=top_ids,
|
| 399 |
+
).view(beam_size * batch_size, 1)
|
| 400 |
+
|
| 401 |
+
prediction_tokens = torch.cat([
|
| 402 |
+
prediction_tokens[topB_row_number_in_ys],
|
| 403 |
+
tokens
|
| 404 |
+
], dim=1)
|
| 405 |
+
|
| 406 |
+
# 更新cache(原地操作)
|
| 407 |
+
for i in range(n_layer_self_k_cache.shape[0]):
|
| 408 |
+
n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys]
|
| 409 |
+
n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys]
|
| 410 |
+
|
| 411 |
+
# 更新完成状态
|
| 412 |
+
is_finished = tokens.eq(self.eos_id)
|
| 413 |
+
|
| 414 |
+
return tokens, scores, prediction_tokens, n_layer_self_k_cache, n_layer_self_v_cache, is_finished
|
| 415 |
+
|
| 416 |
+
def _optimized_init_self_cache(
|
| 417 |
+
self, batch_size: int, beam_size: int
|
| 418 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 419 |
+
"""优化的self cache初始化"""
|
| 420 |
+
shape = (
|
| 421 |
+
self.num_decoder_blocks,
|
| 422 |
+
batch_size * beam_size,
|
| 423 |
+
self.decode_max_len,
|
| 424 |
+
self.decoder_hidden_dim
|
| 425 |
+
)
|
| 426 |
+
n_layer_self_k_cache = np.zeros(shape, dtype=np.float32)
|
| 427 |
+
n_layer_self_v_cache = np.zeros(shape, dtype=np.float32)
|
| 428 |
+
return n_layer_self_k_cache, n_layer_self_v_cache
|
| 429 |
+
|
| 430 |
+
def _extract_results(
|
| 431 |
+
self,
|
| 432 |
+
scores: torch.Tensor,
|
| 433 |
+
prediction_tokens: torch.Tensor,
|
| 434 |
+
batch_size: int,
|
| 435 |
+
beam_size: int,
|
| 436 |
+
nbest: int
|
| 437 |
+
) -> List[Dict]:
|
| 438 |
+
"""提取结果"""
|
| 439 |
scores = scores.view(batch_size, beam_size)
|
| 440 |
+
valid_lengths = torch.sum(
|
| 441 |
torch.ne(prediction_tokens.view(batch_size, beam_size, -1), self.eos_id),
|
| 442 |
+
dim=-1
|
| 443 |
).int()
|
| 444 |
+
|
| 445 |
nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1)
|
| 446 |
+
index = nbest_ids + beam_size * torch.arange(batch_size, device=self.device).unsqueeze(1)
|
| 447 |
+
|
| 448 |
+
nbest_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)]
|
| 449 |
+
nbest_tokens = nbest_tokens.view(batch_size, nbest_ids.size(1), -1)
|
| 450 |
+
|
| 451 |
+
results = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
for j, score in enumerate(nbest_scores[0]):
|
| 453 |
hyp = {
|
| 454 |
+
"token_ids": nbest_tokens[0, j, 1:valid_lengths[0, nbest_ids[0, j]]],
|
|
|
|
|
|
|
| 455 |
"score": score,
|
| 456 |
}
|
| 457 |
+
results.append(hyp)
|
| 458 |
+
|
| 459 |
+
return results
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def extract_results_numpy_vectorized(
|
| 463 |
+
self,
|
| 464 |
+
scores: np.ndarray,
|
| 465 |
+
prediction_tokens: np.ndarray,
|
| 466 |
+
batch_size: int,
|
| 467 |
+
beam_size: int,
|
| 468 |
+
nbest: int,
|
| 469 |
+
eos_id: int = 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
) -> List[Dict]:
|
| 471 |
+
"""向量化版本的NumPy实现"""
|
| 472 |
+
|
| 473 |
+
# 1. 重塑和计算有效长度
|
| 474 |
+
scores_2d = scores.reshape(batch_size, beam_size)
|
| 475 |
+
tokens_3d = prediction_tokens.reshape(batch_size, beam_size, -1)
|
| 476 |
+
|
| 477 |
+
# 计算有效长度(不包括eos_id)
|
| 478 |
+
valid_lengths = np.sum(tokens_3d != eos_id, axis=-1).astype(np.int32)
|
| 479 |
+
|
| 480 |
+
# 2. 使用argpartition进行部分排序(比argsort更快)
|
| 481 |
+
# 获取最大的nbest个元素的索引
|
| 482 |
+
# 使用argpartition: O(n) vs argsort: O(n log n)
|
| 483 |
+
partitioned_indices = np.argpartition(-scores_2d, nbest-1, axis=1)[:, :nbest]
|
| 484 |
+
|
| 485 |
+
# 对每个batch内的topk进行排序
|
| 486 |
+
nbest_scores = np.take_along_axis(scores_2d, partitioned_indices, axis=1)
|
| 487 |
+
sorted_order = np.argsort(-nbest_scores, axis=1)
|
| 488 |
+
|
| 489 |
+
# 应用排序
|
| 490 |
+
nbest_ids = np.take_along_axis(partitioned_indices, sorted_order, axis=1)
|
| 491 |
+
nbest_scores = np.take_along_axis(nbest_scores, sorted_order, axis=1)
|
| 492 |
+
|
| 493 |
+
# 3. 计算全局索引
|
| 494 |
+
batch_indices = np.arange(batch_size)[:, np.newaxis]
|
| 495 |
+
global_indices = nbest_ids + beam_size * batch_indices
|
| 496 |
+
flat_global_indices = global_indices.reshape(-1)
|
| 497 |
+
|
| 498 |
+
# 4. 提取tokens
|
| 499 |
+
flat_tokens = prediction_tokens.reshape(-1, prediction_tokens.shape[-1])
|
| 500 |
+
nbest_tokens = flat_tokens[flat_global_indices]
|
| 501 |
+
nbest_tokens = nbest_tokens.reshape(batch_size, nbest, -1)
|
| 502 |
+
|
| 503 |
+
# 5. 提取对应的有效长度
|
| 504 |
+
nbest_valid_lengths = np.take_along_axis(valid_lengths, nbest_ids, axis=1)
|
| 505 |
+
|
| 506 |
+
# 6. 构建结果
|
| 507 |
+
results = []
|
| 508 |
+
|
| 509 |
+
for b in range(batch_size):
|
| 510 |
+
batch_results = []
|
| 511 |
+
for j in range(nbest):
|
| 512 |
+
valid_len = nbest_valid_lengths[b, j]
|
| 513 |
+
|
| 514 |
+
# 提取token_ids(跳过<sos>)
|
| 515 |
+
token_ids = nbest_tokens[b, j, 1:valid_len]
|
| 516 |
+
|
| 517 |
+
hyp = {
|
| 518 |
+
"token_ids": token_ids.tolist(),
|
| 519 |
+
"score": float(nbest_scores[b, j]),
|
| 520 |
+
}
|
| 521 |
+
batch_results.append(hyp)
|
| 522 |
+
|
| 523 |
+
# 如果是批量处理,可以按batch返回
|
| 524 |
+
# 这里假设batch_size=1,直接返回第一个batch的结果
|
| 525 |
+
if b == 0:
|
| 526 |
+
results = batch_results
|
| 527 |
+
|
| 528 |
+
return results
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def _sequential_transcribe(
|
| 532 |
+
self,
|
| 533 |
+
chunks: List[torch.Tensor],
|
| 534 |
+
beam_size: int,
|
| 535 |
+
nbest: int
|
| 536 |
+
) -> Dict:
|
| 537 |
+
"""顺序转录(单线程)"""
|
| 538 |
tokens = []
|
| 539 |
+
wav_durations = []
|
| 540 |
+
transcribe_duration = 0
|
| 541 |
+
|
| 542 |
for chunk in chunks:
|
| 543 |
+
# 优化的特征提取
|
| 544 |
+
feats, lengths, wav_duration = self._optimized_feature_extraction(chunk)
|
|
|
|
|
|
|
|
|
|
| 545 |
wav_durations.append(wav_duration)
|
| 546 |
+
|
| 547 |
+
# 运行encoder和decoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
start_time = time.time()
|
| 549 |
n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
|
| 550 |
+
feats, lengths.numpy().astype(np.int32)
|
| 551 |
)
|
| 552 |
+
|
| 553 |
+
nbest_hyps = self._optimized_decode_loop(
|
| 554 |
n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
|
| 555 |
)
|
| 556 |
+
|
| 557 |
+
tokens.extend([int(id) for id in nbest_hyps[0]["token_ids"]])
|
| 558 |
+
transcribe_duration += time.time() - start_time
|
| 559 |
+
|
| 560 |
text = self.tokenizer.detokenize(tokens)
|
| 561 |
+
return {"text": text}, wav_durations, transcribe_duration
|
| 562 |
+
|
| 563 |
+
def _parallel_transcribe(
|
| 564 |
+
self,
|
| 565 |
+
chunks: List[torch.Tensor],
|
| 566 |
+
beam_size: int,
|
| 567 |
+
nbest: int
|
| 568 |
+
) -> Dict:
|
| 569 |
+
"""并行转录(多线程)"""
|
| 570 |
+
import threading
|
| 571 |
+
|
| 572 |
+
results = []
|
| 573 |
+
lock = threading.Lock()
|
| 574 |
+
|
| 575 |
+
def process_chunk(chunk_idx, chunk):
|
| 576 |
+
try:
|
| 577 |
+
# 特征提取
|
| 578 |
+
feats, lengths, wav_duration = self._optimized_feature_extraction(chunk)
|
| 579 |
+
|
| 580 |
+
# encoder
|
| 581 |
+
n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder(
|
| 582 |
+
feats, lengths.astype(np.int32)
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# decoder
|
| 586 |
+
nbest_hyps = self._optimized_decode_loop(
|
| 587 |
+
n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
with lock:
|
| 591 |
+
results.append({
|
| 592 |
+
'chunk_idx': chunk_idx,
|
| 593 |
+
'tokens': [int(id) for id in nbest_hyps[0]["token_ids"].cpu()],
|
| 594 |
+
'duration': wav_duration
|
| 595 |
+
})
|
| 596 |
+
except Exception as e:
|
| 597 |
+
print(f"Error processing chunk {chunk_idx}: {e}")
|
| 598 |
+
|
| 599 |
+
# 使用ThreadPoolExecutor并行处理
|
| 600 |
+
with ThreadPoolExecutor(max_workers=min(4, len(chunks))) as executor:
|
| 601 |
+
futures = []
|
| 602 |
+
for i, chunk in enumerate(chunks):
|
| 603 |
+
future = executor.submit(process_chunk, i, chunk)
|
| 604 |
+
futures.append(future)
|
| 605 |
+
|
| 606 |
+
# 等待所有任务完成
|
| 607 |
+
for future in as_completed(futures):
|
| 608 |
+
future.result()
|
| 609 |
+
|
| 610 |
+
# 合并结果
|
| 611 |
+
results.sort(key=lambda x: x['chunk_idx'])
|
| 612 |
+
tokens = []
|
| 613 |
+
wav_durations = []
|
| 614 |
+
|
| 615 |
+
for result in results:
|
| 616 |
+
tokens.extend(result['tokens'])
|
| 617 |
+
wav_durations.append(result['duration'])
|
| 618 |
+
|
| 619 |
+
text = self.tokenizer.detokenize(tokens)
|
| 620 |
+
return {"text": text}, wav_durations, 0 # 并行处理时间不好统计
|
| 621 |
+
|
| 622 |
+
def _optimized_feature_extraction(
|
| 623 |
+
self,
|
| 624 |
+
chunk: torch.Tensor
|
| 625 |
+
) -> Tuple[np.ndarray, np.ndarray, float]:
|
| 626 |
+
"""优化的特征提取"""
|
| 627 |
+
chunk = (chunk.clamp(-1, 1) * 32768).to(torch.int16)
|
| 628 |
+
feats, lengths, wav_duration = self.feature_extractor.run_chunk(
|
| 629 |
+
chunk, self.sample_rate
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
# 原地padding,避免创建新数组
|
| 633 |
+
if feats.shape[1] < self.max_feat_len:
|
| 634 |
+
pad_width = ((0, 0), (0, self.max_feat_len - feats.shape[1]), (0, 0))
|
| 635 |
+
feats = np.pad(feats, pad_width, mode='constant', constant_values=0)
|
| 636 |
+
|
| 637 |
+
feats = feats[:, :self.max_feat_len, :]
|
| 638 |
+
lengths = np.minimum(lengths, self.max_feat_len)
|
| 639 |
+
|
| 640 |
+
return feats, lengths, wav_duration
|
test_ax_model.py
CHANGED
|
@@ -44,7 +44,7 @@ def parse_args():
|
|
| 44 |
parser.add_argument(
|
| 45 |
"--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
|
| 46 |
)
|
| 47 |
-
parser.add_argument("--beam_size", type=int, default=
|
| 48 |
parser.add_argument("--nbest", type=int, default=1, help="")
|
| 49 |
parser.add_argument("--decode_max_len", type=int, default=128, help="max token len")
|
| 50 |
parser.add_argument("--max_dur", type=int, default=10, help="max audio len")
|
|
|
|
| 44 |
parser.add_argument(
|
| 45 |
"--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
|
| 46 |
)
|
| 47 |
+
parser.add_argument("--beam_size", type=int, default=1, help="")
|
| 48 |
parser.add_argument("--nbest", type=int, default=1, help="")
|
| 49 |
parser.add_argument("--decode_max_len", type=int, default=128, help="max token len")
|
| 50 |
parser.add_argument("--max_dur", type=int, default=10, help="max audio len")
|
test_wer.py
CHANGED
|
@@ -183,12 +183,6 @@ def get_args():
|
|
| 183 |
default="axmodel/encoder.axmodel",
|
| 184 |
help="Path to onnx encoder",
|
| 185 |
)
|
| 186 |
-
parser.add_argument(
|
| 187 |
-
"--decoder_main",
|
| 188 |
-
type=str,
|
| 189 |
-
default="axmodel/decoder_main.axmodel",
|
| 190 |
-
help="Path to axmodel decoder main",
|
| 191 |
-
)
|
| 192 |
parser.add_argument(
|
| 193 |
"--decoder_loop",
|
| 194 |
type=str,
|
|
@@ -213,7 +207,7 @@ def get_args():
|
|
| 213 |
parser.add_argument(
|
| 214 |
"--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
|
| 215 |
)
|
| 216 |
-
parser.add_argument("--beam_size", type=int, default=
|
| 217 |
parser.add_argument("--nbest", type=int, default=1, help="")
|
| 218 |
parser.add_argument("--max_len", type=int, default=128, help="")
|
| 219 |
return parser.parse_args()
|
|
|
|
| 183 |
default="axmodel/encoder.axmodel",
|
| 184 |
help="Path to onnx encoder",
|
| 185 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
parser.add_argument(
|
| 187 |
"--decoder_loop",
|
| 188 |
type=str,
|
|
|
|
| 207 |
parser.add_argument(
|
| 208 |
"--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
|
| 209 |
)
|
| 210 |
+
parser.add_argument("--beam_size", type=int, default=1, help="")
|
| 211 |
parser.add_argument("--nbest", type=int, default=1, help="")
|
| 212 |
parser.add_argument("--max_len", type=int, default=128, help="")
|
| 213 |
return parser.parse_args()
|