File size: 5,356 Bytes
c668e80 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import torch
import torch.nn as nn
def collapse_copy_scores(
scores, batch, tgt_vocab, src_vocabs=None, batch_dim=1, batch_offset=None
):
"""
Given scores from an expanded dictionary
corresponeding to a batch, sums together copies,
with a dictionary word when it is ambiguous.
"""
offset = len(tgt_vocab)
for b in range(scores.size(batch_dim)):
blank = []
fill = []
if src_vocabs is None:
src_vocab = batch["src_ex_vocab"][b]
else:
batch_id = batch_offset[b] if batch_offset is not None else b
index = batch["indices"].data[batch_id]
src_vocab = src_vocabs[index]
for i in range(1, len(src_vocab)):
sw = src_vocab.ids_to_tokens[i]
ti = tgt_vocab[sw]
if ti != 0:
blank.append(offset + i)
fill.append(ti)
if blank:
blank = torch.Tensor(blank).type_as(batch["indices"].data)
fill = torch.Tensor(fill).type_as(batch["indices"].data)
score = scores[:, b] if batch_dim == 1 else scores[b]
score.index_add_(1, fill, score.index_select(1, blank))
score.index_fill_(1, blank, 1e-10)
return scores
class CopyGenerator(nn.Module):
"""An implementation of pointer-generator networks
:cite:`DBLP:journals/corr/SeeLM17`.
These networks consider copying words
directly from the source sequence.
The copy generator is an extended version of the standard
generator that computes three values.
* :math:`p_{softmax}` the standard softmax over `tgt_dict`
* :math:`p(z)` the probability of copying a word from
the source
* :math:`p_{copy}` the probility of copying a particular word.
taken from the attention distribution directly.
The model returns a distribution over the extend dictionary,
computed as
:math:`p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)`
Args:
input_size (int): size of input representation
output_size (int): size of output vocabulary
pad_idx (int)
"""
def __init__(self, input_size, output_size, pad_idx):
super(CopyGenerator, self).__init__()
self.linear = nn.Linear(input_size, output_size)
self.linear_copy = nn.Linear(input_size, 1)
self.pad_idx = pad_idx
def forward(self, hidden, attn, src_map):
"""
Compute a distribution over the target dictionary
extended by the dynamic dictionary implied by copying
source words.
Args:
hidden (FloatTensor): hidden output ``(batch x tlen, input_size)``
attn (FloatTensor): attn for each ``(batch x tlen, slen)``
src_map (FloatTensor):
A sparse indicator matrix mapping each source word to
its index in the "extended" vocab containing.
``(batch, src_len, extra_words)``
"""
_, slen = attn.size()
batch, _, cvocab = src_map.size()
# Original probabilities.
logits = self.linear(hidden)
logits[:, self.pad_idx] = -float("inf")
prob = torch.softmax(logits, 1)
# Probability of copying p(z=1) batch.
p_copy = torch.sigmoid(self.linear_copy(hidden))
# Probability of not copying: p_{word}(w) * (1 - p(z))
out_prob = torch.mul(prob, 1 - p_copy)
mul_attn = torch.mul(attn, p_copy)
copy_prob = torch.bmm(mul_attn.view(-1, batch, slen).transpose(0, 1), src_map)
copy_prob = copy_prob.contiguous().view(-1, cvocab)
return torch.cat([out_prob, copy_prob], 1)
class CopyGeneratorLoss(nn.Module):
"""Copy generator criterion."""
def __init__(
self, vocab_size, force_copy, unk_index=0, ignore_index=-100, eps=1e-20
):
super(CopyGeneratorLoss, self).__init__()
self.force_copy = force_copy
self.eps = eps
self.vocab_size = vocab_size
self.ignore_index = ignore_index
self.unk_index = unk_index
def forward(self, scores, align, target):
"""
Args:
scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size
whose sum along dim 1 is less than or equal to 1, i.e. cols
softmaxed.
align (LongTensor): ``(batch_size x tgt_len)``
target (LongTensor): ``(batch_size x tgt_len)``
"""
# probabilities assigned by the model to the gold targets
vocab_probs = scores.gather(1, target.unsqueeze(1)).squeeze(1)
# probability of tokens copied from source
copy_ix = align.unsqueeze(1) + self.vocab_size
copy_tok_probs = scores.gather(1, copy_ix).squeeze(1)
# Set scores for unk to 0 and add eps
copy_tok_probs[align == self.unk_index] = 0
copy_tok_probs += self.eps # to avoid -inf logs
# find the indices in which you do not use the copy mechanism
non_copy = align == self.unk_index
if not self.force_copy:
non_copy = non_copy | (target != self.unk_index)
probs = torch.where(non_copy, copy_tok_probs + vocab_probs, copy_tok_probs)
loss = -probs.log() # just NLLLoss; can the module be incorporated?
# Drop padding.
loss[target == self.ignore_index] = 0
return loss
|