File size: 16,823 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
"""
The full encoder-decoder model, built on top of the base seq2seq modules.
"""
import logging
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import stanza.models.common.seq2seq_constant as constant
from stanza.models.common import utils
from stanza.models.common.seq2seq_modules import LSTMAttention
from stanza.models.common.beam import Beam
from stanza.models.common.seq2seq_constant import UNK_ID
logger = logging.getLogger('stanza')
class Seq2SeqModel(nn.Module):
"""
A complete encoder-decoder model, with optional attention.
A parent class which makes use of the contextual_embedding (such as a charlm)
can make use of unsaved_modules when saving.
"""
def __init__(self, args, emb_matrix=None, contextual_embedding=None):
super().__init__()
self.unsaved_modules = []
self.vocab_size = args['vocab_size']
self.emb_dim = args['emb_dim']
self.hidden_dim = args['hidden_dim']
self.nlayers = args['num_layers'] # encoder layers, decoder layers = 1
self.emb_dropout = args.get('emb_dropout', 0.0)
self.dropout = args['dropout']
self.pad_token = constant.PAD_ID
self.max_dec_len = args['max_dec_len']
self.top = args.get('top', 1e10)
self.args = args
self.emb_matrix = emb_matrix
self.add_unsaved_module("contextual_embedding", contextual_embedding)
logger.debug("Building an attentional Seq2Seq model...")
logger.debug("Using a Bi-LSTM encoder")
self.num_directions = 2
self.enc_hidden_dim = self.hidden_dim // 2
self.dec_hidden_dim = self.hidden_dim
self.use_pos = args.get('pos', False)
self.pos_dim = args.get('pos_dim', 0)
self.pos_vocab_size = args.get('pos_vocab_size', 0)
self.pos_dropout = args.get('pos_dropout', 0)
self.edit = args.get('edit', False)
self.num_edit = args.get('num_edit', 0)
self.copy = args.get('copy', False)
self.emb_drop = nn.Dropout(self.emb_dropout)
self.drop = nn.Dropout(self.dropout)
self.embedding = nn.Embedding(self.vocab_size, self.emb_dim, self.pad_token)
self.input_dim = self.emb_dim
if self.contextual_embedding is not None:
self.input_dim += self.contextual_embedding.hidden_dim()
self.encoder = nn.LSTM(self.input_dim, self.enc_hidden_dim, self.nlayers, \
bidirectional=True, batch_first=True, dropout=self.dropout if self.nlayers > 1 else 0)
self.decoder = LSTMAttention(self.emb_dim, self.dec_hidden_dim, \
batch_first=True, attn_type=self.args['attn_type'])
self.dec2vocab = nn.Linear(self.dec_hidden_dim, self.vocab_size)
if self.use_pos and self.pos_dim > 0:
logger.debug("Using POS in encoder")
self.pos_embedding = nn.Embedding(self.pos_vocab_size, self.pos_dim, self.pad_token)
self.pos_drop = nn.Dropout(self.pos_dropout)
if self.edit:
edit_hidden = self.hidden_dim//2
self.edit_clf = nn.Sequential(
nn.Linear(self.hidden_dim, edit_hidden),
nn.ReLU(),
nn.Linear(edit_hidden, self.num_edit))
if self.copy:
self.copy_gate = nn.Linear(self.dec_hidden_dim, 1)
SOS_tensor = torch.LongTensor([constant.SOS_ID])
self.register_buffer('SOS_tensor', SOS_tensor)
self.init_weights()
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
def init_weights(self):
# initialize embeddings
init_range = constant.EMB_INIT_RANGE
if self.emb_matrix is not None:
if isinstance(self.emb_matrix, np.ndarray):
self.emb_matrix = torch.from_numpy(self.emb_matrix)
assert self.emb_matrix.size() == (self.vocab_size, self.emb_dim), \
"Input embedding matrix must match size: {} x {}".format(self.vocab_size, self.emb_dim)
self.embedding.weight.data.copy_(self.emb_matrix)
else:
self.embedding.weight.data.uniform_(-init_range, init_range)
# decide finetuning
if self.top <= 0:
logger.debug("Do not finetune embedding layer.")
self.embedding.weight.requires_grad = False
elif self.top < self.vocab_size:
logger.debug("Finetune top {} embeddings.".format(self.top))
self.embedding.weight.register_hook(lambda x: utils.keep_partial_grad(x, self.top))
else:
logger.debug("Finetune all embeddings.")
# initialize pos embeddings
if self.use_pos:
self.pos_embedding.weight.data.uniform_(-init_range, init_range)
def zero_state(self, inputs):
batch_size = inputs.size(0)
device = self.SOS_tensor.device
h0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
c0 = torch.zeros(self.encoder.num_layers*2, batch_size, self.enc_hidden_dim, requires_grad=False, device=device)
return h0, c0
def encode(self, enc_inputs, lens):
""" Encode source sequence. """
h0, c0 = self.zero_state(enc_inputs)
packed_inputs = nn.utils.rnn.pack_padded_sequence(enc_inputs, lens, batch_first=True)
packed_h_in, (hn, cn) = self.encoder(packed_inputs, (h0, c0))
h_in, _ = nn.utils.rnn.pad_packed_sequence(packed_h_in, batch_first=True)
hn = torch.cat((hn[-1], hn[-2]), 1)
cn = torch.cat((cn[-1], cn[-2]), 1)
return h_in, (hn, cn)
def decode(self, dec_inputs, hn, cn, ctx, ctx_mask=None, src=None, never_decode_unk=False):
""" Decode a step, based on context encoding and source context states."""
dec_hidden = (hn, cn)
decoder_output = self.decoder(dec_inputs, dec_hidden, ctx, ctx_mask, return_logattn=self.copy)
if self.copy:
h_out, dec_hidden, log_attn = decoder_output
else:
h_out, dec_hidden = decoder_output
h_out_reshape = h_out.contiguous().view(h_out.size(0) * h_out.size(1), -1)
decoder_logits = self.dec2vocab(h_out_reshape)
decoder_logits = decoder_logits.view(h_out.size(0), h_out.size(1), -1)
log_probs = self.get_log_prob(decoder_logits)
if self.copy:
copy_logit = self.copy_gate(h_out)
if self.use_pos:
# can't copy the UPOS
log_attn = log_attn[:, :, 1:]
# renormalize
log_attn = torch.log_softmax(log_attn, -1)
# calculate copy probability for each word in the vocab
log_copy_prob = torch.nn.functional.logsigmoid(copy_logit) + log_attn
# scatter logsumexp
mx = log_copy_prob.max(-1, keepdim=True)[0]
log_copy_prob = log_copy_prob - mx
# here we make space in the log probs for vocab items
# which might be copied from the encoder side, but which
# were not known at training time
# note that such an item cannot possibly be predicted by
# the model as a raw output token
# however, the copy gate might score high on copying a
# previously unknown vocab item
copy_prob = torch.exp(log_copy_prob)
copied_vocab_shape = list(log_probs.size())
if torch.max(src) >= copied_vocab_shape[-1]:
copied_vocab_shape[-1] = torch.max(src) + 1
copied_vocab_prob = log_probs.new_zeros(copied_vocab_shape)
scattered_copy = src.unsqueeze(1).expand(src.size(0), copy_prob.size(1), src.size(1))
# fill in the copy tensor with the copy probs of each character
# the rest of the copy tensor will be filled with -largenumber
copied_vocab_prob = copied_vocab_prob.scatter_add(-1, scattered_copy, copy_prob)
zero_mask = (copied_vocab_prob == 0)
log_copied_vocab_prob = torch.log(copied_vocab_prob.masked_fill(zero_mask, 1e-12)) + mx
log_copied_vocab_prob = log_copied_vocab_prob.masked_fill(zero_mask, -1e12)
# combine with normal vocab probability
log_nocopy_prob = -torch.log(1 + torch.exp(copy_logit))
if log_probs.shape[-1] < copied_vocab_shape[-1]:
# for previously unknown vocab items which are in the encoder,
# we reuse the UNK_ID prediction
# this gives a baseline number which we can combine with
# the copy gate prediction
# technically this makes log_probs no longer represent
# a probability distribution when looking at unknown vocab
# this is probably not a serious problem
# an example of this usage is in the Lemmatizer, such as a
# plural word in English with the character "ã" in it instead of "a"
# if "ã" is not known in the training data, the lemmatizer would
# ordinarily be unable to output it, and thus the seq2seq model
# would have no chance to depluralize "ãntennae" -> "ãntenna"
# however, if we temporarily add "ã" to the encoder vocab,
# then let the copy gate accept that letter, we find the Lemmatizer
# seq2seq model will want to copy that particular vocab item
# this allows the Lemmatizer to produce "ã" instead of requiring
# that it produces UNK, then going back to the input text to
# figure out which UNK it intended to produce
new_log_probs = log_probs.new_zeros(copied_vocab_shape)
new_log_probs[:, :, :log_probs.shape[-1]] = log_probs
new_log_probs[:, :, log_probs.shape[-1]:] = new_log_probs[:, :, UNK_ID].unsqueeze(2)
log_probs = new_log_probs
log_probs = log_probs + log_nocopy_prob
log_probs = torch.logsumexp(torch.stack([log_copied_vocab_prob, log_probs]), 0)
if never_decode_unk:
log_probs[:, :, UNK_ID] = float("-inf")
return log_probs, dec_hidden
def embed(self, src, src_mask, pos, raw):
embed_src = src.clone()
embed_src[embed_src >= self.vocab_size] = UNK_ID
enc_inputs = self.emb_drop(self.embedding(embed_src))
batch_size = enc_inputs.size(0)
if self.use_pos:
assert pos is not None, "Missing POS input for seq2seq lemmatizer."
pos_inputs = self.pos_drop(self.pos_embedding(pos))
enc_inputs = torch.cat([pos_inputs.unsqueeze(1), enc_inputs], dim=1)
pos_src_mask = src_mask.new_zeros([batch_size, 1])
src_mask = torch.cat([pos_src_mask, src_mask], dim=1)
if raw is not None and self.contextual_embedding is not None:
raw_inputs = self.contextual_embedding(raw)
if self.use_pos:
raw_zeros = raw_inputs.new_zeros((raw_inputs.shape[0], 1, raw_inputs.shape[2]))
raw_inputs = torch.cat([raw_inputs, raw_zeros], dim=1)
enc_inputs = torch.cat([enc_inputs, raw_inputs], dim=2)
src_lens = list(src_mask.data.eq(constant.PAD_ID).long().sum(1))
return enc_inputs, batch_size, src_lens, src_mask
def forward(self, src, src_mask, tgt_in, pos=None, raw=None):
# prepare for encoder/decoder
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
# encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
if self.edit:
edit_logits = self.edit_clf(hn)
else:
edit_logits = None
dec_inputs = self.emb_drop(self.embedding(tgt_in))
log_probs, _ = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src)
return log_probs, edit_logits
def get_log_prob(self, logits):
logits_reshape = logits.view(-1, self.vocab_size)
log_probs = F.log_softmax(logits_reshape, dim=1)
if logits.dim() == 2:
return log_probs
return log_probs.view(logits.size(0), logits.size(1), logits.size(2))
def predict_greedy(self, src, src_mask, pos=None, raw=None, never_decode_unk=False):
""" Predict with greedy decoding. """
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
# encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
if self.edit:
edit_logits = self.edit_clf(hn)
else:
edit_logits = None
# greedy decode by step
dec_inputs = self.embedding(self.SOS_tensor)
dec_inputs = dec_inputs.expand(batch_size, dec_inputs.size(0), dec_inputs.size(1))
done = [False for _ in range(batch_size)]
total_done = 0
max_len = 0
output_seqs = [[] for _ in range(batch_size)]
while total_done < batch_size and max_len < self.max_dec_len:
log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
assert log_probs.size(1) == 1, "Output must have 1-step of output."
_, preds = log_probs.squeeze(1).max(1, keepdim=True)
# if a unlearned character is predicted via the copy mechanism,
# use the UNK embedding for it
dec_inputs = preds.clone()
dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
dec_inputs = self.embedding(dec_inputs) # update decoder inputs
max_len += 1
for i in range(batch_size):
if not done[i]:
token = preds.data[i][0].item()
if token == constant.EOS_ID:
done[i] = True
total_done += 1
else:
output_seqs[i].append(token)
return output_seqs, edit_logits
def predict(self, src, src_mask, pos=None, beam_size=5, raw=None, never_decode_unk=False):
""" Predict with beam search. """
if beam_size == 1:
return self.predict_greedy(src, src_mask, pos, raw, never_decode_unk=never_decode_unk)
enc_inputs, batch_size, src_lens, src_mask = self.embed(src, src_mask, pos, raw)
# (1) encode source
h_in, (hn, cn) = self.encode(enc_inputs, src_lens)
if self.edit:
edit_logits = self.edit_clf(hn)
else:
edit_logits = None
# (2) set up beam
with torch.no_grad():
h_in = h_in.data.repeat(beam_size, 1, 1) # repeat data for beam search
src_mask = src_mask.repeat(beam_size, 1)
# repeat decoder hidden states
hn = hn.data.repeat(beam_size, 1)
cn = cn.data.repeat(beam_size, 1)
device = self.SOS_tensor.device
beam = [Beam(beam_size, device) for _ in range(batch_size)]
def update_state(states, idx, positions, beam_size):
""" Select the states according to back pointers. """
for e in states:
br, d = e.size()
s = e.contiguous().view(beam_size, br // beam_size, d)[:,idx]
s.data.copy_(s.data.index_select(0, positions))
# (3) main loop
for i in range(self.max_dec_len):
dec_inputs = torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(-1, 1)
# if a unlearned character is predicted via the copy mechanism,
# use the UNK embedding for it
dec_inputs[dec_inputs >= self.vocab_size] = UNK_ID
dec_inputs = self.embedding(dec_inputs)
log_probs, (hn, cn) = self.decode(dec_inputs, hn, cn, h_in, src_mask, src=src, never_decode_unk=never_decode_unk)
log_probs = log_probs.view(beam_size, batch_size, -1).transpose(0,1).contiguous() # [batch, beam, V]
# advance each beam
done = []
for b in range(batch_size):
is_done = beam[b].advance(log_probs.data[b])
if is_done:
done += [b]
# update beam state
update_state((hn, cn), b, beam[b].get_current_origin(), beam_size)
if len(done) == batch_size:
break
# back trace and find hypothesis
all_hyp, all_scores = [], []
for b in range(batch_size):
scores, ks = beam[b].sort_best()
all_scores += [scores[0]]
k = ks[0]
hyp = beam[b].get_hyp(k)
hyp = utils.prune_hyp(hyp)
hyp = [i.item() for i in hyp]
all_hyp += [hyp]
return all_hyp, edit_logits
|