| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import math |
| from typing import List, Optional, NamedTuple |
|
|
| import numpy as np |
| from fairseq.data.resampling_dataset import ResamplingDataset |
| import torch |
| from fairseq.data import ( |
| ConcatDataset, |
| LanguagePairDataset, |
| FileAudioDataset, |
| data_utils, |
| ) |
| from fairseq.data import FairseqDataset |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ModalityDatasetItem(NamedTuple): |
| datasetname: str |
| dataset: any |
| max_positions: List[int] |
| max_tokens: Optional[int] = None |
| max_sentences: Optional[int] = None |
|
|
|
|
| def resampling_dataset_present(ds): |
| if isinstance(ds, ResamplingDataset): |
| return True |
| if isinstance(ds, ConcatDataset): |
| return any(resampling_dataset_present(d) for d in ds.datasets) |
| if hasattr(ds, "dataset"): |
| return resampling_dataset_present(ds.dataset) |
| return False |
|
|
|
|
| |
| |
| |
| |
| |
| |
| class MultiModalityDataset(ConcatDataset): |
| def __init__(self, datasets: List[ModalityDatasetItem]): |
| id_to_mode = [] |
| dsets = [] |
| max_tokens = [] |
| max_sentences = [] |
| max_positions = [] |
| for dset in datasets: |
| id_to_mode.append(dset.datasetname) |
| dsets.append(dset.dataset) |
| max_tokens.append(dset.max_tokens) |
| max_positions.append(dset.max_positions) |
| max_sentences.append(dset.max_sentences) |
| weights = [1.0 for s in dsets] |
| super().__init__(dsets, weights) |
| self.max_tokens = max_tokens |
| self.max_positions = max_positions |
| self.max_sentences = max_sentences |
| self.id_to_mode = id_to_mode |
| self.raw_sub_batch_samplers = [] |
| self._cur_epoch = 0 |
|
|
| def set_epoch(self, epoch): |
| super().set_epoch(epoch) |
| self._cur_epoch = epoch |
|
|
| def __getitem__(self, idx): |
| dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) |
| sample = self.datasets[dataset_idx][sample_idx] |
| return (dataset_idx, sample) |
|
|
| def collater(self, samples): |
| if len(samples) == 0: |
| return {} |
| dataset_idx = samples[0][0] |
| |
| assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0 |
| samples = self.datasets[dataset_idx].collater([x[1] for x in samples]) |
| |
| samples["net_input"]["mode"] = self.id_to_mode[dataset_idx] |
|
|
| return samples |
|
|
| def size(self, index: int): |
| if len(self.datasets) == 1: |
| return self.datasets[0].size(index) |
| return super().size(index) |
|
|
| @property |
| def sizes(self): |
| if len(self.datasets) == 1: |
| return self.datasets[0].sizes |
| return super().sizes |
|
|
| def ordered_indices(self): |
| """ |
| Returns indices sorted by length. So less padding is needed. |
| """ |
| if len(self.datasets) == 1: |
| return self.datasets[0].ordered_indices() |
| indices_group = [] |
| for d_idx, ds in enumerate(self.datasets): |
| sample_num = self.cumulative_sizes[d_idx] |
| if d_idx > 0: |
| sample_num = sample_num - self.cumulative_sizes[d_idx - 1] |
| assert sample_num == len(ds) |
| indices_group.append(ds.ordered_indices()) |
| return indices_group |
|
|
| def get_raw_batch_samplers(self, required_batch_size_multiple, seed): |
| with data_utils.numpy_seed(seed): |
| indices = self.ordered_indices() |
| for i, ds in enumerate(self.datasets): |
| |
| |
| if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present( |
| ds |
| ): |
| logger.info(f"dataset {i} is valid and it is not re-sampled") |
| continue |
| indices[i] = ds.filter_indices_by_size( |
| indices[i], |
| self.max_positions[i], |
| )[0] |
| sub_batch_sampler = ds.batch_by_size( |
| indices[i], |
| max_tokens=self.max_tokens[i], |
| max_sentences=self.max_sentences[i], |
| required_batch_size_multiple=required_batch_size_multiple, |
| ) |
| if i < len(self.raw_sub_batch_samplers): |
| self.raw_sub_batch_samplers[i] = sub_batch_sampler |
| else: |
| self.raw_sub_batch_samplers.append(sub_batch_sampler) |
|
|
| def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed): |
| self.get_raw_batch_samplers(required_batch_size_multiple, seed) |
| batch_samplers = [] |
| for i, _ in enumerate(self.datasets): |
| if i > 0: |
| sub_batch_sampler = [ |
| [y + self.cumulative_sizes[i - 1] for y in x] |
| for x in self.raw_sub_batch_samplers[i] |
| ] |
| else: |
| sub_batch_sampler = list(self.raw_sub_batch_samplers[i]) |
| smp_r = mult_ratios[i] |
| if smp_r != 1: |
| is_increase = "increased" if smp_r > 1 else "decreased" |
| logger.info( |
| "number of batch for the dataset {} is {} from {} to {}".format( |
| self.id_to_mode[i], |
| is_increase, |
| len(sub_batch_sampler), |
| int(len(sub_batch_sampler) * smp_r), |
| ) |
| ) |
| mul_samplers = [] |
| for _ in range(math.floor(smp_r)): |
| mul_samplers = mul_samplers + sub_batch_sampler |
| if math.floor(smp_r) != smp_r: |
| with data_utils.numpy_seed(seed + self._cur_epoch): |
| np.random.shuffle(sub_batch_sampler) |
| smp_num = int( |
| (smp_r - math.floor(smp_r)) * len(sub_batch_sampler) |
| ) |
| mul_samplers = mul_samplers + sub_batch_sampler[:smp_num] |
| sub_batch_sampler = mul_samplers |
| else: |
| logger.info( |
| "dataset {} batch number is {} ".format( |
| self.id_to_mode[i], len(sub_batch_sampler) |
| ) |
| ) |
| batch_samplers.append(sub_batch_sampler) |
|
|
| return batch_samplers |
|
|
|
|
| class LangPairMaskDataset(FairseqDataset): |
| def __init__( |
| self, |
| dataset: LanguagePairDataset, |
| src_eos: int, |
| src_bos: Optional[int] = None, |
| noise_id: Optional[int] = -1, |
| mask_ratio: Optional[float] = 0, |
| mask_type: Optional[str] = "random", |
| ): |
| self.dataset = dataset |
| self.src_eos = src_eos |
| self.src_bos = src_bos |
| self.noise_id = noise_id |
| self.mask_ratio = mask_ratio |
| self.mask_type = mask_type |
| assert mask_type in ("random", "tail") |
|
|
| @property |
| def src_sizes(self): |
| return self.dataset.src_sizes |
|
|
| @property |
| def tgt_sizes(self): |
| return self.dataset.tgt_sizes |
|
|
| @property |
| def sizes(self): |
| |
| return self.dataset.sizes |
|
|
| def get_batch_shapes(self): |
| if hasattr(self.dataset, "get_batch_shapes"): |
| return self.dataset.get_batch_shapes() |
| return self.dataset.buckets |
|
|
| def num_tokens_vec(self, indices): |
| return self.dataset.num_tokens_vec(indices) |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def num_tokens(self, index): |
| return self.dataset.num_tokens(index) |
|
|
| def size(self, index): |
| return self.dataset.size(index) |
|
|
| def ordered_indices(self): |
| return self.dataset.ordered_indices() |
|
|
| @property |
| def supports_prefetch(self): |
| return getattr(self.dataset, "supports_prefetch", False) |
|
|
| def prefetch(self, indices): |
| return self.dataset.prefetch(indices) |
|
|
| def mask_src_tokens(self, sample): |
| src_item = sample["source"] |
| mask = None |
| if self.mask_type == "random": |
| mask = torch.rand(len(src_item)).le(self.mask_ratio) |
| else: |
| mask = torch.ones(len(src_item)) |
| mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0 |
| mask = mask.eq(1) |
| if src_item[0] == self.src_bos: |
| mask[0] = False |
| if src_item[-1] == self.src_eos: |
| mask[-1] = False |
| mask_src_item = src_item.masked_fill(mask, self.noise_id) |
| smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]} |
| return smp |
|
|
| def __getitem__(self, index): |
| sample = self.dataset[index] |
| if self.mask_ratio > 0: |
| sample = self.mask_src_tokens(sample) |
| return sample |
|
|
| def collater(self, samples, pad_to_length=None): |
| return self.dataset.collater(samples, pad_to_length) |
|
|
|
|
| class FileAudioDatasetWrapper(FileAudioDataset): |
| def collater(self, samples): |
| samples = super().collater(samples) |
| if len(samples) == 0: |
| return {} |
| samples["net_input"]["src_tokens"] = samples["net_input"]["source"] |
| samples["net_input"]["prev_output_tokens"] = None |
| del samples["net_input"]["source"] |
| samples["net_input"]["src_lengths"] = None |
| samples["net_input"]["alignment"] = None |
| return samples |
|
|