| | import os |
| | import torch |
| | import pickle |
| | import random |
| | import numpy as np |
| |
|
| | from PIL import Image |
| | from torchvision import datasets |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.utils.data.distributed import DistributedSampler |
| | from continuum.datasets import TinyImageNet200 |
| | from continuum import ClassIncremental |
| |
|
| | class ContinualDatasets: |
| | def __init__(self, dataset, mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batchsize, num_workers, config): |
| | self.mode = mode |
| | self.task_num = task_num |
| | self.init_cls_num = init_cls_num |
| | self.inc_cls_num = inc_cls_num |
| | self.data_root = data_root |
| | self.cls_map = cls_map |
| | self.trfms = trfms |
| | self.batchsize = batchsize |
| | self.num_workers = num_workers |
| | self.config = config |
| | self.dataset = dataset |
| |
|
| | if self.dataset == 'binary_cifar100': |
| | datasets.CIFAR100(self.data_root, download = True) |
| |
|
| | self.create_loaders() |
| |
|
| | def create_loaders(self): |
| | self.dataloaders = [] |
| |
|
| | if self.dataset == 'tiny-imagenet': |
| |
|
| | if 'class_order' in self.config: |
| | class_order = self.config['class_order'] |
| | else: |
| | class_order = list(range(200)) |
| | random.seed(self.config['seed']) |
| | random.shuffle(class_order) |
| |
|
| | scenario = ClassIncremental( |
| | TinyImageNet200(self.data_root, train=self.mode == 'train', download=True), |
| | initial_increment=self.init_cls_num, |
| | increment=self.inc_cls_num, |
| | class_order=class_order |
| | ) |
| |
|
| | class_ids_per_task = ( |
| | [class_order[:self.init_cls_num]] + |
| | [class_order[i:i + self.inc_cls_num] for i in range(self.init_cls_num, len(class_order), self.inc_cls_num)] |
| | ) |
| |
|
| | with open(os.path.join(os.getcwd(), "core", "data", "dataset_reqs", f"tinyimagenet_classes.txt"), "r") as f: |
| | lines = f.read().splitlines() |
| | classes_names = [line.split("\t")[-1] for line in lines] |
| |
|
| | for t in range(self.task_num): |
| |
|
| | cur_scenario = scenario[t:t+1] |
| |
|
| | dataset = SingleDataset(self.dataset, self.data_root, self.mode, self.init_cls_num, self.inc_cls_num, self.cls_map, self.trfms, init=False) |
| | dataset.images = cur_scenario._x |
| | dataset.labels = cur_scenario._y |
| | dataset.labels_name = [classes_names[class_id] for class_id in class_ids_per_task[t]] |
| |
|
| | self.dataloaders.append(DataLoader( |
| | dataset, |
| | shuffle = True, |
| | batch_size = self.batchsize, |
| | drop_last = False, |
| | num_workers = self.num_workers, |
| | pin_memory=self.config['pin_memory'] |
| | )) |
| |
|
| | else: |
| |
|
| | for i in range(self.task_num): |
| |
|
| | start_idx = 0 if i == 0 else (self.init_cls_num + (i-1) * self.inc_cls_num) |
| | end_idx = start_idx + (self.init_cls_num if i ==0 else self.inc_cls_num) |
| | self.dataloaders.append(DataLoader( |
| | SingleDataset(self.dataset, self.data_root, self.mode, self.init_cls_num, self.inc_cls_num, self.cls_map, self.trfms, start_idx, end_idx), |
| | shuffle = True, |
| | batch_size = self.batchsize, |
| | drop_last = False, |
| | num_workers = self.num_workers, |
| | pin_memory=False |
| | )) |
| |
|
| | def get_loader(self, task_idx): |
| | assert task_idx >= 0 and task_idx < self.task_num |
| | if self.mode == 'train': |
| | return self.dataloaders[task_idx] |
| | else: |
| | return self.dataloaders[:task_idx+1] |
| |
|
| | class ImbalancedDatasets(ContinualDatasets): |
| | def __init__(self, mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batchsize, num_workers, imb_type='exp', imb_factor=0.002, shuffle=False): |
| | self.imb_type = imb_type |
| | self.imb_factor = imb_factor |
| | self.shuffle = shuffle |
| | super().__init__(mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batchsize, num_workers) |
| |
|
| | def create_loaders(self): |
| | self.dataloaders = [] |
| | cls_num = self.init_cls_num + self.inc_cls_num * (self.task_num - 1) |
| | img_num_list = self._get_img_num_per_cls(cls_num, self.imb_type, self.imb_factor) |
| |
|
| | if self.shuffle: |
| | grouped_img_nums = [img_num_list[i:i + self.inc_cls_num] for i in range(0, cls_num, self.inc_cls_num)] |
| | np.random.shuffle(grouped_img_nums) |
| | for group in grouped_img_nums: |
| | np.random.shuffle(group) |
| | shuffled_img_num_list = [num for group in grouped_img_nums for num in group] |
| | img_num_list = shuffled_img_num_list |
| |
|
| | for i in range(self.task_num): |
| | start_idx = 0 if i == 0 else (self.init_cls_num + (i - 1) * self.inc_cls_num) |
| | end_idx = start_idx + (self.init_cls_num if i == 0 else self.inc_cls_num) |
| | dataset = SingleDataset(self.data_root, self.mode, self.cls_map, self.trfms, start_idx, end_idx) |
| |
|
| | new_imgs, new_labels = [], [] |
| | labels_np = np.array(dataset.labels, dtype=np.int64) |
| | classes = np.unique(labels_np) |
| | for the_class, the_img_num in zip(classes, img_num_list[i * self.inc_cls_num:(i + 1) * self.inc_cls_num]): |
| | idx = np.nonzero(labels_np == the_class)[0] |
| | np.random.shuffle(idx) |
| | selec_idx = idx[:the_img_num] |
| | new_imgs.extend([dataset.images[j] for j in selec_idx]) |
| | new_labels.extend([the_class, ] * the_img_num) |
| | dataset.images = new_imgs |
| | dataset.labels = new_labels |
| |
|
| | self.dataloaders.append(DataLoader( |
| | dataset, |
| | batch_size = self.batchsize, |
| | drop_last = False |
| | )) |
| |
|
| | def _get_img_num_per_cls(self, cls_num, imb_type, imb_factor): |
| | img_max = len(os.listdir(os.path.join(self.data_root, self.mode, self.cls_map[0]))) |
| | img_num_per_cls = [] |
| | if imb_type == 'exp': |
| | for cls_idx in range(cls_num): |
| | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) |
| | img_num_per_cls.append(max(int(num), 1)) |
| | elif imb_type == 'exp_re': |
| | for cls_idx in range(cls_num): |
| | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) |
| | img_num_per_cls.append(max(int(num), 1)) |
| | img_num_per_cls.reverse() |
| | elif imb_type == 'exp_max': |
| | cls_per_group = cls_num//self.task_num |
| | for cls_idx in range(cls_num): |
| | if (cls_idx+1)%cls_per_group==1: |
| | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) |
| | img_num_per_cls.append(int(num)) |
| | elif imb_type == 'exp_max_re': |
| | cls_per_group = cls_num//self.task_num |
| | for cls_idx in range(cls_num): |
| | if (cls_idx+1)%cls_per_group==1: |
| | |
| | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) |
| | img_num_per_cls.append(int(num)) |
| | img_num_per_cls.reverse() |
| |
|
| | elif imb_type == 'exp_min': |
| | cls_per_group = cls_num//self.task_num |
| | for cls_idx in range(cls_num): |
| | if (cls_idx+1)%cls_per_group==1: |
| | |
| | num = img_max * (imb_factor**((cls_idx+cls_per_group-1) / (cls_num - 1.0))) |
| | |
| | img_num_per_cls.append(int(num)) |
| |
|
| | elif imb_type == 'half': |
| | cls_per_group = cls_num // self.task_num |
| | ratio = 2 |
| | num = 1 |
| | for cls_idx in range(cls_num): |
| | if num > img_max: |
| | num = img_max |
| | img_num_per_cls.append(int(num)) |
| | if (cls_idx + 1) % cls_per_group == 0: |
| | num *= ratio |
| | img_num_per_cls.reverse() |
| |
|
| | elif imb_type == 'half_re': |
| | cls_per_group = cls_num // self.task_num |
| | ratio = 2 |
| | num = 1 |
| | for cls_idx in range(cls_num): |
| | if num > img_max: |
| | num = img_max |
| | img_num_per_cls.append(int(num)) |
| | if (cls_idx + 1) % cls_per_group == 0: |
| | num *= ratio |
| |
|
| | elif imb_type == 'halfbal': |
| | cls_per_group = cls_num // self.task_num |
| | N = img_max * cls_per_group |
| |
|
| | total = 0 |
| | for i in range(self.task_num): |
| | total += N / (2**i) |
| | print(total) |
| | per_class_count = int(total / cls_num) |
| | img_num_per_cls.extend([per_class_count] * cls_num) |
| |
|
| | elif imb_type == 'oneshot': |
| | img_num_per_cls.extend([1] * cls_num) |
| | elif imb_type == 'step': |
| | for cls_idx in range(cls_num // 2): |
| | img_num_per_cls.append(int(img_max)) |
| | for cls_idx in range(cls_num // 2): |
| | img_num_per_cls.append(int(img_max * imb_factor)) |
| | elif imb_type == 'fewshot': |
| | for cls_idx in range(cls_num): |
| | if cls_idx<50: |
| | num = img_max |
| | else: |
| | num = img_max*0.01 |
| | img_num_per_cls.append(int(num)) |
| | else: |
| | img_num_per_cls.extend([int(img_max)] * cls_num) |
| | return img_num_per_cls |
| |
|
| | class SingleDataset(Dataset): |
| | def __init__(self, dataset, data_root, mode, init_cls_num, inc_cls_num, cls_map, trfms, start_idx=-1, end_idx=-1, init=True): |
| | super().__init__() |
| | self.dataset = dataset |
| | self.data_root = data_root |
| | self.mode = mode |
| | self.init_cls_num = init_cls_num |
| | self.inc_cls_num = inc_cls_num |
| | self.cls_map = cls_map |
| | self.start_idx = start_idx |
| | self.end_idx = end_idx |
| | self.trfms = trfms |
| |
|
| | if init: |
| | self.images, self.labels, self.labels_name = self._init_datalist() |
| |
|
| | def __getitem__(self, idx): |
| | if self.dataset == 'binary_cifar100': |
| |
|
| | image = self.images[idx] |
| | image = Image.fromarray(np.uint8(image)) |
| |
|
| | elif self.dataset == 'tiny-imagenet': |
| | img_path = self.images[idx] |
| | image = Image.open(img_path).convert("RGB") |
| |
|
| | else: |
| | |
| | img_path = self.images[idx] |
| | image = Image.open(os.path.join(self.data_root, self.mode, img_path)).convert("RGB") |
| | |
| | label = self.labels[idx] |
| | image = self.trfms(image) |
| |
|
| | return {"image": image, "label": label} |
| | |
| | def __len__(self,): |
| | return len(self.labels) |
| |
|
| | def _init_datalist(self): |
| |
|
| | imgs, labels, labels_name = [], [], [] |
| |
|
| | if self.dataset == 'binary_cifar100': |
| | |
| | with open(os.path.join(self.data_root, 'cifar-100-python', self.mode), 'rb') as f: |
| | load_data = pickle.load(f, encoding='latin1') |
| |
|
| | for data, label in zip(load_data['data'], load_data['fine_labels']): |
| |
|
| | if label in range(self.start_idx, self.end_idx): |
| | r = data[:1024].reshape(32, 32) |
| | g = data[1024:2048].reshape(32, 32) |
| | b = data[2048:].reshape(32, 32) |
| |
|
| | tt_data = np.dstack((r, g, b)) |
| |
|
| | imgs.append(tt_data) |
| | labels.append(label) |
| | labels_name.append(label) |
| |
|
| | else: |
| |
|
| | for id in range(self.start_idx, self.end_idx): |
| | img_list = [self.cls_map[id] + '/' + pic_path for pic_path in os.listdir(os.path.join(self.data_root, self.mode, self.cls_map[id]))] |
| | imgs.extend(img_list) |
| | labels.extend([id for _ in range(len(img_list))]) |
| | labels_name.append(self.cls_map[id]) |
| | |
| | return imgs, labels, labels_name |
| |
|
| | def get_class_names(self): |
| | return self.labels_name |