| | |
| | """ |
| | @inproceedings{DBLP:conf/cvpr/SmithKGCKAPFK23, |
| | author = {James Seale Smith and |
| | Leonid Karlinsky and |
| | Vyshnavi Gutta and |
| | Paola Cascante{-}Bonilla and |
| | Donghyun Kim and |
| | Assaf Arbelle and |
| | Rameswar Panda and |
| | Rog{\'{e}}rio Feris and |
| | Zsolt Kira}, |
| | title = {CODA-Prompt: COntinual Decomposed Attention-Based Prompting for Rehearsal-Free |
| | Continual Learning}, |
| | booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, |
| | {CVPR} 2023, Vancouver, BC, Canada, June 17-24, 2023}, |
| | pages = {11909--11919}, |
| | publisher = {{IEEE}}, |
| | year = {2023} |
| | } |
| | |
| | https://arxiv.org/abs/2211.13218 |
| | |
| | Adapted from https://github.com/GT-RIPL/CODA-Prompt |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.nn.init as init |
| | import torchvision.models as models |
| | from torch.autograd import Variable |
| | import numpy as np |
| | import copy |
| |
|
| | |
| | class CodaPrompt(nn.Module): |
| | def __init__(self, emb_d, n_tasks, prompt_param, key_dim=768): |
| | super().__init__() |
| | self.task_count = 0 |
| | self.emb_d = emb_d |
| | self.key_d = key_dim |
| | self.n_tasks = n_tasks |
| | self._init_smart(emb_d, prompt_param) |
| |
|
| | |
| | for e in self.e_layers: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | e_l = self.e_p_length |
| | p = tensor_prompt(self.e_pool_size, e_l, emb_d) |
| | k = tensor_prompt(self.e_pool_size, self.key_d) |
| | a = tensor_prompt(self.e_pool_size, self.key_d) |
| | p = self.gram_schmidt(p) |
| | k = self.gram_schmidt(k) |
| | a = self.gram_schmidt(a) |
| | setattr(self, f'e_p_{e}',p) |
| | setattr(self, f'e_k_{e}',k) |
| | setattr(self, f'e_a_{e}',a) |
| |
|
| | def _init_smart(self, emb_d, prompt_param): |
| |
|
| | |
| | self.e_pool_size = int(prompt_param[0]) |
| | self.e_p_length = int(prompt_param[1]) |
| | self.e_layers = [0,1,2,3,4] |
| |
|
| | |
| | self.ortho_mu = prompt_param[2] |
| | |
| | def process_task_count(self): |
| | self.task_count += 1 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | for e in self.e_layers: |
| | K = getattr(self,f'e_k_{e}') |
| | A = getattr(self,f'e_a_{e}') |
| | P = getattr(self,f'e_p_{e}') |
| | k = self.gram_schmidt(K) |
| | a = self.gram_schmidt(A) |
| | p = self.gram_schmidt(P) |
| | setattr(self, f'e_p_{e}',p) |
| | setattr(self, f'e_k_{e}',k) |
| | setattr(self, f'e_a_{e}',a) |
| |
|
| | |
| | |
| | def gram_schmidt(self, vv): |
| |
|
| | def projection(u, v): |
| | denominator = (u * u).sum() |
| |
|
| | if denominator < 1e-8: |
| | return None |
| | else: |
| | return (v * u).sum() / denominator * u |
| |
|
| | |
| | is_3d = len(vv.shape) == 3 |
| | if is_3d: |
| | shape_2d = copy.deepcopy(vv.shape) |
| | vv = vv.view(vv.shape[0],-1) |
| |
|
| | |
| | vv = vv.T |
| |
|
| | |
| | nk = vv.size(1) |
| | uu = torch.zeros_like(vv, device=vv.device) |
| |
|
| | |
| | pt = int(self.e_pool_size / (self.n_tasks)) |
| | s = int(self.task_count * pt) |
| | f = int((self.task_count + 1) * pt) |
| | if s > 0: |
| | uu[:, 0:s] = vv[:, 0:s].clone() |
| | for k in range(s, f): |
| | redo = True |
| | while redo: |
| | redo = False |
| | vk = torch.randn_like(vv[:,k]).to(vv.device) |
| | uk = 0 |
| | for j in range(0, k): |
| | if not redo: |
| | uj = uu[:, j].clone() |
| | proj = projection(uj, vk) |
| | if proj is None: |
| | redo = True |
| | print('restarting!!!') |
| | else: |
| | uk = uk + proj |
| | if not redo: uu[:, k] = vk - uk |
| | for k in range(s, f): |
| | uk = uu[:, k].clone() |
| | uu[:, k] = uk / (uk.norm()) |
| |
|
| | |
| | uu = uu.T |
| |
|
| | |
| | if is_3d: |
| | uu = uu.view(shape_2d) |
| | |
| | return torch.nn.Parameter(uu) |
| |
|
| | def forward(self, x_querry, l, x_block, train=False, task_id=None): |
| |
|
| | |
| | e_valid = False |
| | if l in self.e_layers: |
| | e_valid = True |
| | B, C = x_querry.shape |
| |
|
| | K = getattr(self,f'e_k_{l}') |
| | A = getattr(self,f'e_a_{l}') |
| | p = getattr(self,f'e_p_{l}') |
| | pt = int(self.e_pool_size / (self.n_tasks)) |
| | s = int(self.task_count * pt) |
| | f = int((self.task_count + 1) * pt) |
| | |
| | |
| | if train: |
| | if self.task_count > 0: |
| | K = torch.cat((K[:s].detach().clone(),K[s:f]), dim=0) |
| | A = torch.cat((A[:s].detach().clone(),A[s:f]), dim=0) |
| | p = torch.cat((p[:s].detach().clone(),p[s:f]), dim=0) |
| | else: |
| | K = K[s:f] |
| | A = A[s:f] |
| | p = p[s:f] |
| | else: |
| | K = K[0:f] |
| | A = A[0:f] |
| | p = p[0:f] |
| |
|
| | |
| | |
| | a_querry = torch.einsum('bd,kd->bkd', x_querry, A) |
| | |
| | n_K = nn.functional.normalize(K, dim=1) |
| | q = nn.functional.normalize(a_querry, dim=2) |
| | aq_k = torch.einsum('bkd,kd->bk', q, n_K) |
| | |
| | P_ = torch.einsum('bk,kld->bld', aq_k, p) |
| |
|
| | |
| | i = int(self.e_p_length/2) |
| | Ek = P_[:,:i,:] |
| | Ev = P_[:,i:,:] |
| |
|
| | |
| | if train and self.ortho_mu > 0: |
| | loss = ortho_penalty(K) * self.ortho_mu |
| | loss += ortho_penalty(A) * self.ortho_mu |
| | loss += ortho_penalty(p.view(p.shape[0], -1)) * self.ortho_mu |
| | else: |
| | loss = 0 |
| | else: |
| | loss = 0 |
| |
|
| | |
| | if e_valid: |
| | p_return = [Ek, Ev] |
| | else: |
| | p_return = None |
| |
|
| | |
| | return p_return, loss, x_block |
| |
|
| | def ortho_penalty(t): |
| | return ((t @t.T - torch.eye(t.shape[0]).cuda())**2).mean() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | class DualPrompt(nn.Module): |
| | def __init__(self, emb_d, n_tasks, prompt_param, key_dim=768): |
| | super().__init__() |
| | self.task_count = 0 |
| | self.emb_d = emb_d |
| | self.key_d = key_dim |
| | self.n_tasks = n_tasks |
| | self._init_smart(emb_d, prompt_param) |
| |
|
| | |
| | for g in self.g_layers: |
| | p = tensor_prompt(self.g_p_length, emb_d) |
| | setattr(self, f'g_p_{g}',p) |
| | |
| |
|
| | |
| | for e in self.e_layers: |
| | p = tensor_prompt(self.e_pool_size, self.e_p_length, emb_d) |
| | k = tensor_prompt(self.e_pool_size, self.key_d) |
| | setattr(self, f'e_p_{e}',p) |
| | setattr(self, f'e_k_{e}',k) |
| |
|
| | def _init_smart(self, emb_d, prompt_param): |
| | self.top_k = 1 |
| | self.task_id_bootstrap = True |
| |
|
| | |
| | self.g_layers = [0,1] |
| | self.e_layers = [2,3,4] |
| |
|
| | |
| | self.g_p_length = int(prompt_param[2]) |
| | self.e_p_length = int(prompt_param[1]) |
| | self.e_pool_size = int(prompt_param[0]) |
| |
|
| | def process_task_count(self): |
| | self.task_count += 1 |
| |
|
| | def forward(self, x_querry, l, x_block, train=False, task_id=None): |
| | |
| | e_valid = False |
| | if l in self.e_layers: |
| | e_valid = True |
| | B, C = x_querry.shape |
| | K = getattr(self,f'e_k_{l}') |
| | p = getattr(self,f'e_p_{l}') |
| | |
| | |
| | n_K = nn.functional.normalize(K, dim=1) |
| | q = nn.functional.normalize(x_querry, dim=1).detach() |
| | cos_sim = torch.einsum('bj,kj->bk', q, n_K) |
| | |
| | if train: |
| | |
| | if self.task_id_bootstrap: |
| | loss = (1.0 - cos_sim[:,task_id]).sum() |
| | P_ = p[task_id].expand(len(x_querry),-1,-1) |
| | else: |
| | top_k = torch.topk(cos_sim, self.top_k, dim=1) |
| | k_idx = top_k.indices |
| | loss = (1.0 - cos_sim[:,k_idx]).sum() |
| | P_ = p[k_idx] |
| | else: |
| | top_k = torch.topk(cos_sim, self.top_k, dim=1) |
| | k_idx = top_k.indices |
| | P_ = p[k_idx] |
| | |
| | |
| | if train and self.task_id_bootstrap: |
| | i = int(self.e_p_length/2) |
| | Ek = P_[:,:i,:].reshape((B,-1,self.emb_d)) |
| | Ev = P_[:,i:,:].reshape((B,-1,self.emb_d)) |
| | else: |
| | i = int(self.e_p_length/2) |
| | Ek = P_[:,:,:i,:].reshape((B,-1,self.emb_d)) |
| | Ev = P_[:,:,i:,:].reshape((B,-1,self.emb_d)) |
| | |
| | |
| | g_valid = False |
| | if l in self.g_layers: |
| | g_valid = True |
| | j = int(self.g_p_length/2) |
| | p = getattr(self,f'g_p_{l}') |
| | P_ = p.expand(len(x_querry),-1,-1) |
| | Gk = P_[:,:j,:] |
| | Gv = P_[:,j:,:] |
| |
|
| | |
| | if e_valid and g_valid: |
| | Pk = torch.cat((Ek, Gk), dim=1) |
| | Pv = torch.cat((Ev, Gv), dim=1) |
| | p_return = [Pk, Pv] |
| | elif e_valid: |
| | p_return = [Ek, Ev] |
| | elif g_valid: |
| | p_return = [Gk, Gv] |
| | loss = 0 |
| | else: |
| | p_return = None |
| | loss = 0 |
| | |
| |
|
| | |
| | if train: |
| | return p_return, loss, x_block |
| | else: |
| | return p_return, 0, x_block |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | class L2P(nn.Module): |
| |
|
| | def __init__(self, length, prompt_init=nn.init.uniform_, prompt_key=False, |
| | pool_size=None, top_k=None, num_layers=1, embed_dim=768): |
| |
|
| | super().__init__() |
| | self.length = length |
| | self.prompt_init = prompt_init |
| | self.pool_size = pool_size |
| | self.top_k = top_k |
| | self.num_layers = num_layers |
| | self.embed_dim = embed_dim |
| |
|
| | |
| | self.prompt = nn.Parameter( |
| | torch.empty((self.num_layers, self.pool_size, self.length, embed_dim)) |
| | ) |
| | self.prompt_key = nn.Parameter( |
| | torch.empty((self.pool_size, embed_dim)) |
| | ) |
| | self.prompt_init(self.prompt) |
| | self.prompt_init(self.prompt_key) |
| |
|
| | def forward(self, x_embed, cls_features=None): |
| |
|
| | B, N, C = x_embed.shape |
| | assert C == self.embed_dim |
| |
|
| | |
| | prompt_key_norm = F.normalize(self.prompt_key, p=2, dim=-1, eps=1e-12) |
| | x_embed_norm = F.normalize(cls_features, p=2, dim=-1, eps=1e-12) |
| |
|
| | sim = x_embed_norm @ prompt_key_norm.T |
| | _, idx = torch.topk(sim, self.top_k, dim=1) |
| |
|
| | prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True) |
| | |
| | |
| | prompt_id = F.pad(prompt_id, (0, self.pool_size - len(prompt_id)), "constant", prompt_id[0]) |
| | id_counts = F.pad(id_counts, (0, self.pool_size - len(id_counts)), "constant", 0) |
| |
|
| | _, major_idx = torch.topk(id_counts, self.top_k) |
| | |
| | major_prompt_id = prompt_id[major_idx] |
| | idx = major_prompt_id.unsqueeze(0).repeat(B, 1) |
| |
|
| | batched_prompt_raw = self.prompt[:, idx] |
| |
|
| | batched_prompt = batched_prompt_raw.reshape( |
| | batched_prompt_raw.shape[0], |
| | batched_prompt_raw.shape[1], |
| | -1, |
| | batched_prompt_raw.shape[-1] |
| | ) |
| | |
| | |
| | batched_key_norm = prompt_key_norm[idx] |
| | sim_pull = batched_key_norm * x_embed_norm.unsqueeze(1) |
| | reduce_sim = torch.sum(sim_pull) / B |
| |
|
| | return batched_prompt, reduce_sim |
| |
|
| | |
| | def tensor_prompt(a, b, c=None, ortho=False): |
| | if c is None: |
| | p = torch.nn.Parameter(torch.FloatTensor(a,b), requires_grad=True) |
| | else: |
| | p = torch.nn.Parameter(torch.FloatTensor(a,b,c), requires_grad=True) |
| | if ortho: |
| | nn.init.orthogonal_(p) |
| | else: |
| | nn.init.uniform_(p) |
| | return p |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | class DAP(nn.Module): |
| | def __init__(self, length=5, embed_dim=768, embedding_key='mean', prompt_init='uniform', prompt_pool=False, |
| | prompt_key=False, pool_size=None, top_k=None, batchwise_prompt=False, prompt_key_init='uniform',tasklength=10): |
| | super().__init__() |
| |
|
| | self.length = length |
| | self.embed_dim = embed_dim |
| | self.prompt_pool = prompt_pool |
| | self.embedding_key = embedding_key |
| | self.prompt_init = prompt_init |
| | self.prompt_key = prompt_key |
| | self.pool_size = pool_size |
| | self.top_k = top_k |
| | self.batchwise_prompt = batchwise_prompt |
| | self.tasklength = tasklength |
| | if self.prompt_pool: |
| | prompt_pool_shape = (pool_size, length, embed_dim) |
| | generalpromt = (top_k, length, embed_dim) |
| |
|
| | if prompt_init == 'zero': |
| | self.prompt = nn.Parameter(torch.zeros(prompt_pool_shape)) |
| | self.taskprompt = nn.ParameterList([nn.Parameter(torch.zeros(top_k, length, embed_dim)) for _ in range(tasklength)]) |
| | self.generalprompt = nn.Parameter(torch.zeros(generalpromt)) |
| |
|
| | elif prompt_init == 'uniform': |
| | self.prompt = nn.Parameter(torch.randn(prompt_pool_shape)) |
| | nn.init.uniform_(self.prompt, -1, 1) |
| | self.taskprompt = nn.ParameterList([nn.Parameter(torch.zeros(top_k, length, embed_dim)) for _ in range(tasklength)]) |
| | for tp in self.taskprompt: |
| | nn.init.uniform_(tp, -1, 1) |
| | self.generalprompt = nn.Parameter(torch.randn(generalpromt)) |
| | nn.init.uniform_(self.generalprompt, -1, 1) |
| |
|
| | if prompt_key: |
| |
|
| | key_shape = (pool_size, embed_dim) |
| | if prompt_key_init == 'zero': |
| | self.prompt_key = nn.Parameter(torch.zeros(key_shape)) |
| | elif prompt_key_init == 'uniform': |
| | self.prompt_key = nn.Parameter(torch.randn(key_shape)) |
| | nn.init.uniform_(self.prompt_key, -1, 1) |
| | else: |
| |
|
| | prompt_mean = torch.mean(self.prompt, dim=1) |
| | self.prompt_key = prompt_mean |
| |
|
| | def l2_normalize(self, x, dim=None, epsilon=1e-12): |
| | """Normalizes a given vector or matrix.""" |
| | square_sum = torch.sum(x ** 2, dim=dim, keepdim=True) |
| | x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device))) |
| | return x * x_inv_norm |
| |
|
| | def forward(self, x_embed, prompt_mask=None, cls_features=None,taskid=None): |
| |
|
| | out = dict() |
| |
|
| | top_k, length, c = self.taskprompt[taskid].shape |
| | batched_task_prompt_raw = self.taskprompt[taskid].reshape(top_k * length, c) |
| | batched_task_prompt = batched_task_prompt_raw.unsqueeze(0).expand(x_embed.shape[0], -1, -1) |
| |
|
| | batched_general_prompt_raw = self.generalprompt.reshape(top_k * length, c) |
| | batched_general_prompt = batched_general_prompt_raw.unsqueeze(0).expand(x_embed.shape[0], -1, -1) |
| |
|
| |
|
| | out['total_prompt_len'] = batched_task_prompt.shape[1] |
| | out['prompted_embedding'] = torch.cat([batched_task_prompt, x_embed], dim=1) |
| |
|
| | out['gen_total_prompt_len'] = batched_general_prompt.shape[1] |
| | out['gen_prompted_embedding'] = torch.cat([batched_general_prompt, x_embed], dim=1) |
| | return out |