| |
| """ |
| @inproceedings{yu2024boosting, |
| title={Boosting continual learning of vision-language models via mixture-of-experts adapters}, |
| author={Yu, Jiazuo and Zhuge, Yunzhi and Zhang, Lu and Hu, Ping and Wang, Dong and Lu, Huchuan and He, You}, |
| booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, |
| pages={23219--23230}, |
| year={2024} |
| } |
| |
| Adapted from https://github.com/JiazuoYu/MoE-Adapters4CL |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
| from torch import optim |
| from torch.nn import functional as F |
| from torch.nn.parameter import Parameter |
| from tqdm import tqdm |
|
|
| from .backbone.clip import tokenize, CLIP |
| from .backbone.vit import ViTZoo |
|
|
| VIT = ViTZoo |
| CLIP = CLIP |
|
|
| class MOE_ADAPTER4CL(nn.Module): |
|
|
| def __init__(self, backbone, device, **kwargs): |
| super().__init__() |
|
|
| self.device = device |
| self.init_cls_num = kwargs['init_cls_num'] |
| self.inc_cls_num = kwargs['inc_cls_num'] |
| self.label_smoothing = kwargs['label_smoothing'] |
|
|
| self._known_classes = 0 |
| self._cur_task_id = -1 |
|
|
| self.accm_class_names = [] |
| self.curr_class_names = [] |
| self.accm_text_tokens = None |
| self.curr_text_tokens = None |
|
|
| self.prompt_template = kwargs['prompt_template'] |
|
|
| self._network = backbone |
| |
| self.classifier_pool = nn.ModuleList([ |
| nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + |
| [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)] |
| ) |
|
|
| for name, param in self._network.named_parameters(): |
| if 'adaptmlp' not in name and 'router' not in name and 'noise' not in name: |
| param.requires_grad = False |
|
|
| def observe(self, data): |
| ''' |
| Called during the training phase, it inputs a batch of training examples and returns the prediction, accuracy, and forward loss. |
| ''' |
|
|
| x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes |
|
|
| if isinstance(self._network, CLIP): |
| features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.curr_text_tokens) |
| elif isinstance(self._network, VIT): |
| features = self._network(x) |
| logits_per_img = [] |
| for prompts in [self.classifier_pool[self._cur_task_id]]: |
| logits_per_img.append(prompts(features)) |
| logits_per_img = torch.cat(logits_per_img, dim=1) |
| else: |
| raise NotImplementedError |
| |
| loss = F.cross_entropy(logits_per_img, y, label_smoothing=self.label_smoothing) |
|
|
| preds = logits_per_img.softmax(dim=-1).argmax(dim=1) |
| acc = preds.eq(y).sum().item() / y.size(0) |
|
|
| return preds, acc, loss |
|
|
| def inference(self, data): |
|
|
| x, y = data['image'].to(self.device), data['label'].to(self.device) |
|
|
| if isinstance(self._network, CLIP): |
| features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens) |
| elif isinstance(self._network, VIT): |
| features = self._network(x) |
| logits_per_img = [] |
| for prompts in self.classifier_pool[:self._cur_task_id + 1]: |
| logits_per_img.append(prompts(features)) |
| logits_per_img = torch.cat(logits_per_img, dim=1) |
| else: |
| raise NotImplementedError |
|
|
| preds = logits_per_img.softmax(dim=-1).argmax(dim=1) |
| acc = preds.eq(y).sum().item() / y.size(0) |
|
|
| return preds, acc |
| |
| def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
| self._cur_task_id = task_idx |
| if task_idx == 1: |
| self._known_classes = self.init_cls_num |
| elif task_idx > 1: |
| self._known_classes += self.inc_cls_num |
|
|
| self.curr_class_names = train_loader.dataset.get_class_names() |
| self.accm_class_names += self.curr_class_names |
|
|
| self.curr_text_tokens = tokenize([self.prompt_template.format(c) for c in self.curr_class_names]).to(self.device) |
| self.accm_text_tokens = tokenize([self.prompt_template.format(c) for c in self.accm_class_names]).to(self.device) |
|
|
| def get_parameters(self, config): |
| return self._network.parameters() |