vllm_fairseq / fairseq-0.10.2 /examples /simultaneous_translation /modules /monotonic_multihead_attention.py
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from examples.simultaneous_translation.utils.functions import ( | |
| exclusive_cumprod, | |
| lengths_to_mask, | |
| ) | |
| from fairseq import utils | |
| from fairseq.incremental_decoding_utils import with_incremental_state | |
| from fairseq.modules import MultiheadAttention | |
| from fairseq.utils import convert_padding_direction | |
| from . import register_monotonic_attention | |
| class MonotonicAttention(nn.Module): | |
| """ | |
| Abstract class of monotonic attentions | |
| """ | |
| def __init__(self, args): | |
| self.eps = args.attention_eps | |
| self.mass_preservation = args.mass_preservation | |
| self.noise_mean = args.noise_mean | |
| self.noise_var = args.noise_var | |
| self.energy_bias_init = args.energy_bias_init | |
| self.energy_bias = ( | |
| nn.Parameter(self.energy_bias_init * torch.ones([1])) | |
| if args.energy_bias is True | |
| else 0 | |
| ) | |
| def add_args(parser): | |
| # fmt: off | |
| parser.add_argument('--no-mass-preservation', action="store_false", dest="mass_preservation", | |
| help='Do not stay on the last token when decoding') | |
| parser.add_argument('--mass-preservation', action="store_true", dest="mass_preservation", | |
| help='Stay on the last token when decoding') | |
| parser.set_defaults(mass_preservation=True) | |
| parser.add_argument('--noise-var', type=float, default=1.0, | |
| help='Variance of discretness noise') | |
| parser.add_argument('--noise-mean', type=float, default=0.0, | |
| help='Mean of discretness noise') | |
| parser.add_argument('--energy-bias', action="store_true", default=False, | |
| help='Bias for energy') | |
| parser.add_argument('--energy-bias-init', type=float, default=-2.0, | |
| help='Initial value of the bias for energy') | |
| parser.add_argument('--attention-eps', type=float, default=1e-6, | |
| help='Epsilon when calculating expected attention') | |
| # fmt: on | |
| def p_choose(self, *args): | |
| raise NotImplementedError | |
| def input_projections(self, *args): | |
| raise NotImplementedError | |
| def attn_energy(self, q_proj, k_proj, key_padding_mask=None): | |
| """ | |
| Calculating monotonic energies | |
| ============================================================ | |
| Expected input size | |
| q_proj: bsz * num_heads, tgt_len, self.head_dim | |
| k_proj: bsz * num_heads, src_len, self.head_dim | |
| key_padding_mask: bsz, src_len | |
| attn_mask: tgt_len, src_len | |
| """ | |
| bsz, tgt_len, embed_dim = q_proj.size() | |
| bsz = bsz // self.num_heads | |
| src_len = k_proj.size(1) | |
| attn_energy = torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias | |
| attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len) | |
| if key_padding_mask is not None: | |
| attn_energy = attn_energy.masked_fill( | |
| key_padding_mask.unsqueeze(1).unsqueeze(2).bool(), | |
| float("-inf"), | |
| ) | |
| return attn_energy | |
| def expected_alignment_train(self, p_choose, key_padding_mask): | |
| """ | |
| Calculating expected alignment for MMA | |
| Mask is not need because p_choose will be 0 if masked | |
| q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} | |
| a_ij = p_ij q_ij | |
| parellel solution: | |
| ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) | |
| ============================================================ | |
| Expected input size | |
| p_choose: bsz * num_heads, tgt_len, src_len | |
| """ | |
| # p_choose: bsz * num_heads, tgt_len, src_len | |
| bsz_num_heads, tgt_len, src_len = p_choose.size() | |
| # cumprod_1mp : bsz * num_heads, tgt_len, src_len | |
| cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps) | |
| cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0) | |
| init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len]) | |
| init_attention[:, :, 0] = 1.0 | |
| previous_attn = [init_attention] | |
| for i in range(tgt_len): | |
| # p_choose: bsz * num_heads, tgt_len, src_len | |
| # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len | |
| # previous_attn[i]: bsz * num_heads, 1, src_len | |
| # alpha_i: bsz * num_heads, src_len | |
| alpha_i = ( | |
| p_choose[:, i] | |
| * cumprod_1mp[:, i] | |
| * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1) | |
| ).clamp(0, 1.0) | |
| previous_attn.append(alpha_i.unsqueeze(1)) | |
| # alpha: bsz * num_heads, tgt_len, src_len | |
| alpha = torch.cat(previous_attn[1:], dim=1) | |
| if self.mass_preservation: | |
| # Last token has the residual probabilities | |
| alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) | |
| assert not torch.isnan(alpha).any(), "NaN detected in alpha." | |
| return alpha | |
| def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state): | |
| """ | |
| Calculating mo alignment for MMA during inference time | |
| ============================================================ | |
| Expected input size | |
| p_choose: bsz * num_heads, tgt_len, src_len | |
| key_padding_mask: bsz * src_len | |
| incremental_state: dict | |
| """ | |
| # p_choose: bsz * self.num_heads, src_len | |
| bsz_num_heads, tgt_len, src_len = p_choose.size() | |
| # One token at a time | |
| assert tgt_len == 1 | |
| p_choose = p_choose[:, 0, :] | |
| monotonic_cache = self._get_monotonic_buffer(incremental_state) | |
| # prev_monotonic_step: bsz, num_heads | |
| bsz = bsz_num_heads // self.num_heads | |
| prev_monotonic_step = monotonic_cache.get( | |
| "step", p_choose.new_zeros([bsz, self.num_heads]).long() | |
| ) | |
| bsz, num_heads = prev_monotonic_step.size() | |
| assert num_heads == self.num_heads | |
| assert bsz * num_heads == bsz_num_heads | |
| # p_choose: bsz, num_heads, src_len | |
| p_choose = p_choose.view(bsz, num_heads, src_len) | |
| if key_padding_mask is not None: | |
| src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long() | |
| else: | |
| src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len | |
| # src_lengths: bsz, num_heads | |
| src_lengths = src_lengths.expand_as(prev_monotonic_step) | |
| # new_monotonic_step: bsz, num_heads | |
| new_monotonic_step = prev_monotonic_step | |
| step_offset = 0 | |
| if key_padding_mask is not None: | |
| if key_padding_mask[:, 0].any(): | |
| # left_pad_source = True: | |
| step_offset = key_padding_mask.sum(dim=-1, keepdim=True) | |
| max_steps = src_lengths - 1 if self.mass_preservation else src_lengths | |
| # finish_read: bsz, num_heads | |
| finish_read = new_monotonic_step.eq(max_steps) | |
| while finish_read.sum().item() < bsz * self.num_heads: | |
| # p_choose: bsz * self.num_heads, src_len | |
| # only choose the p at monotonic steps | |
| # p_choose_i: bsz , self.num_heads | |
| p_choose_i = ( | |
| p_choose.gather( | |
| 2, | |
| (step_offset + new_monotonic_step) | |
| .unsqueeze(2) | |
| .clamp(0, src_len - 1), | |
| ) | |
| ).squeeze(2) | |
| action = ( | |
| (p_choose_i < 0.5) | |
| .type_as(prev_monotonic_step) | |
| .masked_fill(finish_read, 0) | |
| ) | |
| # 1 x bsz | |
| # sample actions on unfinished seq | |
| # 1 means stay, finish reading | |
| # 0 means leave, continue reading | |
| # dist = torch.distributions.bernoulli.Bernoulli(p_choose) | |
| # action = dist.sample().type_as(finish_read) * (1 - finish_read) | |
| new_monotonic_step += action | |
| finish_read = new_monotonic_step.eq(max_steps) | (action == 0) | |
| # finish_read = (~ (finish_read.sum(dim=1, keepdim=True) < self.num_heads / 2)) | finish_read | |
| monotonic_cache["step"] = new_monotonic_step | |
| # alpha: bsz * num_heads, 1, src_len | |
| # new_monotonic_step: bsz, num_heads | |
| alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter( | |
| 1, | |
| (step_offset + new_monotonic_step) | |
| .view(bsz * self.num_heads, 1) | |
| .clamp(0, src_len - 1), | |
| 1, | |
| ) | |
| if not self.mass_preservation: | |
| alpha = alpha.masked_fill( | |
| (new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0 | |
| ) | |
| alpha = alpha.unsqueeze(1) | |
| self._set_monotonic_buffer(incremental_state, monotonic_cache) | |
| return alpha | |
| def v_proj_output(self, value): | |
| raise NotImplementedError | |
| def forward( | |
| self, | |
| query, | |
| key, | |
| value, | |
| key_padding_mask=None, | |
| incremental_state=None, | |
| *args, | |
| **kwargs, | |
| ): | |
| tgt_len, bsz, embed_dim = query.size() | |
| src_len = value.size(0) | |
| # stepwise prob | |
| # p_choose: bsz * self.num_heads, tgt_len, src_len | |
| p_choose = self.p_choose(query, key, key_padding_mask) | |
| # expected alignment alpha | |
| # bsz * self.num_heads, tgt_len, src_len | |
| if incremental_state is not None: | |
| alpha = self.expected_alignment_infer( | |
| p_choose, key_padding_mask, incremental_state | |
| ) | |
| else: | |
| alpha = self.expected_alignment_train(p_choose, key_padding_mask) | |
| # expected attention beta | |
| # bsz * self.num_heads, tgt_len, src_len | |
| beta = self.expected_attention( | |
| alpha, query, key, value, key_padding_mask, incremental_state | |
| ) | |
| attn_weights = beta | |
| v_proj = self.v_proj_output(value) | |
| attn = torch.bmm(attn_weights.type_as(v_proj), v_proj) | |
| attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) | |
| attn = self.out_proj(attn) | |
| beta = beta.view(bsz, self.num_heads, tgt_len, src_len) | |
| alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) | |
| p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) | |
| return attn, {"alpha": alpha, "beta": beta, "p_choose": p_choose} | |
| def reorder_incremental_state(self, incremental_state, new_order): | |
| """Reorder buffered internal state (for incremental generation).""" | |
| super().reorder_incremental_state(incremental_state, new_order) | |
| input_buffer = self._get_monotonic_buffer(incremental_state) | |
| if input_buffer is not None: | |
| for k in input_buffer.keys(): | |
| input_buffer[k] = input_buffer[k].index_select(0, new_order) | |
| self._set_monotonic_buffer(incremental_state, input_buffer) | |
| def _get_monotonic_buffer(self, incremental_state): | |
| return ( | |
| utils.get_incremental_state( | |
| self, | |
| incremental_state, | |
| "monotonic", | |
| ) | |
| or {} | |
| ) | |
| def _set_monotonic_buffer(self, incremental_state, buffer): | |
| utils.set_incremental_state( | |
| self, | |
| incremental_state, | |
| "monotonic", | |
| buffer, | |
| ) | |
| def get_pointer(self, incremental_state): | |
| return ( | |
| utils.get_incremental_state( | |
| self, | |
| incremental_state, | |
| "monotonic", | |
| ) | |
| or {} | |
| ) | |
| def get_fastest_pointer(self, incremental_state): | |
| return self.get_pointer(incremental_state)["step"].max(0)[0] | |
| def set_pointer(self, incremental_state, p_choose): | |
| curr_pointer = self.get_pointer(incremental_state) | |
| if len(curr_pointer) == 0: | |
| buffer = torch.zeros_like(p_choose) | |
| else: | |
| buffer = self.get_pointer(incremental_state)["step"] | |
| buffer += (p_choose < 0.5).type_as(buffer) | |
| utils.set_incremental_state( | |
| self, | |
| incremental_state, | |
| "monotonic", | |
| {"step": buffer}, | |
| ) | |
| class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention): | |
| def __init__(self, args): | |
| MultiheadAttention.__init__( | |
| self, | |
| embed_dim=args.decoder_embed_dim, | |
| num_heads=args.decoder_attention_heads, | |
| kdim=getattr(args, "encoder_embed_dim", None), | |
| vdim=getattr(args, "encoder_embed_dim", None), | |
| dropout=args.attention_dropout, | |
| encoder_decoder_attention=True, | |
| ) | |
| MonotonicAttention.__init__(self, args) | |
| self.k_in_proj = {"monotonic": self.k_proj} | |
| self.q_in_proj = {"monotonic": self.q_proj} | |
| self.v_in_proj = {"output": self.v_proj} | |
| def input_projections(self, query, key, value, name): | |
| """ | |
| Prepare inputs for multihead attention | |
| ============================================================ | |
| Expected input size | |
| query: tgt_len, bsz, embed_dim | |
| key: src_len, bsz, embed_dim | |
| value: src_len, bsz, embed_dim | |
| name: monotonic or soft | |
| """ | |
| if query is not None: | |
| bsz = query.size(1) | |
| q = self.q_in_proj[name](query) | |
| q *= self.scaling | |
| q = ( | |
| q.contiguous() | |
| .view(-1, bsz * self.num_heads, self.head_dim) | |
| .transpose(0, 1) | |
| ) | |
| else: | |
| q = None | |
| if key is not None: | |
| bsz = key.size(1) | |
| k = self.k_in_proj[name](key) | |
| k = ( | |
| k.contiguous() | |
| .view(-1, bsz * self.num_heads, self.head_dim) | |
| .transpose(0, 1) | |
| ) | |
| else: | |
| k = None | |
| if value is not None: | |
| bsz = value.size(1) | |
| v = self.v_in_proj[name](value) | |
| v = ( | |
| v.contiguous() | |
| .view(-1, bsz * self.num_heads, self.head_dim) | |
| .transpose(0, 1) | |
| ) | |
| else: | |
| v = None | |
| return q, k, v | |
| def p_choose(self, query, key, key_padding_mask=None): | |
| """ | |
| Calculating step wise prob for reading and writing | |
| 1 to read, 0 to write | |
| ============================================================ | |
| Expected input size | |
| query: bsz, tgt_len, embed_dim | |
| key: bsz, src_len, embed_dim | |
| value: bsz, src_len, embed_dim | |
| key_padding_mask: bsz, src_len | |
| attn_mask: bsz, src_len | |
| query: bsz, tgt_len, embed_dim | |
| """ | |
| # prepare inputs | |
| q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic") | |
| # attention energy | |
| attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask) | |
| noise = 0 | |
| if self.training: | |
| # add noise here to encourage discretness | |
| noise = ( | |
| torch.normal(self.noise_mean, self.noise_var, attn_energy.size()) | |
| .type_as(attn_energy) | |
| .to(attn_energy.device) | |
| ) | |
| p_choose = torch.sigmoid(attn_energy + noise) | |
| _, _, tgt_len, src_len = p_choose.size() | |
| # p_choose: bsz * self.num_heads, tgt_len, src_len | |
| return p_choose.view(-1, tgt_len, src_len) | |
| def expected_attention(self, alpha, *args): | |
| """ | |
| For MMA-H, beta = alpha | |
| """ | |
| return alpha | |
| def v_proj_output(self, value): | |
| _, _, v_proj = self.input_projections(None, None, value, "output") | |
| return v_proj | |
| class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHard): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.init_soft_attention() | |
| def init_soft_attention(self): | |
| self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True) | |
| self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True) | |
| self.k_in_proj["soft"] = self.k_proj_soft | |
| self.q_in_proj["soft"] = self.q_proj_soft | |
| if self.qkv_same_dim: | |
| # Empirically observed the convergence to be much better with | |
| # the scaled initialization | |
| nn.init.xavier_uniform_( | |
| self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2) | |
| ) | |
| nn.init.xavier_uniform_( | |
| self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2) | |
| ) | |
| else: | |
| nn.init.xavier_uniform_(self.k_in_proj["soft"].weight) | |
| nn.init.xavier_uniform_(self.q_in_proj["soft"].weight) | |
| def expected_attention( | |
| self, alpha, query, key, value, key_padding_mask, incremental_state | |
| ): | |
| # monotonic attention, we will calculate milk here | |
| bsz_x_num_heads, tgt_len, src_len = alpha.size() | |
| bsz = int(bsz_x_num_heads / self.num_heads) | |
| q, k, _ = self.input_projections(query, key, None, "soft") | |
| soft_energy = self.attn_energy(q, k, key_padding_mask) | |
| assert list(soft_energy.size()) == [bsz, self.num_heads, tgt_len, src_len] | |
| soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len) | |
| if incremental_state is not None: | |
| monotonic_cache = self._get_monotonic_buffer(incremental_state) | |
| monotonic_step = monotonic_cache["step"] + 1 | |
| step_offset = 0 | |
| if key_padding_mask is not None: | |
| if key_padding_mask[:, 0].any(): | |
| # left_pad_source = True: | |
| step_offset = key_padding_mask.sum(dim=-1, keepdim=True) | |
| monotonic_step += step_offset | |
| mask = lengths_to_mask( | |
| monotonic_step.view(-1), soft_energy.size(2), 1 | |
| ).unsqueeze(1) | |
| soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf")) | |
| soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] | |
| exp_soft_energy = torch.exp(soft_energy) | |
| exp_soft_energy_sum = exp_soft_energy.sum(dim=2) | |
| beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2) | |
| else: | |
| # bsz * num_heads, tgt_len, src_len | |
| soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] | |
| exp_soft_energy = torch.exp(soft_energy) | |
| exp_soft_energy_cumsum = torch.cumsum(exp_soft_energy, dim=2) | |
| if key_padding_mask is not None: | |
| if key_padding_mask.any(): | |
| exp_soft_energy_cumsum = ( | |
| exp_soft_energy_cumsum.view( | |
| -1, self.num_heads, tgt_len, src_len | |
| ) | |
| .masked_fill( | |
| key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps | |
| ) | |
| .view(-1, tgt_len, src_len) | |
| ) | |
| inner_items = alpha / exp_soft_energy_cumsum | |
| beta = exp_soft_energy * torch.cumsum( | |
| inner_items.flip(dims=[2]), dim=2 | |
| ).flip(dims=[2]) | |
| beta = self.dropout_module(beta) | |
| assert not torch.isnan(beta).any(), "NaN detected in beta." | |
| return beta | |
| class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookback): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.q_in_proj["soft"] = self.q_in_proj["monotonic"] | |
| self.k_in_proj["soft"] = self.k_in_proj["monotonic"] | |
| self.waitk_lagging = args.waitk_lagging | |
| assert ( | |
| self.waitk_lagging > 0 | |
| ), f"Lagging has to been larger than 0, get {self.waitk_lagging}." | |
| def add_args(parser): | |
| super( | |
| MonotonicMultiheadAttentionWaitk, | |
| MonotonicMultiheadAttentionWaitk, | |
| ).add_args(parser) | |
| parser.add_argument( | |
| "--waitk-lagging", type=int, required=True, help="Wait k lagging" | |
| ) | |
| def p_choose( | |
| self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None | |
| ): | |
| """ | |
| query: bsz, tgt_len | |
| key: bsz, src_len | |
| key_padding_mask: bsz, src_len | |
| """ | |
| src_len, bsz, _ = key.size() | |
| tgt_len, bsz, _ = query.size() | |
| p_choose = query.new_ones(bsz, tgt_len, src_len) | |
| p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1) | |
| p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1) | |
| if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any(): | |
| # Left pad source | |
| # add -1 to the end | |
| p_choose = p_choose.masked_fill( | |
| key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1 | |
| ) | |
| p_choose = convert_padding_direction( | |
| p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True | |
| ) | |
| p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query) | |
| # remove -1 | |
| p_choose[p_choose.eq(-1)] = 0 | |
| # Extend to each head | |
| p_choose = ( | |
| p_choose.contiguous() | |
| .unsqueeze(1) | |
| .expand(-1, self.num_heads, -1, -1) | |
| .contiguous() | |
| .view(-1, tgt_len, src_len) | |
| ) | |
| return p_choose | |