|
|
import logging |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import transforms |
|
|
from utils.data import iCIFAR10, iCIFAR100, iImageNet100, iImageNet1000 |
|
|
from tqdm import tqdm |
|
|
from torch.utils.data import DataLoader |
|
|
import os |
|
|
import utils.inc_net |
|
|
from utils import factory |
|
|
import torch |
|
|
import copy |
|
|
import random |
|
|
|
|
|
class DataManager(object): |
|
|
def __init__(self, dataset_name, shuffle, seed, init_cls, increment, attack=False): |
|
|
self.dataset_name = dataset_name |
|
|
self.attack = attack |
|
|
self._setup_data(dataset_name, shuffle, seed, attack=self.attack) |
|
|
assert init_cls <= len(self._class_order), "No enough classes." |
|
|
self._increments = [init_cls] |
|
|
while sum(self._increments) + increment < len(self._class_order): |
|
|
self._increments.append(increment) |
|
|
offset = len(self._class_order) - sum(self._increments) |
|
|
if offset > 0: |
|
|
self._increments.append(offset) |
|
|
|
|
|
@property |
|
|
def nb_tasks(self): |
|
|
return len(self._increments) |
|
|
|
|
|
def get_task_size(self, task): |
|
|
return self._increments[task] |
|
|
|
|
|
def get_accumulate_tasksize(self,task): |
|
|
return sum(self._increments[:task+1]) |
|
|
|
|
|
def get_total_classnum(self): |
|
|
return len(self._class_order) |
|
|
|
|
|
def get_dataset( |
|
|
self, indices, source, mode, appendent=None, ret_data=False, m_rate=None |
|
|
): |
|
|
if source == "train": |
|
|
x, y = self._train_data, self._train_targets |
|
|
elif source == "test": |
|
|
x, y = self._test_data, self._test_targets |
|
|
else: |
|
|
raise ValueError("Unknown data source {}.".format(source)) |
|
|
|
|
|
if mode == "train": |
|
|
if self.attack: |
|
|
trsf = transforms.Compose([*self._test_trsf,]) |
|
|
else: |
|
|
trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) |
|
|
elif mode == "flip": |
|
|
if self.attack: |
|
|
trsf = transforms.Compose( |
|
|
[ |
|
|
*self._test_trsf, |
|
|
transforms.RandomHorizontalFlip(p=1.0), |
|
|
] |
|
|
) |
|
|
else: |
|
|
trsf = transforms.Compose( |
|
|
[ |
|
|
*self._test_trsf, |
|
|
transforms.RandomHorizontalFlip(p=1.0), |
|
|
*self._common_trsf, |
|
|
] |
|
|
) |
|
|
elif mode == "test": |
|
|
if self.attack: |
|
|
trsf = transforms.Compose([*self._test_trsf,]) |
|
|
else: |
|
|
trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) |
|
|
else: |
|
|
raise ValueError("Unknown mode {}.".format(mode)) |
|
|
|
|
|
data, targets = [], [] |
|
|
for idx in indices: |
|
|
if m_rate is None: |
|
|
class_data, class_targets = self._select( |
|
|
x, y, low_range=idx, high_range=idx + 1 |
|
|
) |
|
|
else: |
|
|
class_data, class_targets = self._select_rmm( |
|
|
x, y, low_range=idx, high_range=idx + 1, m_rate=m_rate |
|
|
) |
|
|
data.append(class_data) |
|
|
targets.append(class_targets) |
|
|
|
|
|
if appendent is not None and len(appendent) != 0: |
|
|
appendent_data, appendent_targets = appendent |
|
|
data.append(appendent_data) |
|
|
targets.append(appendent_targets) |
|
|
|
|
|
data, targets = np.concatenate(data), np.concatenate(targets) |
|
|
|
|
|
if ret_data: |
|
|
return data, targets, DummyDataset(data, targets, trsf, self.use_path) |
|
|
else: |
|
|
return DummyDataset(data, targets, trsf, self.use_path) |
|
|
|
|
|
|
|
|
def get_finetune_dataset(self,known_classes,total_classes,source,mode,appendent,type="ratio"): |
|
|
if source == 'train': |
|
|
x, y = self._train_data, self._train_targets |
|
|
elif source == 'test': |
|
|
x, y = self._test_data, self._test_targets |
|
|
else: |
|
|
raise ValueError('Unknown data source {}.'.format(source)) |
|
|
|
|
|
if mode == 'train': |
|
|
trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) |
|
|
elif mode == 'test': |
|
|
trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) |
|
|
else: |
|
|
raise ValueError('Unknown mode {}.'.format(mode)) |
|
|
val_data = [] |
|
|
val_targets = [] |
|
|
|
|
|
old_num_tot = 0 |
|
|
appendent_data, appendent_targets = appendent |
|
|
|
|
|
for idx in range(0, known_classes): |
|
|
append_data, append_targets = self._select(appendent_data, appendent_targets, |
|
|
low_range=idx, high_range=idx+1) |
|
|
num=len(append_data) |
|
|
if num == 0: |
|
|
continue |
|
|
old_num_tot += num |
|
|
val_data.append(append_data) |
|
|
val_targets.append(append_targets) |
|
|
if type == "ratio": |
|
|
new_num_tot = int(old_num_tot*(total_classes-known_classes)/known_classes) |
|
|
elif type == "same": |
|
|
new_num_tot = old_num_tot |
|
|
else: |
|
|
assert 0, "not implemented yet" |
|
|
new_num_average = int(new_num_tot/(total_classes-known_classes)) |
|
|
for idx in range(known_classes,total_classes): |
|
|
class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1) |
|
|
val_indx = np.random.choice(len(class_data),new_num_average, replace=False) |
|
|
val_data.append(class_data[val_indx]) |
|
|
val_targets.append(class_targets[val_indx]) |
|
|
val_data=np.concatenate(val_data) |
|
|
val_targets = np.concatenate(val_targets) |
|
|
return DummyDataset(val_data, val_targets, trsf, self.use_path) |
|
|
|
|
|
def get_dataset_with_split( |
|
|
self, indices, source, mode, appendent=None, val_samples_per_class=0 |
|
|
): |
|
|
if source == "train": |
|
|
x, y = self._train_data, self._train_targets |
|
|
elif source == "test": |
|
|
x, y = self._test_data, self._test_targets |
|
|
else: |
|
|
raise ValueError("Unknown data source {}.".format(source)) |
|
|
|
|
|
if mode == "train": |
|
|
trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) |
|
|
elif mode == "test": |
|
|
trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) |
|
|
else: |
|
|
raise ValueError("Unknown mode {}.".format(mode)) |
|
|
|
|
|
train_data, train_targets = [], [] |
|
|
val_data, val_targets = [], [] |
|
|
for idx in indices: |
|
|
class_data, class_targets = self._select( |
|
|
x, y, low_range=idx, high_range=idx + 1 |
|
|
) |
|
|
val_indx = np.random.choice( |
|
|
len(class_data), val_samples_per_class, replace=False |
|
|
) |
|
|
train_indx = list(set(np.arange(len(class_data))) - set(val_indx)) |
|
|
val_data.append(class_data[val_indx]) |
|
|
val_targets.append(class_targets[val_indx]) |
|
|
train_data.append(class_data[train_indx]) |
|
|
train_targets.append(class_targets[train_indx]) |
|
|
|
|
|
if appendent is not None: |
|
|
appendent_data, appendent_targets = appendent |
|
|
for idx in range(0, int(np.max(appendent_targets)) + 1): |
|
|
append_data, append_targets = self._select( |
|
|
appendent_data, appendent_targets, low_range=idx, high_range=idx + 1 |
|
|
) |
|
|
val_indx = np.random.choice( |
|
|
len(append_data), val_samples_per_class, replace=False |
|
|
) |
|
|
train_indx = list(set(np.arange(len(append_data))) - set(val_indx)) |
|
|
val_data.append(append_data[val_indx]) |
|
|
val_targets.append(append_targets[val_indx]) |
|
|
train_data.append(append_data[train_indx]) |
|
|
train_targets.append(append_targets[train_indx]) |
|
|
|
|
|
train_data, train_targets = np.concatenate(train_data), np.concatenate( |
|
|
train_targets |
|
|
) |
|
|
val_data, val_targets = np.concatenate(val_data), np.concatenate(val_targets) |
|
|
|
|
|
return DummyDataset( |
|
|
train_data, train_targets, trsf, self.use_path |
|
|
), DummyDataset(val_data, val_targets, trsf, self.use_path) |
|
|
|
|
|
def _setup_data(self, dataset_name, shuffle, seed, attack=False): |
|
|
idata = _get_idata(dataset_name) |
|
|
idata.download_data() |
|
|
|
|
|
|
|
|
self._train_data, self._train_targets = idata.train_data, idata.train_targets |
|
|
self._test_data, self._test_targets = idata.test_data, idata.test_targets |
|
|
self.use_path = idata.use_path |
|
|
|
|
|
|
|
|
self._train_trsf = idata.train_trsf |
|
|
self._test_trsf = idata.test_trsf |
|
|
if attack: |
|
|
self._common_trsf = None |
|
|
else: |
|
|
self._common_trsf = idata.common_trsf |
|
|
|
|
|
|
|
|
order = [i for i in range(len(np.unique(self._train_targets)))] |
|
|
if shuffle: |
|
|
np.random.seed(seed) |
|
|
order = np.random.permutation(len(order)).tolist() |
|
|
else: |
|
|
order = idata.class_order |
|
|
self._class_order = order |
|
|
logging.info(self._class_order) |
|
|
|
|
|
|
|
|
self._train_targets = _map_new_class_index( |
|
|
self._train_targets, self._class_order |
|
|
) |
|
|
self._test_targets = _map_new_class_index(self._test_targets, self._class_order) |
|
|
|
|
|
def _select(self, x, y, low_range, high_range): |
|
|
idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] |
|
|
|
|
|
if isinstance(x,np.ndarray): |
|
|
x_return = x[idxes] |
|
|
else: |
|
|
x_return = [] |
|
|
for id in idxes: |
|
|
x_return.append(x[id]) |
|
|
return x_return, y[idxes] |
|
|
|
|
|
def _select_rmm(self, x, y, low_range, high_range, m_rate): |
|
|
assert m_rate is not None |
|
|
if m_rate != 0: |
|
|
idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] |
|
|
selected_idxes = np.random.randint( |
|
|
0, len(idxes), size=int((1 - m_rate) * len(idxes)) |
|
|
) |
|
|
new_idxes = idxes[selected_idxes] |
|
|
new_idxes = np.sort(new_idxes) |
|
|
else: |
|
|
new_idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] |
|
|
return x[new_idxes], y[new_idxes] |
|
|
|
|
|
def getlen(self, index): |
|
|
y = self._train_targets |
|
|
return np.sum(np.where(y == index)) |
|
|
|
|
|
|
|
|
class DummyDataset(Dataset): |
|
|
def __init__(self, images, labels, trsf, use_path=False): |
|
|
assert len(images) == len(labels), "Data size error!" |
|
|
self.images = images |
|
|
self.labels = labels |
|
|
self.trsf = trsf |
|
|
self.use_path = use_path |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if self.use_path: |
|
|
image = self.trsf(pil_loader(self.images[idx])) |
|
|
else: |
|
|
image = self.trsf(Image.fromarray(self.images[idx])) |
|
|
label = self.labels[idx] |
|
|
|
|
|
return idx, image, label |
|
|
|
|
|
|
|
|
def _map_new_class_index(y, order): |
|
|
return np.array(list(map(lambda x: order.index(x), y))) |
|
|
|
|
|
|
|
|
def _get_idata(dataset_name): |
|
|
name = dataset_name.lower() |
|
|
if name == "cifar10": |
|
|
return iCIFAR10() |
|
|
elif name == "cifar100": |
|
|
return iCIFAR100() |
|
|
elif name == "imagenet1000": |
|
|
return iImageNet1000() |
|
|
elif name == "imagenet100": |
|
|
return iImageNet100() |
|
|
else: |
|
|
raise NotImplementedError("Unknown dataset {}.".format(dataset_name)) |
|
|
|
|
|
|
|
|
def get_dataloader(data_manager, batch_size=32, |
|
|
start_class=0, end_class=10, |
|
|
train=False, shuffle=True, num_workers=0): |
|
|
if train: |
|
|
dataset = data_manager.get_dataset(np.arange(start_class, end_class), source="train", mode="train") |
|
|
else: |
|
|
dataset = data_manager.get_dataset(np.arange(start_class, end_class), source="test", mode="test") |
|
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) |
|
|
|
|
|
return loader |
|
|
|
|
|
|
|
|
def pil_loader(path): |
|
|
""" |
|
|
Ref: |
|
|
https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder |
|
|
""" |
|
|
|
|
|
with open(path, "rb") as f: |
|
|
img = Image.open(f) |
|
|
return img.convert("RGB") |
|
|
|
|
|
|
|
|
def accimage_loader(path): |
|
|
""" |
|
|
Ref: |
|
|
https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder |
|
|
accimage is an accelerated Image loader and preprocessor leveraging Intel IPP. |
|
|
accimage is available on conda-forge. |
|
|
""" |
|
|
import accimage |
|
|
|
|
|
try: |
|
|
return accimage.Image(path) |
|
|
except IOError: |
|
|
|
|
|
return pil_loader(path) |
|
|
|
|
|
|
|
|
def default_loader(path): |
|
|
""" |
|
|
Ref: |
|
|
https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder |
|
|
""" |
|
|
from torchvision import get_image_backend |
|
|
|
|
|
if get_image_backend() == "accimage": |
|
|
return accimage_loader(path) |
|
|
else: |
|
|
return pil_loader(path) |
|
|
|
|
|
def load_all_task_models(args, checkpoint_dir, data_manager, batch_size, |
|
|
device='cuda', train=False, weights=None, load_type='model_loader'): |
|
|
if weights == None: |
|
|
model_list = [] |
|
|
|
|
|
loader_list = [] |
|
|
ckpts = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith('.pkl')]) |
|
|
known_classes = 0 |
|
|
|
|
|
if 'model' in load_type: |
|
|
model = factory.get_model(args["model_name"], args) |
|
|
|
|
|
for i, ckpt_file in enumerate(ckpts): |
|
|
if 'model' in load_type: |
|
|
ckpt_path = os.path.join(checkpoint_dir, ckpt_file) |
|
|
ckpt = torch.load(ckpt_path, map_location=device) |
|
|
|
|
|
model.incremental_train(data_manager) |
|
|
model._network.load_state_dict(ckpt['model_state_dict']) |
|
|
model._network.to(device) |
|
|
model._network.eval() |
|
|
model_list.append(copy.deepcopy(model)) |
|
|
model.after_task() |
|
|
|
|
|
if 'loader' in load_type: |
|
|
cur_task = ckpt['tasks'] if 'tasks' in ckpt else int(ckpt_file.split('_')[-1].split('.')[0]) |
|
|
total_classes = known_classes + data_manager.get_task_size(cur_task) |
|
|
|
|
|
if train: |
|
|
dataset = data_manager.get_dataset(np.arange(0, total_classes), source="train", mode="train") |
|
|
else: |
|
|
dataset = data_manager.get_dataset(np.arange(0, total_classes), source="test", mode="test") |
|
|
test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
|
|
loader_list.append(test_loader) |
|
|
known_classes = total_classes |
|
|
|
|
|
return model_list, loader_list |
|
|
|
|
|
else: |
|
|
model = factory.get_model(args["model_name"], args) |
|
|
ckpt = torch.load(weights, map_location=device) |
|
|
model.incremental_train(data_manager) |
|
|
model._network.load_state_dict(ckpt['model_state_dict']) |
|
|
model._network.to(device) |
|
|
model._network.eval() |
|
|
|
|
|
total_classes = 10 |
|
|
if train: |
|
|
dataset = data_manager.get_dataset(np.arange(0, total_classes), source="train", mode="train") |
|
|
else: |
|
|
dataset = data_manager.get_dataset(np.arange(0, total_classes), source="test", mode="test") |
|
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
|
|
|
|
|
return model, loader |
|
|
|
|
|
def load_src_model(model_name, checkpoint_dir, device='cuda'): |
|
|
CL_model_dict = { |
|
|
'FOSTERNet': utils.inc_net.FOSTERNet |
|
|
} |
|
|
model = CL_model_dict["FOSTERNet"] |
|
|
ckpt = torch.load(checkpoint_dir, map_location=device) |
|
|
|
|
|
total_classes = 10 |
|
|
model.update_fc(total_classes) |
|
|
model._network.load_state_dict(ckpt['model_state_dict']) |
|
|
model._network.to(device) |
|
|
return model |
|
|
|
|
|
def load_src_dataset(data_manager, batch_size): |
|
|
total_classes = 10 |
|
|
test_dataset = data_manager.get_dataset(np.arange(0, total_classes), source="train", mode="train") |
|
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
|
|
return test_loader |
|
|
|
|
|
def balanced_sample_from_loaders(loaders, total_batch_size): |
|
|
num_loaders = len(loaders) |
|
|
per_loader_sample = total_batch_size // num_loaders |
|
|
remainder = total_batch_size % num_loaders |
|
|
|
|
|
x_batch, y_batch = [], [] |
|
|
|
|
|
for i, loader in enumerate(loaders): |
|
|
batch_needed = per_loader_sample + (1 if i < remainder else 0) |
|
|
data_iter = iter(loader) |
|
|
current_count = 0 |
|
|
while current_count < batch_needed: |
|
|
x, y = next(data_iter) |
|
|
needed = batch_needed - current_count |
|
|
if x.shape[0] > needed: |
|
|
x = x[:needed] |
|
|
y = y[:needed] |
|
|
x_batch.append(x) |
|
|
y_batch.append(y) |
|
|
current_count += x.shape[0] |
|
|
|
|
|
x_batch = torch.cat(x_batch, dim=0) |
|
|
y_batch = torch.cat(y_batch, dim=0) |
|
|
return x_batch, y_batch |
|
|
|
|
|
|
|
|
class CustomDMDataset(Dataset): |
|
|
def __init__(self, data_dir, transform=None, split='train', test_size=0.2): |
|
|
self.data_dir = data_dir |
|
|
self.transform = transform |
|
|
self.split = split |
|
|
self.test_size = test_size |
|
|
|
|
|
self.classes = sorted(os.listdir(data_dir)) |
|
|
self.image_paths = [] |
|
|
self.labels = [] |
|
|
|
|
|
for label, class_name in enumerate(self.classes): |
|
|
class_folder = os.path.join(data_dir, class_name) |
|
|
if os.path.isdir(class_folder): |
|
|
for img_name in os.listdir(class_folder): |
|
|
img_path = os.path.join(class_folder, img_name) |
|
|
if img_name.endswith(".jpg") or img_name.endswith(".png"): |
|
|
self.image_paths.append(img_path) |
|
|
self.labels.append(label) |
|
|
|
|
|
total_size = len(self.image_paths) |
|
|
test_size = int(total_size * self.test_size) |
|
|
train_size = total_size - test_size |
|
|
|
|
|
indices = list(range(total_size)) |
|
|
random.shuffle(indices) |
|
|
|
|
|
train_indices = indices[:train_size] |
|
|
test_indices = indices[train_size:] |
|
|
|
|
|
if self.split == 'train': |
|
|
self.image_paths = [self.image_paths[i] for i in train_indices] |
|
|
self.labels = [self.labels[i] for i in train_indices] |
|
|
else: |
|
|
self.image_paths = [self.image_paths[i] for i in test_indices] |
|
|
self.labels = [self.labels[i] for i in test_indices] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
img_path = self.image_paths[idx] |
|
|
label = self.labels[idx] |
|
|
img = Image.open(img_path) |
|
|
|
|
|
if self.transform: |
|
|
img = self.transform(img) |
|
|
|
|
|
return img, label |
|
|
|
|
|
|