| |
| from typing import Optional |
| from typing import Tuple |
| import torch |
| from torch import Tensor |
| from torch.nn import Linear |
| from torch.nn import Module |
| from torch.nn.init import constant_ |
| from torch.nn.init import xavier_normal_ |
| from torch.nn.init import xavier_uniform_ |
| from torch.nn.modules.linear import NonDynamicallyQuantizableLinear |
| from torch.nn.parameter import Parameter |
|
|
| from torch.nn import functional as F |
| from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched |
|
|
|
|
| class MultiheadAttention(Module): |
| __constants__ = ["batch_first"] |
| bias_k: Optional[torch.Tensor] |
| bias_v: Optional[torch.Tensor] |
|
|
| def __init__( |
| self, |
| embed_dim, |
| num_heads, |
| dropout=0.0, |
| bias=True, |
| add_bias_kv=False, |
| add_zero_attn=False, |
| kdim=None, |
| vdim=None, |
| batch_first=False, |
| linear1_cls=Linear, |
| linear2_cls=Linear, |
| device=None, |
| dtype=None, |
| ) -> None: |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super(MultiheadAttention, self).__init__() |
| self.embed_dim = embed_dim |
| self.kdim = kdim if kdim is not None else embed_dim |
| self.vdim = vdim if vdim is not None else embed_dim |
| self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim |
|
|
| self.num_heads = num_heads |
| self.dropout = dropout |
| self.batch_first = batch_first |
| self.head_dim = embed_dim // num_heads |
| assert ( |
| self.head_dim * num_heads == self.embed_dim |
| ), "embed_dim must be divisible by num_heads" |
|
|
| if add_bias_kv: |
| self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) |
| self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) |
| else: |
| self.bias_k = self.bias_v = None |
|
|
| if linear1_cls == Linear: |
| if not self._qkv_same_embed_dim: |
| self.q_proj_weight = Parameter( |
| torch.empty((embed_dim, embed_dim), **factory_kwargs) |
| ) |
| self.k_proj_weight = Parameter( |
| torch.empty((embed_dim, self.kdim), **factory_kwargs) |
| ) |
| self.v_proj_weight = Parameter( |
| torch.empty((embed_dim, self.vdim), **factory_kwargs) |
| ) |
| self.register_parameter("in_proj_weight", None) |
| else: |
| self.in_proj_weight = Parameter( |
| torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) |
| ) |
| self.register_parameter("q_proj_weight", None) |
| self.register_parameter("k_proj_weight", None) |
| self.register_parameter("v_proj_weight", None) |
|
|
| if bias: |
| self.in_proj_bias = Parameter( |
| torch.empty(3 * embed_dim, **factory_kwargs) |
| ) |
| else: |
| self.register_parameter("in_proj_bias", None) |
| self.out_proj = NonDynamicallyQuantizableLinear( |
| embed_dim, embed_dim, bias=bias, **factory_kwargs |
| ) |
|
|
| self._reset_parameters() |
| else: |
| if not self._qkv_same_embed_dim: |
| raise NotImplementedError |
| else: |
| self.in_proj_linear = linear1_cls( |
| embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs |
| ) |
| self.in_proj_weight = self.in_proj_linear.weight |
|
|
| self.register_parameter("q_proj_weight", None) |
| self.register_parameter("k_proj_weight", None) |
| self.register_parameter("v_proj_weight", None) |
|
|
| if bias: |
| self.in_proj_bias = self.in_proj_linear.bias |
| else: |
| self.register_parameter("in_proj_bias", None) |
|
|
| self.out_proj = linear2_cls( |
| embed_dim, embed_dim, bias=bias, **factory_kwargs |
| ) |
|
|
| if self.bias_k is not None: |
| xavier_normal_(self.bias_k) |
| if self.bias_v is not None: |
| xavier_normal_(self.bias_v) |
|
|
| self.add_zero_attn = add_zero_attn |
|
|
| def _reset_parameters(self): |
| if self._qkv_same_embed_dim: |
| xavier_uniform_(self.in_proj_weight) |
| else: |
| xavier_uniform_(self.q_proj_weight) |
| xavier_uniform_(self.k_proj_weight) |
| xavier_uniform_(self.v_proj_weight) |
|
|
| if self.in_proj_bias is not None: |
| constant_(self.in_proj_bias, 0.0) |
| constant_(self.out_proj.bias, 0.0) |
|
|
| if self.bias_k is not None: |
| xavier_normal_(self.bias_k) |
| if self.bias_v is not None: |
| xavier_normal_(self.bias_v) |
|
|
| def __setstate__(self, state): |
| |
| if "_qkv_same_embed_dim" not in state: |
| state["_qkv_same_embed_dim"] = True |
|
|
| super(MultiheadAttention, self).__setstate__(state) |
|
|
| def forward( |
| self, |
| query: Tensor, |
| key: Tensor, |
| value: Tensor, |
| key_padding_mask: Optional[Tensor] = None, |
| need_weights: bool = True, |
| attn_mask: Optional[Tensor] = None, |
| average_attn_weights: bool = True, |
| cache=None, |
| ) -> Tuple[Tensor, Optional[Tensor]]: |
| any_nested = query.is_nested or key.is_nested or value.is_nested |
| query = key = value = query.transpose(1, 0) |
| attn_output = multi_head_attention_forward_patched( |
| query, |
| key, |
| value, |
| self.embed_dim, |
| self.num_heads, |
| self.in_proj_weight, |
| self.in_proj_bias, |
| self.bias_k, |
| self.bias_v, |
| self.add_zero_attn, |
| self.dropout, |
| self.out_proj.weight, |
| self.out_proj.bias, |
| training=self.training, |
| key_padding_mask=key_padding_mask, |
| need_weights=need_weights, |
| attn_mask=attn_mask, |
| average_attn_weights=average_attn_weights, |
| cache=cache, |
| ) |
| return attn_output.transpose(1, 0) |
|
|