File size: 10,330 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# -*- coding: utf-8 -*-
"""
@inproceedings{DBLP:conf/cvpr/HouPLWL19,
  title        = {Learning a Unified Classifier Incrementally via Rebalancing},
  author       = {Saihui Hou and
                  Xinyu Pan and
                  Chen Change Loy and
                  Zilei Wang and
                  Dahua Lin},
  booktitle    = {{IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR}
                  2019, Long Beach, CA, USA, June 16-20, 2019},
  pages        = {831--839},
  publisher    = {Computer Vision Foundation / {IEEE}},
  year         = {2019}
}
https://openaccess.thecvf.com/content_CVPR_2019/html/Hou_Learning_a_Unified_Classifier_Incrementally_via_Rebalancing_CVPR_2019_paper.html

Adapted from https://github.com/hshustc/CVPR19_Incremental_Learning
"""

import math
import copy
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from .finetune import Finetune
from core.model.backbone.resnet import *
import numpy as np
from torch.utils.data import DataLoader


cur_features = []
ref_features = []
old_scores = []
new_scores = []
def get_ref_features(self, inputs, outputs):
    global ref_features
    ref_features = inputs[0]

def get_cur_features(self, inputs, outputs):
    global cur_features
    cur_features = inputs[0]

def get_old_scores_before_scale(self, inputs, outputs):
    global old_scores
    old_scores = outputs

def get_new_scores_before_scale(self, inputs, outputs):
    global new_scores
    new_scores = outputs



class Model(nn.Module):
    # A model consists with a backbone and a classifier
    def __init__(self, backbone, feat_dim, num_class):
        super().__init__()
        self.backbone = backbone
        self.feat_dim = feat_dim
        self.num_class = num_class
        self.classifier = CosineLinear(feat_dim, num_class)
        
    def forward(self, x):
        return self.get_logits(x)
    
    def get_logits(self, x):
        logits = self.classifier(self.backbone(x)['features'])
        return logits



