Spaces:
Runtime error
Runtime error
| from typing import List, Optional, Tuple | |
| from sklearn.model_selection import train_test_split | |
| from torch.utils.data import DataLoader, SequentialSampler | |
| from training.datasets import LibriTTSDatasetAcoustic | |
| def train_dataloader( | |
| batch_size: int = 6, | |
| num_workers: int = 5, | |
| root: str = "datasets_cache/LIBRITTS", | |
| cache: bool = True, | |
| cache_dir: str = "datasets_cache", | |
| mem_cache: bool = False, | |
| url: str = "train-clean-360", | |
| lang: str = "en", | |
| selected_speaker_ids: Optional[List[int]] = None, | |
| ) -> DataLoader: | |
| r"""Returns the training dataloader, that is using the LibriTTS dataset. | |
| Args: | |
| batch_size (int): The batch size. | |
| num_workers (int): The number of workers. | |
| root (str): The root directory of the dataset. | |
| cache (bool): Whether to cache the preprocessed data. | |
| cache_dir (str): The directory for the cache. | |
| mem_cache (bool): Whether to use memory cache. | |
| url (str): The URL of the dataset. | |
| lang (str): The language of the dataset. | |
| selected_speaker_ids (Optional[List[int]]): A list of selected speakers. | |
| Returns: | |
| DataLoader: The training and validation dataloaders. | |
| """ | |
| dataset = LibriTTSDatasetAcoustic( | |
| root=root, | |
| lang=lang, | |
| cache=cache, | |
| cache_dir=cache_dir, | |
| mem_cache=mem_cache, | |
| url=url, | |
| selected_speaker_ids=selected_speaker_ids, | |
| ) | |
| train_loader = DataLoader( | |
| dataset, | |
| # 4x80Gb max 10 sec audio | |
| # batch_size=20, # self.train_config.batch_size, | |
| # 4*80Gb max ~20.4 sec audio | |
| batch_size=batch_size, | |
| # TODO: find the optimal num_workers | |
| num_workers=num_workers, | |
| persistent_workers=True, | |
| pin_memory=True, | |
| shuffle=False, | |
| collate_fn=dataset.collate_fn, | |
| ) | |
| return train_loader | |
| def train_val_dataloader( | |
| batch_size: int = 6, | |
| num_workers: int = 5, | |
| root: str = "datasets_cache/LIBRITTS", | |
| cache: bool = True, | |
| cache_dir: str = "datasets_cache", | |
| mem_cache: bool = False, | |
| url: str = "train-clean-360", | |
| lang: str = "en", | |
| validation_split: float = 0.02, # Percentage of data to use for validation | |
| ) -> Tuple[DataLoader, DataLoader]: | |
| r"""Returns the training dataloader, that is using the LibriTTS dataset. | |
| Args: | |
| batch_size (int): The batch size. | |
| num_workers (int): The number of workers. | |
| root (str): The root directory of the dataset. | |
| cache (bool): Whether to cache the preprocessed data. | |
| cache_dir (str): The directory for the cache. | |
| mem_cache (bool): Whether to use memory cache. | |
| url (str): The URL of the dataset. | |
| lang (str): The language of the dataset. | |
| validation_split (float): The percentage of data to use for validation. | |
| Returns: | |
| Tupple[DataLoader, DataLoader]: The training and validation dataloaders. | |
| """ | |
| dataset = LibriTTSDatasetAcoustic( | |
| root=root, | |
| lang=lang, | |
| cache=cache, | |
| cache_dir=cache_dir, | |
| mem_cache=mem_cache, | |
| url=url, | |
| ) | |
| # Split dataset into train and validation | |
| train_indices, val_indices = train_test_split( | |
| list(range(len(dataset))), | |
| test_size=validation_split, | |
| random_state=42, | |
| ) | |
| # Create Samplers | |
| train_sampler = SequentialSampler(train_indices) | |
| val_sampler = SequentialSampler(val_indices) | |
| # dataset = LibriTTSMMDatasetAcoustic("checkpoints/libri_preprocessed_data.pt") | |
| train_loader = DataLoader( | |
| dataset, | |
| # 4x80Gb max 10 sec audio | |
| # batch_size=20, # self.train_config.batch_size, | |
| # 4*80Gb max ~20.4 sec audio | |
| batch_size=batch_size, | |
| # TODO: find the optimal num_workers | |
| num_workers=num_workers, | |
| sampler=train_sampler, | |
| persistent_workers=True, | |
| pin_memory=True, | |
| shuffle=False, | |
| collate_fn=dataset.collate_fn, | |
| ) | |
| val_loader = DataLoader( | |
| dataset, | |
| # 4x80Gb max 10 sec audio | |
| # batch_size=20, # self.train_config.batch_size, | |
| # 4*80Gb max ~20.4 sec audio | |
| batch_size=batch_size, | |
| # TODO: find the optimal num_workers | |
| num_workers=num_workers, | |
| sampler=val_sampler, | |
| persistent_workers=True, | |
| pin_memory=True, | |
| shuffle=False, | |
| collate_fn=dataset.collate_fn, | |
| ) | |
| return train_loader, val_loader | |