Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from fairseq.search import Search | |
| class NoisyChannelBeamSearch(Search): | |
| def __init__(self, tgt_dict): | |
| super().__init__(tgt_dict) | |
| self.fw_scores_buf = None | |
| self.lm_scores_buf = None | |
| def _init_buffers(self, t): | |
| # super()._init_buffers(t) | |
| if self.fw_scores_buf is None: | |
| self.scores_buf = t.new() | |
| self.indices_buf = torch.LongTensor().to(device=t.device) | |
| self.beams_buf = torch.LongTensor().to(device=t.device) | |
| self.fw_scores_buf = t.new() | |
| self.lm_scores_buf = t.new() | |
| def combine_fw_bw(self, combine_method, fw_cum, bw, step): | |
| if combine_method == "noisy_channel": | |
| fw_norm = fw_cum.div(step + 1) | |
| lprobs = bw + fw_norm | |
| elif combine_method == "lm_only": | |
| lprobs = bw + fw_cum | |
| return lprobs | |
| def step(self, step, fw_lprobs, scores, bw_lprobs, lm_lprobs, combine_method): | |
| self._init_buffers(fw_lprobs) | |
| bsz, beam_size, vocab_size = fw_lprobs.size() | |
| if step == 0: | |
| # at the first step all hypotheses are equally likely, so use | |
| # only the first beam | |
| fw_lprobs = fw_lprobs[:, ::beam_size, :].contiguous() | |
| bw_lprobs = bw_lprobs[:, ::beam_size, :].contiguous() | |
| # nothing to add since we are at the first step | |
| fw_lprobs_cum = fw_lprobs | |
| else: | |
| # make probs contain cumulative scores for each hypothesis | |
| raw_scores = (scores[:, :, step - 1].unsqueeze(-1)) | |
| fw_lprobs_cum = (fw_lprobs.add(raw_scores)) | |
| combined_lprobs = self.combine_fw_bw(combine_method, fw_lprobs_cum, bw_lprobs, step) | |
| # choose the top k according to the combined noisy channel model score | |
| torch.topk( | |
| combined_lprobs.view(bsz, -1), | |
| k=min( | |
| # Take the best 2 x beam_size predictions. We'll choose the first | |
| # beam_size of these which don't predict eos to continue with. | |
| beam_size * 2, | |
| combined_lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |
| ), | |
| out=(self.scores_buf, self.indices_buf), | |
| ) | |
| # save corresponding fw and lm scores | |
| self.fw_scores_buf = torch.gather(fw_lprobs_cum.view(bsz, -1), 1, self.indices_buf) | |
| self.lm_scores_buf = torch.gather(lm_lprobs.view(bsz, -1), 1, self.indices_buf) | |
| # Project back into relative indices and beams | |
| self.beams_buf = self.indices_buf // vocab_size | |
| self.indices_buf.fmod_(vocab_size) | |
| return self.scores_buf, self.fw_scores_buf, self.lm_scores_buf, self.indices_buf, self.beams_buf | |