| | |
| | """ |
| | @inproceedings{DBLP:conf/cvpr/HouPLWL19, |
| | title = {Learning a Unified Classifier Incrementally via Rebalancing}, |
| | author = {Saihui Hou and |
| | Xinyu Pan and |
| | Chen Change Loy and |
| | Zilei Wang and |
| | Dahua Lin}, |
| | booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR} |
| | 2019, Long Beach, CA, USA, June 16-20, 2019}, |
| | pages = {831--839}, |
| | publisher = {Computer Vision Foundation / {IEEE}}, |
| | year = {2019} |
| | } |
| | https://openaccess.thecvf.com/content_CVPR_2019/html/Hou_Learning_a_Unified_Classifier_Incrementally_via_Rebalancing_CVPR_2019_paper.html |
| | |
| | Adapted from https://github.com/hshustc/CVPR19_Incremental_Learning |
| | """ |
| |
|
| | import math |
| | import copy |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import Parameter |
| | import torch.nn.functional as F |
| | from .finetune import Finetune |
| | from core.model.backbone.resnet import * |
| | import numpy as np |
| | from torch.utils.data import DataLoader |
| |
|
| |
|
| | cur_features = [] |
| | ref_features = [] |
| | old_scores = [] |
| | new_scores = [] |
| | def get_ref_features(self, inputs, outputs): |
| | global ref_features |
| | ref_features = inputs[0] |
| |
|
| | def get_cur_features(self, inputs, outputs): |
| | global cur_features |
| | cur_features = inputs[0] |
| |
|
| | def get_old_scores_before_scale(self, inputs, outputs): |
| | global old_scores |
| | old_scores = outputs |
| |
|
| | def get_new_scores_before_scale(self, inputs, outputs): |
| | global new_scores |
| | new_scores = outputs |
| |
|
| |
|
| |
|
| | class Model(nn.Module): |
| | |
| | def __init__(self, backbone, feat_dim, num_class): |
| | super().__init__() |
| | self.backbone = backbone |
| | self.feat_dim = feat_dim |
| | self.num_class = num_class |
| | self.classifier = CosineLinear(feat_dim, num_class) |
| | |
| | def forward(self, x): |
| | return self.get_logits(x) |
| | |
| | def get_logits(self, x): |
| | logits = self.classifier(self.backbone(x)['features']) |
| | return logits |
| |
|
| |
|
| |
|
| | class LUCIR(Finetune): |
| | def __init__(self, backbone, feat_dim, num_class, **kwargs): |
| | super().__init__(backbone, feat_dim, num_class, **kwargs) |
| | self.kwargs = kwargs |
| | self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num']) |
| | self.K = kwargs['K'] |
| | self.lw_mr = kwargs['lw_mr'] |
| | self.ref_model = None |
| | self.task_idx = 0 |
| |
|
| | def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| | self.task_idx = task_idx |
| |
|
| | if task_idx == 1: |
| | self.ref_model = copy.deepcopy(self.network) |
| | in_features = self.network.classifier.in_features |
| | out_features = self.network.classifier.out_features |
| | new_fc = SplitCosineLinear(in_features, out_features, self.kwargs['inc_cls_num']) |
| | new_fc.fc1.weight.data = self.network.classifier.weight.data |
| | new_fc.sigma.data = self.network.classifier.sigma.data |
| | self.network.classifier = new_fc |
| | lamda_mult = out_features*1.0 / self.kwargs['inc_cls_num'] |
| |
|
| |
|
| | elif task_idx > 1: |
| | self.ref_model = copy.deepcopy(self.network) |
| | in_features = self.network.classifier.in_features |
| | out_features1 = self.network.classifier.fc1.out_features |
| | out_features2 = self.network.classifier.fc2.out_features |
| | new_fc = SplitCosineLinear(in_features, out_features1+out_features2, self.kwargs['inc_cls_num']).to(self.device) |
| | new_fc.fc1.weight.data[:out_features1] = self.network.classifier.fc1.weight.data |
| | new_fc.fc1.weight.data[out_features1:] = self.network.classifier.fc2.weight.data |
| | new_fc.sigma.data = self.network.classifier.sigma.data |
| | self.network.classifier = new_fc |
| | lamda_mult = (out_features1+out_features2)*1.0 / (self.kwargs['inc_cls_num']) |
| | |
| | if task_idx > 0: |
| | self.cur_lamda = self.kwargs['lamda'] * math.sqrt(lamda_mult) |
| | else: |
| | self.cur_lamda = self.kwargs['lamda'] |
| |
|
| | self._init_new_fc(task_idx, buffer, train_loader) |
| |
|
| | if task_idx == 0: |
| | self.loss_fn = nn.CrossEntropyLoss() |
| | else: |
| | self.loss_fn1 = nn.CosineEmbeddingLoss() |
| | self.loss_fn2 = nn.CrossEntropyLoss() |
| | self.loss_fn3 = nn.MarginRankingLoss(margin=self.kwargs['dist']) |
| |
|
| | self.ref_model.eval() |
| | self.num_old_classes = self.ref_model.classifier.out_features |
| | self.handle_ref_features = self.ref_model.classifier.register_forward_hook(get_ref_features) |
| | self.handle_cur_features = self.network.classifier.register_forward_hook(get_cur_features) |
| | self.handle_old_scores_bs = self.network.classifier.fc1.register_forward_hook(get_old_scores_before_scale) |
| | self.handle_new_scores_bs = self.network.classifier.fc2.register_forward_hook(get_new_scores_before_scale) |
| |
|
| | self.network = self.network.to(self.device) |
| | if self.ref_model is not None: |
| | self.ref_model = self.ref_model.to(self.device) |
| |
|
| | def _init_new_fc(self, task_idx, buffer, train_loader): |
| | if task_idx == 0: |
| | return |
| | old_embedding_norm = self.network.classifier.fc1.weight.data.norm(dim=1, keepdim=True) |
| | average_old_embedding_norm = torch.mean(old_embedding_norm, dim=0).to('cpu').type(torch.DoubleTensor) |
| | feature_model = self.network.backbone |
| | num_features = self.network.classifier.in_features |
| | novel_embedding = torch.zeros((self.kwargs['inc_cls_num'], num_features)) |
| |
|
| | tmp_datasets = copy.deepcopy(train_loader.dataset) |
| | for cls_idx in range(self.network.classifier.fc1.out_features, self.network.classifier.fc1.out_features + self.network.classifier.fc2.out_features): |
| | cls_dataset = train_loader.dataset |
| | task_data, task_target = cls_dataset.images, cls_dataset.labels |
| | cls_indices = np.where(np.array(task_target) == cls_idx) |
| | cls_data, cls_target = np.array([task_data[i] for i in cls_indices[0]]), np.array([task_target[i] for i in cls_indices[0]]) |
| | tmp_datasets.images = cls_data |
| | tmp_datasets.labels = cls_target |
| | tmp_loader = DataLoader(tmp_datasets, batch_size=128, shuffle=False, num_workers=2) |
| | num_samples = cls_data.shape[0] |
| | cls_features = self._compute_feature(feature_model, tmp_loader, num_samples, num_features) |
| | norm_features = F.normalize(torch.from_numpy(cls_features), p=2, dim=1) |
| | cls_embedding = torch.mean(norm_features, dim=0) |
| | novel_embedding[cls_idx-self.network.classifier.fc1.out_features] = F.normalize(cls_embedding, p=2, dim=0) * average_old_embedding_norm |
| | |
| | self.network.to(self.device) |
| | self.network.classifier.fc2.weight.data = novel_embedding.to(self.device) |
| |
|
| | def _compute_feature(self, feature_model, loader, num_samples, num_features): |
| | feature_model.eval() |
| | features = np.zeros([num_samples, num_features]) |
| | start_idx = 0 |
| | with torch.no_grad(): |
| | for batch_idx, batch in enumerate(loader): |
| | inputs, labels = batch['image'], batch['label'] |
| | inputs = inputs.to(self.device) |
| | features[start_idx:start_idx+inputs.shape[0], :] = np.squeeze(feature_model.feature(inputs).cpu()) |
| | start_idx = start_idx+inputs.shape[0] |
| | assert(start_idx==num_samples) |
| | return features |
| |
|
| |
|
| | def observe(self, data): |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| | logit = self.network(x) |
| |
|
| | if self.task_idx == 0: |
| | loss = self.loss_fn(logit, y) |
| | else: |
| | ref_outputs = self.ref_model(x) |
| | loss = self.loss_fn1(cur_features, ref_features.detach(), \ |
| | torch.ones(x.size(0)).to(self.device)) * self.cur_lamda |
| | |
| | loss += self.loss_fn2(logit, y) |
| | |
| | outputs_bs = torch.cat((old_scores, new_scores), dim=1) |
| | assert(outputs_bs.size()==logit.size()) |
| | gt_index = torch.zeros(outputs_bs.size()).to(self.device) |
| | gt_index = gt_index.scatter(1, y.view(-1,1), 1).ge(0.5) |
| | gt_scores = outputs_bs.masked_select(gt_index) |
| | max_novel_scores = outputs_bs[:, self.num_old_classes:].topk(self.K, dim=1)[0] |
| | hard_index = y.lt(self.num_old_classes) |
| | hard_num = torch.nonzero(hard_index).size(0) |
| | |
| | if hard_num > 0: |
| | gt_scores = gt_scores[hard_index].view(-1, 1).repeat(1, self.K) |
| | max_novel_scores = max_novel_scores[hard_index] |
| | assert(gt_scores.size() == max_novel_scores.size()) |
| | assert(gt_scores.size(0) == hard_num) |
| | loss += self.loss_fn3(gt_scores.view(-1, 1), \ |
| | max_novel_scores.view(-1, 1), torch.ones(hard_num*self.K, 1).to(self.device)) * self.lw_mr |
| |
|
| | pred = torch.argmax(logit, dim=1) |
| | acc = torch.sum(pred == y).item() |
| |
|
| | return pred, acc / x.size(0), loss |
| |
|
| | def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| | if self.task_idx > 0: |
| | self.handle_ref_features.remove() |
| | self.handle_cur_features.remove() |
| | self.handle_old_scores_bs.remove() |
| | self.handle_new_scores_bs.remove() |
| |
|
| | def inference(self, data): |
| |
|
| | x, y = data['image'].to(self.device), data['label'].to(self.device) |
| |
|
| | logit = self.network(x) |
| | pred = torch.argmax(logit, dim=1) |
| |
|
| | acc = torch.sum(pred == y).item() |
| | return pred, acc / x.size(0) |
| |
|
| | def get_parameters(self, config): |
| | if self.task_idx > 0: |
| | |
| | ignored_params = list(map(id, self.network.classifier.fc1.parameters())) |
| | base_params = filter(lambda p: id(p) not in ignored_params, \ |
| | self.network.parameters()) |
| | tg_params =[{'params': base_params, 'lr': 0.1, 'weight_decay': 5e-4}, \ |
| | {'params': self.network.classifier.fc1.parameters(), 'lr': 0, 'weight_decay': 0}] |
| | else: |
| | tg_params = self.network.parameters() |
| | |
| | return tg_params |