| | |
| | """ Define the Attention Layer of the model. |
| | """ |
| |
|
| | from __future__ import print_function, division |
| |
|
| | import torch |
| |
|
| | from torch.autograd import Variable |
| | from torch.nn import Module |
| | from torch.nn.parameter import Parameter |
| |
|
| | class Attention(Module): |
| | """ |
| | Computes a weighted average of the different channels across timesteps. |
| | Uses 1 parameter pr. channel to compute the attention value for a single timestep. |
| | """ |
| |
|
| | def __init__(self, attention_size, return_attention=False): |
| | """ Initialize the attention layer |
| | |
| | # Arguments: |
| | attention_size: Size of the attention vector. |
| | return_attention: If true, output will include the weight for each input token |
| | used for the prediction |
| | |
| | """ |
| | super(Attention, self).__init__() |
| | self.return_attention = return_attention |
| | self.attention_size = attention_size |
| | self.attention_vector = Parameter(torch.FloatTensor(attention_size)) |
| | self.attention_vector.data.normal_(std=0.05) |
| |
|
| | def __repr__(self): |
| | s = '{name}({attention_size}, return attention={return_attention})' |
| | return s.format(name=self.__class__.__name__, **self.__dict__) |
| |
|
| | def forward(self, inputs, input_lengths): |
| | """ Forward pass. |
| | |
| | # Arguments: |
| | inputs (Torch.Variable): Tensor of input sequences |
| | input_lengths (torch.LongTensor): Lengths of the sequences |
| | |
| | # Return: |
| | Tuple with (representations and attentions if self.return_attention else None). |
| | """ |
| | logits = inputs.matmul(self.attention_vector) |
| | unnorm_ai = (logits - logits.max()).exp() |
| |
|
| | |
| | |
| | max_len = unnorm_ai.size(1) |
| | idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0) |
| | mask = Variable((idxes < input_lengths.unsqueeze(1)).float()) |
| |
|
| | |
| | masked_weights = unnorm_ai * mask |
| | att_sums = masked_weights.sum(dim=1, keepdim=True) |
| | attentions = masked_weights.div(att_sums) |
| |
|
| | |
| | weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs)) |
| |
|
| | |
| | representations = weighted.sum(dim=1) |
| |
|
| | return (representations, attentions if self.return_attention else None) |
| |
|