| from typing import Optional, Dict |
| from torch import Tensor |
| import torch |
|
|
|
|
| def waitk( |
| query, key, waitk_lagging: int, num_heads: int, key_padding_mask: Optional[Tensor] = None, |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None |
| ): |
| if incremental_state is not None: |
| |
| |
| tgt_len = incremental_state["steps"]["tgt"] |
| assert tgt_len is not None |
| tgt_len = int(tgt_len) |
| else: |
| tgt_len, bsz, _ = query.size() |
|
|
| max_src_len, bsz, _ = key.size() |
|
|
| if max_src_len < waitk_lagging: |
| if incremental_state is not None: |
| tgt_len = 1 |
| return query.new_zeros( |
| bsz * num_heads, tgt_len, max_src_len |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| activate_indices_offset = ( |
| ( |
| torch.arange(tgt_len) * (max_src_len + 1) |
| + waitk_lagging - 1 |
| ) |
| .unsqueeze(0) |
| .expand(bsz, tgt_len) |
| .to(query) |
| .long() |
| ) |
|
|
| if key_padding_mask is not None: |
| if key_padding_mask[:, 0].any(): |
| |
| activate_indices_offset += ( |
| key_padding_mask.sum(dim=1, keepdim=True) |
| ) |
|
|
| |
| activate_indices_offset = ( |
| activate_indices_offset |
| .clamp( |
| 0, |
| min( |
| [ |
| tgt_len, |
| max_src_len - waitk_lagging + 1 |
| ] |
| ) * max_src_len - 1 |
| ) |
| ) |
|
|
| p_choose = torch.zeros(bsz, tgt_len * max_src_len).to(query) |
|
|
| p_choose = p_choose.scatter( |
| 1, |
| activate_indices_offset, |
| 1.0 |
| ).view(bsz, tgt_len, max_src_len) |
|
|
| if incremental_state is not None: |
| p_choose = p_choose[:, -1:] |
| tgt_len = 1 |
|
|
| |
| p_choose = ( |
| p_choose.contiguous() |
| .unsqueeze(1) |
| .expand(-1, num_heads, -1, -1) |
| .contiguous() |
| .view(-1, tgt_len, max_src_len) |
| ) |
|
|
| return p_choose |
|
|
|
|
| def hard_aligned(q_proj: Optional[Tensor], k_proj: Optional[Tensor], attn_energy, noise_mean: float = 0.0, noise_var: float = 0.0, training: bool = True): |
| """ |
| Calculating step wise prob for reading and writing |
| 1 to read, 0 to write |
| """ |
|
|
| noise = 0 |
| if training: |
| |
| noise = ( |
| torch.normal(noise_mean, noise_var, attn_energy.size()) |
| .type_as(attn_energy) |
| .to(attn_energy.device) |
| ) |
|
|
| p_choose = torch.sigmoid(attn_energy + noise) |
| _, _, tgt_len, src_len = p_choose.size() |
|
|
| |
| return p_choose.view(-1, tgt_len, src_len) |
|
|