Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """Average Attention module.""" | |
| import torch | |
| import torch.nn as nn | |
| from .position_ffn import PositionwiseFeedForward | |
| class AverageAttention(nn.Module): | |
| """ | |
| Average Attention module from | |
| "Accelerating Neural Transformer via an Average Attention Network" | |
| :cite:`DBLP:journals/corr/abs-1805-00631`. | |
| Args: | |
| model_dim (int): the dimension of keys/values/queries, | |
| must be divisible by head_count | |
| dropout (float): dropout parameter | |
| """ | |
| def __init__(self, model_dim, dropout=0.1, aan_useffn=False): | |
| self.model_dim = model_dim | |
| self.aan_useffn = aan_useffn | |
| super(AverageAttention, self).__init__() | |
| if aan_useffn: | |
| self.average_layer = PositionwiseFeedForward(model_dim, model_dim, | |
| dropout) | |
| self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2) | |
| def cumulative_average_mask(self, batch_size, inputs_len, device): | |
| """ | |
| Builds the mask to compute the cumulative average as described in | |
| :cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3 | |
| Args: | |
| batch_size (int): batch size | |
| inputs_len (int): length of the inputs | |
| Returns: | |
| (FloatTensor): | |
| * A Tensor of shape ``(batch_size, input_len, input_len)`` | |
| """ | |
| triangle = torch.tril(torch.ones(inputs_len, inputs_len, | |
| dtype=torch.float, device=device)) | |
| weights = torch.ones(1, inputs_len, dtype=torch.float, device=device) \ | |
| / torch.arange(1, inputs_len + 1, dtype=torch.float, device=device) | |
| mask = triangle * weights.transpose(0, 1) | |
| return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len) | |
| def cumulative_average(self, inputs, mask_or_step, | |
| layer_cache=None, step=None): | |
| """ | |
| Computes the cumulative average as described in | |
| :cite:`DBLP:journals/corr/abs-1805-00631` -- Equations (1) (5) (6) | |
| Args: | |
| inputs (FloatTensor): sequence to average | |
| ``(batch_size, input_len, dimension)`` | |
| mask_or_step: if cache is set, this is assumed | |
| to be the current step of the | |
| dynamic decoding. Otherwise, it is the mask matrix | |
| used to compute the cumulative average. | |
| layer_cache: a dictionary containing the cumulative average | |
| of the previous step. | |
| Returns: | |
| a tensor of the same shape and type as ``inputs``. | |
| """ | |
| if layer_cache is not None: | |
| step = mask_or_step | |
| average_attention = (inputs + step * | |
| layer_cache["prev_g"]) / (step + 1) | |
| layer_cache["prev_g"] = average_attention | |
| return average_attention | |
| else: | |
| mask = mask_or_step | |
| return torch.matmul(mask.to(inputs.dtype), inputs) | |
| def forward(self, inputs, mask=None, layer_cache=None, step=None): | |
| """ | |
| Args: | |
| inputs (FloatTensor): ``(batch_size, input_len, model_dim)`` | |
| Returns: | |
| (FloatTensor, FloatTensor): | |
| * gating_outputs ``(batch_size, input_len, model_dim)`` | |
| * average_outputs average attention | |
| ``(batch_size, input_len, model_dim)`` | |
| """ | |
| batch_size = inputs.size(0) | |
| inputs_len = inputs.size(1) | |
| average_outputs = self.cumulative_average( | |
| inputs, self.cumulative_average_mask(batch_size, | |
| inputs_len, inputs.device) | |
| if layer_cache is None else step, layer_cache=layer_cache) | |
| if self.aan_useffn: | |
| average_outputs = self.average_layer(average_outputs) | |
| gating_outputs = self.gating_layer(torch.cat((inputs, | |
| average_outputs), -1)) | |
| input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2) | |
| gating_outputs = torch.sigmoid(input_gate) * inputs + \ | |
| torch.sigmoid(forget_gate) * average_outputs | |
| return gating_outputs, average_outputs | |