unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class LoRALinear(nn.Module):
"""
LoRA layer: Low-Rank Adaptation.
This layer consists of a low-rank decomposition of weight updates.
"""
def __init__(self, in_features, out_features, r=8, alpha=1.0, dropout=0.1, **kwargs):
super(LoRALinear, self).__init__()
self.use_si = False
self.multi_domain = 0
if 'use_si' in kwargs.keys():
self.model = LoRALinearSI(
in_features, out_features, r, alpha, **kwargs
)
self.use_si = True
elif 'multi_domain' in kwargs.keys():
self.r = r
self.alpha = alpha
self.multi_domain = kwargs['multi_domain']
a_list, b_list, drop_list = [], [], []
for i in range(self.multi_domain):
a_list.append(nn.Linear(in_features, r, bias=False))
b_list.append(nn.Linear(r, out_features, bias=False))
drop_list.append(nn.Dropout(dropout))
self.A = nn.ModuleList(a_list)
self.B = nn.ModuleList(b_list)
self.drop =nn.ModuleList(drop_list)
self.scaling = alpha / r
self._init_weights()
else:
self.r = r
self.alpha = alpha
# Low-rank decomposition matrices
self.A = nn.Linear(in_features, r, bias=False) # Down-projection
self.drop = nn.Dropout(dropout)
self.B = nn.Linear(r, out_features, bias=False) # Up-projection
nn.init.zeros_(self.B.weight)
nn.init.normal_(self.A.weight, std=1 / r)
self.lora_name = "lora_layer" # Unique name
# Scaling factor for LoRA
self.scaling = alpha / r
def _init_weights(self):
for layer in self.A:
nn.init.normal_(layer.weight, std=1 / self.r)
for layer in self.B:
nn.init.zeros_(layer.weight)
def forward(self, x, task_mask=None, i=None,task_idx=None):
# Apply low-rank update: scaling * (A(x) * B)
if self.use_si:
return self.model(x)
return self.scaling * self.B(self.drop(self.A(x)))
def update_si_information(self):
if self.use_si:
self.model.update_si_information()
def finalize_si_importance(self):
if self.use_si:
self.model.finalize_si_importance()
class BayesianLinear(nn.Module):
def __init__(self, in_features, out_features, r=8, prior_std=0.1, dropout=0.1, **kwargs):
"""
Bayesian LoRA Layer: Instead of deterministic weights,
it learns a distribution over LoRA parameters using Bayesian inference.
Args:
in_features (int): Input dimension.
out_features (int): Output dimension.
rank (int): LoRA rank.
prior_std (float): Standard deviation of the Gaussian prior.
"""
super(BayesianLinear, self).__init__()
# Learnable means and log-variances (for stability)
self.scaling = 1 / r
self.A_mu = nn.Parameter(torch.randn(in_features, r) * (1 / r))
self.A_logvar = nn.Parameter(torch.randn(in_features, r) * (1 / r))
self.B_mu = nn.Parameter(torch.randn(r, out_features) * (1 / r))
self.B_logvar = nn.Parameter(torch.randn(r, out_features) * (1 / r))
self.drop = nn.Dropout(dropout)
# Gaussian prior (zero mean)
self.prior_std = prior_std
def sample_weights(self):
"""
Reparameterization Trick: Sample weights from Gaussian distribution.
"""
A_std = torch.exp(0.5 * self.A_logvar)
B_std = torch.exp(0.5 * self.B_logvar)
# Sample weights using reparameterization
B_sample = self.B_mu + B_std * torch.randn_like(B_std)
A_sample = self.A_mu + A_std * torch.randn_like(A_std)
return A_sample, B_sample
# def kl_divergence(self):
# """
# Compute KL divergence between learned weight distributions and the prior.
# """
# W_std = torch.exp(0.5 * self.W_logvar)
# A_std = torch.exp(0.5 * self.A_logvar)
# kl_W = (self.W_mu**2 + W_std**2 - 2 * torch.log(W_std) - 1).sum()
# kl_A = (self.A_mu**2 + A_std**2 - 2 * torch.log(A_std) - 1).sum()
# return 0.5 * (kl_W + kl_A)
def forward(self, x):
"""
Forward pass with Bayesian weight sampling.
"""
if self.training:
A, B = self.sample_weights()
else:
A, B = self.A_mu, self.B_mu # Use deterministic weights for testing
out = self.drop(x @ A)
return out @ B # LoRA forward pass
class LoRALinearSI(nn.Module):
def __init__(self, in_features, out_features, r=8,
alpha=1.0, lambda_si=0.1, si_decay=0.99, dropout=0.1,
plasticity_base=0.5, sparsity_threshold=1e-3):
super().__init__()
self.r = r
self.alpha = alpha # Base scaling factor for LoRA updates
self.lambda_si = lambda_si # Strength of SI regularization
self.si_decay = si_decay # Decay factor for importance updates
self.plasticity_base = plasticity_base # Base plasticity level
self.sparsity_threshold = sparsity_threshold # Threshold for detecting sparse weights
# LoRA trainable parameters
self.lora_A = nn.Parameter(torch.randn(in_features, r))
self.lora_B = nn.Parameter(torch.randn(r, out_features))
self.drop = nn.Dropout(dropout)
nn.init.zeros_(self.lora_B)
nn.init.normal_(self.lora_A, std=1 / r)
# Synaptic Intelligence (SI) buffers
self.register_buffer("omega_A", torch.zeros_like(self.lora_A)) # Importance of lora_A
self.register_buffer("omega_B", torch.zeros_like(self.lora_B)) # Importance of lora_B
self.register_buffer("prev_params_A", self.lora_A.clone().detach())
self.register_buffer("prev_params_B", self.lora_B.clone().detach())
self.register_buffer("trajectory_A", torch.zeros_like(self.lora_A)) # Tracks updates for lora_A
self.register_buffer("trajectory_B", torch.zeros_like(self.lora_B)) # Tracks updates for lora_B
# self.register_buffer("plasticity", torch.ones_like(self.lora_A) * self.plasticity_base) # Dynamic plasticity control
def forward(self, x):
adaptive_alpha = self.alpha #* self.plasticity # Scale LoRA update based on plasticity
lora_update = torch.matmul(x, self.lora_A)
lora_update = self.drop(lora_update)
lora_update = torch.matmul(lora_update, self.lora_B)
return adaptive_alpha * lora_update # Dynamic scaling
def update_si_information(self):
"""Update Synaptic Intelligence importance online."""
if self.lora_A.grad is not None:
delta_theta_A = self.lora_A - self.prev_params_A
self.trajectory_A += delta_theta_A * self.lora_A.grad # Path integral for A
self.prev_params_A = self.lora_A.detach().clone()
if self.lora_B.grad is not None:
delta_theta_B = self.lora_B - self.prev_params_B
self.trajectory_B += delta_theta_B * self.lora_B.grad # Path integral for B
self.prev_params_B = self.lora_B.detach().clone()
def compute_sparsity(self, param):
"""Compute the sparsity score: fraction of near-zero values."""
return torch.mean((torch.abs(param) < self.sparsity_threshold).float())
def finalize_si_importance(self):
"""Compute final importance after training a task and adjust plasticity."""
self.omega_A = self.si_decay * self.omega_A + (1 - self.si_decay) * (self.trajectory_A / (self.lora_A**2 + 1e-6)).detach()
self.omega_B = self.si_decay * self.omega_B + (1 - self.si_decay) * (self.trajectory_B / (self.lora_B**2 + 1e-6)).detach()
self.trajectory_A.zero_()
self.trajectory_B.zero_()
# Compute sparsity scores
# sparsity_A = self.compute_sparsity(self.lora_A)
# sparsity_B = self.compute_sparsity(self.lora_B)
# Adjust plasticity dynamically based on sparsity
# self.plasticity = torch.exp(-self.omega_A) * (1 - sparsity_A)
def si_loss(self):
"""Compute the SI loss term for both LoRA parameters."""
loss_A = torch.sum(self.omega_A * (self.lora_A - self.prev_params_A) ** 2)
loss_B = torch.sum(self.omega_B * (self.lora_B - self.prev_params_B) ** 2)
return self.lambda_si * (loss_A + loss_B)
def set_plasticity(self, value: float):
"""Manually set a global plasticity value if needed."""
self.plasticity.fill_(value)
class MOELoRALinear(nn.Module):
"""
LoRA layer: Low-Rank Adaptation.
This layer consists of a low-rank decomposition of weight updates.
"""
def __init__(self, in_features, out_features, r=8, alpha=1.0, dropout=0.1, num_task=3, **kwargs):
super(MOELoRALinear, self).__init__()
self.loras = nn.ModuleList([
LoRALinear(
in_features,
out_features,
r, alpha, dropout, **kwargs) for _ in range(num_task)
])
self.num_task=num_task
def forward(self, x, i):
if isinstance(i, int):
return self.loras[i](x)
elif i.dtype == torch.float:
orig_shape = x.shape
b = orig_shape[0]
new_shape = (b//self.num_task, self.num_task) + orig_shape[1:]
x = x.reshape(new_shape)
mask_shape = i.shape + (1,)*len(orig_shape[1:])
i = i.reshape(mask_shape)
res_list = torch.stack([
self.loras[t](x[:, t]) for t in range(self.num_task)
], dim=1) #[b, task, class, dim]
res_list = res_list * i
res_list = res_list.reshape(orig_shape[:-1]+(-1,))
return res_list
res_list = torch.stack([
self.loras[t](x) for t in range(self.num_task)
], dim=1) #[b, task, class, dim]
b = res_list.shape[0]
res = res_list[torch.arange(b), i]
# print(res.shape, i.shape)
return res
class ZeroAdapter(nn.Module):
"""
LoRA layer: Low-Rank Adaptation.
This layer consists of multiple LoRA mitigating catastrophic forgetting
"""
def __init__(self, in_features, out_feature, dropout=0.1, **kwargs):
super(ZeroAdapter, self).__init__()
mid_feature = in_features // 2
self.down_linear = nn.Linear(in_features, mid_feature)
self.up_linear = nn.Linear(mid_feature, out_feature)
nn.init.zeros_(self.down_linear.weight)
nn.init.zeros_(self.down_linear.bias)
nn.init.zeros_(self.up_linear.weight)
nn.init.zeros_(self.up_linear.bias)
self.act = nn.ReLU()
self.drop = nn.Dropout(dropout)
self.lora_name = "lora_layer" # Unique name
def forward(self, x):
x = self.down_linear(x)
x = self.drop(self.act(x))
x = self.up_linear(x)
return x
class LoRAMoECLAdapter(nn.Module):
def __init__(self, in_features, mid_feature, out_feature,
num_task=6, r=8, alpha=1.0, dropout=0.1, **kwargs):
super(LoRAMoECLAdapter, self).__init__()
self.r = r
self.alpha = alpha
self.num_task = num_task
self.adapters = nn.ModuleList([
nn.Sequential(
LoRALinear(in_features, mid_feature, r, alpha, dropout),
nn.Dropout(dropout),
nn.ReLU(),
LoRALinear(mid_feature, out_feature, r, alpha, dropout),
)
for _ in range(num_task)
])
self.router = nn.Linear(in_features, num_task)
self.out_drop = nn.Dropout(dropout)
self.lora_name = "lora_layer" # Unique name
def forward(self, x, i=None):
outputs = []
logits = self.router(x)
route_prob = logits.softmax(-1)
for i in range(self.num_task):
outputs.append(self.adapters[i](x))
outputs = torch.stack(outputs, dim=-2)
outputs = torch.sum(outputs * route_prob[..., None], dim=-2)
outputs = self.out_drop(outputs)
return outputs
class LoRACLAdapter(nn.Module):
"""
LoRA layer: Low-Rank Adaptation.
This layer consists of multiple LoRA mitigating catastrophic forgetting
"""
def __init__(self, in_features, out_feature,
num_task=6, r=8, alpha=1.0, dropout=0.1, **kwargs):
super(LoRACLAdapter, self).__init__()
self.r = r
self.alpha = alpha
self.loras = nn.ModuleList([
LoRALinear(in_features, out_feature, r, alpha, dropout) for _ in range(num_task)
])
self.attn_weights = nn.ModuleList([nn.Linear(out_feature, 1) for _ in range(num_task)])
self.attn_drop = nn.Dropout(dropout)
self.num_task = num_task
# Scaling factor for LoRA
self.scaling = alpha / r
self.lora_name = "lora_layer" # Unique name
def forward(self, x, task_mask=None):
# Apply low-rank update: scaling * (A(x) * B)
#x:[b, 1, d]
assert task_mask is not None
outputs = []
output_weights = []
for i in range(self.num_task):
out = self.loras[i](x)
weight_out = self.attn_weights[i](out)
outputs.append(out)
output_weights.append(weight_out)
outputs = torch.cat(outputs, dim=1)
output_weights = torch.cat(output_weights, dim=1)
output_weights = output_weights.softmax(1)
outputs = outputs * self.attn_drop(output_weights)
# detach invalid outputs:
task_mask = task_mask[0]
task_mask = task_mask.unsqueeze(-1).expand(outputs.shape[0], -1, outputs.shape[2])
# print(task_mask.shape, outputs.shape)
outputs[task_mask==0] = outputs[task_mask==0].detach()
outputs = outputs.sum(1)
return outputs[:, None]
valid_lora_list = (LoRALinear, LoRACLAdapter, ZeroAdapter, LoRAMoECLAdapter, MOELoRALinear)
def lora_wrapper(
module,
LoraLayer=LoRALinear,
rank=8, alpha=1.0, dropout=0.1,
num_task=6,
**kwargs):
"""
Creates a separate LoRA module that mirrors the Linear layers in the original model.
"""
if isinstance(module, nn.ModuleList):
lora_module = nn.ModuleList()
for m in module:
lora_module.append(lora_wrapper(
m, LoraLayer,
rank=rank, alpha=alpha, dropout=dropout,num_task=num_task
))
return lora_module
if isinstance(module, nn.ModuleDict):
lora_module = nn.ModuleDict()
for k,v in module.items():
lora_module[f'lora_{k}'] = lora_wrapper(
v, LoraLayer,
rank=rank, alpha=alpha, dropout=dropout,num_task=num_task
)
return lora_module
if len(list(module.named_modules())) == 1 :
if not isinstance(module, nn.Linear):
print(f'Wrap non nn.Linear unit{type(module)}, skipping with Identity')
return nn.Identity()
lora_module = LoraLayer(module.in_features, module.out_features,
r=rank, alpha=alpha,dropout=dropout, num_task=num_task,**kwargs)
return lora_module
# sequential case
lora_module = nn.Sequential()
for name, child in module.named_children():
if isinstance(child, nn.Linear):
lora_layer = LoraLayer(child.in_features, child.out_features,
r=rank, alpha=alpha,dropout=dropout, num_task=num_task, **kwargs)
lora_module.add_module(f'lora_{name}', lora_layer)
elif isinstance(child, nn.Sequential):
lora_module.add_module(f'lora_{name}',
lora_wrapper(child,
LoraLayer,
rank=rank, alpha=alpha, dropout=dropout,num_task=num_task,
)
)
else:
lora_module.add_module(f'lora_{name}', nn.Identity())
return lora_module
def single_peft_forward(x, model, lora_model, lora_only=False, idx=None):
if lora_only:
return lora_model(x, i=idx)
return model(x) + lora_model(x, i=idx)
def peft_wrapper_forward(x, model, lora_model, use_lora=True,
layer_idx=-1, layer_name="", lora_only=False, task_idx=None):
"""
Custom forward function to combine original model output with LoRA output.
layer_idx: can be specified for (nn.ModuleList) model; Default: running sequentially through whole ModuleList
layer_name: can be specified for (nn.ModuleDict) model; Default:running sequentially through whole ModuleDict
lora_only: if lora_only=True, forward function will only pass through the lora layer when meet with matched Linear
"""
if isinstance(model, nn.ModuleList):
if layer_idx > -1:
return single_peft_forward(x, model[layer_idx], lora_model[layer_idx], lora_only, task_idx)
if isinstance(model, nn.ModuleDict):
if layer_name != "":
return single_peft_forward(x, model[layer_name], lora_model[layer_name], lora_only, task_idx)
if len(list(model.named_modules())) == 1:
return single_peft_forward(x, model, lora_model, lora_only, task_idx)
def process_layer(orig_layer, lora_layer, x):
""" Recursively process nested nn.Sequential layers """
if isinstance(orig_layer, nn.Sequential) and isinstance(lora_layer, nn.Sequential):
for o_layer, l_layer in zip(orig_layer.children(), lora_layer.children()):
x = process_layer(o_layer, l_layer, x)
return x
else:
if use_lora and not isinstance(lora_layer, nn.Identity):
return single_peft_forward(x, orig_layer, lora_layer, lora_only, task_idx)
else:
return orig_layer(x)
for orig_layer, lora_layer in zip(model.children(), lora_model.children()):
x = process_layer(orig_layer, lora_layer, x)
return x
def finetuning_detach(model):
'''
work with a detach for customed layer
ensure if some sublayer inside containing such LoRA layer
or adapter with "lora_name" attribute,
also have this finetuning function and lora_name attr
'''
for name, module in model.named_modules():
if 'lora' in name:
for param in module.parameters():
param.requires_grad = True
else:
for param in module.parameters():
param.requires_grad = False # disable param
if isinstance(module, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
module.eval()
def frozen_grad(model):
for param in model.parameters():
param.requires_grad = False
return model
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
# self.model = nn.Sequential(
# nn.Linear(10, 20),
# nn.ReLU(),
# nn.Sequential(
# nn.Linear(20, 30),
# nn.ReLU(),
# nn.Linear(30, 40)
# )
# )
# self.model = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
self.model = nn.ModuleDict()
for i in range(3):
self.model[f'{i}'] = nn.Linear(10,10)
self.lora_layer = lora_wrapper(
self.model,
ZeroAdapter,
rank=4, alpha=1.0)
def forward(self, x):
x = peft_wrapper_forward(x, self.model, self.lora_layer)
return x
def retreive_bayesian_lora_param(module):
'''
input, any nn.Module
searching for all Bayesian Lora param
return: lora_dict: Dict[sub_name: Dict['A_mu','B_mu','A_logvar','B_logvar']]
'''
lora_dict = {}
lora_list = set(['A_mu','B_mu','A_logvar','B_logvar'])
if isinstance(module, BayesianLinear):
lora_dict['.'] = dict()
for name,m in module.named_parameters():
lora_dict['.'][name] = m
return lora_dict
for name,m in module.named_parameters():
name_list = name.split('.')
if name_list[-2] in lora_list:
m_prefix = ".".join(name_list[:-2])
if m_prefix not in lora_dict:
lora_dict[m_prefix] = dict()
lora_dict[m_prefix][name.split('.')[-1]] = m
return lora_dict
def test_lora_si():
from time import time
import numpy as np
lora_model = LoRALinearSI(
256, 256, 16
)
t = []
for _ in range(10):
s = time()
x = torch.randn(2, 256)
y = lora_model(x)
loss = lora_model.si_loss()
t.append(time()-s)
print(loss, np.mean(t))
def test_kl_lora():
lora_layer = BayesianLinear(
32, 32, r=8
)
inputs = torch.randn(4, 10, 32)
out = lora_layer(inputs)
bayesian_params = retreive_bayesian_lora_param(lora_layer)
loss = 0.
for v_dict in bayesian_params.values():
print(v_dict.keys())
B_std = torch.exp(0.5 * v_dict['B_logvar'])
A_std = torch.exp(0.5 * v_dict['A_logvar'])
kl_B = (v_dict['B_mu']**2 + B_std**2 - 2 * torch.log(B_std) - 1).sum()
kl_A = (v_dict['A_mu']**2 + A_std**2 - 2 * torch.log(A_std) - 1).sum()
module_loss = 0.5 * (kl_B + kl_A)
loss += module_loss
print(out.shape, loss)
# Example usage
if __name__ == "__main__":
# Define a nested Sequential model
# model = TestModule()
# finetuning_detach(model)
# x = torch.randn(4, 10)
# print(model(x).shape)
# # Print the model structure after attaching LoRA layers
# print("Model structure after attaching LoRA layers:\n", model)
# for name, param in model.named_parameters():
# print(name, param.shape, param.requires_grad)
# test_lora_si()
test_kl_lora()