| | |
| | from pytorch_lightning import LightningDataModule |
| | from AR.data.bucket_sampler import DistributedBucketSampler |
| | from AR.data.dataset import Text2SemanticDataset |
| | from torch.utils.data import DataLoader |
| |
|
| |
|
| | class Text2SemanticDataModule(LightningDataModule): |
| | def __init__( |
| | self, |
| | config, |
| | train_semantic_path, |
| | train_phoneme_path, |
| | dev_semantic_path=None, |
| | dev_phoneme_path=None, |
| | ): |
| | super().__init__() |
| | self.config = config |
| | self.train_semantic_path = train_semantic_path |
| | self.train_phoneme_path = train_phoneme_path |
| | self.dev_semantic_path = dev_semantic_path |
| | self.dev_phoneme_path = dev_phoneme_path |
| | self.num_workers = self.config["data"]["num_workers"] |
| |
|
| | def prepare_data(self): |
| | pass |
| |
|
| | def setup(self, stage=None, output_logs=False): |
| | self._train_dataset = Text2SemanticDataset( |
| | phoneme_path=self.train_phoneme_path, |
| | semantic_path=self.train_semantic_path, |
| | max_sec=self.config["data"]["max_sec"], |
| | pad_val=self.config["data"]["pad_val"], |
| | ) |
| | self._dev_dataset = self._train_dataset |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def train_dataloader(self): |
| | batch_size = self.config["train"]["batch_size"] |
| | sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) |
| | return DataLoader( |
| | self._train_dataset, |
| | batch_size=batch_size, |
| | sampler=sampler, |
| | collate_fn=self._train_dataset.collate, |
| | num_workers=self.num_workers, |
| | persistent_workers=True, |
| | prefetch_factor=16, |
| | ) |
| |
|
| | def val_dataloader(self): |
| | return DataLoader( |
| | self._dev_dataset, |
| | batch_size=1, |
| | shuffle=False, |
| | collate_fn=self._train_dataset.collate, |
| | num_workers=max(self.num_workers, 12), |
| | persistent_workers=True, |
| | prefetch_factor=16, |
| | ) |
| |
|
| | |
| | def test_dataloader(self): |
| | return DataLoader( |
| | self._dev_dataset, |
| | batch_size=1, |
| | shuffle=False, |
| | collate_fn=self._train_dataset.collate, |
| | ) |
| |
|