| | |
| | """ |
| | @article{DBLP:journals/corr/KirkpatrickPRVD16, |
| | author = {James Kirkpatrick and |
| | Razvan Pascanu and |
| | Neil C. Rabinowitz and |
| | Joel Veness and |
| | Guillaume Desjardins and |
| | Andrei A. Rusu and |
| | Kieran Milan and |
| | John Quan and |
| | Tiago Ramalho and |
| | Agnieszka Grabska{-}Barwinska and |
| | Demis Hassabis and |
| | Claudia Clopath and |
| | Dharshan Kumaran and |
| | Raia Hadsell}, |
| | title = {Overcoming catastrophic forgetting in neural networks}, |
| | journal = {CoRR}, |
| | volume = {abs/1612.00796}, |
| | year = {2016} |
| | } |
| | |
| | https://arxiv.org/abs/1612.00796 |
| | |
| | Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py |
| | """ |
| |
|
| |
|
| | import math |
| | import copy |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import Parameter |
| | import torch.nn.functional as F |
| | from .finetune import Finetune |
| | from core.model.backbone.resnet import * |
| | import numpy as np |
| | from torch.utils.data import DataLoader |
| | from torch import optim |
| |
|
| |
|
| | class Model(nn.Module): |
| | |
| | def __init__(self, backbone, feat_dim, num_class): |
| | super().__init__() |
| | self.backbone = backbone |
| | self.feat_dim = feat_dim |
| | self.num_class = num_class |
| | self.classifier = nn.Linear(feat_dim, num_class) |
| | |
| | def forward(self, x): |
| | return self.get_logits(x) |
| | |
| | def get_logits(self, x): |
| | logits = self.classifier(self.backbone(x)['features']) |
| | return logits |
| |
|
| | class EWC(Finetune): |
| | def __init__(self, backbone, feat_dim, num_class, **kwargs): |
| | super().__init__(backbone, feat_dim, num_class, **kwargs) |
| | self.kwargs = kwargs |
| | self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num']) |
| | |
| | self.ref_param = {n: p.clone().detach() for n, p in self.network.named_parameters() |
| | if p.requires_grad} |
| | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.network.named_parameters() |
| | if p.requires_grad} |
| | self.lamda = self.kwargs['lamda'] |
| | |
| | def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| | self.task_idx = task_idx |
| | in_features = self.network.classifier.in_features |
| | out_features = self.network.classifier.out_features |
| | |
| | new_fc = nn.Linear(in_features, self.kwargs['init_cls_num'] + task_idx * self.kwargs['inc_cls_num']) |
| | new_fc.weight.data[:out_features] = self.network.classifier.weight.data |
| | new_fc.bias.data[:out_features] = self.network.classifier.bias.data |
| | self.network.classifier = new_fc |
| | self.network.to(self.device) |
| |
|
| | def observe(self, data): |
| | x, y = data['image'].to(self.device), data['label'].to(self.device) |
| | logit = self.network(x) |
| |
|
| | if self.task_idx == 0: |
| | loss = F.cross_entropy(logit, y) |
| | else: |
| |
|
| |
|
| |
|
| | old_classes = self.network.classifier.out_features - self.kwargs['inc_cls_num'] |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | loss = F.cross_entropy(logit[:, old_classes:], y - old_classes) |
| | loss += self.lamda * self.compute_ewc() |
| |
|
| | pred = torch.argmax(logit, dim=1) |
| |
|
| | |
| | |
| |
|
| | acc = torch.sum(pred == y).item() |
| | return pred, acc / x.size(0), loss |
| |
|
| | def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| | """ |
| | Args: |
| | task_idx (int): The index of the current task. |
| | buffer: Buffer object used in previous tasks. |
| | train_loader (torch.utils.data.DataLoader): Dataloader for the training dataset. |
| | test_loaders (list of DataLoader): List of dataloaders for the test datasets. |
| | |
| | Code Reference: |
| | https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py |
| | https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py |
| | """ |
| | |
| | |
| | self.ref_param = {n: p.clone().detach() for n, p in self.network.named_parameters() |
| | if p.requires_grad} |
| | |
| | new_fisher = self.getFisher(train_loader) |
| | |
| | alpha = 1 - self.kwargs['inc_cls_num']/self.network.classifier.out_features |
| | for n, p in self.fisher.items(): |
| | new_fisher[n][:len(self.fisher[n])] = alpha * p + (1 - alpha) * new_fisher[n][:len(self.fisher[n])] |
| |
|
| | self.fisher = new_fisher |
| | |
| | def inference(self, data): |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| | |
| | logit = self.network(x) |
| |
|
| | pred = torch.argmax(logit, dim=1) |
| |
|
| | acc = torch.sum(pred == y).item() |
| | return pred, acc / x.size(0) |
| | |
| | def getFisher(self, train_loader): |
| | """ |
| | Compute the Fisher Information Matrix for the parameters of the network. |
| | |
| | Args: |
| | train_loader (torch.utils.data.DataLoader): Dataloader for the training dataset. |
| | |
| | Returns: |
| | dict: Dictionary of Fisher Information Matrices for each parameter. |
| | |
| | Code Reference: |
| | https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py |
| | https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py |
| | """ |
| | def accumulate(fisher): |
| | """ |
| | Accumulate the squared gradients for the Fisher Information Matrix. |
| | |
| | Args: |
| | fisher (dict): Dictionary containing the current Fisher Information matrices. |
| | |
| | Returns: |
| | dict: Updated Fisher Information matrices. |
| | """ |
| | for n, p in self.network.named_parameters(): |
| | if p.grad is not None and n in fisher.keys(): |
| | fisher[n] += p.grad.pow(2).clone() * len(y) |
| | return fisher |
| | |
| | |
| | fisher = { |
| | n: torch.zeros_like(p).to(self.device) for n, p in self.network.named_parameters() |
| | if p.requires_grad |
| | } |
| | |
| | self.network.train() |
| | optimizer = optim.SGD(self.network.parameters(), lr=0.1) |
| | |
| | loss_fn = torch.nn.CrossEntropyLoss() |
| | |
| | for data in train_loader: |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| | |
| | logits = self.network(x) |
| | loss = loss_fn(logits, y) |
| | |
| | optimizer.zero_grad() |
| | loss.backward() |
| | |
| | |
| | fisher = accumulate(fisher) |
| | |
| | |
| | num_samples = train_loader.batch_size * len(train_loader) |
| | for n, p in fisher.items(): |
| | fisher[n] = p / num_samples |
| | return fisher |
| |
|
| | def compute_ewc(self): |
| | """ |
| | Compute the Elastic Weight Consolidation (EWC) loss. |
| | |
| | This function calculates the EWC loss based on the stored Fisher Information matrices |
| | and reference parameters from a previous task. |
| | |
| | References: |
| | - https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py |
| | - https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py |
| | |
| | Returns: |
| | torch.Tensor: The computed EWC loss. |
| | """ |
| | loss = 0 |
| | for n, p in self.network.named_parameters(): |
| | if n in self.fisher.keys(): |
| | loss += torch.sum(self.fisher[n] * (p[:len(self.ref_param[n])] - self.ref_param[n]).pow(2)) / 2 |
| | return loss |
| | |
| | def get_parameters(self, config): |
| | train_parameters = [] |
| | train_parameters.append({"params": self.network.parameters()}) |
| | return train_parameters |