shethjenil commited on
Commit
ae4543e
·
verified ·
1 Parent(s): c4574c2

Update modeling_conformer.py

Browse files
Files changed (1) hide show
  1. modeling_conformer.py +265 -265
modeling_conformer.py CHANGED
@@ -1,265 +1,265 @@
1
- from datetime import timedelta
2
- import gc
3
- import json
4
- from huggingface_hub import hf_hub_download
5
- import torch
6
- import torch.nn.functional as F
7
- import torchaudio
8
- import librosa
9
- from torch import nn
10
- from transformers import Wav2Vec2ConformerModel
11
- from torch_state_bridge import state_bridge
12
- from torch.nn.utils.rnn import pad_sequence
13
- from safetensors.torch import load_file
14
- import webrtcvad
15
- from torch.utils.data import Dataset , DataLoader
16
- import srt
17
-
18
- def calc_length(lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
19
- add_pad = all_paddings - kernel_size
20
- for _ in range(repeat_num):
21
- lengths = torch.floor((lengths.float() + add_pad) / stride + 1)
22
- return lengths
23
-
24
- class ChunkedData(Dataset):
25
- def __init__(self, wav, sr):
26
- if sr != 16000: wav = torchaudio.functional.resample(wav, sr, 16000)
27
- wav = wav.mean(0, keepdim=True)
28
- self.data, self.ts = self.make_chunks(wav)
29
-
30
- def __len__(self): return len(self.data)
31
- def __getitem__(self, i): return self.data[i], self.ts[i]
32
-
33
- def make_chunks(self, wav, sr=16000, ag=2, min_s=10, max_s=15, ms=30):
34
- w = (wav * 32768).clamp(-32768, 32767).short().squeeze(0)
35
- fl = int(sr * ms / 1000)
36
- nf = len(w) // fl
37
- w = w[: nf * fl]
38
- fr = w.view(nf, fl)
39
- vad = webrtcvad.Vad(ag)
40
- sp = torch.zeros(nf, dtype=torch.bool)
41
- for i, f in enumerate(fr):
42
- try: sp[i] = vad.is_speech(f.cpu().numpy().tobytes(), sr)
43
- except: pass
44
- seg, s = [], None
45
- for i, v in enumerate(sp):
46
- if v and s is None: s = i
47
- elif not v and s is not None: seg.append((s, i)); s = None
48
- if s is not None: seg.append((s, len(sp)))
49
- cs, ts, st = [], [], 0
50
- mn, mx, N = int(min_s * sr), int(max_s * sr), len(w)
51
- while st < N:
52
- ed = min(st + mx, N)
53
- f = ed // fl
54
- while f < len(sp) and sp[f]:
55
- f += 1; ed = min(f * fl, N)
56
- if ed - st > mx * 1.5: break
57
- if ed - st < mn and ed < N: ed = min(st + mn, N)
58
- cs.append(wav[:, st:ed].squeeze())
59
- ts.append([round(st / sr, 2), round(ed / sr, 2)])
60
- st = ed
61
- return cs, torch.tensor(ts)
62
-
63
-
64
-
65
- def padding_audio(batch):
66
- audios, times = zip(*batch)
67
- return pad_sequence(audios, batch_first=True), torch.tensor([audio.numel() for audio in audios]), torch.stack(times)
68
-
69
- class Op(nn.Module):
70
- def __init__(self, func,allow_self=False):
71
- super().__init__()
72
- self.func = func
73
- self.allow_self = allow_self
74
-
75
- def forward(self, x):
76
- if self.allow_self:
77
- return self.func(self,x)
78
- return self.func(x)
79
-
80
- class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
81
- def __init__(self, config):
82
- self.language = config.languages[0]
83
- if len(config.languages) > 1:
84
- config.hidden_size = 1024
85
- config.num_hidden_layers = 24
86
- config.conv_depthwise_kernel_size = 9
87
- config.conv_stride = [2,2,2]
88
- config.conv_kernel = [3,3,3]
89
- config.conv_dim = [256,256,256]
90
- config.feat_extract_norm = "group"
91
- config.intermediate_size = 4096
92
- config.num_feat_extract_layers = len(config.conv_dim)
93
- config.lstm_layer = 2
94
-
95
- self.cache_length = None
96
- self.hop, self.preemph, self.eps, self.pad_to = 160, 0.97, 2**-24, 16
97
- self.denorm = (2 ** config.num_feat_extract_layers) * self.hop / config.sampling_rate
98
- self.scaler = config.hidden_size ** (1/2)
99
- super().__init__(config)
100
- self.eval()
101
-
102
- def init_weights(self):
103
- del self.encoder.pos_conv_embed
104
- config = self.config
105
- self.enc = nn.Linear(config.hidden_size, config.joint_hidden)
106
- self.pred = nn.Linear(config.pred_hidden, config.joint_hidden)
107
- self.joint = nn.Linear(config.joint_hidden, config.vocab_size // 22 + 1)
108
- self.embed = nn.Embedding(config.vocab_size+1, config.pred_hidden, padding_idx=config.vocab_size)
109
- self.lstm = nn.LSTM(config.pred_hidden, config.pred_hidden, config.lstm_layer, batch_first=True)
110
- self.act = nn.ReLU(inplace=True)
111
- self.spec = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160, win_length=400, center=False)
112
- self.mask_layer = Op(lambda self_obj,x : x.masked_fill(self_obj.cache_pad_mask.unsqueeze(1), 0),True)
113
- self.mel_fb = nn.Parameter(torch.tensor(librosa.filters.mel(sr=self.config.sampling_rate, n_fft=512, n_mels=80)),False)
114
-
115
- for idx,l in enumerate(self.feature_extractor.conv_layers):
116
- if len(self.config.languages) == 1 or idx == 0:
117
- l.conv = nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1)
118
- l.layer_norm = nn.Identity()
119
- else:
120
- l.conv = nn.Sequential(nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1,groups=l.conv.out_channels),nn.Conv2d(l.conv.in_channels,l.conv.out_channels, 1))
121
-
122
- self.feature_extractor.conv_layers.append(Op(lambda x : x.transpose(1, 2)))
123
- self.feature_projection.projection = nn.Linear(config.conv_dim[-1] * int(calc_length(torch.tensor(80.),repeat_num=self.config.num_feat_extract_layers)),config.hidden_size)
124
- self.feature_projection.layer_norm = Op(lambda x:x.permute(0, 2, 1, 3).flatten(2))
125
- for l in self.encoder.layers:
126
- l.conv_module.glu = nn.Sequential(l.conv_module.glu,self.mask_layer)
127
- l.conv_module.pointwise_conv1.bias = nn.Parameter(torch.empty(l.conv_module.pointwise_conv1.out_channels))
128
- l.conv_module.pointwise_conv2.bias = nn.Parameter(torch.empty(l.conv_module.pointwise_conv2.out_channels))
129
- l.conv_module.depthwise_conv.bias = nn.Parameter(torch.empty(l.conv_module.depthwise_conv.out_channels))
130
- self.encoder.layer_norm = nn.Identity()
131
- if len(self.config.languages) > 1:
132
- self.lang_joint_net = nn.ModuleDict({l: nn.Linear(config.joint_hidden, config.vocab_size // 22 + 1) for l in config.languages})
133
- return super().init_weights()
134
-
135
- def _mask_hidden_states(self, hidden_states, mask_time_indices = None, attention_mask = None):
136
- hidden_states = hidden_states * self.scaler
137
- self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
138
- return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
139
-
140
- def preprocessing(self, x):
141
- x, l = x
142
- l = (l // self.hop + 1).long()
143
- x = torch.cat((x[:, :1], x[:, 1:] - self.preemph * x[:, :-1]), 1)
144
- x = (self.mel_fb @ self.spec(x) + self.eps).log()
145
- T = x.size(-1)
146
- m = torch.arange(T, device=x.device)[None] >= l[:, None]
147
- x = x.masked_fill(m[:, None], 0)
148
- μ = x.sum(-1) / l[:, None]
149
- σ = (((x - μ[..., None])**2).sum(-1) / (l[:, None] - 1) + 1e-5).sqrt()
150
- x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
151
- self.cache_length = calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
152
- return F.pad(x, (0, (-T) % self.pad_to)).transpose(1, 2)
153
-
154
- def forward(self, input_values):
155
- return self.postprocessing(super().forward(self.preprocessing(input_values)).last_hidden_state)
156
-
157
- @torch.inference_mode()
158
- def transcribe(self,wav,sr,batch_size):
159
- device = next(self.parameters()).device
160
- subtitles = []
161
- for batch, lengths, timestamp in DataLoader(ChunkedData(wav, sr),batch_size,collate_fn=padding_audio):
162
- batch = batch.to(device)
163
- lengths = lengths.to(device)
164
- timestamp = timestamp.to(device)
165
- subtitles.extend(self.make_srt(self.forward((batch, lengths)),timestamp))
166
- yield srt.compose(subtitles)
167
- torch.cuda.empty_cache()
168
- gc.collect()
169
-
170
- def load_state_dict(self, state_dict, strict=True, assign=False):
171
- del state_dict['ctc_decoder.decoder_layers.0.bias']
172
- del state_dict['ctc_decoder.decoder_layers.0.weight']
173
- state_dict['preprocessor.featurizer.fb'] = state_dict['preprocessor.featurizer.fb'].squeeze(0)
174
- changes = """
175
- preprocessor.featurizer.fb,mel_fb
176
- preprocessor.featurizer.window,spec.window
177
- norm_feed_forward1,ffn1_layer_norm
178
- norm_feed_forward2,ffn2_layer_norm
179
- feed_forward1.linear1,ffn1.intermediate_dense
180
- feed_forward1.linear2,ffn1.output_dense
181
- feed_forward2.linear1,ffn2.intermediate_dense
182
- feed_forward2.linear2,ffn2.output_dense
183
- norm_self_att,self_attn_layer_norm
184
- norm_out,final_layer_norm
185
- norm_conv,conv_module.layer_norm
186
- .conv.,.conv_module.
187
- decoder.prediction.dec_rnn.lstm,lstm
188
- decoder.prediction.embed,embed
189
- joint.enc,enc
190
- joint.pred,pred
191
- joint.joint_net.2,lang_joint_net
192
- encoder.pre_encode.conv_module.0,feature_extractor.conv_layers.0.conv
193
- encoder.pre_encode.out,feature_projection.projection
194
- """
195
- if len(self.config.languages) == 1:
196
- changes += f"""lang_joint_net.{self.language},joint
197
- encoder.pre_encode.conv_module.{{n}},feature_extractor.conv_layers.{{(n/2)}}.conv"""
198
- else:
199
- state_dict["joint.weight"] = self.joint.weight.clone()
200
- state_dict["joint.bias"] = self.joint.bias.clone()
201
- changes += """encoder.pre_encode.conv_module.{n},encoder.pre_encode.conv_module.{(n-2)}
202
- encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv.{(n%3)}
203
- """
204
- # replicate many changes for complex maths
205
- state_dict = state_bridge(state_dict, changes)
206
- if len(self.config.languages) == 1:
207
- state_dict = {k: v for k, v in state_dict.items() if "lang_joint_net" not in k}
208
- return super().load_state_dict(state_dict, strict, assign)
209
-
210
- def postprocessing(self, x):
211
- if len(self.config.languages) > 1:
212
- self.joint.load_state_dict(self.lang_joint_net[self.language].state_dict())
213
- B = x.size(0)
214
- last = x.new_full((B, 1), self.config.blank_id, dtype=torch.long)
215
- h, tok, st = None, [[] for _ in range(B)], [[] for _ in range(B)]
216
- for t, e in enumerate(x.unbind(1)):
217
- v = t < self.cache_length
218
- if not v.any(): break
219
- e = e[:, None]
220
- for _ in range(self.config.max_symbols_per_step):
221
- p, h2 = self.lstm(self.embed(last), h)
222
- lg = self.joint(self.act(self.enc(e) + self.pred(p))).squeeze(1)
223
- n = torch.where(v, lg.argmax(-1), self.config.blank_id)
224
- b = n.eq(self.config.blank_id)
225
- if b.all(): break
226
- a = v & ~b
227
- for i in a.nonzero().flatten().tolist():
228
- tok[i].append(n[i]); st[i].append(t * self.denorm)
229
- last = torch.where(a[:, None], n[:, None], last)
230
- if h is None: h = h2
231
- else:
232
- k = (b | ~v).view(1, -1, 1)
233
- h = (torch.where(k, h[0], h2[0]), torch.where(k, h[1], h2[1]))
234
- self.cache_length = None
235
- return torch.tensor(tok), torch.tensor(st)
236
-
237
- def make_srt(self, x, ts):
238
- t , s = x
239
- start_token_segment = self.config.languages.index(self.language) * self.joint.out_features
240
- all_tokens, all_starts, all_ends = [], [], []
241
- for tokens, starts, (s, e) in zip(t,s, ts):
242
- tokens += start_token_segment
243
- starts += s
244
- all_tokens.append(tokens)
245
- all_starts.append(starts)
246
- all_ends.append(torch.cat([starts[1:], e[None]]))
247
- all_tokens.append(torch.tensor([-1]))
248
- all_starts.append(torch.tensor([e]))
249
- all_ends.append(torch.tensor([e + 0.005]))
250
- return [srt.Subtitle(i,timedelta(seconds=float(st)),timedelta(seconds=float(en)),"<line>" if tok == -1 else self.config.vocab[int(tok)]) for i, (tok, st, en) in enumerate(zip(torch.cat(all_tokens), torch.cat(all_starts), torch.cat(all_ends)), 1)]
251
-
252
-
253
- @classmethod
254
- def from_pretrained(cls, pretrained_model_name_or_path, config = None, language=None,**kwargs):
255
- if language:
256
- config.languages = [language]
257
- config.vocab = ['<unk>'] + json.load(open(hf_hub_download(pretrained_model_name_or_path, "vocab.json")))['small'][language]
258
- else:
259
- temp_vocab = json.load(open(hf_hub_download(pretrained_model_name_or_path, "vocab.json")))['large']
260
- config.vocab = []
261
- for i in sorted(config.languages):
262
- config.vocab.extend(['<unk>'] + temp_vocab[i])
263
- model = cls(config)
264
- model.load_state_dict(load_file(hf_hub_download(pretrained_model_name_or_path, f"{language or 'all'}.safetensors")))
265
- return model
 
1
+ from datetime import timedelta
2
+ import gc
3
+ import json
4
+ from huggingface_hub import hf_hub_download
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import librosa
9
+ from torch import nn
10
+ from transformers import Wav2Vec2ConformerModel
11
+ from torch_state_bridge import state_bridge
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from safetensors.torch import load_file
14
+ import webrtcvad
15
+ from torch.utils.data import Dataset , DataLoader
16
+ import srt
17
+
18
+ def calc_length(lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
19
+ add_pad = all_paddings - kernel_size
20
+ for _ in range(repeat_num):
21
+ lengths = torch.floor((lengths.float() + add_pad) / stride + 1)
22
+ return lengths
23
+
24
+ class ChunkedData(Dataset):
25
+ def __init__(self, wav, sr):
26
+ if sr != 16000: wav = torchaudio.functional.resample(wav, sr, 16000)
27
+ wav = wav.mean(0, keepdim=True)
28
+ self.data, self.ts = self.make_chunks(wav)
29
+
30
+ def __len__(self): return len(self.data)
31
+ def __getitem__(self, i): return self.data[i], self.ts[i]
32
+
33
+ def make_chunks(self, wav, sr=16000, ag=2, min_s=10, max_s=15, ms=30):
34
+ w = (wav * 32768).clamp(-32768, 32767).short().squeeze(0)
35
+ fl = int(sr * ms / 1000)
36
+ nf = len(w) // fl
37
+ w = w[: nf * fl]
38
+ fr = w.view(nf, fl)
39
+ vad = webrtcvad.Vad(ag)
40
+ sp = torch.zeros(nf, dtype=torch.bool)
41
+ for i, f in enumerate(fr):
42
+ try: sp[i] = vad.is_speech(f.cpu().numpy().tobytes(), sr)
43
+ except: pass
44
+ seg, s = [], None
45
+ for i, v in enumerate(sp):
46
+ if v and s is None: s = i
47
+ elif not v and s is not None: seg.append((s, i)); s = None
48
+ if s is not None: seg.append((s, len(sp)))
49
+ cs, ts, st = [], [], 0
50
+ mn, mx, N = int(min_s * sr), int(max_s * sr), len(w)
51
+ while st < N:
52
+ ed = min(st + mx, N)
53
+ f = ed // fl
54
+ while f < len(sp) and sp[f]:
55
+ f += 1; ed = min(f * fl, N)
56
+ if ed - st > mx * 1.5: break
57
+ if ed - st < mn and ed < N: ed = min(st + mn, N)
58
+ cs.append(wav[:, st:ed].squeeze())
59
+ ts.append([round(st / sr, 2), round(ed / sr, 2)])
60
+ st = ed
61
+ return cs, torch.tensor(ts)
62
+
63
+
64
+
65
+ def padding_audio(batch):
66
+ audios, times = zip(*batch)
67
+ return pad_sequence(audios, batch_first=True), torch.tensor([audio.numel() for audio in audios]), torch.stack(times)
68
+
69
+ class Op(nn.Module):
70
+ def __init__(self, func,allow_self=False):
71
+ super().__init__()
72
+ self.func = func
73
+ self.allow_self = allow_self
74
+
75
+ def forward(self, x):
76
+ if self.allow_self:
77
+ return self.func(self,x)
78
+ return self.func(x)
79
+
80
+ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
81
+ def __init__(self, config):
82
+ self.language = config.languages[0]
83
+ if len(config.languages) > 1:
84
+ config.hidden_size = 1024
85
+ config.num_hidden_layers = 24
86
+ config.conv_depthwise_kernel_size = 9
87
+ config.conv_stride = [2,2,2]
88
+ config.conv_kernel = [3,3,3]
89
+ config.conv_dim = [256,256,256]
90
+ config.feat_extract_norm = "group"
91
+ config.intermediate_size = 4096
92
+ config.num_feat_extract_layers = len(config.conv_dim)
93
+ config.lstm_layer = 2
94
+
95
+ self.cache_length = None
96
+ self.hop, self.preemph, self.eps, self.pad_to = 160, 0.97, 2**-24, 16
97
+ self.denorm = (2 ** config.num_feat_extract_layers) * self.hop / config.sampling_rate
98
+ self.scaler = config.hidden_size ** (1/2)
99
+ super().__init__(config)
100
+ self.eval()
101
+
102
+ def init_weights(self):
103
+ del self.encoder.pos_conv_embed
104
+ config = self.config
105
+ self.enc = nn.Linear(config.hidden_size, config.joint_hidden)
106
+ self.pred = nn.Linear(config.pred_hidden, config.joint_hidden)
107
+ self.joint = nn.Linear(config.joint_hidden, config.vocab_size // 22 + 1)
108
+ self.embed = nn.Embedding(config.vocab_size+1, config.pred_hidden, padding_idx=config.vocab_size)
109
+ self.lstm = nn.LSTM(config.pred_hidden, config.pred_hidden, config.lstm_layer, batch_first=True)
110
+ self.act = nn.ReLU(inplace=True)
111
+ self.spec = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=160, win_length=400, center=False)
112
+ self.mask_layer = Op(lambda self_obj,x : x.masked_fill(self_obj.cache_pad_mask.unsqueeze(1), 0),True)
113
+ self.mel_fb = nn.Parameter(torch.tensor(librosa.filters.mel(sr=self.config.sampling_rate, n_fft=512, n_mels=80)),False)
114
+
115
+ for idx,l in enumerate(self.feature_extractor.conv_layers):
116
+ if len(self.config.languages) == 1 or idx == 0:
117
+ l.conv = nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1)
118
+ l.layer_norm = nn.Identity()
119
+ else:
120
+ l.conv = nn.Sequential(nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1,groups=l.conv.out_channels),nn.Conv2d(l.conv.in_channels,l.conv.out_channels, 1))
121
+
122
+ self.feature_extractor.conv_layers.append(Op(lambda x : x.transpose(1, 2)))
123
+ self.feature_projection.projection = nn.Linear(config.conv_dim[-1] * int(calc_length(torch.tensor(80.),repeat_num=self.config.num_feat_extract_layers)),config.hidden_size)
124
+ self.feature_projection.layer_norm = Op(lambda x:x.permute(0, 2, 1, 3).flatten(2))
125
+ for l in self.encoder.layers:
126
+ l.conv_module.glu = nn.Sequential(l.conv_module.glu,self.mask_layer)
127
+ l.conv_module.pointwise_conv1.bias = nn.Parameter(torch.empty(l.conv_module.pointwise_conv1.out_channels))
128
+ l.conv_module.pointwise_conv2.bias = nn.Parameter(torch.empty(l.conv_module.pointwise_conv2.out_channels))
129
+ l.conv_module.depthwise_conv.bias = nn.Parameter(torch.empty(l.conv_module.depthwise_conv.out_channels))
130
+ self.encoder.layer_norm = nn.Identity()
131
+ if len(self.config.languages) > 1:
132
+ self.lang_joint_net = nn.ModuleDict({l: nn.Linear(config.joint_hidden, config.vocab_size // 22 + 1) for l in config.languages})
133
+ return super().init_weights()
134
+
135
+ def _mask_hidden_states(self, hidden_states, mask_time_indices = None, attention_mask = None):
136
+ hidden_states = hidden_states * self.scaler
137
+ self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
138
+ return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
139
+
140
+ def preprocessing(self, x):
141
+ x, l = x
142
+ l = (l // self.hop + 1).long()
143
+ x = torch.cat((x[:, :1], x[:, 1:] - self.preemph * x[:, :-1]), 1)
144
+ x = (self.mel_fb @ self.spec(x) + self.eps).log()
145
+ T = x.size(-1)
146
+ m = torch.arange(T, device=x.device)[None] >= l[:, None]
147
+ x = x.masked_fill(m[:, None], 0)
148
+ μ = x.sum(-1) / l[:, None]
149
+ σ = (((x - μ[..., None])**2).sum(-1) / (l[:, None] - 1) + 1e-5).sqrt()
150
+ x = ((x - μ[..., None]) / σ[..., None]).masked_fill(m[:, None], 0)
151
+ self.cache_length = calc_length(l, repeat_num=self.config.num_feat_extract_layers).long()
152
+ return F.pad(x, (0, (-T) % self.pad_to)).transpose(1, 2)
153
+
154
+ def forward(self, input_values):
155
+ return self.postprocessing(super().forward(self.preprocessing(input_values)).last_hidden_state)
156
+
157
+ @torch.inference_mode()
158
+ def transcribe(self,wav,sr,batch_size):
159
+ device = next(self.parameters()).device
160
+ subtitles = []
161
+ for batch, lengths, timestamp in DataLoader(ChunkedData(wav, sr),batch_size,collate_fn=padding_audio):
162
+ batch = batch.to(device)
163
+ lengths = lengths.to(device)
164
+ timestamp = timestamp.to(device)
165
+ subtitles.extend(self.make_srt(self.forward((batch, lengths)),timestamp))
166
+ yield srt.compose(subtitles)
167
+ torch.cuda.empty_cache()
168
+ gc.collect()
169
+
170
+ def load_state_dict(self, state_dict, strict=True, assign=False):
171
+ del state_dict['ctc_decoder.decoder_layers.0.bias']
172
+ del state_dict['ctc_decoder.decoder_layers.0.weight']
173
+ state_dict['preprocessor.featurizer.fb'] = state_dict['preprocessor.featurizer.fb'].squeeze(0)
174
+ changes = """
175
+ preprocessor.featurizer.fb,mel_fb
176
+ preprocessor.featurizer.window,spec.window
177
+ norm_feed_forward1,ffn1_layer_norm
178
+ norm_feed_forward2,ffn2_layer_norm
179
+ feed_forward1.linear1,ffn1.intermediate_dense
180
+ feed_forward1.linear2,ffn1.output_dense
181
+ feed_forward2.linear1,ffn2.intermediate_dense
182
+ feed_forward2.linear2,ffn2.output_dense
183
+ norm_self_att,self_attn_layer_norm
184
+ norm_out,final_layer_norm
185
+ norm_conv,conv_module.layer_norm
186
+ .conv.,.conv_module.
187
+ decoder.prediction.dec_rnn.lstm,lstm
188
+ decoder.prediction.embed,embed
189
+ joint.enc,enc
190
+ joint.pred,pred
191
+ joint.joint_net.2,lang_joint_net
192
+ encoder.pre_encode.conv_module.0,feature_extractor.conv_layers.0.conv
193
+ encoder.pre_encode.out,feature_projection.projection
194
+ """
195
+ if len(self.config.languages) == 1:
196
+ changes += f"""lang_joint_net.{self.language},joint
197
+ encoder.pre_encode.conv_module.{{n}},feature_extractor.conv_layers.{{(n/2)}}.conv"""
198
+ else:
199
+ state_dict["joint.weight"] = self.joint.weight.clone()
200
+ state_dict["joint.bias"] = self.joint.bias.clone()
201
+ changes += """encoder.pre_encode.conv_module.{n},encoder.pre_encode.conv_module.{(n-2)}
202
+ encoder.pre_encode.conv_module.{n},feature_extractor.conv_layers.{(n//3+1)}.conv.{(n%3)}
203
+ """
204
+ # replicate many changes for complex maths
205
+ state_dict = state_bridge(state_dict, changes)
206
+ if len(self.config.languages) == 1:
207
+ state_dict = {k: v for k, v in state_dict.items() if "lang_joint_net" not in k}
208
+ return super().load_state_dict(state_dict, strict, assign)
209
+
210
+ def postprocessing(self, x):
211
+ if len(self.config.languages) > 1:
212
+ self.joint.load_state_dict(self.lang_joint_net[self.language].state_dict())
213
+ B = x.size(0)
214
+ last = x.new_full((B, 1), self.config.blank_id, dtype=torch.long)
215
+ h, tok, st = None, [[] for _ in range(B)], [[] for _ in range(B)]
216
+ for t, e in enumerate(x.unbind(1)):
217
+ v = t < self.cache_length
218
+ if not v.any(): break
219
+ e = e[:, None]
220
+ for _ in range(self.config.max_symbols_per_step):
221
+ p, h2 = self.lstm(self.embed(last), h)
222
+ lg = self.joint(self.act(self.enc(e) + self.pred(p))).squeeze(1)
223
+ n = torch.where(v, lg.argmax(-1), self.config.blank_id)
224
+ b = n.eq(self.config.blank_id)
225
+ if b.all(): break
226
+ a = v & ~b
227
+ for i in a.nonzero().flatten().tolist():
228
+ tok[i].append(n[i]); st[i].append(t * self.denorm)
229
+ last = torch.where(a[:, None], n[:, None], last)
230
+ if h is None: h = h2
231
+ else:
232
+ k = (b | ~v).view(1, -1, 1)
233
+ h = (torch.where(k, h[0], h2[0]), torch.where(k, h[1], h2[1]))
234
+ self.cache_length = None
235
+ return [torch.tensor(i) for i in tok], [torch.tensor(i) for i in st]
236
+
237
+ def make_srt(self, x, ts):
238
+ t , s = x
239
+ start_token_segment = self.config.languages.index(self.language) * self.joint.out_features
240
+ all_tokens, all_starts, all_ends = [], [], []
241
+ for tokens, starts, (s, e) in zip(t,s, ts):
242
+ tokens += start_token_segment
243
+ starts += s
244
+ all_tokens.append(tokens)
245
+ all_starts.append(starts)
246
+ all_ends.append(torch.cat([starts[1:], e[None]]))
247
+ all_tokens.append(torch.tensor([-1]))
248
+ all_starts.append(torch.tensor([e]))
249
+ all_ends.append(torch.tensor([e + 0.005]))
250
+ return [srt.Subtitle(i,timedelta(seconds=float(st)),timedelta(seconds=float(en)),"<line>" if tok == -1 else self.config.vocab[int(tok)]) for i, (tok, st, en) in enumerate(zip(torch.cat(all_tokens), torch.cat(all_starts), torch.cat(all_ends)), 1)]
251
+
252
+
253
+ @classmethod
254
+ def from_pretrained(cls, pretrained_model_name_or_path, config = None, language=None,**kwargs):
255
+ if language:
256
+ config.languages = [language]
257
+ config.vocab = ['<unk>'] + json.load(open(hf_hub_download(pretrained_model_name_or_path, "vocab.json")))['small'][language]
258
+ else:
259
+ temp_vocab = json.load(open(hf_hub_download(pretrained_model_name_or_path, "vocab.json")))['large']
260
+ config.vocab = []
261
+ for i in sorted(config.languages):
262
+ config.vocab.extend(['<unk>'] + temp_vocab[i])
263
+ model = cls(config)
264
+ model.load_state_dict(load_file(hf_hub_download(pretrained_model_name_or_path, f"{language or 'all'}.safetensors")))
265
+ return model