File size: 8,233 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | # -*- coding: utf-8 -*-
"""
@article{DBLP:journals/corr/KirkpatrickPRVD16,
author = {James Kirkpatrick and
Razvan Pascanu and
Neil C. Rabinowitz and
Joel Veness and
Guillaume Desjardins and
Andrei A. Rusu and
Kieran Milan and
John Quan and
Tiago Ramalho and
Agnieszka Grabska{-}Barwinska and
Demis Hassabis and
Claudia Clopath and
Dharshan Kumaran and
Raia Hadsell},
title = {Overcoming catastrophic forgetting in neural networks},
journal = {CoRR},
volume = {abs/1612.00796},
year = {2016}
}
https://arxiv.org/abs/1612.00796
Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
"""
import math
import copy
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from .finetune import Finetune
from core.model.backbone.resnet import *
import numpy as np
from torch.utils.data import DataLoader
from torch import optim
class Model(nn.Module):
# A model consists with a backbone and a classifier
def __init__(self, backbone, feat_dim, num_class):
super().__init__()
self.backbone = backbone
self.feat_dim = feat_dim
self.num_class = num_class
self.classifier = nn.Linear(feat_dim, num_class)
def forward(self, x):
return self.get_logits(x)
def get_logits(self, x):
logits = self.classifier(self.backbone(x)['features'])
return logits
class EWC(Finetune):
def __init__(self, backbone, feat_dim, num_class, **kwargs):
super().__init__(backbone, feat_dim, num_class, **kwargs)
self.kwargs = kwargs
self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num'])
self.ref_param = {n: p.clone().detach() for n, p in self.network.named_parameters()
if p.requires_grad}
self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.network.named_parameters()
if p.requires_grad}
self.lamda = self.kwargs['lamda']
def before_task(self, task_idx, buffer, train_loader, test_loaders):
self.task_idx = task_idx
in_features = self.network.classifier.in_features
out_features = self.network.classifier.out_features
new_fc = nn.Linear(in_features, self.kwargs['init_cls_num'] + task_idx * self.kwargs['inc_cls_num'])
new_fc.weight.data[:out_features] = self.network.classifier.weight.data
new_fc.bias.data[:out_features] = self.network.classifier.bias.data
self.network.classifier = new_fc
self.network.to(self.device)
def observe(self, data):
x, y = data['image'].to(self.device), data['label'].to(self.device)
logit = self.network(x)
if self.task_idx == 0:
loss = F.cross_entropy(logit, y)
else:
old_classes = self.network.classifier.out_features - self.kwargs['inc_cls_num']
#print(old_classes)
#print(logit[:, old_classes:].shape)
#print(y)
#print(y-old_classes)
loss = F.cross_entropy(logit[:, old_classes:], y - old_classes)
loss += self.lamda * self.compute_ewc()
pred = torch.argmax(logit, dim=1)
#print(pred)
#print(y)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0), loss
def after_task(self, task_idx, buffer, train_loader, test_loaders):
"""
Args:
task_idx (int): The index of the current task.
buffer: Buffer object used in previous tasks.
train_loader (torch.utils.data.DataLoader): Dataloader for the training dataset.
test_loaders (list of DataLoader): List of dataloaders for the test datasets.
Code Reference:
https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py
"""
# record the parameters
self.ref_param = {n: p.clone().detach() for n, p in self.network.named_parameters()
if p.requires_grad}
# the shape of new fisher is changed
new_fisher = self.getFisher(train_loader)
# using growing alpha
alpha = 1 - self.kwargs['inc_cls_num']/self.network.classifier.out_features
for n, p in self.fisher.items():
new_fisher[n][:len(self.fisher[n])] = alpha * p + (1 - alpha) * new_fisher[n][:len(self.fisher[n])]
self.fisher = new_fisher
def inference(self, data):
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logit = self.network(x)
pred = torch.argmax(logit, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0)
def getFisher(self, train_loader):
"""
Compute the Fisher Information Matrix for the parameters of the network.
Args:
train_loader (torch.utils.data.DataLoader): Dataloader for the training dataset.
Returns:
dict: Dictionary of Fisher Information Matrices for each parameter.
Code Reference:
https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py
"""
def accumulate(fisher):
"""
Accumulate the squared gradients for the Fisher Information Matrix.
Args:
fisher (dict): Dictionary containing the current Fisher Information matrices.
Returns:
dict: Updated Fisher Information matrices.
"""
for n, p in self.network.named_parameters():
if p.grad is not None and n in fisher.keys():
fisher[n] += p.grad.pow(2).clone() * len(y)
return fisher
# Initialize Fisher Information matrices with zeros
fisher = {
n: torch.zeros_like(p).to(self.device) for n, p in self.network.named_parameters()
if p.requires_grad
}
self.network.train()
optimizer = optim.SGD(self.network.parameters(), lr=0.1)
loss_fn = torch.nn.CrossEntropyLoss()
# Iterate over the training data
for data in train_loader:
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logits = self.network(x)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
# Accumulate Fisher Information
fisher = accumulate(fisher)
# Normalize Fisher Information matrices by the number of samples
num_samples = train_loader.batch_size * len(train_loader)
for n, p in fisher.items():
fisher[n] = p / num_samples
return fisher
def compute_ewc(self):
"""
Compute the Elastic Weight Consolidation (EWC) loss.
This function calculates the EWC loss based on the stored Fisher Information matrices
and reference parameters from a previous task.
References:
- https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
- https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py
Returns:
torch.Tensor: The computed EWC loss.
"""
loss = 0
for n, p in self.network.named_parameters():
if n in self.fisher.keys():
loss += torch.sum(self.fisher[n] * (p[:len(self.ref_param[n])] - self.ref_param[n]).pow(2)) / 2
return loss
def get_parameters(self, config):
train_parameters = []
train_parameters.append({"params": self.network.parameters()})
return train_parameters |