| | |
| | """ |
| | TODO: citation |
| | |
| | Adapted from TODO: source |
| | """ |
| |
|
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| |
|
| | from torch import optim |
| | from torch.nn import functional as F |
| | from torch.nn.parameter import Parameter |
| | from tqdm import tqdm |
| |
|
| | from .backbone.transformer import ResidualAttentionBlock |
| | from .backbone.clip import tokenize, CLIP |
| | from .backbone.vit import ViTZoo |
| |
|
| | VIT = ViTZoo |
| | CLIP = CLIP |
| |
|
| | class DMNSP(nn.Module): |
| |
|
| | def __init__(self, backbone, device, **kwargs): |
| | super().__init__() |
| |
|
| | self.device = device |
| | self.init_cls_num = kwargs['init_cls_num'] |
| | self.inc_cls_num = kwargs['inc_cls_num'] |
| | self.label_smoothing = kwargs['label_smoothing'] |
| |
|
| | self._cur_task_id = -1 |
| | self._known_classes = 0 |
| | self.visual_U = [] |
| | self.lamda = [[0 for _ in range(12)] for _ in range(12)] |
| | self.lamda_scale = kwargs['lamda_scale'] |
| |
|
| | self.accm_class_names = [] |
| | self.curr_class_names = [] |
| | self.accm_text_tokens = None |
| | self.curr_text_tokens = None |
| |
|
| | self.prompt_template = kwargs['prompt_template'] |
| | |
| | self._network = backbone |
| |
|
| | for name, param in self._network.named_parameters(): |
| | if 'adapt' not in name: |
| | param.requires_grad = False |
| |
|
| | if isinstance(self._network, VIT): |
| | self.visual_transformer_blocks = [module for module in self._network.modules() if isinstance(module, ResidualAttentionBlock)] |
| |
|
| | self.classifier_pool = nn.ModuleList([ |
| | nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + |
| | [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)] |
| | ) |
| |
|
| | elif isinstance(self._network, CLIP): |
| | self.visual_transformer_blocks = [module for name, module in self._network.named_modules() if isinstance(module, ResidualAttentionBlock) and 'visual' in name] |
| | else: |
| | assert 0 |
| |
|
| | def observe(self, data): |
| |
|
| | x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes |
| |
|
| | if isinstance(self._network, CLIP): |
| | features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.curr_text_tokens) |
| | elif isinstance(self._network, ViTZoo): |
| | features = self._network(x) |
| | logits_per_img = [] |
| | for prompts in [self.classifier_pool[self._cur_task_id]]: |
| | logits_per_img.append(prompts(features)) |
| | logits_per_img = torch.cat(logits_per_img, dim=1) |
| |
|
| | loss = F.cross_entropy(logits_per_img, y, label_smoothing=self.label_smoothing) |
| |
|
| | preds = logits_per_img.softmax(dim=-1).argmax(dim=1) |
| | acc = preds.eq(y).sum().item() / y.size(0) |
| |
|
| | loss.backward() |
| |
|
| | if self._cur_task_id > 0: |
| |
|
| | if isinstance(self._network, VIT): |
| |
|
| | for name, param in self._network.named_parameters(): |
| | for i in range(12): |
| | if 'adapt' in name and 'down' in name and 'weight' in name: |
| |
|
| | v = self.visual_U[i].to(self.device) |
| | v_ = torch.mm(param.grad.data, v) |
| | param.grad.data = torch.mm(v_, v.T) * self.lamda[int(name.split(".")[3])][i] |
| |
|
| | elif 'adapt' in name and 'up' in name and 'weight' in name: |
| |
|
| | v = self.visual_U[i].to(self.device) |
| | v_ = torch.mm(v.T, param.grad.data) |
| | param.grad.data = torch.mm(v, v_) * self.lamda[int(name.split(".")[3])][i] |
| |
|
| | elif isinstance(self._network, CLIP): |
| |
|
| | for name, param in self._network.named_parameters(): |
| | for i in range(12): |
| | if 'visual' in name and 'adapt' in name and 'down' in name and 'weight' in name: |
| |
|
| | v = self.visual_U[i].to(self.device) |
| | v_ = torch.mm(param.grad.data, v) |
| | param.grad.data = torch.mm(v_, v.T) * self.lamda[int(name.split(".")[3])][i] |
| |
|
| | elif 'visual' in name and 'adapt' in name and 'up' in name and 'weight' in name: |
| |
|
| | v = self.visual_U[i].to(self.device) |
| | v_ = torch.mm(v.T, param.grad.data) |
| | param.grad.data = torch.mm(v, v_) * self.lamda[int(name.split(".")[3])][i] |
| |
|
| |
|
| | return preds, acc, loss |
| |
|
| | def inference(self, data, task_id = -1): |
| |
|
| | x, y = data['image'].to(self.device), data['label'].to(self.device) |
| |
|
| | if isinstance(self._network, CLIP): |
| | if task_id > -1: |
| | assert self.init_cls_num == self.inc_cls_num |
| | features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens[task_id * self.inc_cls_num : (task_id + 1) * self.inc_cls_num]) |
| | else: |
| | features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens) |
| | elif isinstance(self._network, VIT): |
| | if task_id > -1: |
| | assert 0, 'Not Implemented' |
| | else: |
| | features = self._network(x) |
| | logits_per_img = [] |
| | for prompts in self.classifier_pool[:self._cur_task_id + 1]: |
| | logits_per_img.append(prompts(features)) |
| | logits_per_img = torch.cat(logits_per_img, dim=1) |
| |
|
| | preds = logits_per_img.softmax(dim=-1).argmax(dim=1) |
| |
|
| | if task_id > -1: |
| | assert self.init_cls_num == self.inc_cls_num |
| | preds += task_id * self.inc_cls_num |
| |
|
| | acc = preds.eq(y).sum().item() / y.size(0) |
| |
|
| | return preds, acc |
| | |
| | @torch.no_grad() |
| | def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| | |
| | self._cur_task_id = task_idx |
| | if task_idx == 1: |
| | self._known_classes = self.init_cls_num |
| | elif task_idx > 1: |
| | self._known_classes += self.inc_cls_num |
| |
|
| | self.curr_class_names = train_loader.dataset.get_class_names() |
| | self.accm_class_names += self.curr_class_names |
| |
|
| | self.curr_text_tokens = tokenize( |
| | [self.prompt_template.format(c) for c in self.curr_class_names] |
| | ).to(self.device) |
| |
|
| | self.accm_text_tokens = tokenize( |
| | [self.prompt_template.format(c) for c in self.accm_class_names] |
| | ).to(self.device) |
| |
|
| | if task_idx > 0: |
| | for data in train_loader: |
| | x = data['image'].to(self.device) |
| | self._network(x, self.curr_text_tokens, compute_lora_feat=True) |
| |
|
| | for j in range(12): |
| | activation_visual = self.visual_transformer_blocks[j].lora_feature |
| | activation_visual = torch.bmm(activation_visual.permute(1, 2, 0), |
| | activation_visual.permute(1, 0, 2)).sum(dim=0) |
| | U_visual, _, _ = torch.linalg.svd(activation_visual, full_matrices=False) |
| | U_visual = U_visual[:, 0:1] |
| |
|
| | for k in range(12): |
| | v_visual = self.visual_U[k] |
| | normalized_vector_visual = U_visual / torch.norm(U_visual) |
| | similarities_visual = [] |
| |
|
| | for column_visual in v_visual.t(): |
| | normalized_column_visual = column_visual / torch.norm(column_visual) |
| | cos_sim_visual = torch.dot(normalized_vector_visual.squeeze(), |
| | normalized_column_visual.squeeze()) |
| | similarities_visual.append(cos_sim_visual) |
| |
|
| | dot_products_visual = torch.mean(torch.topk(torch.stack(similarities_visual), int(len(similarities_visual) * 00.1))[0]) |
| | self.lamda[j][k] = torch.exp(-dot_products_visual) * self.lamda_scale |
| |
|
| | break |
| |
|
| | @torch.no_grad() |
| | def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
|
| | for data in train_loader: |
| | x = data['image'].to(self.device) |
| | self._network(x, self.curr_text_tokens, compute_lora_feat=True) |
| |
|
| | for i in range(12): |
| |
|
| | activation = self.visual_transformer_blocks[i].lora_feature |
| | |
| | activation = torch.bmm(activation.permute(1, 2, 0), |
| | activation.permute(1, 0, 2)).sum(dim=0) |
| |
|
| | U, _, _ = torch.linalg.svd(activation, full_matrices=False) |
| |
|
| | if task_idx == 0: |
| | r = 0 |
| | self.visual_U.append(U[:,max(r,1):]) |
| | else: |
| | r = 1 |
| | Ui = torch.cat((self.visual_U[i], U[:, r:]), dim=1) |
| | self.visual_U[i] = Ui |
| |
|
| | break |
| |
|
| | def get_parameters(self, config): |
| | return self._network.parameters() |