boringKey's picture
Upload 236 files
5fee096 verified
# -*- coding: utf-8 -*-
"""
TODO: citation
Adapted from TODO: source
"""
import math
import torch
import torch.nn as nn
import numpy as np
from torch import optim
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from tqdm import tqdm
from .backbone.transformer import ResidualAttentionBlock
from .backbone.clip import tokenize, CLIP
from .backbone.vit import ViTZoo
VIT = ViTZoo
CLIP = CLIP
class DMNSP(nn.Module):
def __init__(self, backbone, device, **kwargs):
super().__init__()
self.device = device
self.init_cls_num = kwargs['init_cls_num']
self.inc_cls_num = kwargs['inc_cls_num']
self.label_smoothing = kwargs['label_smoothing']
self._cur_task_id = -1
self._known_classes = 0
self.visual_U = []
self.lamda = [[0 for _ in range(12)] for _ in range(12)]
self.lamda_scale = kwargs['lamda_scale']
self.accm_class_names = []
self.curr_class_names = []
self.accm_text_tokens = None
self.curr_text_tokens = None
self.prompt_template = kwargs['prompt_template']
self._network = backbone
for name, param in self._network.named_parameters():
if 'adapt' not in name:
param.requires_grad = False
if isinstance(self._network, VIT):
self.visual_transformer_blocks = [module for module in self._network.modules() if isinstance(module, ResidualAttentionBlock)]
self.classifier_pool = nn.ModuleList([
nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] +
[nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)]
)
elif isinstance(self._network, CLIP):
self.visual_transformer_blocks = [module for name, module in self._network.named_modules() if isinstance(module, ResidualAttentionBlock) and 'visual' in name]
else:
assert 0
def observe(self, data):
x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes
if isinstance(self._network, CLIP):
features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.curr_text_tokens)
elif isinstance(self._network, ViTZoo):
features = self._network(x)
logits_per_img = []
for prompts in [self.classifier_pool[self._cur_task_id]]:
logits_per_img.append(prompts(features))
logits_per_img = torch.cat(logits_per_img, dim=1)
loss = F.cross_entropy(logits_per_img, y, label_smoothing=self.label_smoothing)
preds = logits_per_img.softmax(dim=-1).argmax(dim=1)
acc = preds.eq(y).sum().item() / y.size(0)
loss.backward()
if self._cur_task_id > 0:
if isinstance(self._network, VIT):
for name, param in self._network.named_parameters():
for i in range(12):
if 'adapt' in name and 'down' in name and 'weight' in name:
v = self.visual_U[i].to(self.device)
v_ = torch.mm(param.grad.data, v)
param.grad.data = torch.mm(v_, v.T) * self.lamda[int(name.split(".")[3])][i]
elif 'adapt' in name and 'up' in name and 'weight' in name:
v = self.visual_U[i].to(self.device)
v_ = torch.mm(v.T, param.grad.data)
param.grad.data = torch.mm(v, v_) * self.lamda[int(name.split(".")[3])][i]
elif isinstance(self._network, CLIP):
for name, param in self._network.named_parameters():
for i in range(12):
if 'visual' in name and 'adapt' in name and 'down' in name and 'weight' in name:
v = self.visual_U[i].to(self.device)
v_ = torch.mm(param.grad.data, v)
param.grad.data = torch.mm(v_, v.T) * self.lamda[int(name.split(".")[3])][i]
elif 'visual' in name and 'adapt' in name and 'up' in name and 'weight' in name:
v = self.visual_U[i].to(self.device)
v_ = torch.mm(v.T, param.grad.data)
param.grad.data = torch.mm(v, v_) * self.lamda[int(name.split(".")[3])][i]
return preds, acc, loss
def inference(self, data, task_id = -1):
x, y = data['image'].to(self.device), data['label'].to(self.device)
if isinstance(self._network, CLIP):
if task_id > -1:
assert self.init_cls_num == self.inc_cls_num
features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens[task_id * self.inc_cls_num : (task_id + 1) * self.inc_cls_num])
else:
features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens)
elif isinstance(self._network, VIT):
if task_id > -1:
assert 0, 'Not Implemented'
else:
features = self._network(x)
logits_per_img = []
for prompts in self.classifier_pool[:self._cur_task_id + 1]:
logits_per_img.append(prompts(features))
logits_per_img = torch.cat(logits_per_img, dim=1)
preds = logits_per_img.softmax(dim=-1).argmax(dim=1)
if task_id > -1:
assert self.init_cls_num == self.inc_cls_num
preds += task_id * self.inc_cls_num
acc = preds.eq(y).sum().item() / y.size(0)
return preds, acc
@torch.no_grad()
def before_task(self, task_idx, buffer, train_loader, test_loaders):
self._cur_task_id = task_idx
if task_idx == 1:
self._known_classes = self.init_cls_num
elif task_idx > 1:
self._known_classes += self.inc_cls_num
self.curr_class_names = train_loader.dataset.get_class_names()
self.accm_class_names += self.curr_class_names
self.curr_text_tokens = tokenize(
[self.prompt_template.format(c) for c in self.curr_class_names]
).to(self.device)
self.accm_text_tokens = tokenize(
[self.prompt_template.format(c) for c in self.accm_class_names]
).to(self.device)
if task_idx > 0:
for data in train_loader:
x = data['image'].to(self.device)
self._network(x, self.curr_text_tokens, compute_lora_feat=True) # will replace last lora_feat
for j in range(12): # Number of layers of both vision transformer and text transformer, hardcoded
activation_visual = self.visual_transformer_blocks[j].lora_feature
activation_visual = torch.bmm(activation_visual.permute(1, 2, 0),
activation_visual.permute(1, 0, 2)).sum(dim=0)
U_visual, _, _ = torch.linalg.svd(activation_visual, full_matrices=False)
U_visual = U_visual[:, 0:1]
for k in range(12):
v_visual = self.visual_U[k]
normalized_vector_visual = U_visual / torch.norm(U_visual)
similarities_visual = []
for column_visual in v_visual.t():
normalized_column_visual = column_visual / torch.norm(column_visual)
cos_sim_visual = torch.dot(normalized_vector_visual.squeeze(),
normalized_column_visual.squeeze())
similarities_visual.append(cos_sim_visual)
dot_products_visual = torch.mean(torch.topk(torch.stack(similarities_visual), int(len(similarities_visual) * 00.1))[0])
self.lamda[j][k] = torch.exp(-dot_products_visual) * self.lamda_scale
break # first batch only
@torch.no_grad()
def after_task(self, task_idx, buffer, train_loader, test_loaders):
for data in train_loader:
x = data['image'].to(self.device)
self._network(x, self.curr_text_tokens, compute_lora_feat=True) # will replace last lora_feat
for i in range(12):
activation = self.visual_transformer_blocks[i].lora_feature
activation = torch.bmm(activation.permute(1, 2, 0),
activation.permute(1, 0, 2)).sum(dim=0)
U, _, _ = torch.linalg.svd(activation, full_matrices=False)
if task_idx == 0:
r = 0
self.visual_U.append(U[:,max(r,1):])
else:
r = 1
Ui = torch.cat((self.visual_U[i], U[:, r:]), dim=1)
self.visual_U[i] = Ui
break # first batch only
def get_parameters(self, config):
return self._network.parameters()