Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
| # Github source: https://github.com/mbzuai-nlp/ArTST | |
| # Based on speecht5, fairseq and espnet code bases | |
| # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
| # -------------------------------------------------------- | |
| import bisect | |
| import logging | |
| import numpy as np | |
| from torch.utils.data.dataloader import default_collate | |
| from fairseq.data import data_utils | |
| from fairseq.data.fairseq_dataset import FairseqDataset | |
| logger = logging.getLogger(__name__) | |
| class MultitaskDataset(FairseqDataset): | |
| def cumsum(sequence): | |
| r, s = [], 0 | |
| for e in sequence: | |
| curr_len = len(e) | |
| r.append(curr_len + s) | |
| s += curr_len | |
| return r | |
| def __init__(self, datasets, sample_ratios=1, batch_ratio=None): | |
| super(MultitaskDataset, self).__init__() | |
| assert len(datasets) > 0, "datasets should not be an empty iterable" | |
| self.datasets = list(datasets) | |
| if isinstance(sample_ratios, int): | |
| sample_ratios = [sample_ratios] * len(self.datasets) | |
| if batch_ratio is not None: | |
| logger.info('batch ratio is ' + str(batch_ratio)) | |
| self.batch_ratio = batch_ratio | |
| else: | |
| self.batch_ratio = None | |
| else: | |
| logger.info('set sample ratio to ' + str(sample_ratios)) | |
| if batch_ratio is not None: | |
| logger.info('batch ratio is ' + str(batch_ratio)) | |
| self.batch_ratio = batch_ratio | |
| else: | |
| self.batch_ratio = None | |
| self.sample_ratios = sample_ratios | |
| self._ordered_indices = None | |
| self._update_size() | |
| def __len__(self): | |
| return self.cumulative_sizes[-1] | |
| def __getitem__(self, idx): | |
| dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) | |
| sample = self.datasets[dataset_idx][sample_idx] | |
| if isinstance(sample, dict): | |
| sample["dataset_idx"] = dataset_idx | |
| else: | |
| sample = sample + (dataset_idx,) | |
| return sample | |
| def _update_size(self): | |
| self.cumulative_sizes = self.cumsum(self.datasets) | |
| self.real_sizes = [len(d) for d in self.datasets] | |
| def _get_dataset_and_sample_index(self, idx: int): | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
| if dataset_idx == 0: | |
| sample_idx = idx | |
| else: | |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
| sample_idx = sample_idx % self.real_sizes[dataset_idx] | |
| return dataset_idx, sample_idx | |
| def collater(self, samples, **extra_args): | |
| # For now only supports datasets with same underlying collater implementations | |
| if samples is not None and len(samples) > 0: | |
| if isinstance(samples[0], dict): | |
| dataset_idx = samples[0]["dataset_idx"] | |
| else: | |
| dataset_idx = samples[0][-1] | |
| samples = [sample[:-1] for sample in samples] | |
| else: | |
| dataset_idx = 0 | |
| if hasattr(self.datasets[dataset_idx], "collater"): | |
| return self.datasets[dataset_idx].collater(samples, **extra_args) | |
| else: | |
| return default_collate(samples, **extra_args) | |
| def size(self, idx: int): | |
| """ | |
| Return an example's size as a float or tuple. | |
| """ | |
| dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) | |
| return self.datasets[dataset_idx].size(sample_idx) | |
| def num_tokens(self, index: int): | |
| return np.max(self.size(index)) | |
| def attr(self, attr: str, index: int): | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, index) | |
| return getattr(self.datasets[dataset_idx], attr, None) | |
| def sizes(self): | |
| _dataset_sizes = [] | |
| for ds in self.datasets: | |
| if isinstance(ds.sizes, np.ndarray): | |
| _dataset_sizes.append(ds.sizes) | |
| else: | |
| # Only support underlying dataset with single size array. | |
| assert isinstance(ds.sizes, list) | |
| _dataset_sizes.append(ds.sizes[0]) | |
| return np.concatenate(_dataset_sizes) | |
| def supports_prefetch(self): | |
| return all(d.supports_prefetch for d in self.datasets) | |
| def ordered_indices(self): | |
| # ordered_indices = [] | |
| # for i, dataset in enumerate(self.datasets): | |
| # indice = dataset.ordered_indices() | |
| # ordered_indices.append(indice) | |
| if self._ordered_indices is None: | |
| # Call the underlying dataset's ordered_indices() here, so that we | |
| # get the same random ordering as we would have from using the | |
| # underlying sub-datasets directly. | |
| self._ordered_indices = [ | |
| dataset.ordered_indices() | |
| for dataset in self.datasets | |
| ] | |
| return np.arange(len(self)) | |
| def prefetch(self, indices): | |
| frm = 0 | |
| for to, ds in zip(self.cumulative_sizes, self.datasets): | |
| real_size = len(ds) | |
| if getattr(ds, "supports_prefetch", False): | |
| ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) | |
| frm = to | |
| def batch_by_size( | |
| self, | |
| indices, | |
| max_tokens=None, | |
| max_sentences=None, | |
| required_batch_size_multiple=1, | |
| ): | |
| if not hasattr(self, "max_tokens"): | |
| self.max_tokens = max_tokens | |
| if not hasattr(self, "max_sentences"): | |
| self.max_sentences = max_sentences | |
| if not hasattr(self, "required_batch_size_multiple"): | |
| self.required_batch_size_multiple = required_batch_size_multiple | |
| batch_samplers = [] | |
| for i, dataset in enumerate(self.datasets): | |
| batch_sampler = dataset.batch_by_size( | |
| self._ordered_indices[i], | |
| max_tokens=max_tokens if self.batch_ratio is None else max_tokens * self.batch_ratio[i], | |
| max_sentences=max_sentences, | |
| required_batch_size_multiple=required_batch_size_multiple, | |
| ) | |
| if i > 0: | |
| for batch in batch_sampler: | |
| batch += self.cumulative_sizes[i - 1] | |
| if self.sample_ratios[i] != 1.0: | |
| batch_sampler = np.array(batch_sampler) | |
| batch_sampler = np.random.choice(batch_sampler, int(len(batch_sampler) * self.sample_ratios[i])) | |
| batch_sampler = list(batch_sampler) | |
| logger.info('Adjust batch by ratio ' + str(self.sample_ratios[i]) + ' and the number of batch is ' + str(int(len(batch_sampler))) + ' for dataset ' + str(i)) | |
| batch_samplers.extend(batch_sampler) | |
| return batch_samplers | |
| def filter_indices_by_size(self, indices, max_positions): | |
| """ | |
| Filter each sub-dataset independently, then update the round robin to work | |
| on the filtered sub-datasets. | |
| """ | |
| if not hasattr(self, "max_positions"): | |
| self.max_positions = max_positions | |
| ignored_some = False | |
| for i in range(len(self.datasets)): | |
| # ignored = [] | |
| self._ordered_indices[i], ignored = self.datasets[i].filter_indices_by_size( | |
| self._ordered_indices[i], self.max_positions[i] | |
| ) | |
| if len(ignored) > 0: | |
| ignored_some = True | |
| logger.warning( | |
| f"{len(ignored)} samples from {i} have invalid sizes and will be skipped, " | |
| f"max_positions={self.max_positions[i]}, first few sample ids={ignored[:10]}" | |
| ) | |
| logger.info('update dataset size') | |
| self._update_size() | |
| # Since we are modifying in place the _ordered_indices, | |
| # it's not possible anymore to return valid ignored indices. | |
| # Hopefully the extra debug information print above should be enough to debug. | |
| # Ideally we would receive ignore_invalid_inputs so that we could have | |
| # a proper error message. | |
| return (np.arange(len(self)), [0] if ignored_some else []) | |
| def can_reuse_epoch_itr_across_epochs(self): | |
| return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets) | |
| def set_epoch(self, epoch): | |
| super().set_epoch(epoch) | |
| for ds in self.datasets: | |
| if hasattr(ds, "set_epoch"): | |
| ds.set_epoch(epoch) | |
| def shuffle_batches(self, batches, seed): | |
| logger.info("shuffle batches") | |
| new_batches_fromlist = [] | |
| new_batches_notlist = [] | |
| new_batches = [] | |
| with data_utils.numpy_seed(seed): | |
| np.random.shuffle(batches) | |
| for batch in batches: | |
| if isinstance(batch, list): | |
| # np.random.shuffle(batch) | |
| new_batches_fromlist.append(batch) | |
| else: | |
| new_batches_notlist.append(batch) | |
| logger.info("Get " + str(len(new_batches_fromlist)) + " chunk from speech sides") | |
| logger.info("Get " + str(sum([len(batch_list) for batch_list in new_batches_fromlist])) + " batches from speech sides") | |
| logger.info("Get " + str(len(new_batches_notlist)) + " batches from text sides") | |
| if len(new_batches_fromlist) == 0: | |
| return new_batches_notlist | |
| st_ratio = int(len(new_batches_notlist) / len(new_batches_fromlist)) | |
| logger.info("Get st_ratio " + str(st_ratio)) | |
| last_idx = 0 | |
| for i in range(len(new_batches_fromlist)): | |
| if i == len(new_batches_fromlist) - 1: | |
| new_batches_fromlist[i].extend(new_batches_notlist[last_idx:]) | |
| else: | |
| new_batches_fromlist[i].extend(new_batches_notlist[last_idx : last_idx + st_ratio]) | |
| np.random.shuffle(new_batches_fromlist[i]) | |
| new_batches.extend(new_batches_fromlist[i]) | |
| last_idx = last_idx + st_ratio | |
| logger.info("Finish shuffle") | |
| return new_batches | |
| def reset_batch_sampler(self): | |
| logger.info("reset batch sampler") | |
| self._ordered_indices = [ | |
| self.datasets[i].ordered_indices() | |
| for i in range(len(self.datasets)) | |
| ] | |
| self.filter_indices_by_size(None, None) | |
| batch_samplers = self.batch_by_size( | |
| None, | |
| self.max_tokens, | |
| self.max_sentences, | |
| self.required_batch_size_multiple | |
| ) | |
| return batch_samplers | |