File size: 3,777 Bytes
6498fe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | import torch
class Beam:
def __init__(self, beam_size=8, min_length=0, n_top=1, ranker=None,
start_token_id=1, end_token_id=2):
self.beam_size = beam_size
self.min_length = min_length
self.ranker = ranker
self.end_token_id = end_token_id
self.top_sentence_ended = False
self.prev_ks = []
self.next_ys = [torch.LongTensor(beam_size).fill_(start_token_id)] # remove padding
self.current_scores = torch.FloatTensor(beam_size).zero_()
self.all_scores = []
# Time and k pair for finished.
self.finished = []
self.n_top = n_top
self.ranker = ranker
def advance(self, next_log_probs):
# next_probs : beam_size X vocab_size
vocabulary_size = next_log_probs.size(1)
# current_beam_size = next_log_probs.size(0)
current_length = len(self.next_ys)
if current_length < self.min_length:
for beam_index in range(len(next_log_probs)):
next_log_probs[beam_index][self.end_token_id] = -1e10
if len(self.prev_ks) > 0:
beam_scores = next_log_probs + self.current_scores.unsqueeze(1).expand_as(next_log_probs)
# Don't let EOS have children.
last_y = self.next_ys[-1]
for beam_index in range(last_y.size(0)):
if last_y[beam_index] == self.end_token_id:
beam_scores[beam_index] = -1e10 # -1e20 raises error when executing
else:
beam_scores = next_log_probs[0]
flat_beam_scores = beam_scores.view(-1)
top_scores, top_score_ids = flat_beam_scores.topk(k=self.beam_size, dim=0, largest=True, sorted=True)
self.current_scores = top_scores
self.all_scores.append(self.current_scores)
prev_k = top_score_ids // vocabulary_size # (beam_size, )
next_y = top_score_ids - prev_k * vocabulary_size # (beam_size, )
self.prev_ks.append(prev_k)
self.next_ys.append(next_y)
for beam_index, last_token_id in enumerate(next_y):
if last_token_id == self.end_token_id:
# skip scoring
self.finished.append((self.current_scores[beam_index], len(self.next_ys) - 1, beam_index))
if next_y[0] == self.end_token_id:
self.top_sentence_ended = True
def get_current_state(self):
"Get the outputs for the current timestep."
return torch.stack(self.next_ys, dim=1)
def get_current_origin(self):
"Get the backpointers for the current timestep."
return self.prev_ks[-1]
def done(self):
return self.top_sentence_ended and len(self.finished) >= self.n_top
def get_hypothesis(self, timestep, k):
hypothesis = []
for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
hypothesis.append(self.next_ys[j + 1][k])
# for RNN, [:, k, :], and for trnasformer, [k, :, :]
k = self.prev_ks[j][k]
return hypothesis[::-1]
def sort_finished(self, minimum=None):
if minimum is not None:
i = 0
# Add from beam until we have minimum outputs.
while len(self.finished) < minimum:
# global_scores = self.global_scorer.score(self, self.scores)
# s = global_scores[i]
s = self.current_scores[i]
self.finished.append((s, len(self.next_ys) - 1, i))
i += 1
self.finished = sorted(self.finished, key=lambda a: a[0], reverse=True)
scores = [sc for sc, _, _ in self.finished]
ks = [(t, k) for _, t, k in self.finished]
return scores, ks
|