File size: 8,020 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | # -*- coding: utf-8 -*-
"""
@inproceedings{zhao2020maintaining,
title={Maintaining discrimination and fairness in class incremental learning},
author={Zhao, Bowen and Xiao, Xi and Gan, Guojun and Zhang, Bin and Xia, Shu-Tao},
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)},
pages={13208--13217},
year={2020}
}
https://arxiv.org/abs/1911.07053
Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/wa.py, https://github.com/G-U-N/PyCIL/blob/master/utils/inc_net.py.
"""
import torch
from torch import nn
import copy
from torch.nn import functional as F
import numpy as np
from .finetune import Finetune
def KD_loss(pred, soft, T=2):
'''
Code Reference:
https://github.com/G-U-N/PyCIL/blob/master/models/wa.py
Compute the knowledge distillation loss.
Args:
pred (torch.Tensor): Predictions of the model.
soft (torch.Tensor): Soft targets.
T (float): Temperature parameter for softening the predictions. Default is 2.
Returns:
torch.Tensor: Knowledge distillation loss.
'''
pred = torch.log_softmax(pred / T, dim=1)
soft = torch.softmax(soft / T, dim=1)
return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
class IncrementalModel(nn.Module):
'''
Code Reference:
https://github.com/G-U-N/PyCIL/blob/master/utils/inc_net.py
A model consists with a backbone and a classifier.
Args:
backbone (nn.Module): Backbone network.
feat_dim (int): Dimension of the extracted features.
num_class (int): Number of classes in the dataset.
'''
def __init__(self, backbone, feat_dim, num_class):
super().__init__()
self.backbone = backbone
self.feat_dim = feat_dim
self.num_class = num_class
self.classifier = None
def forward(self, x):
return self.get_logits(x)
def get_logits(self, x):
'''
Compute logits for the input data.
Args:
x (torch.Tensor): Input data.
Returns:
torch.Tensor: Logits of the input data.
'''
logits = self.classifier(self.backbone(x)['features'])
return logits
def update_classifier(self, number_classes):
'''
Incrementally update the classifier with deepcopy.
Args:
number_classes (int): Number of classes after update.
'''
classifier = nn.Linear(self.feat_dim, number_classes)
if self.classifier is not None:
number_output = self.classifier.out_features
weight = copy.deepcopy(self.classifier.weight.data)
bias = copy.deepcopy(self.classifier.bias.data)
classifier.weight.data[:number_output] = weight
classifier.bias.data[:number_output] = bias
del self.classifier
self.classifier = classifier
def classifier_weight_align(self, incremental_number):
'''
Align the weight of the classifier after every task.
Args:
incremental_number (int): Number of classes added in the current task.
'''
weights = self.classifier.weight.data
new_norm = torch.norm(weights[-incremental_number:, :], p=2, dim=1)
old_norm = torch.norm(weights[:-incremental_number, :], p=2, dim=1)
new_mean = torch.mean(new_norm)
old_mean = torch.mean(old_norm)
gamma = old_mean / new_mean
self.classifier.weight.data[-incremental_number:, :] *= gamma
def forward(self, x):
return self.get_logits(x)
def get_logits(self, x):
logits = self.classifier(self.backbone(x)['features'])
return logits
def freeze(self):
'''
Freeze the model parameters.
'''
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def extract_vector(self, x):
'''
Extract features from the backbone network.
Args:
x (torch.Tensor): Input data.
Returns:
torch.Tensor: Extracted features.
'''
return self.backbone(x)["features"]
class WA(Finetune):
def __init__(self, backbone, feat_dim, num_class, **kwargs):
super().__init__(backbone, feat_dim, num_class, **kwargs)
self.network = IncrementalModel(self.backbone, feat_dim, kwargs['init_cls_num'])
self.device = kwargs['device']
self.old_network = None
self.known_classes = 0
self.total_classes = 0
self.task_idx = 0
# For buffer update
self.total_classes_indexes = 0
def observe(self, data):
'''
Do every current task.
Args:
data (dict): Dictionary containing input data and labels.
Returns:
tuple: Tuple containing predictions, accuracy, and loss.
'''
x, y = data['image'].to(self.device), data['label'].to(self.device)
self.network.to(self.device)
if self.old_network:
self.old_network.to(self.device)
logits = self.network(x)
loss = F.cross_entropy(logits, y)
if self.task_idx > 0:
kd_lambda = self.known_classes / self.total_classes
loss_kd = KD_loss(
logits[:, : self.known_classes],
self.old_network(x),
)
loss = (1 - kd_lambda) * loss + kd_lambda * loss_kd
pred = torch.argmax(logits, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0), loss
def inference(self, data):
'''
Perform inference on the input data.
Args:
data (dict): Dictionary containing input data and labels.
Returns:
tuple: Tuple containing predictions and accuracy.
'''
x, y = data['image'].to(self.device), data['label'].to(self.device)
logits = self.network(x)
pred = torch.argmax(logits, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0)
def forward(self, x):
return self.network(x)
def before_task(self, task_idx, buffer, train_loader, test_loaders):
'''
Do before every task for task initialization.
Args:
task_idx (int): Index of the current task.
buffer (Buffer): Buffer object.
train_loader (DataLoader): DataLoader for training data.
test_loaders (list): List of DataLoaders for test data.
'''
self.total_classes += self.kwargs['init_cls_num']
self.network.update_classifier(self.total_classes)
self.total_classes_indexes = np.arange(self.known_classes, self.total_classes)
def after_task(self, task_idx, buffer, train_loader, test_loaders):
'''
Do after every task for updating the model.
Args:
task_idx (int): Index of the current task.
buffer (Buffer): Buffer object.
train_loader (DataLoader): DataLoader for training data.
test_loaders (list): List of DataLoaders for test data.
'''
if self.task_idx > 0:
self.network.classifier_weight_align(self.total_classes - self.known_classes)
self.old_network = copy.deepcopy(self.network).freeze()
self.known_classes = self.total_classes
# update buffer
buffer.reduce_old_data(self.task_idx, self.total_classes)
val_transform = test_loaders[0].dataset.trfms
buffer.update(self.network, train_loader, val_transform,
self.task_idx, self.total_classes, self.total_classes_indexes,
self.device)
self.task_idx += 1
|