File size: 15,698 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
"""

@inproceedings{DBLP:conf/iccv/ShiY23,

  title        = {Prototype Reminiscence and Augmented Asymmetric Knowledge Aggregation for Non-Exemplar Class-Incremental Learning},

  author       = {Shi, Wuxuan and Ye, Mang},

  booktitle    = {2023 IEEE/CVF International Conference on Computer Vision (ICCV)},

  pages        = {1772-1781},

  publisher    = {Computer Vision Foundation / {IEEE}},

  year         = {2023}

}



https://openaccess.thecvf.com/content/ICCV2023/papers/Shi_Prototype_Reminiscence_and_Augmented_Asymmetric_Knowledge_Aggregation_for_Non-Exemplar_Class-Incremental_ICCV_2023_paper.pdf



Adapted from https://github.com/ShiWuxuan/PRAKA

"""

from torch.nn import functional as F
import os
import numpy as np
import torch
import torch.nn as nn
import math
import copy
from core.model import Finetune

class joint_network(nn.Module):
    def __init__(self, numclass, feature_extractor):
        '''

        Code Reference:

        https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py

        '''
        super(joint_network, self).__init__()
        self.feature = feature_extractor
        self.fc = nn.Linear(512, numclass * 4, bias=True)
        self.classifier = nn.Linear(512, numclass, bias=True)

    def forward(self, input):
        '''

        Code Reference:

        https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py

        '''
        x = self.feature(input)
        x = self.classifier(x)
        return x

    def Incremental_learning(self, numclass):
        '''

        Update the fully connected (fc) layer and classifier layer to accommodate the new number of classes.



        This function modifies the output dimensions of the model's fully connected layer (`fc`)

        and the classifier layer based on the total number of classes after the current task.

        It ensures that the new layers retain the weights and biases from the previous configuration

        for the classes that were previously learned.



        Parameters:

        - numclass (int): The total number of classes after the current task, including both old and new classes.



        Notes:

        - The `fc` layer's output dimension is set to `numclass * 4`.

        - The classifier layer is adjusted to match the new total number of classes, while retaining the previously learned weights and biases.



        Code Reference:

        https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py

        '''
        weight = self.fc.weight.data
        bias = self.fc.bias.data
        in_feature = self.fc.in_features
        out_feature = self.fc.out_features

        self.fc = nn.Linear(in_feature, numclass * 4, bias=True)
        self.fc.weight.data[:out_feature] = weight[:out_feature]
        self.fc.bias.data[:out_feature] = bias[:out_feature]

        weight = self.classifier.weight.data
        bias = self.classifier.bias.data
        in_feature = self.classifier.in_features
        out_feature = self.classifier.out_features

        self.classifier = nn.Linear(in_feature, numclass, bias=True)
        self.classifier.weight.data[:out_feature] = weight[:out_feature]
        self.classifier.bias.data[:out_feature] = bias[:out_feature]

    def feature_extractor(self, inputs):
        '''

        Code Reference:

        https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py

        '''
        return self.feature(inputs)

