| """ Multi Step Attention for CNN """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| SCALE_WEIGHT = 0.5**0.5 | |
| def seq_linear(linear, x): | |
| """linear transform for 3-d tensor""" | |
| batch, hidden_size, length, _ = x.size() | |
| h = linear(torch.transpose(x, 1, 2).contiguous().view(batch * length, hidden_size)) | |
| return torch.transpose(h.view(batch, length, hidden_size, 1), 1, 2) | |
| class ConvMultiStepAttention(nn.Module): | |
| """ | |
| Conv attention takes a key matrix, a value matrix and a query vector. | |
| Attention weight is calculated by key matrix with the query vector | |
| and sum on the value matrix. And the same operation is applied | |
| in each decode conv layer. | |
| """ | |
| def __init__(self, input_size): | |
| super(ConvMultiStepAttention, self).__init__() | |
| self.linear_in = nn.Linear(input_size, input_size) | |
| self.mask = None | |
| def apply_mask(self, mask): | |
| """Apply mask""" | |
| self.mask = mask | |
| def forward( | |
| self, base_target_emb, input_from_dec, encoder_out_top, encoder_out_combine | |
| ): | |
| """ | |
| Args: | |
| base_target_emb: target emb tensor | |
| ``(batch, channel, height, width)`` | |
| input_from_dec: output of dec conv | |
| ``(batch, channel, height, width)`` | |
| encoder_out_top: the key matrix for calc of attention weight, | |
| which is the top output of encode conv | |
| encoder_out_combine: | |
| the value matrix for the attention-weighted sum, | |
| which is the combination of base emb and top output of encode | |
| """ | |
| preatt = seq_linear(self.linear_in, input_from_dec) | |
| target = (base_target_emb + preatt) * SCALE_WEIGHT | |
| target = torch.squeeze(target, 3) | |
| target = torch.transpose(target, 1, 2) | |
| pre_attn = torch.bmm(target, encoder_out_top) | |
| if self.mask is not None: | |
| pre_attn.data.masked_fill_(self.mask, -float("inf")) | |
| attn = F.softmax(pre_attn, dim=2) | |
| context_output = torch.bmm(attn, torch.transpose(encoder_out_combine, 1, 2)) | |
| context_output = torch.transpose(torch.unsqueeze(context_output, 3), 1, 2) | |
| return context_output, attn | |