Spaces:
Configuration error
Configuration error
| # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
| # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
| # International Conference on Computer Vision (ICCV), 2023 | |
| import copy | |
| import warnings | |
| import torch.utils.data | |
| from torch.utils.data.distributed import DistributedSampler | |
| from src.efficientvit.apps.data_provider.random_resolution import RRSController | |
| from src.efficientvit.models.utils import val2tuple | |
| __all__ = ["parse_image_size", "random_drop_data", "DataProvider"] | |
| def parse_image_size(size: int or str) -> tuple[int, int]: | |
| if isinstance(size, str): | |
| size = [int(val) for val in size.split("-")] | |
| return size[0], size[1] | |
| else: | |
| return val2tuple(size, 2) | |
| def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)): | |
| g = torch.Generator() | |
| g.manual_seed(seed) # set random seed before sampling validation set | |
| rand_indexes = torch.randperm(len(dataset), generator=g).tolist() | |
| dropped_indexes = rand_indexes[:drop_size] | |
| remaining_indexes = rand_indexes[drop_size:] | |
| dropped_dataset = copy.deepcopy(dataset) | |
| for key in keys: | |
| setattr( | |
| dropped_dataset, | |
| key, | |
| [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes], | |
| ) | |
| setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes]) | |
| return dataset, dropped_dataset | |
| class DataProvider: | |
| data_keys = ("samples",) | |
| mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} | |
| SUB_SEED = 937162211 # random seed for sampling subset | |
| VALID_SEED = 2147483647 # random seed for the validation set | |
| name: str | |
| def __init__( | |
| self, | |
| train_batch_size: int, | |
| test_batch_size: int or None, | |
| valid_size: int or float or None, | |
| n_worker: int, | |
| image_size: int or list[int] or str or list[str], | |
| num_replicas: int or None = None, | |
| rank: int or None = None, | |
| train_ratio: float or None = None, | |
| drop_last: bool = False, | |
| ): | |
| warnings.filterwarnings("ignore") | |
| super().__init__() | |
| # batch_size & valid_size | |
| self.train_batch_size = train_batch_size | |
| self.test_batch_size = test_batch_size or self.train_batch_size | |
| self.valid_size = valid_size | |
| # image size | |
| if isinstance(image_size, list): | |
| self.image_size = [parse_image_size(size) for size in image_size] | |
| self.image_size.sort() # e.g., 160 -> 224 | |
| RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size) | |
| self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1] | |
| else: | |
| self.image_size = parse_image_size(image_size) | |
| RRSController.IMAGE_SIZE_LIST = [self.image_size] | |
| self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size | |
| # distributed configs | |
| self.num_replicas = num_replicas | |
| self.rank = rank | |
| # build datasets | |
| train_dataset, val_dataset, test_dataset = self.build_datasets() | |
| if train_ratio is not None and train_ratio < 1.0: | |
| assert 0 < train_ratio < 1 | |
| _, train_dataset = random_drop_data( | |
| train_dataset, | |
| int(train_ratio * len(train_dataset)), | |
| self.SUB_SEED, | |
| self.data_keys, | |
| ) | |
| # build data loader | |
| self.train = self.build_dataloader( | |
| train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True | |
| ) | |
| self.valid = self.build_dataloader( | |
| val_dataset, test_batch_size, n_worker, drop_last=False, train=False | |
| ) | |
| self.test = self.build_dataloader( | |
| test_dataset, test_batch_size, n_worker, drop_last=False, train=False | |
| ) | |
| if self.valid is None: | |
| self.valid = self.test | |
| self.sub_train = None | |
| def data_shape(self) -> tuple[int, ...]: | |
| return 3, self.active_image_size[0], self.active_image_size[1] | |
| def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any: | |
| raise NotImplementedError | |
| def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any: | |
| raise NotImplementedError | |
| def build_datasets(self) -> tuple[any, any, any]: | |
| raise NotImplementedError | |
| def build_dataloader( | |
| self, | |
| dataset: any or None, | |
| batch_size: int, | |
| n_worker: int, | |
| drop_last: bool, | |
| train: bool, | |
| ): | |
| if dataset is None: | |
| return None | |
| if isinstance(self.image_size, list) and train: | |
| from efficientvit.apps.data_provider.random_resolution._data_loader import \ | |
| RRSDataLoader | |
| dataloader_class = RRSDataLoader | |
| else: | |
| dataloader_class = torch.utils.data.DataLoader | |
| if self.num_replicas is None: | |
| return dataloader_class( | |
| dataset=dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=n_worker, | |
| pin_memory=True, | |
| drop_last=drop_last, | |
| ) | |
| else: | |
| sampler = DistributedSampler(dataset, self.num_replicas, self.rank) | |
| return dataloader_class( | |
| dataset=dataset, | |
| batch_size=batch_size, | |
| sampler=sampler, | |
| num_workers=n_worker, | |
| pin_memory=True, | |
| drop_last=drop_last, | |
| ) | |
| def set_epoch(self, epoch: int) -> None: | |
| RRSController.set_epoch(epoch, len(self.train)) | |
| if isinstance(self.train.sampler, DistributedSampler): | |
| self.train.sampler.set_epoch(epoch) | |
| def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None: | |
| self.active_image_size = val2tuple(new_size, 2) | |
| new_transform = self.build_valid_transform(self.active_image_size) | |
| # change the transform of the valid and test set | |
| self.valid.dataset.transform = self.test.dataset.transform = new_transform | |
| def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]: | |
| if self.valid_size is not None: | |
| if 0 < self.valid_size < 1: | |
| valid_size = int(self.valid_size * len(train_dataset)) | |
| else: | |
| assert self.valid_size >= 1 | |
| valid_size = int(self.valid_size) | |
| train_dataset, val_dataset = random_drop_data( | |
| train_dataset, | |
| valid_size, | |
| self.VALID_SEED, | |
| self.data_keys, | |
| ) | |
| val_dataset.transform = valid_transform | |
| else: | |
| val_dataset = None | |
| return train_dataset, val_dataset | |
| def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any: | |
| # used for resetting BN running statistics | |
| if self.sub_train is None: | |
| self.sub_train = {} | |
| if self.active_image_size in self.sub_train: | |
| return self.sub_train[self.active_image_size] | |
| # construct dataset and dataloader | |
| train_dataset = copy.deepcopy(self.train.dataset) | |
| if n_samples < len(train_dataset): | |
| _, train_dataset = random_drop_data( | |
| train_dataset, | |
| n_samples, | |
| self.SUB_SEED, | |
| self.data_keys, | |
| ) | |
| RRSController.ACTIVE_SIZE = self.active_image_size | |
| train_dataset.transform = self.build_train_transform( | |
| image_size=self.active_image_size | |
| ) | |
| data_loader = self.build_dataloader( | |
| train_dataset, batch_size, self.train.num_workers, True, False | |
| ) | |
| # pre-fetch data | |
| self.sub_train[self.active_image_size] = [ | |
| data | |
| for data in data_loader | |
| for _ in range(max(1, n_samples // len(train_dataset))) | |
| ] | |
| return self.sub_train[self.active_image_size] | |