import os import torch.nn as nn from continuum import ClassIncremental, InstanceIncremental from continuum.datasets import ( CIFAR100, ImageNet100, TinyImageNet200, ImageFolderDataset, Core50, fgvc_aircraft, Caltech101, DTD, EuroSAT, flowers102, food101, MNIST, OxfordPet, SUN397 ) from .utils import get_dataset_class_names class ImageNet1000(ImageFolderDataset): """Continuum dataset for datasetsss with tree-like structure. :param train_folder: The folder of the train data. :param test_folder: The folder of the test data. :param download: Dummy parameter. """ def __init__( self, data_path: str, train: bool = True, download: bool = False, ): super().__init__(data_path=data_path, train=train, download=download) def get_data(self): if self.train: self.data_path = os.path.join(self.data_path, "train") else: self.data_path = os.path.join(self.data_path, "val") return super().get_data() def get_dataset(cfg, is_train, transforms=None): if cfg.dataset == "cifar100": data_path = os.path.join(cfg.dataset_root, cfg.dataset) dataset = CIFAR100( data_path=data_path, download=True, train=is_train, # transforms=transforms ) classes_names = dataset.dataset.classes # elif cfg.dataset == "tiny-imagenet-200": elif cfg.dataset == "tinyimagenet": # data_path = '/data/kangborui/' data_path = os.path.join(cfg.dataset_root, cfg.dataset) dataset = TinyImageNet200( data_path, train=is_train, download=True ) classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) elif cfg.dataset == "imagenet100": data_path = cfg.dataset_root # data_path = os.path.join(cfg.dataset_root, "ImageNet") dataset = ImageNet100( data_path, train=is_train, data_subset=os.path.join('/home/kangborui/ClProject/MoE-Adapters4CL-cross-guild-fusion/cil/dataset_reqs/imagenet100_splits', "train_100.txt" if is_train else "val_100.txt") ) classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) elif cfg.dataset == "imagenet1000": data_path = os.path.join(cfg.dataset_root, cfg.dataset) dataset = ImageNet1000( data_path, train=is_train ) classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) elif cfg.dataset == "core50": data_path = os.path.join(cfg.dataset_root, cfg.dataset) dataset = dataset = Core50( data_path, scenario="domains", classification="category", train=is_train ) classes_names = [ "plug adapters", "mobile phones", "scissors", "light bulbs", "cans", "glasses", "balls", "markers", "cups", "remote controls" ] else: ValueError(f"'{cfg.dataset}' is a invalid dataset.") return dataset, classes_names def build_cl_scenarios(cfg, is_train, transforms) -> nn.Module: dataset, classes_names = get_dataset(cfg, is_train) if cfg.scenario == "class": scenario = ClassIncremental( dataset, initial_increment=cfg.initial_increment, increment=cfg.increment, transformations=transforms.transforms, # Convert Compose into list class_order=cfg.class_order, ) elif cfg.scenario == "domain": scenario = InstanceIncremental( dataset, transformations=transforms.transforms, ) elif cfg.scenario == "task-agnostic": NotImplementedError("Method has not been implemented. Soon be added.") else: ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, " "please choose from {{'class', 'domain', 'task-agnostic'}}.") return scenario, classes_names