Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # Copyright 2017-present, Facebook, Inc. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Definitions of model layers/NN modules""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # ------------------------------------------------------------------------------ | |
| # Modules | |
| # ------------------------------------------------------------------------------ | |
| class StackedBRNN(nn.Module): | |
| """Stacked Bi-directional RNNs. | |
| Differs from standard PyTorch library in that it has the option to save | |
| and concat the hidden states between layers. (i.e. the output hidden size | |
| for each sequence input is num_layers * hidden_size). | |
| """ | |
| def __init__(self, input_size, hidden_size, num_layers, | |
| dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM, | |
| concat_layers=False, padding=False): | |
| super(StackedBRNN, self).__init__() | |
| self.padding = padding | |
| self.dropout_output = dropout_output | |
| self.dropout_rate = dropout_rate | |
| self.num_layers = num_layers | |
| self.concat_layers = concat_layers | |
| self.rnns = nn.ModuleList() | |
| for i in range(num_layers): | |
| input_size = input_size if i == 0 else 2 * hidden_size | |
| self.rnns.append(rnn_type(input_size, hidden_size, | |
| num_layers=1, | |
| bidirectional=True)) | |
| def forward(self, x, x_mask): | |
| """Encode either padded or non-padded sequences. | |
| Can choose to either handle or ignore variable length sequences. | |
| Always handle padding in eval. | |
| Args: | |
| x: batch * len * hdim | |
| x_mask: batch * len (1 for padding, 0 for true) | |
| Output: | |
| x_encoded: batch * len * hdim_encoded | |
| """ | |
| if x_mask.data.sum() == 0: | |
| # No padding necessary. | |
| output = self._forward_unpadded(x, x_mask) | |
| elif self.padding or not self.training: | |
| # Pad if we care or if its during eval. | |
| output = self._forward_padded(x, x_mask) | |
| else: | |
| # We don't care. | |
| output = self._forward_unpadded(x, x_mask) | |
| return output.contiguous() | |
| def _forward_unpadded(self, x, x_mask): | |
| """Faster encoding that ignores any padding.""" | |
| # Transpose batch and sequence dims | |
| x = x.transpose(0, 1) | |
| # Encode all layers | |
| outputs = [x] | |
| for i in range(self.num_layers): | |
| rnn_input = outputs[-1] | |
| # Apply dropout to hidden input | |
| if self.dropout_rate > 0: | |
| rnn_input = F.dropout(rnn_input, | |
| p=self.dropout_rate, | |
| training=self.training) | |
| # Forward | |
| rnn_output = self.rnns[i](rnn_input)[0] | |
| outputs.append(rnn_output) | |
| # Concat hidden layers | |
| if self.concat_layers: | |
| output = torch.cat(outputs[1:], 2) | |
| else: | |
| output = outputs[-1] | |
| # Transpose back | |
| output = output.transpose(0, 1) | |
| # Dropout on output layer | |
| if self.dropout_output and self.dropout_rate > 0: | |
| output = F.dropout(output, | |
| p=self.dropout_rate, | |
| training=self.training) | |
| return output | |
| def _forward_padded(self, x, x_mask): | |
| """Slower (significantly), but more precise, encoding that handles | |
| padding. | |
| """ | |
| # Compute sorted sequence lengths | |
| lengths = x_mask.data.eq(0).long().sum(1).squeeze() | |
| _, idx_sort = torch.sort(lengths, dim=0, descending=True) | |
| _, idx_unsort = torch.sort(idx_sort, dim=0) | |
| lengths = list(lengths[idx_sort]) | |
| # Sort x | |
| x = x.index_select(0, idx_sort) | |
| # Transpose batch and sequence dims | |
| x = x.transpose(0, 1) | |
| # Pack it up | |
| rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths) | |
| # Encode all layers | |
| outputs = [rnn_input] | |
| for i in range(self.num_layers): | |
| rnn_input = outputs[-1] | |
| # Apply dropout to input | |
| if self.dropout_rate > 0: | |
| dropout_input = F.dropout(rnn_input.data, | |
| p=self.dropout_rate, | |
| training=self.training) | |
| rnn_input = nn.utils.rnn.PackedSequence(dropout_input, | |
| rnn_input.batch_sizes) | |
| outputs.append(self.rnns[i](rnn_input)[0]) | |
| # Unpack everything | |
| for i, o in enumerate(outputs[1:], 1): | |
| outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0] | |
| # Concat hidden layers or take final | |
| if self.concat_layers: | |
| output = torch.cat(outputs[1:], 2) | |
| else: | |
| output = outputs[-1] | |
| # Transpose and unsort | |
| output = output.transpose(0, 1) | |
| output = output.index_select(0, idx_unsort) | |
| # Pad up to original batch sequence length | |
| if output.size(1) != x_mask.size(1): | |
| padding = torch.zeros(output.size(0), | |
| x_mask.size(1) - output.size(1), | |
| output.size(2)).type(output.data.type()) | |
| output = torch.cat([output, padding], 1) | |
| # Dropout on output layer | |
| if self.dropout_output and self.dropout_rate > 0: | |
| output = F.dropout(output, | |
| p=self.dropout_rate, | |
| training=self.training) | |
| return output | |
| class SeqAttnMatch(nn.Module): | |
| """Given sequences X and Y, match sequence Y to each element in X. | |
| * o_i = sum(alpha_j * y_j) for i in X | |
| * alpha_j = softmax(y_j * x_i) | |
| """ | |
| def __init__(self, input_size, identity=False): | |
| super(SeqAttnMatch, self).__init__() | |
| if not identity: | |
| self.linear = nn.Linear(input_size, input_size) | |
| else: | |
| self.linear = None | |
| def forward(self, x, y, y_mask): | |
| """ | |
| Args: | |
| x: batch * len1 * hdim | |
| y: batch * len2 * hdim | |
| y_mask: batch * len2 (1 for padding, 0 for true) | |
| Output: | |
| matched_seq: batch * len1 * hdim | |
| """ | |
| # Project vectors | |
| if self.linear: | |
| x_proj = self.linear(x.view(-1, x.size(2))).view(x.size()) | |
| x_proj = F.relu(x_proj) | |
| y_proj = self.linear(y.view(-1, y.size(2))).view(y.size()) | |
| y_proj = F.relu(y_proj) | |
| else: | |
| x_proj = x | |
| y_proj = y | |
| # Compute scores | |
| scores = x_proj.bmm(y_proj.transpose(2, 1)) | |
| # Mask padding | |
| y_mask = y_mask.unsqueeze(1).expand(scores.size()) | |
| scores.data.masked_fill_(y_mask.data, -float('inf')) | |
| # Normalize with softmax | |
| alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1) | |
| alpha = alpha_flat.view(-1, x.size(1), y.size(1)) | |
| # Take weighted average | |
| matched_seq = alpha.bmm(y) | |
| return matched_seq | |
| class BilinearSeqAttn(nn.Module): | |
| """A bilinear attention layer over a sequence X w.r.t y: | |
| * o_i = softmax(x_i'Wy) for x_i in X. | |
| Optionally don't normalize output weights. | |
| """ | |
| def __init__(self, x_size, y_size, identity=False, normalize=True): | |
| super(BilinearSeqAttn, self).__init__() | |
| self.normalize = normalize | |
| # If identity is true, we just use a dot product without transformation. | |
| if not identity: | |
| self.linear = nn.Linear(y_size, x_size) | |
| else: | |
| self.linear = None | |
| def forward(self, x, y, x_mask): | |
| """ | |
| Args: | |
| x: batch * len * hdim1 | |
| y: batch * hdim2 | |
| x_mask: batch * len (1 for padding, 0 for true) | |
| Output: | |
| alpha = batch * len | |
| """ | |
| Wy = self.linear(y) if self.linear is not None else y | |
| xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2) | |
| xWy.data.masked_fill_(x_mask.data, -float('inf')) | |
| if self.normalize: | |
| if self.training: | |
| # In training we output log-softmax for NLL | |
| alpha = F.log_softmax(xWy, dim=-1) | |
| else: | |
| # ...Otherwise 0-1 probabilities | |
| alpha = F.softmax(xWy, dim=-1) | |
| else: | |
| alpha = xWy.exp() | |
| return alpha | |
| class LinearSeqAttn(nn.Module): | |
| """Self attention over a sequence: | |
| * o_i = softmax(Wx_i) for x_i in X. | |
| """ | |
| def __init__(self, input_size): | |
| super(LinearSeqAttn, self).__init__() | |
| self.linear = nn.Linear(input_size, 1) | |
| def forward(self, x, x_mask): | |
| """ | |
| Args: | |
| x: batch * len * hdim | |
| x_mask: batch * len (1 for padding, 0 for true) | |
| Output: | |
| alpha: batch * len | |
| """ | |
| x_flat = x.view(-1, x.size(-1)) | |
| scores = self.linear(x_flat).view(x.size(0), x.size(1)) | |
| scores.data.masked_fill_(x_mask.data, -float('inf')) | |
| alpha = F.softmax(scores, dim=-1) | |
| return alpha | |
| # ------------------------------------------------------------------------------ | |
| # Functional | |
| # ------------------------------------------------------------------------------ | |
| def uniform_weights(x, x_mask): | |
| """Return uniform weights over non-masked x (a sequence of vectors). | |
| Args: | |
| x: batch * len * hdim | |
| x_mask: batch * len (1 for padding, 0 for true) | |
| Output: | |
| x_avg: batch * hdim | |
| """ | |
| alpha = torch.ones(x.size(0), x.size(1)) | |
| if x.data.is_cuda: | |
| alpha = alpha.cuda() | |
| alpha = alpha * x_mask.eq(0).float() | |
| alpha = alpha / alpha.sum(1).expand(alpha.size()) | |
| return alpha | |
| def weighted_avg(x, weights): | |
| """Return a weighted average of x (a sequence of vectors). | |
| Args: | |
| x: batch * len * hdim | |
| weights: batch * len, sum(dim = 1) = 1 | |
| Output: | |
| x_avg: batch * hdim | |
| """ | |
| return weights.unsqueeze(1).bmm(x).squeeze(1) | |