# -*- coding: utf-8 -*- """ @inproceedings{zhao2020maintaining, title={Maintaining discrimination and fairness in class incremental learning}, author={Zhao, Bowen and Xiao, Xi and Gan, Guojun and Zhang, Bin and Xia, Shu-Tao}, booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)}, pages={13208--13217}, year={2020} } https://arxiv.org/abs/1911.07053 Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/wa.py, https://github.com/G-U-N/PyCIL/blob/master/utils/inc_net.py. """ import torch from torch import nn import copy from torch.nn import functional as F import numpy as np from .finetune import Finetune def KD_loss(pred, soft, T=2): ''' Code Reference: https://github.com/G-U-N/PyCIL/blob/master/models/wa.py Compute the knowledge distillation loss. Args: pred (torch.Tensor): Predictions of the model. soft (torch.Tensor): Soft targets. T (float): Temperature parameter for softening the predictions. Default is 2. Returns: torch.Tensor: Knowledge distillation loss. ''' pred = torch.log_softmax(pred / T, dim=1) soft = torch.softmax(soft / T, dim=1) return -1 * torch.mul(soft, pred).sum() / pred.shape[0] class IncrementalModel(nn.Module): ''' Code Reference: https://github.com/G-U-N/PyCIL/blob/master/utils/inc_net.py A model consists with a backbone and a classifier. Args: backbone (nn.Module): Backbone network. feat_dim (int): Dimension of the extracted features. num_class (int): Number of classes in the dataset. ''' def __init__(self, backbone, feat_dim, num_class): super().__init__() self.backbone = backbone self.feat_dim = feat_dim self.num_class = num_class self.classifier = None def forward(self, x): return self.get_logits(x) def get_logits(self, x): ''' Compute logits for the input data. Args: x (torch.Tensor): Input data. Returns: torch.Tensor: Logits of the input data. ''' logits = self.classifier(self.backbone(x)['features']) return logits def update_classifier(self, number_classes): ''' Incrementally update the classifier with deepcopy. Args: number_classes (int): Number of classes after update. ''' classifier = nn.Linear(self.feat_dim, number_classes) if self.classifier is not None: number_output = self.classifier.out_features weight = copy.deepcopy(self.classifier.weight.data) bias = copy.deepcopy(self.classifier.bias.data) classifier.weight.data[:number_output] = weight classifier.bias.data[:number_output] = bias del self.classifier self.classifier = classifier def classifier_weight_align(self, incremental_number): ''' Align the weight of the classifier after every task. Args: incremental_number (int): Number of classes added in the current task. ''' weights = self.classifier.weight.data new_norm = torch.norm(weights[-incremental_number:, :], p=2, dim=1) old_norm = torch.norm(weights[:-incremental_number, :], p=2, dim=1) new_mean = torch.mean(new_norm) old_mean = torch.mean(old_norm) gamma = old_mean / new_mean self.classifier.weight.data[-incremental_number:, :] *= gamma def forward(self, x): return self.get_logits(x) def get_logits(self, x): logits = self.classifier(self.backbone(x)['features']) return logits def freeze(self): ''' Freeze the model parameters. ''' for param in self.parameters(): param.requires_grad = False self.eval() return self def extract_vector(self, x): ''' Extract features from the backbone network. Args: x (torch.Tensor): Input data. Returns: torch.Tensor: Extracted features. ''' return self.backbone(x)["features"] class WA(Finetune): def __init__(self, backbone, feat_dim, num_class, **kwargs): super().__init__(backbone, feat_dim, num_class, **kwargs) self.network = IncrementalModel(self.backbone, feat_dim, kwargs['init_cls_num']) self.device = kwargs['device'] self.old_network = None self.known_classes = 0 self.total_classes = 0 self.task_idx = 0 # For buffer update self.total_classes_indexes = 0 def observe(self, data): ''' Do every current task. Args: data (dict): Dictionary containing input data and labels. Returns: tuple: Tuple containing predictions, accuracy, and loss. ''' x, y = data['image'].to(self.device), data['label'].to(self.device) self.network.to(self.device) if self.old_network: self.old_network.to(self.device) logits = self.network(x) loss = F.cross_entropy(logits, y) if self.task_idx > 0: kd_lambda = self.known_classes / self.total_classes loss_kd = KD_loss( logits[:, : self.known_classes], self.old_network(x), ) loss = (1 - kd_lambda) * loss + kd_lambda * loss_kd pred = torch.argmax(logits, dim=1) acc = torch.sum(pred == y).item() return pred, acc / x.size(0), loss def inference(self, data): ''' Perform inference on the input data. Args: data (dict): Dictionary containing input data and labels. Returns: tuple: Tuple containing predictions and accuracy. ''' x, y = data['image'].to(self.device), data['label'].to(self.device) logits = self.network(x) pred = torch.argmax(logits, dim=1) acc = torch.sum(pred == y).item() return pred, acc / x.size(0) def forward(self, x): return self.network(x) def before_task(self, task_idx, buffer, train_loader, test_loaders): ''' Do before every task for task initialization. Args: task_idx (int): Index of the current task. buffer (Buffer): Buffer object. train_loader (DataLoader): DataLoader for training data. test_loaders (list): List of DataLoaders for test data. ''' self.total_classes += self.kwargs['init_cls_num'] self.network.update_classifier(self.total_classes) self.total_classes_indexes = np.arange(self.known_classes, self.total_classes) def after_task(self, task_idx, buffer, train_loader, test_loaders): ''' Do after every task for updating the model. Args: task_idx (int): Index of the current task. buffer (Buffer): Buffer object. train_loader (DataLoader): DataLoader for training data. test_loaders (list): List of DataLoaders for test data. ''' if self.task_idx > 0: self.network.classifier_weight_align(self.total_classes - self.known_classes) self.old_network = copy.deepcopy(self.network).freeze() self.known_classes = self.total_classes # update buffer buffer.reduce_old_data(self.task_idx, self.total_classes) val_transform = test_loaders[0].dataset.trfms buffer.update(self.network, train_loader, val_transform, self.task_idx, self.total_classes, self.total_classes_indexes, self.device) self.task_idx += 1