File size: 18,698 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 | # -*- coding: utf-8 -*-
"""
@inproceedings{DBLP:conf/cvpr/SmithKGCKAPFK23,
author = {James Seale Smith and
Leonid Karlinsky and
Vyshnavi Gutta and
Paola Cascante{-}Bonilla and
Donghyun Kim and
Assaf Arbelle and
Rameswar Panda and
Rog{\'{e}}rio Feris and
Zsolt Kira},
title = {CODA-Prompt: COntinual Decomposed Attention-Based Prompting for Rehearsal-Free
Continual Learning},
booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition,
{CVPR} 2023, Vancouver, BC, Canada, June 17-24, 2023},
pages = {11909--11919},
publisher = {{IEEE}},
year = {2023}
}
https://arxiv.org/abs/2211.13218
Adapted from https://github.com/GT-RIPL/CODA-Prompt
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torchvision.models as models
from torch.autograd import Variable
import numpy as np
import copy
# source code from https://github.com/GT-RIPL/CODA-Prompt
class CodaPrompt(nn.Module):
def __init__(self, emb_d, n_tasks, prompt_param, key_dim=768):
super().__init__()
self.task_count = 0
self.emb_d = emb_d
self.key_d = key_dim
self.n_tasks = n_tasks
self._init_smart(emb_d, prompt_param)
# e prompt init
for e in self.e_layers:
# for model saving/loading simplicity, we init the full parameters here
# however, please note that we reinit the new components at each task
# in the "spirit of continual learning", as we don't know how many tasks
# we will encounter at the start of the task sequence
#
# in the original paper, we used ortho init at the start - this modification is more
# fair in the spirit of continual learning and has little affect on performance
e_l = self.e_p_length
p = tensor_prompt(self.e_pool_size, e_l, emb_d)
k = tensor_prompt(self.e_pool_size, self.key_d)
a = tensor_prompt(self.e_pool_size, self.key_d)
p = self.gram_schmidt(p)
k = self.gram_schmidt(k)
a = self.gram_schmidt(a)
setattr(self, f'e_p_{e}',p)
setattr(self, f'e_k_{e}',k)
setattr(self, f'e_a_{e}',a)
def _init_smart(self, emb_d, prompt_param):
# prompt basic param
self.e_pool_size = int(prompt_param[0])
self.e_p_length = int(prompt_param[1])
self.e_layers = [0,1,2,3,4]
# strenth of ortho penalty
self.ortho_mu = prompt_param[2]
def process_task_count(self):
self.task_count += 1
# in the spirit of continual learning, we will reinit the new components
# for the new task with Gram Schmidt
#
# in the original paper, we used ortho init at the start - this modification is more
# fair in the spirit of continual learning and has little affect on performance
#
# code for this function is modified from:
# https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py
for e in self.e_layers:
K = getattr(self,f'e_k_{e}')
A = getattr(self,f'e_a_{e}')
P = getattr(self,f'e_p_{e}')
k = self.gram_schmidt(K)
a = self.gram_schmidt(A)
p = self.gram_schmidt(P)
setattr(self, f'e_p_{e}',p)
setattr(self, f'e_k_{e}',k)
setattr(self, f'e_a_{e}',a)
# code for this function is modified from:
# https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py
def gram_schmidt(self, vv):
def projection(u, v):
denominator = (u * u).sum()
if denominator < 1e-8:
return None
else:
return (v * u).sum() / denominator * u
# check if the tensor is 3D and flatten the last two dimensions if necessary
is_3d = len(vv.shape) == 3
if is_3d:
shape_2d = copy.deepcopy(vv.shape)
vv = vv.view(vv.shape[0],-1)
# swap rows and columns
vv = vv.T
# process matrix size
nk = vv.size(1)
uu = torch.zeros_like(vv, device=vv.device)
# get starting point
pt = int(self.e_pool_size / (self.n_tasks))
s = int(self.task_count * pt)
f = int((self.task_count + 1) * pt)
if s > 0:
uu[:, 0:s] = vv[:, 0:s].clone()
for k in range(s, f):
redo = True
while redo:
redo = False
vk = torch.randn_like(vv[:,k]).to(vv.device)
uk = 0
for j in range(0, k):
if not redo:
uj = uu[:, j].clone()
proj = projection(uj, vk)
if proj is None:
redo = True
print('restarting!!!')
else:
uk = uk + proj
if not redo: uu[:, k] = vk - uk
for k in range(s, f):
uk = uu[:, k].clone()
uu[:, k] = uk / (uk.norm())
# undo swapping of rows and columns
uu = uu.T
# return from 2D
if is_3d:
uu = uu.view(shape_2d)
return torch.nn.Parameter(uu)
def forward(self, x_querry, l, x_block, train=False, task_id=None):
# e prompts
e_valid = False
if l in self.e_layers:
e_valid = True
B, C = x_querry.shape
K = getattr(self,f'e_k_{l}')
A = getattr(self,f'e_a_{l}')
p = getattr(self,f'e_p_{l}')
pt = int(self.e_pool_size / (self.n_tasks))
s = int(self.task_count * pt)
f = int((self.task_count + 1) * pt)
# freeze/control past tasks
if train:
if self.task_count > 0:
K = torch.cat((K[:s].detach().clone(),K[s:f]), dim=0)
A = torch.cat((A[:s].detach().clone(),A[s:f]), dim=0)
p = torch.cat((p[:s].detach().clone(),p[s:f]), dim=0)
else:
K = K[s:f]
A = A[s:f]
p = p[s:f]
else:
K = K[0:f]
A = A[0:f]
p = p[0:f]
# with attention and cosine sim
# (b x 1 x d) * soft([1 x k x d]) = (b x k x d) -> attention = k x d
a_querry = torch.einsum('bd,kd->bkd', x_querry, A)
# # (b x k x d) - [1 x k x d] = (b x k) -> key = k x d
n_K = nn.functional.normalize(K, dim=1)
q = nn.functional.normalize(a_querry, dim=2)
aq_k = torch.einsum('bkd,kd->bk', q, n_K)
# (b x 1 x k x 1) * [1 x plen x k x d] = (b x plen x d) -> prompt = plen x k x d
P_ = torch.einsum('bk,kld->bld', aq_k, p)
# select prompts
i = int(self.e_p_length/2)
Ek = P_[:,:i,:]
Ev = P_[:,i:,:]
# ortho penalty
if train and self.ortho_mu > 0:
loss = ortho_penalty(K) * self.ortho_mu
loss += ortho_penalty(A) * self.ortho_mu
loss += ortho_penalty(p.view(p.shape[0], -1)) * self.ortho_mu
else:
loss = 0
else:
loss = 0
# combine prompts for prefix tuning
if e_valid:
p_return = [Ek, Ev]
else:
p_return = None
# return
return p_return, loss, x_block
def ortho_penalty(t):
return ((t @t.T - torch.eye(t.shape[0]).cuda())**2).mean()
# @article{wang2022dualprompt,
# title={DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning},
# author={Wang, Zifeng and Zhang, Zizhao and Ebrahimi, Sayna and Sun, Ruoxi and Zhang, Han and Lee, Chen-Yu and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and others},
# journal={European Conference on Computer Vision},
# year={2022}
# }
class DualPrompt(nn.Module):
def __init__(self, emb_d, n_tasks, prompt_param, key_dim=768):
super().__init__()
self.task_count = 0
self.emb_d = emb_d
self.key_d = key_dim
self.n_tasks = n_tasks
self._init_smart(emb_d, prompt_param)
# g prompt init
for g in self.g_layers:
p = tensor_prompt(self.g_p_length, emb_d)
setattr(self, f'g_p_{g}',p)
# e prompt init
for e in self.e_layers:
p = tensor_prompt(self.e_pool_size, self.e_p_length, emb_d)
k = tensor_prompt(self.e_pool_size, self.key_d)
setattr(self, f'e_p_{e}',p)
setattr(self, f'e_k_{e}',k)
def _init_smart(self, emb_d, prompt_param):
self.top_k = 1
self.task_id_bootstrap = True
# prompt locations
self.g_layers = [0,1]
self.e_layers = [2,3,4]
# prompt pool size
self.g_p_length = int(prompt_param[2])
self.e_p_length = int(prompt_param[1])
self.e_pool_size = int(prompt_param[0])
def process_task_count(self):
self.task_count += 1
def forward(self, x_querry, l, x_block, train=False, task_id=None):
# e prompts
e_valid = False
if l in self.e_layers:
e_valid = True
B, C = x_querry.shape
K = getattr(self,f'e_k_{l}') # 0 based indexing here
p = getattr(self,f'e_p_{l}') # 0 based indexing here
# print(p.shape)
# cosine similarity to match keys/querries
n_K = nn.functional.normalize(K, dim=1)
q = nn.functional.normalize(x_querry, dim=1).detach()
cos_sim = torch.einsum('bj,kj->bk', q, n_K)
if train:
# dual prompt during training uses task id
if self.task_id_bootstrap:
loss = (1.0 - cos_sim[:,task_id]).sum()
P_ = p[task_id].expand(len(x_querry),-1,-1)
else:
top_k = torch.topk(cos_sim, self.top_k, dim=1)
k_idx = top_k.indices
loss = (1.0 - cos_sim[:,k_idx]).sum()
P_ = p[k_idx]
else:
top_k = torch.topk(cos_sim, self.top_k, dim=1)
k_idx = top_k.indices
P_ = p[k_idx]
# select prompts
if train and self.task_id_bootstrap:
i = int(self.e_p_length/2)
Ek = P_[:,:i,:].reshape((B,-1,self.emb_d))
Ev = P_[:,i:,:].reshape((B,-1,self.emb_d))
else:
i = int(self.e_p_length/2)
Ek = P_[:,:,:i,:].reshape((B,-1,self.emb_d))
Ev = P_[:,:,i:,:].reshape((B,-1,self.emb_d))
# g prompts
g_valid = False
if l in self.g_layers:
g_valid = True
j = int(self.g_p_length/2)
p = getattr(self,f'g_p_{l}') # 0 based indexing here
P_ = p.expand(len(x_querry),-1,-1)
Gk = P_[:,:j,:]
Gv = P_[:,j:,:]
# combine prompts for prefix tuning
if e_valid and g_valid:
Pk = torch.cat((Ek, Gk), dim=1)
Pv = torch.cat((Ev, Gv), dim=1)
p_return = [Pk, Pv]
elif e_valid:
p_return = [Ek, Ev]
elif g_valid:
p_return = [Gk, Gv]
loss = 0
else:
p_return = None
loss = 0
# return
if train:
return p_return, loss, x_block
else:
return p_return, 0, x_block
# @inproceedings{wang2022learning,
# title={Learning to prompt for continual learning},
# author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
# booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
# pages={139--149},
# year={2022}
# }
class L2P(nn.Module):
def __init__(self, length, prompt_init=nn.init.uniform_, prompt_key=False,
pool_size=None, top_k=None, num_layers=1, embed_dim=768):
super().__init__()
self.length = length
self.prompt_init = prompt_init
self.pool_size = pool_size
self.top_k = top_k
self.num_layers = num_layers
self.embed_dim = embed_dim
# Initialize prompt parameters
self.prompt = nn.Parameter(
torch.empty((self.num_layers, self.pool_size, self.length, embed_dim))
)
self.prompt_key = nn.Parameter(
torch.empty((self.pool_size, embed_dim))
)
self.prompt_init(self.prompt)
self.prompt_init(self.prompt_key)
def forward(self, x_embed, cls_features=None):
B, N, C = x_embed.shape
assert C == self.embed_dim
# Normalize key features
prompt_key_norm = F.normalize(self.prompt_key, p=2, dim=-1, eps=1e-12)
x_embed_norm = F.normalize(cls_features, p=2, dim=-1, eps=1e-12)
sim = x_embed_norm @ prompt_key_norm.T
_, idx = torch.topk(sim, self.top_k, dim=1)
prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True)
# Manually pad to pool_size, equivalent as jnp.unique()
prompt_id = F.pad(prompt_id, (0, self.pool_size - len(prompt_id)), "constant", prompt_id[0])
id_counts = F.pad(id_counts, (0, self.pool_size - len(id_counts)), "constant", 0)
_, major_idx = torch.topk(id_counts, self.top_k)
major_prompt_id = prompt_id[major_idx]
idx = major_prompt_id.unsqueeze(0).repeat(B, 1)
batched_prompt_raw = self.prompt[:, idx]
batched_prompt = batched_prompt_raw.reshape(
batched_prompt_raw.shape[0],
batched_prompt_raw.shape[1],
-1,
batched_prompt_raw.shape[-1]
)
# Calculate pull constraint loss
batched_key_norm = prompt_key_norm[idx]
sim_pull = batched_key_norm * x_embed_norm.unsqueeze(1)
reduce_sim = torch.sum(sim_pull) / B
return batched_prompt, reduce_sim
# note - ortho init has not been found to help l2p/dual prompt
def tensor_prompt(a, b, c=None, ortho=False):
if c is None:
p = torch.nn.Parameter(torch.FloatTensor(a,b), requires_grad=True)
else:
p = torch.nn.Parameter(torch.FloatTensor(a,b,c), requires_grad=True)
if ortho:
nn.init.orthogonal_(p)
else:
nn.init.uniform_(p)
return p
# @inproceedings{10.24963/ijcai.2024/456,
# author = {Hong, Chenxing and Jin, Yan and Kang, Zhiqi and Chen, Yizhou and Li, Mengke and Lu, Yang and Wang, Hanzi},
# title = {Dynamically anchored prompting for task-imbalanced continual learning},
# booktitle = {Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence},
# year = {2025},
# }
class DAP(nn.Module):
def __init__(self, length=5, embed_dim=768, embedding_key='mean', prompt_init='uniform', prompt_pool=False,
prompt_key=False, pool_size=None, top_k=None, batchwise_prompt=False, prompt_key_init='uniform',tasklength=10):
super().__init__()
self.length = length
self.embed_dim = embed_dim
self.prompt_pool = prompt_pool
self.embedding_key = embedding_key
self.prompt_init = prompt_init
self.prompt_key = prompt_key
self.pool_size = pool_size
self.top_k = top_k
self.batchwise_prompt = batchwise_prompt
self.tasklength = tasklength
if self.prompt_pool:
prompt_pool_shape = (pool_size, length, embed_dim)
generalpromt = (top_k, length, embed_dim)
if prompt_init == 'zero':
self.prompt = nn.Parameter(torch.zeros(prompt_pool_shape))
self.taskprompt = nn.ParameterList([nn.Parameter(torch.zeros(top_k, length, embed_dim)) for _ in range(tasklength)]) # this is for taskid
self.generalprompt = nn.Parameter(torch.zeros(generalpromt))
elif prompt_init == 'uniform':
self.prompt = nn.Parameter(torch.randn(prompt_pool_shape))
nn.init.uniform_(self.prompt, -1, 1)
self.taskprompt = nn.ParameterList([nn.Parameter(torch.zeros(top_k, length, embed_dim)) for _ in range(tasklength)]) # this is for taskid
for tp in self.taskprompt:
nn.init.uniform_(tp, -1, 1)
self.generalprompt = nn.Parameter(torch.randn(generalpromt))
nn.init.uniform_(self.generalprompt, -1, 1)
if prompt_key:
key_shape = (pool_size, embed_dim)
if prompt_key_init == 'zero':
self.prompt_key = nn.Parameter(torch.zeros(key_shape))
elif prompt_key_init == 'uniform':
self.prompt_key = nn.Parameter(torch.randn(key_shape))
nn.init.uniform_(self.prompt_key, -1, 1)
else:
prompt_mean = torch.mean(self.prompt, dim=1)
self.prompt_key = prompt_mean
def l2_normalize(self, x, dim=None, epsilon=1e-12):
"""Normalizes a given vector or matrix."""
square_sum = torch.sum(x ** 2, dim=dim, keepdim=True)
x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device)))
return x * x_inv_norm
def forward(self, x_embed, prompt_mask=None, cls_features=None,taskid=None):
out = dict()
top_k, length, c = self.taskprompt[taskid].shape
batched_task_prompt_raw = self.taskprompt[taskid].reshape(top_k * length, c)
batched_task_prompt = batched_task_prompt_raw.unsqueeze(0).expand(x_embed.shape[0], -1, -1)
batched_general_prompt_raw = self.generalprompt.reshape(top_k * length, c)
batched_general_prompt = batched_general_prompt_raw.unsqueeze(0).expand(x_embed.shape[0], -1, -1)
out['total_prompt_len'] = batched_task_prompt.shape[1]
out['prompted_embedding'] = torch.cat([batched_task_prompt, x_embed], dim=1)
out['gen_total_prompt_len'] = batched_general_prompt.shape[1]
out['gen_prompted_embedding'] = torch.cat([batched_general_prompt, x_embed], dim=1)
return out |