| import torch |
|
|
| class Beam: |
|
|
| def __init__(self, beam_size=8, min_length=0, n_top=1, ranker=None, |
| start_token_id=1, end_token_id=2): |
| self.beam_size = beam_size |
| self.min_length = min_length |
| self.ranker = ranker |
|
|
| self.end_token_id = end_token_id |
| self.top_sentence_ended = False |
|
|
| self.prev_ks = [] |
| self.next_ys = [torch.LongTensor(beam_size).fill_(start_token_id)] |
|
|
| self.current_scores = torch.FloatTensor(beam_size).zero_() |
| self.all_scores = [] |
|
|
| |
| self.finished = [] |
| self.n_top = n_top |
|
|
| self.ranker = ranker |
|
|
| def advance(self, next_log_probs): |
| |
|
|
| vocabulary_size = next_log_probs.size(1) |
| |
|
|
| current_length = len(self.next_ys) |
| if current_length < self.min_length: |
| for beam_index in range(len(next_log_probs)): |
| next_log_probs[beam_index][self.end_token_id] = -1e10 |
|
|
| if len(self.prev_ks) > 0: |
| beam_scores = next_log_probs + self.current_scores.unsqueeze(1).expand_as(next_log_probs) |
| |
| last_y = self.next_ys[-1] |
| for beam_index in range(last_y.size(0)): |
| if last_y[beam_index] == self.end_token_id: |
| beam_scores[beam_index] = -1e10 |
| else: |
| beam_scores = next_log_probs[0] |
| |
| flat_beam_scores = beam_scores.view(-1) |
| top_scores, top_score_ids = flat_beam_scores.topk(k=self.beam_size, dim=0, largest=True, sorted=True) |
|
|
| self.current_scores = top_scores |
| self.all_scores.append(self.current_scores) |
| |
| prev_k = top_score_ids // vocabulary_size |
| next_y = top_score_ids - prev_k * vocabulary_size |
| |
| |
| self.prev_ks.append(prev_k) |
| self.next_ys.append(next_y) |
|
|
| for beam_index, last_token_id in enumerate(next_y): |
| |
| if last_token_id == self.end_token_id: |
| |
| |
| self.finished.append((self.current_scores[beam_index], len(self.next_ys) - 1, beam_index)) |
|
|
| if next_y[0] == self.end_token_id: |
| self.top_sentence_ended = True |
|
|
| def get_current_state(self): |
| "Get the outputs for the current timestep." |
| return torch.stack(self.next_ys, dim=1) |
|
|
| def get_current_origin(self): |
| "Get the backpointers for the current timestep." |
| return self.prev_ks[-1] |
|
|
| def done(self): |
| return self.top_sentence_ended and len(self.finished) >= self.n_top |
|
|
| def get_hypothesis(self, timestep, k): |
| hypothesis = [] |
| for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): |
| hypothesis.append(self.next_ys[j + 1][k]) |
| |
| k = self.prev_ks[j][k] |
|
|
| return hypothesis[::-1] |
|
|
| def sort_finished(self, minimum=None): |
| if minimum is not None: |
| i = 0 |
| |
| while len(self.finished) < minimum: |
| |
| |
| s = self.current_scores[i] |
| self.finished.append((s, len(self.next_ys) - 1, i)) |
| i += 1 |
|
|
| self.finished = sorted(self.finished, key=lambda a: a[0], reverse=True) |
| scores = [sc for sc, _, _ in self.finished] |
| ks = [(t, k) for _, t, k in self.finished] |
| return scores, ks |
|
|