|
|
""" |
|
|
Manage beam search info structure. |
|
|
Heavily borrowed from OpenNMT-py. |
|
|
For code in OpenNMT-py, please check the following link (maybe in oldest version): |
|
|
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py |
|
|
""" |
|
|
|
|
|
import torch |
|
|
|
|
|
class Constants(): |
|
|
def __init__(self): |
|
|
self.PAD = 0 |
|
|
self.UNK = 1 |
|
|
self.BOS = 2 |
|
|
self.EOS = 3 |
|
|
self.PAD_WORD = '[PAD]' |
|
|
self.UNK_WORD = '[UNK]' |
|
|
self.BOS_WORD = '[CLS]' |
|
|
self.EOS_WORD = '[SEP]' |
|
|
|
|
|
@classmethod |
|
|
def from_tokenizer(cls, tokenizer): |
|
|
instance = cls() |
|
|
instance.PAD = tokenizer.vocab[instance.PAD_WORD] |
|
|
instance.UNK = tokenizer.vocab[instance.UNK_WORD] |
|
|
instance.BOS = tokenizer.vocab[instance.BOS_WORD] |
|
|
instance.EOS = tokenizer.vocab[instance.EOS_WORD] |
|
|
return instance |
|
|
|
|
|
class Beam(): |
|
|
''' Beam search ''' |
|
|
|
|
|
def __init__(self, size, device=False, tokenizer=None): |
|
|
if tokenizer is None: |
|
|
self.constants = Constants() |
|
|
else: |
|
|
self.constants = Constants.from_tokenizer(tokenizer) |
|
|
|
|
|
self.size = size |
|
|
self._done = False |
|
|
|
|
|
self.scores = torch.zeros((size,), dtype=torch.float, device=device) |
|
|
self.all_scores = [] |
|
|
|
|
|
|
|
|
self.prev_ks = [] |
|
|
|
|
|
|
|
|
self.next_ys = [torch.full((size,), self.constants.BOS, dtype=torch.long, device=device)] |
|
|
|
|
|
def get_current_state(self): |
|
|
"Get the outputs for the current timestep." |
|
|
return self.get_tentative_hypothesis() |
|
|
|
|
|
def get_current_origin(self): |
|
|
"Get the backpointers for the current timestep." |
|
|
return self.prev_ks[-1] |
|
|
|
|
|
@property |
|
|
def done(self): |
|
|
return self._done |
|
|
|
|
|
def advance(self, word_prob, word_length=None): |
|
|
|
|
|
"Update beam status and check if finished or not." |
|
|
num_words = word_prob.size(1) |
|
|
|
|
|
if len(self.prev_ks) > 0: |
|
|
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) |
|
|
else: |
|
|
beam_lk = word_prob[0] |
|
|
flat_beam_lk = beam_lk.view(-1) |
|
|
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) |
|
|
self.all_scores.append(self.scores) |
|
|
self.scores = best_scores |
|
|
|
|
|
|
|
|
prev_k = best_scores_id // num_words |
|
|
self.prev_ks.append(prev_k) |
|
|
self.next_ys.append(best_scores_id - prev_k * num_words) |
|
|
|
|
|
if self.next_ys[-1][0].item() == self.constants.EOS: |
|
|
self._done = True |
|
|
|
|
|
return self._done |
|
|
|
|
|
def sort_scores(self): |
|
|
"Sort the scores." |
|
|
return torch.sort(self.scores, 0, True) |
|
|
|
|
|
def get_the_best_score_and_idx(self): |
|
|
"Get the score of the best in the beam." |
|
|
scores, ids = self.sort_scores() |
|
|
return scores[1], ids[1] |
|
|
|
|
|
def get_tentative_hypothesis(self): |
|
|
"Get the decoded sequence for the current timestep." |
|
|
|
|
|
if len(self.next_ys) == 1: |
|
|
dec_seq = self.next_ys[0].unsqueeze(1) |
|
|
else: |
|
|
_, keys = self.sort_scores() |
|
|
hyps = [self.get_hypothesis(k) for k in keys] |
|
|
hyps = [[self.constants.BOS] + h for h in hyps] |
|
|
dec_seq = torch.LongTensor(hyps) |
|
|
|
|
|
return dec_seq |
|
|
|
|
|
def get_hypothesis(self, k): |
|
|
""" Walk back to construct the full hypothesis. """ |
|
|
hyp = [] |
|
|
for j in range(len(self.prev_ks) - 1, -1, -1): |
|
|
hyp.append(self.next_ys[j+1][k]) |
|
|
k = self.prev_ks[j][k] |
|
|
|
|
|
return list(map(lambda x: x.item(), hyp[::-1])) |
|
|
|