Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # Copyright 2017-present, Facebook, Inc. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Implementation of the RNN based DrQA reader.""" | |
| import torch | |
| import torch.nn as nn | |
| from . import layers | |
| # ------------------------------------------------------------------------------ | |
| # Network | |
| # ------------------------------------------------------------------------------ | |
| class RnnDocReader(nn.Module): | |
| RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN} | |
| def __init__(self, args, normalize=True): | |
| super(RnnDocReader, self).__init__() | |
| # Store config | |
| self.args = args | |
| # Word embeddings (+1 for padding) | |
| self.embedding = nn.Embedding(args.vocab_size, | |
| args.embedding_dim, | |
| padding_idx=0) | |
| # Projection for attention weighted question | |
| if args.use_qemb: | |
| self.qemb_match = layers.SeqAttnMatch(args.embedding_dim) | |
| # Input size to RNN: word emb + question emb + manual features | |
| doc_input_size = args.embedding_dim + args.num_features | |
| if args.use_qemb: | |
| doc_input_size += args.embedding_dim | |
| # RNN document encoder | |
| self.doc_rnn = layers.StackedBRNN( | |
| input_size=doc_input_size, | |
| hidden_size=args.hidden_size, | |
| num_layers=args.doc_layers, | |
| dropout_rate=args.dropout_rnn, | |
| dropout_output=args.dropout_rnn_output, | |
| concat_layers=args.concat_rnn_layers, | |
| rnn_type=self.RNN_TYPES[args.rnn_type], | |
| padding=args.rnn_padding, | |
| ) | |
| # RNN question encoder | |
| self.question_rnn = layers.StackedBRNN( | |
| input_size=args.embedding_dim, | |
| hidden_size=args.hidden_size, | |
| num_layers=args.question_layers, | |
| dropout_rate=args.dropout_rnn, | |
| dropout_output=args.dropout_rnn_output, | |
| concat_layers=args.concat_rnn_layers, | |
| rnn_type=self.RNN_TYPES[args.rnn_type], | |
| padding=args.rnn_padding, | |
| ) | |
| # Output sizes of rnn encoders | |
| doc_hidden_size = 2 * args.hidden_size | |
| question_hidden_size = 2 * args.hidden_size | |
| if args.concat_rnn_layers: | |
| doc_hidden_size *= args.doc_layers | |
| question_hidden_size *= args.question_layers | |
| # Question merging | |
| if args.question_merge not in ['avg', 'self_attn']: | |
| raise NotImplementedError('merge_mode = %s' % args.merge_mode) | |
| if args.question_merge == 'self_attn': | |
| self.self_attn = layers.LinearSeqAttn(question_hidden_size) | |
| # Bilinear attention for span start/end | |
| self.start_attn = layers.BilinearSeqAttn( | |
| doc_hidden_size, | |
| question_hidden_size, | |
| normalize=normalize, | |
| ) | |
| self.end_attn = layers.BilinearSeqAttn( | |
| doc_hidden_size, | |
| question_hidden_size, | |
| normalize=normalize, | |
| ) | |
| def forward(self, x1, x1_f, x1_mask, x2, x2_mask): | |
| """Inputs: | |
| x1 = document word indices [batch * len_d] | |
| x1_f = document word features indices [batch * len_d * nfeat] | |
| x1_mask = document padding mask [batch * len_d] | |
| x2 = question word indices [batch * len_q] | |
| x2_mask = question padding mask [batch * len_q] | |
| """ | |
| # Embed both document and question | |
| x1_emb = self.embedding(x1) | |
| x2_emb = self.embedding(x2) | |
| # Dropout on embeddings | |
| if self.args.dropout_emb > 0: | |
| x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb, | |
| training=self.training) | |
| x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb, | |
| training=self.training) | |
| # Form document encoding inputs | |
| drnn_input = [x1_emb] | |
| # Add attention-weighted question representation | |
| if self.args.use_qemb: | |
| x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask) | |
| drnn_input.append(x2_weighted_emb) | |
| # Add manual features | |
| if self.args.num_features > 0: | |
| drnn_input.append(x1_f) | |
| # Encode document with RNN | |
| doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask) | |
| # Encode question with RNN + merge hiddens | |
| question_hiddens = self.question_rnn(x2_emb, x2_mask) | |
| if self.args.question_merge == 'avg': | |
| q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask) | |
| elif self.args.question_merge == 'self_attn': | |
| q_merge_weights = self.self_attn(question_hiddens, x2_mask) | |
| question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights) | |
| # Predict start and end positions | |
| start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask) | |
| end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask) | |
| return start_scores, end_scores | |