boringKey's picture
Upload 236 files
5fee096 verified
"""
@inproceedings{rebuffi2017icarl,
title={icarl: Incremental classifier and representation learning},
author={Rebuffi, Sylvestre-Alvise and Kolesnikov, Alexander and Sperl, Georg and Lampert, Christoph H},
booktitle={Proceedings of the IEEE conference on Computer Vision and Pattern Recognition},
pages={2001--2010},
year={2017}
}
https://arxiv.org/abs/1611.07725
"""
from typing import Iterator
import torch
from torch import nn
from torch.nn import functional as F
from copy import deepcopy
import numpy as np
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
import PIL
import os
import copy
class Model(nn.Module):
# A model consists with a backbone and a classifier
def __init__(self, backbone, feat_dim, num_class):
super().__init__()
self.backbone = backbone
self.feat_dim = feat_dim
self.num_class = num_class
self.classifier = nn.Linear(feat_dim, num_class)
def forward(self, x):
return self.get_logits(x)
def get_logits(self, x):
logits = self.classifier(self.backbone(x)['features'])
return logits
class ICarl(nn.Module):
def __init__(self, backbone, feat_dim, num_class, **kwargs):
super().__init__()
# device setting
self.device = kwargs['device']
# current task index
self.cur_task_id = 0
# current task class indexes
self.cur_cls_indexes = None
# Build model structure
self.network = Model(backbone, feat_dim, num_class)
# Store old network
self.old_network = None
# the previous class num before this task
self.prev_cls_num = 0
# the total class num containing this task
self.accu_cls_num = 0
self.init_cls_num = kwargs['init_cls_num']
self.inc_cls_num = kwargs['inc_cls_num']
self.task_num = kwargs['task_num']
# class prototype vector
self.class_means = None
# only the current model is optimized
def get_parameters(self, config):
return self.network.parameters()
def observe(self, data):
# get data and labels
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
# compute logits and loss
logits, loss = self.criterion(x, y)
pred = torch.argmax(logits, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0), loss
def inference(self, data):
# if self.class_means is not None:
# print(len(self.class_means), self.accu_cls_num)
if self.class_means is not None and len(self.class_means) == self.accu_cls_num:
# we only test when class mean vector computation is finished.
return self.NCM_classify(data)
else:
# class mean vector for this task have not computed yet,
# call this function after func "after_task" called,
# and return value of this "inference" function is computed
# via model forward logits
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logits = self.network(x)[:, :self.accu_cls_num]
pred = torch.argmax(logits, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0)
def NCM_classify(self, data):
def metric(x, y):
"""Calculate the pair-wise euclidean distance between input tensor `x` and `y`.
Args:
x (Tensor): to be calculated for distance, with shape (N, D)
y (Tensor): to be calculated for distance, with shape (M, D), where D is embedding size.
Returns:
pair euclidean distance tensor with shape (N, M)
and dist[i][j] represent the distance between x[i] and y[j]
"""
n = x.size(0)
m = y.size(0)
x = x.unsqueeze(1).expand(n, m, -1)
y = y.unsqueeze(0).expand(n, m, -1)
return torch.pow(x - y, 2).sum(2) # (N, M)
# using NCM
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
feats = feats = self.network.backbone(x)['features']
feats = feats.view(feats.size(0), -1)
distance = metric(feats, self.class_means)
pred = torch.argmin(distance, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0)
def forward(self, x):
return self.network(x)[:, self.accu_cls_num]
def before_task(self, task_idx, buffer, train_loader, test_loaders):
if self.cur_task_id == 0:
self.accu_cls_num = self.init_cls_num
else:
self.accu_cls_num += self.inc_cls_num
self.cur_cls_indexes = np.arange(self.prev_cls_num, self.accu_cls_num)
def after_task(self, task_idx, buffer, train_loader, test_loaders):
# freeze old network as KD teacher
self.old_network = copy.deepcopy(self.network)
self.old_network.eval()
self.prev_cls_num = self.accu_cls_num
# update buffer
buffer.reduce_old_data(self.cur_task_id, self.accu_cls_num)
val_transform = test_loaders[0].dataset.trfms
buffer.update(self.network, train_loader, val_transform,
self.cur_task_id, self.accu_cls_num, self.cur_cls_indexes,
self.device)
# compute class mean vector via samples in buffer
self.class_means = self.calc_class_mean(buffer,
train_loader,
val_transform,
self.device).to(self.device)
self.cur_task_id += 1
def criterion(self, x, y):
def _KD_loss(pred, soft, T=2):
"""
Compute the knowledge distillation (KD) loss between the predicted logits and the soft target.
Code Reference:
KD loss function is borrowed from: https://github.com/G-U-N/PyCIL/blob/master/models/icarl.py
"""
pred = torch.log_softmax(pred / T, dim=1)
soft = torch.softmax(soft / T, dim=1)
return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
cur_logits = self.network(x)[:, :self.accu_cls_num]
loss_clf = F.cross_entropy(cur_logits, y)
if self.cur_task_id > 0:
old_logits = self.old_network(x)
loss_kd = _KD_loss(
cur_logits[:, : self.prev_cls_num],
old_logits[:, : self.prev_cls_num],
)
loss = loss_clf + loss_kd
else:
loss = loss_clf
return cur_logits, loss
def calc_class_mean(self, buffer, train_loader, val_transform, device):
# mini dataset simulating all samples in the buffer
class miniBufferDataset(Dataset):
def __init__(self, root, mode, image_list, label_list, transforms):
self.data_root = root
self.mode = mode
self.images = image_list
self.labels = label_list
self.transforms = transforms
def __getitem__(self, idx):
img_path = self.images[idx]
label = self.labels[idx]
image = PIL.Image.open(os.path.join(self.data_root, self.mode, img_path)).convert("RGB")
image = self.transforms(image)
return {"image": image, "label": label}
def __len__(self):
return len(self.labels)
root_path = train_loader.dataset.data_root
mode = train_loader.dataset.mode
image_list = buffer.images
label_list = buffer.labels
ds = miniBufferDataset(root_path, mode, image_list, label_list, val_transform)
icarl_loader = DataLoader(ds,
batch_size=train_loader.batch_size,
shuffle=False,
num_workers=train_loader.num_workers,
pin_memory=train_loader.pin_memory)
# compute features for all training samples
extracted_features = []
extracted_targets = []
with torch.no_grad():
self.network.eval()
for data in icarl_loader:
images = data['image'].to(device)
labels = data['label'].to(device)
feats = self.network.backbone(images)['features']
# normalize
extracted_features.append(feats / feats.norm(dim=1).view(-1, 1))
extracted_targets.extend(labels)
extracted_features = torch.cat(extracted_features).cpu()
extracted_targets = torch.stack(extracted_targets).cpu()
all_class_means = []
for curr_cls in np.unique(extracted_targets):
# get all indices from current class
cls_ind = np.where(extracted_targets == curr_cls)[0]
# get all extracted features for current class
cls_feats = extracted_features[cls_ind]
# add the exemplars to the set and normalize
cls_feats_mean = cls_feats.mean(0) / cls_feats.mean(0).norm()
all_class_means.append(cls_feats_mean)
return torch.stack(all_class_means)