File size: 4,373 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
"""
@inproceedings{yu2024boosting,
  title={Boosting continual learning of vision-language models via mixture-of-experts adapters},
  author={Yu, Jiazuo and Zhuge, Yunzhi and Zhang, Lu and Hu, Ping and Wang, Dong and Lu, Huchuan and He, You},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={23219--23230},
  year={2024}
}

Adapted from https://github.com/JiazuoYu/MoE-Adapters4CL
"""

import math
import torch
import torch.nn as nn
import numpy as np

from torch import optim
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from tqdm import tqdm

from .backbone.clip import tokenize, CLIP
from .backbone.vit import ViTZoo

VIT = ViTZoo
CLIP = CLIP

class MOE_ADAPTER4CL(nn.Module):

    def __init__(self, backbone, device, **kwargs):
        super().__init__()

        self.device = device
        self.init_cls_num = kwargs['init_cls_num']
        self.inc_cls_num = kwargs['inc_cls_num']
        self.label_smoothing = kwargs['label_smoothing']

        self._known_classes = 0
        self._cur_task_id = -1

        self.accm_class_names = []   
        self.curr_class_names = []
        self.accm_text_tokens = None
        self.curr_text_tokens = None

        self.prompt_template = kwargs['prompt_template']

        self._network = backbone
        
        self.classifier_pool = nn.ModuleList([
            nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + 
            [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)]
        )

        for name, param in self._network.named_parameters():
            if 'adaptmlp' not in name and 'router' not in name and 'noise' not in name:
                param.requires_grad = False

    def observe(self, data):
        '''
        Called during the training phase, it inputs a batch of training examples and returns the prediction, accuracy, and forward loss.
        '''

        x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes

        if isinstance(self._network, CLIP):
            features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.curr_text_tokens)
        elif isinstance(self._network, VIT):
            features = self._network(x)
            logits_per_img = []
            for prompts in [self.classifier_pool[self._cur_task_id]]:
                logits_per_img.append(prompts(features))
            logits_per_img = torch.cat(logits_per_img, dim=1)
        else:
            raise NotImplementedError
        
        loss = F.cross_entropy(logits_per_img, y, label_smoothing=self.label_smoothing)

        preds = logits_per_img.softmax(dim=-1).argmax(dim=1)
        acc = preds.eq(y).sum().item() / y.size(0)

        return preds, acc, loss

    def inference(self, data):

        x, y = data['image'].to(self.device), data['label'].to(self.device)

        if isinstance(self._network, CLIP):
            features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens)
        elif isinstance(self._network, VIT):
            features = self._network(x)
            logits_per_img = []
            for prompts in self.classifier_pool[:self._cur_task_id + 1]:
                logits_per_img.append(prompts(features))
            logits_per_img = torch.cat(logits_per_img, dim=1)
        else:
            raise NotImplementedError

        preds = logits_per_img.softmax(dim=-1).argmax(dim=1)
        acc = preds.eq(y).sum().item() / y.size(0)

        return preds, acc
    
    def before_task(self, task_idx, buffer, train_loader, test_loaders):
        
        self._cur_task_id = task_idx
        if task_idx == 1:
            self._known_classes = self.init_cls_num
        elif task_idx > 1:
            self._known_classes += self.inc_cls_num

        self.curr_class_names = train_loader.dataset.get_class_names()
        self.accm_class_names += self.curr_class_names

        self.curr_text_tokens = tokenize([self.prompt_template.format(c) for c in self.curr_class_names]).to(self.device)
        self.accm_text_tokens = tokenize([self.prompt_template.format(c) for c in self.accm_class_names]).to(self.device)

    def get_parameters(self, config):
        return self._network.parameters()