class LUCIR(Finetune):
    def __init__(self, backbone, feat_dim, num_class, **kwargs):
        super().__init__(backbone, feat_dim, num_class, **kwargs)
        self.kwargs = kwargs
        self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num'])
        self.K = kwargs['K']
        self.lw_mr = kwargs['lw_mr']
        self.ref_model = None
        self.task_idx = 0

    def before_task(self, task_idx, buffer, train_loader, test_loaders):
        self.task_idx = task_idx

        if task_idx == 1:
            self.ref_model = copy.deepcopy(self.network)
            in_features = self.network.classifier.in_features
            out_features = self.network.classifier.out_features
            new_fc = SplitCosineLinear(in_features, out_features, self.kwargs['inc_cls_num'])
            new_fc.fc1.weight.data = self.network.classifier.weight.data
            new_fc.sigma.data = self.network.classifier.sigma.data
            self.network.classifier = new_fc
            lamda_mult = out_features*1.0 / self.kwargs['inc_cls_num']


        elif task_idx > 1:
            self.ref_model = copy.deepcopy(self.network) 
            in_features = self.network.classifier.in_features
            out_features1 = self.network.classifier.fc1.out_features
            out_features2 = self.network.classifier.fc2.out_features
            new_fc = SplitCosineLinear(in_features, out_features1+out_features2, self.kwargs['inc_cls_num']).to(self.device)
            new_fc.fc1.weight.data[:out_features1] = self.network.classifier.fc1.weight.data
            new_fc.fc1.weight.data[out_features1:] = self.network.classifier.fc2.weight.data
            new_fc.sigma.data = self.network.classifier.sigma.data
            self.network.classifier = new_fc
            lamda_mult = (out_features1+out_features2)*1.0 / (self.kwargs['inc_cls_num'])
        
        if task_idx > 0:
            self.cur_lamda = self.kwargs['lamda'] * math.sqrt(lamda_mult)
        else:
            self.cur_lamda = self.kwargs['lamda']

        self._init_new_fc(task_idx, buffer, train_loader)

        if task_idx == 0:
            self.loss_fn = nn.CrossEntropyLoss()
        else:
            self.loss_fn1 = nn.CosineEmbeddingLoss()
            self.loss_fn2 = nn.CrossEntropyLoss()
            self.loss_fn3 = nn.MarginRankingLoss(margin=self.kwargs['dist'])

            self.ref_model.eval()
            self.num_old_classes = self.ref_model.classifier.out_features
            self.handle_ref_features = self.ref_model.classifier.register_forward_hook(get_ref_features)
            self.handle_cur_features = self.network.classifier.register_forward_hook(get_cur_features)
            self.handle_old_scores_bs = self.network.classifier.fc1.register_forward_hook(get_old_scores_before_scale)
            self.handle_new_scores_bs = self.network.classifier.fc2.register_forward_hook(get_new_scores_before_scale)

        self.network = self.network.to(self.device)
        if self.ref_model is not None:
            self.ref_model = self.ref_model.to(self.device)

    def _init_new_fc(self, task_idx, buffer, train_loader):
        if task_idx == 0:
            return
        old_embedding_norm = self.network.classifier.fc1.weight.data.norm(dim=1, keepdim=True)  
        average_old_embedding_norm = torch.mean(old_embedding_norm, dim=0).to('cpu').type(torch.DoubleTensor) 
        feature_model = self.network.backbone 
        num_features = self.network.classifier.in_features
        novel_embedding = torch.zeros((self.kwargs['inc_cls_num'], num_features))

        tmp_datasets = copy.deepcopy(train_loader.dataset)
        for cls_idx in range(self.network.classifier.fc1.out_features, self.network.classifier.fc1.out_features + self.network.classifier.fc2.out_features):
            cls_dataset = train_loader.dataset
            task_data, task_target = cls_dataset.images, cls_dataset.labels
            cls_indices = np.where(np.array(task_target) == cls_idx) # tuple
            cls_data, cls_target = np.array([task_data[i] for i in cls_indices[0]]), np.array([task_target[i] for i in cls_indices[0]])
            tmp_datasets.images = cls_data
            tmp_datasets.labels = cls_target
            tmp_loader = DataLoader(tmp_datasets, batch_size=128, shuffle=False, num_workers=2)
            num_samples = cls_data.shape[0]
            cls_features = self._compute_feature(feature_model, tmp_loader, num_samples, num_features)
            norm_features = F.normalize(torch.from_numpy(cls_features), p=2, dim=1)
            cls_embedding = torch.mean(norm_features, dim=0)
            novel_embedding[cls_idx-self.network.classifier.fc1.out_features] = F.normalize(cls_embedding, p=2, dim=0) * average_old_embedding_norm
        
        self.network.to(self.device)
        self.network.classifier.fc2.weight.data = novel_embedding.to(self.device)

    def _compute_feature(self, feature_model, loader, num_samples, num_features):
        feature_model.eval()
        features = np.zeros([num_samples, num_features])
        start_idx = 0
        with torch.no_grad():
            for batch_idx, batch in enumerate(loader):
                inputs, labels = batch['image'], batch['label']
                inputs = inputs.to(self.device)
                features[start_idx:start_idx+inputs.shape[0], :] = np.squeeze(feature_model.feature(inputs).cpu())
                start_idx = start_idx+inputs.shape[0]
        assert(start_idx==num_samples)
        return features


    def observe(self, data):
        x, y = data['image'], data['label']
        x = x.to(self.device)
        y = y.to(self.device)
        logit = self.network(x)

        if self.task_idx == 0:
            loss = self.loss_fn(logit, y)
        else:
            ref_outputs = self.ref_model(x)
            loss = self.loss_fn1(cur_features, ref_features.detach(), \
                    torch.ones(x.size(0)).to(self.device)) * self.cur_lamda
            
            loss += self.loss_fn2(logit, y)
            
            outputs_bs = torch.cat((old_scores, new_scores), dim=1)
            assert(outputs_bs.size()==logit.size())
            gt_index = torch.zeros(outputs_bs.size()).to(self.device)
            gt_index = gt_index.scatter(1, y.view(-1,1), 1).ge(0.5)
            gt_scores = outputs_bs.masked_select(gt_index)
            max_novel_scores = outputs_bs[:, self.num_old_classes:].topk(self.K, dim=1)[0]
            hard_index = y.lt(self.num_old_classes)
            hard_num = torch.nonzero(hard_index).size(0)
            
            if  hard_num > 0:
                gt_scores = gt_scores[hard_index].view(-1, 1).repeat(1, self.K)
                max_novel_scores = max_novel_scores[hard_index]
                assert(gt_scores.size() == max_novel_scores.size())
                assert(gt_scores.size(0) == hard_num)
                loss += self.loss_fn3(gt_scores.view(-1, 1), \
                        max_novel_scores.view(-1, 1), torch.ones(hard_num*self.K, 1).to(self.device)) * self.lw_mr

        pred = torch.argmax(logit, dim=1)
        acc = torch.sum(pred == y).item()

        return pred, acc / x.size(0), loss

    def after_task(self, task_idx, buffer, train_loader, test_loaders):
        if self.task_idx > 0:
            self.handle_ref_features.remove()
            self.handle_cur_features.remove()
            self.handle_old_scores_bs.remove()
            self.handle_new_scores_bs.remove()

    def inference(self, data):

        x, y = data['image'].to(self.device), data['label'].to(self.device)

        logit = self.network(x)
        pred = torch.argmax(logit, dim=1)

        acc = torch.sum(pred == y).item()
        return pred, acc / x.size(0)

    def get_parameters(self, config):
        if self.task_idx > 0:
            #fix the embedding of old classes
            ignored_params = list(map(id, self.network.classifier.fc1.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params, \
                    self.network.parameters())
            tg_params =[{'params': base_params, 'lr': 0.1, 'weight_decay': 5e-4}, \
                        {'params': self.network.classifier.fc1.parameters(), 'lr': 0, 'weight_decay': 0}]
        else:
            tg_params = self.network.parameters()
        
        return tg_params