| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | import math |
| |
|
| |
|
| | import copy |
| | import numpy as np |
| |
|
| | class MHE_LoRA(nn.Module): |
| | def __init__(self, model): |
| | super(MHE_LoRA, self).__init__() |
| | |
| | self.model = self.copy_without_grad(model) |
| |
|
| | self.extracted_params = {} |
| | keys_to_delete = [] |
| | |
| | |
| |
|
| | for name, tensor in model.state_dict().items(): |
| | self.extracted_params[name] = tensor.detach().clone() |
| |
|
| | for name in self.extracted_params: |
| | if 'attn' in name and 'processor' not in name: |
| | if 'weight' in name: |
| | if 'to_q' in name: |
| | lora_down = name.replace('to_q', 'processor.to_q_lora.down') |
| | lora_up = name.replace('to_q', 'processor.to_q_lora.up') |
| | elif 'to_k' in name: |
| | lora_down = name.replace('to_k', 'processor.to_k_lora.down') |
| | lora_up = name.replace('to_k', 'processor.to_k_lora.up') |
| | elif 'to_v' in name: |
| | lora_down = name.replace('to_v', 'processor.to_v_lora.down') |
| | lora_up = name.replace('to_v', 'processor.to_v_lora.up') |
| | elif 'to_out' in name: |
| | lora_down = name.replace('to_out.0', 'processor.to_out_lora.down') |
| | lora_up = name.replace('to_out.0', 'processor.to_out_lora.up') |
| | else: |
| | pass |
| | with torch.no_grad(): |
| | self.extracted_params[name] += self.extracted_params[lora_up].cuda() @ self.extracted_params[lora_down].cuda() |
| | keys_to_delete.append(lora_up) |
| | keys_to_delete.append(lora_down) |
| | |
| | for key in keys_to_delete: |
| | del self.extracted_params[key] |
| |
|
| | def copy_without_grad(self, model): |
| | copied_model = copy.deepcopy(model) |
| | for param in copied_model.parameters(): |
| | param.requires_grad = False |
| | param.detach_() |
| | return copied_model |
| |
|
| | @staticmethod |
| | def mhe_loss(filt): |
| | if len(filt.shape) == 2: |
| | n_filt, _ = filt.shape |
| | filt = torch.transpose(filt, 0, 1) |
| | filt_neg = filt * (-1) |
| | filt = torch.cat((filt, filt_neg), dim=1) |
| | n_filt *= 2 |
| |
|
| | filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) |
| | norm_mat = torch.matmul(filt_norm.t(), filt_norm) |
| | inner_pro = torch.matmul(filt.t(), filt) |
| | inner_pro /= norm_mat |
| |
|
| | cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) |
| | final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) |
| | final -= torch.tril(final) |
| | cnt = n_filt * (n_filt - 1) / 2.0 |
| | MHE_loss = 1 * torch.sum(final) / cnt |
| | |
| | else: |
| | n_filt, _, _, _ = filt.shape |
| | filt = filt.reshape(n_filt, -1) |
| | filt = torch.transpose(filt, 0, 1) |
| | filt_neg = filt * -1 |
| | filt = torch.cat((filt, filt_neg), dim=1) |
| | n_filt *= 2 |
| |
|
| | filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) |
| | norm_mat = torch.matmul(filt_norm.t(), filt_norm) |
| | inner_pro = torch.matmul(filt.t(), filt) |
| | inner_pro /= norm_mat |
| |
|
| | cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) |
| | final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) |
| | final -= torch.tril(final) |
| | cnt = n_filt * (n_filt - 1) / 2.0 |
| | MHE_loss = 1 * torch.sum(final) / cnt |
| |
|
| | return MHE_loss |
| |
|
| | def calculate_mhe(self): |
| | mhe_loss = [] |
| | with torch.no_grad(): |
| | for name in self.extracted_params: |
| | weight = self.extracted_params[name] |
| | |
| | if len(weight.shape) == 2 or len(weight.shape) == 4: |
| | loss = self.mhe_loss(weight) |
| | mhe_loss.append(loss.cpu().detach().item()) |
| | mhe_loss = np.array(mhe_loss) |
| | return mhe_loss.sum() |
| |
|
| |
|
| | def project(R, eps): |
| | I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device) |
| | diff = R - I |
| | norm_diff = torch.norm(diff) |
| | if norm_diff <= eps: |
| | return R |
| | else: |
| | return I + eps * (diff / norm_diff) |
| |
|
| | def project_batch(R, eps=1e-5): |
| | |
| | eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0])) |
| | I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R) |
| | diff = R - I |
| | norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True) |
| | mask = (norm_diff <= eps).bool() |
| | out = torch.where(mask, R, I + eps * (diff / norm_diff)) |
| | return out |
| |
|
| |
|
| | class MHE_OFT(nn.Module): |
| | def __init__(self, model, eps=6e-5, r=4): |
| | super(MHE_OFT, self).__init__() |
| | |
| | |
| |
|
| | self.r = r |
| |
|
| | self.extracted_params = {} |
| | keys_to_delete = [] |
| | |
| | |
| |
|
| | for name, tensor in model.state_dict().items(): |
| | self.extracted_params[name] = tensor.detach().clone() |
| |
|
| | for name in self.extracted_params: |
| | if 'attn' in name and 'processor' not in name: |
| | if 'weight' in name: |
| | if 'to_q' in name: |
| | oft_R = name.replace('to_q.weight', 'processor.to_q_oft.R') |
| | elif 'to_k' in name: |
| | oft_R = name.replace('to_k.weight', 'processor.to_k_oft.R') |
| | elif 'to_v' in name: |
| | oft_R = name.replace('to_v.weight', 'processor.to_v_oft.R') |
| | elif 'to_out' in name: |
| | oft_R = name.replace('to_out.0.weight', 'processor.to_out_oft.R') |
| | else: |
| | pass |
| | |
| | R = self.extracted_params[oft_R].cuda() |
| |
|
| | with torch.no_grad(): |
| | if len(R.shape) == 2: |
| | self.eps = eps * R.shape[0] * R.shape[0] |
| | R.copy_(project(R, eps=self.eps)) |
| | orth_rotate = self.cayley(R) |
| | else: |
| | self.eps = eps * R.shape[1] * R.shape[0] |
| | R.copy_(project_batch(R, eps=self.eps)) |
| | orth_rotate = self.cayley_batch(R) |
| |
|
| | self.extracted_params[name] = self.extracted_params[name] @ self.block_diagonal(orth_rotate) |
| | keys_to_delete.append(oft_R) |
| | |
| | for key in keys_to_delete: |
| | del self.extracted_params[key] |
| | |
| | def is_orthogonal(self, R, eps=1e-5): |
| | with torch.no_grad(): |
| | RtR = torch.matmul(R.t(), R) |
| | diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device)) |
| | return torch.all(diff < eps) |
| |
|
| | def block_diagonal(self, R): |
| | if len(R.shape) == 2: |
| | |
| | blocks = [R] * self.r |
| | else: |
| | |
| | blocks = [R[i, ...] for i in range(R.shape[0])] |
| |
|
| | |
| | A = torch.block_diag(*blocks) |
| |
|
| | return A |
| |
|
| | def copy_without_grad(self, model): |
| | copied_model = copy.deepcopy(model) |
| | for param in copied_model.parameters(): |
| | param.requires_grad = False |
| | param.detach_() |
| | return copied_model |
| | |
| | def cayley(self, data): |
| | r, c = list(data.shape) |
| | |
| | skew = 0.5 * (data - data.t()) |
| | I = torch.eye(r, device=data.device) |
| | |
| | Q = torch.mm(I + skew, torch.inverse(I - skew)) |
| | return Q |
| | |
| | def cayley_batch(self, data): |
| | b, r, c = data.shape |
| | |
| | skew = 0.5 * (data - data.transpose(1, 2)) |
| | |
| | I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) |
| |
|
| | |
| | Q = torch.bmm(I + skew, torch.inverse(I - skew)) |
| |
|
| | return Q |
| |
|
| | @staticmethod |
| | def mhe_loss(filt): |
| | if len(filt.shape) == 2: |
| | n_filt, _ = filt.shape |
| | filt = torch.transpose(filt, 0, 1) |
| | filt_neg = filt * (-1) |
| | filt = torch.cat((filt, filt_neg), dim=1) |
| | n_filt *= 2 |
| |
|
| | filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) |
| | norm_mat = torch.matmul(filt_norm.t(), filt_norm) |
| | inner_pro = torch.matmul(filt.t(), filt) |
| | inner_pro /= norm_mat |
| |
|
| | cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) |
| | final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) |
| | final -= torch.tril(final) |
| | cnt = n_filt * (n_filt - 1) / 2.0 |
| | MHE_loss = 1 * torch.sum(final) / cnt |
| | |
| | else: |
| | n_filt, _, _, _ = filt.shape |
| | filt = filt.reshape(n_filt, -1) |
| | filt = torch.transpose(filt, 0, 1) |
| | filt_neg = filt * -1 |
| | filt = torch.cat((filt, filt_neg), dim=1) |
| | n_filt *= 2 |
| |
|
| | filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) |
| | norm_mat = torch.matmul(filt_norm.t(), filt_norm) |
| | inner_pro = torch.matmul(filt.t(), filt) |
| | inner_pro /= norm_mat |
| |
|
| | cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) |
| | final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) |
| | final -= torch.tril(final) |
| | cnt = n_filt * (n_filt - 1) / 2.0 |
| | MHE_loss = 1 * torch.sum(final) / cnt |
| |
|
| | return MHE_loss |
| |
|
| | def calculate_mhe(self): |
| | mhe_loss = [] |
| | with torch.no_grad(): |
| | for name in self.extracted_params: |
| | weight = self.extracted_params[name] |
| | |
| | if len(weight.shape) == 2 or len(weight.shape) == 4: |
| | loss = self.mhe_loss(weight) |
| | mhe_loss.append(loss.cpu().detach().item()) |
| | mhe_loss = np.array(mhe_loss) |
| | return mhe_loss.sum() |
| | |
| | def is_orthogonal(self, R, eps=1e-5): |
| | with torch.no_grad(): |
| | RtR = torch.matmul(R.t(), R) |
| | diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device)) |
| | return torch.all(diff < eps) |
| |
|
| | def is_identity_matrix(self, tensor): |
| | if not torch.is_tensor(tensor): |
| | raise TypeError("Input must be a PyTorch tensor.") |
| | if tensor.ndim != 2 or tensor.shape[0] != tensor.shape[1]: |
| | return False |
| | identity = torch.eye(tensor.shape[0], device=tensor.device) |
| | return torch.all(torch.eq(tensor, identity)) |
| |
|
| |
|
| |
|
| | class MHE_db: |
| | def __init__(self, model): |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | self.extracted_params = {} |
| | for name, tensor in model.state_dict().items(): |
| | self.extracted_params[name] = tensor.detach().clone() |
| |
|
| | @staticmethod |
| | def mhe_loss(filt): |
| | if len(filt.shape) == 2: |
| | n_filt, _ = filt.shape |
| | filt = torch.transpose(filt, 0, 1) |
| | filt_neg = filt * (-1) |
| | filt = torch.cat((filt, filt_neg), dim=1) |
| | n_filt *= 2 |
| |
|
| | filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) |
| | norm_mat = torch.matmul(filt_norm.t(), filt_norm) |
| | inner_pro = torch.matmul(filt.t(), filt) |
| | inner_pro /= norm_mat |
| |
|
| | cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) |
| | final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) |
| | final -= torch.tril(final) |
| | cnt = n_filt * (n_filt - 1) / 2.0 |
| | MHE_loss = 1 * torch.sum(final) / cnt |
| | |
| | else: |
| | n_filt, _, _, _ = filt.shape |
| | filt = filt.reshape(n_filt, -1) |
| | filt = torch.transpose(filt, 0, 1) |
| | filt_neg = filt * -1 |
| | filt = torch.cat((filt, filt_neg), dim=1) |
| | n_filt *= 2 |
| |
|
| | filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) |
| | norm_mat = torch.matmul(filt_norm.t(), filt_norm) |
| | inner_pro = torch.matmul(filt.t(), filt) |
| | inner_pro /= norm_mat |
| |
|
| | cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) |
| | final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) |
| | final -= torch.tril(final) |
| | cnt = n_filt * (n_filt - 1) / 2.0 |
| | MHE_loss = 1 * torch.sum(final) / cnt |
| |
|
| | return MHE_loss |
| |
|
| | def calculate_mhe(self): |
| | mhe_loss = [] |
| | with torch.no_grad(): |
| | for name in self.extracted_params: |
| | weight = self.extracted_params[name] |
| | |
| | if len(weight.shape) == 2 or len(weight.shape) == 4: |
| | loss = self.mhe_loss(weight) |
| | mhe_loss.append(loss.cpu().detach().item()) |
| | mhe_loss = np.array(mhe_loss) |
| | return mhe_loss.sum() |
| |
|