| | from functools import partial |
| |
|
| | import torch |
| | from torch import Tensor |
| | import math |
| | import torch.nn.functional as F |
| |
|
| | from . import register_monotonic_attention |
| | from .monotonic_multihead_attention import ( |
| | MonotonicAttention, |
| | MonotonicInfiniteLookbackAttention, |
| | WaitKAttention |
| | ) |
| | from typing import Dict, Optional |
| |
|
| |
|
| | def fixed_pooling_monotonic_attention(monotonic_attention): |
| | def create_model(monotonic_attention, klass): |
| | class FixedStrideMonotonicAttention(monotonic_attention): |
| | def __init__(self, args): |
| | self.waitk_lagging = 0 |
| | self.num_heads = 0 |
| | self.noise_mean = 0.0 |
| | self.noise_var = 0.0 |
| | super().__init__(args) |
| | self.pre_decision_type = args.fixed_pre_decision_type |
| | self.pre_decision_ratio = args.fixed_pre_decision_ratio |
| | self.pre_decision_pad_threshold = args.fixed_pre_decision_pad_threshold |
| | assert self.pre_decision_ratio > 1 |
| |
|
| | if args.fixed_pre_decision_type == "average": |
| | self.pooling_layer = torch.nn.AvgPool1d( |
| | kernel_size=self.pre_decision_ratio, |
| | stride=self.pre_decision_ratio, |
| | ceil_mode=True, |
| | ) |
| | elif args.fixed_pre_decision_type == "last": |
| |
|
| | def last(key): |
| | if key.size(2) < self.pre_decision_ratio: |
| | return key |
| | else: |
| | k = key[ |
| | :, |
| | :, |
| | self.pre_decision_ratio - 1:: self.pre_decision_ratio, |
| | ].contiguous() |
| | if key.size(-1) % self.pre_decision_ratio != 0: |
| | k = torch.cat([k, key[:, :, -1:]], dim=-1).contiguous() |
| | return k |
| |
|
| | self.pooling_layer = last |
| | else: |
| | raise NotImplementedError |
| |
|
| | @staticmethod |
| | def add_args(parser): |
| | super( |
| | FixedStrideMonotonicAttention, FixedStrideMonotonicAttention |
| | ).add_args(parser) |
| | parser.add_argument( |
| | "--fixed-pre-decision-ratio", |
| | type=int, |
| | required=True, |
| | help=( |
| | "Ratio for the fixed pre-decision," |
| | "indicating how many encoder steps will start" |
| | "simultaneous decision making process." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--fixed-pre-decision-type", |
| | default="average", |
| | choices=["average", "last"], |
| | help="Pooling type", |
| | ) |
| | parser.add_argument( |
| | "--fixed-pre-decision-pad-threshold", |
| | type=float, |
| | default=0.3, |
| | help="If a part of the sequence has pad" |
| | ",the threshold the pooled part is a pad.", |
| | ) |
| |
|
| | def insert_zeros(self, x): |
| | bsz_num_heads, tgt_len, src_len = x.size() |
| | stride = self.pre_decision_ratio |
| | weight = F.pad(torch.ones(1, 1, 1).to(x), (stride - 1, 0)) |
| | x_upsample = F.conv_transpose1d( |
| | x.view(-1, src_len).unsqueeze(1), |
| | weight, |
| | stride=stride, |
| | padding=0, |
| | ) |
| | return x_upsample.squeeze(1).view(bsz_num_heads, tgt_len, -1) |
| |
|
| | def p_choose( |
| | self, |
| | query: Optional[Tensor], |
| | key: Optional[Tensor], |
| | key_padding_mask: Optional[Tensor] = None, |
| | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
| | ): |
| | assert key is not None |
| | assert query is not None |
| | src_len = key.size(0) |
| | tgt_len = query.size(0) |
| | batch_size = query.size(1) |
| |
|
| | key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2) |
| |
|
| | if key_padding_mask is not None: |
| | key_padding_mask_pool = ( |
| | self.pooling_layer(key_padding_mask.unsqueeze(0).float()) |
| | .squeeze(0) |
| | .gt(self.pre_decision_pad_threshold) |
| | ) |
| | |
| | key_padding_mask_pool[:, 0] = 0 |
| | else: |
| | key_padding_mask_pool = None |
| |
|
| | if incremental_state is not None: |
| | |
| | |
| | if ( |
| | max(1, math.floor(key.size(0) / self.pre_decision_ratio)) |
| | ) < key_pool.size(0): |
| | key_pool = key_pool[:-1] |
| | if key_padding_mask_pool is not None: |
| | key_padding_mask_pool = key_padding_mask_pool[:-1] |
| |
|
| | p_choose_pooled = self.p_choose_from_qk( |
| | query, |
| | key_pool, |
| | key_padding_mask_pool, |
| | incremental_state=incremental_state, |
| | ) |
| |
|
| | |
| | p_choose = self.insert_zeros(p_choose_pooled) |
| |
|
| | if p_choose.size(-1) < src_len: |
| | |
| | p_choose = torch.cat( |
| | [ |
| | p_choose, |
| | torch.zeros( |
| | p_choose.size(0), |
| | tgt_len, |
| | src_len - p_choose.size(-1) |
| | ).to(p_choose) |
| | ], |
| | dim=2 |
| | ) |
| | else: |
| | |
| | p_choose = p_choose[:, :, :src_len] |
| | p_choose[:, :, -1] = p_choose_pooled[:, :, -1] |
| |
|
| | assert list(p_choose.size()) == [ |
| | batch_size * self.num_heads, |
| | tgt_len, |
| | src_len, |
| | ] |
| |
|
| | return p_choose |
| |
|
| | FixedStrideMonotonicAttention.__name__ = klass.__name__ |
| | return FixedStrideMonotonicAttention |
| |
|
| | return partial(create_model, monotonic_attention) |
| |
|
| |
|
| | @register_monotonic_attention("waitk_fixed_pre_decision") |
| | @fixed_pooling_monotonic_attention(WaitKAttention) |
| | class WaitKAttentionFixedStride: |
| | pass |
| |
|
| |
|
| | @register_monotonic_attention("hard_aligned_fixed_pre_decision") |
| | @fixed_pooling_monotonic_attention(MonotonicAttention) |
| | class MonotonicAttentionFixedStride: |
| | pass |
| |
|
| |
|
| | @register_monotonic_attention("infinite_lookback_fixed_pre_decision") |
| | @fixed_pooling_monotonic_attention(MonotonicInfiniteLookbackAttention) |
| | class MonotonicInfiniteLookbackAttentionFixedStride: |
| | pass |
| |
|