Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.nn import functional as F; | |
| from torch.nn.init import xavier_uniform_,constant_,xavier_normal_ | |
| from torch.nn.modules.linear import NonDynamicallyQuantizableLinear | |
| from typing import Optional, Any,Tuple,List | |
| import math | |
| import warnings | |
| def _in_projection_packed( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| w: torch.Tensor, | |
| b: Optional[torch.Tensor] = None, | |
| ) -> List[torch.Tensor]: | |
| r""" | |
| Performs the in-projection step of the attention operation, using packed weights. | |
| Output is a triple containing projection tensors for query, key and value. | |
| Args: | |
| q, k, v: query, key and value tensors to be projected. For self-attention, | |
| these are typically the same tensor; for encoder-decoder attention, | |
| k and v are typically the same tensor. (We take advantage of these | |
| identities for performance if they are present.) Regardless, q, k and v | |
| must share a common embedding dimension; otherwise their shapes may vary. | |
| w: projection weights for q, k and v, packed into a single tensor. Weights | |
| are packed along dimension 0, in q, k, v order. | |
| b: optional projection biases for q, k and v, packed into a single tensor | |
| in q, k, v order. | |
| Shape: | |
| Inputs: | |
| - q: :math:`(..., E)` where E is the embedding dimension | |
| - k: :math:`(..., E)` where E is the embedding dimension | |
| - v: :math:`(..., E)` where E is the embedding dimension | |
| - w: :math:`(E * 3, E)` where E is the embedding dimension | |
| - b: :math:`E * 3` where E is the embedding dimension | |
| Output: | |
| - in output list :math:`[q', k', v']`, each output tensor will have the | |
| same shape as the corresponding input tensor. | |
| """ | |
| E = q.size(-1) | |
| if k is v: | |
| if q is k: | |
| # self-attention | |
| return F.linear(q, w, b).chunk(3, dim=-1) | |
| else: | |
| # encoder-decoder attention | |
| w_q, w_kv = w.split([E, E * 2]) | |
| if b is None: | |
| b_q = b_kv = None | |
| else: | |
| b_q, b_kv = b.split([E, E * 2]) | |
| return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1) | |
| else: | |
| w_q, w_k, w_v = w.chunk(3) | |
| if b is None: | |
| b_q = b_k = b_v = None | |
| else: | |
| b_q, b_k, b_v = b.chunk(3) | |
| return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) | |
| def _in_projection( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| w_q: torch.Tensor, | |
| w_k: torch.Tensor, | |
| w_v: torch.Tensor, | |
| b_q: Optional[torch.Tensor] = None, | |
| b_k: Optional[torch.Tensor] = None, | |
| b_v: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| r""" | |
| Performs the in-projection step of the attention operation. This is simply | |
| a triple of linear projections, with shape constraints on the weights which | |
| ensure embedding dimension uniformity in the projected outputs. | |
| Output is a triple containing projection tensors for query, key and value. | |
| Args: | |
| q, k, v: query, key and value tensors to be projected. | |
| w_q, w_k, w_v: weights for q, k and v, respectively. | |
| b_q, b_k, b_v: optional biases for q, k and v, respectively. | |
| Shape: | |
| Inputs: | |
| - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any | |
| number of leading dimensions. | |
| - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any | |
| number of leading dimensions. | |
| - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any | |
| number of leading dimensions. | |
| - w_q: :math:`(Eq, Eq)` | |
| - w_k: :math:`(Eq, Ek)` | |
| - w_v: :math:`(Eq, Ev)` | |
| - b_q: :math:`(Eq)` | |
| - b_k: :math:`(Eq)` | |
| - b_v: :math:`(Eq)` | |
| Output: in output triple :math:`(q', k', v')`, | |
| - q': :math:`[Qdims..., Eq]` | |
| - k': :math:`[Kdims..., Eq]` | |
| - v': :math:`[Vdims..., Eq]` | |
| """ | |
| Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) | |
| assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" | |
| assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" | |
| assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" | |
| assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" | |
| assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" | |
| assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" | |
| return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) | |
| def _scaled_dot_product_attention( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| r""" | |
| Computes scaled dot product attention on query, key and value tensors, using | |
| an optional attention mask if passed, and applying dropout if a probability | |
| greater than 0.0 is specified. | |
| Returns a tensor pair containing attended values and attention weights. | |
| Args: | |
| q, k, v: query, key and value tensors. See Shape section for shape details. | |
| attn_mask: optional tensor containing mask values to be added to calculated | |
| attention. May be 2D or 3D; see Shape section for details. | |
| dropout_p: dropout probability. If greater than 0.0, dropout is applied. | |
| Shape: | |
| - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length, | |
| and E is embedding dimension. | |
| - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length, | |
| and E is embedding dimension. | |
| - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length, | |
| and E is embedding dimension. | |
| - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of | |
| shape :math:`(Nt, Ns)`. | |
| - Output: attention values have shape :math:`(B, Nt, E)`; attention weights | |
| have shape :math:`(B, Nt, Ns)` | |
| """ | |
| B, Nt, E = q.shape | |
| q = q / math.sqrt(E) | |
| # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) | |
| attn = torch.bmm(q, k.transpose(-2, -1)) | |
| if attn_mask is not None: | |
| attn += attn_mask | |
| attn = F.softmax(attn, dim=-1) | |
| if dropout_p > 0.0: | |
| attn = F.dropout(attn, p=dropout_p) | |
| # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) | |
| output = torch.bmm(attn, v) | |
| return output, attn | |
| def multi_head_attention_forward( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| embed_dim_to_check: int, | |
| num_heads: int, | |
| in_proj_weight: torch.Tensor, | |
| in_proj_bias: Optional[torch.Tensor], | |
| bias_k: Optional[torch.Tensor], | |
| bias_v: Optional[torch.Tensor], | |
| add_zero_attn: bool, | |
| dropout_p: float, | |
| out_proj_weight: torch.Tensor, | |
| out_proj_bias: Optional[torch.Tensor], | |
| training: bool = True, | |
| key_padding_mask: Optional[torch.Tensor] = None, | |
| need_weights: bool = True, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| use_separate_proj_weight: bool = False, | |
| q_proj_weight: Optional[torch.Tensor] = None, | |
| k_proj_weight: Optional[torch.Tensor] = None, | |
| v_proj_weight: Optional[torch.Tensor] = None, | |
| static_k: Optional[torch.Tensor] = None, | |
| static_v: Optional[torch.Tensor] = None, | |
| minf=-1e9 | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| r""" | |
| Args: | |
| query, key, value: map a query and a set of key-value pairs to an output. | |
| See "Attention Is All You Need" for more details. | |
| embed_dim_to_check: total dimension of the model. | |
| num_heads: parallel attention heads. | |
| in_proj_weight, in_proj_bias: input projection weight and bias. | |
| bias_k, bias_v: bias of the key and value sequences to be added at dim=0. | |
| add_zero_attn: add a new batch of zeros to the key and | |
| value sequences at dim=1. | |
| dropout_p: probability of an element to be zeroed. | |
| out_proj_weight, out_proj_bias: the output projection weight and bias. | |
| training: apply dropout if is ``True``. | |
| key_padding_mask: if provided, specified padding elements in the key will | |
| be ignored by the attention. This is an binary mask. When the value is True, | |
| the corresponding value on the attention layer will be filled with -inf. | |
| need_weights: output attn_output_weights. | |
| attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all | |
| the batches while a 3D mask allows to specify a different mask for the entries of each batch. | |
| use_separate_proj_weight: the function accept the proj. weights for query, key, | |
| and value in different forms. If false, in_proj_weight will be used, which is | |
| a combination of q_proj_weight, k_proj_weight, v_proj_weight. | |
| q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. | |
| static_k, static_v: static key and value used for attention operators. | |
| Shape: | |
| Inputs: | |
| - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is | |
| the embedding dimension. | |
| - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is | |
| the embedding dimension. | |
| - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is | |
| the embedding dimension. | |
| - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. | |
| If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions | |
| will be unchanged. If a BoolTensor is provided, the positions with the | |
| value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. | |
| - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. | |
| 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, | |
| S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked | |
| positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend | |
| while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` | |
| are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor | |
| is provided, it will be added to the attention weight. | |
| - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, | |
| N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. | |
| - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, | |
| N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. | |
| Outputs: | |
| - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, | |
| E is the embedding dimension. | |
| - attn_output_weights: :math:`(N, L, S)` where N is the batch size, | |
| L is the target sequence length, S is the source sequence length. | |
| """ | |
| # set up shape vars | |
| tgt_len, bsz, embed_dim = query.shape | |
| src_len, _, _ = key.shape | |
| assert embed_dim == embed_dim_to_check, \ | |
| f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" | |
| if isinstance(embed_dim, torch.Tensor): | |
| # embed_dim can be a tensor when JIT tracing | |
| head_dim = embed_dim.div(num_heads, rounding_mode='trunc') | |
| else: | |
| head_dim = embed_dim // num_heads | |
| assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" | |
| if use_separate_proj_weight: | |
| # allow MHA to have different embedding dimensions when separate projection weights are used | |
| assert key.shape[:2] == value.shape[:2], \ | |
| f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" | |
| else: | |
| assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" | |
| # | |
| # compute in-projection | |
| # | |
| if not use_separate_proj_weight: | |
| q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) | |
| else: | |
| assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" | |
| assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" | |
| assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" | |
| if in_proj_bias is None: | |
| b_q = b_k = b_v = None | |
| else: | |
| b_q, b_k, b_v = in_proj_bias.chunk(3) | |
| q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) | |
| # prep attention mask | |
| if attn_mask is not None: | |
| if attn_mask.dtype == torch.uint8: | |
| warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") | |
| attn_mask = attn_mask.to(torch.bool) | |
| else: | |
| assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ | |
| f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" | |
| # ensure attn_mask's dim is 3 | |
| if attn_mask.dim() == 2: | |
| correct_2d_size = (tgt_len, src_len) | |
| if attn_mask.shape != correct_2d_size: | |
| raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") | |
| attn_mask = attn_mask.unsqueeze(0) | |
| elif attn_mask.dim() == 3: | |
| correct_3d_size = (bsz * num_heads, tgt_len, src_len) | |
| if attn_mask.shape != correct_3d_size: | |
| raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") | |
| else: | |
| raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") | |
| # prep key padding mask | |
| if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: | |
| # F.warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") | |
| key_padding_mask = key_padding_mask.to(torch.bool) | |
| # add bias along batch dimension (currently second) | |
| if bias_k is not None and bias_v is not None: | |
| assert static_k is None, "bias cannot be added to static key." | |
| assert static_v is None, "bias cannot be added to static value." | |
| k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) | |
| v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) | |
| if attn_mask is not None: | |
| attn_mask = F.pad(attn_mask, (0, 1)) | |
| if key_padding_mask is not None: | |
| key_padding_mask = F.pad(key_padding_mask, (0, 1)) | |
| else: | |
| assert bias_k is None | |
| assert bias_v is None | |
| # | |
| # reshape q, k, v for multihead attention and make em batch first | |
| # | |
| q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) | |
| if static_k is None: | |
| k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) | |
| else: | |
| # TODO finish disentangling control flow so we don't do in-projections when statics are passed | |
| assert static_k.size(0) == bsz * num_heads, \ | |
| f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" | |
| assert static_k.size(2) == head_dim, \ | |
| f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" | |
| k = static_k | |
| if static_v is None: | |
| v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) | |
| else: | |
| # TODO finish disentangling control flow so we don't do in-projections when statics are passed | |
| assert static_v.size(0) == bsz * num_heads, \ | |
| f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" | |
| assert static_v.size(2) == head_dim, \ | |
| f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" | |
| v = static_v | |
| # add zero attention along batch dimension (now first) | |
| if add_zero_attn: | |
| zero_attn_shape = (bsz * num_heads, 1, head_dim) | |
| k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) | |
| v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) | |
| if attn_mask is not None: | |
| attn_mask = F.pad(attn_mask, (0, 1)) | |
| if key_padding_mask is not None: | |
| key_padding_mask = F.pad(key_padding_mask, (0, 1)) | |
| # update source sequence length after adjustments | |
| src_len = k.size(1) | |
| # merge key padding and attention masks | |
| if key_padding_mask is not None: | |
| assert key_padding_mask.shape == (bsz, src_len), \ | |
| f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" | |
| key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ | |
| expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) | |
| if attn_mask is None: | |
| attn_mask = key_padding_mask | |
| elif attn_mask.dtype == torch.bool: | |
| attn_mask = attn_mask.logical_or(key_padding_mask) | |
| else: | |
| attn_mask = attn_mask.masked_fill(key_padding_mask, minf) | |
| # convert mask to float | |
| if attn_mask is not None and attn_mask.dtype == torch.bool: | |
| new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float) | |
| new_attn_mask.masked_fill_(attn_mask, minf) | |
| attn_mask = new_attn_mask | |
| # adjust dropout probability | |
| if not training: | |
| dropout_p = 0.0 | |
| # | |
| # (deep breath) calculate attention and out projection | |
| # | |
| attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p) | |
| attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) | |
| attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) | |
| if need_weights: | |
| # average attention weights over heads | |
| attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) | |
| return attn_output, attn_output_weights.sum(dim=1) / num_heads | |
| else: | |
| return attn_output, None | |
| def _get_activation_fn(activation): | |
| if activation == "relu": | |
| return F.relu | |
| elif activation == "gelu": | |
| return F.gelu | |
| raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) | |
| class neko_MultiheadAttention(torch.nn.Module): | |
| r"""Allows the model to jointly attend to information | |
| from different representation subspaces. | |
| See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ | |
| .. math:: | |
| \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O | |
| where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. | |
| Args: | |
| embed_dim: total dimension of the model. | |
| num_heads: parallel attention heads. | |
| dropout: a Dropout layer on attn_output_weights. Default: 0.0. | |
| bias: add bias as module parameter. Default: True. | |
| add_bias_kv: add bias to the key and value sequences at dim=0. | |
| add_zero_attn: add a new batch of zeros to the key and | |
| value sequences at dim=1. | |
| kdim: total number of features in key. Default: None. | |
| vdim: total number of features in value. Default: None. | |
| batch_first: If ``True``, then the input and output tensors are provided | |
| as (batch, seq, feature). Default: ``False`` (seq, batch, feature). | |
| Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set | |
| to :attr:`embed_dim` such that query, key, and value have the same | |
| number of features. | |
| Examples:: | |
| >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) | |
| >>> attn_output, attn_output_weights = multihead_attn(query, key, value) | |
| """ | |
| __constants__ = ['batch_first'] | |
| bias_k: Optional[torch.Tensor] | |
| bias_v: Optional[torch.Tensor] | |
| def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, | |
| kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None: | |
| factory_kwargs = {'device': device, 'dtype': dtype} | |
| super(neko_MultiheadAttention, self).__init__() | |
| self.embed_dim = embed_dim | |
| self.kdim = kdim if kdim is not None else embed_dim | |
| self.vdim = vdim if vdim is not None else embed_dim | |
| self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.batch_first = batch_first | |
| self.head_dim = embed_dim // num_heads | |
| assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | |
| if self._qkv_same_embed_dim is False: | |
| self.q_proj_weight = torch.nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) | |
| self.k_proj_weight = torch.nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) | |
| self.v_proj_weight = torch.nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) | |
| self.register_parameter('in_proj_weight', None) | |
| else: | |
| self.in_proj_weight = torch.nn.Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) | |
| self.register_parameter('q_proj_weight', None) | |
| self.register_parameter('k_proj_weight', None) | |
| self.register_parameter('v_proj_weight', None) | |
| if bias: | |
| self.in_proj_bias = torch.nn.Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) | |
| else: | |
| self.register_parameter('in_proj_bias', None) | |
| self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs) | |
| if add_bias_kv: | |
| self.bias_k = torch.nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) | |
| self.bias_v = torch.nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) | |
| else: | |
| self.bias_k = self.bias_v = None | |
| self.add_zero_attn = add_zero_attn | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| if self._qkv_same_embed_dim: | |
| xavier_uniform_(self.in_proj_weight) | |
| else: | |
| xavier_uniform_(self.q_proj_weight) | |
| xavier_uniform_(self.k_proj_weight) | |
| xavier_uniform_(self.v_proj_weight) | |
| if self.in_proj_bias is not None: | |
| constant_(self.in_proj_bias, 0.) | |
| constant_(self.out_proj.bias, 0.) | |
| if self.bias_k is not None: | |
| xavier_normal_(self.bias_k) | |
| if self.bias_v is not None: | |
| xavier_normal_(self.bias_v) | |
| def __setstate__(self, state): | |
| # Support loading old MultiheadAttention checkpoints generated by v1.1.0 | |
| if '_qkv_same_embed_dim' not in state: | |
| state['_qkv_same_embed_dim'] = True | |
| super(neko_MultiheadAttention, self).__setstate__(state) | |
| def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, | |
| need_weights: bool = True, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| r""" | |
| Args: | |
| query, key, value: map a query and a set of key-value pairs to an output. | |
| See "Attention Is All You Need" for more details. | |
| key_padding_mask: if provided, specified padding elements in the key will | |
| be ignored by the attention. When given a binary mask and a value is True, | |
| the corresponding value on the attention layer will be ignored. When given | |
| a byte mask and a value is non-zero, the corresponding value on the attention | |
| layer will be ignored | |
| need_weights: output attn_output_weights. | |
| attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all | |
| the batches while a 3D mask allows to specify a different mask for the entries of each batch. | |
| Shapes for inputs: | |
| - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is | |
| the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. | |
| - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is | |
| the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. | |
| - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is | |
| the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. | |
| - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. | |
| If a ByteTensor is provided, the non-zero positions will be ignored while the position | |
| with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the | |
| value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. | |
| - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the | |
| source sequence length. | |
| If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence | |
| length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend | |
| the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend | |
| while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` | |
| is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor | |
| is provided, it will be added to the attention weight. | |
| Shapes for outputs: | |
| - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, | |
| E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. | |
| - attn_output_weights: :math:`(N, L, S)` where N is the batch size, | |
| L is the target sequence length, S is the source sequence length. | |
| """ | |
| if self.batch_first: | |
| query, key, value = [x.transpose(1, 0) for x in (query, key, value)] | |
| if not self._qkv_same_embed_dim: | |
| attn_output, attn_output_weights = multi_head_attention_forward( | |
| query, key, value, self.embed_dim, self.num_heads, | |
| self.in_proj_weight, self.in_proj_bias, | |
| self.bias_k, self.bias_v, self.add_zero_attn, | |
| self.dropout, self.out_proj.weight, self.out_proj.bias, | |
| training=self.training, | |
| key_padding_mask=key_padding_mask, need_weights=need_weights, | |
| attn_mask=attn_mask, use_separate_proj_weight=True, | |
| q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, | |
| v_proj_weight=self.v_proj_weight) | |
| else: | |
| attn_output, attn_output_weights = multi_head_attention_forward( | |
| query, key, value, self.embed_dim, self.num_heads, | |
| self.in_proj_weight, self.in_proj_bias, | |
| self.bias_k, self.bias_v, self.add_zero_attn, | |
| self.dropout, self.out_proj.weight, self.out_proj.bias, | |
| training=self.training, | |
| key_padding_mask=key_padding_mask, need_weights=need_weights, | |
| attn_mask=attn_mask) | |
| if self.batch_first: | |
| return attn_output.transpose(1, 0), attn_output_weights | |
| else: | |
| return attn_output, attn_output_weights | |
| class neko_TransformerEncoderLayer(torch.nn.Module): | |
| r"""TransformerEncoderLayer is made up of self-attn and feedforward network. | |
| This standard encoder layer is based on the paper "Attention Is All You Need". | |
| Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, | |
| Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in | |
| Neural Information Processing Systems, pages 6000-6010. Users may modify or implement | |
| in a different way during application. | |
| Args: | |
| d_model: the number of expected features in the input (required). | |
| nhead: the number of heads in the multiheadattention models (required). | |
| dim_feedforward: the dimension of the feedforward network model (default=2048). | |
| dropout: the dropout value (default=0.1). | |
| activation: the activation function of intermediate layer, relu or gelu (default=relu). | |
| layer_norm_eps: the eps value in layer normalization components (default=1e-5). | |
| batch_first: If ``True``, then the input and output tensors are provided | |
| as (batch, seq, feature). Default: ``False``. | |
| Examples:: | |
| >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) | |
| >>> src = torch.rand(10, 32, 512) | |
| >>> out = encoder_layer(src) | |
| Alternatively, when ``batch_first`` is ``True``: | |
| >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) | |
| >>> src = torch.rand(32, 10, 512) | |
| >>> out = encoder_layer(src) | |
| """ | |
| __constants__ = ['batch_first'] | |
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", | |
| layer_norm_eps=1e-5, batch_first=False, | |
| device=None, dtype=None) -> None: | |
| factory_kwargs = {'device': device, 'dtype': dtype} | |
| super(neko_TransformerEncoderLayer, self).__init__() | |
| self.self_attn = neko_MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, | |
| **factory_kwargs) | |
| # Implementation of Feedforward model | |
| self.linear1 = torch.nn.Linear(d_model, dim_feedforward, **factory_kwargs) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.linear2 = torch.nn.Linear(dim_feedforward, d_model, **factory_kwargs) | |
| self.norm1 = torch.nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) | |
| self.norm2 = torch.nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) | |
| self.dropout1 = torch.nn.Dropout(dropout) | |
| self.dropout2 = torch.nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| def __setstate__(self, state): | |
| if 'activation' not in state: | |
| state['activation'] = F.relu | |
| super(neko_TransformerEncoderLayer, self).__setstate__(state) | |
| def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| r"""Pass the input through the encoder layer. | |
| Args: | |
| src: the sequence to the encoder layer (required). | |
| src_mask: the mask for the src sequence (optional). | |
| src_key_padding_mask: the mask for the src keys per batch (optional). | |
| Shape: | |
| see the docs in Transformer class. | |
| """ | |
| src2 = self.self_attn(src, src, src, attn_mask=src_mask, | |
| key_padding_mask=src_key_padding_mask)[0] | |
| src = src + self.dropout1(src2) | |
| src = self.norm1(src) | |
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) | |
| src = src + self.dropout2(src2) | |
| src = self.norm2(src) | |
| return src | |