Exclibur's picture
Upload folder using huggingface_hub
f1c1609 verified
"""
Manage beam search info structure.
Heavily borrowed from OpenNMT-py.
For code in OpenNMT-py, please check the following link (maybe in oldest version):
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
"""
import torch
class Constants():
def __init__(self):
self.PAD = 0
self.UNK = 1
self.BOS = 2
self.EOS = 3
self.PAD_WORD = '[PAD]'
self.UNK_WORD = '[UNK]'
self.BOS_WORD = '[CLS]'
self.EOS_WORD = '[SEP]'
@classmethod
def from_tokenizer(cls, tokenizer):
instance = cls()
instance.PAD = tokenizer.vocab[instance.PAD_WORD]
instance.UNK = tokenizer.vocab[instance.UNK_WORD]
instance.BOS = tokenizer.vocab[instance.BOS_WORD]
instance.EOS = tokenizer.vocab[instance.EOS_WORD]
return instance
class Beam():
''' Beam search '''
def __init__(self, size, device=False, tokenizer=None):
if tokenizer is None:
self.constants = Constants()
else:
self.constants = Constants.from_tokenizer(tokenizer)
self.size = size
self._done = False
# The score for each interface on the beam.
self.scores = torch.zeros((size,), dtype=torch.float, device=device)
self.all_scores = []
# The backpointers at each time-step.
self.prev_ks = []
# The outputs at each time-step.
self.next_ys = [torch.full((size,), self.constants.BOS, dtype=torch.long, device=device)]
def get_current_state(self):
"Get the outputs for the current timestep."
return self.get_tentative_hypothesis()
def get_current_origin(self):
"Get the backpointers for the current timestep."
return self.prev_ks[-1]
@property
def done(self):
return self._done
def advance(self, word_prob, word_length=None):
"Update beam status and check if finished or not."
num_words = word_prob.size(1)
# Sum the previous scores.
if len(self.prev_ks) > 0:
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
else:
beam_lk = word_prob[0]
flat_beam_lk = beam_lk.view(-1)
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
self.all_scores.append(self.scores)
self.scores = best_scores
# bestScoresId is flattened as a (beam x word) array,
# so we need to calculate which word and beam each score came from
prev_k = best_scores_id // num_words
self.prev_ks.append(prev_k)
self.next_ys.append(best_scores_id - prev_k * num_words)
# End condition is when top-of-beam is EOS.
if self.next_ys[-1][0].item() == self.constants.EOS:
self._done = True
return self._done
def sort_scores(self):
"Sort the scores."
return torch.sort(self.scores, 0, True)
def get_the_best_score_and_idx(self):
"Get the score of the best in the beam."
scores, ids = self.sort_scores()
return scores[1], ids[1]
def get_tentative_hypothesis(self):
"Get the decoded sequence for the current timestep."
if len(self.next_ys) == 1:
dec_seq = self.next_ys[0].unsqueeze(1)
else:
_, keys = self.sort_scores()
hyps = [self.get_hypothesis(k) for k in keys]
hyps = [[self.constants.BOS] + h for h in hyps]
dec_seq = torch.LongTensor(hyps)
return dec_seq
def get_hypothesis(self, k):
""" Walk back to construct the full hypothesis. """
hyp = []
for j in range(len(self.prev_ks) - 1, -1, -1):
hyp.append(self.next_ys[j+1][k])
k = self.prev_ks[j][k]
return list(map(lambda x: x.item(), hyp[::-1]))