| |
| |
| |
|
|
| import copy |
| import warnings |
|
|
| import torch.utils.data |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| from src.efficientvit.apps.data_provider.random_resolution import RRSController |
| from src.efficientvit.models.utils import val2tuple |
|
|
| __all__ = ["parse_image_size", "random_drop_data", "DataProvider"] |
|
|
|
|
| def parse_image_size(size: int or str) -> tuple[int, int]: |
| if isinstance(size, str): |
| size = [int(val) for val in size.split("-")] |
| return size[0], size[1] |
| else: |
| return val2tuple(size, 2) |
|
|
|
|
| def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)): |
| g = torch.Generator() |
| g.manual_seed(seed) |
| rand_indexes = torch.randperm(len(dataset), generator=g).tolist() |
|
|
| dropped_indexes = rand_indexes[:drop_size] |
| remaining_indexes = rand_indexes[drop_size:] |
|
|
| dropped_dataset = copy.deepcopy(dataset) |
| for key in keys: |
| setattr( |
| dropped_dataset, |
| key, |
| [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes], |
| ) |
| setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes]) |
| return dataset, dropped_dataset |
|
|
|
|
| class DataProvider: |
| data_keys = ("samples",) |
| mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} |
| SUB_SEED = 937162211 |
| VALID_SEED = 2147483647 |
|
|
| name: str |
|
|
| def __init__( |
| self, |
| train_batch_size: int, |
| test_batch_size: int or None, |
| valid_size: int or float or None, |
| n_worker: int, |
| image_size: int or list[int] or str or list[str], |
| num_replicas: int or None = None, |
| rank: int or None = None, |
| train_ratio: float or None = None, |
| drop_last: bool = False, |
| ): |
| warnings.filterwarnings("ignore") |
| super().__init__() |
|
|
| |
| self.train_batch_size = train_batch_size |
| self.test_batch_size = test_batch_size or self.train_batch_size |
| self.valid_size = valid_size |
|
|
| |
| if isinstance(image_size, list): |
| self.image_size = [parse_image_size(size) for size in image_size] |
| self.image_size.sort() |
| RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size) |
| self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1] |
| else: |
| self.image_size = parse_image_size(image_size) |
| RRSController.IMAGE_SIZE_LIST = [self.image_size] |
| self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size |
|
|
| |
| self.num_replicas = num_replicas |
| self.rank = rank |
|
|
| |
| train_dataset, val_dataset, test_dataset = self.build_datasets() |
|
|
| if train_ratio is not None and train_ratio < 1.0: |
| assert 0 < train_ratio < 1 |
| _, train_dataset = random_drop_data( |
| train_dataset, |
| int(train_ratio * len(train_dataset)), |
| self.SUB_SEED, |
| self.data_keys, |
| ) |
|
|
| |
| self.train = self.build_dataloader( |
| train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True |
| ) |
| self.valid = self.build_dataloader( |
| val_dataset, test_batch_size, n_worker, drop_last=False, train=False |
| ) |
| self.test = self.build_dataloader( |
| test_dataset, test_batch_size, n_worker, drop_last=False, train=False |
| ) |
| if self.valid is None: |
| self.valid = self.test |
| self.sub_train = None |
|
|
| @property |
| def data_shape(self) -> tuple[int, ...]: |
| return 3, self.active_image_size[0], self.active_image_size[1] |
|
|
| def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any: |
| raise NotImplementedError |
|
|
| def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any: |
| raise NotImplementedError |
|
|
| def build_datasets(self) -> tuple[any, any, any]: |
| raise NotImplementedError |
|
|
| def build_dataloader( |
| self, |
| dataset: any or None, |
| batch_size: int, |
| n_worker: int, |
| drop_last: bool, |
| train: bool, |
| ): |
| if dataset is None: |
| return None |
| if isinstance(self.image_size, list) and train: |
| from efficientvit.apps.data_provider.random_resolution._data_loader import \ |
| RRSDataLoader |
|
|
| dataloader_class = RRSDataLoader |
| else: |
| dataloader_class = torch.utils.data.DataLoader |
| if self.num_replicas is None: |
| return dataloader_class( |
| dataset=dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=n_worker, |
| pin_memory=True, |
| drop_last=drop_last, |
| ) |
| else: |
| sampler = DistributedSampler(dataset, self.num_replicas, self.rank) |
| return dataloader_class( |
| dataset=dataset, |
| batch_size=batch_size, |
| sampler=sampler, |
| num_workers=n_worker, |
| pin_memory=True, |
| drop_last=drop_last, |
| ) |
|
|
| def set_epoch(self, epoch: int) -> None: |
| RRSController.set_epoch(epoch, len(self.train)) |
| if isinstance(self.train.sampler, DistributedSampler): |
| self.train.sampler.set_epoch(epoch) |
|
|
| def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None: |
| self.active_image_size = val2tuple(new_size, 2) |
| new_transform = self.build_valid_transform(self.active_image_size) |
| |
| self.valid.dataset.transform = self.test.dataset.transform = new_transform |
|
|
| def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]: |
| if self.valid_size is not None: |
| if 0 < self.valid_size < 1: |
| valid_size = int(self.valid_size * len(train_dataset)) |
| else: |
| assert self.valid_size >= 1 |
| valid_size = int(self.valid_size) |
| train_dataset, val_dataset = random_drop_data( |
| train_dataset, |
| valid_size, |
| self.VALID_SEED, |
| self.data_keys, |
| ) |
| val_dataset.transform = valid_transform |
| else: |
| val_dataset = None |
| return train_dataset, val_dataset |
|
|
| def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any: |
| |
| if self.sub_train is None: |
| self.sub_train = {} |
| if self.active_image_size in self.sub_train: |
| return self.sub_train[self.active_image_size] |
|
|
| |
| train_dataset = copy.deepcopy(self.train.dataset) |
| if n_samples < len(train_dataset): |
| _, train_dataset = random_drop_data( |
| train_dataset, |
| n_samples, |
| self.SUB_SEED, |
| self.data_keys, |
| ) |
| RRSController.ACTIVE_SIZE = self.active_image_size |
| train_dataset.transform = self.build_train_transform( |
| image_size=self.active_image_size |
| ) |
| data_loader = self.build_dataloader( |
| train_dataset, batch_size, self.train.num_workers, True, False |
| ) |
|
|
| |
| self.sub_train[self.active_image_size] = [ |
| data |
| for data in data_loader |
| for _ in range(max(1, n_samples // len(train_dataset))) |
| ] |
|
|
| return self.sub_train[self.active_image_size] |
|
|