| |
| |
| |
| |
|
|
| import datetime |
| import logging |
| import time |
|
|
| import torch |
| from fairseq.data import ( |
| FairseqDataset, |
| LanguagePairDataset, |
| ListDataset, |
| data_utils, |
| iterators, |
| ) |
| from fairseq.data.multilingual.multilingual_data_manager import ( |
| MultilingualDatasetManager, |
| ) |
| from fairseq.data.multilingual.sampling_method import SamplingMethod |
| from fairseq.tasks import LegacyFairseqTask, register_task |
| from fairseq.utils import FileContentsAction |
|
|
|
|
| |
| def get_time_gap(s, e): |
| return ( |
| datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s) |
| ).__str__() |
|
|
|
|
| |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @register_task("translation_multi_simple_epoch") |
| class TranslationMultiSimpleEpochTask(LegacyFairseqTask): |
| """ |
| Translate from one (source) language to another (target) language. |
| |
| Args: |
| langs (List[str]): a list of languages that are being supported |
| dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries |
| training (bool): whether the task should be configured for training or not |
| |
| .. note:: |
| |
| The translation task is compatible with :mod:`fairseq-train`, |
| :mod:`fairseq-generate` and :mod:`fairseq-interactive`. |
| |
| The translation task provides the following additional command-line |
| arguments: |
| |
| .. argparse:: |
| :ref: fairseq.tasks.translation_parser |
| :prog: |
| """ |
|
|
| @staticmethod |
| def add_args(parser): |
| """Add task-specific arguments to the parser.""" |
| |
| parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', |
| help='inference source language') |
| parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', |
| help='inference target language') |
| parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', |
| help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr', |
| action=FileContentsAction) |
| parser.add_argument('--keep-inference-langtok', action='store_true', |
| help='keep language tokens in inference output (e.g. for analysis or debugging)') |
|
|
| SamplingMethod.add_arguments(parser) |
| MultilingualDatasetManager.add_args(parser) |
| |
|
|
| def __init__(self, args, langs, dicts, training): |
| super().__init__(args) |
| self.langs = langs |
| self.dicts = dicts |
| self.training = training |
| if training: |
| self.lang_pairs = args.lang_pairs |
| else: |
| self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] |
| |
| |
| |
| |
| |
| self.eval_lang_pairs = self.lang_pairs |
| |
| |
| |
| self.model_lang_pairs = self.lang_pairs |
| self.source_langs = [d.split("-")[0] for d in self.lang_pairs] |
| self.target_langs = [d.split("-")[1] for d in self.lang_pairs] |
| self.check_dicts(self.dicts, self.source_langs, self.target_langs) |
|
|
| self.sampling_method = SamplingMethod.build_sampler(args, self) |
| self.data_manager = MultilingualDatasetManager.setup_data_manager( |
| args, self.lang_pairs, langs, dicts, self.sampling_method |
| ) |
|
|
| def check_dicts(self, dicts, source_langs, target_langs): |
| if self.args.source_dict is not None or self.args.target_dict is not None: |
| |
| return |
| src_dict = dicts[source_langs[0]] |
| tgt_dict = dicts[target_langs[0]] |
| for src_lang in source_langs: |
| assert ( |
| src_dict == dicts[src_lang] |
| ), "Diffrent dictionary are specified for different source languages; " |
| "TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages" |
| for tgt_lang in target_langs: |
| assert ( |
| tgt_dict == dicts[tgt_lang] |
| ), "Diffrent dictionary are specified for different target languages; " |
| "TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages" |
|
|
| @classmethod |
| def setup_task(cls, args, **kwargs): |
| langs, dicts, training = MultilingualDatasetManager.prepare( |
| cls.load_dictionary, args, **kwargs |
| ) |
| return cls(args, langs, dicts, training) |
|
|
| def has_sharded_data(self, split): |
| return self.data_manager.has_sharded_data(split) |
|
|
| def load_dataset(self, split, epoch=1, combine=False, **kwargs): |
| """Load a given dataset split. |
| |
| Args: |
| split (str): name of the split (e.g., train, valid, test) |
| """ |
| if split in self.datasets: |
| dataset = self.datasets[split] |
| if self.has_sharded_data(split): |
| if self.args.virtual_epoch_size is not None: |
| if dataset.load_next_shard: |
| shard_epoch = dataset.shard_epoch |
| else: |
| |
| |
| return |
| else: |
| shard_epoch = epoch |
| else: |
| |
| shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch) |
| logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}") |
| logger.info(f"mem usage: {data_utils.get_mem_usage()}") |
| if split in self.datasets: |
| del self.datasets[split] |
| logger.info("old dataset deleted manually") |
| logger.info(f"mem usage: {data_utils.get_mem_usage()}") |
| self.datasets[split] = self.data_manager.load_dataset( |
| split, |
| self.training, |
| epoch=epoch, |
| combine=combine, |
| shard_epoch=shard_epoch, |
| **kwargs, |
| ) |
|
|
| def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): |
| if constraints is not None: |
| raise NotImplementedError( |
| "Constrained decoding with the multilingual_translation task is not supported" |
| ) |
|
|
| src_data = ListDataset(src_tokens, src_lengths) |
| dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) |
| src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"] |
| if self.args.lang_tok_replacing_bos_eos: |
| dataset = self.data_manager.alter_dataset_langtok( |
| dataset, |
| src_eos=self.source_dictionary.eos(), |
| src_lang=self.args.source_lang, |
| tgt_eos=self.target_dictionary.eos(), |
| tgt_lang=self.args.target_lang, |
| src_langtok_spec=src_langtok_spec, |
| tgt_langtok_spec=tgt_langtok_spec, |
| ) |
| else: |
| dataset.src = self.data_manager.src_dataset_tranform_func( |
| self.args.source_lang, |
| self.args.target_lang, |
| dataset=dataset.src, |
| spec=src_langtok_spec, |
| ) |
| return dataset |
|
|
| def build_generator( |
| self, |
| models, |
| args, |
| seq_gen_cls=None, |
| extra_gen_cls_kwargs=None, |
| ): |
| if not getattr(args, "keep_inference_langtok", False): |
| _, tgt_langtok_spec = self.args.langtoks["main"] |
| if tgt_langtok_spec: |
| tgt_lang_tok = self.data_manager.get_decoder_langtok( |
| self.args.target_lang, tgt_langtok_spec |
| ) |
| extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} |
| extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok} |
|
|
| return super().build_generator( |
| models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs |
| ) |
|
|
| def build_model(self, args, from_checkpoint=False): |
| return super().build_model(args, from_checkpoint) |
|
|
| def valid_step(self, sample, model, criterion): |
| loss, sample_size, logging_output = super().valid_step(sample, model, criterion) |
| return loss, sample_size, logging_output |
|
|
| def inference_step( |
| self, generator, models, sample, prefix_tokens=None, constraints=None |
| ): |
| with torch.no_grad(): |
| _, tgt_langtok_spec = self.args.langtoks["main"] |
| if not self.args.lang_tok_replacing_bos_eos: |
| if prefix_tokens is None and tgt_langtok_spec: |
| tgt_lang_tok = self.data_manager.get_decoder_langtok( |
| self.args.target_lang, tgt_langtok_spec |
| ) |
| src_tokens = sample["net_input"]["src_tokens"] |
| bsz = src_tokens.size(0) |
| prefix_tokens = ( |
| torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens) |
| ) |
| return generator.generate( |
| models, |
| sample, |
| prefix_tokens=prefix_tokens, |
| constraints=constraints, |
| ) |
| else: |
| return generator.generate( |
| models, |
| sample, |
| prefix_tokens=prefix_tokens, |
| bos_token=self.data_manager.get_decoder_langtok( |
| self.args.target_lang, tgt_langtok_spec |
| ) |
| if tgt_langtok_spec |
| else self.target_dictionary.eos(), |
| ) |
|
|
| def reduce_metrics(self, logging_outputs, criterion): |
| super().reduce_metrics(logging_outputs, criterion) |
|
|
| def max_positions(self): |
| """Return the max sentence length allowed by the task.""" |
| return (self.args.max_source_positions, self.args.max_target_positions) |
|
|
| @property |
| def source_dictionary(self): |
| return self.data_manager.get_source_dictionary(self.source_langs[0]) |
|
|
| @property |
| def target_dictionary(self): |
| return self.data_manager.get_target_dictionary(self.target_langs[0]) |
|
|
| def create_batch_sampler_func( |
| self, |
| max_positions, |
| ignore_invalid_inputs, |
| max_tokens, |
| max_sentences, |
| required_batch_size_multiple=1, |
| seed=1, |
| ): |
| def construct_batch_sampler(dataset, epoch): |
| splits = [ |
| s for s, _ in self.datasets.items() if self.datasets[s] == dataset |
| ] |
| split = splits[0] if len(splits) > 0 else None |
| |
| if epoch is not None: |
| |
| dataset.set_epoch(epoch) |
|
|
| |
| start_time = time.time() |
| logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}") |
|
|
| with data_utils.numpy_seed(seed): |
| indices = dataset.ordered_indices() |
| logger.info( |
| f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}" |
| ) |
| logger.info(f"mem usage: {data_utils.get_mem_usage()}") |
|
|
| |
| if max_positions is not None: |
| my_time = time.time() |
| indices = self.filter_indices_by_size( |
| indices, dataset, max_positions, ignore_invalid_inputs |
| ) |
| logger.info( |
| f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}" |
| ) |
| logger.info(f"mem usage: {data_utils.get_mem_usage()}") |
|
|
| |
| my_time = time.time() |
| batch_sampler = dataset.batch_by_size( |
| indices, |
| max_tokens=max_tokens, |
| max_sentences=max_sentences, |
| required_batch_size_multiple=required_batch_size_multiple, |
| ) |
|
|
| logger.info( |
| f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}" |
| ) |
| logger.info( |
| f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}" |
| ) |
| logger.info(f"mem usage: {data_utils.get_mem_usage()}") |
|
|
| return batch_sampler |
|
|
| return construct_batch_sampler |
|
|
| |
| def get_batch_iterator( |
| self, |
| dataset, |
| max_tokens=None, |
| max_sentences=None, |
| max_positions=None, |
| ignore_invalid_inputs=False, |
| required_batch_size_multiple=1, |
| seed=1, |
| num_shards=1, |
| shard_id=0, |
| num_workers=0, |
| epoch=1, |
| data_buffer_size=0, |
| disable_iterator_cache=False, |
| skip_remainder_batch=False, |
| grouped_shuffling=False, |
| update_epoch_batch_itr=False, |
| ): |
| """ |
| Get an iterator that yields batches of data from the given dataset. |
| |
| Args: |
| dataset (~fairseq.data.FairseqDataset): dataset to batch |
| max_tokens (int, optional): max number of tokens in each batch |
| (default: None). |
| max_sentences (int, optional): max number of sentences in each |
| batch (default: None). |
| max_positions (optional): max sentence length supported by the |
| model (default: None). |
| ignore_invalid_inputs (bool, optional): don't raise Exception for |
| sentences that are too long (default: False). |
| required_batch_size_multiple (int, optional): require batch size to |
| be a multiple of N (default: 1). |
| seed (int, optional): seed for random number generator for |
| reproducibility (default: 1). |
| num_shards (int, optional): shard the data iterator into N |
| shards (default: 1). |
| shard_id (int, optional): which shard of the data iterator to |
| return (default: 0). |
| num_workers (int, optional): how many subprocesses to use for data |
| loading. 0 means the data will be loaded in the main process |
| (default: 0). |
| epoch (int, optional): the epoch to start the iterator from |
| (default: 0). |
| data_buffer_size (int, optional): number of batches to |
| preload (default: 0). |
| disable_iterator_cache (bool, optional): don't cache the |
| EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) |
| (default: False). |
| grouped_shuffling (bool, optional): group batches with each groups |
| containing num_shards batches and shuffle groups. Reduces difference |
| between sequence lengths among workers for batches sorted by length. |
| update_epoch_batch_itr (bool optional): if true then donot use the cached |
| batch iterator for the epoch |
| |
| Returns: |
| ~fairseq.iterators.EpochBatchIterator: a batched iterator over the |
| given dataset split |
| """ |
| |
| assert isinstance(dataset, FairseqDataset) |
| if dataset in self.dataset_to_epoch_iter: |
| return self.dataset_to_epoch_iter[dataset] |
| if self.args.sampling_method == "RoundRobin": |
| batch_iter = super().get_batch_iterator( |
| dataset, |
| max_tokens=max_tokens, |
| max_sentences=max_sentences, |
| max_positions=max_positions, |
| ignore_invalid_inputs=ignore_invalid_inputs, |
| required_batch_size_multiple=required_batch_size_multiple, |
| seed=seed, |
| num_shards=num_shards, |
| shard_id=shard_id, |
| num_workers=num_workers, |
| epoch=epoch, |
| data_buffer_size=data_buffer_size, |
| disable_iterator_cache=disable_iterator_cache, |
| skip_remainder_batch=skip_remainder_batch, |
| update_epoch_batch_itr=update_epoch_batch_itr, |
| ) |
| self.dataset_to_epoch_iter[dataset] = batch_iter |
| return batch_iter |
|
|
| construct_batch_sampler = self.create_batch_sampler_func( |
| max_positions, |
| ignore_invalid_inputs, |
| max_tokens, |
| max_sentences, |
| required_batch_size_multiple=required_batch_size_multiple, |
| seed=seed, |
| ) |
|
|
| epoch_iter = iterators.EpochBatchIterator( |
| dataset=dataset, |
| collate_fn=dataset.collater, |
| batch_sampler=construct_batch_sampler, |
| seed=seed, |
| num_shards=num_shards, |
| shard_id=shard_id, |
| num_workers=num_workers, |
| epoch=epoch, |
| ) |
| return epoch_iter |
|
|