| |
|
| |
|
| | import os |
| | import torch.nn as nn |
| |
|
| | from continuum import ClassIncremental, InstanceIncremental |
| | from continuum.datasets import ( |
| | CIFAR100, ImageNet100, TinyImageNet200, ImageFolderDataset, Core50, |
| | fgvc_aircraft, Caltech101, DTD, EuroSAT, flowers102, food101, |
| | MNIST, OxfordPet, SUN397 |
| |
|
| | ) |
| | from .utils import get_dataset_class_names |
| |
|
| |
|
| | class ImageNet1000(ImageFolderDataset): |
| | """Continuum dataset for datasetsss with tree-like structure. |
| | :param train_folder: The folder of the train data. |
| | :param test_folder: The folder of the test data. |
| | :param download: Dummy parameter. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | data_path: str, |
| | train: bool = True, |
| | download: bool = False, |
| | ): |
| | super().__init__(data_path=data_path, train=train, download=download) |
| |
|
| | def get_data(self): |
| | if self.train: |
| | self.data_path = os.path.join(self.data_path, "train") |
| | else: |
| | self.data_path = os.path.join(self.data_path, "val") |
| | return super().get_data() |
| |
|
| |
|
| | def get_dataset(cfg, is_train, transforms=None): |
| | if cfg.dataset == "cifar100": |
| | data_path = os.path.join(cfg.dataset_root, cfg.dataset) |
| | dataset = CIFAR100( |
| | data_path=data_path, |
| | download=True, |
| | train=is_train, |
| | |
| | ) |
| | classes_names = dataset.dataset.classes |
| |
|
| | |
| | elif cfg.dataset == "tinyimagenet": |
| | |
| | data_path = os.path.join(cfg.dataset_root, cfg.dataset) |
| | dataset = TinyImageNet200( |
| | data_path, |
| | train=is_train, |
| | download=True |
| | ) |
| | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) |
| | |
| | elif cfg.dataset == "imagenet100": |
| | data_path = cfg.dataset_root |
| | |
| | dataset = ImageNet100( |
| | data_path, |
| | train=is_train, |
| | data_subset=os.path.join('/home/kangborui/ClProject/MoE-Adapters4CL-cross-guild-fusion/cil/dataset_reqs/imagenet100_splits', "train_100.txt" if is_train else "val_100.txt") |
| | ) |
| | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) |
| |
|
| | elif cfg.dataset == "imagenet1000": |
| | data_path = os.path.join(cfg.dataset_root, cfg.dataset) |
| | dataset = ImageNet1000( |
| | data_path, |
| | train=is_train |
| | ) |
| | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) |
| |
|
| | elif cfg.dataset == "core50": |
| | data_path = os.path.join(cfg.dataset_root, cfg.dataset) |
| | dataset = dataset = Core50( |
| | data_path, |
| | scenario="domains", |
| | classification="category", |
| | train=is_train |
| | ) |
| | classes_names = [ |
| | "plug adapters", "mobile phones", "scissors", "light bulbs", "cans", |
| | "glasses", "balls", "markers", "cups", "remote controls" |
| | ] |
| | |
| | else: |
| | ValueError(f"'{cfg.dataset}' is a invalid dataset.") |
| |
|
| | return dataset, classes_names |
| |
|
| |
|
| | def build_cl_scenarios(cfg, is_train, transforms) -> nn.Module: |
| |
|
| | dataset, classes_names = get_dataset(cfg, is_train) |
| |
|
| | if cfg.scenario == "class": |
| | scenario = ClassIncremental( |
| | dataset, |
| | initial_increment=cfg.initial_increment, |
| | increment=cfg.increment, |
| | transformations=transforms.transforms, |
| | class_order=cfg.class_order, |
| | ) |
| |
|
| | elif cfg.scenario == "domain": |
| | scenario = InstanceIncremental( |
| | dataset, |
| | transformations=transforms.transforms, |
| | ) |
| |
|
| | elif cfg.scenario == "task-agnostic": |
| | NotImplementedError("Method has not been implemented. Soon be added.") |
| |
|
| | else: |
| | ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, " |
| | "please choose from {{'class', 'domain', 'task-agnostic'}}.") |
| |
|
| | return scenario, classes_names |