| | """ |
| | Code Reference: |
| | https://github.com/liangyanshuo/InfLoRA/blob/main/methods/inflora.py |
| | """ |
| |
|
| | import os |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.distributed as dist |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| |
|
| | from tqdm import tqdm |
| | from .backbone.transformer import MultiHeadAttention_MaskedLoRA1 |
| |
|
| | GREEDY=True |
| | APPROX_FEAT=True |
| | |
| | Epsilon = 0.5 |
| |
|
| | def _set_random(seed): |
| | ''' |
| | Set random values on various devices to ensure repeatable results |
| | ''' |
| |
|
| | seed = int(seed) |
| |
|
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | np.random.seed(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| |
|
| | def select_probe_greedy_span_unified_normalized(cur_matrixs_list, energy_threshold=0.95, top_r=None): |
| | """ |
| | Greedy span selection across multiple attention blocks, with per-block normalization. |
| | Dynamically select samples that together span a certain percentage of gradient space. |
| | |
| | Args: |
| | cur_matrixs_list (List[torch.Tensor]): list of (Num, 768, 768) tensors for each block. |
| | energy_threshold (float): fraction of gradient space to cover. |
| | top_r (int, optional): number of top singular vectors to use. |
| | |
| | Returns: |
| | selected_indices (torch.Tensor) |
| | """ |
| | N = cur_matrixs_list[0].shape[0] |
| | device = cur_matrixs_list[0].device |
| |
|
| | |
| | normalized_cur_matrices = [] |
| | for matrices in cur_matrixs_list: |
| | frob_norms = torch.norm(matrices.view(N, -1), dim=-1, p=2).view(N, 1, 1) |
| | matrices_normalized = matrices / (frob_norms + 1e-8) |
| | normalized_cur_matrices.append(matrices_normalized) |
| |
|
| | |
| | C_global = sum([matrices.sum(dim=0) for matrices in normalized_cur_matrices]) |
| |
|
| | |
| | U, _, _ = torch.linalg.svd(C_global) |
| | if top_r is not None: |
| | U = U[:, :top_r] |
| |
|
| | |
| | projected_vectors = [] |
| | for i in range(N): |
| | x_cov_sum = sum([matrices[i] for matrices in normalized_cur_matrices]) |
| | proj = U.T @ x_cov_sum @ U |
| | projected_vectors.append(proj.flatten()) |
| | projected_vectors = torch.stack(projected_vectors, dim=0) |
| |
|
| | |
| | selected_indices = [] |
| | remaining_indices = set(range(N)) |
| | selected_vectors = [] |
| |
|
| |
|
| |
|
| | |
| | current_energy = 0.0 |
| |
|
| | while current_energy / total_energy < energy_threshold: |
| | best_idx = -1 |
| | best_gain = -float('inf') |
| |
|
| | for idx in remaining_indices: |
| | vec = projected_vectors[idx] |
| | if selected_vectors: |
| | |
| | selected_mat = torch.stack(selected_vectors, dim=0) |
| |
|
| | Q, _ = torch.linalg.qr(selected_mat.T, mode='reduced') |
| | projection = (Q @ (Q.T @ vec)) |
| |
|
| | |
| | vec_residual = vec - projection |
| | else: |
| | vec_residual = vec |
| |
|
| | gain = vec_residual.norm().item() |
| |
|
| | if gain > best_gain: |
| | best_gain = gain |
| | best_idx = idx |
| |
|
| | selected_indices.append(best_idx) |
| | remaining_indices.remove(best_idx) |
| | selected_vectors.append(projected_vectors[best_idx] / (projected_vectors[best_idx].norm() + 1e-8)) |
| | current_energy += best_gain ** 2 |
| |
|
| | selected_indices = torch.tensor(selected_indices) |
| |
|
| | print(f"Selected {len(selected_indices)} samples covering {current_energy / total_energy * 100:.2f}% of gradient space.") |
| |
|
| | |
| | plt.figure(figsize=(8, 6)) |
| | plt.plot(torch.arange(len(selected_indices))+1, (torch.tensor([current_energy / total_energy for _ in selected_indices])*100).numpy(), label='Cumulative Span Coverage') |
| | plt.xlabel('Number of Samples Selected') |
| | plt.ylabel('Coverage (%)') |
| | plt.title('Greedy Span Selection Coverage') |
| | plt.grid(True) |
| | plt.legend() |
| | plt.savefig('greedy_span_coverage.png', dpi=300) |
| |
|
| | return selected_indices |
| |
|
| | def select_probe_greedy_span_unified_normalized_high_precision( |
| | cur_matrixs_list, |
| | energy_threshold=0.95, |
| | top_r=None |
| | ): |
| | """ |
| | Greedy span selection across multiple attention blocks, with per-block normalization. |
| | Dynamically select samples that together span a certain percentage of gradient space. |
| | |
| | Args: |
| | cur_matrixs_list (List[torch.Tensor]): list of (Num, 768, 768) tensors for each block. |
| | energy_threshold (float): fraction of gradient space to cover. |
| | top_r (int, optional): number of top singular vectors to use. |
| | feature_mode (str): "trace" (default) or "flatten". How to extract projected features. |
| | |
| | Returns: |
| | selected_indices (torch.Tensor) |
| | """ |
| |
|
| | N = cur_matrixs_list[0].shape[0] |
| |
|
| | |
| | normalized_cur_matrices = [] |
| | for matrices in cur_matrixs_list: |
| | frob_norms = torch.norm(matrices.view(N, -1), dim=-1, p=2).view(N, 1, 1) |
| | matrices_normalized = matrices / (frob_norms + 1e-8) |
| | normalized_cur_matrices.append(matrices_normalized) |
| | |
| | |
| | C_global = sum([matrices.sum(dim=0) for matrices in normalized_cur_matrices]) |
| |
|
| | |
| | U, _, _ = torch.linalg.svd(C_global) |
| | if top_r is not None: |
| | U = U[:, :top_r] |
| |
|
| | |
| | projected_vectors = [] |
| | for i in range(N): |
| | x_cov_sum = sum([matrices[i] for matrices in normalized_cur_matrices]) |
| | proj = U.T @ x_cov_sum @ U |
| | proj_feat = proj.flatten() |
| | projected_vectors.append(proj_feat) |
| | projected_vectors = torch.stack(projected_vectors, dim=0) |
| |
|
| | |
| | selected_indices = [] |
| | remaining_indices = set(range(N)) |
| |
|
| | residual_vectors = projected_vectors.clone() |
| | selected_vectors = [] |
| |
|
| | total_energy = projected_vectors.norm(dim=-1).pow(2).sum().item() |
| | current_energy = 0.0 |
| |
|
| | |
| | while current_energy / total_energy < energy_threshold: |
| | assert remaining_indices |
| | best_idx = -1 |
| | best_gain = -float('inf') |
| |
|
| | for idx in remaining_indices: |
| | vec = residual_vectors[idx] |
| | if selected_vectors: |
| | |
| | selected_mat = torch.stack(selected_vectors, dim=0) |
| | Q, _ = torch.linalg.qr(selected_mat.T, mode='reduced') |
| | vec_residual = vec - (Q @ (Q.T @ vec)) |
| | else: |
| | vec_residual = vec |
| |
|
| | |
| | gain = vec_residual.norm().item() ** 2 |
| | |
| | if gain > best_gain: |
| | best_gain = gain |
| | best_idx = idx |
| | if not GREEDY: |
| | break |
| |
|
| | |
| | selected_indices.append(best_idx) |
| | selected_vec = residual_vectors[best_idx] / (residual_vectors[best_idx].norm() + 1e-8) |
| | current_energy += projected_vectors[best_idx].norm().item() ** 2 |
| | print(current_energy, '/', total_energy) |
| |
|
| | |
| | selected_vectors.append(selected_vec) |
| | projection = (residual_vectors @ selected_vec.unsqueeze(-1)).squeeze(-1) |
| | residual_vectors = residual_vectors - projection.unsqueeze(-1) * selected_vec.unsqueeze(0) |
| |
|
| | remaining_indices.discard(best_idx) |
| |
|
| | |
| | selected_indices = torch.tensor(selected_indices) |
| | print(f"Selected {len(selected_indices)} samples covering {current_energy / total_energy * 100:.2f}% of gradient space.") |
| |
|
| | return selected_indices |
| |
|
| | |
| |
|
| | class TopK: |
| |
|
| | ''' |
| | A class to maintain a collection of the top K items based on a specified attribute. |
| | |
| | This class allows for the dynamic addition of items, each represented as a dictionary, |
| | where each dictionary must have a key 'proj_norm' that represents the value used |
| | to determine the ranking. The class keeps track of the top K items with the highest |
| | 'proj_norm' values. |
| | ''' |
| |
|
| | def __init__(self, k): |
| | self.k = k |
| | self.top_k_list = [] |
| |
|
| | def add(self, dict): |
| | if len(self.top_k_list) < self.k: |
| | self.top_k_list.append(dict) |
| | elif dict['proj_norm'] > min(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm']: |
| | self.top_k_list.remove(min(self.top_k_list, key=lambda x: x['proj_norm'])) |
| | self.top_k_list.append(dict) |
| | elif dict['proj_norm'] == min(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm'] and \ |
| | dict['proj_norm'] == max(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm']: |
| | self.top_k_list.remove(min(self.top_k_list, key=lambda x: x['task_id'])) |
| | self.top_k_list.append(dict) |
| |
|
| | def get_top_k(self): |
| | return self.top_k_list |
| |
|
| | class SiNet(nn.Module): |
| | def __init__(self, backbone, **kwargs): |
| | super().__init__() |
| |
|
| | self._cur_task_id = -1 |
| | self.backbone = backbone |
| | self.init_cls_num = kwargs["init_cls_num"] |
| | self.inc_cls_num = kwargs["inc_cls_num"] |
| |
|
| | _set_random(os.environ["PYTHONHASHSEED"]) |
| | self.classifier_pool = nn.ModuleList([ |
| | nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + |
| | [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)]) |
| |
|
| | for name, module in self.backbone.named_modules(): |
| | if 'transformer' in name and 'blocks' not in name: |
| | self.transformer_module = module |
| |
|
| | def update_fc(self): |
| | self._cur_task_id += 1 |
| |
|
| | def forward(self, x, expert_id, inference = False): |
| | logits = [] |
| | features = self.backbone(x, expert_id = expert_id) |
| |
|
| | if inference: |
| |
|
| | |
| | for i, prompts in enumerate(self.classifier_pool[:self._cur_task_id + 1]): |
| | |
| | logits.append(prompts(features)) |
| |
|
| | logits = torch.cat(logits, dim=1) |
| |
|
| | return logits |
| |
|
| | else: |
| | logits.append(self.classifier_pool[self._cur_task_id](features)) |
| | return torch.cat(logits, dim=1) |
| |
|
| | def update_input_matrix(self, x): |
| | self.backbone(x, expert_id = -1, get_input_matrix = True) |
| |
|
| | class MInfLoRA(nn.Module): |
| |
|
| | def __init__(self, backbone, device, **kwargs): |
| | super().__init__() |
| |
|
| | self.device = device |
| | self.init_cls_num = kwargs["init_cls_num"] |
| | self.inc_cls_num = kwargs["inc_cls_num"] |
| | self.task_num = kwargs["task_num"] |
| | self.lame = kwargs["lame"] |
| | self.lamb = kwargs["lamb"] |
| | self.embd_dim = kwargs["embd_dim"] |
| | self.eval_mat = False |
| |
|
| | self._known_classes = 0 |
| | self.feature_list = [] |
| | self.project_type = [] |
| |
|
| | self.distributed = torch.distributed.is_initialized() |
| | assert not self.distributed, 'current not support' |
| | self.local_rank = torch.distributed.get_rank() if self.distributed else 0 |
| | |
| | self._network = SiNet(backbone, **kwargs) |
| |
|
| | self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_MaskedLoRA1)] |
| |
|
| | |
| | self.feature_list_each_tasks = [[np.zeros((1)) for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] |
| | self.final_decision = [[np.zeros((1)) for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] |
| | self.before_mat = [[0 for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] |
| |
|
| | self.experts_distributions = [] |
| |
|
| | |
| | self._use_class_alignment = kwargs['use_ca'] |
| | self._class_means = None |
| | self._class_covs = None |
| | self._dataset = kwargs['dataset'] |
| | if self._dataset == 'cifar': |
| | self.logit_norm = None |
| | else: |
| | self.logit_norm = 0.1 |
| | |
| | self.lll = [] |
| |
|
| | self._network.to(self.device) |
| | |
| | def observe(self, data): |
| |
|
| | with torch.no_grad(): |
| | self._network(self.probe_selection, expert_id = -1) |
| |
|
| | x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes |
| |
|
| | logits = self._network(x, expert_id = self._network._cur_task_id) |
| | loss = F.cross_entropy(logits, y) |
| |
|
| | preds = logits.max(1)[1] |
| | acc = preds.eq(y).sum().item() / y.shape[0] |
| |
|
| | return preds, acc, loss |
| | |
| | def inference(self, data, **kwargs): |
| |
|
| | task_id = kwargs['task_id'] if 'task_id' in kwargs else -1 |
| | x, y = data['image'].to(self.device, non_blocking=True), data['label'].to(self.device, non_blocking=True) |
| |
|
| | logits = self._network(x, expert_id = task_id, inference = True) |
| | preds = logits.max(1)[1] |
| | acc = preds.eq(y).sum().item() / y.shape[0] |
| |
|
| | return preds, acc |
| |
|
| | @torch.no_grad() |
| | def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
|
| | print('Greedy', GREEDY) |
| | print('Approx Feature', APPROX_FEAT) |
| | |
| | self._network.update_fc() |
| |
|
| | |
| | _set_random(os.environ["PYTHONHASHSEED"]) |
| | for module in self.attention_modules: |
| | |
| | module.init_param() |
| |
|
| | self._network = self._network.to(self.device) |
| | self._update_input_matrix(train_loader) |
| | |
| | ''' |
| | probe_indices_svd = select_probe_svd_energy_matrix_unified_normalized( |
| | [m.cur_matrixs for m in self.attention_modules] |
| | ,probe_size=512 |
| | ) |
| | ''' |
| | ''' |
| | self.probe_indices_svd = select_probe_svd_energy_matrix_unified_normalized( |
| | [m.cur_matrixs for m in self.attention_modules] |
| | ,energy_threshold=0.15, top_r=64 |
| | ) |
| | ''' |
| | self.probe_indices_svd = select_probe_greedy_span_unified_normalized_high_precision( |
| | [m.cur_matrixs for m in self.attention_modules] |
| | ,energy_threshold=0.01, top_r=128 |
| | |
| | ) |
| | |
| | self.probe_selection = self.dataset[self.probe_indices_svd].to(self.device) |
| |
|
| | if task_idx == 0: |
| | for i, module in enumerate(self.attention_modules): |
| |
|
| | |
| | U, _, _ = torch.linalg.svd(module.cur_matrixs[self.probe_indices_svd].sum(dim=0) / 512, full_matrices=False) |
| | |
| | module.lora_A_k_list[task_idx].weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) |
| | module.lora_A_v_list[task_idx].weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) |
| |
|
| | else: |
| | for i, module in enumerate(self.attention_modules): |
| |
|
| | feature_mat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) |
| | module.feature_mat = feature_mat.clone().to(self.device) |
| |
|
| | activation = module.cur_matrixs[self.probe_indices_svd].sum(dim=0) / 512 |
| | activation = activation - feature_mat @ activation |
| |
|
| | U, _, _ = torch.linalg.svd(activation, full_matrices = False) |
| |
|
| | module.lora_A_k_list[task_idx].weight.data.copy_(U[:, :module.lora_rank].T/(3 ** 0.5)) |
| | module.lora_A_v_list[task_idx].weight.data.copy_(U[:, :module.lora_rank].T/(3 ** 0.5)) |
| | |
| | ''' |
| | for i, module in enumerate(self.attention_modules): |
| | |
| | topk = TopK(1) |
| | |
| | mat = module.cur_matrix.cpu().numpy() |
| | mat_norm = np.linalg.norm(mat) |
| | |
| | for task_id in range(task_idx): |
| | |
| | if not np.array_equal(self.feature_list_each_tasks[task_id][i], np.zeros((1))): |
| | |
| | proj_norm = np.linalg.norm(self.feature_list_each_tasks[task_id][i] @ self.feature_list_each_tasks[task_id][i].T @ mat) |
| | print(f'{task_idx} to {task_id} in layer {i} : {proj_norm}') |
| | |
| | if proj_norm > Epsilon * mat_norm: |
| | topk.add({'proj_norm':proj_norm, 'task_id': task_id}) |
| | |
| | self.final_decision[task_idx][i] = [dic['task_id'] for dic in topk.get_top_k()] |
| | print(f'Layer {i} of {task_idx} consider {self.final_decision[task_idx][i]} as trust region') |
| | |
| | self.prev_matrix = [] |
| | if task_idx == 0: |
| | for i, module in enumerate(self.attention_modules): |
| | |
| | U, _, _ = torch.linalg.svd(module.cur_matrix) |
| | U = torch.Tensor(U).to(self.device) |
| | |
| | self.prev_matrix.append(U[:,:module.lora_rank].T.cpu()) |
| | |
| | module.lora_A_k_list[task_idx].weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) |
| | module.lora_A_v_list[task_idx].weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) |
| | #module.reset_input_matrix() |
| | else: |
| | for i, module in enumerate(self.attention_modules): |
| | assert self.project_type[i] == 'remove' or self.project_type[i] == 'retain' |
| | |
| | cur_matrix = module.cur_matrix.to(self.device) |
| | |
| | |
| | # TRGP |
| | tr = self.final_decision[task_idx][i][0] |
| | tr = task_idx - 1 |
| | |
| | #feature_mat = torch.Tensor(self.feature_list_each_tasks[tr][i] @ self.feature_list_each_tasks[tr][i].T).to(self.device) |
| | |
| | feature_mat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T).to(self.device) |
| | intersect = feature_mat @ cur_matrix |
| | |
| | target_shape = 768 |
| | |
| | U, _, _ = np.linalg.svd(intersect.cpu().numpy(), full_matrices = False) |
| | U = torch.Tensor(U).to(self.device) |
| | module.space_k[tr] = U[:, :target_shape].T/math.sqrt(3) |
| | module.space_v[tr] = U[:, :target_shape].T/math.sqrt(3) |
| | |
| | # InfLoRA |
| | feature_mat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T).to(self.device) |
| | |
| | if self.project_type[i] == 'remove': |
| | cur_matrix = cur_matrix - feature_mat @ cur_matrix |
| | else: |
| | cur_matrix = feature_mat @ cur_matrix |
| | |
| | module.feature_mat = feature_mat.clone() |
| | |
| | U, _, _ = np.linalg.svd(cur_matrix.cpu().numpy(), full_matrices = False) |
| | U = U[:, :module.lora_rank] |
| | |
| | alphas = torch.linalg.lstsq(torch.Tensor(module.lora_A_k_list[task_idx-1].weight.data).T.cpu(), torch.Tensor(U) / math.sqrt(3)) |
| | if alphas.residuals.numel() != 0: |
| | print(f'Task {task_idx}, Layer {i}, {alphas.residuals}') |
| | assert 0 |
| | |
| | U = torch.Tensor(U).to(self.device) |
| | |
| | module.lora_A_k_list[task_idx].weight.data.copy_(U[:, :module.lora_rank].T/math.sqrt(3)) # here should have /sqrt3 |
| | module.lora_A_v_list[task_idx].weight.data.copy_(U[:, :module.lora_rank].T/math.sqrt(3)) |
| | ''' |
| |
|
| | for name, param in self._network.named_parameters(): |
| | param.requires_grad_(False) |
| | if f"classifier_pool.{task_idx}" in name or \ |
| | f"lora_B_k_list.{task_idx}" in name or \ |
| | f"lora_B_v_list.{task_idx}" in name: |
| | param.requires_grad_(True) |
| |
|
| | for name, param in self._network.named_parameters(): |
| | if param.requires_grad: |
| | print(name) |
| |
|
| | def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| | ''' |
| | Called after each task before final testing, it is used to perform preliminary operations on the mapping matrix to facilitate the update of lora_a layer in the next round of before_task |
| | ''' |
| |
|
| | self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num |
| |
|
| | self._update_feature(task_idx, train_loader, test_loaders) |
| |
|
| | @torch.no_grad() |
| | def _update_feature(self, task_idx, train_loader, test_loaders): |
| | ''' |
| | Update feature lists and the corresponding type |
| | ''' |
| |
|
| | self._update_input_matrix(train_loader) |
| |
|
| | if self.local_rank == 0: |
| |
|
| | threshold = (self.lame - self.lamb)*task_idx/self.task_num + self.lamb |
| |
|
| | if task_idx == 0: |
| | for i, module in enumerate(self.attention_modules): |
| | |
| | activation = module.cur_matrixs[self.probe_indices_svd].sum(dim=0) / 512 |
| | U, S, _ = torch.linalg.svd(activation, full_matrices=False) |
| | true_U = U[:, :module.lora_rank] |
| |
|
| | |
| | alphas = torch.linalg.lstsq(module.lora_A_k_list[task_idx].weight.data.T.cpu() * math.sqrt(3), true_U) |
| | approx2_U = module.lora_A_k_list[task_idx].weight.data.T.cpu() * math.sqrt(3) @ alphas.solution |
| |
|
| | if APPROX_FEAT: |
| | self.feature_list.append(approx2_U) |
| | self.feature_list_each_tasks[task_idx][i] = approx2_U |
| | else: |
| | self.feature_list.append(true_U) |
| | self.feature_list_each_tasks[task_idx][i] = true_U |
| |
|
| | self.project_type.append('remove') |
| |
|
| | else: |
| | for i, module in enumerate(self.attention_modules): |
| |
|
| | activation = module.cur_matrixs[self.probe_indices_svd].sum(dim=0) / 512 |
| | act_hat = activation - torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) @ activation |
| |
|
| | U, _, _ = torch.linalg.svd(act_hat, full_matrices = False) |
| | true_U = U[:, :module.lora_rank] |
| |
|
| | alphas = torch.linalg.lstsq(module.lora_A_k_list[task_idx].weight.data.T.cpu() * math.sqrt(3), true_U) |
| | approx2_U = module.lora_A_k_list[task_idx].weight.data.T.cpu() * math.sqrt(3) @ alphas.solution |
| |
|
| | if APPROX_FEAT: |
| | self.feature_list[i] = torch.cat([self.feature_list[i], approx2_U], dim=1) |
| | self.feature_list_each_tasks[task_idx][i] = approx2_U |
| | else: |
| | self.feature_list[i] = torch.cat([self.feature_list[i], true_U], dim=1) |
| | self.feature_list_each_tasks[task_idx][i] = true_U |
| |
|
| | print('-'*40) |
| | print(f'Threshold: {threshold}') |
| | print('-'*40) |
| | for i in range(len(self.feature_list)): |
| | ''' |
| | if self.project_type[i]=='remove' and (self.feature_list[i].shape[1] > (self.feature_list[i].shape[0]/2)): |
| | feature = self.feature_list[i] |
| | U, S, V = np.linalg.svd(feature) |
| | new_feature = U[:,feature.shape[1]:] |
| | self.feature_list[i] = new_feature |
| | self.project_type[i] = 'retain' |
| | elif self.project_type[i]=='retain': |
| | assert self.feature_list[i].shape[1] <= (self.feature_list[i].shape[0]/2) |
| | ''' |
| | print ('Layer {} : {}/{} type {}'.format(i+1,self.feature_list[i].shape[1], self.feature_list[i].shape[0], self.project_type[i])) |
| | print('-'*40) |
| |
|
| | @torch.no_grad() |
| | def _update_input_matrix(self, train_loader): |
| |
|
| | for module in self.attention_modules: |
| | module.reset_input_matrix() |
| |
|
| | _set_random(os.environ["PYTHONHASHSEED"]) |
| | self.dataset = [] |
| | for batch in tqdm(train_loader, desc="Forwarding to get input matrix", disable=(self.local_rank != 0)): |
| | self._network.update_input_matrix(batch['image'].to(self.device)) |
| | self.dataset.append(batch['image']) |
| |
|
| | self.dataset = torch.cat(self.dataset, dim=0) |
| |
|
| | for module in self.attention_modules: |
| | module.cur_matrixs = torch.cat(module.cur_matrixs, dim=0) |
| | module.cur_matrixs = torch.bmm( |
| | module.cur_matrixs.permute(0, 2, 1), |
| | module.cur_matrixs |
| | ).cpu() |
| |
|
| | def get_parameters(self, config): |
| | return self._network.parameters() |