Siddarth commited on
Commit
872069c
·
1 Parent(s): 42f1fc5

Upload UpdatedTransformer.py

Browse files
Files changed (1) hide show
  1. UpdatedTransformer.py +648 -0
UpdatedTransformer.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F;
3
+ from torch.nn.init import xavier_uniform_,constant_,xavier_normal_
4
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
5
+ from typing import Optional, Any,Tuple,List
6
+ import math
7
+ import warnings
8
+
9
+
10
+
11
+ def _in_projection_packed(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ w: torch.Tensor,
16
+ b: Optional[torch.Tensor] = None,
17
+ ) -> List[torch.Tensor]:
18
+ r"""
19
+ Performs the in-projection step of the attention operation, using packed weights.
20
+ Output is a triple containing projection tensors for query, key and value.
21
+
22
+ Args:
23
+ q, k, v: query, key and value tensors to be projected. For self-attention,
24
+ these are typically the same tensor; for encoder-decoder attention,
25
+ k and v are typically the same tensor. (We take advantage of these
26
+ identities for performance if they are present.) Regardless, q, k and v
27
+ must share a common embedding dimension; otherwise their shapes may vary.
28
+ w: projection weights for q, k and v, packed into a single tensor. Weights
29
+ are packed along dimension 0, in q, k, v order.
30
+ b: optional projection biases for q, k and v, packed into a single tensor
31
+ in q, k, v order.
32
+
33
+ Shape:
34
+ Inputs:
35
+ - q: :math:`(..., E)` where E is the embedding dimension
36
+ - k: :math:`(..., E)` where E is the embedding dimension
37
+ - v: :math:`(..., E)` where E is the embedding dimension
38
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
39
+ - b: :math:`E * 3` where E is the embedding dimension
40
+
41
+ Output:
42
+ - in output list :math:`[q', k', v']`, each output tensor will have the
43
+ same shape as the corresponding input tensor.
44
+ """
45
+ E = q.size(-1)
46
+ if k is v:
47
+ if q is k:
48
+ # self-attention
49
+ return F.linear(q, w, b).chunk(3, dim=-1)
50
+ else:
51
+ # encoder-decoder attention
52
+ w_q, w_kv = w.split([E, E * 2])
53
+ if b is None:
54
+ b_q = b_kv = None
55
+ else:
56
+ b_q, b_kv = b.split([E, E * 2])
57
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
58
+ else:
59
+ w_q, w_k, w_v = w.chunk(3)
60
+ if b is None:
61
+ b_q = b_k = b_v = None
62
+ else:
63
+ b_q, b_k, b_v = b.chunk(3)
64
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
65
+
66
+
67
+ def _in_projection(
68
+ q: torch.Tensor,
69
+ k: torch.Tensor,
70
+ v: torch.Tensor,
71
+ w_q: torch.Tensor,
72
+ w_k: torch.Tensor,
73
+ w_v: torch.Tensor,
74
+ b_q: Optional[torch.Tensor] = None,
75
+ b_k: Optional[torch.Tensor] = None,
76
+ b_v: Optional[torch.Tensor] = None,
77
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
78
+ r"""
79
+ Performs the in-projection step of the attention operation. This is simply
80
+ a triple of linear projections, with shape constraints on the weights which
81
+ ensure embedding dimension uniformity in the projected outputs.
82
+ Output is a triple containing projection tensors for query, key and value.
83
+
84
+ Args:
85
+ q, k, v: query, key and value tensors to be projected.
86
+ w_q, w_k, w_v: weights for q, k and v, respectively.
87
+ b_q, b_k, b_v: optional biases for q, k and v, respectively.
88
+
89
+ Shape:
90
+ Inputs:
91
+ - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
92
+ number of leading dimensions.
93
+ - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
94
+ number of leading dimensions.
95
+ - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
96
+ number of leading dimensions.
97
+ - w_q: :math:`(Eq, Eq)`
98
+ - w_k: :math:`(Eq, Ek)`
99
+ - w_v: :math:`(Eq, Ev)`
100
+ - b_q: :math:`(Eq)`
101
+ - b_k: :math:`(Eq)`
102
+ - b_v: :math:`(Eq)`
103
+
104
+ Output: in output triple :math:`(q', k', v')`,
105
+ - q': :math:`[Qdims..., Eq]`
106
+ - k': :math:`[Kdims..., Eq]`
107
+ - v': :math:`[Vdims..., Eq]`
108
+
109
+ """
110
+ Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
111
+ assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
112
+ assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
113
+ assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
114
+ assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
115
+ assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
116
+ assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
117
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
118
+
119
+
120
+ def _scaled_dot_product_attention(
121
+ q: torch.Tensor,
122
+ k: torch.Tensor,
123
+ v: torch.Tensor,
124
+ attn_mask: Optional[torch.Tensor] = None,
125
+ dropout_p: float = 0.0,
126
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
127
+ r"""
128
+ Computes scaled dot product attention on query, key and value tensors, using
129
+ an optional attention mask if passed, and applying dropout if a probability
130
+ greater than 0.0 is specified.
131
+ Returns a tensor pair containing attended values and attention weights.
132
+
133
+ Args:
134
+ q, k, v: query, key and value tensors. See Shape section for shape details.
135
+ attn_mask: optional tensor containing mask values to be added to calculated
136
+ attention. May be 2D or 3D; see Shape section for details.
137
+ dropout_p: dropout probability. If greater than 0.0, dropout is applied.
138
+
139
+ Shape:
140
+ - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
141
+ and E is embedding dimension.
142
+ - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
143
+ and E is embedding dimension.
144
+ - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
145
+ and E is embedding dimension.
146
+ - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
147
+ shape :math:`(Nt, Ns)`.
148
+
149
+ - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
150
+ have shape :math:`(B, Nt, Ns)`
151
+ """
152
+ B, Nt, E = q.shape
153
+ q = q / math.sqrt(E)
154
+ # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
155
+ attn = torch.bmm(q, k.transpose(-2, -1))
156
+ if attn_mask is not None:
157
+ attn += attn_mask
158
+ attn = F.softmax(attn, dim=-1)
159
+ if dropout_p > 0.0:
160
+ attn = F.dropout(attn, p=dropout_p)
161
+ # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
162
+ output = torch.bmm(attn, v)
163
+ return output, attn
164
+
165
+
166
+ def multi_head_attention_forward(
167
+ query: torch.Tensor,
168
+ key: torch.Tensor,
169
+ value: torch.Tensor,
170
+ embed_dim_to_check: int,
171
+ num_heads: int,
172
+ in_proj_weight: torch.Tensor,
173
+ in_proj_bias: Optional[torch.Tensor],
174
+ bias_k: Optional[torch.Tensor],
175
+ bias_v: Optional[torch.Tensor],
176
+ add_zero_attn: bool,
177
+ dropout_p: float,
178
+ out_proj_weight: torch.Tensor,
179
+ out_proj_bias: Optional[torch.Tensor],
180
+ training: bool = True,
181
+ key_padding_mask: Optional[torch.Tensor] = None,
182
+ need_weights: bool = True,
183
+ attn_mask: Optional[torch.Tensor] = None,
184
+ use_separate_proj_weight: bool = False,
185
+ q_proj_weight: Optional[torch.Tensor] = None,
186
+ k_proj_weight: Optional[torch.Tensor] = None,
187
+ v_proj_weight: Optional[torch.Tensor] = None,
188
+ static_k: Optional[torch.Tensor] = None,
189
+ static_v: Optional[torch.Tensor] = None,
190
+ minf=-1e9
191
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
192
+ r"""
193
+ Args:
194
+ query, key, value: map a query and a set of key-value pairs to an output.
195
+ See "Attention Is All You Need" for more details.
196
+ embed_dim_to_check: total dimension of the model.
197
+ num_heads: parallel attention heads.
198
+ in_proj_weight, in_proj_bias: input projection weight and bias.
199
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
200
+ add_zero_attn: add a new batch of zeros to the key and
201
+ value sequences at dim=1.
202
+ dropout_p: probability of an element to be zeroed.
203
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
204
+ training: apply dropout if is ``True``.
205
+ key_padding_mask: if provided, specified padding elements in the key will
206
+ be ignored by the attention. This is an binary mask. When the value is True,
207
+ the corresponding value on the attention layer will be filled with -inf.
208
+ need_weights: output attn_output_weights.
209
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
210
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
211
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
212
+ and value in different forms. If false, in_proj_weight will be used, which is
213
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
214
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
215
+ static_k, static_v: static key and value used for attention operators.
216
+
217
+
218
+ Shape:
219
+ Inputs:
220
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
221
+ the embedding dimension.
222
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
223
+ the embedding dimension.
224
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
225
+ the embedding dimension.
226
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
227
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
228
+ will be unchanged. If a BoolTensor is provided, the positions with the
229
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
230
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
231
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
232
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
233
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
234
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
235
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
236
+ is provided, it will be added to the attention weight.
237
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
238
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
239
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
240
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
241
+
242
+ Outputs:
243
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
244
+ E is the embedding dimension.
245
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
246
+ L is the target sequence length, S is the source sequence length.
247
+ """
248
+
249
+ # set up shape vars
250
+ tgt_len, bsz, embed_dim = query.shape
251
+ src_len, _, _ = key.shape
252
+ assert embed_dim == embed_dim_to_check, \
253
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
254
+ if isinstance(embed_dim, torch.Tensor):
255
+ # embed_dim can be a tensor when JIT tracing
256
+ head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
257
+ else:
258
+ head_dim = embed_dim // num_heads
259
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
260
+ if use_separate_proj_weight:
261
+ # allow MHA to have different embedding dimensions when separate projection weights are used
262
+ assert key.shape[:2] == value.shape[:2], \
263
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
264
+ else:
265
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
266
+
267
+ #
268
+ # compute in-projection
269
+ #
270
+ if not use_separate_proj_weight:
271
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
272
+ else:
273
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
274
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
275
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
276
+ if in_proj_bias is None:
277
+ b_q = b_k = b_v = None
278
+ else:
279
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
280
+ q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
281
+
282
+ # prep attention mask
283
+ if attn_mask is not None:
284
+ if attn_mask.dtype == torch.uint8:
285
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
286
+ attn_mask = attn_mask.to(torch.bool)
287
+ else:
288
+ assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
289
+ f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
290
+ # ensure attn_mask's dim is 3
291
+ if attn_mask.dim() == 2:
292
+ correct_2d_size = (tgt_len, src_len)
293
+ if attn_mask.shape != correct_2d_size:
294
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
295
+ attn_mask = attn_mask.unsqueeze(0)
296
+ elif attn_mask.dim() == 3:
297
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
298
+ if attn_mask.shape != correct_3d_size:
299
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
300
+ else:
301
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
302
+
303
+ # prep key padding mask
304
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
305
+ # F.warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
306
+ key_padding_mask = key_padding_mask.to(torch.bool)
307
+
308
+ # add bias along batch dimension (currently second)
309
+ if bias_k is not None and bias_v is not None:
310
+ assert static_k is None, "bias cannot be added to static key."
311
+ assert static_v is None, "bias cannot be added to static value."
312
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
313
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
314
+ if attn_mask is not None:
315
+ attn_mask = F.pad(attn_mask, (0, 1))
316
+ if key_padding_mask is not None:
317
+ key_padding_mask = F.pad(key_padding_mask, (0, 1))
318
+ else:
319
+ assert bias_k is None
320
+ assert bias_v is None
321
+
322
+ #
323
+ # reshape q, k, v for multihead attention and make em batch first
324
+ #
325
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
326
+ if static_k is None:
327
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
328
+ else:
329
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
330
+ assert static_k.size(0) == bsz * num_heads, \
331
+ f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
332
+ assert static_k.size(2) == head_dim, \
333
+ f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
334
+ k = static_k
335
+ if static_v is None:
336
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
337
+ else:
338
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
339
+ assert static_v.size(0) == bsz * num_heads, \
340
+ f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
341
+ assert static_v.size(2) == head_dim, \
342
+ f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
343
+ v = static_v
344
+
345
+ # add zero attention along batch dimension (now first)
346
+ if add_zero_attn:
347
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
348
+ k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
349
+ v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
350
+ if attn_mask is not None:
351
+ attn_mask = F.pad(attn_mask, (0, 1))
352
+ if key_padding_mask is not None:
353
+ key_padding_mask = F.pad(key_padding_mask, (0, 1))
354
+
355
+ # update source sequence length after adjustments
356
+ src_len = k.size(1)
357
+
358
+ # merge key padding and attention masks
359
+ if key_padding_mask is not None:
360
+ assert key_padding_mask.shape == (bsz, src_len), \
361
+ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
362
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
363
+ expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
364
+ if attn_mask is None:
365
+ attn_mask = key_padding_mask
366
+ elif attn_mask.dtype == torch.bool:
367
+ attn_mask = attn_mask.logical_or(key_padding_mask)
368
+ else:
369
+ attn_mask = attn_mask.masked_fill(key_padding_mask, minf)
370
+
371
+ # convert mask to float
372
+ if attn_mask is not None and attn_mask.dtype == torch.bool:
373
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
374
+ new_attn_mask.masked_fill_(attn_mask, minf)
375
+ attn_mask = new_attn_mask
376
+
377
+ # adjust dropout probability
378
+ if not training:
379
+ dropout_p = 0.0
380
+
381
+ #
382
+ # (deep breath) calculate attention and out projection
383
+ #
384
+ attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
385
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
386
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
387
+
388
+ if need_weights:
389
+ # average attention weights over heads
390
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
391
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
392
+ else:
393
+ return attn_output, None
394
+
395
+ def _get_activation_fn(activation):
396
+ if activation == "relu":
397
+ return F.relu
398
+ elif activation == "gelu":
399
+ return F.gelu
400
+
401
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
402
+
403
+ class neko_MultiheadAttention(torch.nn.Module):
404
+ r"""Allows the model to jointly attend to information
405
+ from different representation subspaces.
406
+ See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_
407
+
408
+ .. math::
409
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
410
+
411
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
412
+
413
+ Args:
414
+ embed_dim: total dimension of the model.
415
+ num_heads: parallel attention heads.
416
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
417
+ bias: add bias as module parameter. Default: True.
418
+ add_bias_kv: add bias to the key and value sequences at dim=0.
419
+ add_zero_attn: add a new batch of zeros to the key and
420
+ value sequences at dim=1.
421
+ kdim: total number of features in key. Default: None.
422
+ vdim: total number of features in value. Default: None.
423
+ batch_first: If ``True``, then the input and output tensors are provided
424
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
425
+
426
+ Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
427
+ to :attr:`embed_dim` such that query, key, and value have the same
428
+ number of features.
429
+
430
+ Examples::
431
+
432
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
433
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
434
+ """
435
+ __constants__ = ['batch_first']
436
+ bias_k: Optional[torch.Tensor]
437
+ bias_v: Optional[torch.Tensor]
438
+
439
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
440
+ kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
441
+ factory_kwargs = {'device': device, 'dtype': dtype}
442
+ super(neko_MultiheadAttention, self).__init__()
443
+ self.embed_dim = embed_dim
444
+ self.kdim = kdim if kdim is not None else embed_dim
445
+ self.vdim = vdim if vdim is not None else embed_dim
446
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
447
+
448
+ self.num_heads = num_heads
449
+ self.dropout = dropout
450
+ self.batch_first = batch_first
451
+ self.head_dim = embed_dim // num_heads
452
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
453
+
454
+ if self._qkv_same_embed_dim is False:
455
+ self.q_proj_weight = torch.nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
456
+ self.k_proj_weight = torch.nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
457
+ self.v_proj_weight = torch.nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
458
+ self.register_parameter('in_proj_weight', None)
459
+ else:
460
+ self.in_proj_weight = torch.nn.Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
461
+ self.register_parameter('q_proj_weight', None)
462
+ self.register_parameter('k_proj_weight', None)
463
+ self.register_parameter('v_proj_weight', None)
464
+
465
+ if bias:
466
+ self.in_proj_bias = torch.nn.Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
467
+ else:
468
+ self.register_parameter('in_proj_bias', None)
469
+ self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
470
+
471
+ if add_bias_kv:
472
+ self.bias_k = torch.nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
473
+ self.bias_v = torch.nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
474
+ else:
475
+ self.bias_k = self.bias_v = None
476
+
477
+ self.add_zero_attn = add_zero_attn
478
+
479
+ self._reset_parameters()
480
+
481
+ def _reset_parameters(self):
482
+ if self._qkv_same_embed_dim:
483
+ xavier_uniform_(self.in_proj_weight)
484
+ else:
485
+ xavier_uniform_(self.q_proj_weight)
486
+ xavier_uniform_(self.k_proj_weight)
487
+ xavier_uniform_(self.v_proj_weight)
488
+
489
+ if self.in_proj_bias is not None:
490
+ constant_(self.in_proj_bias, 0.)
491
+ constant_(self.out_proj.bias, 0.)
492
+ if self.bias_k is not None:
493
+ xavier_normal_(self.bias_k)
494
+ if self.bias_v is not None:
495
+ xavier_normal_(self.bias_v)
496
+
497
+ def __setstate__(self, state):
498
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
499
+ if '_qkv_same_embed_dim' not in state:
500
+ state['_qkv_same_embed_dim'] = True
501
+
502
+ super(neko_MultiheadAttention, self).__setstate__(state)
503
+
504
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None,
505
+ need_weights: bool = True, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
506
+ r"""
507
+ Args:
508
+ query, key, value: map a query and a set of key-value pairs to an output.
509
+ See "Attention Is All You Need" for more details.
510
+ key_padding_mask: if provided, specified padding elements in the key will
511
+ be ignored by the attention. When given a binary mask and a value is True,
512
+ the corresponding value on the attention layer will be ignored. When given
513
+ a byte mask and a value is non-zero, the corresponding value on the attention
514
+ layer will be ignored
515
+ need_weights: output attn_output_weights.
516
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
517
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
518
+
519
+ Shapes for inputs:
520
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
521
+ the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
522
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
523
+ the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
524
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
525
+ the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
526
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
527
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
528
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
529
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
530
+ - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
531
+ source sequence length.
532
+
533
+ If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
534
+ length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
535
+ the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
536
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
537
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
538
+ is provided, it will be added to the attention weight.
539
+
540
+ Shapes for outputs:
541
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
542
+ E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
543
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
544
+ L is the target sequence length, S is the source sequence length.
545
+ """
546
+ if self.batch_first:
547
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
548
+
549
+ if not self._qkv_same_embed_dim:
550
+ attn_output, attn_output_weights = multi_head_attention_forward(
551
+ query, key, value, self.embed_dim, self.num_heads,
552
+ self.in_proj_weight, self.in_proj_bias,
553
+ self.bias_k, self.bias_v, self.add_zero_attn,
554
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
555
+ training=self.training,
556
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
557
+ attn_mask=attn_mask, use_separate_proj_weight=True,
558
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
559
+ v_proj_weight=self.v_proj_weight)
560
+ else:
561
+ attn_output, attn_output_weights = multi_head_attention_forward(
562
+ query, key, value, self.embed_dim, self.num_heads,
563
+ self.in_proj_weight, self.in_proj_bias,
564
+ self.bias_k, self.bias_v, self.add_zero_attn,
565
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
566
+ training=self.training,
567
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
568
+ attn_mask=attn_mask)
569
+ if self.batch_first:
570
+ return attn_output.transpose(1, 0), attn_output_weights
571
+ else:
572
+ return attn_output, attn_output_weights
573
+
574
+ class neko_TransformerEncoderLayer(torch.nn.Module):
575
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
576
+ This standard encoder layer is based on the paper "Attention Is All You Need".
577
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
578
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
579
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
580
+ in a different way during application.
581
+
582
+ Args:
583
+ d_model: the number of expected features in the input (required).
584
+ nhead: the number of heads in the multiheadattention models (required).
585
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
586
+ dropout: the dropout value (default=0.1).
587
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
588
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
589
+ batch_first: If ``True``, then the input and output tensors are provided
590
+ as (batch, seq, feature). Default: ``False``.
591
+
592
+ Examples::
593
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
594
+ >>> src = torch.rand(10, 32, 512)
595
+ >>> out = encoder_layer(src)
596
+
597
+ Alternatively, when ``batch_first`` is ``True``:
598
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
599
+ >>> src = torch.rand(32, 10, 512)
600
+ >>> out = encoder_layer(src)
601
+ """
602
+ __constants__ = ['batch_first']
603
+
604
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
605
+ layer_norm_eps=1e-5, batch_first=False,
606
+ device=None, dtype=None) -> None:
607
+ factory_kwargs = {'device': device, 'dtype': dtype}
608
+ super(neko_TransformerEncoderLayer, self).__init__()
609
+ self.self_attn = neko_MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
610
+ **factory_kwargs)
611
+ # Implementation of Feedforward model
612
+ self.linear1 = torch.nn.Linear(d_model, dim_feedforward, **factory_kwargs)
613
+ self.dropout = torch.nn.Dropout(dropout)
614
+ self.linear2 = torch.nn.Linear(dim_feedforward, d_model, **factory_kwargs)
615
+
616
+ self.norm1 = torch.nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
617
+ self.norm2 = torch.nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
618
+ self.dropout1 = torch.nn.Dropout(dropout)
619
+ self.dropout2 = torch.nn.Dropout(dropout)
620
+
621
+ self.activation = _get_activation_fn(activation)
622
+
623
+ def __setstate__(self, state):
624
+ if 'activation' not in state:
625
+ state['activation'] = F.relu
626
+ super(neko_TransformerEncoderLayer, self).__setstate__(state)
627
+
628
+ def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
629
+ r"""Pass the input through the encoder layer.
630
+
631
+ Args:
632
+ src: the sequence to the encoder layer (required).
633
+ src_mask: the mask for the src sequence (optional).
634
+ src_key_padding_mask: the mask for the src keys per batch (optional).
635
+
636
+ Shape:
637
+ see the docs in Transformer class.
638
+ """
639
+ src2 = self.self_attn(src, src, src, attn_mask=src_mask,
640
+ key_padding_mask=src_key_padding_mask)[0]
641
+ src = src + self.dropout1(src2)
642
+ src = self.norm1(src)
643
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
644
+ src = src + self.dropout2(src2)
645
+ src = self.norm2(src)
646
+ return src
647
+
648
+