boringKey's picture
Upload 236 files
5fee096 verified
"""
@article{lin2022trgp,
title={TRGP: Trust Region Gradient Projection for Continual Learning},
author={Lin, Sen and Yang, Li and Fan, Deliang and Zhang, Junshan},
journal={arXiv preprint arXiv:2202.02931},
year={2022}
}
Code Reference:
https://github.com/LYang-666/TRGP
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .backbone.alexnet import Conv2d_TRGP, Linear_TRGP, AlexNet_TRGP
from .backbone.clip import tokenize, CLIP
Epsilon = 0.5
AlexNet = AlexNet_TRGP
Clip = CLIP
class TopK:
'''
A class to maintain a collection of the top K items based on a specified attribute.
This class allows for the dynamic addition of items, each represented as a dictionary,
where each dictionary must have a key 'proj_norm' that represents the value used
to determine the ranking. The class keeps track of the top K items with the highest
'proj_norm' values.
'''
def __init__(self, k):
self.k = k
self.top_k_list = []
def add(self, dict):
if len(self.top_k_list) < self.k:
self.top_k_list.append(dict)
elif dict['proj_norm'] > min(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm']:
self.top_k_list.remove(min(self.top_k_list, key=lambda x: x['proj_norm']))
self.top_k_list.append(dict)
def get_top_k(self):
return self.top_k_list
class Network(nn.Module):
def __init__(self, backbone, **kwargs):
super().__init__()
self.backbone = backbone
self.classifiers = nn.ModuleList([
nn.Linear(backbone.feat_dim, kwargs['init_cls_num'], bias = False)] +
[nn.Linear(backbone.feat_dim, kwargs['inc_cls_num'], bias = False) for _ in range(kwargs['task_num'] - 1)]
)
def return_hidden(self, data):
return self.backbone(data)
def forward(self, data, compute_input_matrix = False):
logits = []
image_features = self.backbone(data, compute_input_matrix)
for classifier in self.classifiers:
logits.append(classifier(image_features))
return logits
class TRGP(nn.Module):
def __init__(self, backbone, device, **kwargs):
super().__init__()
self.backbone = backbone
self.device = device
self.task_num = kwargs["task_num"]
self.init_cls_num = kwargs["init_cls_num"]
self.inc_cls_num = kwargs["inc_cls_num"]
self.label_smoothing = kwargs['label_smoothing']
self._known_classes = 0
self.feature_list = []
self.feature_mat = []
self.layers = []
if isinstance(backbone, Clip):
self.network = backbone
self.visual_U = []
self.lamda = [[0 for _ in range(12)] for _ in range(12)]
self.accm_class_names = []
self.curr_class_names = []
self.accm_text_tokens = None
self.curr_text_tokens = None
self.prompt_template = kwargs['prompt_template']
# 12 Visual Transformer's Adapter * 2 (Down & Up)
for name, module in self.network.named_modules():
if 'visual' in name and isinstance(module, Linear_TRGP):
self.layers.append(module)
for name, param in self.network.named_parameters():
param.requires_grad = False
if 'adaptmlp' in name:
param.requires_grad = True
elif isinstance(backbone, AlexNet):
self.network = Network(backbone, **kwargs)
# # 3 Conv2d and 2 Linear
for module in self.network.modules(): #
if isinstance(module, Conv2d_TRGP) or isinstance(module, Linear_TRGP):
self.layers.append(module)
else:
raise NotImplementedError
self.feature_list_each_tasks = [[0 for _ in range(len(self.layers))] for _ in range(self.task_num)]
self.scale_param_each_tasks_each_layers = [[0 for _ in range(len(self.layers))] for _ in range(self.task_num)]
self.all_space = [[0 for _ in range(len(self.layers))] for _ in range(self.task_num)]
self.network.to(self.device)
def observe(self, data):
x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes
if len(y) == 1: # Ignore batch_size == 1
return None, 0., torch.zeros(1, requires_grad=True)
if isinstance(self.backbone, Clip):
features_img, features_txt, logits_per_img, logits_per_txt = self.network(x, self.curr_text_tokens)
loss = F.cross_entropy(logits_per_img, y, label_smoothing=self.label_smoothing)
preds = logits_per_img.softmax(dim=-1).argmax(dim=1)
loss.backward()
if self.cur_task > 0:
for i, module in enumerate(self.layers):
sz = module.weight.grad.data.shape[0]
module.weight.grad.data = module.weight.grad.data - (module.weight.grad.data.view(sz,-1) @ self.feature_mat[i]).view(module.weight.shape)
elif isinstance(self.backbone, AlexNet):
logits = self.network(x)
loss = F.cross_entropy(logits[self.cur_task], y, label_smoothing=self.label_smoothing)
preds = logits[self.cur_task].max(1)[1]
loss.backward()
if self.cur_task > 0:
for i, module in enumerate(self.layers):
sz = module.weight.grad.data.shape[0]
module.weight.grad.data = module.weight.grad.data - (module.weight.grad.data.view(sz,-1) @ self.feature_mat[i]).view(module.weight.shape)
else:
raise NotImplementedError
acc = preds.eq(y).sum().item() / y.size(0)
return preds, acc, loss
def inference(self, data, task_id = -1):
x, y = data['image'].to(self.device), data['label'].to(self.device)
# Add dummy, to prevert batch_size == 1
dummy_x = torch.randn_like(x[:1]) # only one dummy sample
x = torch.cat([x, dummy_x], dim=0)
# Task-Aware (Task-Incremetanl Scenario)
if task_id > -1:
if task_id == 0:
bias_classes = 0
elif task_id == 1:
bias_classes = self.init_cls_num
else:
bias_classes = self.init_cls_num + (task_id - 1) * self.inc_cls_num
if isinstance(self.backbone, Clip):
for i, module in enumerate(self.layers):
module.space = self.all_space[task_id][i]
module.scale_param = nn.ParameterList([nn.Parameter(scale_param) for scale_param in self.scale_param_each_tasks_each_layers[task_id][i]])
features_img, features_txt, logits_per_img, logits_per_txt = self.network(x, self.accm_text_tokens[bias_classes : self.init_cls_num + task_id * self.inc_cls_num])
preds = logits_per_img.softmax(dim=-1).argmax(dim=1) + bias_classes
elif isinstance(self.backbone, AlexNet):
for i, module in enumerate(self.layers):
module.space = self.all_space[task_id][i]
module.scale_param = nn.ParameterList([nn.Parameter(scale_param) for scale_param in self.scale_param_each_tasks_each_layers[task_id][i]])
logits = self.network(x)
preds = logits[task_id].softmax(dim=-1).argmax(dim=1) + bias_classes
else:
raise NotImplementedError
# Task-Agnostic (Class-Incremetanl Scenario)
else:
logits = []
if isinstance(self.backbone, Clip):
for t in range(self.cur_task + 1):
for i, module in enumerate(self.layers):
module.space = self.all_space[t][i]
module.scale_param = nn.ParameterList([nn.Parameter(scale_param) for scale_param in self.scale_param_each_tasks_each_layers[t][i]])
if t == 0:
features_img, features_txt, logits_per_img, logits_per_txt = self.network(x, self.accm_text_tokens[: self.init_cls_num])
else:
features_img, features_txt, logits_per_img, logits_per_txt = self.network(x, self.accm_text_tokens[self.init_cls_num + (t-1) * self.inc_cls_num : self.init_cls_num + t * self.inc_cls_num])
logits.append(logits_per_img)
elif isinstance(self.backbone, AlexNet):
for t in range(self.cur_task + 1):
for i, module in enumerate(self.layers):
module.space = self.all_space[t][i]
module.scale_param = nn.ParameterList([nn.Parameter(scale_param) for scale_param in self.scale_param_each_tasks_each_layers[t][i]])
logits.append(self.network(x)[t])
else:
raise NotImplementedError
preds = torch.cat(logits, dim=-1).softmax(dim=-1).argmax(dim=1)
# Remove dummy
preds = preds[:-1]
correct_count = preds.eq(y).sum().item()
acc = correct_count / y.size(0)
return preds, acc
def before_task(self, task_idx, buffer, train_loader, test_loaders):
# Last task have scale_param and space, need to init again
for module in self.layers:
module.disable_scale()
self.cur_task = task_idx
if isinstance(self.backbone, Clip):
self.curr_class_names = train_loader.dataset.get_class_names()
self.accm_class_names += self.curr_class_names
self.curr_text_tokens = tokenize([self.prompt_template.format(c) for c in self.curr_class_names]).to(self.device)
self.accm_text_tokens = tokenize([self.prompt_template.format(c) for c in self.accm_class_names]).to(self.device)
if task_idx > 0:
self.feature_mat = [torch.tensor(feat @ feat.T, dtype=torch.float32, device=self.device) for feat in self.feature_list]
optimizer = torch.optim.SGD(self.network.parameters(), lr = 0.01) # lr hardcoded
x, y = [], []
for batch in train_loader:
x.append(batch['image'].to(self.device))
y.append(batch['label'].to(self.device) - self._known_classes)
x, y = torch.cat(x, dim = 0), torch.cat(y, dim = 0)
indices = torch.randperm(x.size(0))
selected_indices = indices[:125]
x, y = x[selected_indices], y[selected_indices]
optimizer.zero_grad()
if isinstance(self.backbone, Clip):
features_img, features_txt, logits_per_img, logits_per_txt = self.network(x, self.curr_text_tokens)
loss = F.cross_entropy(logits_per_img, y)
elif isinstance(self.backbone, AlexNet):
logits = self.network(x)
loss = F.cross_entropy(logits[self.cur_task], y)
loss.backward()
for i, module in enumerate(self.layers):
topk = TopK(2)
grad = module.weight.grad.data.detach().cpu().numpy()
if isinstance(self.backbone, AlexNet) and isinstance(module, Conv2d_TRGP):
grad = grad.reshape(grad.shape[0], -1)
for task_id in range(task_idx):
proj = grad @ self.feature_list_each_tasks[task_id][i] @ self.feature_list_each_tasks[task_id][i].T
proj_norm = np.linalg.norm(proj)
print(f'Layer {i} of {task_idx} to {task_id} : {proj_norm:.4f}/{np.linalg.norm(grad):.4f} ({proj_norm > Epsilon * np.linalg.norm(grad)})')
if proj_norm > Epsilon * np.linalg.norm(grad):
topk.add({'proj_norm':proj_norm, 'task_id': task_id})
final_decision = [dic['task_id'] for dic in topk.get_top_k()]
module.enable_scale([
torch.tensor(self.feature_list_each_tasks[task_id][i], dtype=torch.float32).to(self.device) for task_id in final_decision
])
print(f'Layer {i} of {task_idx} consider {final_decision} as trust region')
def after_task(self, task_idx, buffer, train_loader, test_loaders):
self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num
# Save the scale param
for i, module in enumerate(self.layers):
self.scale_param_each_tasks_each_layers[task_idx][i] = [scale_param.data for scale_param in module.scale_param] # top2
self.all_space[task_idx][i] = module.space
module.disable_scale()
x = torch.cat([batch['image'].to(self.device) for batch in train_loader], dim = 0)
# hardcoded, choose 125 input from it
indices = torch.randperm(x.size(0))
selected_indices = indices[:125]
x = x[selected_indices]
self.network.eval()
mat_list = [] # Representation / Activation of each layer
threshold = 0.97 + task_idx * 0.003
if isinstance(self.backbone, Clip):
self.network(x, self.curr_text_tokens, compute_input_matrix = True)
for module in self.layers:
assert module.input_matrix.shape[0] == 125
mat_list.append(module.input_matrix.view(-1, module.input_matrix.shape[-1]).detach().cpu().numpy().T)
elif isinstance(self.backbone, AlexNet):
self.network(x, compute_input_matrix = True)
batch_list = [2*12,100,100]
ksize = [4, 3, 2] # kernel size of each conv layer
conv_output_size = [29, 12, 5] # output size of each conv layer
in_channel = [3, 64, 128] # input channel of each conv layer
for i, module in enumerate(self.layers):
if isinstance(module, Conv2d_TRGP):
bsz, ksz, s, inc = batch_list[i], ksize[i], conv_output_size[i], in_channel[i]
mat = np.zeros((ksz * ksz * inc, s * s * bsz))
act = module.input_matrix.detach().cpu().numpy()
k = 0
for kk in range(bsz):
for ii in range(s):
for jj in range(s):
mat[:,k]=act[kk, :, ii:ksz+ii, jj:ksz+jj].reshape(-1)
k += 1
mat_list.append(mat)
elif isinstance(module, Linear_TRGP):
mat_list.append(module.input_matrix.detach().cpu().numpy().T)
# get the space for each layer
if task_idx == 0:
for i, activation in enumerate(mat_list):
U, S, _ = np.linalg.svd(activation, full_matrices = False)
# criteria (Eq-5)
sval_total = (S**2).sum()
sval_ratio = (S**2)/sval_total
r = np.sum(np.cumsum(sval_ratio) < threshold)
self.feature_list_each_tasks[task_idx][i] = U[:, :r]
self.feature_list.append(U[:, :r])
else:
for i, activation in enumerate(mat_list):
_, S, _ = np.linalg.svd(activation, full_matrices = False)
sval_total = (S**2).sum()
delta = (self.feature_list[i].T @ activation @ activation.T @ self.feature_list[i]).diagonal()
# following the GPM to get the sigma (S**2)
act_hat = activation - self.feature_list[i] @ self.feature_list[i].T @ activation
U, S, _ = np.linalg.svd(act_hat, full_matrices=False)
sigma = S**2
# stack delta and sigma, then sort in descending order
stack = np.hstack((delta, sigma))
stack_index = np.argsort(stack)[::-1] # the index of each element in descending sorted array
stack = np.sort(stack)[::-1] # descending sorted array
if threshold * sval_total <= 0:
r = 0
else:
r = min(np.sum(np.cumsum(stack) < threshold * sval_total) + 1, activation.shape[0])
Ui = np.hstack((self.feature_list[i], U))
sel_each = stack_index[:r]
sel_overall = sel_each[sel_each >= len(delta)] # without overlap
self.feature_list[i] = np.hstack((self.feature_list[i], Ui[:, sel_overall]))
self.feature_list_each_tasks[task_idx][i] = Ui[:, sel_each]
if sel_overall.shape[0] == 0:
print(f'Skip Updating Space for layer: {i+1}')
def get_parameters(self, config):
return self.network.parameters()