import torch import torch.nn as nn import torch.nn.functional as F import math import copy import numpy as np class MHE_LoRA(nn.Module): def __init__(self, model): super(MHE_LoRA, self).__init__() # self.model = copy.deepcopy(model) self.model = self.copy_without_grad(model) self.extracted_params = {} keys_to_delete = [] # for name, param in self.model.named_parameters(): # self.extracted_params[name] = param for name, tensor in model.state_dict().items(): self.extracted_params[name] = tensor.detach().clone() for name in self.extracted_params: if 'attn' in name and 'processor' not in name: if 'weight' in name: if 'to_q' in name: lora_down = name.replace('to_q', 'processor.to_q_lora.down') lora_up = name.replace('to_q', 'processor.to_q_lora.up') elif 'to_k' in name: lora_down = name.replace('to_k', 'processor.to_k_lora.down') lora_up = name.replace('to_k', 'processor.to_k_lora.up') elif 'to_v' in name: lora_down = name.replace('to_v', 'processor.to_v_lora.down') lora_up = name.replace('to_v', 'processor.to_v_lora.up') elif 'to_out' in name: lora_down = name.replace('to_out.0', 'processor.to_out_lora.down') lora_up = name.replace('to_out.0', 'processor.to_out_lora.up') else: pass with torch.no_grad(): self.extracted_params[name] += self.extracted_params[lora_up].cuda() @ self.extracted_params[lora_down].cuda() keys_to_delete.append(lora_up) keys_to_delete.append(lora_down) for key in keys_to_delete: del self.extracted_params[key] def copy_without_grad(self, model): copied_model = copy.deepcopy(model) for param in copied_model.parameters(): param.requires_grad = False param.detach_() return copied_model @staticmethod def mhe_loss(filt): if len(filt.shape) == 2: n_filt, _ = filt.shape filt = torch.transpose(filt, 0, 1) filt_neg = filt * (-1) filt = torch.cat((filt, filt_neg), dim=1) n_filt *= 2 filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) norm_mat = torch.matmul(filt_norm.t(), filt_norm) inner_pro = torch.matmul(filt.t(), filt) inner_pro /= norm_mat cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) final -= torch.tril(final) cnt = n_filt * (n_filt - 1) / 2.0 MHE_loss = 1 * torch.sum(final) / cnt else: n_filt, _, _, _ = filt.shape filt = filt.reshape(n_filt, -1) filt = torch.transpose(filt, 0, 1) filt_neg = filt * -1 filt = torch.cat((filt, filt_neg), dim=1) n_filt *= 2 filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) norm_mat = torch.matmul(filt_norm.t(), filt_norm) inner_pro = torch.matmul(filt.t(), filt) inner_pro /= norm_mat cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) final -= torch.tril(final) cnt = n_filt * (n_filt - 1) / 2.0 MHE_loss = 1 * torch.sum(final) / cnt return MHE_loss def calculate_mhe(self): mhe_loss = [] with torch.no_grad(): for name in self.extracted_params: weight = self.extracted_params[name] # linear layer or conv layer if len(weight.shape) == 2 or len(weight.shape) == 4: loss = self.mhe_loss(weight) mhe_loss.append(loss.cpu().detach().item()) mhe_loss = np.array(mhe_loss) return mhe_loss.sum() def project(R, eps): I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device) diff = R - I norm_diff = torch.norm(diff) if norm_diff <= eps: return R else: return I + eps * (diff / norm_diff) def project_batch(R, eps=1e-5): # scaling factor for each of the smaller block matrix eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0])) I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R) diff = R - I norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True) mask = (norm_diff <= eps).bool() out = torch.where(mask, R, I + eps * (diff / norm_diff)) return out class MHE_OFT(nn.Module): def __init__(self, model, eps=6e-5, r=4): super(MHE_OFT, self).__init__() # self.model = copy.deepcopy(model) # self.model = self.copy_without_grad(model) self.r = r self.extracted_params = {} keys_to_delete = [] # for name, param in self.model.named_parameters(): # self.extracted_params[name] = param for name, tensor in model.state_dict().items(): self.extracted_params[name] = tensor.detach().clone() for name in self.extracted_params: if 'attn' in name and 'processor' not in name: if 'weight' in name: if 'to_q' in name: oft_R = name.replace('to_q.weight', 'processor.to_q_oft.R') elif 'to_k' in name: oft_R = name.replace('to_k.weight', 'processor.to_k_oft.R') elif 'to_v' in name: oft_R = name.replace('to_v.weight', 'processor.to_v_oft.R') elif 'to_out' in name: oft_R = name.replace('to_out.0.weight', 'processor.to_out_oft.R') else: pass R = self.extracted_params[oft_R].cuda() with torch.no_grad(): if len(R.shape) == 2: self.eps = eps * R.shape[0] * R.shape[0] R.copy_(project(R, eps=self.eps)) orth_rotate = self.cayley(R) else: self.eps = eps * R.shape[1] * R.shape[0] R.copy_(project_batch(R, eps=self.eps)) orth_rotate = self.cayley_batch(R) self.extracted_params[name] = self.extracted_params[name] @ self.block_diagonal(orth_rotate) keys_to_delete.append(oft_R) for key in keys_to_delete: del self.extracted_params[key] def is_orthogonal(self, R, eps=1e-5): with torch.no_grad(): RtR = torch.matmul(R.t(), R) diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device)) return torch.all(diff < eps) def block_diagonal(self, R): if len(R.shape) == 2: # Create a list of R repeated block_count times blocks = [R] * self.r else: # Create a list of R slices along the third dimension blocks = [R[i, ...] for i in range(R.shape[0])] # Use torch.block_diag to create the block diagonal matrix A = torch.block_diag(*blocks) return A def copy_without_grad(self, model): copied_model = copy.deepcopy(model) for param in copied_model.parameters(): param.requires_grad = False param.detach_() return copied_model def cayley(self, data): r, c = list(data.shape) # Ensure the input matrix is skew-symmetric skew = 0.5 * (data - data.t()) I = torch.eye(r, device=data.device) # Perform the Cayley parametrization Q = torch.mm(I + skew, torch.inverse(I - skew)) return Q def cayley_batch(self, data): b, r, c = data.shape # Ensure the input matrix is skew-symmetric skew = 0.5 * (data - data.transpose(1, 2)) # I = torch.eye(r, device=data.device).unsqueeze(0).repeat(b, 1, 1) I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) # Perform the Cayley parametrization Q = torch.bmm(I + skew, torch.inverse(I - skew)) return Q @staticmethod def mhe_loss(filt): if len(filt.shape) == 2: n_filt, _ = filt.shape filt = torch.transpose(filt, 0, 1) filt_neg = filt * (-1) filt = torch.cat((filt, filt_neg), dim=1) n_filt *= 2 filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) norm_mat = torch.matmul(filt_norm.t(), filt_norm) inner_pro = torch.matmul(filt.t(), filt) inner_pro /= norm_mat cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) final -= torch.tril(final) cnt = n_filt * (n_filt - 1) / 2.0 MHE_loss = 1 * torch.sum(final) / cnt else: n_filt, _, _, _ = filt.shape filt = filt.reshape(n_filt, -1) filt = torch.transpose(filt, 0, 1) filt_neg = filt * -1 filt = torch.cat((filt, filt_neg), dim=1) n_filt *= 2 filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) norm_mat = torch.matmul(filt_norm.t(), filt_norm) inner_pro = torch.matmul(filt.t(), filt) inner_pro /= norm_mat cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) final -= torch.tril(final) cnt = n_filt * (n_filt - 1) / 2.0 MHE_loss = 1 * torch.sum(final) / cnt return MHE_loss def calculate_mhe(self): mhe_loss = [] with torch.no_grad(): for name in self.extracted_params: weight = self.extracted_params[name] # linear layer or conv layer if len(weight.shape) == 2 or len(weight.shape) == 4: loss = self.mhe_loss(weight) mhe_loss.append(loss.cpu().detach().item()) mhe_loss = np.array(mhe_loss) return mhe_loss.sum() def is_orthogonal(self, R, eps=1e-5): with torch.no_grad(): RtR = torch.matmul(R.t(), R) diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device)) return torch.all(diff < eps) def is_identity_matrix(self, tensor): if not torch.is_tensor(tensor): raise TypeError("Input must be a PyTorch tensor.") if tensor.ndim != 2 or tensor.shape[0] != tensor.shape[1]: return False identity = torch.eye(tensor.shape[0], device=tensor.device) return torch.all(torch.eq(tensor, identity)) class MHE_db: def __init__(self, model): # self.model = copy.deepcopy(model) # self.model.load_state_dict(model.state_dict()) # self.model = self.copy_without_grad(model) #self.extracted_params = {} #for name, param in model.named_parameters(): # self.extracted_params[name] = param self.extracted_params = {} for name, tensor in model.state_dict().items(): self.extracted_params[name] = tensor.detach().clone() @staticmethod def mhe_loss(filt): if len(filt.shape) == 2: n_filt, _ = filt.shape filt = torch.transpose(filt, 0, 1) filt_neg = filt * (-1) filt = torch.cat((filt, filt_neg), dim=1) n_filt *= 2 filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) norm_mat = torch.matmul(filt_norm.t(), filt_norm) inner_pro = torch.matmul(filt.t(), filt) inner_pro /= norm_mat cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) final -= torch.tril(final) cnt = n_filt * (n_filt - 1) / 2.0 MHE_loss = 1 * torch.sum(final) / cnt else: n_filt, _, _, _ = filt.shape filt = filt.reshape(n_filt, -1) filt = torch.transpose(filt, 0, 1) filt_neg = filt * -1 filt = torch.cat((filt, filt_neg), dim=1) n_filt *= 2 filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4) norm_mat = torch.matmul(filt_norm.t(), filt_norm) inner_pro = torch.matmul(filt.t(), filt) inner_pro /= norm_mat cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda()) final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5)) final -= torch.tril(final) cnt = n_filt * (n_filt - 1) / 2.0 MHE_loss = 1 * torch.sum(final) / cnt return MHE_loss def calculate_mhe(self): mhe_loss = [] with torch.no_grad(): for name in self.extracted_params: weight = self.extracted_params[name] # linear layer or conv layer if len(weight.shape) == 2 or len(weight.shape) == 4: loss = self.mhe_loss(weight) mhe_loss.append(loss.cpu().detach().item()) mhe_loss = np.array(mhe_loss) return mhe_loss.sum()