File size: 1,522 Bytes
c02d17f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from torch import Tensor
from torch.nn import MultiheadAttention
from torch.nn import functional as F
from typing import Optional, Tuple
class MultiheadSelfAttention(MultiheadAttention):
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, attn_mask: Optional[Tensor] = None, return_tokens: bool = False) \
-> Tuple[Tensor, Tensor]:
assert query is value and value is key # self-attention
if return_tokens:
# in_projection
tokens = F.linear(value, self.in_proj_weight, bias=self.in_proj_bias)[..., -self.embed_dim:]
# out_projection
tokens = F.linear(tokens, self.out_proj.weight, bias=self.out_proj.bias)
else:
tokens = None
attn_output, attn_output_weights = F.multi_head_attention_forward(
query=query, key=key, value=value,
embed_dim_to_check=self.embed_dim,
num_heads=self.num_heads,
in_proj_weight=self.in_proj_weight,
in_proj_bias=self.in_proj_bias,
bias_k=None, bias_v=None,
add_zero_attn=False,
dropout_p=0.,
out_proj_weight=self.out_proj.weight,
out_proj_bias=self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask)
return attn_output, tokens # , attn_output_weights
|