boringKey's picture
Upload 236 files
5fee096 verified
# -*- coding: utf-8 -*-
"""
@article{DBLP:journals/corr/KirkpatrickPRVD16,
author = {James Kirkpatrick and
Razvan Pascanu and
Neil C. Rabinowitz and
Joel Veness and
Guillaume Desjardins and
Andrei A. Rusu and
Kieran Milan and
John Quan and
Tiago Ramalho and
Agnieszka Grabska{-}Barwinska and
Demis Hassabis and
Claudia Clopath and
Dharshan Kumaran and
Raia Hadsell},
title = {Overcoming catastrophic forgetting in neural networks},
journal = {CoRR},
volume = {abs/1612.00796},
year = {2016}
}
https://arxiv.org/abs/1612.00796
Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
"""
import math
import copy
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from .finetune import Finetune
from core.model.backbone.resnet import *
import numpy as np
from torch.utils.data import DataLoader
from torch import optim
class Model(nn.Module):
# A model consists with a backbone and a classifier
def __init__(self, backbone, feat_dim, num_class):
super().__init__()
self.backbone = backbone
self.feat_dim = feat_dim
self.num_class = num_class
self.classifier = nn.Linear(feat_dim, num_class)
def forward(self, x):
return self.get_logits(x)
def get_logits(self, x):
logits = self.classifier(self.backbone(x)['features'])
return logits
class EWC(Finetune):
def __init__(self, backbone, feat_dim, num_class, **kwargs):
super().__init__(backbone, feat_dim, num_class, **kwargs)
self.kwargs = kwargs
self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num'])
self.ref_param = {n: p.clone().detach() for n, p in self.network.named_parameters()
if p.requires_grad}
self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.network.named_parameters()
if p.requires_grad}
self.lamda = self.kwargs['lamda']
def before_task(self, task_idx, buffer, train_loader, test_loaders):
self.task_idx = task_idx
in_features = self.network.classifier.in_features
out_features = self.network.classifier.out_features
new_fc = nn.Linear(in_features, self.kwargs['init_cls_num'] + task_idx * self.kwargs['inc_cls_num'])
new_fc.weight.data[:out_features] = self.network.classifier.weight.data
new_fc.bias.data[:out_features] = self.network.classifier.bias.data
self.network.classifier = new_fc
self.network.to(self.device)
def observe(self, data):
x, y = data['image'].to(self.device), data['label'].to(self.device)
logit = self.network(x)
if self.task_idx == 0:
loss = F.cross_entropy(logit, y)
else:
old_classes = self.network.classifier.out_features - self.kwargs['inc_cls_num']
#print(old_classes)
#print(logit[:, old_classes:].shape)
#print(y)
#print(y-old_classes)
loss = F.cross_entropy(logit[:, old_classes:], y - old_classes)
loss += self.lamda * self.compute_ewc()
pred = torch.argmax(logit, dim=1)
#print(pred)
#print(y)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0), loss
def after_task(self, task_idx, buffer, train_loader, test_loaders):
"""
Args:
task_idx (int): The index of the current task.
buffer: Buffer object used in previous tasks.
train_loader (torch.utils.data.DataLoader): Dataloader for the training dataset.
test_loaders (list of DataLoader): List of dataloaders for the test datasets.
Code Reference:
https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py
"""
# record the parameters
self.ref_param = {n: p.clone().detach() for n, p in self.network.named_parameters()
if p.requires_grad}
# the shape of new fisher is changed
new_fisher = self.getFisher(train_loader)
# using growing alpha
alpha = 1 - self.kwargs['inc_cls_num']/self.network.classifier.out_features
for n, p in self.fisher.items():
new_fisher[n][:len(self.fisher[n])] = alpha * p + (1 - alpha) * new_fisher[n][:len(self.fisher[n])]
self.fisher = new_fisher
def inference(self, data):
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logit = self.network(x)
pred = torch.argmax(logit, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0)
def getFisher(self, train_loader):
"""
Compute the Fisher Information Matrix for the parameters of the network.
Args:
train_loader (torch.utils.data.DataLoader): Dataloader for the training dataset.
Returns:
dict: Dictionary of Fisher Information Matrices for each parameter.
Code Reference:
https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py
"""
def accumulate(fisher):
"""
Accumulate the squared gradients for the Fisher Information Matrix.
Args:
fisher (dict): Dictionary containing the current Fisher Information matrices.
Returns:
dict: Updated Fisher Information matrices.
"""
for n, p in self.network.named_parameters():
if p.grad is not None and n in fisher.keys():
fisher[n] += p.grad.pow(2).clone() * len(y)
return fisher
# Initialize Fisher Information matrices with zeros
fisher = {
n: torch.zeros_like(p).to(self.device) for n, p in self.network.named_parameters()
if p.requires_grad
}
self.network.train()
optimizer = optim.SGD(self.network.parameters(), lr=0.1)
loss_fn = torch.nn.CrossEntropyLoss()
# Iterate over the training data
for data in train_loader:
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logits = self.network(x)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
# Accumulate Fisher Information
fisher = accumulate(fisher)
# Normalize Fisher Information matrices by the number of samples
num_samples = train_loader.batch_size * len(train_loader)
for n, p in fisher.items():
fisher[n] = p / num_samples
return fisher
def compute_ewc(self):
"""
Compute the Elastic Weight Consolidation (EWC) loss.
This function calculates the EWC loss based on the stored Fisher Information matrices
and reference parameters from a previous task.
References:
- https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py
- https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py
Returns:
torch.Tensor: The computed EWC loss.
"""
loss = 0
for n, p in self.network.named_parameters():
if n in self.fisher.keys():
loss += torch.sum(self.fisher[n] * (p[:len(self.ref_param[n])] - self.ref_param[n]).pow(2)) / 2
return loss
def get_parameters(self, config):
train_parameters = []
train_parameters.append({"params": self.network.parameters()})
return train_parameters