File size: 2,898 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import math
import copy
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from .finetune import Finetune

class LWF(Finetune):
    def __init__(self, backbone, feat_dim, num_class, **kwargs):
        super().__init__(backbone, feat_dim, num_class, **kwargs)
        self.kwargs = kwargs
        self.feat_dim = feat_dim
        self.classifier = nn.Linear(self.feat_dim, kwargs['init_cls_num'])
        self.old_fc = None
        self.init_cls_num = kwargs['init_cls_num']
        self.inc_cls_num = kwargs['inc_cls_num']
        self.known_cls_num = 0
        self.total_cls_num = 0
        self.old_backbone = None        

    def freeze(self,nn):
        for param in nn.parameters():
            param.requires_grad = False
        nn.eval()
        return nn
    
    def update_fc(self):
        fc = nn.Linear(self.feat_dim, self.total_cls_num).to(self.device)
        if self.classifier is not None:
            # del self.old_fc
            self.old_fc = self.freeze(copy.deepcopy(self.classifier))
            old_out = self.classifier.out_features
            weight = copy.deepcopy(self.classifier.weight.data)
            bias = copy.deepcopy(self.classifier.bias.data)
            fc.weight.data[:old_out] = weight
            fc.bias.data[:old_out] = bias

        # del self.classifier
        self.classifier = fc

    def before_task(self, task_idx, buffer, train_loader, test_loaders):
        self.task_idx = task_idx
        self.known_cls_num = self.total_cls_num
        self.total_cls_num = self.init_cls_num + self.task_idx*self.inc_cls_num
        self.update_fc()
        self.loss_fn = nn.CrossEntropyLoss()
        if task_idx != 0:
            self.old_backbone = self.freeze(copy.deepcopy(self.backbone)).to(self.device)


    def observe(self, data):
        x, y = data['image'], data['label']
        x = x.to(self.device)
        y = y.to(self.device)
        logit = self.classifier(self.backbone(x)['features'])    

        if self.task_idx == 0:
            loss = self.loss_fn(logit, y)
        else:
            fake_targets = y - self.known_cls_num
            loss_clf = self.loss_fn(logit[:,self.known_cls_num:],fake_targets)
            loss_kd = self._KD_loss(logit[:,:self.known_cls_num],self.old_fc(self.old_backbone(x)['features']),T=2)
            lamda = 3
            loss = lamda*loss_kd + loss_clf

        pred = torch.argmax(logit, dim=1)

        acc = torch.sum(pred == y).item()
        return pred, acc / x.size(0), loss

    def after_task(self, task_idx, buffer, train_loader, test_loaders):
        pass

    def _KD_loss(self, pred, soft, T):
        pred = torch.log_softmax(pred / T, dim=1)
        soft = torch.softmax(soft / T, dim=1)
        return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
    def _cross_entropy(self, pre, logit):
        loss = None
        return loss