| | """
|
| | @inproceedings{DBLP:conf/iccv/ShiY23,
|
| | title = {Prototype Reminiscence and Augmented Asymmetric Knowledge Aggregation for Non-Exemplar Class-Incremental Learning},
|
| | author = {Shi, Wuxuan and Ye, Mang},
|
| | booktitle = {2023 IEEE/CVF International Conference on Computer Vision (ICCV)},
|
| | pages = {1772-1781},
|
| | publisher = {Computer Vision Foundation / {IEEE}},
|
| | year = {2023}
|
| | }
|
| |
|
| | https://openaccess.thecvf.com/content/ICCV2023/papers/Shi_Prototype_Reminiscence_and_Augmented_Asymmetric_Knowledge_Aggregation_for_Non-Exemplar_Class-Incremental_ICCV_2023_paper.pdf
|
| |
|
| | Adapted from https://github.com/ShiWuxuan/PRAKA
|
| | """
|
| |
|
| | from torch.nn import functional as F
|
| | import os
|
| | import numpy as np
|
| | import torch
|
| | import torch.nn as nn
|
| | import math
|
| | import copy
|
| | from core.model import Finetune
|
| |
|
| | class joint_network(nn.Module):
|
| | def __init__(self, numclass, feature_extractor):
|
| | '''
|
| | Code Reference:
|
| | https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py
|
| | '''
|
| | super(joint_network, self).__init__()
|
| | self.feature = feature_extractor
|
| | self.fc = nn.Linear(512, numclass * 4, bias=True)
|
| | self.classifier = nn.Linear(512, numclass, bias=True)
|
| |
|
| | def forward(self, input):
|
| | '''
|
| | Code Reference:
|
| | https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py
|
| | '''
|
| | x = self.feature(input)
|
| | x = self.classifier(x)
|
| | return x
|
| |
|
| | def Incremental_learning(self, numclass):
|
| | '''
|
| | Update the fully connected (fc) layer and classifier layer to accommodate the new number of classes.
|
| |
|
| | This function modifies the output dimensions of the model's fully connected layer (`fc`)
|
| | and the classifier layer based on the total number of classes after the current task.
|
| | It ensures that the new layers retain the weights and biases from the previous configuration
|
| | for the classes that were previously learned.
|
| |
|
| | Parameters:
|
| | - numclass (int): The total number of classes after the current task, including both old and new classes.
|
| |
|
| | Notes:
|
| | - The `fc` layer's output dimension is set to `numclass * 4`.
|
| | - The classifier layer is adjusted to match the new total number of classes, while retaining the previously learned weights and biases.
|
| |
|
| | Code Reference:
|
| | https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py
|
| | '''
|
| | weight = self.fc.weight.data
|
| | bias = self.fc.bias.data
|
| | in_feature = self.fc.in_features
|
| | out_feature = self.fc.out_features
|
| |
|
| | self.fc = nn.Linear(in_feature, numclass * 4, bias=True)
|
| | self.fc.weight.data[:out_feature] = weight[:out_feature]
|
| | self.fc.bias.data[:out_feature] = bias[:out_feature]
|
| |
|
| | weight = self.classifier.weight.data
|
| | bias = self.classifier.bias.data
|
| | in_feature = self.classifier.in_features
|
| | out_feature = self.classifier.out_features
|
| |
|
| | self.classifier = nn.Linear(in_feature, numclass, bias=True)
|
| | self.classifier.weight.data[:out_feature] = weight[:out_feature]
|
| | self.classifier.bias.data[:out_feature] = bias[:out_feature]
|
| |
|
| | def feature_extractor(self, inputs):
|
| | '''
|
| | Code Reference:
|
| | https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py
|
| | '''
|
| | return self.feature(inputs)
|
| |
|
| | class PRAKA(nn.Module):
|
| | def __init__(self, backbone, feat_dim, num_class, **kwargs):
|
| |
|
| | super().__init__()
|
| | self.device = kwargs['device']
|
| | self.kwargs = kwargs
|
| | self.size = 32
|
| |
|
| | encoder = backbone
|
| | self.model = joint_network(kwargs["init_cls_num"], encoder)
|
| | self.radius = 0
|
| | self.prototype = None
|
| | self.numsamples = None
|
| | self.numclass = kwargs["init_cls_num"]
|
| | self.task_size = kwargs["inc_cls_num"]
|
| | self.old_model = None
|
| |
|
| | self.task_idx = 0
|
| |
|
| | def before_task(self, task_idx, buffer, train_loader, test_loaders):
|
| | self.task_idx = task_idx
|
| | if task_idx > 0:
|
| | self.model.Incremental_learning(self.numclass)
|
| | self.model.to(self.device)
|
| |
|
| | def observe(self, data):
|
| | '''
|
| | Processes a batch of training data to compute predictions, accuracy, and loss.
|
| |
|
| | Parameters:
|
| | - data: Dictionary containing the batch of training samples
|
| | - 'image': Tensor of input images
|
| | - 'label': Tensor of ground truth labels
|
| |
|
| | Returns:
|
| | - predictions: Tensor of predicted class labels for the input images
|
| | - accuracy: Float value representing the accuracy of the model on the current batch
|
| | - loss: Float value representing the computed loss for the batch
|
| |
|
| | Description:
|
| | This function is called during the training phase. It performs the following steps:
|
| | 1. Extracts the images and labels from the provided data dictionary and transfers them to the device.
|
| | 2. Augments the images by rotating them by 0, 90, 180, and 270 degrees, and creates corresponding labels for these augmented images.
|
| | 3. Computes the loss using the augmented images and labels.
|
| | 4. Evaluates the model's performance on the current batch by calculating the accuracy and loss.
|
| | 5. Returns the predictions, accuracy, and loss for the batch.
|
| |
|
| | Example Usage:
|
| | predictions, accuracy, loss = observe(data)
|
| | '''
|
| | images, labels = data['image'].to(self.device), data['label'].to(self.device)
|
| |
|
| |
|
| | images = torch.stack([torch.rot90(images, k, (2, 3)) for k in range(4)], 1)
|
| | images = images.view(-1, 3, self.size, self.size)
|
| |
|
| | joint_labels = torch.stack([labels * 4 + k for k in range(4)], 1).view(-1)
|
| | if self.task_idx == 0:
|
| | old_class = 0
|
| | else:
|
| | old_class = self.kwargs['init_cls_num'] + self.kwargs['inc_cls_num'] * (self.task_idx - 1)
|
| |
|
| | loss, single_preds = self._compute_loss(images, joint_labels, labels, old_class)
|
| |
|
| | preds = torch.argmax(single_preds, dim=-1)
|
| | return preds, (preds == labels).sum().item() / len(labels), loss
|
| |
|
| | def inference(self, data):
|
| | '''
|
| | Performs inference on a batch of test samples and computes the classification results and accuracy.
|
| |
|
| | Parameters:
|
| | - data: Dictionary containing the batch of test samples
|
| | - 'image': Tensor of input images
|
| | - 'label': Tensor of ground truth labels
|
| |
|
| | Returns:
|
| | - predictions: Tensor of predicted class labels for the input images
|
| | - accuracy: Float value representing the accuracy of the model on the current batch
|
| |
|
| | Example Usage:
|
| | predictions, accuracy = inference(data)
|
| | '''
|
| |
|
| | imgs, labels = data['image'].to(self.device), data['label'].to(self.device)
|
| |
|
| | preds = torch.argmax(self.model(imgs), dim=-1)
|
| |
|
| | return preds, (preds == labels).sum().item() / len(labels)
|
| |
|
| | def _compute_loss(self, imgs, joint_labels, labels, old_class=0):
|
| | '''
|
| | Computes the loss for a batch of images and labels.
|
| |
|
| | Parameters:
|
| | - imgs: Tensor of input images
|
| | - joint_labels: Tensor of labels for images augmented with rotations (0°, 90°, 180°, 270°)
|
| | - labels: Tensor of ground truth labels for the images
|
| | - old_class: Integer indicating the number of old classes (default is 0)
|
| |
|
| | Returns:
|
| | - loss: Scalar tensor representing the total computed loss
|
| | - preds: Tensor of predictions for the original (non-augmented) images
|
| |
|
| | Example Usage:
|
| | loss, preds = self._compute_loss(imgs, joint_labels, labels, old_class)
|
| |
|
| |
|
| | Code Reference:
|
| | https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/jointSSL.py
|
| | '''
|
| |
|
| | feature = self.model.feature(imgs)
|
| |
|
| |
|
| | joint_preds = self.model.fc(feature)
|
| | single_preds = self.model.classifier(feature)[::4]
|
| | joint_preds, joint_labels, single_preds, labels = joint_preds.to(self.device), joint_labels.to(self.device), single_preds.to(self.device), labels.to(self.device)
|
| | joint_loss = nn.CrossEntropyLoss()(joint_preds/self.kwargs["temp"], joint_labels)
|
| | single_loss = nn.CrossEntropyLoss()(single_preds/self.kwargs["temp"], labels)
|
| |
|
| |
|
| | agg_preds = 0
|
| | for i in range(4):
|
| | agg_preds = agg_preds + joint_preds[i::4, i::4] / 4
|
| |
|
| | distillation_loss = F.kl_div(F.log_softmax(single_preds, 1),
|
| | F.softmax(agg_preds.detach(), 1),
|
| | reduction='batchmean')
|
| | if old_class == 0:
|
| | return joint_loss + single_loss + distillation_loss, single_preds
|
| | else:
|
| | feature_old = self.old_model.feature(imgs)
|
| |
|
| | loss_kd = torch.dist(feature, feature_old, 2)
|
| |
|
| |
|
| | proto_aug = []
|
| | proto_aug_label = []
|
| | old_class_list = list(self.prototype.keys())
|
| | for _ in range(feature.shape[0] // 4):
|
| | i = np.random.randint(0, feature.shape[0])
|
| | np.random.shuffle(old_class_list)
|
| | lam = np.random.beta(0.5, 0.5)
|
| | if lam > 0.6:
|
| | lam = lam * 0.6
|
| |
|
| | if np.random.random() >= 0.5:
|
| |
|
| | temp = (1 + lam) * self.prototype[old_class_list[0]] - lam * feature.detach().cpu().numpy()[i]
|
| | else:
|
| | temp = (1 - lam) * self.prototype[old_class_list[0]] + lam * feature.detach().cpu().numpy()[i]
|
| |
|
| |
|
| | proto_aug.append(temp)
|
| | proto_aug_label.append(old_class_list[0])
|
| |
|
| | proto_aug = torch.from_numpy(np.float32(np.asarray(proto_aug))).float().to(self.device)
|
| | proto_aug_label = torch.from_numpy(np.asarray(proto_aug_label)).to(self.device)
|
| | aug_preds = self.model.classifier(proto_aug)
|
| | joint_aug_preds = self.model.fc(proto_aug)
|
| | agg_preds = joint_aug_preds[:, ::4]
|
| | aug_distillation_loss = F.kl_div(F.log_softmax(aug_preds, 1),
|
| | F.softmax(agg_preds.detach(), 1),
|
| | reduction='batchmean')
|
| |
|
| | loss_protoAug = nn.CrossEntropyLoss()(aug_preds/self.kwargs["temp"], proto_aug_label) + nn.CrossEntropyLoss()(joint_aug_preds/self.kwargs["temp"], proto_aug_label*4) + aug_distillation_loss
|
| | return joint_loss + single_loss + distillation_loss + self.kwargs["protoAug_weight"]*loss_protoAug + self.kwargs["kd_weight"]*loss_kd, single_preds
|
| |
|
| | def after_task(self, task_idx, buffer, train_loader, test_loaders):
|
| | '''
|
| | Perform operations after completing the training for a specific task.
|
| | 1. Save the prototypes of the current model.
|
| | 2. Save the current model state to a file.
|
| | 3. Load the saved model state as the old model for future reference.
|
| |
|
| | Parameters:
|
| | - task_idx (int): The index of the current task.
|
| | - buffer: Data buffer for storing samples (not used in this function).
|
| | - train_loader (DataLoader): DataLoader for the training dataset of the current task.
|
| | - test_loaders (list of DataLoader): List of DataLoaders for test datasets of different tasks.
|
| |
|
| | Example Usage:
|
| | self.after_task(task_idx, buffer, train_loader, test_loaders)
|
| | '''
|
| |
|
| | self.protoSave(self.model, train_loader, self.task_idx)
|
| | self.numclass += self.task_size
|
| |
|
| | self.old_model = copy.deepcopy(self.model)
|
| | self.old_model.eval()
|
| |
|
| | def protoSave(self, model, loader, current_task):
|
| | '''
|
| | Save the class prototypes for the current task.
|
| |
|
| | This function extracts features from the data using the provided model and computes
|
| | class prototypes based on these features. The prototypes are then saved to the class
|
| | attributes. If it's the first task, the prototypes are initialized. For subsequent
|
| | tasks, the prototypes are updated with new class information.
|
| |
|
| | Parameters:
|
| | - model: The model used for feature extraction.
|
| | - loader: DataLoader providing the dataset for the current task.
|
| | - current_task (int): The index of the current task.
|
| |
|
| | Code Reference:
|
| | https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/jointSSL.py
|
| | '''
|
| |
|
| | features = []
|
| | labels = []
|
| | model.eval()
|
| |
|
| | with torch.no_grad():
|
| | for i, batch in enumerate(loader):
|
| | images, target = batch['image'], batch['label']
|
| | feature = model.feature(images.to(self.device))
|
| | if feature.shape[0] == loader.batch_size:
|
| | labels.append(target.numpy())
|
| | features.append(feature.cpu().numpy())
|
| |
|
| | labels_set = np.unique(labels)
|
| | labels = np.array(labels)
|
| | labels = np.reshape(labels, labels.shape[0] * labels.shape[1])
|
| | features = np.array(features)
|
| | features = np.reshape(features, (features.shape[0] * features.shape[1], features.shape[2]))
|
| |
|
| |
|
| | prototype = {}
|
| | class_label = []
|
| | numsamples = {}
|
| |
|
| | for item in labels_set:
|
| | index = np.where(item == labels)[0]
|
| | class_label.append(item)
|
| | feature_classwise = features[index]
|
| | prototype[item] = np.mean(feature_classwise, axis=0)
|
| |
|
| | numsamples[item] = feature_classwise.shape[0]
|
| | if current_task == 0:
|
| | self.prototype = prototype
|
| | self.class_label = class_label
|
| | self.numsamples = numsamples
|
| | else:
|
| | self.prototype.update(prototype)
|
| | self.class_label = np.concatenate((class_label, self.class_label), axis=0)
|
| | self.numsamples.update(numsamples)
|
| |
|
| | def get_parameters(self, config):
|
| | return self.model.parameters()
|
| |
|
| |
|