Albin Thörn Cleland
Clean initial commit with LFS
19b8775
"""
The full encoder-decoder model, built on top of the base seq2seq modules.
"""
import logging
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common import utils
from stanza.models.common.seq2seq_modules import LSTMAttention
from stanza.models.common.beam import Beam
from stanza.models.common.seq2seq_constant import UNK_ID
logger = logging.getLogger('stanza')
class Seq2SeqModel(nn.Module):
"""
A complete encoder-decoder model, with optional attention.
A parent class which makes use of the contextual_embedding (such as a charlm)
can make use of unsaved_modules when saving.
"""
def __init__(self, args, emb_matrix=None, contextual_embedding=None):
super().__init__()
self.unsaved_modules = []
self.vocab_size = args['vocab_size']
self.emb_dim = args['emb_dim']
self.hidden_dim = args['hidden_dim']
self.nlayers = args['num_layers'] # encoder layers, decoder layers = 1
self.emb_dropout = args.get('emb_dropout', 0.0)
self.dropout = args['dropout']
self.pad_token = constant.PAD_ID
self.max_dec_len = args['max_dec_len']
self.top = args.get('top', 1e10)
self.args = args
self.emb_matrix = emb_matrix
self.add_unsaved_module("contextual_embedding", contextual_embedding)
logger.debug("Building an attentional Seq2Seq model...")
logger.debug("Using a Bi-LSTM encoder")
self.num_directions = 2
self.enc_hidden_dim = self.hidden_dim // 2
self.dec_hidden_dim = self.hidden_dim
self.use_pos = args.get('pos', False)
self.pos_dim = args.get('pos_dim', 0)
self.pos_vocab_size = args.get('pos_vocab_size', 0)
self.pos_dropout = args.get('pos_dropout', 0)
self.edit = args.get('edit', False)
self.num_edit = args.get('num_edit', 0)
self.copy = args.get('copy', False)
self.emb_drop = nn.Dropout(self.emb_dropout)
self.drop = nn.Dropout(self.dropout)
self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
self.input_dim = self.emb_dim
if self.contextual_embedding is not None:
self.input_dim += self.contextual_embedding.hidden_dim()
self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
self.decoder = LSTMAttention(self.emb_dim, self.dec_hidden_dim, \
batch_first=True, attn_type=self.args['attn_type'])
self.dec2vocab = nn.Linear(self.dec_hidden_dim, self.vocab_size)
if self.use_pos and self.pos_dim > 0:
logger.debug("Using POS in encoder")
self.pos_embedding = nn.Embedding(self.pos_vocab_size, self.pos_dim, self.pad_token)
self.pos_drop = nn.Dropout(self.pos_dropout)
if self.edit:
edit_hidden = self.hidden_dim//2
self.edit_clf = nn.Sequential(
nn.Linear(self.hidden_dim, edit_hidden),
nn.ReLU(),
nn.Linear(edit_hidden, self.num_edit))
if self.copy:
self.copy_gate = nn.Linear(self.dec_hidden_dim, 1)
SOS_tensor = torch.LongTensor([constant.SOS_ID])
self.register_buffer('SOS_tensor', SOS_tensor)
self.init_weights()
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
def init_weights(self):
# initialize embeddings
init_range = constant.EMB_INIT_RANGE
if self.emb_matrix is not None:
if isinstance(self.emb_matrix, np.ndarray):
self.emb_matrix = torch.from_numpy(self.emb_matrix)
assert self.emb_matrix.size() == (self.vocab_size, self.emb_dim), \
"Input embedding matrix must match size: {} x {}".format(self.vocab_size, self.emb_dim)
self.embedding.weight.data.copy_(self.emb_matrix)
else:
self.embedding.weight.data.uniform_(-init_range, init_range)
# decide finetuning
if self.top <= 0:
logger.debug("Do not finetune embedding layer.")
self.embedding.weight.requires_grad = False
elif self.top < self.vocab_size:
logger.debug("Finetune top {} embeddings.".format(self.top))
self.embedding.weight.register_hook(lambda x: utils.keep_partial_grad(x, self.top))
else:
logger.debug("Finetune all embeddings.")
# initialize pos embeddings
if self.use_pos:
self.pos_embedding.weight.data.uniform_(-init_range, init_range)
def zero_state(self, inputs):
batch_size = inputs.size(0)
device = self.SOS_tensor.device
h0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
c0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
return h0, c0
def encode(self, enc_inputs, lens):
""" Encode source sequence. """
h0, c0 = self.zero_state(enc_inputs)
packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)
packed_h_in, (hn, cn) = self.encoder(packed_inputs, (h0, c0))
h_in, _ = nn.utils.rnn.pad_packed_sequence(packed_h_in, batch_first=True)
hn = torch.cat((hn[-1], hn[-2]), 1)
cn = torch.cat((cn[-1], cn[-2]), 1)
return h_in, (hn, cn)
def decode(self, dec_inputs, hn, cn, ctx, ctx_mask=None, src=None, never_decode_unk=False):
""" Decode a step, based on context encoding and source context states."""
dec_hidden = (hn, cn)
decoder_output = self.decoder(dec_inputs, dec_hidden, ctx, ctx_mask, return_logattn=self.copy)
if self.copy:
h_out, dec_hidden, log_attn = decoder_output
else:
h_out, dec_hidden = decoder_output
h_out_reshape = h_out.contiguous().view(h_out.size(0) * h_out.size(1), -1)
decoder_logits = self.dec2vocab(h_out_reshape)
decoder_logits = decoder_logits.view(h_out.size(0), h_out.size(1), -1)
log_probs = self.get_log_prob(decoder_logits)
if self.copy:
copy_logit = self.copy_gate(h_out)
if self.use_pos:
# can't copy the UPOS
log_attn = log_attn[:, :, 1:]
# renormalize
log_attn = torch.log_softmax(log_attn, -1)
# calculate copy probability for each word in the vocab
log_copy_prob = torch.nn.functional.logsigmoid(copy_logit) + log_attn
# scatter logsumexp
mx = log_copy_prob.max(-1, keepdim=True)[0]
log_copy_prob = log_copy_prob - mx
# here we make space in the log probs for vocab items
# which might be copied from the encoder side, but which
# were not known at training time
# note that such an item cannot possibly be predicted by
# the model as a raw output token
# however, the copy gate might score high on copying a
# previously unknown vocab item
copy_prob = torch.exp(log_copy_prob)
copied_vocab_shape = list(log_probs.size())
if torch.max(src) >= copied_vocab_shape[-1]:
copied_vocab_shape[-1] = torch.max(src) + 1
copied_vocab_prob = log_probs.new_zeros(copied_vocab_shape)
scattered_copy = src.unsqueeze(1).expand(src.size(0), copy_prob.size(1), src.size(1))
# fill in the copy tensor with the copy probs of each character
# the rest of the copy tensor will be filled with -largenumber
copied_vocab_prob = copied_vocab_prob.scatter_add(-1, scattered_copy, copy_prob)
zero_mask = (copied_vocab_prob == 0)
log_copied_vocab_prob = torch.log(copied_vocab_prob.masked_fill(zero_mask, 1e-12)) + mx
log_copied_vocab_prob = log_copied_vocab_prob.masked_fill(zero_mask, -1e12)
# combine with normal vocab probability
log_nocopy_prob = -torch.log(1 + torch.exp(copy_logit))
if log_probs.shape[-1] < copied_vocab_shape[-1]:
# for previously unknown vocab items which are in the encoder,
# we reuse the UNK_ID prediction
# this gives a baseline number which we can combine with
# the copy gate prediction
# technically this makes log_probs no longer represent
# a probability distribution when looking at unknown vocab
# this is probably not a serious problem
# an example of this usage is in the Lemmatizer, such as a
# plural word in English with the character "ã" in it instead of "a"
# if "ã" is not known in the training data, the lemmatizer would
# ordinarily be unable to output it, and thus the seq2seq model
# would have no chance to depluralize "ãntennae" -> "ãntenna"
# however, if we temporarily add "ã" to the encoder vocab,
# then let the copy gate accept that letter, we find the Lemmatizer
# seq2seq model will want to copy that particular vocab item
# this allows the Lemmatizer to produce "ã" instead of requiring
# that it produces UNK, then going back to the input text to
# figure out which UNK it intended to produce
new_log_probs = log_probs.new_zeros(copied_vocab_shape)
new_log_probs[:, :, :log_probs.shape[-1]] = log_probs
new_log_probs[:, :, log_probs.shape[-1]:] = new_log_probs[:, :, UNK_ID].unsqueeze(2)
log_probs = new_log_probs
log_probs = log_probs + log_nocopy_prob
log_probs = torch.logsumexp(torch.stack([log_copied_vocab_prob, log_probs]), 0)
if never_decode_unk:
log_probs[:, :, UNK_ID] = float("-inf")
return log_probs, dec_hidden
def embed(self, src, src_mask, pos, raw):
embed_src = src.clone()
embed_src[embed_src >= self.vocab_size] = UNK_ID
enc_inputs = self.emb_drop(self.embedding(embed_src))
batch_size = enc_inputs.size(0)
if self.use_pos:
assert pos is not None, "Missing POS input for seq2seq lemmatizer."
pos_inputs = self.pos_drop(self.pos_embedding(pos))
enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1)
pos_src_mask = src_mask.new_zeros([batch_size, 1])
src_mask = torch.cat([pos_src_mask, src_mask], dim=1)
if raw is not None and self.contextual_embedding is not None:
raw_inputs = self.contextual_embedding(raw)
if self.use_pos:
raw_zeros = raw_inputs.new_zeros((raw_inputs.shape[0], 1, raw_inputs.shape[2]))
raw_inputs = torch.cat([raw_inputs, raw_zeros], dim=1)
enc_inputs = torch.cat([enc_inputs, raw_inputs], dim=2)
src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
return enc_inputs, batch_size, src_lens, src_mask
def forward(self, src, src_mask, tgt_in, pos=None, raw=None):
# prepare for encoder/decoder
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
# encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
if self.edit:
edit_logits = self.edit_clf(hn)
else:
edit_logits = None
dec_inputs = self.emb_drop(self.embedding(tgt_in))
log_probs, _ = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src)
return log_probs, edit_logits
def get_log_prob(self, logits):
logits_reshape = logits.view(-1, self.vocab_size)
log_probs = F.log_softmax(logits_reshape, dim=1)
if logits.dim() == 2:
return log_probs
return log_probs.view(logits.size(0), logits.size(1), logits.size(2))
def predict_greedy(self, src, src_mask, pos=None, raw=None, never_decode_unk=False):
""" Predict with greedy decoding. """
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
# encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
if self.edit:
edit_logits = self.edit_clf(hn)
else:
edit_logits = None
# greedy decode by step
dec_inputs = self.embedding(self.SOS_tensor)
dec_inputs = dec_inputs.expand(batch_size, dec_inputs.size(0), dec_inputs.size(1))
done = [False for _ in range(batch_size)]
total_done = 0
max_len = 0
output_seqs = [[] for _ in range(batch_size)]
while total_done < batch_size and max_len < self.max_dec_len:
log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
assert log_probs.size(1) == 1, "Output must have 1-step of output."
_, preds = log_probs.squeeze(1).max(1, keepdim=True)
# if a unlearned character is predicted via the copy mechanism,
# use the UNK embedding for it
dec_inputs = preds.clone()
dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
dec_inputs = self.embedding(dec_inputs) # update decoder inputs
max_len += 1
for i in range(batch_size):
if not done[i]:
token = preds.data[i][0].item()
if token == constant.EOS_ID:
done[i] = True
total_done += 1
else:
output_seqs[i].append(token)
return output_seqs, edit_logits
def predict(self, src, src_mask, pos=None, beam_size=5, raw=None, never_decode_unk=False):
""" Predict with beam search. """
if beam_size == 1:
return self.predict_greedy(src, src_mask, pos, raw, never_decode_unk=never_decode_unk)
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
# (1) encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
if self.edit:
edit_logits = self.edit_clf(hn)
else:
edit_logits = None
# (2) set up beam
with torch.no_grad():
h_in = h_in.data.repeat(beam_size, 1, 1) # repeat data for beam search
src_mask = src_mask.repeat(beam_size, 1)
# repeat decoder hidden states
hn = hn.data.repeat(beam_size, 1)
cn = cn.data.repeat(beam_size, 1)
device = self.SOS_tensor.device
beam = [Beam(beam_size, device) for _ in range(batch_size)]
def update_state(states, idx, positions, beam_size):
""" Select the states according to back pointers. """
for e in states:
br, d = e.size()
s = e.contiguous().view(beam_size, br // beam_size, d)[:,idx]
s.data.copy_(s.data.index_select(0, positions))
# (3) main loop
for i in range(self.max_dec_len):
dec_inputs = torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(-1, 1)
# if a unlearned character is predicted via the copy mechanism,
# use the UNK embedding for it
dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
dec_inputs = self.embedding(dec_inputs)
log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
log_probs = log_probs.view(beam_size, batch_size, -1).transpose(0,1).contiguous() # [batch, beam, V]
# advance each beam
done = []
for b in range(batch_size):
is_done = beam[b].advance(log_probs.data[b])
if is_done:
done += [b]
# update beam state
update_state((hn, cn), b, beam[b].get_current_origin(), beam_size)
if len(done) == batch_size:
break
# back trace and find hypothesis
all_hyp, all_scores = [], []
for b in range(batch_size):
scores, ks = beam[b].sort_best()
all_scores += [scores[0]]
k = ks[0]
hyp = beam[b].get_hyp(k)
hyp = utils.prune_hyp(hyp)
hyp = [i.item() for i in hyp]
all_hyp += [hyp]
return all_hyp, edit_logits