File size: 4,076 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
@misc{caccia2022new,
title={New Insights on Reducing Abrupt Representation Change in Online Continual Learning},
author={Lucas Caccia and Rahaf Aljundi and Nader Asadi and Tinne Tuytelaars and Joelle Pineau and Eugene Belilovsky},
year={2022},
eprint={2104.05025},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Adapted from https://github.com/pclucas14/AML
"""
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class distLinear(nn.Module):
def __init__(self, indim, outdim, weight=None):
super().__init__()
self.L = nn.Linear(indim, outdim, bias = False)
if weight is not None:
self.L.weight.data = Variable(weight)
self.scale_factor = 10
def forward(self, x):
x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x)
x_normalized = x.div(x_norm + 0.00001)
L_norm = torch.norm(self.L.weight, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data)
cos_dist = torch.mm(x_normalized,self.L.weight.div(L_norm + 0.00001).transpose(0,1))
scores = self.scale_factor * (cos_dist)
return scores
class Model(nn.Module):
def __init__(self, backbone, num_classes):
super().__init__()
self.backbone = backbone
self.classifier = distLinear(backbone.out_dim, num_classes)
def forward(self, data):
return self.classifier(self.backbone(data))
class ERACE(nn.Module):
def __init__(self, backbone, device, **kwargs):
super().__init__()
self.model = Model(backbone, kwargs['num_classes'])
self.init_cls_num = kwargs['init_cls_num']
self.inc_cls_num = kwargs['inc_cls_num']
self.use_augs = kwargs['use_augs']
self.device = device
self.seen_so_far = 0
self.task_free = kwargs['task_free']
assert self.task_free, 'ER-ACE must be task free'
self.sample_kwargs = {
'amt': 10,
'exclude_task': None
}
self.model.to(self.device)
def observe(self, data):
x, y = data['image'].to(self.device), data['label'].to(self.device)
self.inc_data = {'x': x, 'y': y, 't': self.cur_task_idx}
logits = self.model(x)
mask = torch.zeros_like(logits)
mask[:, self.seen_so_far:] = 1
if self.cur_task_idx > 0 or self.task_free:
logits = logits.masked_fill(mask == 0, -1e9)
loss = F.cross_entropy(logits, y)
pred = logits.max(1)[1]
correct_count = (pred == y).sum().item()
total_count = y.shape[0]
if len(self.buffer) > 0 and (self.task_free or self.cur_task_idx > 0):
re_data = self.buffer.sample_random(**self.sample_kwargs)
re_logits = self.model(re_data['x'])
loss += F.cross_entropy(re_logits, re_data['y'])
re_pred = re_logits.max(1)[1]
correct_count += (re_pred == re_data['y']).sum().item()
total_count += re_data['y'].shape[0]
acc = correct_count / total_count
# only return output of incoming data, not including output of rehearsal data
return pred, acc, loss
def inference(self, data):
x, y = data['image'].to(self.device), data['label'].to(self.device)
logits = self.model(x)
pred = logits.max(1)[1]
correct_count = pred.eq(y).sum().item()
acc = correct_count / y.size(0)
return pred, acc
def before_task(self, task_idx, buffer, train_loader, test_loaders):
if not self.use_augs:
train_loader.dataset.trfms = test_loaders[0].dataset.trfms
self.buffer = buffer
self.buffer.device = self.device
self.cur_task_idx = task_idx
def after_task(self, task_idx, buffer, train_loader, test_loaders):
self.seen_so_far = self.init_cls_num + self.inc_cls_num * task_idx
def add_reservoir(self):
self.buffer.add_reservoir(self.inc_data)
def get_parameters(self, config):
return self.model.parameters()
|