Spaces:
Build error
Build error
| import torchvision | |
| import lightning as L | |
| from torch.utils.data import DataLoader | |
| from utils.transforms import train_transform, test_transform | |
| class Cifar10SearchDataset(torchvision.datasets.CIFAR10): | |
| def __init__(self, root="~/data", train=True, download=True, transform=None): | |
| super().__init__(root=root, train=train, download=download, transform=transform) | |
| def __getitem__(self, index): | |
| image, label = self.data[index], self.targets[index] | |
| if self.transform is not None: | |
| transformed = self.transform(image=image) | |
| image = transformed["image"] | |
| return image, label | |
| class CIFARDataModule(L.LightningDataModule): | |
| def __init__( | |
| self, data_dir="data", batch_size=512, shuffle=True, num_workers=4 | |
| ) -> None: | |
| super().__init__() | |
| self.data_dir = data_dir | |
| self.batch_size = batch_size | |
| self.shuffle = shuffle | |
| self.num_workers = num_workers | |
| def prepare_data(self) -> None: | |
| pass | |
| def setup(self, stage=None): | |
| self.train_dataset = Cifar10SearchDataset( | |
| root=self.data_dir, train=True, transform=train_transform | |
| ) | |
| self.val_dataset = Cifar10SearchDataset( | |
| root=self.data_dir, train=False, transform=test_transform | |
| ) | |
| self.test_dataset = Cifar10SearchDataset( | |
| root=self.data_dir, train=False, transform=test_transform | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| dataset=self.train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=self.shuffle, | |
| num_workers=self.num_workers, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| dataset=self.val_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=self.shuffle, | |
| num_workers=self.num_workers, | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| dataset=self.test_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=self.shuffle, | |
| num_workers=self.num_workers, | |
| ) | |