| | import warnings |
| | import os |
| | import math |
| | import numpy as np |
| | import torch.utils.data |
| | import torchvision.transforms as transforms |
| | import torchvision.datasets as datasets |
| | from .base_provider import DataProvider |
| | from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler |
| |
|
| | __all__ = ["Cifar10DataProvider"] |
| |
|
| | class Cifar10DataProvider(DataProvider): |
| | DEFAULT_PATH = "./dataset/cifar10" |
| | def __init__( |
| | self, |
| | save_path=None, |
| | train_batch_size=256, |
| | test_batch_size=512, |
| | valid_size=None, |
| | resize_scale=0.08, |
| | distort_color=None, |
| | n_worker=32, |
| | image_size=32, |
| | num_replicas=None, |
| | rank=None, |
| | ): |
| |
|
| | warnings.filterwarnings("ignore") |
| | self._save_path = save_path |
| |
|
| | self.image_size = image_size |
| | |
| |
|
| | self._valid_transform_dict = {} |
| | if not isinstance(self.image_size, int): |
| | from proard.utils.my_dataloader.my_data_loader import MyDataLoader |
| |
|
| | assert isinstance(self.image_size, list) |
| | self.image_size.sort() |
| | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() |
| | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) |
| |
|
| | for img_size in self.image_size: |
| | self._valid_transform_dict[img_size] = self.build_valid_transform( |
| | img_size |
| | ) |
| | self.active_img_size = max(self.image_size) |
| | valid_transforms = self._valid_transform_dict[self.active_img_size] |
| | train_loader_class = MyDataLoader |
| | else: |
| | self.active_img_size = self.image_size |
| | valid_transforms = self.build_valid_transform() |
| | train_loader_class = torch.utils.data.DataLoader |
| |
|
| | train_dataset = self.train_dataset(self.build_train_transform()) |
| |
|
| | if valid_size is not None: |
| | if not isinstance(valid_size, int): |
| | assert isinstance(valid_size, float) and 0 < valid_size < 1 |
| | valid_size = int(len(train_dataset) * valid_size) |
| |
|
| | valid_dataset = self.train_dataset(valid_transforms) |
| | train_indexes, valid_indexes = self.random_sample_valid_set( |
| | len(train_dataset), valid_size |
| | ) |
| |
|
| | if num_replicas is not None: |
| | train_sampler = MyDistributedSampler( |
| | train_dataset, num_replicas, rank, True, np.array(train_indexes) |
| | ) |
| | valid_sampler = MyDistributedSampler( |
| | valid_dataset, num_replicas, rank, True, np.array(valid_indexes) |
| | ) |
| | else: |
| | train_sampler = torch.utils.data.sampler.SubsetRandomSampler( |
| | train_indexes |
| | ) |
| | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler( |
| | valid_indexes |
| | ) |
| |
|
| | self.train = train_loader_class( |
| | train_dataset, |
| | batch_size=train_batch_size, |
| | sampler=train_sampler, |
| | num_workers=n_worker, |
| | pin_memory=False, |
| | ) |
| | self.valid = torch.utils.data.DataLoader( |
| | valid_dataset, |
| | batch_size=test_batch_size, |
| | sampler=valid_sampler, |
| | num_workers=n_worker, |
| | pin_memory=False, |
| | ) |
| | else: |
| | if num_replicas is not None: |
| | train_sampler = torch.utils.data.distributed.DistributedSampler( |
| | train_dataset, num_replicas, rank |
| | ) |
| | self.train = train_loader_class( |
| | train_dataset, |
| | batch_size=train_batch_size, |
| | sampler=train_sampler, |
| | num_workers=n_worker, |
| | pin_memory=True, |
| | ) |
| | else: |
| | self.train = train_loader_class( |
| | train_dataset, |
| | batch_size=train_batch_size, |
| | shuffle=True, |
| | num_workers=n_worker, |
| | pin_memory=False, |
| | ) |
| | self.valid = None |
| |
|
| | test_dataset = self.test_dataset(valid_transforms) |
| | if num_replicas is not None: |
| | test_sampler = torch.utils.data.distributed.DistributedSampler( |
| | test_dataset, num_replicas, rank |
| | ) |
| | self.test = torch.utils.data.DataLoader( |
| | test_dataset, |
| | batch_size=test_batch_size, |
| | sampler=test_sampler, |
| | num_workers=n_worker, |
| | pin_memory=False, |
| | ) |
| | else: |
| | self.test = torch.utils.data.DataLoader( |
| | test_dataset, |
| | batch_size=test_batch_size, |
| | shuffle=True, |
| | num_workers=n_worker, |
| | pin_memory=False, |
| | ) |
| |
|
| | if self.valid is None: |
| | self.valid = self.test |
| |
|
| | @staticmethod |
| | def name(): |
| | return "cifar10" |
| |
|
| | @property |
| | def data_shape(self): |
| | return 3, self.active_img_size, self.active_img_size |
| |
|
| | @property |
| | def n_classes(self): |
| | return 10 |
| |
|
| | @property |
| | def save_path(self): |
| | if self._save_path is None: |
| | self._save_path = self.DEFAULT_PATH |
| | if not os.path.exists(self._save_path): |
| | self._save_path = os.path.expanduser("~/dataset/cifar10") |
| | return self._save_path |
| |
|
| | @property |
| | def data_url(self): |
| | raise ValueError("unable to download %s" % self.name()) |
| |
|
| | def train_dataset(self, _transforms): |
| | return datasets.CIFAR10(self.train_path, train=True, transform=_transforms,download=True) |
| | |
| | def test_dataset(self, _transforms): |
| | return datasets.CIFAR10(self.valid_path, train=False, transform=_transforms,download=True) |
| | @property |
| | def train_path(self): |
| | return os.path.join(self.save_path, "train") |
| |
|
| | @property |
| | def valid_path(self): |
| | return os.path.join(self.save_path, "val") |
| |
|
| | @property |
| | def normalize(self): |
| | return transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) |
| |
|
| | def build_train_transform(self, image_size=None, print_log=True): |
| | if image_size is None: |
| | image_size = self.image_size |
| |
|
| | |
| | train_transforms = [ |
| | transforms.RandomCrop(32,padding=4), |
| | transforms.RandomHorizontalFlip(), |
| | |
| | ] |
| | |
| | train_transforms += [ |
| | transforms.ToTensor(), |
| | |
| | ] |
| |
|
| | train_transforms = transforms.Compose(train_transforms) |
| | return train_transforms |
| |
|
| | def build_valid_transform(self, image_size=None): |
| | if image_size is None: |
| | image_size = self.active_img_size |
| | return transforms.Compose([ |
| | transforms.ToTensor(), |
| | |
| | ]) |
| |
|
| | def assign_active_img_size(self, new_img_size): |
| | self.active_img_size = new_img_size |
| | if self.active_img_size not in self._valid_transform_dict: |
| | self._valid_transform_dict[ |
| | self.active_img_size |
| | ] = self.build_valid_transform() |
| | |
| | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] |
| | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] |
| |
|
| | def build_sub_train_loader( |
| | self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None |
| | ): |
| | |
| | if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None: |
| | if num_worker is None: |
| | num_worker = self.train.num_workers |
| |
|
| | n_samples = len(self.train.dataset) |
| | g = torch.Generator() |
| | g.manual_seed(DataProvider.SUB_SEED) |
| | rand_indexes = torch.randperm(n_samples, generator=g).tolist() |
| |
|
| | new_train_dataset = self.train_dataset( |
| | self.build_train_transform( |
| | image_size=self.active_img_size, print_log=False |
| | ) |
| | ) |
| | chosen_indexes = rand_indexes[:n_images] |
| | if num_replicas is not None: |
| | sub_sampler = MyDistributedSampler( |
| | new_train_dataset, |
| | num_replicas, |
| | rank, |
| | True, |
| | np.array(chosen_indexes), |
| | ) |
| | else: |
| | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler( |
| | chosen_indexes |
| | ) |
| | sub_data_loader = torch.utils.data.DataLoader( |
| | new_train_dataset, |
| | batch_size=batch_size, |
| | sampler=sub_sampler, |
| | num_workers=num_worker, |
| | pin_memory=False, |
| | ) |
| | self.__dict__["sub_train_%d" % self.active_img_size] = [] |
| | for images, labels in sub_data_loader: |
| | self.__dict__["sub_train_%d" % self.active_img_size].append( |
| | (images, labels) |
| | ) |
| | return self.__dict__["sub_train_%d" % self.active_img_size] |
| |
|