class PRAKA(nn.Module):
    def __init__(self, backbone, feat_dim, num_class, **kwargs):
        #super().__init__(backbone, feat_dim, num_class, **kwargs)
        super().__init__()
        self.device = kwargs['device']
        self.kwargs = kwargs
        self.size = 32
        # Initialize the feature extractor with a custom ResNet18 structure.
        encoder = backbone
        self.model = joint_network(kwargs["init_cls_num"], encoder)
        self.radius = 0
        self.prototype = None
        self.numsamples = None
        self.numclass = kwargs["init_cls_num"]
        self.task_size = kwargs["inc_cls_num"]
        self.old_model = None
        # save the model and its corresponding task_id
        self.task_idx = 0

    def before_task(self, task_idx, buffer, train_loader, test_loaders):
        self.task_idx = task_idx
        if task_idx > 0:
            self.model.Incremental_learning(self.numclass)
        self.model.to(self.device)
        
    def observe(self, data):
        '''

            Processes a batch of training data to compute predictions, accuracy, and loss.



            Parameters:

            - data: Dictionary containing the batch of training samples

              - 'image': Tensor of input images

              - 'label': Tensor of ground truth labels



            Returns:

            - predictions: Tensor of predicted class labels for the input images

            - accuracy: Float value representing the accuracy of the model on the current batch

            - loss: Float value representing the computed loss for the batch



            Description:

            This function is called during the training phase. It performs the following steps:

            1. Extracts the images and labels from the provided data dictionary and transfers them to the device.

            2. Augments the images by rotating them by 0, 90, 180, and 270 degrees, and creates corresponding labels for these augmented images.

            3. Computes the loss using the augmented images and labels.

            4. Evaluates the model's performance on the current batch by calculating the accuracy and loss.

            5. Returns the predictions, accuracy, and loss for the batch.



            Example Usage:

            predictions, accuracy, loss = observe(data)

        '''
        images, labels = data['image'].to(self.device), data['label'].to(self.device)

        # Generate four times the number of images by rotating each image 0°, 90°, 180°, and 270°.
        images = torch.stack([torch.rot90(images, k, (2, 3)) for k in range(4)], 1)
        images = images.view(-1, 3, self.size, self.size)
        # Generate corresponding labels for the rotated images, each original label produces four new labels.
        joint_labels = torch.stack([labels * 4 + k for k in range(4)], 1).view(-1)
        if self.task_idx == 0:
            old_class = 0
        else:
            old_class = self.kwargs['init_cls_num'] + self.kwargs['inc_cls_num'] * (self.task_idx - 1)
        # Compute loss and predictions for a batch
        loss, single_preds = self._compute_loss(images, joint_labels, labels, old_class)

        preds = torch.argmax(single_preds, dim=-1)
        return preds, (preds == labels).sum().item() / len(labels), loss

    def inference(self, data):
        '''

            Performs inference on a batch of test samples and computes the classification results and accuracy.



            Parameters:

            - data: Dictionary containing the batch of test samples

              - 'image': Tensor of input images

              - 'label': Tensor of ground truth labels



            Returns:

            - predictions: Tensor of predicted class labels for the input images

            - accuracy: Float value representing the accuracy of the model on the current batch



            Example Usage:

            predictions, accuracy = inference(data)

        '''

        imgs, labels = data['image'].to(self.device), data['label'].to(self.device)

        preds = torch.argmax(self.model(imgs), dim=-1)

        return preds, (preds == labels).sum().item() / len(labels)

    def _compute_loss(self, imgs, joint_labels, labels, old_class=0):
        '''

            Computes the loss for a batch of images and labels.



            Parameters:

            - imgs: Tensor of input images

            - joint_labels: Tensor of labels for images augmented with rotations (0°, 90°, 180°, 270°)

            - labels: Tensor of ground truth labels for the images

            - old_class: Integer indicating the number of old classes (default is 0)



            Returns:

            - loss: Scalar tensor representing the total computed loss

            - preds: Tensor of predictions for the original (non-augmented) images



            Example Usage:

            loss, preds = self._compute_loss(imgs, joint_labels, labels, old_class)





            Code Reference:

            https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/jointSSL.py

        '''
        # Feature extraction
        feature = self.model.feature(imgs)

        # Classification predictions
        joint_preds = self.model.fc(feature)
        single_preds = self.model.classifier(feature)[::4]
        joint_preds, joint_labels, single_preds, labels = joint_preds.to(self.device), joint_labels.to(self.device), single_preds.to(self.device), labels.to(self.device)
        joint_loss = nn.CrossEntropyLoss()(joint_preds/self.kwargs["temp"], joint_labels)
        single_loss = nn.CrossEntropyLoss()(single_preds/self.kwargs["temp"], labels)

        # Average loss for images generated by rotating 4 angles
        agg_preds = 0
        for i in range(4):
            agg_preds = agg_preds + joint_preds[i::4, i::4] / 4
        # Compute distillation loss between single predictions and aggregated predictions
        distillation_loss = F.kl_div(F.log_softmax(single_preds, 1),
                                    F.softmax(agg_preds.detach(), 1),
                                    reduction='batchmean')
        if old_class == 0:
            return joint_loss + single_loss + distillation_loss, single_preds
        else:
            feature_old = self.old_model.feature(imgs)

            loss_kd = torch.dist(feature, feature_old, 2)

            # Prototype augmentation
            proto_aug = []
            proto_aug_label = []
            old_class_list = list(self.prototype.keys())
            for _ in range(feature.shape[0] // 4):  # batch_size = feature.shape[0] // 4
                i = np.random.randint(0, feature.shape[0])
                np.random.shuffle(old_class_list)
                lam = np.random.beta(0.5, 0.5)
                if lam > 0.6:
                    lam = lam * 0.6

                if np.random.random() >= 0.5:
                    # Weighted combination of prototype (fixed image from old dataset) and current feature
                    temp = (1 + lam) * self.prototype[old_class_list[0]] - lam * feature.detach().cpu().numpy()[i]
                else:
                    temp = (1 - lam) * self.prototype[old_class_list[0]] + lam * feature.detach().cpu().numpy()[i]

                # Append the generated augmented features and corresponding labels to proto_aug and proto_aug_label
                proto_aug.append(temp)
                proto_aug_label.append(old_class_list[0])

            proto_aug = torch.from_numpy(np.float32(np.asarray(proto_aug))).float().to(self.device)
            proto_aug_label = torch.from_numpy(np.asarray(proto_aug_label)).to(self.device)
            aug_preds = self.model.classifier(proto_aug)
            joint_aug_preds = self.model.fc(proto_aug)
            agg_preds = joint_aug_preds[:, ::4]
            aug_distillation_loss = F.kl_div(F.log_softmax(aug_preds, 1),
                                            F.softmax(agg_preds.detach(), 1),
                                            reduction='batchmean')
            # Calculate the weighted sum of cross-entropy loss and distillation loss for augmented data
            loss_protoAug = nn.CrossEntropyLoss()(aug_preds/self.kwargs["temp"], proto_aug_label) + nn.CrossEntropyLoss()(joint_aug_preds/self.kwargs["temp"], proto_aug_label*4) + aug_distillation_loss
            return joint_loss + single_loss + distillation_loss + self.kwargs["protoAug_weight"]*loss_protoAug + self.kwargs["kd_weight"]*loss_kd, single_preds

    def after_task(self, task_idx, buffer, train_loader, test_loaders):
        '''

            Perform operations after completing the training for a specific task.

                    1. Save the prototypes of the current model.

                    2. Save the current model state to a file.

                    3. Load the saved model state as the old model for future reference.



            Parameters:

            - task_idx (int): The index of the current task.

            - buffer: Data buffer for storing samples (not used in this function).

            - train_loader (DataLoader): DataLoader for the training dataset of the current task.

            - test_loaders (list of DataLoader): List of DataLoaders for test datasets of different tasks.



            Example Usage:

            self.after_task(task_idx, buffer, train_loader, test_loaders)

        '''
        # Save the prototype
        self.protoSave(self.model, train_loader, self.task_idx)
        self.numclass += self.task_size

        self.old_model = copy.deepcopy(self.model)
        self.old_model.eval()

    def protoSave(self, model, loader, current_task):
        '''

            Save the class prototypes for the current task.



            This function extracts features from the data using the provided model and computes

            class prototypes based on these features. The prototypes are then saved to the class

            attributes. If it's the first task, the prototypes are initialized. For subsequent

            tasks, the prototypes are updated with new class information.



            Parameters:

            - model: The model used for feature extraction.

            - loader: DataLoader providing the dataset for the current task.

            - current_task (int): The index of the current task.



            Code Reference:

            https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/jointSSL.py

        '''

        features = []
        labels = []
        model.eval()
        # Feature extraction
        with torch.no_grad():
            for i, batch in enumerate(loader):
                images, target = batch['image'], batch['label']
                feature = model.feature(images.to(self.device))
                if feature.shape[0] == loader.batch_size:
                    labels.append(target.numpy())
                    features.append(feature.cpu().numpy())

        labels_set = np.unique(labels)
        labels = np.array(labels)
        labels = np.reshape(labels, labels.shape[0] * labels.shape[1])
        features = np.array(features)
        features = np.reshape(features, (features.shape[0] * features.shape[1], features.shape[2]))

        # Compute class prototypes
        prototype = {}
        class_label = []
        numsamples = {}

        for item in labels_set:
            index = np.where(item == labels)[0]
            class_label.append(item)
            feature_classwise = features[index]
            prototype[item] = np.mean(feature_classwise, axis=0)
            # Record the number of samples for each class.
            numsamples[item] = feature_classwise.shape[0]
        if current_task == 0:
            self.prototype = prototype
            self.class_label = class_label
            self.numsamples = numsamples
        else:
            self.prototype.update(prototype)
            self.class_label = np.concatenate((class_label, self.class_label), axis=0)
            self.numsamples.update(numsamples)

    def get_parameters(self, config):
        return self.model.parameters()