Spaces:
Runtime error
Runtime error
Dit-document-layout-analysis
/
unilm
/decoding
/GAD
/block_plugins
/tasks
/translation_lev_modified.py
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from dataclasses import dataclass, field | |
| from math import log | |
| import torch | |
| from fairseq import utils | |
| from fairseq.data import LanguagePairDataset | |
| from fairseq.dataclass import ChoiceEnum | |
| from fairseq.tasks import register_task | |
| from fairseq.tasks.translation import TranslationConfig, TranslationTask, load_langpair_dataset | |
| from fairseq.utils import new_arange | |
| import logging | |
| from omegaconf import II | |
| import numpy as np | |
| NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask", "block_mask"]) | |
| class TranslationLevenshteinConfig(TranslationConfig): | |
| noise: NOISE_CHOICES = field( | |
| default="random_delete", | |
| metadata={ | |
| "help": "type of noise" | |
| }, | |
| ) | |
| start_p: float = field( | |
| default=0.5, metadata={"help": "minus prob"} | |
| ) | |
| minus_p: float = field( | |
| default=0.2, metadata={"help": "minus prob"} | |
| ) | |
| total_up: int = field( | |
| default=300000, metadata={"help": "total updates"} | |
| ) | |
| block_size: int = field( | |
| default=5, metadata={"help": "block size"} | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class TranslationLevenshteinModifiedTask(TranslationTask): | |
| """ | |
| Translation (Sequence Generation) task for Levenshtein Transformer | |
| See `"Levenshtein Transformer" <https://arxiv.org/abs/1905.11006>`_. | |
| """ | |
| cfg: TranslationLevenshteinConfig | |
| def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
| """Load a given dataset split. | |
| Args: | |
| split (str): name of the split (e.g., train, valid, test) | |
| """ | |
| paths = utils.split_paths(self.cfg.data) | |
| assert len(paths) > 0 | |
| data_path = paths[(epoch - 1) % len(paths)] | |
| # infer langcode | |
| src, tgt = self.cfg.source_lang, self.cfg.target_lang | |
| self.datasets[split] = load_langpair_dataset( | |
| data_path, | |
| split, | |
| src, | |
| self.src_dict, | |
| tgt, | |
| self.tgt_dict, | |
| combine=combine, | |
| dataset_impl=self.cfg.dataset_impl, | |
| upsample_primary=self.cfg.upsample_primary, | |
| left_pad_source=self.cfg.left_pad_source, | |
| left_pad_target=self.cfg.left_pad_target, | |
| max_source_positions=self.cfg.max_source_positions, | |
| max_target_positions=self.cfg.max_target_positions, | |
| truncate_source=self.cfg.truncate_source, | |
| ) | |
| def inject_noise(self, target_tokens): | |
| def _random_delete(target_tokens): | |
| pad = self.tgt_dict.pad() | |
| bos = self.tgt_dict.bos() | |
| eos = self.tgt_dict.eos() | |
| max_len = target_tokens.size(1) | |
| target_mask = target_tokens.eq(pad) | |
| target_score = target_tokens.clone().float().uniform_() | |
| target_score.masked_fill_( | |
| target_tokens.eq(bos) | target_tokens.eq(eos), 0.0 | |
| ) | |
| target_score.masked_fill_(target_mask, 1) | |
| target_score, target_rank = target_score.sort(1) | |
| target_length = target_mask.size(1) - target_mask.float().sum( | |
| 1, keepdim=True | |
| ) | |
| # do not delete <bos> and <eos> (we assign 0 score for them) | |
| target_cutoff = ( | |
| 2 | |
| + ( | |
| (target_length - 2) | |
| * target_score.new_zeros(target_score.size(0), 1).uniform_() | |
| ).long() | |
| ) | |
| target_cutoff = target_score.sort(1)[1] >= target_cutoff | |
| prev_target_tokens = ( | |
| target_tokens.gather(1, target_rank) | |
| .masked_fill_(target_cutoff, pad) | |
| .gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1]) | |
| ) | |
| prev_target_tokens = prev_target_tokens[ | |
| :, : prev_target_tokens.ne(pad).sum(1).max() | |
| ] | |
| return prev_target_tokens | |
| def _random_mask(target_tokens): | |
| pad = self.tgt_dict.pad() | |
| bos = self.tgt_dict.bos() | |
| eos = self.tgt_dict.eos() | |
| unk = self.tgt_dict.unk() | |
| target_masks = ( | |
| target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos) | |
| ) | |
| target_score = target_tokens.clone().float().uniform_() | |
| target_score.masked_fill_(~target_masks, 2.0) | |
| target_length = target_masks.sum(1).float() | |
| target_length = target_length * target_length.clone().uniform_() | |
| target_length = target_length + 1 # make sure to mask at least one token. | |
| _, target_rank = target_score.sort(1) | |
| target_cutoff = new_arange(target_rank) < target_length[:, None].long() | |
| prev_target_tokens = target_tokens.masked_fill( | |
| target_cutoff.scatter(1, target_rank, target_cutoff), unk | |
| ) | |
| return prev_target_tokens | |
| def _full_mask(target_tokens): | |
| pad = self.tgt_dict.pad() | |
| bos = self.tgt_dict.bos() | |
| eos = self.tgt_dict.eos() | |
| unk = self.tgt_dict.unk() | |
| target_mask = ( | |
| target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad) | |
| ) | |
| return target_tokens.masked_fill(~target_mask, unk) | |
| def _block_mask(target_tokens): | |
| block_size = self.cfg.block_size | |
| pad = self.tgt_dict.pad() | |
| unk = self.tgt_dict.unk() | |
| target_masks = target_tokens.ne(pad) | |
| target_length = target_masks.sum(1).float() | |
| cutoff_length = target_length * target_length.clone().uniform_() | |
| cutoff_length = cutoff_length.int() + 1 # make sure to mask at least one token. | |
| prev_target_tokens = torch.ones((target_tokens.size(0), | |
| target_tokens.size(1) + block_size)).to(target_tokens) | |
| padded_target_tokens = torch.ones((target_tokens.size(0), | |
| target_tokens.size(1) + block_size)).to(target_tokens) | |
| for i in range(target_tokens.size(0)): | |
| remain_length = target_length[i].int() - cutoff_length[i] | |
| prev_target_tokens[i][:remain_length] = target_tokens[i][:remain_length] | |
| prev_target_tokens[i][remain_length:block_size + remain_length] = unk | |
| padded_target_tokens[i][:target_tokens.size(1)] = target_tokens[i] | |
| prev_target_tokens = prev_target_tokens[ | |
| :, : prev_target_tokens.ne(pad).sum(1).max() | |
| ] | |
| padded_target_tokens = padded_target_tokens[ | |
| :, : prev_target_tokens.ne(pad).sum(1).max() | |
| ] | |
| return prev_target_tokens, padded_target_tokens | |
| if self.cfg.noise == "random_delete": | |
| return _random_delete(target_tokens) | |
| elif self.cfg.noise == "random_mask": | |
| return _random_mask(target_tokens) | |
| elif self.cfg.noise == "block_mask": | |
| return _block_mask(target_tokens) | |
| elif self.cfg.noise == "full_mask": | |
| return _full_mask(target_tokens) | |
| elif self.cfg.noise == "no_noise": | |
| return target_tokens | |
| else: | |
| raise NotImplementedError | |
| def build_generator(self, models, args, **unused): | |
| # add models input to match the API for SequenceGenerator | |
| from fairseq.iterative_refinement_generator import IterativeRefinementGenerator | |
| return IterativeRefinementGenerator( | |
| self.target_dictionary, | |
| eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0), | |
| max_iter=getattr(args, "iter_decode_max_iter", 10), | |
| beam_size=getattr(args, "iter_decode_with_beam", 1), | |
| reranking=getattr(args, "iter_decode_with_external_reranker", False), | |
| decoding_format=getattr(args, "decoding_format", None), | |
| adaptive=not getattr(args, "iter_decode_force_max_iter", False), | |
| retain_history=getattr(args, "retain_iter_history", False), | |
| ) | |
| def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
| if constraints is not None: | |
| # Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/ | |
| raise NotImplementedError( | |
| "Constrained decoding with the translation_lev task is not supported" | |
| ) | |
| return LanguagePairDataset( | |
| src_tokens, src_lengths, self.source_dictionary, append_bos=False | |
| ) | |
| def train_step( | |
| self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
| ): | |
| model.train() | |
| train_ratio = max(0, min(1, update_num / self.cfg.total_up)) | |
| sample["glat"] = {"context_p": self.cfg.start_p - self.cfg.minus_p * train_ratio} | |
| sample["prev_target"], sample["target"] = self.inject_noise(sample["target"]) | |
| with torch.autograd.profiler.record_function("forward"): | |
| loss, sample_size, logging_output = criterion(model, sample) | |
| if ignore_grad: | |
| loss *= 0 | |
| with torch.autograd.profiler.record_function("backward"): | |
| optimizer.backward(loss) | |
| return loss, sample_size, logging_output | |
| def valid_step(self, sample, model, criterion): | |
| model.eval() | |
| with torch.no_grad(): | |
| sample["prev_target"], sample["target"] = self.inject_noise(sample["target"]) | |
| loss, sample_size, logging_output = criterion(model, sample) | |
| EVAL_BLEU_ORDER = 4 | |
| if self.cfg.eval_bleu: | |
| bleu = self._inference_with_bleu(self.sequence_generator, sample, model) | |
| logging_output["_bleu_sys_len"] = bleu.sys_len | |
| logging_output["_bleu_ref_len"] = bleu.ref_len | |
| # we split counts into separate entries so that they can be | |
| # summed efficiently across workers using fast-stat-sync | |
| assert len(bleu.counts) == EVAL_BLEU_ORDER | |
| for i in range(EVAL_BLEU_ORDER): | |
| logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] | |
| logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] | |
| return loss, sample_size, logging_output | |
| def _inference_with_bleu(self, generator, sample, model): | |
| import sacrebleu | |
| def decode(toks, escape_unk=False): | |
| s = self.tgt_dict.string( | |
| toks.int().cpu(), | |
| self.cfg.eval_bleu_remove_bpe, | |
| # The default unknown string in fairseq is `<unk>`, but | |
| # this is tokenized by sacrebleu as `< unk >`, inflating | |
| # BLEU scores. Instead, we use a somewhat more verbose | |
| # alternative that is unlikely to appear in the real | |
| # reference, but doesn't get split into multiple tokens. | |
| unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), | |
| ) | |
| if self.tokenizer: | |
| s = self.tokenizer.decode(s) | |
| return s | |
| gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) | |
| hyps, refs = [], [] | |
| for i in range(len(gen_out)): | |
| hyps.append(decode(gen_out[i][0]["tokens"])) | |
| refs.append( | |
| decode( | |
| utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), | |
| escape_unk=True, # don't count <unk> as matches to the hypo | |
| ) | |
| ) | |
| if self.cfg.eval_bleu_print_samples: | |
| logger.info("example hypothesis: " + hyps[0]) | |
| logger.info("example reference: " + refs[0]) | |
| if self.cfg.eval_tokenized_bleu: | |
| return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") | |
| else: | |
| return sacrebleu.corpus_bleu(hyps, [refs]) | |