File size: 8,233 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# -*- coding: utf-8 -*-
"""
@article{DBLP:journals/corr/KirkpatrickPRVD16,
  author       = {James Kirkpatrick and
                  Razvan Pascanu and
                  Neil C. Rabinowitz and
                  Joel Veness and
                  Guillaume Desjardins and
                  Andrei A. Rusu and
                  Kieran Milan and
                  John Quan and
                  Tiago Ramalho and
                  Agnieszka Grabska{-}Barwinska and
                  Demis Hassabis and
                  Claudia Clopath and
                  Dharshan Kumaran and
                  Raia Hadsell},
  title        = {Overcoming catastrophic forgetting in neural networks},
  journal      = {CoRR},
  volume       = {abs/1612.00796},
  year         = {2016}
}

https://arxiv.org/abs/1612.00796

Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
"""


import math
import copy
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from .finetune import Finetune
from core.model.backbone.resnet import *
import numpy as np
from torch.utils.data import DataLoader
from torch import optim


class Model(nn.Module):
    # A model consists with a backbone and a classifier
    def __init__(self, backbone, feat_dim, num_class):
        super().__init__()
        self.backbone = backbone
        self.feat_dim = feat_dim
        self.num_class = num_class
        self.classifier = nn.Linear(feat_dim, num_class)
        
    def forward(self, x):
        return self.get_logits(x)
    
    def get_logits(self, x):
        logits = self.classifier(self.backbone(x)['features'])
        return logits

class EWC(Finetune):
    def __init__(self, backbone, feat_dim, num_class, **kwargs):
        super().__init__(backbone, feat_dim, num_class, **kwargs)
        self.kwargs = kwargs
        self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num'])
        
        self.ref_param = {n: p.clone().detach() for n, p in self.network.named_parameters() 
                          if p.requires_grad}
        self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.network.named_parameters()
                       if p.requires_grad}
        self.lamda = self.kwargs['lamda']
        
    def before_task(self, task_idx, buffer, train_loader, test_loaders):
        self.task_idx = task_idx
        in_features = self.network.classifier.in_features
        out_features = self.network.classifier.out_features
        
        new_fc = nn.Linear(in_features, self.kwargs['init_cls_num'] + task_idx * self.kwargs['inc_cls_num'])
        new_fc.weight.data[:out_features] = self.network.classifier.weight.data
        new_fc.bias.data[:out_features] = self.network.classifier.bias.data
        self.network.classifier = new_fc
        self.network.to(self.device)

    def observe(self, data):
        x, y = data['image'].to(self.device), data['label'].to(self.device)
        logit = self.network(x)

        if self.task_idx == 0:
            loss = F.cross_entropy(logit, y)
        else:



            old_classes = self.network.classifier.out_features - self.kwargs['inc_cls_num']

            #print(old_classes)
            #print(logit[:, old_classes:].shape)
            #print(y)
            #print(y-old_classes)

            loss = F.cross_entropy(logit[:, old_classes:], y - old_classes)
            loss += self.lamda * self.compute_ewc()

        pred = torch.argmax(logit, dim=1)

        #print(pred)
        #print(y)

        acc = torch.sum(pred == y).item()
        return pred, acc / x.size(0), loss

    def after_task(self, task_idx, buffer, train_loader, test_loaders):
        """
        Args:
            task_idx (int): The index of the current task.
            buffer: Buffer object used in previous tasks.
            train_loader (torch.utils.data.DataLoader): Dataloader for the training dataset.
            test_loaders (list of DataLoader): List of dataloaders for the test datasets.
            
        Code Reference:
            https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
            https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py
        """
        
        # record the parameters
        self.ref_param = {n: p.clone().detach() for n, p in self.network.named_parameters() 
                          if p.requires_grad}
        # the shape of new fisher is changed
        new_fisher = self.getFisher(train_loader)
        # using growing alpha
        alpha = 1 - self.kwargs['inc_cls_num']/self.network.classifier.out_features
        for n, p in self.fisher.items():
            new_fisher[n][:len(self.fisher[n])] = alpha * p + (1 - alpha) * new_fisher[n][:len(self.fisher[n])]

        self.fisher = new_fisher
        
    def inference(self, data):
        x, y = data['image'], data['label']
        x = x.to(self.device)
        y = y.to(self.device)
        
        logit = self.network(x)

        pred = torch.argmax(logit, dim=1)

        acc = torch.sum(pred == y).item()
        return pred, acc / x.size(0)
    
    def getFisher(self, train_loader):
        """
        Compute the Fisher Information Matrix for the parameters of the network.
        
        Args:
            train_loader (torch.utils.data.DataLoader): Dataloader for the training dataset.
            
        Returns:
            dict: Dictionary of Fisher Information Matrices for each parameter.
        
        Code Reference:
        https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
        https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py
        """
        def accumulate(fisher):
            """
            Accumulate the squared gradients for the Fisher Information Matrix.
            
            Args:
                fisher (dict): Dictionary containing the current Fisher Information matrices.
                
            Returns:
                dict: Updated Fisher Information matrices.
            """
            for n, p in self.network.named_parameters():
                if p.grad is not None and n in fisher.keys():
                    fisher[n] += p.grad.pow(2).clone() * len(y)
            return fisher
        
        # Initialize Fisher Information matrices with zeros
        fisher = {
            n: torch.zeros_like(p).to(self.device) for n, p in self.network.named_parameters()
            if p.requires_grad
        }
        
        self.network.train()
        optimizer = optim.SGD(self.network.parameters(), lr=0.1)
        
        loss_fn = torch.nn.CrossEntropyLoss()
        # Iterate over the training data
        for data in train_loader:
            x, y = data['image'], data['label']
            x = x.to(self.device)
            y = y.to(self.device)
            
            logits = self.network(x)
            loss = loss_fn(logits, y)
            
            optimizer.zero_grad()
            loss.backward()
            
            # Accumulate Fisher Information 
            fisher = accumulate(fisher)
        
        # Normalize Fisher Information matrices by the number of samples       
        num_samples = train_loader.batch_size * len(train_loader)
        for n, p in fisher.items():
            fisher[n] = p / num_samples
        return fisher

    def compute_ewc(self):
        """
        Compute the Elastic Weight Consolidation (EWC) loss.
        
        This function calculates the EWC loss based on the stored Fisher Information matrices
        and reference parameters from a previous task.
        
        References:
        - https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
        - https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py
        
        Returns:
            torch.Tensor: The computed EWC loss.
        """
        loss = 0
        for n, p in self.network.named_parameters():
            if n in self.fisher.keys():
                loss += torch.sum(self.fisher[n] * (p[:len(self.ref_param[n])] - self.ref_param[n]).pow(2)) / 2
        return loss
    
    def get_parameters(self,  config):
        train_parameters = []
        train_parameters.append({"params": self.network.parameters()})
        return train_parameters