| |
| |
|
|
| import math |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| try: |
| from apex.normalization import FusedLayerNorm as LayerNorm |
| except ModuleNotFoundError: |
| from torch.nn import LayerNorm |
|
|
| from .multiway_network import MultiwayWrapper |
| from .xpos_relative_position import XPOS |
|
|
|
|
| class MultiheadAttention(nn.Module): |
| def __init__( |
| self, |
| args, |
| embed_dim, |
| num_heads, |
| dropout=0.0, |
| self_attention=False, |
| encoder_decoder_attention=False, |
| subln=False, |
| one_attn=False, |
| ): |
| super().__init__() |
| self.args = args |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.head_dim = embed_dim // num_heads |
| self.scaling = self.head_dim ** (-0.5) |
| self.self_attention = self_attention |
| self.encoder_decoder_attention = encoder_decoder_attention |
| assert self.self_attention ^ self.encoder_decoder_attention |
| if one_attn: |
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| else: |
| self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) |
| self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) |
| self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) |
| |
| self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) |
| self.inner_attn_ln = ( |
| MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) |
| if subln and self.self_attention |
| else None |
| ) |
| self.dropout_module = torch.nn.Dropout(dropout) |
| self.xpos = XPOS(self.head_dim, args.xpos_scale_base) if args.xpos_rel_pos and self.self_attention else None |
|
|
| def reset_parameters(self): |
| nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) |
| nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) |
| nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) |
| nn.init.xavier_uniform_(self.out_proj.weight) |
| nn.init.constant_(self.out_proj.bias, 0.0) |
|
|
| def forward( |
| self, |
| query, |
| key, |
| value, |
| incremental_state=None, |
| key_padding_mask=None, |
| attn_mask=None, |
| rel_pos=None, |
| ): |
| bsz, tgt_len, embed_dim = query.size() |
| src_len = tgt_len |
| assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" |
|
|
| key_bsz, src_len, _ = key.size() |
| assert key_bsz == bsz, f"{query.size(), key.size()}" |
| assert value is not None |
| assert bsz, src_len == value.shape[:2] |
| |
| |
| |
| |
| |
| |
|
|
| q = self.q_proj(query) |
| k = self.k_proj(key) |
| v = self.v_proj(value) |
|
|
| q = (q * self.scaling).view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) |
| k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) |
| v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) |
| q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) |
| k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) |
| v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) |
|
|
| if incremental_state is not None: |
| if "prev_key" in incremental_state: |
| prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim) |
| prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim) |
| k = torch.cat([prev_key, k], dim=1) |
| v = torch.cat([prev_value, v], dim=1) |
| incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) |
| incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) |
| src_len = k.size(1) |
|
|
| if self.xpos is not None: |
| if incremental_state is not None: |
| offset = src_len - 1 |
| else: |
| offset = 0 |
| k = self.xpos(k, offset=0, downscale=True) |
| q = self.xpos(q, offset=offset, downscale=False) |
|
|
| attn_weights = torch.bmm(q, k.transpose(1, 2)) |
|
|
| if attn_mask is not None: |
| attn_weights = torch.nan_to_num(attn_weights) |
| if len(attn_mask.shape) != len(attn_weights.shape): |
| attn_mask = attn_mask.unsqueeze(0) |
| else: |
| attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0) |
| attn_weights += attn_mask |
|
|
| if key_padding_mask is not None: |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| attn_weights = attn_weights.masked_fill( |
| key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), |
| float("-inf"), |
| ) |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
| if rel_pos is not None: |
| rel_pos = rel_pos.view(attn_weights.size()) |
| attn_weights = attn_weights + rel_pos |
|
|
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) |
| attn_probs = self.dropout_module(attn_weights) |
|
|
| attn = torch.bmm(attn_probs, v) |
| attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) |
|
|
| if self.inner_attn_ln is not None: |
| attn = self.inner_attn_ln(attn) |
|
|
| attn = self.out_proj(attn) |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) |
|
|
| return attn, attn_weights |
|
|