| |
|
| | """
|
| | @inproceedings{DBLP:conf/cvpr/YanX021,
|
| | author = {Shipeng Yan and
|
| | Jiangwei Xie and
|
| | Xuming He},
|
| | title = {{DER:} Dynamically Expandable Representation for Class Incremental
|
| | Learning},
|
| | booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR}
|
| | 2021, virtual, June 19-25, 2021},
|
| | pages = {3014--3023},
|
| | year = {2021},
|
| | }
|
| |
|
| | https://openaccess.thecvf.com/content/CVPR2021/papers/Yan_DER_Dynamically_Expandable_Representation_for_Class_Incremental_Learning_CVPR_2021_paper.pdf
|
| |
|
| | Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/der.py
|
| | """
|
| | import math
|
| | import copy
|
| | import torch
|
| | import torch.nn as nn
|
| | from torch.nn import Parameter
|
| | import torch.nn.functional as F
|
| | from .finetune import Finetune
|
| | from core.model.backbone import resnet18, resnet34, resnet50
|
| | from core.utils import get_instance
|
| |
|
| | def get_convnet(convnet_type, pretrained=False):
|
| | name = convnet_type.lower()
|
| | if name == "resnet18":
|
| | dic = {"num_classes": 10, "args":{'dataset':'cifar100'}}
|
| | return resnet18(**dic)
|
| |
|
| |
|
| | elif name == "resnet34":
|
| | return resnet34()
|
| | elif name == "resnet50":
|
| | return resnet50()
|
| | else:
|
| | raise NotImplementedError("Unknown type {}".format(convnet_type))
|
| |
|
| | class SimpleLinear(nn.Module):
|
| | '''
|
| | Reference:
|
| | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
|
| | '''
|
| | def __init__(self, in_features, out_features, bias=True):
|
| | super(SimpleLinear, self).__init__()
|
| | self.in_features = in_features
|
| | self.out_features = out_features
|
| | self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
| | if bias:
|
| | self.bias = nn.Parameter(torch.Tensor(out_features))
|
| | else:
|
| | self.register_parameter('bias', None)
|
| | self.reset_parameters()
|
| |
|
| | def reset_parameters(self):
|
| | nn.init.kaiming_uniform_(self.weight, nonlinearity='linear')
|
| | nn.init.constant_(self.bias, 0)
|
| |
|
| | def forward(self, input):
|
| | return {'logits': F.linear(input, self.weight, self.bias)}
|
| |
|
| | class DER(Finetune):
|
| | def __init__(self, backbone, feat_dim, num_class, **kwargs):
|
| | super().__init__(backbone, feat_dim, num_class, **kwargs)
|
| | self.convnets = nn.ModuleList()
|
| | self.pretrained = None
|
| | self.out_dim = None
|
| | self.fc = None
|
| | self.aux_fc = None
|
| | self.task_sizes = []
|
| |
|
| | self.kwargs = kwargs
|
| | self.init_cls_num = kwargs['init_cls_num']
|
| | self.inc_cls_num = kwargs['inc_cls_num']
|
| | self.known_cls_num = 0
|
| | self.total_cls_num = 0
|
| |
|
| | self.convnet_type = 'resnet18'
|
| |
|
| | @property
|
| | def feature_dim(self):
|
| | if self.out_dim is None:
|
| | return 0
|
| | return self.out_dim * len(self.convnets)
|
| |
|
| | def forward(self, x):
|
| | features = [convnet(x)["features"] for convnet in self.convnets]
|
| | features = torch.cat(features, 1)
|
| |
|
| | out = self.fc(features)
|
| |
|
| | aux_logits = self.aux_fc(features[:, -self.out_dim :])["logits"]
|
| |
|
| | out.update({"aux_logits": aux_logits, "features": features})
|
| | return out
|
| | """
|
| | {
|
| | 'features': features
|
| | 'logits': logits
|
| | 'aux_logits':aux_logits
|
| | }
|
| | """
|
| |
|
| | def observe(self, data):
|
| | x, y = data['image'], data['label']
|
| | x = x.to(self.device)
|
| | y = y.to(self.device)
|
| | features = [convnet(x)["features"] for convnet in self.convnets]
|
| | features = torch.cat(features, 1)
|
| |
|
| | logit = self.fc(features)['logits']
|
| |
|
| | if self.task_idx == 0:
|
| | loss = self.loss_fn(logit, y)
|
| | else:
|
| | loss_clf = self.loss_fn(logit, y)
|
| | aux_targets = y.clone()
|
| | aux_targets = torch.where(
|
| | aux_targets - self.known_cls_num + 1 > 0,
|
| | aux_targets - self.known_cls_num + 1,
|
| | 0,
|
| | )
|
| | aux_logits = self.aux_fc(features[:, -self.out_dim :])["logits"]
|
| | loss_aux = F.cross_entropy(aux_logits, aux_targets)
|
| | loss = loss_aux + loss_clf
|
| |
|
| | pred = torch.argmax(logit, dim=1)
|
| |
|
| | acc = torch.sum(pred == y).item()
|
| | return pred, acc / x.size(0), loss
|
| |
|
| | def inference(self, data):
|
| | x, y = data['image'], data['label']
|
| | x = x.to(self.device)
|
| | y = y.to(self.device)
|
| |
|
| | features = [convnet(x)["features"] for convnet in self.convnets]
|
| | features = torch.cat(features, 1)
|
| | logit = self.fc(features)['logits']
|
| | pred = torch.argmax(logit, dim=1)
|
| |
|
| | acc = torch.sum(pred == y).item()
|
| | return pred, acc / x.size(0)
|
| |
|
| | def update_fc(self, nb_classes):
|
| | if len(self.convnets) == 0:
|
| | self.convnets.append(get_convnet(self.convnet_type))
|
| | else:
|
| | self.convnets.append(get_convnet(self.convnet_type))
|
| | self.convnets[-1].load_state_dict(self.convnets[-2].state_dict())
|
| |
|
| | if self.out_dim is None:
|
| | self.out_dim = self.convnets[-1].out_dim
|
| | fc = self.generate_fc(self.feature_dim, nb_classes)
|
| | if self.fc is not None:
|
| | nb_output = self.fc.out_features
|
| | weight = copy.deepcopy(self.fc.weight.data)
|
| | bias = copy.deepcopy(self.fc.bias.data)
|
| | fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight
|
| | fc.bias.data[:nb_output] = bias
|
| |
|
| | del self.fc
|
| | self.fc = fc
|
| |
|
| | new_task_size = nb_classes - sum(self.task_sizes)
|
| | self.task_sizes.append(new_task_size)
|
| |
|
| | self.aux_fc = self.generate_fc(self.out_dim, new_task_size + 1)
|
| |
|
| | def generate_fc(self, in_dim, out_dim):
|
| | fc = SimpleLinear(in_dim, out_dim)
|
| |
|
| | return fc
|
| |
|
| | def freeze_convnets(self):
|
| | for param in self.convnets.parameters():
|
| | param.requires_grad = False
|
| | self.convnets.eval()
|
| |
|
| | def weight_align(self, increment):
|
| | weights = self.fc.weight.data
|
| | newnorm = torch.norm(weights[-increment:, :], p=2, dim=1)
|
| | oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1)
|
| | meannew = torch.mean(newnorm)
|
| | meanold = torch.mean(oldnorm)
|
| | gamma = meanold / meannew
|
| | print("alignweights,gamma=", gamma)
|
| | self.fc.weight.data[-increment:, :] *= gamma
|
| |
|
| | def before_task(self, task_idx, buffer, train_loader, test_loaders):
|
| | self.task_idx = task_idx
|
| | self.known_cls_num = self.total_cls_num
|
| | self.total_cls_num = self.init_cls_num + self.task_idx*self.inc_cls_num
|
| |
|
| | self.freeze_convnets()
|
| | self.update_fc(self.total_cls_num)
|
| | self.loss_fn = nn.CrossEntropyLoss()
|
| | self.convnets = self.convnets.to(self.device)
|
| | self.fc = self.fc.to(self.device)
|
| | self.aux_fc = self.aux_fc.to(self.device)
|
| |
|
| | def _train(self):
|
| | self.fc.train()
|
| |
|
| | self.aux_fc.train()
|
| |
|
| | for i in range(self.task_idx -1):
|
| | self.convnets[i].eval()
|
| | self.convnets[-1].train()
|
| |
|
| |
|
| |
|
| | def get_parameters(self, config):
|
| | train_parameters = []
|
| |
|
| | train_parameters.append({"params": self.convnets.parameters()})
|
| |
|
| | if self.fc is not None:
|
| | train_parameters.append({"params": self.fc.parameters()})
|
| | if self.aux_fc is not None:
|
| | train_parameters.append({"params": self.aux_fc.parameters()})
|
| | return train_parameters
|
| |
|
| | def iffreeze(name,net):
|
| | for k,v in net.named_parameters():
|
| | print('{}{}: {}'.format(name,k, v.requires_grad))
|
| |
|