import torch import torch.nn as nn from utils.utils import sequence_mask from model.module import * from torch.nn import functional as F import math class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000, dropout_rate=0.2): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pe", pe) self.dropout = nn.Dropout(dropout_rate) def forward(self, x): """ x: [B, max_len, d_model] pe: [1, max_len, d_model] """ x = x + self.pe[:, : x.size(1)].requires_grad_(False) return self.dropout(x) class TransformerDecoder(nn.Module): def __init__(self, cfg, tgt_lang, \ d_model=256, nhead=8, num_decoder_layers=4, dim_feedforward=1024, dropout=0.2): super(TransformerDecoder, self).__init__() decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) decoder_norm = nn.LayerNorm(d_model) self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) self.position_dec = PositionalEncoding(d_model=d_model) self.score = Score_Multi(cfg.decoder_hidden_size, cfg.decoder_embedding_size) self.var_start = tgt_lang.var_start self.embedding_tgt = nn.Embedding(self.var_start, cfg.decoder_embedding_size, padding_idx=0) self.no_var_id = torch.arange(self.var_start).unsqueeze(0).cuda() self._reset_parameters() self.d_model = d_model self.nhead = nhead self.cfg = cfg self.sos_id = tgt_lang.word2index["[SOS]"] self.eos_id = tgt_lang.word2index["[EOS]"] def _reset_parameters(self): """ Initiate parameters in the transformer model. """ for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def get_square_subsequent_mask(self, sz): """ Generate a square mask for the sequence. The masked positions are filled with True. Unmasked positions are filled with False. """ mask = (torch.triu(torch.ones(sz, sz)) == 0).transpose(0, 1) return mask.cuda() def get_var_encoder_outputs(self, encoder_outputs, var_pos): """ Arguments: encoder_outputs: B x S1 x H var_pos: B x S3 Returns: var_embeddings: B x S3 x H """ hidden_size = encoder_outputs.size(-1) expand_var_pos = var_pos.unsqueeze(-1).repeat(1, 1, hidden_size) var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_pos) return var_embeddings def forward(self, memory, len_src, tgt, len_tgt, var_pos, len_var, is_train=False): ''' memory: B x S1 x H len_src: B tgt: B x S2 len_tgt: B var_pos: B x S3(var_size) len_var: B ''' self.embedding_var = self.get_var_encoder_outputs(memory, var_pos) # B x S3 x H self.candi_mask = sequence_mask(self.var_start + len_var) # B x (no_var_size + var_size) self.memory_key_padding_mask = ~sequence_mask(len_src) # B x S1 if is_train: return self._forward_train(memory, tgt, len_tgt) else: return self._forward_test(memory) def _forward_train(self, memory, tgt, len_tgt): # mask tgt_mask = self.get_square_subsequent_mask(tgt.size(-1)) tgt_key_padding_mask = ~sequence_mask(len_tgt) # emb_tgt tgt_novar_id = torch.clamp(tgt, max=self.var_start-1) # B x S2 novar_embedding = self.embedding_tgt(tgt_novar_id) # B x S2 x H tgt_var_id = torch.clamp(tgt-self.var_start, min=0) # B x S2 var_embeddings = self.embedding_var.gather(dim=1, index = \ tgt_var_id.unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size)) # B x S2 x H choose_mask = (tgt0: _, candi_exp_output = zip(*sorted(zip(candi_score_output, candi_exp_output), reverse=True)) exp_outputs.append(list(candi_exp_output)) else: exp_outputs.append([]) return exp_outputs