| | """ |
| | @inproceedings{rebuffi2017icarl, |
| | title={icarl: Incremental classifier and representation learning}, |
| | author={Rebuffi, Sylvestre-Alvise and Kolesnikov, Alexander and Sperl, Georg and Lampert, Christoph H}, |
| | booktitle={Proceedings of the IEEE conference on Computer Vision and Pattern Recognition}, |
| | pages={2001--2010}, |
| | year={2017} |
| | } |
| | https://arxiv.org/abs/1611.07725 |
| | """ |
| |
|
| | from typing import Iterator |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from copy import deepcopy |
| | import numpy as np |
| | from torch.nn.parameter import Parameter |
| | from torch.utils.data import DataLoader, Dataset |
| | import PIL |
| | import os |
| | import copy |
| |
|
| | class Model(nn.Module): |
| | |
| | def __init__(self, backbone, feat_dim, num_class): |
| | super().__init__() |
| | self.backbone = backbone |
| | self.feat_dim = feat_dim |
| | self.num_class = num_class |
| | self.classifier = nn.Linear(feat_dim, num_class) |
| | |
| | def forward(self, x): |
| | return self.get_logits(x) |
| | |
| | def get_logits(self, x): |
| | logits = self.classifier(self.backbone(x)['features']) |
| | return logits |
| | |
| | |
| |
|
| | class ICarl(nn.Module): |
| | def __init__(self, backbone, feat_dim, num_class, **kwargs): |
| | super().__init__() |
| |
|
| | |
| | self.device = kwargs['device'] |
| | |
| | |
| | self.cur_task_id = 0 |
| |
|
| | |
| | self.cur_cls_indexes = None |
| | |
| | |
| | self.network = Model(backbone, feat_dim, num_class) |
| | |
| | |
| | self.old_network = None |
| | |
| | |
| | self.prev_cls_num = 0 |
| |
|
| | |
| | self.accu_cls_num = 0 |
| |
|
| | |
| | self.init_cls_num = kwargs['init_cls_num'] |
| | self.inc_cls_num = kwargs['inc_cls_num'] |
| | self.task_num = kwargs['task_num'] |
| |
|
| | |
| | self.class_means = None |
| |
|
| |
|
| | |
| | def get_parameters(self, config): |
| | return self.network.parameters() |
| | |
| | |
| | def observe(self, data): |
| | |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| | |
| | |
| | logits, loss = self.criterion(x, y) |
| |
|
| | pred = torch.argmax(logits, dim=1) |
| | acc = torch.sum(pred == y).item() |
| |
|
| | return pred, acc / x.size(0), loss |
| |
|
| |
|
| | def inference(self, data): |
| | |
| | |
| | |
| | |
| | if self.class_means is not None and len(self.class_means) == self.accu_cls_num: |
| | |
| | return self.NCM_classify(data) |
| | |
| | else: |
| | |
| | |
| | |
| | |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| |
|
| | logits = self.network(x)[:, :self.accu_cls_num] |
| | pred = torch.argmax(logits, dim=1) |
| |
|
| | acc = torch.sum(pred == y).item() |
| | return pred, acc / x.size(0) |
| | |
| | |
| |
|
| | def NCM_classify(self, data): |
| |
|
| | def metric(x, y): |
| | """Calculate the pair-wise euclidean distance between input tensor `x` and `y`. |
| | Args: |
| | x (Tensor): to be calculated for distance, with shape (N, D) |
| | y (Tensor): to be calculated for distance, with shape (M, D), where D is embedding size. |
| | |
| | Returns: |
| | pair euclidean distance tensor with shape (N, M) |
| | and dist[i][j] represent the distance between x[i] and y[j] |
| | """ |
| | n = x.size(0) |
| | m = y.size(0) |
| | x = x.unsqueeze(1).expand(n, m, -1) |
| | y = y.unsqueeze(0).expand(n, m, -1) |
| | return torch.pow(x - y, 2).sum(2) |
| |
|
| | |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| |
|
| | feats = feats = self.network.backbone(x)['features'] |
| | feats = feats.view(feats.size(0), -1) |
| | distance = metric(feats, self.class_means) |
| |
|
| | pred = torch.argmin(distance, dim=1) |
| | acc = torch.sum(pred == y).item() |
| |
|
| | return pred, acc / x.size(0) |
| |
|
| |
|
| | def forward(self, x): |
| | return self.network(x)[:, self.accu_cls_num] |
| | |
| | |
| | def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| | if self.cur_task_id == 0: |
| | self.accu_cls_num = self.init_cls_num |
| | else: |
| | self.accu_cls_num += self.inc_cls_num |
| | |
| | self.cur_cls_indexes = np.arange(self.prev_cls_num, self.accu_cls_num) |
| |
|
| |
|
| |
|
| | def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| | |
| | |
| | self.old_network = copy.deepcopy(self.network) |
| | self.old_network.eval() |
| | |
| | self.prev_cls_num = self.accu_cls_num |
| | |
| | |
| | buffer.reduce_old_data(self.cur_task_id, self.accu_cls_num) |
| | |
| |
|
| | val_transform = test_loaders[0].dataset.trfms |
| | buffer.update(self.network, train_loader, val_transform, |
| | self.cur_task_id, self.accu_cls_num, self.cur_cls_indexes, |
| | self.device) |
| | |
| | |
| | self.class_means = self.calc_class_mean(buffer, |
| | train_loader, |
| | val_transform, |
| | self.device).to(self.device) |
| | self.cur_task_id += 1 |
| | |
| | |
| |
|
| | |
| |
|
| | def criterion(self, x, y): |
| | def _KD_loss(pred, soft, T=2): |
| | """ |
| | Compute the knowledge distillation (KD) loss between the predicted logits and the soft target. |
| | Code Reference: |
| | KD loss function is borrowed from: https://github.com/G-U-N/PyCIL/blob/master/models/icarl.py |
| | """ |
| | pred = torch.log_softmax(pred / T, dim=1) |
| | soft = torch.softmax(soft / T, dim=1) |
| | return -1 * torch.mul(soft, pred).sum() / pred.shape[0] |
| |
|
| | cur_logits = self.network(x)[:, :self.accu_cls_num] |
| | loss_clf = F.cross_entropy(cur_logits, y) |
| |
|
| | if self.cur_task_id > 0: |
| | old_logits = self.old_network(x) |
| | loss_kd = _KD_loss( |
| | cur_logits[:, : self.prev_cls_num], |
| | old_logits[:, : self.prev_cls_num], |
| | ) |
| | loss = loss_clf + loss_kd |
| | else: |
| | loss = loss_clf |
| |
|
| | return cur_logits, loss |
| |
|
| |
|
| |
|
| |
|
| | def calc_class_mean(self, buffer, train_loader, val_transform, device): |
| |
|
| | |
| | class miniBufferDataset(Dataset): |
| | def __init__(self, root, mode, image_list, label_list, transforms): |
| | self.data_root = root |
| | self.mode = mode |
| | self.images = image_list |
| | self.labels = label_list |
| | self.transforms = transforms |
| | |
| | def __getitem__(self, idx): |
| | img_path = self.images[idx] |
| | label = self.labels[idx] |
| | image = PIL.Image.open(os.path.join(self.data_root, self.mode, img_path)).convert("RGB") |
| | image = self.transforms(image) |
| | return {"image": image, "label": label} |
| |
|
| | def __len__(self): |
| | return len(self.labels) |
| | |
| | root_path = train_loader.dataset.data_root |
| | mode = train_loader.dataset.mode |
| | image_list = buffer.images |
| | label_list = buffer.labels |
| | ds = miniBufferDataset(root_path, mode, image_list, label_list, val_transform) |
| |
|
| | icarl_loader = DataLoader(ds, |
| | batch_size=train_loader.batch_size, |
| | shuffle=False, |
| | num_workers=train_loader.num_workers, |
| | pin_memory=train_loader.pin_memory) |
| |
|
| | |
| | |
| | extracted_features = [] |
| | extracted_targets = [] |
| | with torch.no_grad(): |
| | self.network.eval() |
| | for data in icarl_loader: |
| | images = data['image'].to(device) |
| | labels = data['label'].to(device) |
| | feats = self.network.backbone(images)['features'] |
| | |
| | extracted_features.append(feats / feats.norm(dim=1).view(-1, 1)) |
| | extracted_targets.extend(labels) |
| |
|
| | extracted_features = torch.cat(extracted_features).cpu() |
| | extracted_targets = torch.stack(extracted_targets).cpu() |
| |
|
| | all_class_means = [] |
| | for curr_cls in np.unique(extracted_targets): |
| | |
| | cls_ind = np.where(extracted_targets == curr_cls)[0] |
| | |
| | cls_feats = extracted_features[cls_ind] |
| | |
| | cls_feats_mean = cls_feats.mean(0) / cls_feats.mean(0).norm() |
| | |
| | all_class_means.append(cls_feats_mean) |
| | |
| | return torch.stack(all_class_means) |
| |
|