| |
| |
| |
| |
|
|
| import math |
| from typing import Dict, List, Optional |
| import sys |
|
|
| import torch |
| import torch.nn as nn |
| from fairseq import search, utils |
| from fairseq.models import FairseqIncrementalDecoder |
| from torch import Tensor |
| from fairseq.ngram_repeat_block import NGramRepeatBlock |
|
|
| from data import data_utils |
|
|
| class SequenceGenerator(nn.Module): |
| def __init__( |
| self, |
| models, |
| tgt_dict, |
| beam_size=1, |
| max_len_a=0, |
| max_len_b=200, |
| max_len=0, |
| min_len=1, |
| normalize_scores=True, |
| len_penalty=1.0, |
| unk_penalty=0.0, |
| temperature=1.0, |
| match_source_len=False, |
| no_repeat_ngram_size=0, |
| search_strategy=None, |
| eos=None, |
| symbols_to_strip_from_output=None, |
| lm_model=None, |
| lm_weight=1.0, |
| constraint_trie=None, |
| constraint_range=None, |
| gen_code=False, |
| gen_box=False, |
| ignore_eos=False, |
| zero_shot=False |
| ): |
| """Generates translations of a given source sentence. |
| |
| Args: |
| models (List[~fairseq.models.FairseqModel]): ensemble of models, |
| currently support fairseq.models.TransformerModel for scripting |
| beam_size (int, optional): beam width (default: 1) |
| max_len_a/b (int, optional): generate sequences of maximum length |
| ax + b, where x is the source length |
| max_len (int, optional): the maximum length of the generated output |
| (not including end-of-sentence) |
| min_len (int, optional): the minimum length of the generated output |
| (not including end-of-sentence) |
| normalize_scores (bool, optional): normalize scores by the length |
| of the output (default: True) |
| len_penalty (float, optional): length penalty, where <1.0 favors |
| shorter, >1.0 favors longer sentences (default: 1.0) |
| unk_penalty (float, optional): unknown word penalty, where <0 |
| produces more unks, >0 produces fewer (default: 0.0) |
| temperature (float, optional): temperature, where values |
| >1.0 produce more uniform samples and values <1.0 produce |
| sharper samples (default: 1.0) |
| match_source_len (bool, optional): outputs should match the source |
| length (default: False) |
| """ |
| super().__init__() |
| if isinstance(models, EnsembleModel): |
| self.model = models |
| else: |
| self.model = EnsembleModel(models) |
| self.gen_code = gen_code |
| self.gen_box = gen_box |
| self.ignore_eos = ignore_eos |
| self.tgt_dict = tgt_dict |
| self.pad = tgt_dict.pad() |
| self.unk = tgt_dict.unk() |
| self.bos = tgt_dict.bos() |
| self.eos = tgt_dict.eos() if eos is None else eos |
| self.symbols_to_strip_from_output = ( |
| symbols_to_strip_from_output.union({self.eos}) |
| if symbols_to_strip_from_output is not None |
| else {self.bos, self.eos} |
| ) |
| self.vocab_size = len(tgt_dict) |
| self.beam_size = beam_size |
| |
| self.beam_size = min(beam_size, self.vocab_size - 1) |
| self.max_len_a = max_len_a |
| self.max_len_b = max_len_b |
| self.min_len = min_len |
| self.max_len = max_len or self.model.max_decoder_positions() |
|
|
| self.normalize_scores = normalize_scores |
| self.len_penalty = len_penalty |
| self.unk_penalty = unk_penalty |
| self.temperature = temperature |
| self.match_source_len = match_source_len |
| self.zero_shot = zero_shot |
|
|
| if no_repeat_ngram_size > 0: |
| self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) |
| else: |
| self.repeat_ngram_blocker = None |
|
|
| assert temperature > 0, "--temperature must be greater than 0" |
|
|
| self.search = ( |
| search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy |
| ) |
| |
| |
| |
| self.should_set_src_lengths = ( |
| hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths |
| ) |
|
|
| self.model.eval() |
|
|
| self.lm_model = lm_model |
| self.lm_weight = lm_weight |
| if self.lm_model is not None: |
| self.lm_model.eval() |
|
|
| self.constraint_trie = constraint_trie |
|
|
| self.constraint_start = None |
| self.constraint_end = None |
| if constraint_range is not None: |
| constraint_start, constraint_end = constraint_range.split(',') |
| self.constraint_start = int(constraint_start) |
| self.constraint_end = int(constraint_end) |
|
|
| def cuda(self): |
| self.model.cuda() |
| return self |
|
|
| @torch.no_grad() |
| def forward( |
| self, |
| sample: Dict[str, Dict[str, Tensor]], |
| prefix_tokens: Optional[Tensor] = None, |
| bos_token: Optional[int] = None, |
| ): |
| """Generate a batch of translations. |
| |
| Args: |
| sample (dict): batch |
| prefix_tokens (torch.LongTensor, optional): force decoder to begin |
| with these tokens |
| bos_token (int, optional): beginning of sentence token |
| (default: self.eos) |
| """ |
| return self._generate(sample, prefix_tokens, bos_token=bos_token) |
|
|
| |
| def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): |
| """Iterate over a batched dataset and yield individual translations. |
| Args: |
| cuda (bool, optional): use GPU for generation |
| timer (StopwatchMeter, optional): time generations |
| """ |
| for sample in data_itr: |
| s = utils.move_to_cuda(sample) if cuda else sample |
| if "net_input" not in s: |
| continue |
| input = s["net_input"] |
| |
| |
| encoder_input = { |
| k: v for k, v in input.items() if k != "prev_output_tokens" |
| } |
| if timer is not None: |
| timer.start() |
| with torch.no_grad(): |
| hypos = self.generate(encoder_input) |
| if timer is not None: |
| timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) |
| for i, id in enumerate(s["id"].data): |
| |
| src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) |
| ref = ( |
| utils.strip_pad(s["target"].data[i, :], self.pad) |
| if s["target"] is not None |
| else None |
| ) |
| yield id, src, ref, hypos[i] |
|
|
| @torch.no_grad() |
| def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]: |
| """Generate translations. Match the api of other fairseq generators. |
| |
| Args: |
| models (List[~fairseq.models.FairseqModel]): ensemble of models |
| sample (dict): batch |
| prefix_tokens (torch.LongTensor, optional): force decoder to begin |
| with these tokens |
| constraints (torch.LongTensor, optional): force decoder to include |
| the list of constraints |
| bos_token (int, optional): beginning of sentence token |
| (default: self.eos) |
| """ |
| return self._generate(models, sample, **kwargs) |
|
|
| def _generate( |
| self, |
| models, |
| sample: Dict[str, Dict[str, Tensor]], |
| prefix_tokens: Optional[Tensor] = None, |
| constraints: Optional[Tensor] = None, |
| bos_token: Optional[int] = None, |
| ): |
| model = EnsembleModel(models) |
| incremental_states = torch.jit.annotate( |
| List[Dict[str, Dict[str, Optional[Tensor]]]], |
| [ |
| torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) |
| for i in range(model.models_size) |
| ], |
| ) |
| net_input = sample["net_input"] |
|
|
| if "src_tokens" in net_input: |
| src_tokens = net_input["src_tokens"] |
| |
| src_lengths = ( |
| (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) |
| ) |
| elif "source" in net_input: |
| src_tokens = net_input["source"] |
| src_lengths = ( |
| net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) |
| if net_input["padding_mask"] is not None |
| else torch.tensor(src_tokens.size(-1)).to(src_tokens) |
| ) |
| elif "features" in net_input: |
| src_tokens = net_input["features"] |
| src_lengths = ( |
| net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) |
| if net_input["padding_mask"] is not None |
| else torch.tensor(src_tokens.size(-1)).to(src_tokens) |
| ) |
| else: |
| raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys())) |
|
|
| |
| |
| bsz, src_len = src_tokens.size()[:2] |
| beam_size = self.beam_size |
|
|
| if constraints is not None and not self.search.supports_constraints: |
| raise NotImplementedError( |
| "Target-side constraints were provided, but search method doesn't support them" |
| ) |
|
|
| |
| self.search.init_constraints(constraints, beam_size) |
|
|
| max_len: int = -1 |
| if self.match_source_len: |
| max_len = src_lengths.max().item() |
| else: |
| max_len = int(self.max_len_a * src_len + self.max_len_b) |
| assert ( |
| self.min_len <= max_len |
| ), "min_len cannot be larger than max_len, please adjust these!" |
| |
| with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"): |
| encoder_outs = model.forward_encoder(net_input) |
|
|
| |
| new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) |
| new_order = new_order.to(src_tokens.device).long() |
| encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) |
| |
| assert encoder_outs is not None |
|
|
| |
| scores = ( |
| torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() |
| ) |
| tokens = ( |
| torch.zeros(bsz * beam_size, max_len + 2) |
| .to(src_tokens) |
| .long() |
| .fill_(self.pad) |
| ) |
| |
| tokens[:, 0] = self.bos |
| attn: Optional[Tensor] = None |
|
|
| |
| |
| |
| |
| cands_to_ignore = ( |
| torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) |
| ) |
|
|
| |
| finalized = torch.jit.annotate( |
| List[List[Dict[str, Tensor]]], |
| [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], |
| ) |
|
|
| |
| finished = [False for i in range(bsz)] |
| num_remaining_sent = bsz |
|
|
| |
| cand_size = 2 * beam_size |
|
|
| |
| bbsz_offsets = ( |
| (torch.arange(0, bsz) * beam_size) |
| .unsqueeze(1) |
| .type_as(tokens) |
| .to(src_tokens.device) |
| ) |
| cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) |
|
|
| reorder_state: Optional[Tensor] = None |
| batch_idxs: Optional[Tensor] = None |
|
|
| original_batch_idxs: Optional[Tensor] = None |
| if "id" in sample and isinstance(sample["id"], Tensor): |
| original_batch_idxs = sample["id"] |
| else: |
| original_batch_idxs = torch.arange(0, bsz).type_as(tokens) |
|
|
| for step in range(max_len + 1): |
| |
| if reorder_state is not None: |
| if batch_idxs is not None: |
| |
| corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( |
| batch_idxs |
| ) |
| reorder_state.view(-1, beam_size).add_( |
| corr.unsqueeze(-1) * beam_size |
| ) |
| original_batch_idxs = original_batch_idxs[batch_idxs] |
| model.reorder_incremental_state(incremental_states, reorder_state) |
| encoder_outs = model.reorder_encoder_out( |
| encoder_outs, reorder_state |
| ) |
| with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"): |
| lprobs, avg_attn_scores = model.forward_decoder( |
| tokens[:, : step + 1], |
| encoder_outs, |
| incremental_states, |
| self.temperature, |
| constraint_trie=self.constraint_trie, |
| constraint_start=self.constraint_start, |
| constraint_end=self.constraint_end, |
| gen_code=self.gen_code, |
| zero_shot=self.zero_shot, |
| prefix_tokens=prefix_tokens |
| ) |
|
|
| if self.lm_model is not None: |
| lm_out = self.lm_model(tokens[:, : step + 1]) |
| probs = self.lm_model.get_normalized_probs( |
| lm_out, log_probs=True, sample=None |
| ) |
| probs = probs[:, -1, :] * self.lm_weight |
| lprobs += probs |
| |
| if ( |
| prefix_tokens is not None |
| and step < prefix_tokens.size(1) |
| and step < max_len |
| ): |
| lprobs, tokens, scores = self._prefix_tokens( |
| step, lprobs, scores, tokens, prefix_tokens, beam_size |
| ) |
| elif step < self.min_len: |
| |
| lprobs[:, self.eos] = -math.inf |
|
|
| lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) |
|
|
| lprobs[:, self.pad] = -math.inf |
| lprobs[:, self.unk] -= self.unk_penalty |
|
|
| if (self.gen_code or self.gen_box) and step < max_len: |
| lprobs[:, :4] = -math.inf |
| if self.gen_box: |
| lprobs[:, -1] = -math.inf |
| if (step + 1) % 5 == 0: |
| lprobs[:, self.constraint_start:59457] = -math.inf |
| else: |
| lprobs[:, 59457:] = -math.inf |
|
|
| |
| if step >= max_len: |
| lprobs[:, : self.eos] = -math.inf |
| lprobs[:, self.eos + 1 :] = -math.inf |
| if self.ignore_eos: |
| lprobs[:, self.eos] = 1 |
|
|
| |
| if avg_attn_scores is not None: |
| if attn is None: |
| attn = torch.empty( |
| bsz * beam_size, avg_attn_scores.size(1), max_len + 2 |
| ).to(scores) |
| attn[:, :, step + 1].copy_(avg_attn_scores) |
|
|
| scores = scores.type_as(lprobs) |
| eos_bbsz_idx = torch.empty(0).to( |
| tokens |
| ) |
| eos_scores = torch.empty(0).to( |
| scores |
| ) |
|
|
| if self.should_set_src_lengths: |
| self.search.set_src_lengths(src_lengths) |
|
|
| if self.repeat_ngram_blocker is not None: |
| lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) |
|
|
| |
| cand_scores, cand_indices, cand_beams = self.search.step( |
| step, |
| lprobs.view(bsz, -1, self.vocab_size), |
| scores.view(bsz, beam_size, -1)[:, :, :step], |
| tokens[:, : step + 1], |
| original_batch_idxs, |
| ) |
|
|
| |
| |
| |
| cand_bbsz_idx = cand_beams.add(bbsz_offsets) |
|
|
| |
| |
| eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) |
| eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) |
|
|
| |
| |
| |
| eos_bbsz_idx = torch.masked_select( |
| cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] |
| ) |
|
|
| finalized_sents: List[int] = [] |
| if eos_bbsz_idx.numel() > 0: |
| eos_scores = torch.masked_select( |
| cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] |
| ) |
|
|
| finalized_sents = self.finalize_hypos( |
| step, |
| eos_bbsz_idx, |
| eos_scores, |
| tokens, |
| scores, |
| finalized, |
| finished, |
| beam_size, |
| attn, |
| src_lengths, |
| max_len, |
| ) |
| num_remaining_sent -= len(finalized_sents) |
|
|
| assert num_remaining_sent >= 0 |
| if num_remaining_sent == 0: |
| break |
| if self.search.stop_on_max_len and step >= max_len: |
| break |
| assert step < max_len, f"{step} < {max_len}" |
|
|
| |
| |
| if len(finalized_sents) > 0: |
| new_bsz = bsz - len(finalized_sents) |
|
|
| |
| batch_mask = torch.ones( |
| bsz, dtype=torch.bool, device=cand_indices.device |
| ) |
| batch_mask[finalized_sents] = False |
| |
| batch_idxs = torch.arange( |
| bsz, device=cand_indices.device |
| ).masked_select(batch_mask) |
|
|
| |
| self.search.prune_sentences(batch_idxs) |
|
|
| eos_mask = eos_mask[batch_idxs] |
| cand_beams = cand_beams[batch_idxs] |
| bbsz_offsets.resize_(new_bsz, 1) |
| cand_bbsz_idx = cand_beams.add(bbsz_offsets) |
| cand_scores = cand_scores[batch_idxs] |
| cand_indices = cand_indices[batch_idxs] |
|
|
| if prefix_tokens is not None: |
| prefix_tokens = prefix_tokens[batch_idxs] |
| src_lengths = src_lengths[batch_idxs] |
| cands_to_ignore = cands_to_ignore[batch_idxs] |
|
|
| scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) |
| tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) |
| if attn is not None: |
| attn = attn.view(bsz, -1)[batch_idxs].view( |
| new_bsz * beam_size, attn.size(1), -1 |
| ) |
| bsz = new_bsz |
| else: |
| batch_idxs = None |
|
|
| |
| |
| |
|
|
| |
|
|
| eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) |
| active_mask = torch.add( |
| eos_mask.type_as(cand_offsets) * cand_size, |
| cand_offsets[: eos_mask.size(1)], |
| ) |
|
|
| |
| |
| |
| |
| |
| new_cands_to_ignore, active_hypos = torch.topk( |
| active_mask, k=beam_size, dim=1, largest=False |
| ) |
|
|
| |
| cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] |
| |
| assert (~cands_to_ignore).any(dim=1).all() |
|
|
| |
|
|
| |
| |
| active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) |
| active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) |
|
|
| active_bbsz_idx = active_bbsz_idx.view(-1) |
| active_scores = active_scores.view(-1) |
|
|
| |
|
|
| |
| tokens[:, : step + 1] = torch.index_select( |
| tokens[:, : step + 1], dim=0, index=active_bbsz_idx |
| ) |
| |
| tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( |
| cand_indices, dim=1, index=active_hypos |
| ) |
| if step > 0: |
| scores[:, :step] = torch.index_select( |
| scores[:, :step], dim=0, index=active_bbsz_idx |
| ) |
| scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( |
| cand_scores, dim=1, index=active_hypos |
| ) |
|
|
| |
| self.search.update_constraints(active_hypos) |
|
|
| |
| if attn is not None: |
| attn[:, :, : step + 2] = torch.index_select( |
| attn[:, :, : step + 2], dim=0, index=active_bbsz_idx |
| ) |
|
|
| |
| reorder_state = active_bbsz_idx |
|
|
| |
| for sent in range(len(finalized)): |
| scores = torch.tensor( |
| [float(elem["score"].item()) for elem in finalized[sent]] |
| ) |
| _, sorted_scores_indices = torch.sort(scores, descending=True) |
| finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] |
| finalized[sent] = torch.jit.annotate( |
| List[Dict[str, Tensor]], finalized[sent] |
| ) |
| return finalized |
|
|
| def _prefix_tokens( |
| self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int |
| ): |
| """Handle prefix tokens""" |
| prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) |
| prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) |
| prefix_mask = prefix_toks.ne(self.pad) |
| if self.constraint_trie is None: |
| lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1 |
| else: |
| lprobs[prefix_mask] = -math.inf |
| lprobs[prefix_mask] = lprobs[prefix_mask].scatter( |
| -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] |
| ) |
| |
| |
| eos_mask = prefix_toks.eq(self.eos) |
| if eos_mask.any(): |
| |
| first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ |
| :, 0, 1 : step + 1 |
| ] |
| eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] |
| target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] |
| assert (first_beam == target_prefix).all() |
|
|
| |
| tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) |
| scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) |
| lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) |
| return lprobs, tokens, scores |
|
|
| def replicate_first_beam(self, tensor, mask, beam_size: int): |
| tensor = tensor.view(-1, beam_size, tensor.size(-1)) |
| tensor[mask] = tensor[mask][:, :1, :] |
| return tensor.view(-1, tensor.size(-1)) |
|
|
| def finalize_hypos( |
| self, |
| step: int, |
| bbsz_idx, |
| eos_scores, |
| tokens, |
| scores, |
| finalized: List[List[Dict[str, Tensor]]], |
| finished: List[bool], |
| beam_size: int, |
| attn: Optional[Tensor], |
| src_lengths, |
| max_len: int, |
| ): |
| """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. |
| A sentence is finalized when {beam_size} finished items have been collected for it. |
| |
| Returns number of sentences (not beam items) being finalized. |
| These will be removed from the batch and not processed further. |
| Args: |
| bbsz_idx (Tensor): |
| """ |
| assert bbsz_idx.numel() == eos_scores.numel() |
|
|
| |
| |
| |
| tokens_clone = tokens.index_select(0, bbsz_idx)[ |
| :, 1 : step + 2 |
| ] |
|
|
| tokens_clone[:, step] = self.eos |
| attn_clone = ( |
| attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] |
| if attn is not None |
| else None |
| ) |
|
|
| |
| pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] |
| pos_scores[:, step] = eos_scores |
| |
| pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] |
|
|
| |
| if self.normalize_scores: |
| eos_scores /= (step + 1) ** self.len_penalty |
|
|
| |
| |
| |
| |
| cum_unfin: List[int] = [] |
| prev = 0 |
| for f in finished: |
| if f: |
| prev += 1 |
| else: |
| cum_unfin.append(prev) |
| cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) |
|
|
| unfin_idx = bbsz_idx // beam_size |
| sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) |
|
|
| |
| |
| |
| |
| |
| |
| seen = (sent << 32) + unfin_idx |
| unique_seen: List[int] = torch.unique(seen).tolist() |
|
|
| if self.match_source_len: |
| condition = step > torch.index_select(src_lengths, 0, unfin_idx) |
| eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores) |
| sent_list: List[int] = sent.tolist() |
| for i in range(bbsz_idx.size()[0]): |
| |
| |
| if len(finalized[sent_list[i]]) < beam_size: |
| if attn_clone is not None: |
| |
| hypo_attn = attn_clone[i] |
| else: |
| hypo_attn = torch.empty(0) |
|
|
| finalized[sent_list[i]].append( |
| { |
| "tokens": tokens_clone[i], |
| "score": eos_scores[i], |
| "attention": hypo_attn, |
| "alignment": torch.empty(0), |
| "positional_scores": pos_scores[i], |
| } |
| ) |
|
|
| newly_finished: List[int] = [] |
| for unique_s in unique_seen: |
| |
| unique_sent: int = unique_s >> 32 |
| unique_unfin_idx: int = unique_s - (unique_sent << 32) |
|
|
| if not finished[unique_sent] and self.is_finished( |
| step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size |
| ): |
| finished[unique_sent] = True |
| newly_finished.append(unique_unfin_idx) |
|
|
| return newly_finished |
|
|
| def is_finished( |
| self, |
| step: int, |
| unfin_idx: int, |
| max_len: int, |
| finalized_sent_len: int, |
| beam_size: int, |
| ): |
| """ |
| Check whether decoding for a sentence is finished, which |
| occurs when the list of finalized sentences has reached the |
| beam size, or when we reach the maximum length. |
| """ |
| assert finalized_sent_len <= beam_size |
| if finalized_sent_len == beam_size or step == max_len: |
| return True |
| return False |
|
|
|
|
| class EnsembleModel(nn.Module): |
| """A wrapper around an ensemble of models.""" |
|
|
| def __init__(self, models): |
| super().__init__() |
| self.models_size = len(models) |
| |
| self.single_model = models[0] |
| self.models = nn.ModuleList(models) |
|
|
| self.has_incremental: bool = False |
| if all( |
| hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) |
| for m in models |
| ): |
| self.has_incremental = True |
|
|
| def forward(self): |
| pass |
|
|
| def has_encoder(self): |
| return hasattr(self.single_model, "encoder") |
|
|
| def has_incremental_states(self): |
| return self.has_incremental |
|
|
| def max_decoder_positions(self): |
| return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize]) |
|
|
| @torch.jit.export |
| def forward_encoder(self, net_input: Dict[str, Tensor]): |
| if not self.has_encoder(): |
| return None |
| return [model.encoder.forward_torchscript(net_input) for model in self.models] |
|
|
| @torch.jit.export |
| def forward_decoder( |
| self, |
| tokens, |
| encoder_outs: List[Dict[str, List[Tensor]]], |
| incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], |
| temperature: float = 1.0, |
| constraint_trie=None, |
| constraint_start=None, |
| constraint_end=None, |
| gen_code=False, |
| zero_shot=False, |
| prefix_tokens=None |
| ): |
| log_probs = [] |
| avg_attn: Optional[Tensor] = None |
| encoder_out: Optional[Dict[str, List[Tensor]]] = None |
| code_mask = (tokens.new_ones(tokens.size(0))*gen_code).bool() |
| for i, model in enumerate(self.models): |
| if self.has_encoder(): |
| encoder_out = encoder_outs[i] |
| |
| if self.has_incremental_states(): |
| decoder_out = model.decoder.forward( |
| tokens, |
| code_masks=code_mask, |
| encoder_out=encoder_out, |
| incremental_state=incremental_states[i], |
| ) |
| else: |
| if hasattr(model, "decoder"): |
| decoder_out = model.decoder.forward(tokens, code_masks=code_mask, encoder_out=encoder_out) |
| else: |
| decoder_out = model.forward(tokens) |
|
|
| attn: Optional[Tensor] = None |
| decoder_len = len(decoder_out) |
| if decoder_len > 1 and decoder_out[1] is not None: |
| if isinstance(decoder_out[1], Tensor): |
| attn = decoder_out[1] |
| else: |
| attn_holder = decoder_out[1]["attn"] |
| if isinstance(attn_holder, Tensor): |
| attn = attn_holder |
| elif attn_holder is not None: |
| attn = attn_holder[0] |
| if attn is not None: |
| attn = attn[:, -1, :] |
|
|
| decoder_out_tuple = ( |
| decoder_out[0][:, -1:, :].div_(temperature), |
| None if decoder_len <= 1 else decoder_out[1], |
| ) |
|
|
| beam_size = decoder_out_tuple[0].size(0) // prefix_tokens.size(0) if prefix_tokens is not None else 0 |
| if constraint_trie is not None and not zero_shot: |
| assert constraint_start is None and constraint_end is None |
| constraint_masks = decoder_out_tuple[0].new_zeros(decoder_out_tuple[0].size()).bool() |
| constraint_prefix_tokens = tokens.tolist() |
| for token_index, constraint_prefix_token in enumerate(constraint_prefix_tokens): |
| prefix_len = prefix_tokens[token_index // beam_size].ne(1).sum().item() if prefix_tokens is not None else 0 |
| if len(constraint_prefix_token) > prefix_len: |
| constraint_prefix_token = [0] + constraint_prefix_token[prefix_len+1:] |
| constraint_nodes = constraint_trie.get_next_layer(constraint_prefix_token) |
| constraint_masks[token_index][:, constraint_nodes] = True |
| else: |
| constraint_masks[token_index] = True |
| decoder_out_tuple[0].masked_fill_(~constraint_masks, -math.inf) |
| if constraint_start is not None and constraint_end is not None and not zero_shot: |
| assert constraint_trie is None |
| decoder_out_tuple[0][:, :, 4:constraint_start] = -math.inf |
| decoder_out_tuple[0][:, :, constraint_end:] = -math.inf |
|
|
| probs = model.get_normalized_probs( |
| decoder_out_tuple, log_probs=True, sample=None |
| ) |
| if constraint_trie is not None and zero_shot: |
| assert constraint_start is None and constraint_end is None |
| constraint_masks = decoder_out_tuple[0].new_zeros(decoder_out_tuple[0].size()).bool() |
| constraint_prefix_tokens = tokens.tolist() |
| for token_index, constraint_prefix_token in enumerate(constraint_prefix_tokens): |
| constraint_nodes = constraint_trie.get_next_layer(constraint_prefix_token) |
| constraint_masks[token_index][:, constraint_nodes] = True |
| probs.masked_fill_(~constraint_masks, -math.inf) |
| if constraint_start is not None and constraint_end is not None and zero_shot: |
| assert constraint_trie is None |
| probs[:, :, 4:constraint_start] = -math.inf |
| probs[:, :, constraint_end:] = -math.inf |
| probs = probs[:, -1, :] |
| if self.models_size == 1: |
| return probs, attn |
|
|
| log_probs.append(probs) |
| if attn is not None: |
| if avg_attn is None: |
| avg_attn = attn |
| else: |
| avg_attn.add_(attn) |
|
|
| avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( |
| self.models_size |
| ) |
|
|
| if avg_attn is not None: |
| avg_attn.div_(self.models_size) |
| return avg_probs, avg_attn |
|
|
| @torch.jit.export |
| def reorder_encoder_out( |
| self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order |
| ): |
| """ |
| Reorder encoder output according to *new_order*. |
| |
| Args: |
| encoder_out: output from the ``forward()`` method |
| new_order (LongTensor): desired order |
| |
| Returns: |
| *encoder_out* rearranged according to *new_order* |
| """ |
| new_outs: List[Dict[str, List[Tensor]]] = [] |
| if not self.has_encoder(): |
| return new_outs |
| for i, model in enumerate(self.models): |
| assert encoder_outs is not None |
| new_outs.append( |
| model.encoder.reorder_encoder_out(encoder_outs[i], new_order) |
| ) |
| return new_outs |
|
|
| @torch.jit.export |
| def reorder_incremental_state( |
| self, |
| incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], |
| new_order, |
| ): |
| if not self.has_incremental_states(): |
| return |
| for i, model in enumerate(self.models): |
| model.decoder.reorder_incremental_state_scripting( |
| incremental_states[i], new_order |
| ) |
|
|
|
|
| class SequenceGeneratorWithAlignment(SequenceGenerator): |
| def __init__( |
| self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs |
| ): |
| """Generates translations of a given source sentence. |
| |
| Produces alignments following "Jointly Learning to Align and |
| Translate with Transformer Models" (Garg et al., EMNLP 2019). |
| |
| Args: |
| left_pad_target (bool, optional): Whether or not the |
| hypothesis should be left padded or not when they are |
| teacher forced for generating alignments. |
| """ |
| super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) |
| self.left_pad_target = left_pad_target |
|
|
| if print_alignment == "hard": |
| self.extract_alignment = utils.extract_hard_alignment |
| elif print_alignment == "soft": |
| self.extract_alignment = utils.extract_soft_alignment |
|
|
| @torch.no_grad() |
| def generate(self, models, sample, **kwargs): |
| finalized = super()._generate(sample, **kwargs) |
|
|
| src_tokens = sample["net_input"]["src_tokens"] |
| bsz = src_tokens.shape[0] |
| beam_size = self.beam_size |
| ( |
| src_tokens, |
| src_lengths, |
| prev_output_tokens, |
| tgt_tokens, |
| ) = self._prepare_batch_for_alignment(sample, finalized) |
| if any(getattr(m, "full_context_alignment", False) for m in self.model.models): |
| attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) |
| else: |
| attn = [ |
| finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) |
| for i in range(bsz * beam_size) |
| ] |
|
|
| if src_tokens.device != "cpu": |
| src_tokens = src_tokens.to("cpu") |
| tgt_tokens = tgt_tokens.to("cpu") |
| attn = [i.to("cpu") for i in attn] |
|
|
| |
| for i in range(bsz * beam_size): |
| alignment = self.extract_alignment( |
| attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos |
| ) |
| finalized[i // beam_size][i % beam_size]["alignment"] = alignment |
| return finalized |
|
|
| def _prepare_batch_for_alignment(self, sample, hypothesis): |
| src_tokens = sample["net_input"]["src_tokens"] |
| bsz = src_tokens.shape[0] |
| src_tokens = ( |
| src_tokens[:, None, :] |
| .expand(-1, self.beam_size, -1) |
| .contiguous() |
| .view(bsz * self.beam_size, -1) |
| ) |
| src_lengths = sample["net_input"]["src_lengths"] |
| src_lengths = ( |
| src_lengths[:, None] |
| .expand(-1, self.beam_size) |
| .contiguous() |
| .view(bsz * self.beam_size) |
| ) |
| prev_output_tokens = data_utils.collate_tokens( |
| [beam["tokens"] for example in hypothesis for beam in example], |
| self.pad, |
| self.eos, |
| self.left_pad_target, |
| move_eos_to_beginning=True, |
| ) |
| tgt_tokens = data_utils.collate_tokens( |
| [beam["tokens"] for example in hypothesis for beam in example], |
| self.pad, |
| self.eos, |
| self.left_pad_target, |
| move_eos_to_beginning=False, |
| ) |
| return src_tokens, src_lengths, prev_output_tokens, tgt_tokens |
|
|
|
|
| class EnsembleModelWithAlignment(EnsembleModel): |
| """A wrapper around an ensemble of models.""" |
|
|
| def __init__(self, models): |
| super().__init__(models) |
|
|
| def forward_align(self, src_tokens, src_lengths, prev_output_tokens): |
| avg_attn = None |
| for model in self.models: |
| decoder_out = model(src_tokens, src_lengths, prev_output_tokens) |
| attn = decoder_out[1]["attn"][0] |
| if avg_attn is None: |
| avg_attn = attn |
| else: |
| avg_attn.add_(attn) |
| if len(self.models) > 1: |
| avg_attn.div_(len(self.models)) |
| return avg_attn |
|
|