| | |
| | |
| | |
| | |
| |
|
| |
|
| | from collections import OrderedDict, defaultdict |
| | import json |
| | import os |
| | import logging |
| | from argparse import ArgumentError |
| |
|
| | from fairseq import options, models |
| | from fairseq.data import ( |
| | data_utils, |
| | Dictionary, |
| | LanguagePairDataset, |
| | IndexedDataset, |
| | FairseqDataset, |
| | ) |
| | from .multitask_data_utils import ( |
| | MultitaskDatasetWrapper, |
| | MultidatasetEpochBatchIterator, |
| | ) |
| |
|
| |
|
| | from fairseq.tasks import LegacyFairseqTask, register_task |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @register_task("laser") |
| | class LaserTask(LegacyFairseqTask): |
| | @staticmethod |
| | def add_args(parser): |
| | """Add task-specific arguments to the parser.""" |
| | parser.add_argument( |
| | "configfile", metavar="PATH", help="dataset configuration file in json" |
| | ) |
| | parser.add_argument( |
| | "--weighting-alpha", |
| | type=float, |
| | default=None, |
| | help="alpha for automatic weighting", |
| | ) |
| | parser.add_argument( |
| | "--raw-text", action="store_true", help="load raw text dataset" |
| | ) |
| | parser.add_argument( |
| | "--left-pad-source", |
| | default="True", |
| | type=str, |
| | metavar="BOOL", |
| | help="pad the source on the left (default: True)", |
| | ) |
| | parser.add_argument( |
| | "--left-pad-target", |
| | default="False", |
| | type=str, |
| | metavar="BOOL", |
| | help="pad the target on the left (default: False)", |
| | ) |
| | try: |
| | parser.add_argument( |
| | "--max-source-positions", |
| | default=1024, |
| | type=int, |
| | metavar="N", |
| | help="max number of tokens in the source sequence", |
| | ) |
| | parser.add_argument( |
| | "--max-target-positions", |
| | default=1024, |
| | type=int, |
| | metavar="N", |
| | help="max number of tokens in the target sequence", |
| | ) |
| | except ArgumentError: |
| | |
| | pass |
| |
|
| | def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks): |
| | super().__init__(args) |
| | self.config = config |
| | self.src_dictionary = src_dictionary |
| | self.tgt_dictionary = tgt_dictionary |
| | self.num_tasks = num_tasks |
| |
|
| | @classmethod |
| | def setup_task(cls, args, **kwargs): |
| | with open(args.configfile, "r") as f: |
| | config = json.load(f) |
| | num_tasks = max(dataset["id"] for dataset in config["train"]) + 1 |
| |
|
| | args.left_pad_source = options.eval_bool(args.left_pad_source) |
| | args.left_pad_target = options.eval_bool(args.left_pad_target) |
| |
|
| | src_dictionary = Dictionary.load(config["src_vocab"]) |
| | tgt_dictionary = Dictionary.load(config["tgt_vocab"]) |
| |
|
| | logger.info( |
| | "| src Dictionary {} : {} types".format( |
| | config["src_vocab"], len(src_dictionary) |
| | ) |
| | ) |
| | logger.info( |
| | "| tgt Dictionary {} : {} types".format( |
| | config["tgt_vocab"], len(tgt_dictionary) |
| | ) |
| | ) |
| |
|
| | return cls(args, config, src_dictionary, tgt_dictionary, num_tasks) |
| |
|
| | |
| | def build_model(self, args): |
| | model = models.build_model(args, self) |
| | return model |
| |
|
| | def dataset(self, split): |
| | if split not in self.datasets: |
| | raise KeyError("Dataset not loaded: " + split) |
| | return self.datasets[split] |
| |
|
| | def load_dataset(self, split, epoch=1, **kwargs): |
| | """Load a dataset split.""" |
| |
|
| | def indexed_dataset(path, dictionary): |
| | if self.args.raw_text: |
| | raise Exception("Unable to handle raw text.") |
| | dataset = IndexedDataset(path, fix_lua_indexing=True) |
| |
|
| | return dataset |
| |
|
| | pair_datasets = OrderedDict() |
| |
|
| | if split == "valid": |
| | self.datasets[split] = pair_datasets |
| | return |
| |
|
| | if split not in self.config: |
| | raise FileNotFoundError( |
| | "Dataset not found in config file: {}".format(split) |
| | ) |
| |
|
| | size_by_corpus = defaultdict(int) |
| | size_sum = 0 |
| | size_sum_with_subsampling = 0 |
| | init_pair_datasets = {} |
| |
|
| | for dataset_config in self.config[split]: |
| | src_path = os.path.dirname(dataset_config["src"]) |
| | corpus_name = src_path.split("/")[-2] |
| | language_pair_name = src_path.split("/")[-1] |
| | pair_datasets_key = corpus_name + "-" + language_pair_name |
| |
|
| | logger.info(f"loading... {pair_datasets_key}") |
| | if "src" in dataset_config: |
| | src_dataset = indexed_dataset( |
| | dataset_config["src"], self.src_dictionary |
| | ) |
| | else: |
| | src_dataset = None |
| |
|
| | if "tgt" in dataset_config: |
| | tgt_dataset = indexed_dataset( |
| | dataset_config["tgt"], self.tgt_dictionary |
| | ) |
| | else: |
| | tgt_dataset = None |
| |
|
| | dataset = LanguagePairDataset( |
| | src_dataset, |
| | src_dataset.sizes, |
| | self.src_dictionary, |
| | tgt_dataset, |
| | tgt_dataset.sizes, |
| | self.tgt_dictionary, |
| | left_pad_source=self.args.left_pad_source, |
| | left_pad_target=self.args.left_pad_target, |
| | ) |
| |
|
| | if pair_datasets_key in init_pair_datasets: |
| | logger.warning( |
| | f"Ignoring already added {pair_datasets_key}. " |
| | f"Consider using `sample` key in order to upsample." |
| | ) |
| | else: |
| | init_pair_datasets[pair_datasets_key] = { |
| | "dataset": dataset, |
| | "sample": dataset_config.get("sample", None), |
| | "id": dataset_config.get("id", None), |
| | "len": len(dataset), |
| | } |
| |
|
| | length_sum = 0 |
| | weighted_freqs_sum = 0 |
| | freq_per_dataset = {} |
| | vmax = 0 |
| | vmin = 1 |
| | weighted_freq_per_dataset = {} |
| |
|
| | if self.args.weighting_alpha: |
| | for key in init_pair_datasets: |
| | if init_pair_datasets[key]["sample"] is None: |
| | length_sum += len(init_pair_datasets[key]["dataset"]) |
| |
|
| | for key in init_pair_datasets: |
| | if init_pair_datasets[key]["sample"] is None: |
| | val = float(init_pair_datasets[key]["len"]) / length_sum |
| | freq_per_dataset[key] = val |
| | weighted_freqs_sum += val ** self.args.weighting_alpha |
| |
|
| | for key in freq_per_dataset: |
| | val = ( |
| | freq_per_dataset[key] ** self.args.weighting_alpha |
| | / weighted_freqs_sum |
| | ) |
| | vmin = min(vmin, val) |
| | vmax = max(vmax, val) |
| | weighted_freq_per_dataset[key] = val |
| |
|
| | for pair_datasets_key in init_pair_datasets: |
| | dataset_config = init_pair_datasets[pair_datasets_key] |
| | dataset = dataset_config["dataset"] |
| | sample = dataset_config["sample"] |
| | if sample is None: |
| | sample = 1.0 |
| |
|
| | if pair_datasets_key in weighted_freq_per_dataset: |
| | w = vmax / weighted_freq_per_dataset[pair_datasets_key] |
| | sample = w |
| |
|
| | sample = round(sample) |
| |
|
| | initial_sample = sample |
| | initial_pair_datasets_key = pair_datasets_key |
| |
|
| | while sample >= 1.0: |
| | assert ( |
| | pair_datasets_key not in pair_datasets |
| | ), f"{pair_datasets_key} already in" |
| | size_sum_with_subsampling += len(dataset) |
| | pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper( |
| | dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key |
| | ) |
| | size_sum += len(dataset) |
| | sample -= 1.0 |
| | pair_datasets_key += "-up" |
| |
|
| | assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}" |
| |
|
| | logger.info( |
| | f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}" |
| | ) |
| | size_by_corpus[corpus_name] += len(dataset) |
| |
|
| | self.datasets[split] = pair_datasets |
| | logger.info( |
| | f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}" |
| | ) |
| |
|
| | @property |
| | def source_dictionary(self): |
| | return self.src_dictionary |
| |
|
| | @property |
| | def target_dictionary(self): |
| | return self.tgt_dictionary |
| |
|
| | def get_batch_iterator( |
| | self, |
| | dataset, |
| | max_tokens=None, |
| | max_sentences=None, |
| | max_positions=None, |
| | ignore_invalid_inputs=False, |
| | required_batch_size_multiple=1, |
| | seed=1, |
| | num_shards=1, |
| | shard_id=0, |
| | num_workers=0, |
| | epoch=1, |
| | data_buffer_size=0, |
| | disable_iterator_cache=False, |
| | ): |
| |
|
| | assert isinstance(dataset, OrderedDict) |
| | assert len(dataset) |
| | assert isinstance(dataset[next(iter(dataset))], FairseqDataset) |
| |
|
| | |
| | for _, dt in dataset.items(): |
| | dt.set_epoch(epoch) |
| |
|
| | indices = OrderedDict() |
| | batch_sampler = OrderedDict() |
| |
|
| | with data_utils.numpy_seed(seed + epoch): |
| | for key, dt in dataset.items(): |
| | logger.info(f"\t ordered_indices {key}") |
| | indices[key] = dt.ordered_indices() |
| |
|
| | |
| | if max_positions is not None: |
| | for key, dt in dataset.items(): |
| | logger.info(f"\t filter_by_size {key}") |
| | indices[key], ignored = dt.filter_indices_by_size( |
| | indices[key], max_positions |
| | ) |
| |
|
| | for key, dt in dataset.items(): |
| | logger.info(f"\t batch_by_size {key}") |
| | batch_sampler[key] = data_utils.batch_by_size( |
| | indices[key], |
| | dt.num_tokens, |
| | max_tokens=max_tokens, |
| | max_sentences=max_sentences, |
| | required_batch_size_multiple=required_batch_size_multiple, |
| | ) |
| |
|
| | epoch_iter = MultidatasetEpochBatchIterator( |
| | dataset=dataset, |
| | batch_sampler=batch_sampler, |
| | seed=seed, |
| | num_shards=num_shards, |
| | shard_id=shard_id, |
| | num_workers=num_workers, |
| | epoch=epoch, |
| | ) |
| |
|
| | return epoch_iter |
| |
|