nvan13's picture
Add files using upload-large-folder tool
f4dcc30 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
import numpy as np
class MHE_LoRA(nn.Module):
def __init__(self, model):
super(MHE_LoRA, self).__init__()
# self.model = copy.deepcopy(model)
self.model = self.copy_without_grad(model)
self.extracted_params = {}
keys_to_delete = []
# for name, param in self.model.named_parameters():
# self.extracted_params[name] = param
for name, tensor in model.state_dict().items():
self.extracted_params[name] = tensor.detach().clone()
for name in self.extracted_params:
if 'attn' in name and 'processor' not in name:
if 'weight' in name:
if 'to_q' in name:
lora_down = name.replace('to_q', 'processor.to_q_lora.down')
lora_up = name.replace('to_q', 'processor.to_q_lora.up')
elif 'to_k' in name:
lora_down = name.replace('to_k', 'processor.to_k_lora.down')
lora_up = name.replace('to_k', 'processor.to_k_lora.up')
elif 'to_v' in name:
lora_down = name.replace('to_v', 'processor.to_v_lora.down')
lora_up = name.replace('to_v', 'processor.to_v_lora.up')
elif 'to_out' in name:
lora_down = name.replace('to_out.0', 'processor.to_out_lora.down')
lora_up = name.replace('to_out.0', 'processor.to_out_lora.up')
else:
pass
with torch.no_grad():
self.extracted_params[name] += self.extracted_params[lora_up].cuda() @ self.extracted_params[lora_down].cuda()
keys_to_delete.append(lora_up)
keys_to_delete.append(lora_down)
for key in keys_to_delete:
del self.extracted_params[key]
def copy_without_grad(self, model):
copied_model = copy.deepcopy(model)
for param in copied_model.parameters():
param.requires_grad = False
param.detach_()
return copied_model
@staticmethod
def mhe_loss(filt):
if len(filt.shape) == 2:
n_filt, _ = filt.shape
filt = torch.transpose(filt, 0, 1)
filt_neg = filt * (-1)
filt = torch.cat((filt, filt_neg), dim=1)
n_filt *= 2
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
inner_pro = torch.matmul(filt.t(), filt)
inner_pro /= norm_mat
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
final -= torch.tril(final)
cnt = n_filt * (n_filt - 1) / 2.0
MHE_loss = 1 * torch.sum(final) / cnt
else:
n_filt, _, _, _ = filt.shape
filt = filt.reshape(n_filt, -1)
filt = torch.transpose(filt, 0, 1)
filt_neg = filt * -1
filt = torch.cat((filt, filt_neg), dim=1)
n_filt *= 2
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
inner_pro = torch.matmul(filt.t(), filt)
inner_pro /= norm_mat
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
final -= torch.tril(final)
cnt = n_filt * (n_filt - 1) / 2.0
MHE_loss = 1 * torch.sum(final) / cnt
return MHE_loss
def calculate_mhe(self):
mhe_loss = []
with torch.no_grad():
for name in self.extracted_params:
weight = self.extracted_params[name]
# linear layer or conv layer
if len(weight.shape) == 2 or len(weight.shape) == 4:
loss = self.mhe_loss(weight)
mhe_loss.append(loss.cpu().detach().item())
mhe_loss = np.array(mhe_loss)
return mhe_loss.sum()
def project(R, eps):
I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device)
diff = R - I
norm_diff = torch.norm(diff)
if norm_diff <= eps:
return R
else:
return I + eps * (diff / norm_diff)
def project_batch(R, eps=1e-5):
# scaling factor for each of the smaller block matrix
eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0]))
I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R)
diff = R - I
norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True)
mask = (norm_diff <= eps).bool()
out = torch.where(mask, R, I + eps * (diff / norm_diff))
return out
class MHE_OFT(nn.Module):
def __init__(self, model, eps=6e-5, r=4):
super(MHE_OFT, self).__init__()
# self.model = copy.deepcopy(model)
# self.model = self.copy_without_grad(model)
self.r = r
self.extracted_params = {}
keys_to_delete = []
# for name, param in self.model.named_parameters():
# self.extracted_params[name] = param
for name, tensor in model.state_dict().items():
self.extracted_params[name] = tensor.detach().clone()
for name in self.extracted_params:
if 'attn' in name and 'processor' not in name:
if 'weight' in name:
if 'to_q' in name:
oft_R = name.replace('to_q.weight', 'processor.to_q_oft.R')
elif 'to_k' in name:
oft_R = name.replace('to_k.weight', 'processor.to_k_oft.R')
elif 'to_v' in name:
oft_R = name.replace('to_v.weight', 'processor.to_v_oft.R')
elif 'to_out' in name:
oft_R = name.replace('to_out.0.weight', 'processor.to_out_oft.R')
else:
pass
R = self.extracted_params[oft_R].cuda()
with torch.no_grad():
if len(R.shape) == 2:
self.eps = eps * R.shape[0] * R.shape[0]
R.copy_(project(R, eps=self.eps))
orth_rotate = self.cayley(R)
else:
self.eps = eps * R.shape[1] * R.shape[0]
R.copy_(project_batch(R, eps=self.eps))
orth_rotate = self.cayley_batch(R)
self.extracted_params[name] = self.extracted_params[name] @ self.block_diagonal(orth_rotate)
keys_to_delete.append(oft_R)
for key in keys_to_delete:
del self.extracted_params[key]
def is_orthogonal(self, R, eps=1e-5):
with torch.no_grad():
RtR = torch.matmul(R.t(), R)
diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device))
return torch.all(diff < eps)
def block_diagonal(self, R):
if len(R.shape) == 2:
# Create a list of R repeated block_count times
blocks = [R] * self.r
else:
# Create a list of R slices along the third dimension
blocks = [R[i, ...] for i in range(R.shape[0])]
# Use torch.block_diag to create the block diagonal matrix
A = torch.block_diag(*blocks)
return A
def copy_without_grad(self, model):
copied_model = copy.deepcopy(model)
for param in copied_model.parameters():
param.requires_grad = False
param.detach_()
return copied_model
def cayley(self, data):
r, c = list(data.shape)
# Ensure the input matrix is skew-symmetric
skew = 0.5 * (data - data.t())
I = torch.eye(r, device=data.device)
# Perform the Cayley parametrization
Q = torch.mm(I + skew, torch.inverse(I - skew))
return Q
def cayley_batch(self, data):
b, r, c = data.shape
# Ensure the input matrix is skew-symmetric
skew = 0.5 * (data - data.transpose(1, 2))
# I = torch.eye(r, device=data.device).unsqueeze(0).repeat(b, 1, 1)
I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c)
# Perform the Cayley parametrization
Q = torch.bmm(I + skew, torch.inverse(I - skew))
return Q
@staticmethod
def mhe_loss(filt):
if len(filt.shape) == 2:
n_filt, _ = filt.shape
filt = torch.transpose(filt, 0, 1)
filt_neg = filt * (-1)
filt = torch.cat((filt, filt_neg), dim=1)
n_filt *= 2
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
inner_pro = torch.matmul(filt.t(), filt)
inner_pro /= norm_mat
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
final -= torch.tril(final)
cnt = n_filt * (n_filt - 1) / 2.0
MHE_loss = 1 * torch.sum(final) / cnt
else:
n_filt, _, _, _ = filt.shape
filt = filt.reshape(n_filt, -1)
filt = torch.transpose(filt, 0, 1)
filt_neg = filt * -1
filt = torch.cat((filt, filt_neg), dim=1)
n_filt *= 2
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
inner_pro = torch.matmul(filt.t(), filt)
inner_pro /= norm_mat
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
final -= torch.tril(final)
cnt = n_filt * (n_filt - 1) / 2.0
MHE_loss = 1 * torch.sum(final) / cnt
return MHE_loss
def calculate_mhe(self):
mhe_loss = []
with torch.no_grad():
for name in self.extracted_params:
weight = self.extracted_params[name]
# linear layer or conv layer
if len(weight.shape) == 2 or len(weight.shape) == 4:
loss = self.mhe_loss(weight)
mhe_loss.append(loss.cpu().detach().item())
mhe_loss = np.array(mhe_loss)
return mhe_loss.sum()
def is_orthogonal(self, R, eps=1e-5):
with torch.no_grad():
RtR = torch.matmul(R.t(), R)
diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device))
return torch.all(diff < eps)
def is_identity_matrix(self, tensor):
if not torch.is_tensor(tensor):
raise TypeError("Input must be a PyTorch tensor.")
if tensor.ndim != 2 or tensor.shape[0] != tensor.shape[1]:
return False
identity = torch.eye(tensor.shape[0], device=tensor.device)
return torch.all(torch.eq(tensor, identity))
class MHE_db:
def __init__(self, model):
# self.model = copy.deepcopy(model)
# self.model.load_state_dict(model.state_dict())
# self.model = self.copy_without_grad(model)
#self.extracted_params = {}
#for name, param in model.named_parameters():
# self.extracted_params[name] = param
self.extracted_params = {}
for name, tensor in model.state_dict().items():
self.extracted_params[name] = tensor.detach().clone()
@staticmethod
def mhe_loss(filt):
if len(filt.shape) == 2:
n_filt, _ = filt.shape
filt = torch.transpose(filt, 0, 1)
filt_neg = filt * (-1)
filt = torch.cat((filt, filt_neg), dim=1)
n_filt *= 2
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
inner_pro = torch.matmul(filt.t(), filt)
inner_pro /= norm_mat
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
final -= torch.tril(final)
cnt = n_filt * (n_filt - 1) / 2.0
MHE_loss = 1 * torch.sum(final) / cnt
else:
n_filt, _, _, _ = filt.shape
filt = filt.reshape(n_filt, -1)
filt = torch.transpose(filt, 0, 1)
filt_neg = filt * -1
filt = torch.cat((filt, filt_neg), dim=1)
n_filt *= 2
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
inner_pro = torch.matmul(filt.t(), filt)
inner_pro /= norm_mat
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
final -= torch.tril(final)
cnt = n_filt * (n_filt - 1) / 2.0
MHE_loss = 1 * torch.sum(final) / cnt
return MHE_loss
def calculate_mhe(self):
mhe_loss = []
with torch.no_grad():
for name in self.extracted_params:
weight = self.extracted_params[name]
# linear layer or conv layer
if len(weight.shape) == 2 or len(weight.shape) == 4:
loss = self.mhe_loss(weight)
mhe_loss.append(loss.cpu().detach().item())
mhe_loss = np.array(mhe_loss)
return mhe_loss.sum()