| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import random |
| | import json |
| | import math |
| | from functools import partial |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch.utils.data import IterableDataset |
| | from cosyvoice.utils.file_utils import read_lists, read_json_lists |
| |
|
| |
|
| | class Processor(IterableDataset): |
| |
|
| | def __init__(self, source, f, *args, **kw): |
| | assert callable(f) |
| | self.source = source |
| | self.f = f |
| | self.args = args |
| | self.kw = kw |
| |
|
| | def set_epoch(self, epoch): |
| | self.source.set_epoch(epoch) |
| |
|
| | def __iter__(self): |
| | """ Return an iterator over the source dataset processed by the |
| | given processor. |
| | """ |
| | assert self.source is not None |
| | assert callable(self.f) |
| | return self.f(iter(self.source), *self.args, **self.kw) |
| |
|
| | def apply(self, f): |
| | assert callable(f) |
| | return Processor(self, f, *self.args, **self.kw) |
| |
|
| |
|
| | class DistributedSampler: |
| |
|
| | def __init__(self, shuffle=True, partition=True): |
| | self.epoch = -1 |
| | self.update() |
| | self.shuffle = shuffle |
| | self.partition = partition |
| |
|
| | def update(self): |
| | assert dist.is_available() |
| | if dist.is_initialized(): |
| | self.rank = dist.get_rank() |
| | self.world_size = dist.get_world_size() |
| | else: |
| | self.rank = 0 |
| | self.world_size = 1 |
| | worker_info = torch.utils.data.get_worker_info() |
| | if worker_info is None: |
| | self.worker_id = 0 |
| | self.num_workers = 1 |
| | else: |
| | self.worker_id = worker_info.id |
| | self.num_workers = worker_info.num_workers |
| | return dict(rank=self.rank, |
| | world_size=self.world_size, |
| | worker_id=self.worker_id, |
| | num_workers=self.num_workers) |
| |
|
| | def set_epoch(self, epoch): |
| | self.epoch = epoch |
| |
|
| | def sample(self, data): |
| | """ Sample data according to rank/world_size/num_workers |
| | |
| | Args: |
| | data(List): input data list |
| | |
| | Returns: |
| | List: data list after sample |
| | """ |
| | data = list(range(len(data))) |
| | |
| | if self.partition: |
| | if self.shuffle: |
| | random.Random(self.epoch).shuffle(data) |
| | if len(data) < self.world_size: |
| | data = data * math.ceil(self.world_size / len(data)) |
| | data = data[:self.world_size] |
| | data = data[self.rank::self.world_size] |
| | if len(data) < self.num_workers: |
| | data = data * math.ceil(self.num_workers / len(data)) |
| | data = data[:self.num_workers] |
| | data = data[self.worker_id::self.num_workers] |
| | return data |
| |
|
| |
|
| | class DataList(IterableDataset): |
| |
|
| | def __init__(self, lists, shuffle=True, partition=True): |
| | self.lists = lists |
| | self.sampler = DistributedSampler(shuffle, partition) |
| |
|
| | def set_epoch(self, epoch): |
| | self.sampler.set_epoch(epoch) |
| |
|
| | def __iter__(self): |
| | sampler_info = self.sampler.update() |
| | indexes = self.sampler.sample(self.lists) |
| | for index in indexes: |
| | data = dict(src=self.lists[index]) |
| | data.update(sampler_info) |
| | yield data |
| |
|
| |
|
| | def Dataset(data_list_file, |
| | data_pipeline, |
| | mode='train', |
| | gan=False, |
| | shuffle=True, |
| | partition=True, |
| | tts_file='', |
| | prompt_utt2data=''): |
| | """ Construct dataset from arguments |
| | |
| | We have two shuffle stage in the Dataset. The first is global |
| | shuffle at shards tar/raw file level. The second is global shuffle |
| | at training samples level. |
| | |
| | Args: |
| | data_type(str): raw/shard |
| | tokenizer (BaseTokenizer): tokenizer to tokenize |
| | partition(bool): whether to do data partition in terms of rank |
| | """ |
| | assert mode in ['train', 'inference'] |
| | lists = read_lists(data_list_file) |
| | if mode == 'inference': |
| | with open(tts_file) as f: |
| | tts_data = json.load(f) |
| | utt2lists = read_json_lists(prompt_utt2data) |
| | |
| | lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists}) |
| | dataset = DataList(lists, |
| | shuffle=shuffle, |
| | partition=partition) |
| | if mode == 'inference': |
| | |
| | data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data) |
| | if gan is True: |
| | |
| | data_pipeline[-1] = partial(data_pipeline[-1], gan=gan) |
| | for func in data_pipeline: |
| | dataset = Processor(dataset, func, mode=mode) |
| | return dataset |
| |
|