File size: 4,076 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
@misc{caccia2022new,
    title={New Insights on Reducing Abrupt Representation Change in Online Continual Learning}, 
    author={Lucas Caccia and Rahaf Aljundi and Nader Asadi and Tinne Tuytelaars and Joelle Pineau and Eugene Belilovsky},
    year={2022},
    eprint={2104.05025},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Adapted from https://github.com/pclucas14/AML
"""

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

class distLinear(nn.Module):
    def __init__(self, indim, outdim, weight=None):
        super().__init__()
        self.L = nn.Linear(indim, outdim, bias = False)
        if weight is not None:
            self.L.weight.data = Variable(weight)

        self.scale_factor = 10

    def forward(self, x):
        x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x)
        x_normalized = x.div(x_norm + 0.00001)

        L_norm = torch.norm(self.L.weight, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data)
        cos_dist = torch.mm(x_normalized,self.L.weight.div(L_norm + 0.00001).transpose(0,1))

        scores = self.scale_factor * (cos_dist)

        return scores

class Model(nn.Module):
    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        self.classifier = distLinear(backbone.out_dim, num_classes)

    def forward(self, data):
        return self.classifier(self.backbone(data))
    
class ERACE(nn.Module):

    def __init__(self, backbone, device, **kwargs):
        super().__init__()

        self.model = Model(backbone, kwargs['num_classes'])
        self.init_cls_num = kwargs['init_cls_num']
        self.inc_cls_num = kwargs['inc_cls_num']
        self.use_augs = kwargs['use_augs']
        self.device = device
        self.seen_so_far = 0

        self.task_free = kwargs['task_free']
        assert self.task_free, 'ER-ACE must be task free'

        self.sample_kwargs = {
            'amt':          10,
            'exclude_task': None
        }

        self.model.to(self.device)

    def observe(self, data):

        x, y = data['image'].to(self.device), data['label'].to(self.device)
        self.inc_data = {'x': x, 'y': y, 't': self.cur_task_idx}

        logits = self.model(x)

        mask = torch.zeros_like(logits)
        mask[:, self.seen_so_far:] = 1

        if self.cur_task_idx > 0 or self.task_free:
            logits = logits.masked_fill(mask == 0, -1e9)

        loss = F.cross_entropy(logits, y)
        pred = logits.max(1)[1]
        correct_count = (pred == y).sum().item()
        total_count = y.shape[0]

        if len(self.buffer) > 0 and (self.task_free or self.cur_task_idx > 0):
            re_data = self.buffer.sample_random(**self.sample_kwargs)

            re_logits = self.model(re_data['x'])
            loss += F.cross_entropy(re_logits, re_data['y'])
            re_pred = re_logits.max(1)[1]
            correct_count += (re_pred == re_data['y']).sum().item()
            total_count += re_data['y'].shape[0]

        acc = correct_count / total_count

        # only return output of incoming data, not including output of rehearsal data
        return pred, acc, loss

    def inference(self, data):

        x, y = data['image'].to(self.device), data['label'].to(self.device)

        logits = self.model(x)
        pred   = logits.max(1)[1]
        correct_count = pred.eq(y).sum().item()
        acc = correct_count / y.size(0)

        return pred, acc

    def before_task(self, task_idx, buffer, train_loader, test_loaders):

        if not self.use_augs:
            train_loader.dataset.trfms = test_loaders[0].dataset.trfms

        self.buffer = buffer
        self.buffer.device = self.device

        self.cur_task_idx = task_idx

    def after_task(self, task_idx, buffer, train_loader, test_loaders):
        self.seen_so_far = self.init_cls_num + self.inc_cls_num * task_idx

    def add_reservoir(self):
        self.buffer.add_reservoir(self.inc_data)

    def get_parameters(self, config):
        return self.model.parameters()