Spaces:
Running
Running
| import pytorch_lightning as L | |
| from torch.utils.data import DataLoader, random_split | |
| import torch | |
| import time | |
| import webdataset as wds | |
| from torch.utils.data import default_collate | |
| import math | |
| from PIL import Image | |
| class ImageDataModule(L.LightningDataModule): | |
| def __init__( | |
| self, | |
| train_dataset, | |
| val_dataset, | |
| test_dataset, | |
| full_batch_size, | |
| num_workers, | |
| eval_batch_size=None, | |
| num_nodes=1, | |
| num_devices=1, | |
| val_proportion=0.1, | |
| ): | |
| super().__init__() | |
| self._builders = { | |
| "train": train_dataset, | |
| "val": val_dataset, | |
| "test": test_dataset, | |
| } | |
| self.num_workers = num_workers | |
| self.collate_fn = dict_collate_fn() | |
| self.full_batch_size = full_batch_size | |
| self.train_batch_size = full_batch_size // (num_nodes * num_devices) | |
| if eval_batch_size is None: | |
| self.eval_batch_size = self.train_batch_size | |
| self.full_eval_batch_size = self.full_batch_size | |
| else: | |
| self.eval_batch_size = eval_batch_size // (num_nodes * num_devices) | |
| self.full_eval_batch_size = eval_batch_size | |
| print(f"Each GPU will receive {self.train_batch_size} images for training") | |
| print(f"Each GPU will receive {self.eval_batch_size} images for evaluation") | |
| self.val_proportion = val_proportion | |
| self.world_size = num_nodes * num_devices | |
| def setup(self, stage=None): | |
| """Setup the datamodule. | |
| Args: | |
| stage (str): stage of the datamodule | |
| Is be one of "fit" or "test" or None | |
| """ | |
| print("Stage", stage) | |
| start_time = time.time() | |
| if stage == "fit" or stage is None: | |
| self.train_dataset = self._builders["train"]() | |
| self.train_dataset, self.num_train_batches = self.get_webdataset_length( | |
| self.train_dataset, | |
| dict_collate_fn(), | |
| self.full_batch_size, | |
| self.train_batch_size, | |
| ) | |
| self.val_dataset = self._builders["val"]() | |
| self.val_dataset, self.num_val_batches = self.get_webdataset_length( | |
| self.val_dataset, | |
| dict_collate_fn(), | |
| self.full_eval_batch_size, | |
| self.eval_batch_size, | |
| 0, | |
| ) | |
| print(f"Train dataset size: {len(self.train_dataset)}") | |
| print(f"Val dataset size: {len(self.val_dataset)}") | |
| else: | |
| self.test_dataset = self._builders["test"]() | |
| self.test_dataset, self.num_test_batches = self.get_webdataset_length( | |
| self.test_dataset, | |
| dict_collate_fn(), | |
| self.full_eval_batch_size, | |
| self.eval_batch_size, | |
| self.num_workers, | |
| ) | |
| print(f"Test dataset size: {len(self.test_dataset)}") | |
| end_time = time.time() | |
| print(f"Setup took {(end_time - start_time):.2f} seconds") | |
| def train_dataloader(self): | |
| return wds.WebLoader( | |
| self.train_dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| # persistent_workers=self.num_workers > 1, | |
| ).with_length(self.num_train_batches) | |
| # return DataLoader( | |
| # self.train_dataset, | |
| # batch_size=self.batch_size, | |
| # shuffle=True, | |
| # pin_memory=False, | |
| # drop_last=True, | |
| # num_workers=self.num_workers, | |
| # collate_fn=self.train_dataset.collate_fn, | |
| # ) | |
| def val_dataloader(self): | |
| return wds.WebLoader( | |
| self.val_dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=0, | |
| ).with_length(self.num_val_batches) | |
| def test_dataloader(self): | |
| return wds.WebLoader( | |
| self.test_dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=0, | |
| ).with_length(self.num_test_batches) | |
| def get_webdataset_length( | |
| self, dataset, collate_fn, full_batch_size, batch_size, num_workers=0 | |
| ): | |
| dataset = dataset.compose( | |
| wds.batched( | |
| batch_size, | |
| partial=self.world_size > 1, | |
| collation_fn=collate_fn, | |
| # dict_collate_and_pad(["flan_t5_xl"], max_length=256), | |
| ) | |
| ) | |
| num_samples = dataset.num_samples | |
| if self.world_size > 1: | |
| num_batches = math.ceil(num_samples / full_batch_size) | |
| num_workers = max(1, num_workers) | |
| num_worker_batches = math.ceil(num_batches / num_workers) | |
| num_batches = num_worker_batches * num_workers | |
| num_samples = num_batches * full_batch_size | |
| dataset = dataset.with_epoch(num_worker_batches).with_length( | |
| num_worker_batches | |
| ) | |
| else: | |
| num_batches = math.ceil(num_samples / batch_size) | |
| dataset = dataset.with_epoch(num_batches).with_length(num_batches) | |
| return dataset, num_batches | |
| def dict_collate_fn(): | |
| def dict_collate(batch): | |
| output_dict = {} | |
| if isinstance(batch[0], dict): | |
| for key in batch[0].keys(): | |
| output_dict[key] = dict_collate([item[key] for item in batch]) | |
| else: | |
| # Check if the batch contains PIL images | |
| if isinstance(batch[0], Image.Image): | |
| output_dict = batch # Return list of PIL images directly | |
| else: | |
| output_dict = default_collate(batch) | |
| return output_dict | |
| return dict_collate | |