import torch import torch.nn as nn from model.module import * from utils import * from torch.nn import functional as F class DecoderRNN(nn.Module): def __init__(self, cfg, tgt_lang): super(DecoderRNN, self).__init__() # token location self.var_start = tgt_lang.var_start # spe_num + midvar_num + const_num + op_num self.sos_id = tgt_lang.word2index["[SOS]"] self.eos_id = tgt_lang.word2index["[EOS]"] # Define layers self.em_dropout = nn.Dropout(cfg.dropout_rate) self.embedding_tgt = nn.Embedding(self.var_start, cfg.decoder_embedding_size, padding_idx=0) self.gru = nn.GRU(input_size=cfg.decoder_hidden_size+cfg.decoder_embedding_size, \ hidden_size=cfg.decoder_hidden_size, \ num_layers=cfg.decoder_layers, \ dropout = cfg.dropout_rate, \ batch_first = True) # Choose attention model self.attn = Attn(cfg.encoder_hidden_size, cfg.decoder_hidden_size) self.score = Score(cfg.encoder_hidden_size+cfg.decoder_hidden_size, cfg.decoder_embedding_size) # predefined constant self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.no_var_id = torch.arange(self.var_start).unsqueeze(0).to(self.device) self.cfg = cfg def get_var_encoder_outputs(self, encoder_outputs, var_pos): """ Arguments: encoder_outputs: B x S1 x H var_pos: B x S3 Returns: var_embeddings: B x S3 x H """ hidden_size = encoder_outputs.size(-1) expand_var_pos = var_pos.unsqueeze(-1).repeat(1, 1, hidden_size) var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_pos) return var_embeddings def forward(self, encoder_outputs, problem_output, len_src, var_pos, len_var, \ text_tgt=None, is_train=False): """ Arguments: encoder_outputs: B x S1 x H problem_output: layer_num x B x H len_src: B text_tgt: B x S2 var_pos: B x S3 len_var: B Return: training: logits, B x S x (no_var_size+var_size) testing: exp_id, B x candi_size(beam_size) x exp_len """ self.embedding_var = self.get_var_encoder_outputs(encoder_outputs, var_pos) # B x S3 x H self.src_mask = sequence_mask(len_src) # B x S1 self.candi_mask = sequence_mask(self.var_start + len_var) # B x (no_var_size + var_size) if is_train: return self._forward_train(encoder_outputs, problem_output, text_tgt) else: return self._forward_test(encoder_outputs, problem_output) def _forward_train(self, encoder_outputs, problem_output, text_tgt): all_seq_outputs = [] batch_size = encoder_outputs.size(0) # initial hidden input of RNN rnn_hidden = problem_output # input embedding tgt_novar_id = torch.clamp(text_tgt, max=self.var_start-1) # B x S2 novar_embedding = self.embedding_tgt(tgt_novar_id) # B x S2 x H tgt_var_id = torch.clamp(text_tgt-self.var_start, min=0) # B x S2 var_embeddings = self.embedding_var.gather(dim=1, index = \ tgt_var_id.unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size)) # B x S2 x H choose_mask = (text_tgt