| | |
| | """ |
| | @inproceedings{arXiv:2404.00228v3, |
| | title = {InfLoRA: Interference-Free Low-Rank Adaptation for Continual Learning}, |
| | author = {Yan-Shuo Liang and |
| | Wu-Jun Li}, |
| | booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, {CVPR} 2024, Seattle, Washington}, |
| | publisher = {Computer Vision Foundation / {IEEE}}, |
| | year = {2024}, |
| | url = {https://arxiv.org/abs/2404.00228v3}, |
| | } |
| | https://openaccess.thecvf.com/content_CVPR_2019/html/Hou_Learning_a_Unified_Classifier_Incrementally_via_Rebalancing_CVPR_2019_paper.html |
| | |
| | Adapted from https://github.com/liangyanshuo/InfLoRA?utm_source=catalyzex.com |
| | """ |
| |
|
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import optim |
| | from torch.nn import functional as F |
| | from torch.nn.parameter import Parameter |
| | from torch.utils.data import DataLoader |
| |
|
| | import logging |
| | import numpy as np |
| | from tqdm import tqdm |
| | from sklearn.cluster import KMeans |
| |
|
| | from .backbone.vit_inflora import Attention_LoRA |
| | from copy import deepcopy |
| | import math |
| | from .finetune import Finetune |
| |
|
| |
|
| | class InfLoRA(Finetune): |
| |
|
| | def __init__(self, backbone, feat_dim, num_class, **kwargs): |
| | super().__init__(backbone, feat_dim, num_class, **kwargs) |
| |
|
| | self._network = backbone |
| |
|
| | for module in self._network.modules(): |
| | if isinstance(module, Attention_LoRA): |
| | module.init_param() |
| |
|
| | |
| | self.num_class = num_class |
| | |
| | self._total_classes =0 |
| | |
| | self._known_classes =0 |
| |
|
| | |
| | self._cur_task = -1 |
| | |
| | self.inc_cls_num = kwargs["inc_cls_num"] |
| | |
| | self.device = kwargs["device"] |
| |
|
| | |
| | self.feature_list = [] |
| | self.project_type = [] |
| | self.lame = kwargs["lame"] |
| | self.lamb = kwargs["lamb"] |
| | self.total_sessions = kwargs["total_sessions"] |
| | |
| | def observe(self, data): |
| | ''' |
| | Called during the training phase, it inputs a batch of training examples and returns the prediction, accuracy, and forward loss. |
| | |
| | Code Reference: |
| | https://github.com/liangyanshuo/InfLoRA/blob/main/methods/inflora.py |
| | ''' |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| | |
| | y = y-self._known_classes |
| | logits = self._network(x)['logits'] |
| | loss = F.cross_entropy(logits, y) |
| | _, preds = torch.max(logits, dim=1) |
| | correct = preds.eq(y.expand_as(preds)).cpu().sum() |
| | total = len(y) |
| | acc = correct/total |
| | acc = acc.item() |
| | return preds, acc, loss |
| | |
| | def inference(self, data): |
| | ''' |
| | It is called in the inference phase to input a batch of test samples and return the classification result and accuracy. |
| | Calling the interface function of _network returns the value batchsize*_total_classes. |
| | |
| | Code Reference: |
| | https://github.com/liangyanshuo/InfLoRA/blob/main/methods/inflora.py |
| | ''' |
| | x, y = data['image'], data['label'] |
| | x = x.to(self.device) |
| | y = y.to(self.device) |
| | logits = self._network.interface(x) |
| | _, preds = torch.max(logits, dim=1) |
| | correct = preds.eq(y.expand_as(preds)).cpu().sum() |
| | total = len(y) |
| | acc = correct/total |
| | acc = acc.item() |
| | return preds, acc |
| | |
| | def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| | ''' |
| | It is called before the training of each task to update the parameters, select the branch for training, and update the lora_A matrix of the corresponding branch |
| | |
| | Code Reference: |
| | https://github.com/gydpku/OCM/blob/main/test_cifar10.py |
| | ''' |
| |
|
| | |
| | self._known_classes = self._total_classes |
| | self._cur_task += 1 |
| | self._total_classes = self._known_classes + self.inc_cls_num |
| | self._network.update_fc(self._total_classes) |
| |
|
| | self._network.to(self.device) |
| | |
| | |
| | for name, param in self._network.named_parameters(): |
| | param.requires_grad_(False) |
| | try: |
| | if "classifier_pool" + "." + str(self._network.module.numtask - 1) in name: |
| | param.requires_grad_(True) |
| | if "lora_B_k" + "." + str(self._network.module.numtask - 1) in name: |
| | param.requires_grad_(True) |
| | if "lora_B_v" + "." + str(self._network.module.numtask - 1) in name: |
| | param.requires_grad_(True) |
| | except: |
| | if "classifier_pool" + "." + str(self._network.numtask - 1) in name: |
| | param.requires_grad_(True) |
| | if "lora_B_k" + "." + str(self._network.numtask - 1) in name: |
| | param.requires_grad_(True) |
| | if "lora_B_v" + "." + str(self._network.numtask - 1) in name: |
| | param.requires_grad_(True) |
| |
|
| | |
| | enabled = set() |
| | for name, param in self._network.named_parameters(): |
| | if param.requires_grad: |
| | enabled.add(name) |
| |
|
| | with torch.no_grad(): |
| | |
| | for batch_idx, batch in enumerate(train_loader): |
| | inputs = batch["image"] |
| | targets = batch["label"] |
| | inputs, targets = inputs.to(self.device), targets.to(self.device) |
| | inputs=F.interpolate(inputs, size=224, mode='bilinear', align_corners=False) |
| | self._network(inputs, get_cur_feat=True) |
| | |
| | if self._cur_task == 0: |
| | |
| | for module in self._network.modules(): |
| | if isinstance(module, Attention_LoRA): |
| | cur_matrix = module.cur_matrix |
| | U, S, V = torch.linalg.svd(cur_matrix) |
| | module.lora_A_k[self._cur_task].weight.data.copy_(U[:,:module.rank].T/math.sqrt(3)) |
| | module.lora_A_v[self._cur_task].weight.data.copy_(U[:,:module.rank].T/math.sqrt(3)) |
| | module.cur_matrix.zero_() |
| | module.n_cur_matrix = 0 |
| | else: |
| | |
| | kk = 0 |
| | for module in self._network.modules(): |
| | if isinstance(module, Attention_LoRA): |
| | cur_matrix = module.cur_matrix |
| | if self.project_type[kk] == 'remove': |
| | cur_matrix = cur_matrix - torch.mm(self.feature_mat[kk],cur_matrix) |
| | else: |
| | assert self.project_type[kk] == 'retain' |
| | cur_matrix = torch.mm(self.feature_mat[kk],cur_matrix) |
| | cU, cS, cV = torch.linalg.svd(cur_matrix, full_matrices=False) |
| | module.lora_A_k[self._cur_task].weight.data.copy_(cU[:,:module.rank].T/math.sqrt(3)) |
| | module.lora_A_v[self._cur_task].weight.data.copy_(cU[:,:module.rank].T/math.sqrt(3)) |
| | module.cur_matrix.zero_() |
| | module.n_cur_matrix = 0 |
| | kk += 1 |
| | |
| | def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| | ''' |
| | Called after each task starts training, it is used to perform preliminary operations on the mapping matrix to facilitate the update of lora_a layer in the next round of before_task |
| | ''' |
| | with torch.no_grad(): |
| | |
| | for batch_idx, batch in enumerate(train_loader): |
| | inputs = batch["image"] |
| | targets = batch["label"] |
| | inputs, targets = inputs.to(self.device), targets.to(self.device) |
| | inputs=F.interpolate(inputs, size=224, mode='bilinear', align_corners=False) |
| | self._network(inputs, get_cur_feat=True) |
| | |
| | mat_list = [] |
| | for module in self._network.modules(): |
| | if isinstance(module, Attention_LoRA): |
| | mat_list.append(deepcopy(module.cur_matrix)) |
| | module.cur_matrix.zero_() |
| | module.n_cur_matrix = 0 |
| | self.update_DualGPM(mat_list) |
| | self.feature_mat = [] |
| | for p in range(len(self.feature_list)): |
| | Uf=torch.Tensor(np.dot(self.feature_list[p],self.feature_list[p].transpose())) |
| | print('Layer {} - Projection Matrix shape: {}'.format(p+1,Uf.shape)) |
| | self.feature_mat.append(Uf) |
| | |
| | return |
| |
|
| | def update_DualGPM (self, mat_list): |
| | ''' |
| | Code Reference: |
| | https://github.com/liangyanshuo/InfLoRA/blob/main/methods/inflora.py |
| | ''' |
| | threshold = (self.lame - self.lamb)*self._cur_task/self.total_sessions + self.lamb |
| | print ('Threshold: ', threshold) |
| | if len(self.feature_list) == 0: |
| | |
| | for i in range(len(mat_list)): |
| | activation = mat_list[i] |
| | U,S,Vh = np.linalg.svd(activation, full_matrices=False) |
| | |
| | sval_total = (S**2).sum() |
| | sval_ratio = (S**2)/sval_total |
| | r = np.sum(np.cumsum(sval_ratio)<threshold) |
| | if r < (activation.shape[0]/2): |
| | self.feature_list.append(U[:,0:max(r,1)]) |
| | self.project_type.append('remove') |
| | else: |
| | self.feature_list.append(U[:,0:max(r,1)]) |
| | self.project_type.append('retain') |
| | else: |
| | for i in range(len(mat_list)): |
| | if self.project_type[i] == 'remove': |
| | activation = mat_list[i] |
| | U1,S1,Vh1=np.linalg.svd(activation, full_matrices=False) |
| | sval_total = (S1**2).sum() |
| | |
| | act_hat = activation - np.dot(np.dot(self.feature_list[i],self.feature_list[i].transpose()),activation) |
| | U,S,Vh = np.linalg.svd(act_hat, full_matrices=False) |
| | |
| | sval_hat = (S**2).sum() |
| | sval_ratio = (S**2)/sval_total |
| | accumulated_sval = (sval_total-sval_hat)/sval_total |
| | |
| | r = 0 |
| | for ii in range (sval_ratio.shape[0]): |
| | if accumulated_sval < threshold: |
| | accumulated_sval += sval_ratio[ii] |
| | r += 1 |
| | else: |
| | break |
| | if r == 0: |
| | print ('Skip Updating DualGPM for layer: {}'.format(i+1)) |
| | continue |
| | |
| | Ui=np.hstack((self.feature_list[i],U[:,0:r])) |
| | if Ui.shape[1] > Ui.shape[0] : |
| | self.feature_list[i]=Ui[:,0:Ui.shape[0]] |
| | else: |
| | self.feature_list[i]=Ui |
| | else: |
| | assert self.project_type[i] == 'retain' |
| | activation = mat_list[i] |
| | U1,S1,Vh1=np.linalg.svd(activation, full_matrices=False) |
| | sval_total = (S1**2).sum() |
| | |
| | act_hat = np.dot(np.dot(self.feature_list[i],self.feature_list[i].transpose()),activation) |
| | U,S,Vh = np.linalg.svd(act_hat, full_matrices=False) |
| | |
| | sval_hat = (S**2).sum() |
| | sval_ratio = (S**2)/sval_total |
| | accumulated_sval = sval_hat/sval_total |
| |
|
| | r = 0 |
| | for ii in range (sval_ratio.shape[0]): |
| | if accumulated_sval >= (1-threshold): |
| | accumulated_sval -= sval_ratio[ii] |
| | r += 1 |
| | else: |
| | break |
| | if r == 0: |
| | print ('Skip Updating DualGPM for layer: {}'.format(i+1)) |
| | continue |
| |
|
| | |
| | act_feature = self.feature_list[i] - np.dot(np.dot(U[:,0:r],U[:,0:r].transpose()),self.feature_list[i]) |
| | Ui, Si, Vi = np.linalg.svd(act_feature) |
| | self.feature_list[i]=Ui[:,:self.feature_list[i].shape[1]-r] |
| |
|
| | print('-'*40) |
| | print('Gradient Constraints Summary') |
| | print('-'*40) |
| | for i in range(len(self.feature_list)): |
| | if self.project_type[i]=='remove' and (self.feature_list[i].shape[1] > (self.feature_list[i].shape[0]/2)): |
| | feature = self.feature_list[i] |
| | |
| | U, S, V = np.linalg.svd(feature) |
| | new_feature = U[:,feature.shape[1]:] |
| | self.feature_list[i] = new_feature |
| | self.project_type[i] = 'retain' |
| | elif self.project_type[i]=='retain': |
| | assert self.feature_list[i].shape[1] <= (self.feature_list[i].shape[0]/2) |
| | print ('Layer {} : {}/{} type {}'.format(i+1,self.feature_list[i].shape[1], self.feature_list[i].shape[0], self.project_type[i])) |
| | print('-'*40) |
| | |
| | def _set_random(self,args): |
| | ''' |
| | Set random values on various devices to ensure repeatable results |
| | ''' |
| | torch.manual_seed(args['seed']) |
| | torch.cuda.manual_seed(args['seed']) |
| | torch.cuda.manual_seed_all(args['seed']) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |