Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision | |
| import torchvision.transforms as T | |
| from mmpretrain import datasets as mmdatasets | |
| from mmpretrain.registry import TRANSFORMS | |
| from mmengine.dataset import Compose | |
| from torch import nn | |
| from torch.utils.data import Dataset as TorchDataset | |
| # This holds dataset instantiation functions by (dataset_name) tuple keys | |
| DATASET_REGISTRY = {} | |
| DATASET_PATH = "./datasets" | |
| class MMPretrainWrapper(TorchDataset): | |
| def __init__(self, mmdataset) -> None: | |
| super().__init__() | |
| self.mmdataset = mmdataset | |
| test_pipeline = [ | |
| dict(type='LoadImageFromFile'), | |
| dict(type='ResizeEdge', scale=256, edge='short'), | |
| dict(type='CenterCrop', crop_size=224), | |
| dict(type='PackInputs'), | |
| ] | |
| self.pipeline = self.init_pipeline(test_pipeline) | |
| def init_pipeline(self, pipeline_cfg): | |
| pipeline = Compose( | |
| [TRANSFORMS.build(t) for t in pipeline_cfg]) | |
| return pipeline | |
| def classes(self): | |
| return self.mmdataset.CLASSES | |
| def __len__(self): | |
| return len(self.mmdataset) | |
| def __getitem__(self, index): | |
| sample = self.mmdataset[index] | |
| sample = self.pipeline(sample) | |
| # Our interface expects images in [0-1] | |
| img = sample["inputs"].float() / 255 | |
| return img, sample["data_samples"].gt_label.item() | |
| def register_torchvision_dataset(dataset_name, dataset_cls, dataset_kwargs_train={}, dataset_kwargs_val={}): | |
| def instantiate_dataset(): | |
| train_data = dataset_cls( | |
| root=DATASET_PATH, | |
| train=True, | |
| download=True, | |
| transform=T.ToTensor() | |
| ) | |
| val_data = dataset_cls( | |
| root=DATASET_PATH, | |
| train=False, | |
| download=True, | |
| transform=T.ToTensor() | |
| ) | |
| return train_data, val_data | |
| DATASET_REGISTRY[dataset_name] = instantiate_dataset | |
| def register_mmpretrain_dataset(dataset_name, dataset_cls, dataset_kwargs_train={}, dataset_kwargs_val={}): | |
| def instantiate_dataset(): | |
| train_data = dataset_cls(**dataset_kwargs_train) | |
| val_data = dataset_cls(**dataset_kwargs_val) | |
| train_data = MMPretrainWrapper(train_data) | |
| val_data = MMPretrainWrapper(val_data) | |
| return train_data, val_data | |
| DATASET_REGISTRY[dataset_name] = instantiate_dataset | |
| def register_default_datasets(): | |
| register_torchvision_dataset("cifar10", torchvision.datasets.CIFAR10) | |
| register_torchvision_dataset("cifar100", torchvision.datasets.CIFAR100) | |
| register_mmpretrain_dataset("imagenet", mmdatasets.ImageNet, | |
| dataset_kwargs_train=dict( | |
| data_root = "data/imagenet", | |
| data_prefix = "val", | |
| ann_file = "meta/val.txt" | |
| ), | |
| dataset_kwargs_val=dict( | |
| data_root = "data/imagenet", | |
| data_prefix = "val", | |
| ann_file = "meta/val.txt" | |
| )) | |
| def get_dataset(dataset_name): | |
| """ | |
| Returns an instance of a dataset | |
| dataset_name: Name of desired dataset | |
| """ | |
| if dataset_name not in DATASET_REGISTRY: | |
| raise Exception("Requested dataset not in registry") | |
| return DATASET_REGISTRY[dataset_name]() | |