|
|
from torch import Tensor |
|
|
from torch.nn import MultiheadAttention |
|
|
from torch.nn import functional as F |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
|
|
|
class MultiheadSelfAttention(MultiheadAttention): |
|
|
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, |
|
|
need_weights: bool = True, attn_mask: Optional[Tensor] = None, return_tokens: bool = False) \ |
|
|
-> Tuple[Tensor, Tensor]: |
|
|
assert query is value and value is key |
|
|
if return_tokens: |
|
|
|
|
|
tokens = F.linear(value, self.in_proj_weight, bias=self.in_proj_bias)[..., -self.embed_dim:] |
|
|
|
|
|
tokens = F.linear(tokens, self.out_proj.weight, bias=self.out_proj.bias) |
|
|
else: |
|
|
tokens = None |
|
|
|
|
|
attn_output, attn_output_weights = F.multi_head_attention_forward( |
|
|
query=query, key=key, value=value, |
|
|
embed_dim_to_check=self.embed_dim, |
|
|
num_heads=self.num_heads, |
|
|
in_proj_weight=self.in_proj_weight, |
|
|
in_proj_bias=self.in_proj_bias, |
|
|
bias_k=None, bias_v=None, |
|
|
add_zero_attn=False, |
|
|
dropout_p=0., |
|
|
out_proj_weight=self.out_proj.weight, |
|
|
out_proj_bias=self.out_proj.bias, |
|
|
training=self.training, |
|
|
key_padding_mask=key_padding_mask, need_weights=need_weights, |
|
|
attn_mask=attn_mask) |
|
|
|
|
|
return attn_output, tokens |
|
|
|