R-Kentaren commited on
Commit
e9d3545
ยท
verified ยท
1 Parent(s): db00b54

Create fairseq/fairseq.py

Browse files
Files changed (1) hide show
  1. fairseq/fairseq.py +1490 -0
fairseq/fairseq.py ADDED
@@ -0,0 +1,1490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import sys
4
+ import math
5
+ import uuid
6
+ import torch
7
+ import types
8
+ import contextlib
9
+
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+
13
+ from torch import nn
14
+ from torch.nn import Parameter
15
+ from omegaconf import DictConfig, open_dict
16
+
17
+ os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
18
+ os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0"
19
+
20
+ class Dictionary:
21
+ def __init__(self, *args, **kwargs):
22
+ pass
23
+
24
+ fairseq = types.ModuleType("fairseq")
25
+ fairseq_data = types.ModuleType("fairseq.data")
26
+ fairseq_data_dictionary = types.ModuleType("fairseq.data.dictionary")
27
+
28
+ fairseq_data_dictionary.Dictionary = Dictionary
29
+ fairseq.data = fairseq_data
30
+ fairseq_data.dictionary = fairseq_data_dictionary
31
+
32
+ sys.modules["fairseq"] = fairseq
33
+ sys.modules["fairseq.data"] = fairseq_data
34
+ sys.modules["fairseq.data.dictionary"] = fairseq_data_dictionary
35
+
36
+
37
+ def load_model(filename):
38
+ state = torch.load(filename, map_location="cpu")
39
+
40
+ model = HubertModel(HubertConfig(**state['cfg']['model']))
41
+ model.load_state_dict(state['model'], strict=False)
42
+
43
+ cfg = Model_Config(state["cfg"])
44
+ task = Model_Config(state["cfg"]["task"])
45
+
46
+ return [model], cfg, task
47
+
48
+ def softmax(x, dim, onnx_trace = False):
49
+ return F.softmax(x.float(), dim=dim) if onnx_trace else F.softmax(x, dim=dim, dtype=torch.float32)
50
+
51
+ def log_softmax(x, dim, onnx_trace = False):
52
+ return F.log_softmax(x.float(), dim=dim) if onnx_trace else F.log_softmax(x, dim=dim, dtype=torch.float32)
53
+
54
+ def eval_str_dict(x, type=dict):
55
+ if x is None: return None
56
+ if isinstance(x, str): x = eval(x)
57
+ return x
58
+
59
+ def with_incremental_state(cls):
60
+ cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
61
+ return cls
62
+
63
+ def quant_noise(module, p, block_size):
64
+ if p <= 0: return module
65
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
66
+
67
+ is_conv = module.weight.ndim == 4
68
+ if not is_conv: assert (module.weight.size(1) % block_size == 0)
69
+ else:
70
+ if module.kernel_size == (1, 1): assert (module.in_channels % block_size == 0)
71
+ else:
72
+ k = module.kernel_size[0] * module.kernel_size[1]
73
+ assert k % block_size == 0
74
+
75
+ def _forward_pre_hook(mod, input):
76
+ if mod.training:
77
+ if not is_conv:
78
+ weight = mod.weight
79
+ in_features = weight.size(1)
80
+ out_features = weight.size(0)
81
+
82
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
83
+ mask.bernoulli_(p)
84
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
85
+ else:
86
+ weight = mod.weight
87
+ in_channels = mod.in_channels
88
+ out_channels = mod.out_channels
89
+
90
+ if mod.kernel_size == (1, 1):
91
+ mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device)
92
+ mask.bernoulli_(p)
93
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
94
+ else:
95
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
96
+ mask.bernoulli_(p)
97
+ mask = (mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
98
+
99
+ mask = mask.to(torch.bool)
100
+ s = 1 / (1 - p)
101
+ mod.weight.data = s * weight.masked_fill(mask, 0)
102
+
103
+ module.register_forward_pre_hook(_forward_pre_hook)
104
+ return module
105
+
106
+ class FairseqDropout(nn.Module):
107
+ def __init__(self, p, module_name=None):
108
+ super().__init__()
109
+ self.p = p
110
+ self.module_name = module_name
111
+ self.apply_during_inference = False
112
+
113
+ def forward(self, x, inplace = False):
114
+ return F.dropout(x, p=self.p, training=True, inplace=inplace) if self.p > 0 and (self.training or self.apply_during_inference) else x
115
+
116
+ def make_generation_fast_(self, name, retain_dropout = False, retain_dropout_modules = None, **kwargs):
117
+ if retain_dropout:
118
+ if (retain_dropout_modules is None or self.module_name in retain_dropout_modules): self.apply_during_inference = True
119
+
120
+ class FairseqIncrementalState(object):
121
+ def __init__(self, *args, **kwargs):
122
+ super().__init__(*args, **kwargs)
123
+ self.init_incremental_state()
124
+
125
+ def init_incremental_state(self):
126
+ self._incremental_state_id = str(uuid.uuid4())
127
+
128
+ def _get_full_incremental_state_key(self, key):
129
+ return "{}.{}".format(self._incremental_state_id, key)
130
+
131
+ def get_incremental_state(self, incremental_state, key):
132
+ full_key = self._get_full_incremental_state_key(key)
133
+ if incremental_state is None or full_key not in incremental_state: return None
134
+ return incremental_state[full_key]
135
+
136
+ def set_incremental_state(self, incremental_state, key, value):
137
+ if incremental_state is not None: incremental_state[self._get_full_incremental_state_key(key)] = value
138
+ return incremental_state
139
+
140
+ class FairseqDecoder(nn.Module):
141
+ def __init__(self, dictionary):
142
+ super().__init__()
143
+ self.dictionary = dictionary
144
+ self.onnx_trace = False
145
+ self.adaptive_softmax = None
146
+
147
+ def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
148
+ x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
149
+ return self.output_layer(x), extra
150
+
151
+ def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
152
+ pass
153
+
154
+ def output_layer(self, features, **kwargs):
155
+ pass
156
+
157
+ def get_normalized_probs(self, net_output, log_probs, sample = None):
158
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
159
+
160
+ def get_normalized_probs_scriptable(self, net_output, log_probs, sample = None):
161
+ if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
162
+ if sample is not None:
163
+ assert "target" in sample
164
+ target = sample["target"]
165
+ else: target = None
166
+
167
+ out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
168
+ return out.exp_() if not log_probs else out
169
+
170
+ logits = net_output[0]
171
+ return log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) if log_probs else softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
172
+
173
+ def max_positions(self):
174
+ return 1e6
175
+
176
+ def upgrade_state_dict_named(self, state_dict, name):
177
+ return state_dict
178
+
179
+ def prepare_for_onnx_export_(self):
180
+ self.onnx_trace = True
181
+
182
+ @with_incremental_state
183
+ class FairseqIncrementalDecoder(FairseqDecoder):
184
+ def __init__(self, dictionary):
185
+ super().__init__(dictionary)
186
+
187
+ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
188
+ pass
189
+
190
+ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
191
+ pass
192
+
193
+ def reorder_incremental_state(self, incremental_state, new_order):
194
+ pass
195
+
196
+ def reorder_incremental_state_scripting(self, incremental_state, new_order):
197
+ for module in self.modules():
198
+ if hasattr(module, "reorder_incremental_state"):
199
+ result = module.reorder_incremental_state(incremental_state, new_order)
200
+ if result is not None: incremental_state = result
201
+
202
+ def set_beam_size(self, beam_size):
203
+ if getattr(self, "_beam_size", -1) != beam_size:
204
+ seen = set()
205
+
206
+ def apply_set_beam_size(module):
207
+ if (module != self and hasattr(module, "set_beam_size") and module not in seen):
208
+ seen.add(module)
209
+ module.set_beam_size(beam_size)
210
+
211
+ self.apply(apply_set_beam_size)
212
+ self._beam_size = beam_size
213
+
214
+ class MultiheadAttention(FairseqIncrementalDecoder):
215
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, dictionary=None, q_noise=0.0, qn_block_size=8, xformers_att_config=None, xformers_blocksparse_layout=None, xformers_blocksparse_blocksize=16):
216
+ super().__init__(dictionary)
217
+ xformers_att_config = eval_str_dict(xformers_att_config)
218
+ self.use_xformers = xformers_att_config is not None
219
+ if self.use_xformers: raise ImportError
220
+ self.embed_dim = embed_dim
221
+ self.kdim = kdim if kdim is not None else embed_dim
222
+ self.vdim = vdim if vdim is not None else embed_dim
223
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
224
+ self.num_heads = num_heads
225
+ self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
226
+ self.head_dim = embed_dim // num_heads
227
+ assert (self.head_dim * num_heads == self.embed_dim)
228
+ self.scaling = self.head_dim**-0.5
229
+ self.self_attention = self_attention
230
+ self.encoder_decoder_attention = encoder_decoder_attention
231
+ assert not self.self_attention or self.qkv_same_dim
232
+ self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
233
+ self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
234
+ self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
235
+ self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
236
+ if add_bias_kv:
237
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
238
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
239
+ else: self.bias_k = self.bias_v = None
240
+ self.add_zero_attn = add_zero_attn
241
+ self.beam_size = 1
242
+ self.reset_parameters()
243
+ self.onnx_trace = False
244
+ self.skip_embed_dim_check = False
245
+ self.init_incremental_state()
246
+
247
+ def prepare_for_onnx_export_(self):
248
+ self.onnx_trace = True
249
+
250
+ def reset_parameters(self):
251
+ if self.qkv_same_dim:
252
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
253
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
254
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
255
+ else:
256
+ nn.init.xavier_uniform_(self.k_proj.weight)
257
+ nn.init.xavier_uniform_(self.v_proj.weight)
258
+ nn.init.xavier_uniform_(self.q_proj.weight)
259
+
260
+ nn.init.xavier_uniform_(self.out_proj.weight)
261
+
262
+ if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0)
263
+ if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k)
264
+ if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v)
265
+
266
+ def _get_reserve_head_index(self, num_heads_to_keep: int):
267
+ k_proj_heads_norm, q_proj_heads_norm, v_proj_heads_norm = [], [], []
268
+ for i in range(self.num_heads):
269
+ start_idx = i * self.head_dim
270
+ end_idx = (i + 1) * self.head_dim
271
+ k_proj_heads_norm.append(torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist())
272
+ q_proj_heads_norm.append(torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist())
273
+ v_proj_heads_norm.append(torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist())
274
+
275
+ heads_norm = []
276
+ for i in range(self.num_heads):
277
+ heads_norm.append(k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i])
278
+
279
+ sorted_head_index = sorted(range(self.num_heads), key=lambda k: heads_norm[k], reverse=True)
280
+ reserve_head_index = []
281
+ for i in range(num_heads_to_keep):
282
+ reserve_head_index.append((sorted_head_index[i] * self.head_dim, (sorted_head_index[i] + 1) * self.head_dim))
283
+ return reserve_head_index
284
+
285
+ def _adaptive_prune_heads(self, reserve_head_index):
286
+ new_q_weight, new_q_bias, new_k_weight, new_k_bias, new_v_weight, new_v_bias, new_out_proj_weight = [], [], [], [], [], [], []
287
+
288
+ for ele in reserve_head_index:
289
+ start_idx, end_idx = ele
290
+ new_q_weight.append(self.q_proj.weight[start_idx:end_idx])
291
+ new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
292
+ new_k_weight.append(self.k_proj.weight[start_idx:end_idx])
293
+ new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
294
+ new_v_weight.append(self.v_proj.weight[start_idx:end_idx])
295
+ new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
296
+ new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
297
+
298
+ new_q_weight = torch.cat(new_q_weight).detach()
299
+ new_k_weight = torch.cat(new_k_weight).detach()
300
+ new_v_weight = torch.cat(new_v_weight).detach()
301
+ new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
302
+ new_q_weight.requires_grad = True
303
+ new_k_weight.requires_grad = True
304
+ new_v_weight.requires_grad = True
305
+ new_out_proj_weight.requires_grad = True
306
+ new_q_bias = torch.cat(new_q_bias).detach()
307
+ new_q_bias.requires_grad = True
308
+ new_k_bias = torch.cat(new_k_bias).detach()
309
+ new_k_bias.requires_grad = True
310
+ new_v_bias = torch.cat(new_v_bias).detach()
311
+ new_v_bias.requires_grad = True
312
+
313
+ self.q_proj.weight = nn.Parameter(new_q_weight)
314
+ self.q_proj.bias = nn.Parameter(new_q_bias)
315
+ self.k_proj.weight = nn.Parameter(new_k_weight)
316
+ self.k_proj.bias = nn.Parameter(new_k_bias)
317
+ self.v_proj.weight = nn.Parameter(new_v_weight)
318
+ self.v_proj.bias = nn.Parameter(new_v_bias)
319
+ self.out_proj.weight = nn.Parameter(new_out_proj_weight)
320
+ self.num_heads = len(reserve_head_index)
321
+ self.embed_dim = self.head_dim * self.num_heads
322
+ self.q_proj.out_features = self.embed_dim
323
+ self.k_proj.out_features = self.embed_dim
324
+ self.v_proj.out_features = self.embed_dim
325
+
326
+ def _set_skip_embed_dim_check(self):
327
+ self.skip_embed_dim_check = True
328
+
329
+ def _pad_masks(self, key_padding_mask, attn_mask):
330
+ if attn_mask is not None:
331
+ shape = attn_mask.size()[:-1] + torch.Size([1])
332
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
333
+
334
+ if key_padding_mask is not None:
335
+ shape = key_padding_mask.size()[:-1] + torch.Size([1])
336
+ key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(shape)], dim=-1)
337
+
338
+ return key_padding_mask, attn_mask
339
+
340
+ def _add_bias(self, k, v, key_padding_mask, attn_mask, bsz):
341
+ assert self.bias_k is not None or self.bias_v is not None
342
+ key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
343
+ return torch.cat([k, self.bias_k.repeat(1, bsz, 1)]), torch.cat([v, self.bias_v.repeat(1, bsz, 1)]), key_padding_mask, attn_mask
344
+
345
+ def _append_zero_attn(self, k, v, key_padding_mask, attn_mask):
346
+ zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
347
+ key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
348
+ return torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2), torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2), key_padding_mask, attn_mask
349
+
350
+ def forward(self, query, key, value, key_padding_mask = None, incremental_state = None, need_weights = True, static_kv = False, attn_mask = None, before_softmax = False, need_head_weights = False):
351
+ if need_head_weights: need_weights = True
352
+ is_tpu = query.device.type == "xla"
353
+ tgt_len, bsz, embed_dim = query.size()
354
+ src_len = tgt_len
355
+
356
+ if not self.skip_embed_dim_check: assert (embed_dim == self.embed_dim)
357
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
358
+
359
+ if key is not None:
360
+ src_len, key_bsz, _ = key.size()
361
+ if not torch.jit.is_scripting():
362
+ assert value is not None
363
+ assert src_len, key_bsz == value.shape[:2]
364
+
365
+ if (not self.onnx_trace and not is_tpu and incremental_state is None and not static_kv and not torch.jit.is_scripting() and not self.skip_embed_dim_check):
366
+ assert key is not None and value is not None
367
+ return F.multi_head_attention_forward(query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask.bool() if key_padding_mask is not None else None, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight)
368
+
369
+ if incremental_state is not None:
370
+ saved_state = self._get_input_buffer(incremental_state)
371
+ if saved_state is not None and "prev_key" in saved_state:
372
+ if static_kv:
373
+ assert self.encoder_decoder_attention and not self.self_attention
374
+ key = value = None
375
+ else: saved_state = None
376
+
377
+ if self.self_attention:
378
+ q = self.q_proj(query)
379
+ k = self.k_proj(query)
380
+ v = self.v_proj(query)
381
+ elif self.encoder_decoder_attention:
382
+ q = self.q_proj(query)
383
+ if key is None:
384
+ assert value is None
385
+ k = v = None
386
+ else:
387
+ if self.beam_size > 1 and bsz == key.size(1):
388
+ key = key.view(key.size(0), -1, self.beam_size, key.size(2))[:, :, 0, :]
389
+ if key_padding_mask is not None: key_padding_mask = key_padding_mask.view(-1, self.beam_size, key_padding_mask.size(1))[:, 0, :]
390
+ k = self.k_proj(key)
391
+ v = self.v_proj(key)
392
+ else:
393
+ assert key is not None and value is not None
394
+ q = self.q_proj(query)
395
+ k = self.k_proj(key)
396
+ v = self.v_proj(value)
397
+
398
+ q *= self.scaling
399
+
400
+ if self.bias_k is not None:
401
+ assert self.bias_v is not None
402
+ k, v, attn_mask, key_padding_mask = self._add_bias(k, v, attn_mask, key_padding_mask, bsz)
403
+
404
+ q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1))
405
+ kv_bsz = bsz
406
+
407
+ if k is not None:
408
+ kv_bsz = k.size(1)
409
+ k = (k.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
410
+
411
+ if v is not None: v = (v.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
412
+
413
+ if saved_state is not None:
414
+ if "prev_key" in saved_state:
415
+ _prev_key = saved_state["prev_key"]
416
+ assert _prev_key is not None
417
+
418
+ kv_bsz = _prev_key.size(0)
419
+ prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
420
+
421
+ if static_kv: k = prev_key
422
+ else:
423
+ assert k is not None
424
+ k = torch.cat([prev_key, k], dim=1)
425
+ src_len = k.size(1)
426
+
427
+ if "prev_value" in saved_state:
428
+ _prev_value = saved_state["prev_value"]
429
+ assert _prev_value is not None or kv_bsz == _prev_value.size(0)
430
+ prev_value = _prev_value.view(kv_bsz * self.num_heads, -1, self.head_dim)
431
+
432
+ if static_kv: v = prev_value
433
+ else:
434
+ assert v is not None
435
+ v = torch.cat([prev_value, v], dim=1)
436
+
437
+ prev_key_padding_mask = None
438
+ if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"]
439
+
440
+ assert k is not None and v is not None
441
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=kv_bsz, src_len=k.size(1), static_kv=static_kv)
442
+
443
+ saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
444
+ saved_state["prev_value"] = v.view(kv_bsz, self.num_heads, -1, self.head_dim)
445
+ saved_state["prev_key_padding_mask"] = key_padding_mask
446
+
447
+ assert incremental_state is not None
448
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
449
+
450
+ assert k is not None
451
+ assert k.size(1) == src_len
452
+
453
+ if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None
454
+
455
+ if key_padding_mask is not None:
456
+ assert key_padding_mask.size(0) == kv_bsz
457
+ assert key_padding_mask.size(1) == src_len
458
+
459
+ if self.add_zero_attn:
460
+ assert v is not None
461
+ src_len += 1
462
+ k, v, key_padding_mask, attn_mask = self._append_zero_attn(k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
463
+
464
+ if self.encoder_decoder_attention and bsz != kv_bsz:
465
+ attn_weights = torch.einsum("bxhtd,bhsd->bxhts", q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), k.view((kv_bsz, self.num_heads) + k.size()[1:]))
466
+ attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
467
+ else: attn_weights = torch.bmm(q, k.transpose(1, 2))
468
+
469
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
470
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
471
+
472
+ if attn_mask is not None:
473
+ attn_mask = attn_mask.unsqueeze(0)
474
+ if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
475
+ attn_weights += attn_mask
476
+
477
+ if key_padding_mask is not None:
478
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
479
+ attn_weights = attn_weights.view(kv_bsz, -1, self.num_heads, tgt_len, src_len).masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(torch.bool), float("-inf")) if not is_tpu else attn_weights.transpose(0, 2).masked_fill(key_padding_mask, float("-inf")).transpose(0, 2)
480
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
481
+
482
+ if before_softmax: return attn_weights, v
483
+
484
+ attn_weights_float = softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
485
+ attn_weights = attn_weights_float.type_as(attn_weights)
486
+ attn_probs = self.dropout_module(attn_weights)
487
+
488
+ assert v is not None
489
+ attn = None
490
+
491
+ if self.encoder_decoder_attention and bsz != kv_bsz:
492
+ attn = torch.einsum("bxhts,bhsd->bxhtd", attn_probs.view((kv_bsz, -1, self.num_heads) + attn_probs.size()[1:]), v.view((kv_bsz, self.num_heads) + v.size()[1:]))
493
+ attn = attn.reshape((-1,) + attn.size()[-2:])
494
+ else: attn = torch.bmm(attn_probs, v)
495
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
496
+
497
+ if self.onnx_trace and attn.size(1) == 1: attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
498
+ else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
499
+
500
+ attn = self.out_proj(attn)
501
+ attn_weights = None
502
+
503
+ if need_weights:
504
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
505
+ if not need_head_weights: attn_weights = attn_weights.mean(dim=0)
506
+
507
+ return attn, attn_weights
508
+
509
+ @staticmethod
510
+ def _append_prev_key_padding_mask(key_padding_mask, prev_key_padding_mask, batch_size, src_len, static_kv):
511
+ if prev_key_padding_mask is not None and static_kv: new_key_padding_mask = prev_key_padding_mask
512
+ elif prev_key_padding_mask is not None and key_padding_mask is not None: new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
513
+ elif prev_key_padding_mask is not None:
514
+ if src_len > prev_key_padding_mask.size(1):
515
+ filler = torch.zeros((batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device)
516
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
517
+ else: new_key_padding_mask = prev_key_padding_mask.float()
518
+ elif key_padding_mask is not None:
519
+ if src_len > key_padding_mask.size(1):
520
+ filler = torch.zeros((batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device)
521
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
522
+ else: new_key_padding_mask = key_padding_mask.float()
523
+ else: new_key_padding_mask = prev_key_padding_mask
524
+ return new_key_padding_mask
525
+
526
+ @torch.jit.export
527
+ def reorder_incremental_state(self, incremental_state, new_order):
528
+ input_buffer = self._get_input_buffer(incremental_state)
529
+ if input_buffer is not None:
530
+ for k in input_buffer.keys():
531
+ input_buffer_k = input_buffer[k]
532
+ if input_buffer_k is not None:
533
+ if self.encoder_decoder_attention:
534
+ if input_buffer_k.size(0) * self.beam_size == new_order.size(0): return incremental_state
535
+ elif self.beam_size > 1: input_buffer[k] = input_buffer_k.index_select(0, new_order.reshape(-1, self.beam_size)[:, 0] // self.beam_size)
536
+ else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
537
+ else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
538
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
539
+ return incremental_state
540
+
541
+ def set_beam_size(self, beam_size):
542
+ self.beam_size = beam_size
543
+
544
+ def _get_input_buffer(self, incremental_state):
545
+ result = self.get_incremental_state(incremental_state, "attn_state")
546
+ if result is not None: return result
547
+ else: return {}
548
+
549
+ def _set_input_buffer(self, incremental_state, buffer):
550
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
551
+
552
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
553
+ return attn_weights
554
+
555
+ def upgrade_state_dict_named(self, state_dict, name):
556
+ prefix = name + "." if name != "" else ""
557
+ items_to_add = {}
558
+ keys_to_remove = []
559
+ for k in state_dict.keys():
560
+ if k.endswith(prefix + "in_proj_weight"):
561
+ dim = int(state_dict[k].shape[0] / 3)
562
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
563
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
564
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
565
+ keys_to_remove.append(k)
566
+ k_bias = prefix + "in_proj_bias"
567
+ if k_bias in state_dict.keys():
568
+ dim = int(state_dict[k].shape[0] / 3)
569
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
570
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
571
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
572
+ keys_to_remove.append(prefix + "in_proj_bias")
573
+
574
+ for k in keys_to_remove:
575
+ del state_dict[k]
576
+
577
+ for key, value in items_to_add.items():
578
+ state_dict[key] = value
579
+
580
+ def init_bert_params(module):
581
+ def normal_(data):
582
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
583
+
584
+ if isinstance(module, nn.Linear):
585
+ normal_(module.weight.data)
586
+ if module.bias is not None: module.bias.data.zero_()
587
+ if isinstance(module, nn.Embedding):
588
+ normal_(module.weight.data)
589
+ if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_()
590
+ if isinstance(module, MultiheadAttention):
591
+ normal_(module.q_proj.weight.data)
592
+ normal_(module.k_proj.weight.data)
593
+ normal_(module.v_proj.weight.data)
594
+
595
+ def make_conv_pos(e, k, g):
596
+ pos_conv = nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g)
597
+ dropout = 0
598
+
599
+ nn.init.normal_(pos_conv.weight, mean=0, std=math.sqrt((4 * (1.0 - dropout)) / (k * e)))
600
+ nn.init.constant_(pos_conv.bias, 0)
601
+
602
+ return nn.Sequential(nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2), SamePad(k), nn.GELU())
603
+
604
+ def is_xla_tensor(tensor):
605
+ return torch.is_tensor(tensor) and tensor.device.type == "xla"
606
+
607
+ def index_put(tensor, indices, value):
608
+ if is_xla_tensor(tensor):
609
+ for _ in range(indices.dim(), tensor.dim()):
610
+ indices = indices.unsqueeze(-1)
611
+
612
+ if indices.size(-1) < tensor.size(-1): indices = indices.expand_as(tensor)
613
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
614
+ else: tensor[indices] = value
615
+
616
+ return tensor
617
+
618
+ def pad_to_multiple(x, multiple, dim=-1, value=0):
619
+ if x is None: return None, 0
620
+ tsz = x.size(dim)
621
+ m = tsz / multiple
622
+ remainder = math.ceil(m) * multiple - tsz
623
+ if m.is_integer(): return x, 0
624
+ return F.pad(x, (*((0,) * (-1 - dim) * 2), 0, remainder), value=value), remainder
625
+
626
+ def compute_mask_indices(shape, padding_mask, mask_prob, mask_length, mask_type = "static", mask_other = 0.0, min_masks = 0, no_overlap = False, min_space = 0, require_same_masks = True, mask_dropout = 0.0, add_masks = False, seed = None, epoch = None, indices = None, idc_select_ver = 1, num_mask_ver = 2):
627
+ bsz, all_sz = shape
628
+ mask = np.full((bsz, all_sz), False)
629
+
630
+ if num_mask_ver == 1: all_num_mask = max(min_masks, int(mask_prob * all_sz / float(mask_length) + np.random.rand()))
631
+ mask_idcs = []
632
+
633
+ for i in range(bsz):
634
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) if seed is not None and epoch is not None and indices is not None else None
635
+ rng = np.random.default_rng(seed_i)
636
+
637
+ if padding_mask is not None:
638
+ sz = all_sz - padding_mask[i].long().sum().item()
639
+ assert sz >= 0, sz
640
+ else: sz = all_sz
641
+
642
+ if num_mask_ver == 1: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + np.random.rand())) if padding_mask is not None else all_num_mask
643
+ elif num_mask_ver == 2: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + rng.random()))
644
+ else: raise ValueError
645
+
646
+ if mask_type == "static": lengths = np.full(num_mask, mask_length)
647
+ elif mask_type == "uniform": lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
648
+ elif mask_type == "normal": lengths = [max(1, int(round(x))) for x in rng.normal(mask_length, mask_other, size=num_mask)]
649
+ elif mask_type == "poisson": lengths = [int(round(x)) for x in rng.poisson(mask_length, size=num_mask)]
650
+ else: raise Exception
651
+
652
+ if sum(lengths) == 0:
653
+ if mask_type == "static": raise ValueError
654
+ else: lengths = [min(mask_length, sz - 1)]
655
+
656
+ if no_overlap:
657
+ mask_idc = []
658
+
659
+ def arrange(s, e, length, keep_length):
660
+ span_start = rng.randint(s, e - length)
661
+ mask_idc.extend(span_start + i for i in range(length))
662
+ new_parts = []
663
+
664
+ if span_start - s - min_space >= keep_length: new_parts.append((s, span_start - min_space + 1))
665
+ if e - span_start - length - min_space > keep_length: new_parts.append((span_start + length + min_space, e))
666
+
667
+ return new_parts
668
+
669
+ parts = [(0, sz)]
670
+ min_length = min(lengths)
671
+
672
+ for length in sorted(lengths, reverse=True):
673
+ lens = np.fromiter((e - s if e - s >= length + min_space else 0 for s, e in parts), np.int32)
674
+ l_sum = np.sum(lens)
675
+ if l_sum == 0: break
676
+ s, e = parts.pop(rng.choice(len(parts), p=lens / np.sum(lens)))
677
+ parts.extend(arrange(s, e, length, min_length))
678
+ mask_idc = np.asarray(mask_idc)
679
+ else:
680
+ if idc_select_ver == 1:
681
+ min_len = min(lengths)
682
+ if sz - min_len <= num_mask: min_len = sz - num_mask - 1
683
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
684
+ elif idc_select_ver == 2: mask_idc = rng.choice(sz, num_mask, replace=False)
685
+ else: raise ValueError
686
+
687
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
688
+
689
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
690
+ if len(mask_idc) >= sz: raise ValueError
691
+ mask_idcs.append(mask_idc)
692
+
693
+ target_len = None
694
+ if require_same_masks: target_len = max([len(m) for m in mask_idcs]) if add_masks else min([len(m) for m in mask_idcs])
695
+
696
+ for i, mask_idc in enumerate(mask_idcs):
697
+ if target_len is not None and len(mask_idc) > target_len: mask_idc = rng.choice(mask_idc, target_len, replace=False)
698
+ mask[i, mask_idc] = True
699
+
700
+ if target_len is not None and len(mask_idc) < target_len:
701
+ to_mask = rng.choice(np.flatnonzero(~mask[i]), target_len - len(mask_idc), replace=False)
702
+ mask[i, to_mask] = True
703
+
704
+ if mask_dropout > 0:
705
+ masked = np.flatnonzero(mask[i])
706
+ mask[i, rng.choice(masked, np.rint(len(masked) * mask_dropout).astype(int), replace=False)] = False
707
+
708
+ return mask
709
+
710
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
711
+ return nn.LayerNorm(normalized_shape, eps, elementwise_affine)
712
+
713
+ def prune_state_dict(state_dict, model_cfg):
714
+ arch = None
715
+ if model_cfg is not None: arch = (model_cfg._name if isinstance(model_cfg, DictConfig) else getattr(model_cfg, "arch", None))
716
+
717
+ if not model_cfg or arch is None or arch == "ptt_transformer": return state_dict
718
+
719
+ encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
720
+ decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
721
+
722
+ if not encoder_layers_to_keep and not decoder_layers_to_keep: return state_dict
723
+
724
+ def create_pruning_pass(layers_to_keep, layer_name):
725
+ keep_layers = sorted(int(layer_string) for layer_string in layers_to_keep.split(","))
726
+ mapping_dict = {}
727
+ for i in range(len(keep_layers)):
728
+ mapping_dict[str(keep_layers[i])] = str(i)
729
+
730
+ return {"substitution_regex": re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)), "mapping_dict": mapping_dict}
731
+
732
+ pruning_passes = []
733
+ new_state_dict = {}
734
+
735
+ if encoder_layers_to_keep: pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
736
+ if decoder_layers_to_keep: pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
737
+
738
+ for layer_name in state_dict.keys():
739
+ match = re.search(r"\.layers\.(\d+)\.", layer_name)
740
+ if not match:
741
+ new_state_dict[layer_name] = state_dict[layer_name]
742
+ continue
743
+
744
+ original_layer_number = match.group(1)
745
+ for pruning_pass in pruning_passes:
746
+ if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
747
+ substitution_match = pruning_pass["substitution_regex"].search(layer_name)
748
+ new_state_dict[(layer_name[: substitution_match.start(1)] + pruning_pass["mapping_dict"][original_layer_number] + layer_name[substitution_match.end(1) :])] = state_dict[layer_name]
749
+
750
+ with open_dict(model_cfg) if isinstance(model_cfg, DictConfig) else contextlib.ExitStack():
751
+ if hasattr(model_cfg, "encoder_layers_to_keep"): model_cfg.encoder_layers_to_keep = None
752
+ if hasattr(model_cfg, "decoder_layers_to_keep"): model_cfg.decoder_layers_to_keep = None
753
+
754
+ return new_state_dict
755
+
756
+ def relu_squared(x):
757
+ return F.relu(x).pow(2)
758
+
759
+ def get_activation_fn(activation):
760
+ def gelu(x):
761
+ return nn.functional.gelu(x.float()).type_as(x)
762
+
763
+ def gelu_accurate(x):
764
+ if not hasattr(gelu_accurate, "_a"):
765
+ gelu_accurate._a = math.sqrt(2 / math.pi)
766
+ return (0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))))
767
+
768
+ if activation == "relu": return F.relu
769
+ elif activation == "relu_squared": return relu_squared
770
+ elif activation == "gelu": return gelu
771
+ elif activation == "gelu_fast": return gelu_accurate
772
+ elif activation == "gelu_accurate": return gelu_accurate
773
+ elif activation == "tanh": return torch.tanh
774
+ elif activation == "linear": return lambda x: x
775
+ elif activation == "swish": return nn.SiLU
776
+ else: raise RuntimeError
777
+
778
+ class SamePad(nn.Module):
779
+ def __init__(self, kernel_size, causal=False):
780
+ super().__init__()
781
+ if causal: self.remove = kernel_size - 1
782
+ else: self.remove = 1 if kernel_size % 2 == 0 else 0
783
+
784
+ def forward(self, x):
785
+ if self.remove > 0: x = x[:, :, : -self.remove]
786
+ return x
787
+
788
+ class TransformerSentenceEncoderLayer(nn.Module):
789
+ def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False):
790
+ super().__init__()
791
+ self.embedding_dim = embedding_dim
792
+ self.dropout = dropout
793
+ self.activation_dropout = activation_dropout
794
+ self.activation_fn = get_activation_fn(activation_fn)
795
+ self.self_attn = MultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True)
796
+ self.dropout1 = nn.Dropout(dropout)
797
+ self.dropout2 = nn.Dropout(self.activation_dropout)
798
+ self.dropout3 = nn.Dropout(dropout)
799
+ self.layer_norm_first = layer_norm_first
800
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
801
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
802
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
803
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
804
+
805
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None):
806
+ residual = x
807
+
808
+ if self.layer_norm_first:
809
+ x = self.self_attn_layer_norm(x)
810
+ x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, need_weights=False)
811
+ x = residual + self.dropout1(x)
812
+ residual = x
813
+ x = self.fc2(self.dropout2(self.activation_fn(self.fc1(self.final_layer_norm(x)))))
814
+ layer_result = x
815
+ x = residual + self.dropout3(x)
816
+ else:
817
+ x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False)
818
+ x = self.self_attn_layer_norm(residual + self.dropout1(x))
819
+ residual = x
820
+ x = self.fc2(self.dropout2(self.activation_fn(self.fc1(x))))
821
+ layer_result = x
822
+ x = self.final_layer_norm(residual + self.dropout3(x))
823
+
824
+ return x, (attn, layer_result)
825
+
826
+ class AdapterFast(nn.Module):
827
+ def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
828
+ super().__init__()
829
+ self.adapter_num = adapter_num
830
+ self.input_dim = input_dim
831
+ self.hidden_dim = hidden_dim
832
+ self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
833
+ self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
834
+ self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
835
+ self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
836
+ self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
837
+ self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
838
+ self.act_fn = nn.Identity()
839
+ if act_fn == "relu": self.act_fn = nn.ReLU()
840
+ elif act_fn == "gelu": self.act_fn = nn.GELU()
841
+ elif act_fn == "selu": self.act_fn = nn.SELU()
842
+ else: raise ValueError
843
+
844
+ self.input_dim = input_dim
845
+ self.reset_parameters()
846
+
847
+ def reset_parameters(self):
848
+ for ii in range(self.adapter_num):
849
+ nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
850
+ nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
851
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
852
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
853
+ nn.init.uniform_(self.b_a[ii], -bound, bound)
854
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
855
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
856
+ nn.init.uniform_(self.b_b[ii], -bound, bound)
857
+
858
+ nn.init.ones_(self.ln_W)
859
+ nn.init.zeros_(self.ln_b)
860
+
861
+ def forward(self, x, adapter_id):
862
+ ii = adapter_id
863
+ return F.linear(self.act_fn(F.linear(F.layer_norm(x, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii]), self.W_a[ii], self.b_a[ii])), self.W_b[ii], self.b_b[ii])
864
+
865
+ def extra_repr(self):
866
+ return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim))
867
+
868
+ class FeedForwardModule(nn.Module):
869
+ def __init__(self, input_feat, hidden_units, dropout1, dropout2, activation_fn="swish", bias=True):
870
+ super(FeedForwardModule, self).__init__()
871
+ self.layer_norm = LayerNorm(input_feat)
872
+ self.w_1 = nn.Linear(input_feat, hidden_units, bias=bias)
873
+ self.w_2 = nn.Linear(hidden_units, input_feat, bias=bias)
874
+ self.dropout1 = nn.Dropout(dropout1)
875
+ self.dropout2 = nn.Dropout(dropout2)
876
+ self.activation = get_activation_fn(activation_fn)(hidden_units)
877
+
878
+ def forward(self, x):
879
+ return self.dropout2(self.w_2(self.dropout1(self.activation(self.w_1(self.layer_norm(x))))))
880
+
881
+ class ConvolutionModule(nn.Module):
882
+ def __init__(self, embed_dim, channels, depthwise_kernel_size, dropout, activation_fn="swish", bias=False, export=False):
883
+ super(ConvolutionModule, self).__init__()
884
+ assert (depthwise_kernel_size - 1) % 2 == 0
885
+ self.layer_norm = LayerNorm(embed_dim, export=export)
886
+ self.pointwise_conv1 = nn.Conv1d(embed_dim, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias)
887
+ self.glu = nn.GLU(dim=1)
888
+ self.depthwise_conv = nn.Conv1d(channels, channels, depthwise_kernel_size, stride=1, padding=(depthwise_kernel_size - 1) // 2, groups=channels, bias=bias)
889
+ self.batch_norm = nn.BatchNorm1d(channels)
890
+ self.activation = get_activation_fn(activation_fn)(channels)
891
+ self.pointwise_conv2 = nn.Conv1d(channels, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
892
+ self.dropout = nn.Dropout(dropout)
893
+
894
+ def forward(self, x):
895
+ return self.dropout(self.pointwise_conv2(self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(self.layer_norm(x).transpose(1, 2)))))))).transpose(1, 2)
896
+
897
+ def rotate_half(x):
898
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
899
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
900
+
901
+ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
902
+ cos, sin = (cos[offset : q.shape[0] + offset, ...], sin[offset : q.shape[0] + offset, ...])
903
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
904
+
905
+ class RotaryPositionalEmbedding(nn.Module):
906
+ def __init__(self, dim, base=10000, precision=torch.half):
907
+ super().__init__()
908
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
909
+ self.register_buffer("inv_freq", inv_freq)
910
+ self.seq_len_cached = 0
911
+ self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
912
+ self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
913
+ self.precision = precision
914
+
915
+ def forward(self, x, seq_len = 0):
916
+ if seq_len > self.seq_len_cached:
917
+ self.seq_len_cached = seq_len
918
+ freqs = torch.einsum("i,j->ij", torch.arange(seq_len, device=x.device).type_as(self.inv_freq), self.inv_freq)
919
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
920
+ self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1))
921
+ self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1))
922
+ return self.cos_cached, self.sin_cached
923
+
924
+ class ESPNETMultiHeadedAttention(nn.Module):
925
+ def __init__(self, n_feat, n_head, dropout):
926
+ super(ESPNETMultiHeadedAttention, self).__init__()
927
+ assert n_feat % n_head == 0
928
+ self.d_k = n_feat // n_head
929
+ self.h = n_head
930
+ self.linear_q = nn.Linear(n_feat, n_feat)
931
+ self.linear_k = nn.Linear(n_feat, n_feat)
932
+ self.linear_v = nn.Linear(n_feat, n_feat)
933
+ self.linear_out = nn.Linear(n_feat, n_feat)
934
+ self.attn = None
935
+ self.dropout = nn.Dropout(p=dropout)
936
+
937
+ def forward_qkv(self, query, key, value, **kwargs):
938
+ n_batch = query.size(0)
939
+ return self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
940
+
941
+ def forward_attention(self, value, scores, mask):
942
+ n_batch = value.size(0)
943
+
944
+ if mask is not None:
945
+ scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2).to(bool), float("-inf"))
946
+ self.attn = torch.softmax(scores, dim=-1)
947
+ else: self.attn = torch.softmax(scores, dim=-1)
948
+
949
+ return self.linear_out((torch.matmul(self.dropout(self.attn), value).transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)))
950
+
951
+ def forward(self, query, key, value, key_padding_mask=None, **kwargs):
952
+ q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
953
+ return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
954
+
955
+ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
956
+ def __init__(self, n_feat, n_head, dropout, zero_triu=False):
957
+ super().__init__(n_feat, n_head, dropout)
958
+ self.zero_triu = zero_triu
959
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
960
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k))
961
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k))
962
+ nn.init.xavier_uniform_(self.pos_bias_u)
963
+ nn.init.xavier_uniform_(self.pos_bias_v)
964
+
965
+ def rel_shift(self, x):
966
+ x = torch.cat([torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype), x], dim=-1).view(*x.size()[:2], x.size(3) + 1, x.size(2))[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1]
967
+ if self.zero_triu: x = x * torch.tril(torch.ones((x.size(2), x.size(3)), device=x.device), x.size(3) - x.size(2))[None, None, :, :]
968
+ return x
969
+
970
+ def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
971
+ pos_emb = pos_emb.transpose(0, 1)
972
+ q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
973
+ q = q.transpose(1, 2)
974
+
975
+ return self.forward_attention(v, (torch.matmul((q + self.pos_bias_u).transpose(1, 2), k.transpose(-2, -1)) + self.rel_shift(torch.matmul((q + self.pos_bias_v).transpose(1, 2), self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.h, self.d_k).transpose(1, 2).transpose(-2, -1)))) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
976
+
977
+ class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
978
+ def __init__(self, n_feat, n_head, dropout, precision, rotary_emd_base=10000):
979
+ super().__init__(n_feat, n_head, dropout)
980
+ precision = torch.float
981
+ self.rotary_ndims = self.d_k
982
+ if precision == "fp16": precision = torch.half
983
+ self.rotary_emb = RotaryPositionalEmbedding(self.rotary_ndims, base=rotary_emd_base, precision=precision)
984
+
985
+ def forward(self, query, key, value, key_padding_mask=None, **kwargs):
986
+ T, B, C = value.size()
987
+ query = query.view(T, B, self.h, self.d_k)
988
+ key = key.view(T, B, self.h, self.d_k)
989
+ value = value.view(T, B, self.h, self.d_k)
990
+
991
+ cos, sin = self.rotary_emb(value, seq_len=T)
992
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
993
+
994
+ query = query.view(T, B, self.h * self.d_k)
995
+ key = key.view(T, B, self.h * self.d_k)
996
+ value = value.view(T, B, self.h * self.d_k)
997
+
998
+ q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
999
+ return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
1000
+
1001
+ class ConformerEncoderLayer(nn.Module):
1002
+ def __init__(self, embed_dim, ffn_embed_dim, attention_heads, dropout, use_fp16, depthwise_conv_kernel_size=31, activation_fn="swish", attn_type=None, pos_enc_type="abs"):
1003
+ self.pos_enc_type = pos_enc_type
1004
+ super(ConformerEncoderLayer, self).__init__()
1005
+ self.ffn1 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout)
1006
+ self.self_attn_layer_norm = LayerNorm(embed_dim, export=False)
1007
+ self.self_attn_dropout = nn.Dropout(dropout)
1008
+
1009
+ if attn_type == "espnet":
1010
+ if self.pos_enc_type == "rel_pos": self.self_attn = RelPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
1011
+ elif self.pos_enc_type == "rope": self.self_attn = RotaryPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout, precision=use_fp16)
1012
+ elif self.pos_enc_type == "abs": self.self_attn = ESPNETMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
1013
+ else: raise Exception
1014
+ else: self.self_attn = MultiheadAttention(embed_dim, attention_heads, dropout=dropout)
1015
+
1016
+ self.conv_module = ConvolutionModule(embed_dim=embed_dim, channels=embed_dim, depthwise_kernel_size=depthwise_conv_kernel_size, dropout=dropout, activation_fn=activation_fn)
1017
+ self.ffn2 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout, activation_fn=activation_fn)
1018
+ self.final_layer_norm = LayerNorm(embed_dim, export=False)
1019
+
1020
+ def forward(self, x, encoder_padding_mask, position_emb = None):
1021
+ residual = x
1022
+ x = self.ffn1(x) * 0.5 + residual
1023
+ residual = x
1024
+ x = self.self_attn_layer_norm(x)
1025
+
1026
+ if self.pos_enc_type == "rel_pos": x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, pos_emb=position_emb, need_weights=False)
1027
+ else: x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False)
1028
+
1029
+ x = self.self_attn_dropout(x)
1030
+ x = x + residual
1031
+ residual = x
1032
+ x = residual + self.conv_module(x.transpose(0, 1)).transpose(0, 1)
1033
+ residual = x
1034
+ x = self.ffn2(x)
1035
+ layer_result = x
1036
+ x = self.final_layer_norm(x * 0.5 + residual)
1037
+
1038
+ return x, (attn, layer_result)
1039
+
1040
+ class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer):
1041
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, position_emb=None):
1042
+ return super().forward(x, self_attn_padding_mask, position_emb)
1043
+
1044
+ class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
1045
+ def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False, adapter_num=201, adapter_dim=64, adapter_act_fn="relu"):
1046
+ super().__init__(embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, layer_norm_first=layer_norm_first)
1047
+ self.adapter_num = adapter_num
1048
+ self.adapter_dim = adapter_dim
1049
+ self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn)
1050
+
1051
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, corpus_key=None):
1052
+
1053
+ x, (attn, layer_result) = super().forward(x=x, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_weights=need_weights, att_args=att_args)
1054
+ assert corpus_key is not None
1055
+ assert len(set(corpus_key)) == 1
1056
+
1057
+ return x + self.adapter_layer(x, corpus_key[0]), (attn, layer_result)
1058
+
1059
+ class TransposeLast(nn.Module):
1060
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
1061
+ super().__init__()
1062
+ self.deconstruct_idx = deconstruct_idx
1063
+ self.tranpose_dim = tranpose_dim
1064
+
1065
+ def forward(self, x):
1066
+ if self.deconstruct_idx is not None: x = x[self.deconstruct_idx]
1067
+ return x.transpose(self.tranpose_dim, -1)
1068
+
1069
+ class TransformerEncoder(nn.Module):
1070
+ def build_encoder_layer(self, args, **kwargs):
1071
+ if args.layer_type == "transformer": layer = TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first)
1072
+ elif args.layer_type == "conformer": layer = ConformerWav2Vec2EncoderLayer(embed_dim=self.embedding_dim, ffn_embed_dim=args.encoder_ffn_embed_dim, attention_heads=args.encoder_attention_heads, dropout=args.dropout, depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, activation_fn="swish", attn_type=args.attn_type, use_fp16=args.fp16, pos_enc_type="abs")
1073
+ elif args.layer_type == "trf_adp":
1074
+ use_adp = False
1075
+ if args.adp_trf_idx == "all": use_adp = True
1076
+ else:
1077
+ if kwargs.get("layer_idx", None) in list(range(*[int(g) for g in args.adp_trf_idx.split(":")])): use_adp = True
1078
+
1079
+ layer = TransformerSentenceEncoderWithAdapterLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, adapter_num=args.adp_num, adapter_dim=args.adp_dim, adapter_act_fn=args.adp_act_fn) if use_adp else TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first,)
1080
+
1081
+ return layer
1082
+
1083
+ def __init__(self, args):
1084
+ super().__init__()
1085
+ self.dropout = args.dropout
1086
+ self.embedding_dim = args.encoder_embed_dim
1087
+ self.required_seq_len_multiple = args.required_seq_len_multiple
1088
+ pos_conv_depth = getattr(args, "pos_conv_depth", 1)
1089
+
1090
+ if pos_conv_depth > 1:
1091
+ num_layers = args.pos_conv_depth
1092
+ k = max(3, args.conv_pos // num_layers)
1093
+
1094
+ def make_conv_block(e, k, g, l):
1095
+ return nn.Sequential(*[nn.Sequential(nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g), SamePad(k), TransposeLast(), LayerNorm(e, elementwise_affine=False), TransposeLast(), nn.GELU()) for _ in range(l)])
1096
+
1097
+ self.pos_conv = make_conv_block(self.embedding_dim, k, args.conv_pos_groups, num_layers)
1098
+ else: self.pos_conv = make_conv_pos(self.embedding_dim, args.conv_pos, args.conv_pos_groups)
1099
+
1100
+ self.layers = nn.ModuleList([self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)])
1101
+ self.layer_norm_first = args.layer_norm_first
1102
+ self.layer_norm = LayerNorm(self.embedding_dim)
1103
+ self.layerdrop = args.encoder_layerdrop
1104
+ self.apply(init_bert_params)
1105
+
1106
+ def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
1107
+ x, layer_results = self.extract_features(x, padding_mask, layer, corpus_key=corpus_key)
1108
+
1109
+ if self.layer_norm_first and layer is None: x = self.layer_norm(x)
1110
+ return x, layer_results
1111
+
1112
+ def extract_features(self, x, padding_mask=None, tgt_layer=None, min_layer=0, corpus_key=None):
1113
+ if padding_mask is not None: x = index_put(x, padding_mask, 0)
1114
+ x = x + self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
1115
+
1116
+ if not self.layer_norm_first: x = self.layer_norm(x)
1117
+ x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
1118
+
1119
+ if pad_length > 0 and padding_mask is None:
1120
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
1121
+ padding_mask[:, -pad_length:] = True
1122
+ else: padding_mask, _ = pad_to_multiple(padding_mask, self.required_seq_len_multiple, dim=-1, value=True)
1123
+
1124
+ x = F.dropout(x, p=self.dropout, training=self.training).transpose(0, 1)
1125
+ layer_results = []
1126
+ r = None
1127
+
1128
+ for i, layer in enumerate(self.layers):
1129
+ dropout_probability = np.random.random() if self.layerdrop > 0 else 1
1130
+ if not self.training or (dropout_probability > self.layerdrop):
1131
+ layer_check = layer
1132
+
1133
+ if (corpus_key is None) or (not isinstance(layer_check, (TransformerSentenceEncoderWithAdapterLayer))): x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
1134
+ else: x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, corpus_key=corpus_key)
1135
+
1136
+ if i >= min_layer: layer_results.append((x, z, lr))
1137
+ if i == tgt_layer:
1138
+ r = x
1139
+ break
1140
+
1141
+ if r is not None: x = r
1142
+ x = x.transpose(0, 1)
1143
+
1144
+ if pad_length > 0:
1145
+ x = x[:, :-pad_length]
1146
+ def undo_pad(a, b, c):
1147
+ return (a[:-pad_length], b[:-pad_length] if b is not None else b, c[:-pad_length])
1148
+
1149
+ layer_results = [undo_pad(*u) for u in layer_results]
1150
+
1151
+ return x, layer_results
1152
+
1153
+ def max_positions(self):
1154
+ return self.args.max_positions
1155
+
1156
+ def upgrade_state_dict_named(self, state_dict, name):
1157
+ return state_dict
1158
+
1159
+ class Fp32GroupNorm(nn.GroupNorm):
1160
+ def __init__(self, *args, **kwargs):
1161
+ super().__init__(*args, **kwargs)
1162
+
1163
+ def forward(self, input):
1164
+ output = F.group_norm(input.float(), self.num_groups, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
1165
+ return output.type_as(input)
1166
+
1167
+ class Fp32LayerNorm(nn.LayerNorm):
1168
+ def __init__(self, *args, **kwargs):
1169
+ super().__init__(*args, **kwargs)
1170
+
1171
+ def forward(self, input):
1172
+ output = F.layer_norm(input.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
1173
+ return output.type_as(input)
1174
+
1175
+ class ConvFeatureExtractionModel(nn.Module):
1176
+ def __init__(self, conv_layers, dropout = 0.0, mode = "default", conv_bias = False):
1177
+ super().__init__()
1178
+ assert mode in {"default", "layer_norm"}
1179
+
1180
+ def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
1181
+ def make_conv():
1182
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
1183
+ nn.init.kaiming_normal_(conv.weight)
1184
+ return conv
1185
+
1186
+ assert (is_layer_norm and is_group_norm) == False
1187
+
1188
+ if is_layer_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.Sequential(TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast()), nn.GELU())
1189
+ elif is_group_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU())
1190
+ else: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
1191
+
1192
+ in_d = 1
1193
+ self.conv_layers = nn.ModuleList()
1194
+ for i, cl in enumerate(conv_layers):
1195
+ assert len(cl) == 3
1196
+ (dim, k, stride) = cl
1197
+ self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=mode == "layer_norm", is_group_norm=mode == "default" and i == 0, conv_bias=conv_bias))
1198
+ in_d = dim
1199
+
1200
+ def forward(self, x):
1201
+ x = x.unsqueeze(1)
1202
+ for conv in self.conv_layers:
1203
+ x = conv(x)
1204
+
1205
+ return x
1206
+
1207
+ class GradMultiply(torch.autograd.Function):
1208
+ @staticmethod
1209
+ def forward(ctx, x, scale):
1210
+ ctx.scale = scale
1211
+ res = x.new(x)
1212
+ return res
1213
+
1214
+ @staticmethod
1215
+ def backward(ctx, grad):
1216
+ return grad * ctx.scale, None
1217
+
1218
+ class BaseFairseqModel(nn.Module):
1219
+ def __init__(self):
1220
+ super().__init__()
1221
+ self._is_generation_fast = False
1222
+
1223
+ def get_targets(self, sample, net_output):
1224
+ return sample["target"]
1225
+
1226
+ def extract_features(self, *args, **kwargs):
1227
+ return self(*args, **kwargs)
1228
+
1229
+ def load_state_dict(self, state_dict, strict=True, model_cfg = None, args = None):
1230
+ self.upgrade_state_dict(state_dict)
1231
+ new_state_dict = prune_state_dict(state_dict, model_cfg)
1232
+ return super().load_state_dict(new_state_dict, strict)
1233
+
1234
+ def upgrade_state_dict(self, state_dict):
1235
+ self.upgrade_state_dict_named(state_dict, "")
1236
+
1237
+ def upgrade_state_dict_named(self, state_dict, name):
1238
+ assert state_dict is not None
1239
+
1240
+ def do_upgrade(m, prefix):
1241
+ if len(prefix) > 0: prefix += "."
1242
+
1243
+ for n, c in m.named_children():
1244
+ name = prefix + n
1245
+ if hasattr(c, "upgrade_state_dict_named"): c.upgrade_state_dict_named(state_dict, name)
1246
+ elif hasattr(c, "upgrade_state_dict"): c.upgrade_state_dict(state_dict)
1247
+ do_upgrade(c, name)
1248
+
1249
+ do_upgrade(self, name)
1250
+
1251
+ def make_generation_fast_(self, **kwargs):
1252
+ if self._is_generation_fast: return
1253
+ self._is_generation_fast = True
1254
+
1255
+ def apply_remove_weight_norm(module):
1256
+ try:
1257
+ nn.utils.remove_weight_norm(module)
1258
+ except (AttributeError, ValueError):
1259
+ return
1260
+
1261
+ self.apply(apply_remove_weight_norm)
1262
+
1263
+ def apply_make_generation_fast_(module, prefix):
1264
+ if len(prefix) > 0: prefix += "."
1265
+
1266
+ base_func = BaseFairseqModel.make_generation_fast_
1267
+ for n, m in module.named_modules():
1268
+ if (m != self and hasattr(m, "make_generation_fast_") and m.make_generation_fast_.__func__ is not base_func): m.make_generation_fast_(name=prefix + n, **kwargs)
1269
+
1270
+ apply_make_generation_fast_(self, "")
1271
+ self.eval()
1272
+
1273
+ class HubertConfig:
1274
+ def __init__(self, _name, label_rate, encoder_layers_1, logit_temp_ctr, num_negatives, cross_sample_negatives, ctr_layers, extractor_mode = "default", encoder_layers = 12, encoder_embed_dim = 768, encoder_ffn_embed_dim = 3072, encoder_attention_heads = 12, activation_fn = "gelu", layer_type = "transformer", dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.0, encoder_layerdrop = 0.0, dropout_input = 0.0, dropout_features = 0.0, final_dim = 0, untie_final_proj = False, layer_norm_first = False, conv_feature_layers = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", conv_bias = False, logit_temp = 0.1, target_glu = False, feature_grad_mult = 1.0, mask_length = 10, mask_prob = 0.65, mask_selection = "static", mask_other = 0.0, no_mask_overlap = False, mask_min_space = 1, mask_channel_length = 10, mask_channel_prob = 0.0, mask_channel_selection = "static", mask_channel_other = 0.0, no_mask_channel_overlap = False, mask_channel_min_space = 1, conv_pos = 128, conv_pos_groups = 16, conv_pos_batch_norm = False, latent_temp = (2, 0.5, 0.999995), skip_masked = False, skip_nomask = False, checkpoint_activations = False, required_seq_len_multiple = 2, depthwise_conv_kernel_size = 31, attn_type = "", pos_enc_type = "abs", fp16 = False):
1275
+ self._name = _name
1276
+ self.label_rate = label_rate
1277
+ self.encoder_layers_1 = encoder_layers_1
1278
+ self.logit_temp_ctr = logit_temp_ctr
1279
+ self.num_negatives = num_negatives
1280
+ self.cross_sample_negatives = cross_sample_negatives
1281
+ self.ctr_layers = ctr_layers
1282
+ self.extractor_mode = extractor_mode
1283
+ self.encoder_layers = encoder_layers
1284
+ self.encoder_embed_dim = encoder_embed_dim
1285
+ self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
1286
+ self.encoder_attention_heads = encoder_attention_heads
1287
+ self.activation_fn = activation_fn
1288
+ self.layer_type = layer_type
1289
+ self.dropout = dropout
1290
+ self.attention_dropout = attention_dropout
1291
+ self.activation_dropout = activation_dropout
1292
+ self.encoder_layerdrop = encoder_layerdrop
1293
+ self.dropout_input = encoder_layerdrop
1294
+ self.dropout_features = dropout_features
1295
+ self.final_dim = final_dim
1296
+ self.untie_final_proj = untie_final_proj
1297
+ self.layer_norm_first = layer_norm_first
1298
+ self.conv_feature_layers = conv_feature_layers
1299
+ self.conv_bias = conv_bias
1300
+ self.logit_temp = logit_temp
1301
+ self.target_glu = target_glu
1302
+ self.feature_grad_mult = feature_grad_mult
1303
+ self.mask_length = mask_length
1304
+ self.mask_prob = mask_prob
1305
+ self.mask_selection = mask_selection
1306
+ self.mask_other = mask_other
1307
+ self.no_mask_overlap = no_mask_overlap
1308
+ self.mask_min_space = mask_min_space
1309
+ self.mask_channel_length = mask_channel_length
1310
+ self.mask_channel_prob = mask_channel_prob
1311
+ self.mask_channel_selection = mask_channel_selection
1312
+ self.mask_channel_other = mask_channel_other
1313
+ self.no_mask_channel_overlap = no_mask_channel_overlap
1314
+ self.mask_channel_min_space = mask_channel_min_space
1315
+ self.conv_pos = conv_pos
1316
+ self.conv_pos_groups = conv_pos_groups
1317
+ self.conv_pos_batch_norm = conv_pos_batch_norm
1318
+ self.latent_temp = latent_temp
1319
+ self.skip_masked = skip_masked
1320
+ self.skip_nomask = skip_nomask
1321
+ self.checkpoint_activations = checkpoint_activations
1322
+ self.required_seq_len_multiple = required_seq_len_multiple
1323
+ self.depthwise_conv_kernel_size = depthwise_conv_kernel_size
1324
+ self.attn_type = attn_type
1325
+ self.pos_enc_type = pos_enc_type
1326
+ self.fp16 = fp16
1327
+
1328
+ class Model_Config(dict):
1329
+ def __getattr__(*args):
1330
+ val = dict.get(*args)
1331
+ return Model_Config(val) if type(val) is dict else val
1332
+
1333
+ __setattr__ = dict.__setitem__
1334
+ __delattr__ = dict.__delitem__
1335
+
1336
+ class HubertModel(BaseFairseqModel):
1337
+ def __init__(self, cfg):
1338
+ super().__init__()
1339
+ feature_enc_layers = eval(cfg.conv_feature_layers)
1340
+ self.embed = feature_enc_layers[-1][0]
1341
+ self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias)
1342
+ feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
1343
+ self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / 16000
1344
+ self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None)
1345
+ self.mask_prob = cfg.mask_prob
1346
+ self.mask_selection = cfg.mask_selection
1347
+ self.mask_other = cfg.mask_other
1348
+ self.mask_length = cfg.mask_length
1349
+ self.no_mask_overlap = cfg.no_mask_overlap
1350
+ self.mask_min_space = cfg.mask_min_space
1351
+ self.mask_channel_prob = cfg.mask_channel_prob
1352
+ self.mask_channel_selection = cfg.mask_channel_selection
1353
+ self.mask_channel_other = cfg.mask_channel_other
1354
+ self.mask_channel_length = cfg.mask_channel_length
1355
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
1356
+ self.mask_channel_min_space = cfg.mask_channel_min_space
1357
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
1358
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
1359
+ self.feature_grad_mult = cfg.feature_grad_mult
1360
+ self.logit_temp = cfg.logit_temp
1361
+ self.skip_masked = cfg.skip_masked
1362
+ self.skip_nomask = cfg.skip_nomask
1363
+ final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
1364
+ self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
1365
+ self.encoder = TransformerEncoder(cfg)
1366
+ self.layer_norm = LayerNorm(self.embed)
1367
+ self.target_glu = None
1368
+ if cfg.target_glu: self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU())
1369
+ self.untie_final_proj = cfg.untie_final_proj
1370
+ self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
1371
+ self.num_classes = [504]
1372
+ self.label_embs_concat = nn.Parameter(torch.FloatTensor(sum(self.num_classes), final_dim))
1373
+ nn.init.uniform_(self.label_embs_concat)
1374
+
1375
+ def upgrade_state_dict_named(self, state_dict, name):
1376
+ super().upgrade_state_dict_named(state_dict, name)
1377
+ return state_dict
1378
+
1379
+ def apply_mask(self, x, padding_mask, target_list):
1380
+ B, T, C = x.shape
1381
+ if self.mask_prob > 0:
1382
+ mask_indices = torch.from_numpy(compute_mask_indices((B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space)).to(x.device)
1383
+ x[mask_indices] = self.mask_emb
1384
+ else: mask_indices = None
1385
+
1386
+ if self.mask_channel_prob > 0: x[(torch.from_numpy(compute_mask_indices((B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space)).to(x.device).unsqueeze(1).expand(-1, T, -1))] = 0
1387
+ return x, mask_indices
1388
+
1389
+ def compute_nce(self, x, pos, negs):
1390
+ neg_is_pos = (pos == negs).all(-1)
1391
+ logits = torch.cosine_similarity(x.float(), torch.cat([pos.unsqueeze(0), negs], dim=0).float(), dim=-1).type_as(x)
1392
+ logits /= self.logit_temp
1393
+
1394
+ if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf")
1395
+ return logits.transpose(0, 1)
1396
+
1397
+ def forward_features(self, source):
1398
+ if self.feature_grad_mult > 0:
1399
+ features = self.feature_extractor(source)
1400
+ if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult)
1401
+ else:
1402
+ with torch.no_grad():
1403
+ features = self.feature_extractor(source)
1404
+ return features
1405
+
1406
+ def forward_targets(self, features, target_list):
1407
+ feat_tsz = features.size(2)
1408
+ targ_tsz = min([t.size(1) for t in target_list])
1409
+
1410
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
1411
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
1412
+ features = features[..., :feat_tsz]
1413
+
1414
+ return features, [t[:, (torch.arange(feat_tsz).float() * self.feat2tar_ratio).long()] for t in target_list]
1415
+
1416
+ def forward_padding_mask(self, features, padding_mask):
1417
+ extra = padding_mask.size(1) % features.size(1)
1418
+ if extra > 0: padding_mask = padding_mask[:, :-extra]
1419
+
1420
+ return padding_mask.view(padding_mask.size(0), features.size(1), -1).all(-1)
1421
+
1422
+ def forward(self, source, target_list = None, padding_mask = None, mask = True, features_only = False, output_layer = None):
1423
+ features = self.forward_features(source)
1424
+ if target_list is not None: features, target_list = self.forward_targets(features, target_list)
1425
+
1426
+ features_pen = features.float().pow(2).mean()
1427
+
1428
+ features = self.layer_norm(features.transpose(1, 2))
1429
+ unmasked_features = features.clone()
1430
+
1431
+ if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask)
1432
+ if self.post_extract_proj is not None: features = self.post_extract_proj(features)
1433
+
1434
+ features = self.dropout_input(features)
1435
+ unmasked_features = self.dropout_features(unmasked_features)
1436
+
1437
+ if mask: x, mask_indices = self.apply_mask(features, padding_mask, target_list)
1438
+ else: x, mask_indices = features, None
1439
+
1440
+ x, _ = self.encoder(x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1)
1441
+ if features_only: return {"x": x, "padding_mask": padding_mask, "features": features}
1442
+
1443
+ def compute_pred(proj_x, target, label_embs):
1444
+ y = torch.index_select(label_embs, 0, target.long())
1445
+ negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
1446
+
1447
+ if self.target_glu:
1448
+ y = self.target_glu(y)
1449
+ negs = self.target_glu(negs)
1450
+
1451
+ return self.compute_nce(proj_x, y, negs)
1452
+
1453
+ label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
1454
+
1455
+ if not self.skip_masked:
1456
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
1457
+ proj_x_m = self.final_proj(x[masked_indices])
1458
+ logit_m_list = [compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) for i, (proj_x_m, t) in enumerate(zip(proj_x_m.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_m for _ in range(len(target_list))], target_list))]
1459
+ else: logit_m_list = [None for _ in target_list]
1460
+
1461
+ if not self.skip_nomask:
1462
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
1463
+ proj_x_u = self.final_proj(x[nomask_indices])
1464
+ logit_u_list = [compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) for i, (proj_x_u, t) in enumerate(zip(proj_x_u.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_u for _ in range(len(target_list))], target_list))]
1465
+ else: logit_u_list = [None for _ in target_list]
1466
+
1467
+ return {"logit_m_list": logit_m_list, "logit_u_list": logit_u_list, "padding_mask": padding_mask, "features_pen": features_pen}
1468
+
1469
+ def extract_features(self, source, padding_mask = None, mask = False, ret_conv = False, output_layer = None):
1470
+ res = self.forward(source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=output_layer)
1471
+ return res["features"] if ret_conv else res["x"], res["padding_mask"]
1472
+
1473
+ def get_logits(self, net_output, is_masked=True):
1474
+ return [x.float() for x in (net_output["logit_m_list"] if is_masked else net_output["logit_u_list"]) if x is not None]
1475
+
1476
+ def get_targets(self, net_output, is_masked=True):
1477
+ return [x.new_zeros(x.size(0), dtype=torch.long) for x in self.get_logits(net_output, is_masked)]
1478
+
1479
+ def get_extra_losses(self, net_output):
1480
+ extra_losses, names = [], []
1481
+
1482
+ if "features_pen" in net_output:
1483
+ extra_losses.append(net_output["features_pen"])
1484
+ names.append("features_pen")
1485
+
1486
+ return extra_losses, names
1487
+
1488
+ def remove_pretraining_modules(self):
1489
+ self.target_glu = None
1490
+ self.final_proj = None