pgps-demo / model /decoder /transformer.py
asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
import torch
import torch.nn as nn
from utils.utils import sequence_mask
from model.module import *
from torch.nn import functional as F
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout_rate=0.2):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
"""
x: [B, max_len, d_model]
pe: [1, max_len, d_model]
"""
x = x + self.pe[:, : x.size(1)].requires_grad_(False)
return self.dropout(x)
class TransformerDecoder(nn.Module):
def __init__(self, cfg, tgt_lang, \
d_model=256, nhead=8, num_decoder_layers=4, dim_feedforward=1024, dropout=0.2):
super(TransformerDecoder, self).__init__()
decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
self.position_dec = PositionalEncoding(d_model=d_model)
self.score = Score_Multi(cfg.decoder_hidden_size, cfg.decoder_embedding_size)
self.var_start = tgt_lang.var_start
self.embedding_tgt = nn.Embedding(self.var_start, cfg.decoder_embedding_size, padding_idx=0)
self.no_var_id = torch.arange(self.var_start).unsqueeze(0).cuda()
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
self.cfg = cfg
self.sos_id = tgt_lang.word2index["[SOS]"]
self.eos_id = tgt_lang.word2index["[EOS]"]
def _reset_parameters(self):
"""
Initiate parameters in the transformer model.
"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def get_square_subsequent_mask(self, sz):
"""
Generate a square mask for the sequence. The masked positions are filled with True.
Unmasked positions are filled with False.
"""
mask = (torch.triu(torch.ones(sz, sz)) == 0).transpose(0, 1)
return mask.cuda()
def get_var_encoder_outputs(self, encoder_outputs, var_pos):
"""
Arguments:
encoder_outputs: B x S1 x H
var_pos: B x S3
Returns:
var_embeddings: B x S3 x H
"""
hidden_size = encoder_outputs.size(-1)
expand_var_pos = var_pos.unsqueeze(-1).repeat(1, 1, hidden_size)
var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_pos)
return var_embeddings
def forward(self, memory, len_src, tgt, len_tgt, var_pos, len_var, is_train=False):
'''
memory: B x S1 x H
len_src: B
tgt: B x S2
len_tgt: B
var_pos: B x S3(var_size)
len_var: B
'''
self.embedding_var = self.get_var_encoder_outputs(memory, var_pos) # B x S3 x H
self.candi_mask = sequence_mask(self.var_start + len_var) # B x (no_var_size + var_size)
self.memory_key_padding_mask = ~sequence_mask(len_src) # B x S1
if is_train:
return self._forward_train(memory, tgt, len_tgt)
else:
return self._forward_test(memory)
def _forward_train(self, memory, tgt, len_tgt):
# mask
tgt_mask = self.get_square_subsequent_mask(tgt.size(-1))
tgt_key_padding_mask = ~sequence_mask(len_tgt)
# emb_tgt
tgt_novar_id = torch.clamp(tgt, max=self.var_start-1) # B x S2
novar_embedding = self.embedding_tgt(tgt_novar_id) # B x S2 x H
tgt_var_id = torch.clamp(tgt-self.var_start, min=0) # B x S2
var_embeddings = self.embedding_var.gather(dim=1, index = \
tgt_var_id.unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size)) # B x S2 x H
choose_mask = (tgt<self.var_start).unsqueeze(2). \
repeat(1, 1, self.cfg.decoder_embedding_size)
emb_tgt = torch.where(choose_mask, novar_embedding, var_embeddings) # B x S2 x H
# position decoding
emb_tgt = self.position_dec(emb_tgt)
output = self.decoder( # B x S2 x H
emb_tgt.permute(1,0,2),
memory.permute(1,0,2),
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=self.memory_key_padding_mask,
).permute(1,0,2)
# candi weight embedding
embedding_weight_no_var = self.embedding_tgt(self.no_var_id.repeat(len(len_tgt), 1)) # B x no_var_size x H
embedding_weight_all = torch.cat((embedding_weight_no_var, self.embedding_var), dim=1) # B x (no_var_size+var_size) x H
candi_score = self.score( # B x S2 x (no_var_size + var_size)
output,
embedding_weight_all, \
self.candi_mask
)
return candi_score[:,:-1,:].clone()
def _forward_test(self, memory):
exp_outputs = []
for sample_id in range(memory.size(0)):
# predefine
rem_size = self.cfg.beam_size
memory_item = memory[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S1 x H
memory_key_padding_mask = self.memory_key_padding_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
embedding_var = self.embedding_var[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S3 x H
embedding_weight_no_var = self.embedding_tgt(self.no_var_id.repeat(rem_size, 1)) # beam_size x no_var_size x H
embedding_weight_all = torch.cat((embedding_weight_no_var, embedding_var), dim=1) # beam_size x (no_var_size + var_size) x H
candi_mask = self.candi_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
candi_exp_output = []
candi_score_output = []
tgt = torch.LongTensor([[self.sos_id]]*rem_size).cuda() # rem_size x 1
len_tgt = torch.LongTensor([1]*rem_size).cuda() # rem_size
current_score = torch.FloatTensor([[0.0]]*rem_size).cuda() # rem_size x 1
current_exp_list = [[self.sos_id]]*rem_size
for i in range(self.cfg.max_output_len):
# mask
tgt_mask = self.get_square_subsequent_mask(tgt.size(-1))
tgt_key_padding_mask = ~sequence_mask(len_tgt)
# input embedding
tgt_novar_id = torch.clamp(tgt, max=self.var_start-1) # rem_size x S
novar_embedding = self.embedding_tgt(tgt_novar_id) # rem_size x S x H
tgt_var_id = torch.clamp(tgt-self.var_start, min=0) # rem_size x S
var_embeddings = embedding_var[:rem_size].gather(dim=1, index=tgt_var_id.unsqueeze(2). \
repeat(1, 1, self.cfg.decoder_embedding_size)) # rem_size x S x H
choose_mask = (tgt<self.var_start).unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size) # rem_size x S x H
emb_tgt = torch.where(choose_mask, novar_embedding, var_embeddings) # rem_size x S x H
# position decoding
emb_tgt = self.position_dec(emb_tgt)
output = self.decoder( # rem_size x S x H
emb_tgt.permute(1,0,2),
memory_item[:rem_size].permute(1,0,2),
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask[:rem_size],
).permute(1,0,2)
candi_score = self.score( # rem_size x S x (no_var_size + var_size)
output,
embedding_weight_all[:rem_size], \
candi_mask[:rem_size]
)
if i==0:
new_score = F.log_softmax(candi_score[:, -1, :], dim=1)[:1]
else:
new_score = F.log_softmax(candi_score[:, -1, :], dim=1) + current_score # rem_size x (no_var_size + var_size)
topv, topi = new_score.view(-1).topk(rem_size)
exp_list = []
score_list = topv.tolist()
for tv, ti in zip(topv, topi):
idex = ti.item()
x = idex // candi_score.size(-1)
y = idex % candi_score.size(-1)
if y!=self.eos_id:
exp_list.append(current_exp_list[x]+[y])
else:
candi_exp_output.append(current_exp_list[x][1:])
candi_score_output.append(float(tv))
if len(exp_list)==0:
break
tgt = torch.LongTensor(exp_list).cuda() # rem_size x S
len_tgt = torch.LongTensor([len(item) for item in exp_list]).cuda() # rem_size
current_exp_list = exp_list
rem_size = len(exp_list)
current_score = torch.FloatTensor(score_list[:rem_size]).unsqueeze(1).cuda() # rem_size x 1
if len(candi_exp_output)>0:
_, candi_exp_output = zip(*sorted(zip(candi_score_output, candi_exp_output), reverse=True))
exp_outputs.append(list(candi_exp_output))
else:
exp_outputs.append([])
return exp_outputs