| |
| |
| |
| |
|
|
| from dataclasses import dataclass, field |
| import logging |
| import os |
| import math |
| import torch |
| from typing import Dict, Optional |
|
|
| from fairseq import search |
| from fairseq.data import FairseqDataset, iterators |
| from fairseq.optim.amp_optimizer import AMPOptimizer |
| from fairseq.dataclass import FairseqDataclass |
| from fairseq.tasks import FairseqTask, register_task |
| from omegaconf import DictConfig |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class OFAConfig(FairseqDataclass): |
| data: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "comma separated path to data list, will be iterated upon during epochs " |
| "in round-robin manner; valid data are always in the last" |
| }, |
| ) |
| selected_cols: Optional[str] = field( |
| default=None, |
| metadata={"help": "selected cols"}, |
| ) |
| bpe_dir: Optional[str] = field( |
| default=None, |
| metadata={"help": "bpe dir"}, |
| ) |
| max_source_positions: int = field( |
| default=1024, metadata={"help": "max number of tokens in the source sequence"} |
| ) |
| max_target_positions: int = field( |
| default=1024, metadata={"help": "max number of tokens in the target sequence"} |
| ) |
| max_src_length: int = field( |
| default=128, metadata={"help": "the maximum src sequence length"} |
| ) |
| max_tgt_length: int = field( |
| default=30, metadata={"help": "the maximum target sequence length"} |
| ) |
|
|
| code_dict_size: int = field( |
| default=8192, metadata={"help": "code dict size"} |
| ) |
| patch_image_size: int = field( |
| default=480, metadata={"help": "patch image size"} |
| ) |
| num_bins: int = field( |
| default=1000, metadata={"help": "number of quantization bins"} |
| ) |
|
|
| imagenet_default_mean_and_std: bool = field( |
| default=False, |
| metadata={"help": "imagenet normalize"}, |
| ) |
| constraint_range: Optional[str] = field( |
| default=None, |
| metadata={"help": "constraint range"} |
| ) |
|
|
|
|
| @register_task("ofa", dataclass=OFAConfig) |
| class OFATask(FairseqTask): |
| def __init__(self, cfg: OFAConfig, src_dict, tgt_dict): |
| super().__init__(cfg) |
| self.src_dict = src_dict |
| self.tgt_dict = tgt_dict |
|
|
| @classmethod |
| def setup_task(cls, cfg: DictConfig, **kwargs): |
| """Setup the task.""" |
|
|
| |
| src_dict = cls.load_dictionary( |
| os.path.join(cfg.bpe_dir, "dict.txt") |
| ) |
| tgt_dict = cls.load_dictionary( |
| os.path.join(cfg.bpe_dir, "dict.txt") |
| ) |
| src_dict.add_symbol("<mask>") |
| tgt_dict.add_symbol("<mask>") |
| for i in range(cfg.code_dict_size): |
| src_dict.add_symbol("<code_{}>".format(i)) |
| tgt_dict.add_symbol("<code_{}>".format(i)) |
| |
| for i in range(cfg.num_bins): |
| src_dict.add_symbol("<bin_{}>".format(i)) |
| tgt_dict.add_symbol("<bin_{}>".format(i)) |
|
|
| logger.info("source dictionary: {} types".format(len(src_dict))) |
| logger.info("target dictionary: {} types".format(len(tgt_dict))) |
| return cls(cfg, src_dict, tgt_dict) |
|
|
| def get_batch_iterator( |
| self, |
| dataset, |
| max_tokens=None, |
| max_sentences=None, |
| max_positions=None, |
| ignore_invalid_inputs=False, |
| required_batch_size_multiple=1, |
| seed=1, |
| num_shards=1, |
| shard_id=0, |
| num_workers=0, |
| epoch=1, |
| data_buffer_size=0, |
| disable_iterator_cache=False, |
| ): |
| assert isinstance(dataset, FairseqDataset) |
|
|
| |
| dataset.set_epoch(epoch) |
|
|
| |
| batch_sampler = [ |
| [j for j in range(i, min(i + max_sentences, len(dataset)))] |
| for i in range(0, len(dataset), max_sentences) |
| ] |
| total_row_count = dataset.dataset.get_total_row_count() |
| num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences) |
| if len(batch_sampler) < num_batches: |
| batch_sampler.append([]) |
|
|
| |
| epoch_iter = iterators.EpochBatchIterator( |
| dataset=dataset, |
| collate_fn=dataset.collater, |
| batch_sampler=batch_sampler, |
| seed=seed, |
| num_shards=1, |
| shard_id=0, |
| num_workers=num_workers, |
| epoch=epoch, |
| buffer_size=data_buffer_size |
| ) |
|
|
| return epoch_iter |
|
|
| def build_model(self, cfg: FairseqDataclass): |
| model = super().build_model(cfg) |
| bpe_dict = { |
| "_name": "gpt2", |
| "gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"), |
| "gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe") |
| } |
| bpe_dict = DictConfig(bpe_dict) |
| self.bpe = self.build_bpe(bpe_dict) |
| return model |
|
|
| def build_generator( |
| self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, |
| ): |
| """ |
| Build a :class:`~fairseq.SequenceGenerator` instance for this |
| task. |
| |
| Args: |
| models (List[~fairseq.models.FairseqModel]): ensemble of models |
| args (fairseq.dataclass.configs.GenerationConfig): |
| configuration object (dataclass) for generation |
| extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass |
| through to SequenceGenerator |
| prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): |
| If provided, this function constrains the beam search to |
| allowed tokens only at each step. The provided function |
| should take 2 arguments: the batch ID (`batch_id: int`) |
| and a unidimensional tensor of token ids (`inputs_ids: |
| torch.Tensor`). It has to return a `List[int]` with the |
| allowed tokens for the next generation step conditioned |
| on the previously generated tokens (`inputs_ids`) and |
| the batch ID (`batch_id`). This argument is useful for |
| constrained generation conditioned on the prefix, as |
| described in "Autoregressive Entity Retrieval" |
| (https://arxiv.org/abs/2010.00904) and |
| https://github.com/facebookresearch/GENRE. |
| """ |
| if getattr(args, "score_reference", False): |
| from fairseq.sequence_scorer import SequenceScorer |
|
|
| return SequenceScorer( |
| self.target_dictionary, |
| compute_alignment=getattr(args, "print_alignment", False), |
| ) |
|
|
| from fairseq.sequence_generator import ( |
| |
| SequenceGeneratorWithAlignment, |
| ) |
| from models.sequence_generator import SequenceGenerator |
|
|
| |
| sampling = getattr(args, "sampling", False) |
| sampling_topk = getattr(args, "sampling_topk", -1) |
| sampling_topp = getattr(args, "sampling_topp", -1.0) |
| diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) |
| diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) |
| match_source_len = getattr(args, "match_source_len", False) |
| diversity_rate = getattr(args, "diversity_rate", -1) |
| constrained = getattr(args, "constraints", False) |
| if prefix_allowed_tokens_fn is None: |
| prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) |
| if ( |
| sum( |
| int(cond) |
| for cond in [ |
| sampling, |
| diverse_beam_groups > 0, |
| match_source_len, |
| diversity_rate > 0, |
| ] |
| ) |
| > 1 |
| ): |
| raise ValueError("Provided Search parameters are mutually exclusive.") |
| assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" |
| assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" |
|
|
| if sampling: |
| search_strategy = search.Sampling( |
| self.target_dictionary, sampling_topk, sampling_topp |
| ) |
| elif diverse_beam_groups > 0: |
| search_strategy = search.DiverseBeamSearch( |
| self.target_dictionary, diverse_beam_groups, diverse_beam_strength |
| ) |
| elif match_source_len: |
| |
| |
| |
| search_strategy = search.LengthConstrainedBeamSearch( |
| self.target_dictionary, |
| min_len_a=1, |
| min_len_b=0, |
| max_len_a=1, |
| max_len_b=0, |
| ) |
| elif diversity_rate > -1: |
| search_strategy = search.DiverseSiblingsSearch( |
| self.target_dictionary, diversity_rate |
| ) |
| elif constrained: |
| search_strategy = search.LexicallyConstrainedBeamSearch( |
| self.target_dictionary, args.constraints |
| ) |
| elif prefix_allowed_tokens_fn: |
| search_strategy = search.PrefixConstrainedBeamSearch( |
| self.target_dictionary, prefix_allowed_tokens_fn |
| ) |
| else: |
| search_strategy = search.BeamSearch(self.target_dictionary) |
|
|
| extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} |
| if seq_gen_cls is None: |
| if getattr(args, "print_alignment", False): |
| seq_gen_cls = SequenceGeneratorWithAlignment |
| extra_gen_cls_kwargs["print_alignment"] = args.print_alignment |
| else: |
| seq_gen_cls = SequenceGenerator |
|
|
| return seq_gen_cls( |
| models, |
| self.target_dictionary, |
| beam_size=getattr(args, "beam", 5), |
| max_len_a=getattr(args, "max_len_a", 0), |
| max_len_b=getattr(args, "max_len_b", 200), |
| min_len=getattr(args, "min_len", 1), |
| normalize_scores=(not getattr(args, "unnormalized", False)), |
| len_penalty=getattr(args, "lenpen", 1), |
| unk_penalty=getattr(args, "unkpen", 0), |
| temperature=getattr(args, "temperature", 1.0), |
| match_source_len=getattr(args, "match_source_len", False), |
| no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), |
| search_strategy=search_strategy, |
| constraint_range=self.cfg.constraint_range, |
| **extra_gen_cls_kwargs, |
| ) |
|
|
| def train_step( |
| self, sample, model, criterion, optimizer, update_num, ignore_grad=False, **extra_kwargs |
| ): |
| """ |
| Do forward and backward, and return the loss as computed by *criterion* |
| for the given *model* and *sample*. |
| |
| Args: |
| sample (dict): the mini-batch. The format is defined by the |
| :class:`~fairseq.data.FairseqDataset`. |
| model (~fairseq.models.BaseFairseqModel): the model |
| criterion (~fairseq.criterions.FairseqCriterion): the criterion |
| optimizer (~fairseq.optim.FairseqOptimizer): the optimizer |
| update_num (int): the current update |
| ignore_grad (bool): multiply loss by 0 if this is set to True |
| |
| Returns: |
| tuple: |
| - the loss |
| - the sample size, which is used as the denominator for the |
| gradient |
| - logging outputs to display while training |
| """ |
| model.train() |
| model.set_num_updates(update_num) |
| with torch.autograd.profiler.record_function("forward"): |
| with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): |
| loss, sample_size, logging_output = criterion(model, sample, update_num=update_num) |
| if ignore_grad: |
| loss *= 0 |
| with torch.autograd.profiler.record_function("backward"): |
| optimizer.backward(loss) |
| return loss, sample_size, logging_output |
|
|
| def max_positions(self): |
| """Return the max sentence length allowed by the task.""" |
| return (self.cfg.max_source_positions, self.cfg.max_target_positions) |
|
|
| @property |
| def source_dictionary(self): |
| """Return the source :class:`~fairseq.data.Dictionary`.""" |
| return self.src_dict |
|
|
| @property |
| def target_dictionary(self): |
| """Return the target :class:`~fairseq.data.Dictionary`.""" |
| return self.tgt_dict |
|
|