import os import torch import pickle import random import numpy as np from PIL import Image from torchvision import datasets from torch.utils.data import Dataset, DataLoader from torch.utils.data.distributed import DistributedSampler from continuum.datasets import TinyImageNet200 from continuum import ClassIncremental class ContinualDatasets: def __init__(self, dataset, mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batchsize, num_workers, config): self.mode = mode self.task_num = task_num self.init_cls_num = init_cls_num self.inc_cls_num = inc_cls_num self.data_root = data_root self.cls_map = cls_map self.trfms = trfms self.batchsize = batchsize self.num_workers = num_workers self.config = config self.dataset = dataset if self.dataset == 'binary_cifar100': datasets.CIFAR100(self.data_root, download = True) self.create_loaders() def create_loaders(self): self.dataloaders = [] if self.dataset == 'tiny-imagenet': if 'class_order' in self.config: class_order = self.config['class_order'] else: class_order = list(range(200)) random.seed(self.config['seed']) random.shuffle(class_order) scenario = ClassIncremental( TinyImageNet200(self.data_root, train=self.mode == 'train', download=True), initial_increment=self.init_cls_num, increment=self.inc_cls_num, class_order=class_order ) class_ids_per_task = ( [class_order[:self.init_cls_num]] + [class_order[i:i + self.inc_cls_num] for i in range(self.init_cls_num, len(class_order), self.inc_cls_num)] ) with open(os.path.join(os.getcwd(), "core", "data", "dataset_reqs", f"tinyimagenet_classes.txt"), "r") as f: lines = f.read().splitlines() classes_names = [line.split("\t")[-1] for line in lines] for t in range(self.task_num): cur_scenario = scenario[t:t+1] dataset = SingleDataset(self.dataset, self.data_root, self.mode, self.init_cls_num, self.inc_cls_num, self.cls_map, self.trfms, init=False) dataset.images = cur_scenario._x dataset.labels = cur_scenario._y dataset.labels_name = [classes_names[class_id] for class_id in class_ids_per_task[t]] self.dataloaders.append(DataLoader( dataset, shuffle = True, batch_size = self.batchsize, drop_last = False, num_workers = self.num_workers, pin_memory=self.config['pin_memory'] )) else: for i in range(self.task_num): start_idx = 0 if i == 0 else (self.init_cls_num + (i-1) * self.inc_cls_num) end_idx = start_idx + (self.init_cls_num if i ==0 else self.inc_cls_num) self.dataloaders.append(DataLoader( SingleDataset(self.dataset, self.data_root, self.mode, self.init_cls_num, self.inc_cls_num, self.cls_map, self.trfms, start_idx, end_idx), shuffle = True, batch_size = self.batchsize, drop_last = False, num_workers = self.num_workers, pin_memory=False )) def get_loader(self, task_idx): assert task_idx >= 0 and task_idx < self.task_num if self.mode == 'train': return self.dataloaders[task_idx] else: return self.dataloaders[:task_idx+1] class ImbalancedDatasets(ContinualDatasets): def __init__(self, mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batchsize, num_workers, imb_type='exp', imb_factor=0.002, shuffle=False): self.imb_type = imb_type self.imb_factor = imb_factor self.shuffle = shuffle super().__init__(mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batchsize, num_workers) def create_loaders(self): self.dataloaders = [] cls_num = self.init_cls_num + self.inc_cls_num * (self.task_num - 1) img_num_list = self._get_img_num_per_cls(cls_num, self.imb_type, self.imb_factor) if self.shuffle: grouped_img_nums = [img_num_list[i:i + self.inc_cls_num] for i in range(0, cls_num, self.inc_cls_num)] np.random.shuffle(grouped_img_nums) for group in grouped_img_nums: np.random.shuffle(group) shuffled_img_num_list = [num for group in grouped_img_nums for num in group] img_num_list = shuffled_img_num_list for i in range(self.task_num): start_idx = 0 if i == 0 else (self.init_cls_num + (i - 1) * self.inc_cls_num) end_idx = start_idx + (self.init_cls_num if i == 0 else self.inc_cls_num) dataset = SingleDataset(self.data_root, self.mode, self.cls_map, self.trfms, start_idx, end_idx) new_imgs, new_labels = [], [] labels_np = np.array(dataset.labels, dtype=np.int64) classes = np.unique(labels_np) for the_class, the_img_num in zip(classes, img_num_list[i * self.inc_cls_num:(i + 1) * self.inc_cls_num]): idx = np.nonzero(labels_np == the_class)[0] np.random.shuffle(idx) selec_idx = idx[:the_img_num] new_imgs.extend([dataset.images[j] for j in selec_idx]) new_labels.extend([the_class, ] * the_img_num) dataset.images = new_imgs dataset.labels = new_labels self.dataloaders.append(DataLoader( dataset, batch_size = self.batchsize, drop_last = False )) def _get_img_num_per_cls(self, cls_num, imb_type, imb_factor): img_max = len(os.listdir(os.path.join(self.data_root, self.mode, self.cls_map[0]))) img_num_per_cls = [] if imb_type == 'exp': for cls_idx in range(cls_num): num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) img_num_per_cls.append(max(int(num), 1)) elif imb_type == 'exp_re': for cls_idx in range(cls_num): num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) img_num_per_cls.append(max(int(num), 1)) img_num_per_cls.reverse() elif imb_type == 'exp_max': cls_per_group = cls_num//self.task_num for cls_idx in range(cls_num): if (cls_idx+1)%cls_per_group==1: num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) img_num_per_cls.append(int(num)) elif imb_type == 'exp_max_re': cls_per_group = cls_num//self.task_num for cls_idx in range(cls_num): if (cls_idx+1)%cls_per_group==1: # print(cls_idx) num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) img_num_per_cls.append(int(num)) img_num_per_cls.reverse() elif imb_type == 'exp_min': cls_per_group = cls_num//self.task_num for cls_idx in range(cls_num): if (cls_idx+1)%cls_per_group==1: # print(cls_idx) num = img_max * (imb_factor**((cls_idx+cls_per_group-1) / (cls_num - 1.0))) # print(num) img_num_per_cls.append(int(num)) elif imb_type == 'half': cls_per_group = cls_num // self.task_num ratio = 2 num = 1 for cls_idx in range(cls_num): if num > img_max: num = img_max img_num_per_cls.append(int(num)) if (cls_idx + 1) % cls_per_group == 0: num *= ratio img_num_per_cls.reverse() elif imb_type == 'half_re': cls_per_group = cls_num // self.task_num ratio = 2 num = 1 for cls_idx in range(cls_num): if num > img_max: num = img_max img_num_per_cls.append(int(num)) if (cls_idx + 1) % cls_per_group == 0: num *= ratio elif imb_type == 'halfbal': cls_per_group = cls_num // self.task_num N = img_max * cls_per_group total = 0 for i in range(self.task_num): total += N / (2**i) print(total) per_class_count = int(total / cls_num) img_num_per_cls.extend([per_class_count] * cls_num) elif imb_type == 'oneshot': img_num_per_cls.extend([1] * cls_num) elif imb_type == 'step': for cls_idx in range(cls_num // 2): img_num_per_cls.append(int(img_max)) for cls_idx in range(cls_num // 2): img_num_per_cls.append(int(img_max * imb_factor)) elif imb_type == 'fewshot': for cls_idx in range(cls_num): if cls_idx<50: num = img_max else: num = img_max*0.01 img_num_per_cls.append(int(num)) else: img_num_per_cls.extend([int(img_max)] * cls_num) return img_num_per_cls class SingleDataset(Dataset): def __init__(self, dataset, data_root, mode, init_cls_num, inc_cls_num, cls_map, trfms, start_idx=-1, end_idx=-1, init=True): super().__init__() self.dataset = dataset self.data_root = data_root self.mode = mode self.init_cls_num = init_cls_num self.inc_cls_num = inc_cls_num self.cls_map = cls_map self.start_idx = start_idx self.end_idx = end_idx self.trfms = trfms if init: self.images, self.labels, self.labels_name = self._init_datalist() def __getitem__(self, idx): if self.dataset == 'binary_cifar100': image = self.images[idx] image = Image.fromarray(np.uint8(image)) elif self.dataset == 'tiny-imagenet': img_path = self.images[idx] image = Image.open(img_path).convert("RGB") else: img_path = self.images[idx] image = Image.open(os.path.join(self.data_root, self.mode, img_path)).convert("RGB") label = self.labels[idx] image = self.trfms(image) return {"image": image, "label": label} def __len__(self,): return len(self.labels) def _init_datalist(self): imgs, labels, labels_name = [], [], [] if self.dataset == 'binary_cifar100': with open(os.path.join(self.data_root, 'cifar-100-python', self.mode), 'rb') as f: load_data = pickle.load(f, encoding='latin1') for data, label in zip(load_data['data'], load_data['fine_labels']): if label in range(self.start_idx, self.end_idx): r = data[:1024].reshape(32, 32) g = data[1024:2048].reshape(32, 32) b = data[2048:].reshape(32, 32) tt_data = np.dstack((r, g, b)) imgs.append(tt_data) labels.append(label) labels_name.append(label) else: for id in range(self.start_idx, self.end_idx): img_list = [self.cls_map[id] + '/' + pic_path for pic_path in os.listdir(os.path.join(self.data_root, self.mode, self.cls_map[id]))] imgs.extend(img_list) labels.extend([id for _ in range(len(img_list))]) labels_name.append(self.cls_map[id]) return imgs, labels, labels_name def get_class_names(self): return self.labels_name