| |
| |
| |
| |
|
|
| from collections import namedtuple |
|
|
| import numpy as np |
| import torch |
| from fairseq import utils |
|
|
|
|
| DecoderOut = namedtuple( |
| "IterativeRefinementDecoderOut", |
| ["output_tokens", "output_scores", "attn", "step", "max_step", "history"], |
| ) |
|
|
|
|
| class IterativeRefinementGenerator(object): |
| def __init__( |
| self, |
| tgt_dict, |
| models=None, |
| eos_penalty=0.0, |
| max_iter=10, |
| max_ratio=2, |
| beam_size=1, |
| decoding_format=None, |
| retain_dropout=False, |
| adaptive=True, |
| retain_history=False, |
| reranking=False, |
| ): |
| """ |
| Generates translations based on iterative refinement. |
| |
| Args: |
| tgt_dict: target dictionary |
| eos_penalty: if > 0.0, it penalized early-stopping in decoding |
| max_iter: maximum number of refinement iterations |
| max_ratio: generate sequences of maximum length ax, where x is the source length |
| decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'} |
| retain_dropout: retaining dropout in the inference |
| adaptive: decoding with early stop |
| """ |
| self.bos = tgt_dict.bos() |
| self.pad = tgt_dict.pad() |
| self.unk = tgt_dict.unk() |
| self.eos = tgt_dict.eos() |
| self.vocab_size = len(tgt_dict) |
| self.eos_penalty = eos_penalty |
| self.max_iter = max_iter |
| self.max_ratio = max_ratio |
| self.beam_size = beam_size |
| self.reranking = reranking |
| self.decoding_format = decoding_format |
| self.retain_dropout = retain_dropout |
| self.retain_history = retain_history |
| self.adaptive = adaptive |
| self.models = models |
|
|
| def generate_batched_itr( |
| self, |
| data_itr, |
| maxlen_a=None, |
| maxlen_b=None, |
| cuda=False, |
| timer=None, |
| prefix_size=0, |
| ): |
| """Iterate over a batched dataset and yield individual translations. |
| |
| Args: |
| maxlen_a/b: generate sequences of maximum length ax + b, |
| where x is the source sentence length. |
| cuda: use GPU for generation |
| timer: StopwatchMeter for timing generations. |
| """ |
|
|
| for sample in data_itr: |
| if "net_input" not in sample: |
| continue |
| if timer is not None: |
| timer.start() |
| with torch.no_grad(): |
| hypos = self.generate( |
| self.models, |
| sample, |
| prefix_tokens=sample["target"][:, :prefix_size] |
| if prefix_size > 0 |
| else None, |
| ) |
| if timer is not None: |
| timer.stop(sample["ntokens"]) |
| for i, id in enumerate(sample["id"]): |
| |
| src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad) |
| ref = utils.strip_pad(sample["target"][i, :], self.pad) |
| yield id, src, ref, hypos[i] |
|
|
| @torch.no_grad() |
| def generate(self, models, sample, prefix_tokens=None, constraints=None): |
| if constraints is not None: |
| raise NotImplementedError( |
| "Constrained decoding with the IterativeRefinementGenerator is not supported" |
| ) |
|
|
| |
| if not self.retain_dropout: |
| for model in models: |
| model.eval() |
|
|
| model, reranker = models[0], None |
| if self.reranking: |
| assert len(models) > 1, "Assuming the last checkpoint is the reranker" |
| assert ( |
| self.beam_size > 1 |
| ), "Reranking requires multiple translation for each example" |
|
|
| reranker = models[-1] |
| models = models[:-1] |
|
|
| if len(models) > 1 and hasattr(model, "enable_ensemble"): |
| assert model.allow_ensemble, "{} does not support ensembling".format( |
| model.__class__.__name__ |
| ) |
| model.enable_ensemble(models) |
|
|
| |
| src_tokens = sample["net_input"]["src_tokens"] |
| src_lengths = sample["net_input"]["src_lengths"] |
| bsz, src_len = src_tokens.size() |
|
|
| |
| encoder_out = model.forward_encoder([src_tokens, src_lengths]) |
| prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) |
|
|
| if self.beam_size > 1: |
| assert ( |
| model.allow_length_beam |
| ), "{} does not support decoding with length beam.".format( |
| model.__class__.__name__ |
| ) |
|
|
| |
| length_beam_order = ( |
| utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) |
| ) |
| encoder_out = model.encoder.reorder_encoder_out( |
| encoder_out, length_beam_order |
| ) |
| prev_decoder_out = model.regenerate_length_beam( |
| prev_decoder_out, self.beam_size |
| ) |
| bsz = bsz * self.beam_size |
|
|
| sent_idxs = torch.arange(bsz) |
| prev_output_tokens = prev_decoder_out.output_tokens.clone() |
|
|
| if self.retain_history: |
| prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens]) |
|
|
| finalized = [[] for _ in range(bsz)] |
|
|
| def is_a_loop(x, y, s, a): |
| b, l_x, l_y = x.size(0), x.size(1), y.size(1) |
| if l_x > l_y: |
| y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) |
| s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) |
| if a is not None: |
| a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) |
| elif l_x < l_y: |
| x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) |
| return (x == y).all(1), y, s, a |
|
|
| def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): |
| cutoff = prev_out_token.ne(self.pad) |
| tokens = prev_out_token[cutoff] |
| if prev_out_score is None: |
| scores, score = None, None |
| else: |
| scores = prev_out_score[cutoff] |
| score = scores.mean() |
|
|
| if prev_out_attn is None: |
| hypo_attn, alignment = None, None |
| else: |
| hypo_attn = prev_out_attn[cutoff] |
| alignment = hypo_attn.max(dim=1)[1] |
| return { |
| "steps": step, |
| "tokens": tokens, |
| "positional_scores": scores, |
| "score": score, |
| "hypo_attn": hypo_attn, |
| "alignment": alignment, |
| } |
|
|
| for step in range(self.max_iter + 1): |
|
|
| decoder_options = { |
| "eos_penalty": self.eos_penalty, |
| "max_ratio": self.max_ratio, |
| "decoding_format": self.decoding_format, |
| } |
| prev_decoder_out = prev_decoder_out._replace( |
| step=step, |
| max_step=self.max_iter + 1, |
| ) |
|
|
| decoder_out = model.forward_decoder( |
| prev_decoder_out, encoder_out, **decoder_options |
| ) |
|
|
| if self.adaptive: |
| |
| terminated, out_tokens, out_scores, out_attn = is_a_loop( |
| prev_output_tokens, |
| decoder_out.output_tokens, |
| decoder_out.output_scores, |
| decoder_out.attn, |
| ) |
| decoder_out = decoder_out._replace( |
| output_tokens=out_tokens, |
| output_scores=out_scores, |
| attn=out_attn, |
| ) |
|
|
| else: |
| terminated = decoder_out.output_tokens.new_zeros( |
| decoder_out.output_tokens.size(0) |
| ).bool() |
|
|
| if step == self.max_iter: |
| terminated.fill_(1) |
|
|
| |
| finalized_idxs = sent_idxs[terminated] |
| finalized_tokens = decoder_out.output_tokens[terminated] |
| finalized_scores = decoder_out.output_scores[terminated] |
| finalized_attn = ( |
| None |
| if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) |
| else decoder_out.attn[terminated] |
| ) |
|
|
| if self.retain_history: |
| finalized_history_tokens = [h[terminated] for h in decoder_out.history] |
|
|
| for i in range(finalized_idxs.size(0)): |
| finalized[finalized_idxs[i]] = [ |
| finalized_hypos( |
| step, |
| finalized_tokens[i], |
| finalized_scores[i], |
| None if finalized_attn is None else finalized_attn[i], |
| ) |
| ] |
|
|
| if self.retain_history: |
| finalized[finalized_idxs[i]][0]["history"] = [] |
| for j in range(len(finalized_history_tokens)): |
| finalized[finalized_idxs[i]][0]["history"].append( |
| finalized_hypos( |
| step, finalized_history_tokens[j][i], None, None |
| ) |
| ) |
|
|
| |
| if terminated.sum() == terminated.size(0): |
| break |
|
|
| |
| not_terminated = ~terminated |
| prev_decoder_out = decoder_out._replace( |
| output_tokens=decoder_out.output_tokens[not_terminated], |
| output_scores=decoder_out.output_scores[not_terminated], |
| attn=decoder_out.attn[not_terminated] |
| if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0) |
| else None, |
| history=[h[not_terminated] for h in decoder_out.history] |
| if decoder_out.history is not None |
| else None, |
| ) |
| encoder_out = model.encoder.reorder_encoder_out( |
| encoder_out, not_terminated.nonzero(as_tuple=False).squeeze() |
| ) |
| sent_idxs = sent_idxs[not_terminated] |
| prev_output_tokens = prev_decoder_out.output_tokens.clone() |
|
|
| if self.beam_size > 1: |
| if reranker is not None: |
| finalized = self.rerank( |
| reranker, finalized, [src_tokens, src_lengths], self.beam_size |
| ) |
|
|
| |
| finalized = [ |
| finalized[ |
| np.argmax( |
| [ |
| finalized[self.beam_size * i + j][0]["score"] |
| for j in range(self.beam_size) |
| ] |
| ) |
| + self.beam_size * i |
| ] |
| for i in range(len(finalized) // self.beam_size) |
| ] |
|
|
| return finalized |
|
|
| def rerank(self, reranker, finalized, encoder_input, beam_size): |
| def rebuild_batch(finalized): |
| finalized_tokens = [f[0]["tokens"] for f in finalized] |
| finalized_maxlen = max(f.size(0) for f in finalized_tokens) |
| final_output_tokens = ( |
| finalized_tokens[0] |
| .new_zeros(len(finalized_tokens), finalized_maxlen) |
| .fill_(self.pad) |
| ) |
| for i, f in enumerate(finalized_tokens): |
| final_output_tokens[i, : f.size(0)] = f |
| return final_output_tokens |
|
|
| final_output_tokens = rebuild_batch(finalized) |
| final_output_tokens[ |
| :, 0 |
| ] = self.eos |
|
|
| reranker_encoder_out = reranker.encoder(*encoder_input) |
| length_beam_order = ( |
| utils.new_arange( |
| final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1) |
| ) |
| .t() |
| .reshape(-1) |
| ) |
| reranker_encoder_out = reranker.encoder.reorder_encoder_out( |
| reranker_encoder_out, length_beam_order |
| ) |
| reranking_scores = reranker.get_normalized_probs( |
| reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), |
| True, |
| None, |
| ) |
| reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None]) |
| reranking_masks = final_output_tokens[:, 1:].ne(self.pad) |
| reranking_scores = ( |
| reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1) |
| ) |
| reranking_scores = reranking_scores / reranking_masks.sum(1).type_as( |
| reranking_scores |
| ) |
|
|
| for i in range(len(finalized)): |
| finalized[i][0]["score"] = reranking_scores[i] |
|
|
| return finalized |
|
|