| from typing import Optional |
| from pytorch_lightning import LightningDataModule |
| from torch.utils.data import DataLoader |
| from fengshen.data.mmap_index_dataset import MMapIndexDataset |
|
|
|
|
| class MMapDataModule(LightningDataModule): |
| @ staticmethod |
| def add_data_specific_args(parent_args): |
| parser = parent_args.add_argument_group('MMAP DataModule') |
| parser.add_argument('--num_workers', default=8, type=int) |
| parser.add_argument('--train_batchsize', default=32, type=int) |
| parser.add_argument('--eval_batchsize', default=32, type=int) |
| parser.add_argument('--test_batchsize', default=32, type=int) |
| parser.add_argument('--train_datas', default=[ |
| './train_datas' |
| ], type=str, nargs='+') |
| parser.add_argument('--valid_datas', default=[ |
| './valid_datas' |
| ], type=str, nargs='+') |
| parser.add_argument('--test_datas', default=[ |
| './test_datas'], |
| type=str, nargs='+') |
| parser.add_argument('--input_tensor_name', default=['input_ids'], type=str, nargs='+') |
| return parent_args |
|
|
| def __init__( |
| self, |
| collate_fn, |
| args, |
| **kwargs, |
| ): |
| super().__init__() |
| self.collate_fn = collate_fn |
| self.train_dataset = MMapIndexDataset(args.train_datas, args.input_tensor_name) |
| self.valid_dataset = MMapIndexDataset(args.valid_datas, args.input_tensor_name) |
| self.test_dataset = MMapIndexDataset(args.test_datas, args.input_tensor_name) |
| self.save_hyperparameters(args) |
|
|
| def setup(self, stage: Optional[str] = None) -> None: |
| return super().setup(stage) |
|
|
| def train_dataloader(self): |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.hparams.train_batchsize, |
| shuffle=True, |
| num_workers=self.hparams.num_workers, |
| collate_fn=self.collate_fn, |
| ) |
|
|
| def val_dataloader(self): |
| return DataLoader( |
| self.valid_dataset, |
| batch_size=self.hparams.eval_batchsize, |
| shuffle=True, |
| num_workers=self.hparams.num_workers, |
| collate_fn=self.collate_fn, |
| ) |
|
|
| def test_dataloader(self): |
| return DataLoader( |
| self.test_dataset, |
| batch_size=self.hparams.test_batchsize, |
| shuffle=True, |
| num_workers=self.hparams.num_workers, |
| collate_fn=self.collate_fn, |
| ) |
|
|