# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import os import warnings import torch from fairseq import metrics, search, tokenizer, utils from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary logger = logging.getLogger(__name__) class FairseqTask(object): """ Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss. """ @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" pass @staticmethod def logging_outputs_can_be_summed(criterion) -> bool: """ Whether the logging outputs returned by `train_step` and `valid_step` can be summed across workers prior to calling `aggregate_logging_outputs`. Setting this to True will improves distributed training speed. """ return criterion.logging_outputs_can_be_summed() def __init__(self, args): self.args = args self.datasets = {} self.dataset_to_epoch_iter = {} @classmethod def load_dictionary(cls, filename): """Load the dictionary from the filename Args: filename (str): the filename """ return Dictionary.load(filename) @classmethod def build_dictionary( cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 ): """Build the dictionary Args: filenames (list): list of filenames workers (int): number of concurrent workers threshold (int): defines the minimum word count nwords (int): defines the total number of words in the final dictionary, including special symbols padding_factor (int): can be used to pad the dictionary size to be a multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores). """ d = Dictionary() for filename in filenames: Dictionary.add_file_to_dictionary( filename, d, tokenizer.tokenize_line, workers ) d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) return d @classmethod def setup_task(cls, args, **kwargs): """Setup the task (e.g., load dictionaries). Args: args (argparse.Namespace): parsed command-line arguments """ return cls(args, **kwargs) def has_sharded_data(self, split): return (os.pathsep in getattr(self.args, 'data', '')) def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ raise NotImplementedError def dataset(self, split): """ Return a loaded dataset split. Args: split (str): name of the split (e.g., train, valid, test) Returns: a :class:`~fairseq.data.FairseqDataset` corresponding to *split* """ from fairseq.data import FairseqDataset if split not in self.datasets: raise KeyError("Dataset not loaded: " + split) if not isinstance(self.datasets[split], FairseqDataset): raise TypeError("Datasets are expected to be of type FairseqDataset") return self.datasets[split] def filter_indices_by_size(self, indices, dataset, max_positions, ignore_invalid_inputs): """ Filter examples that are too large Args: indices (np.array): original array of sample indices dataset (~fairseq.data.FairseqDataset): dataset to batch max_positions (optional): max sentence length supported by the model (default: None). ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long (default: False). Returns: np.array: array of filtered sample indices """ indices, ignored = dataset.filter_indices_by_size(indices, max_positions) if len(ignored) > 0: if not ignore_invalid_inputs: raise Exception(( 'Size of sample #{} is invalid (={}) since max_positions={}, ' 'skip this example with --skip-invalid-size-inputs-valid-test' ).format(ignored[0], dataset.size(ignored[0]), max_positions)) logger.warning(( '{} samples have invalid sizes and will be skipped, ' 'max_positions={}, first few sample ids={}' ).format(len(ignored), max_positions, ignored[:10])) return indices def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1 ): """ Get an iterator that yields batches of data from the given dataset. Args: dataset (~fairseq.data.FairseqDataset): dataset to batch max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). max_positions (optional): max sentence length supported by the model (default: None). ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long (default: False). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). seed (int, optional): seed for random number generator for reproducibility (default: 1). num_shards (int, optional): shard the data iterator into N shards (default: 1). shard_id (int, optional): which shard of the data iterator to return (default: 0). num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). epoch (int, optional): the epoch to start the iterator from (default: 1). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ # For default fairseq task, return same iterator across epochs # as datasets are not dynamic, can be overridden in task specific # setting. if dataset in self.dataset_to_epoch_iter: return self.dataset_to_epoch_iter[dataset] assert isinstance(dataset, FairseqDataset) # initialize the dataset with the correct starting epoch dataset.set_epoch(epoch) # get indices ordered by example size with data_utils.numpy_seed(seed): indices = dataset.ordered_indices() # filter examples that are too large if max_positions is not None: indices = self.filter_indices_by_size(indices, dataset, max_positions, ignore_invalid_inputs) # create mini-batches with given size constraints batch_sampler = dataset.batch_by_size( indices, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, ) # return a reusable, sharded iterator epoch_iter = iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, buffer_size=getattr(self.args, 'data_buffer_size', 0) ) self.dataset_to_epoch_iter[dataset] = epoch_iter return epoch_iter def build_model(self, args): """ Build the :class:`~fairseq.models.BaseFairseqModel` instance for this task. Args: args (argparse.Namespace): parsed command-line arguments Returns: a :class:`~fairseq.models.BaseFairseqModel` instance """ from fairseq import models, quantization_utils model = models.build_model(args, self) if getattr(args, 'tpu', False): model.prepare_for_tpu_() model = quantization_utils.quantize_model_scalar(model, args) return model def build_criterion(self, args): """ Build the :class:`~fairseq.criterions.FairseqCriterion` instance for this task. Args: args (argparse.Namespace): parsed command-line arguments Returns: a :class:`~fairseq.criterions.FairseqCriterion` instance """ from fairseq import criterions return criterions.build_criterion(args, self) def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None ): if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer( self.target_dictionary, compute_alignment=getattr(args, "print_alignment", False), ) from fairseq.sequence_generator import ( SequenceGenerator, SequenceGeneratorWithAlignment, ) # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) sampling_topk = getattr(args, "sampling_topk", -1) sampling_topp = getattr(args, "sampling_topp", -1.0) diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) if ( sum( int(cond) for cond in [ sampling, diverse_beam_groups > 0, match_source_len, diversity_rate > 0, ] ) > 1 ): raise ValueError("Provided Search parameters are mutually exclusive.") assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" if sampling: search_strategy = search.Sampling( self.target_dictionary, sampling_topk, sampling_topp ) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch( self.target_dictionary, diverse_beam_groups, diverse_beam_strength ) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) elif diversity_rate > -1: search_strategy = search.DiverseSiblingsSearch( self.target_dictionary, diversity_rate ) else: search_strategy = search.BeamSearch(self.target_dictionary) if seq_gen_cls is None: if getattr(args, "print_alignment", False): seq_gen_cls = SequenceGeneratorWithAlignment else: seq_gen_cls = SequenceGenerator extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} return seq_gen_cls( models, self.target_dictionary, beam_size=getattr(args, "beam", 5), max_len_a=getattr(args, "max_len_a", 0), max_len_b=getattr(args, "max_len_b", 200), min_len=getattr(args, "min_len", 1), normalize_scores=(not getattr(args, "unnormalized", False)), len_penalty=getattr(args, "lenpen", 1), unk_penalty=getattr(args, "unkpen", 0), temperature=getattr(args, "temperature", 1.0), match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, **extra_gen_cls_kwargs, ) def train_step( self, sample, model, criterion, optimizer, update_num, ignore_grad=False ): """ Do forward and backward, and return the loss as computed by *criterion* for the given *model* and *sample*. Args: sample (dict): the mini-batch. The format is defined by the :class:`~fairseq.data.FairseqDataset`. model (~fairseq.models.BaseFairseqModel): the model criterion (~fairseq.criterions.FairseqCriterion): the criterion optimizer (~fairseq.optim.FairseqOptimizer): the optimizer update_num (int): the current update ignore_grad (bool): multiply loss by 0 if this is set to True Returns: tuple: - the loss - the sample size, which is used as the denominator for the gradient - logging outputs to display while training """ model.train() model.set_num_updates(update_num) with torch.autograd.profiler.record_function("forward"): loss, sample_size, logging_output = criterion(model, sample) if ignore_grad: loss *= 0 with torch.autograd.profiler.record_function("backward"): optimizer.backward(loss) return loss, sample_size, logging_output def valid_step(self, sample, model, criterion): model.eval() with torch.no_grad(): loss, sample_size, logging_output = criterion(model, sample) return loss, sample_size, logging_output def inference_step(self, generator, models, sample, prefix_tokens=None): with torch.no_grad(): return generator.generate(models, sample, prefix_tokens=prefix_tokens) def begin_epoch(self, epoch, model): """Hook function called before the start of each epoch.""" pass def aggregate_logging_outputs(self, logging_outputs, criterion): """[deprecated] Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( "The aggregate_logging_outputs API is deprecated. " "Please use the reduce_metrics API instead." ) with metrics.aggregate() as agg: self.reduce_metrics(logging_outputs, criterion) return agg.get_smoothed_values() def reduce_metrics(self, logging_outputs, criterion): """Aggregate logging outputs from data parallel training.""" # backward compatibility for tasks that override aggregate_logging_outputs base_func = FairseqTask.aggregate_logging_outputs self_func = getattr(self, "aggregate_logging_outputs").__func__ if self_func is not base_func: utils.deprecation_warning( "Tasks should implement the reduce_metrics API. " "Falling back to deprecated aggregate_logging_outputs API." ) agg_logging_outputs = self.aggregate_logging_outputs( logging_outputs, criterion ) for k, v in agg_logging_outputs.items(): metrics.log_scalar(k, v) return if not any("ntokens" in log for log in logging_outputs): warnings.warn( "ntokens not found in Criterion logging outputs, cannot log wpb or wps" ) else: ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) metrics.log_scalar("wpb", ntokens, priority=180, round=1) metrics.log_speed("wps", ntokens, priority=90, round=1) if not any("nsentences" in log for log in logging_outputs): warnings.warn( "nsentences not found in Criterion logging outputs, cannot log bsz" ) else: nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) metrics.log_scalar("bsz", nsentences, priority=190, round=1) criterion.__class__.reduce_metrics(logging_outputs) def max_positions(self): """Return the max input length allowed by the task.""" return None @property def source_dictionary(self): """Return the source :class:`~fairseq.data.Dictionary` (if applicable for this task).""" raise NotImplementedError @property def target_dictionary(self): """Return the target :class:`~fairseq.data.Dictionary` (if applicable for this task).""" raise NotImplementedError