""" The full encoder-decoder model, built on top of the base seq2seq modules. """ import logging import torch from torch import nn import torch.nn.functional as F import numpy as np import stanza.models.common.seq2seq_constant as constant from stanza.models.common import utils from stanza.models.common.seq2seq_modules import LSTMAttention from stanza.models.common.beam import Beam from stanza.models.common.seq2seq_constant import UNK_ID logger = logging.getLogger('stanza') class Seq2SeqModel(nn.Module): """ A complete encoder-decoder model, with optional attention. A parent class which makes use of the contextual_embedding (such as a charlm) can make use of unsaved_modules when saving. """ def __init__(self, args, emb_matrix=None, contextual_embedding=None): super().__init__() self.unsaved_modules = [] self.vocab_size = args['vocab_size'] self.emb_dim = args['emb_dim'] self.hidden_dim = args['hidden_dim'] self.nlayers = args['num_layers'] # encoder layers, decoder layers = 1 self.emb_dropout = args.get('emb_dropout', 0.0) self.dropout = args['dropout'] self.pad_token = constant.PAD_ID self.max_dec_len = args['max_dec_len'] self.top = args.get('top', 1e10) self.args = args self.emb_matrix = emb_matrix self.add_unsaved_module("contextual_embedding", contextual_embedding) logger.debug("Building an attentional Seq2Seq model...") logger.debug("Using a Bi-LSTM encoder") self.num_directions = 2 self.enc_hidden_dim = self.hidden_dim // 2 self.dec_hidden_dim = self.hidden_dim self.use_pos = args.get('pos', False) self.pos_dim = args.get('pos_dim', 0) self.pos_vocab_size = args.get('pos_vocab_size', 0) self.pos_dropout = args.get('pos_dropout', 0) self.edit = args.get('edit', False) self.num_edit = args.get('num_edit', 0) self.copy = args.get('copy', False) self.emb_drop = nn.Dropout(self.emb_dropout) self.drop = nn.Dropout(self.dropout) self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token) self.input_dim = self.emb_dim if self.contextual_embedding is not None: self.input_dim += self.contextual_embedding.hidden_dim() self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \ bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0) self.decoder = LSTMAttention(self.emb_dim, self.dec_hidden_dim, \ batch_first=True, attn_type=self.args['attn_type']) self.dec2vocab = nn.Linear(self.dec_hidden_dim, self.vocab_size) if self.use_pos and self.pos_dim > 0: logger.debug("Using POS in encoder") self.pos_embedding = nn.Embedding(self.pos_vocab_size, self.pos_dim, self.pad_token) self.pos_drop = nn.Dropout(self.pos_dropout) if self.edit: edit_hidden = self.hidden_dim//2 self.edit_clf = nn.Sequential( nn.Linear(self.hidden_dim, edit_hidden), nn.ReLU(), nn.Linear(edit_hidden, self.num_edit)) if self.copy: self.copy_gate = nn.Linear(self.dec_hidden_dim, 1) SOS_tensor = torch.LongTensor([constant.SOS_ID]) self.register_buffer('SOS_tensor', SOS_tensor) self.init_weights() def add_unsaved_module(self, name, module): self.unsaved_modules += [name] setattr(self, name, module) def init_weights(self): # initialize embeddings init_range = constant.EMB_INIT_RANGE if self.emb_matrix is not None: if isinstance(self.emb_matrix, np.ndarray): self.emb_matrix = torch.from_numpy(self.emb_matrix) assert self.emb_matrix.size() == (self.vocab_size, self.emb_dim), \ "Input embedding matrix must match size: {} x {}".format(self.vocab_size, self.emb_dim) self.embedding.weight.data.copy_(self.emb_matrix) else: self.embedding.weight.data.uniform_(-init_range, init_range) # decide finetuning if self.top <= 0: logger.debug("Do not finetune embedding layer.") self.embedding.weight.requires_grad = False elif self.top < self.vocab_size: logger.debug("Finetune top {} embeddings.".format(self.top)) self.embedding.weight.register_hook(lambda x: utils.keep_partial_grad(x, self.top)) else: logger.debug("Finetune all embeddings.") # initialize pos embeddings if self.use_pos: self.pos_embedding.weight.data.uniform_(-init_range, init_range) def zero_state(self, inputs): batch_size = inputs.size(0) device = self.SOS_tensor.device h0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device) c0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device) return h0, c0 def encode(self, enc_inputs, lens): """ Encode source sequence. """ h0, c0 = self.zero_state(enc_inputs) packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True) packed_h_in, (hn, cn) = self.encoder(packed_inputs, (h0, c0)) h_in, _ = nn.utils.rnn.pad_packed_sequence(packed_h_in, batch_first=True) hn = torch.cat((hn[-1], hn[-2]), 1) cn = torch.cat((cn[-1], cn[-2]), 1) return h_in, (hn, cn) def decode(self, dec_inputs, hn, cn, ctx, ctx_mask=None, src=None, never_decode_unk=False): """ Decode a step, based on context encoding and source context states.""" dec_hidden = (hn, cn) decoder_output = self.decoder(dec_inputs, dec_hidden, ctx, ctx_mask, return_logattn=self.copy) if self.copy: h_out, dec_hidden, log_attn = decoder_output else: h_out, dec_hidden = decoder_output h_out_reshape = h_out.contiguous().view(h_out.size(0) * h_out.size(1), -1) decoder_logits = self.dec2vocab(h_out_reshape) decoder_logits = decoder_logits.view(h_out.size(0), h_out.size(1), -1) log_probs = self.get_log_prob(decoder_logits) if self.copy: copy_logit = self.copy_gate(h_out) if self.use_pos: # can't copy the UPOS log_attn = log_attn[:, :, 1:] # renormalize log_attn = torch.log_softmax(log_attn, -1) # calculate copy probability for each word in the vocab log_copy_prob = torch.nn.functional.logsigmoid(copy_logit) + log_attn # scatter logsumexp mx = log_copy_prob.max(-1, keepdim=True)[0] log_copy_prob = log_copy_prob - mx # here we make space in the log probs for vocab items # which might be copied from the encoder side, but which # were not known at training time # note that such an item cannot possibly be predicted by # the model as a raw output token # however, the copy gate might score high on copying a # previously unknown vocab item copy_prob = torch.exp(log_copy_prob) copied_vocab_shape = list(log_probs.size()) if torch.max(src) >= copied_vocab_shape[-1]: copied_vocab_shape[-1] = torch.max(src) + 1 copied_vocab_prob = log_probs.new_zeros(copied_vocab_shape) scattered_copy = src.unsqueeze(1).expand(src.size(0), copy_prob.size(1), src.size(1)) # fill in the copy tensor with the copy probs of each character # the rest of the copy tensor will be filled with -largenumber copied_vocab_prob = copied_vocab_prob.scatter_add(-1, scattered_copy, copy_prob) zero_mask = (copied_vocab_prob == 0) log_copied_vocab_prob = torch.log(copied_vocab_prob.masked_fill(zero_mask, 1e-12)) + mx log_copied_vocab_prob = log_copied_vocab_prob.masked_fill(zero_mask, -1e12) # combine with normal vocab probability log_nocopy_prob = -torch.log(1 + torch.exp(copy_logit)) if log_probs.shape[-1] < copied_vocab_shape[-1]: # for previously unknown vocab items which are in the encoder, # we reuse the UNK_ID prediction # this gives a baseline number which we can combine with # the copy gate prediction # technically this makes log_probs no longer represent # a probability distribution when looking at unknown vocab # this is probably not a serious problem # an example of this usage is in the Lemmatizer, such as a # plural word in English with the character "ã" in it instead of "a" # if "ã" is not known in the training data, the lemmatizer would # ordinarily be unable to output it, and thus the seq2seq model # would have no chance to depluralize "ãntennae" -> "ãntenna" # however, if we temporarily add "ã" to the encoder vocab, # then let the copy gate accept that letter, we find the Lemmatizer # seq2seq model will want to copy that particular vocab item # this allows the Lemmatizer to produce "ã" instead of requiring # that it produces UNK, then going back to the input text to # figure out which UNK it intended to produce new_log_probs = log_probs.new_zeros(copied_vocab_shape) new_log_probs[:, :, :log_probs.shape[-1]] = log_probs new_log_probs[:, :, log_probs.shape[-1]:] = new_log_probs[:, :, UNK_ID].unsqueeze(2) log_probs = new_log_probs log_probs = log_probs + log_nocopy_prob log_probs = torch.logsumexp(torch.stack([log_copied_vocab_prob, log_probs]), 0) if never_decode_unk: log_probs[:, :, UNK_ID] = float("-inf") return log_probs, dec_hidden def embed(self, src, src_mask, pos, raw): embed_src = src.clone() embed_src[embed_src >= self.vocab_size] = UNK_ID enc_inputs = self.emb_drop(self.embedding(embed_src)) batch_size = enc_inputs.size(0) if self.use_pos: assert pos is not None, "Missing POS input for seq2seq lemmatizer." pos_inputs = self.pos_drop(self.pos_embedding(pos)) enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1) pos_src_mask = src_mask.new_zeros([batch_size, 1]) src_mask = torch.cat([pos_src_mask, src_mask], dim=1) if raw is not None and self.contextual_embedding is not None: raw_inputs = self.contextual_embedding(raw) if self.use_pos: raw_zeros = raw_inputs.new_zeros((raw_inputs.shape[0], 1, raw_inputs.shape[2])) raw_inputs = torch.cat([raw_inputs, raw_zeros], dim=1) enc_inputs = torch.cat([enc_inputs, raw_inputs], dim=2) src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1)) return enc_inputs, batch_size, src_lens, src_mask def forward(self, src, src_mask, tgt_in, pos=None, raw=None): # prepare for encoder/decoder enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw) # encode source h_in, (hn, cn) = self.encode(enc_inputs, src_lens) if self.edit: edit_logits = self.edit_clf(hn) else: edit_logits = None dec_inputs = self.emb_drop(self.embedding(tgt_in)) log_probs, _ = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src) return log_probs, edit_logits def get_log_prob(self, logits): logits_reshape = logits.view(-1, self.vocab_size) log_probs = F.log_softmax(logits_reshape, dim=1) if logits.dim() == 2: return log_probs return log_probs.view(logits.size(0), logits.size(1), logits.size(2)) def predict_greedy(self, src, src_mask, pos=None, raw=None, never_decode_unk=False): """ Predict with greedy decoding. """ enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw) # encode source h_in, (hn, cn) = self.encode(enc_inputs, src_lens) if self.edit: edit_logits = self.edit_clf(hn) else: edit_logits = None # greedy decode by step dec_inputs = self.embedding(self.SOS_tensor) dec_inputs = dec_inputs.expand(batch_size, dec_inputs.size(0), dec_inputs.size(1)) done = [False for _ in range(batch_size)] total_done = 0 max_len = 0 output_seqs = [[] for _ in range(batch_size)] while total_done < batch_size and max_len < self.max_dec_len: log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk) assert log_probs.size(1) == 1, "Output must have 1-step of output." _, preds = log_probs.squeeze(1).max(1, keepdim=True) # if a unlearned character is predicted via the copy mechanism, # use the UNK embedding for it dec_inputs = preds.clone() dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID dec_inputs = self.embedding(dec_inputs) # update decoder inputs max_len += 1 for i in range(batch_size): if not done[i]: token = preds.data[i][0].item() if token == constant.EOS_ID: done[i] = True total_done += 1 else: output_seqs[i].append(token) return output_seqs, edit_logits def predict(self, src, src_mask, pos=None, beam_size=5, raw=None, never_decode_unk=False): """ Predict with beam search. """ if beam_size == 1: return self.predict_greedy(src, src_mask, pos, raw, never_decode_unk=never_decode_unk) enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw) # (1) encode source h_in, (hn, cn) = self.encode(enc_inputs, src_lens) if self.edit: edit_logits = self.edit_clf(hn) else: edit_logits = None # (2) set up beam with torch.no_grad(): h_in = h_in.data.repeat(beam_size, 1, 1) # repeat data for beam search src_mask = src_mask.repeat(beam_size, 1) # repeat decoder hidden states hn = hn.data.repeat(beam_size, 1) cn = cn.data.repeat(beam_size, 1) device = self.SOS_tensor.device beam = [Beam(beam_size, device) for _ in range(batch_size)] def update_state(states, idx, positions, beam_size): """ Select the states according to back pointers. """ for e in states: br, d = e.size() s = e.contiguous().view(beam_size, br // beam_size, d)[:,idx] s.data.copy_(s.data.index_select(0, positions)) # (3) main loop for i in range(self.max_dec_len): dec_inputs = torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(-1, 1) # if a unlearned character is predicted via the copy mechanism, # use the UNK embedding for it dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID dec_inputs = self.embedding(dec_inputs) log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk) log_probs = log_probs.view(beam_size, batch_size, -1).transpose(0,1).contiguous() # [batch, beam, V] # advance each beam done = [] for b in range(batch_size): is_done = beam[b].advance(log_probs.data[b]) if is_done: done += [b] # update beam state update_state((hn, cn), b, beam[b].get_current_origin(), beam_size) if len(done) == batch_size: break # back trace and find hypothesis all_hyp, all_scores = [], [] for b in range(batch_size): scores, ks = beam[b].sort_best() all_scores += [scores[0]] k = ks[0] hyp = beam[b].get_hyp(k) hyp = utils.prune_hyp(hyp) hyp = [i.item() for i in hyp] all_hyp += [hyp] return all_hyp, edit_logits