File size: 9,836 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 | '''
@misc{mcdonnell2024ranpacrandomprojectionspretrained,
title={RanPAC: Random Projections and Pre-trained Models for Continual Learning},
author={Mark D. McDonnell and Dong Gong and Amin Parveneh and Ehsan Abbasnejad and Anton van den Hengel},
year={2024},
eprint={2307.02251},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2307.02251},
}
Code Reference:
https://github.com/RanPAC/RanPAC
'''
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone.transformer import MultiHeadAttention_LoRA, VisionTransformer
from .backbone.clip import CLIP, tokenize
from .backbone.vit import ViTZoo, ViT_in21k_adapter
VIT = ViT_in21k_adapter
CLIP = CLIP
class CosineLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features))
self.sigma = nn.Parameter(torch.Tensor(1))
self.reset_parameters()
self.use_RP = False
self.W_rand = None
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
self.sigma.data.fill_(1)
def forward(self, input):
if not self.use_RP:
out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
else:
if self.W_rand is not None:
inn = F.relu(input @ self.W_rand)
else:
assert 0, 'should not reach here, for now'
inn = input
out = F.linear(inn, self.weight)
out = self.sigma * out
return out
class Network(nn.Module):
def __init__(self, backbone, device, **kwargs):
super().__init__()
self._cur_task_id = -1
self.backbone = backbone
self.device = device
self.classifier = None
if isinstance(self.backbone, VIT):
self.feature_dim = self.backbone.feat_dim
elif isinstance(self.backbone, CLIP):
# Assuming the final features_dim is concat of image and text
self.feature_dim = self.backbone.visual.output_dim + self.backbone.transformer.width
self.accm_class_names = []
self.curr_class_names = []
self.accm_text_tokens = None
self.curr_text_tokens = None
self.prompt_template = kwargs['prompt_template']
def update_classifer(self, num_classes, train_loader):
if isinstance(self.backbone, VIT):
pass
elif isinstance(self.backbone, CLIP):
self.curr_class_names = train_loader.dataset.get_class_names()
self.accm_class_names += self.curr_class_names
self.curr_text_tokens = tokenize(
[self.prompt_template.format(c) for c in self.curr_class_names]
).to(self.device)
self.accm_text_tokens = tokenize(
[self.prompt_template.format(c) for c in self.accm_class_names]
).to(self.device)
else:
assert 0
self._cur_task_id += 1
del self.classifier
self.classifier = CosineLinear(self.feature_dim, num_classes).to(self.device)
def get_feature(self, x):
if isinstance(self.backbone, VIT):
return self.backbone(x)
elif isinstance(self.backbone, CLIP):
features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.curr_text_tokens)
max_indices = logits_per_image.softmax(dim=-1).argmax(dim=1) # Shape will be [48]
max_features = features_text[max_indices] # Shape will be [48, 768]
return torch.cat([features_image, max_features], dim=1) # Shape will be [48, 1536]
else:
assert 0
def forward(self, x, inference=False):
if isinstance(self.backbone, VIT):
features = self.backbone(x)
elif isinstance(self.backbone, CLIP):
if inference:
features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.accm_text_tokens)
else:
features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.curr_text_tokens)
max_indices = logits_per_image.softmax(dim=-1).argmax(dim=1) # Shape will be [48]
max_features = features_text[max_indices] # Shape will be [48, 768]
features = torch.cat([features_image, max_features], dim=1) # Shape will be [48, 1536]
else:
assert 0
return self.classifier(features)
class RanPAC(nn.Module):
def __init__(self, backbone, device, **kwargs):
super().__init__()
self._network = Network(backbone, device, **kwargs)
self.device = device
self.first_session_training = kwargs["first_session_training"]
self.init_cls_num = kwargs["init_cls_num"]
self.inc_cls_num = kwargs["inc_cls_num"]
self.total_cls_num = kwargs['total_cls_num']
self.task_num = kwargs["task_num"]
#self.use_RP = kwargs["use_RP"]
self.M = kwargs['M']
self._known_classes = 0
self._classes_seen_so_far = 0
self._skip_train = False # this flag is used to skip training
self._network.to(self.device)
if isinstance(backbone, CLIP):
for name, param in self._network.named_parameters():
if 'adapt' not in name:
param.requires_grad = False
def before_task(self, task_idx, buffer, train_loader, test_loaders):
if task_idx == 0:
self._classes_seen_so_far = self.init_cls_num
elif task_idx > 0:
self._classes_seen_so_far += self.inc_cls_num
self._network.update_classifer(self._classes_seen_so_far, train_loader)
if task_idx == 0 and self.first_session_training:
self._skip_train = False
else:
self._skip_train = True
print(f"Not training on task {task_idx}")
def observe(self, data):
if self._skip_train:
# set required_grad be True so that it can call backward() but don't do anything
return None, 0., torch.tensor(0., device = self.device, requires_grad = True)
inputs, targets = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes
logits = self._network(inputs)
loss = F.cross_entropy(logits, targets)
_, preds = torch.max(logits, dim=1)
correct = preds.eq(targets.expand_as(preds)).sum().item()
total = len(targets)
acc = round(correct / total, 4)
return preds, acc, loss
def inference(self, data):
inputs, targets = data['image'].to(self.device), data['label']
logits = self._network(inputs, True)
_, preds = torch.max(logits, dim=1)
correct = preds.cpu().eq(targets.expand_as(preds)).sum().item()
total = len(targets)
acc = round(correct / total, 4)
return logits, acc
def after_task(self, task_idx, buffer, train_loader, test_loaders):
self._known_classes = self._classes_seen_so_far
if task_idx == 0:
# Initialize attribute for random projection classifier
self.W_rand = torch.randn(self._network.classifier.in_features, self.M)
self.Q = torch.zeros(self.M, self.init_cls_num) # C
self.G = torch.zeros(self.M, self.M)
else:
self.Q = torch.cat((self.Q, torch.zeros(self.M, self.inc_cls_num)), dim=1)
self.update_rp_classifier(train_loader, test_loaders[0].dataset.trfms)
@torch.no_grad()
def update_rp_classifier(self, train_loader, test_trfms):
self._network.eval()
train_loader.dataset.trfms = test_trfms
self._network.classifier.use_RP = True
self._network.classifier.W_rand = self.W_rand.to(self.device) # feature_dim x M
feature_list, label_list = [], []
for batch in train_loader:
x, y = batch['image'].to(self.device), batch['label']
feature_list.append(self._network.get_feature(x).cpu())
label_list.append(y)
feature_list, label_list = torch.cat(feature_list, dim = 0), torch.cat(label_list, dim = 0)
label_list = F.one_hot(label_list, self._classes_seen_so_far).to(torch.float32)
proj_feature_list = F.relu(feature_list @ self.W_rand)
self.Q += proj_feature_list.T @ label_list
self.G += proj_feature_list.T @ proj_feature_list
ridges = 10.0**np.arange(-8,9)
num_val_samples = int(proj_feature_list.shape[0] * 0.8)
losses = []
Q_val = proj_feature_list[:num_val_samples, :].T @ label_list[:num_val_samples, :]
G_val = proj_feature_list[:num_val_samples, :].T @ proj_feature_list[:num_val_samples, :]
for ridge in ridges:
Wo = torch.linalg.solve(G_val + ridge * torch.eye(self.M), Q_val).T #better nmerical stability than .inv
Y_train_pred = proj_feature_list[num_val_samples:, :] @ Wo.T
losses.append(F.mse_loss(Y_train_pred, label_list[num_val_samples:, :]))
ridge = ridges[np.argmin(np.array(losses))]
print(f"Optimal lambda: {ridge}")
Wo = torch.linalg.solve(self.G + ridge * torch.eye(self.M), self.Q).T #better nmerical stability than .inv
self._network.classifier.weight.data = Wo[:self._network.classifier.weight.shape[0], :].to(self.device) # num_classes x M
def get_parameters(self, config):
return self._network.parameters() |