UniBioTransfer / lora_layers.py
scy639's picture
Upload folder using huggingface_hub
2b534de verified
"""
LoRA (Low-Rank Adaptation) implementation for MLP layers.
Replaces qkv projections in attention and the FFN MLP layers.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
USE_LORA :bool = 1 # enable LoRA replacement for MLP layers (and Conv2d)
if USE_LORA:
LORA_dropout :float = 0.0 # LoRA dropout rate
LORA_apply_to_conv :bool = 1 # also apply LoRA to Conv2d layers
LORA_freeze_base :bool = False
LORA_DEBUG :bool = 0
FORCE_SAME_RANK_ACROSS_TASKS :bool = 0
DONT_lora_if_dim_lt :int = 90 # 0: disable. increase for low-dim layers (e.g., in/out conv dim < 32)
DONT_lora_if_rankFrac_gt :float = 0.3
class LoRALinear(nn.Module):
"""
LoRA layer that wraps a frozen Linear layer with low-rank adaptation.
Args:
original_linear: original nn.Linear layer that will be frozen
rank: LoRA rank (r)
dropout: dropout probability
"""
def __init__(
self,
original_linear: nn.Linear,
rank: int = 4,
dropout: float = 0.0,
freeze_base: bool = True,
):
super().__init__()
self.in_features = original_linear.in_features
self.out_features = original_linear.out_features
self.rank = rank
self.scaling = 2.0
# Freeze the original weights
self.original_linear = original_linear
if freeze_base:
for param in self.original_linear.parameters():
param.requires_grad = False
# LoRA low-rank decomposition: W = W_0 + B @ A, where B: out_features x rank, A: rank x in_features
self.lora_A = nn.Parameter(torch.zeros(rank, self.in_features))
self.lora_B = nn.Parameter(torch.zeros(self.out_features, rank))
# Initialization
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B) # initialize B to 0 so LoRA has no initial effect
# Dropout
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Output from the frozen original linear layer
result = self.original_linear(x)
# LoRA low-rank update: x @ A^T @ B^T
# x: (..., in_features)
# lora_A: (rank, in_features) -> A^T: (in_features, rank)
# lora_B: (out_features, rank) -> B^T: (rank, out_features)
lora_out = self.dropout(x) @ self.lora_A.T @ self.lora_B.T
return result + lora_out * self.scaling
def __repr__(self):
return f"LoRALinear(in_features={self.in_features}, out_features={self.out_features}, rank={self.rank}, scaling={self.scaling})"
def replace_linear_with_lora(
module: nn.Module,
rank: int = 4,
dropout: float = 0.0,
target_modules: list = None,
verbose: bool = True,
):
"""
Recursively replace nn.Linear layers within a module with LoRALinear wrappers.
Args:
module: module whose linear layers should be replaced
rank: LoRA rank
dropout: dropout probability
target_modules: specific module names to replace; None means all linears
e.g.: ['to_q', 'to_k', 'to_v', 'to_out'] for attention
['net.0', 'net.2'] for FeedForward
verbose: whether to log replacements
Returns:
the module with replacements applied
"""
replaced_count = 0
for name, child in module.named_children():
# Skip modules not in the target list (if filtering is enabled)
if target_modules is not None and name not in target_modules:
# Continue recursing into child modules
replace_linear_with_lora(child, rank, dropout, target_modules, verbose)
continue
if isinstance(child, nn.Linear):
# Replace with LoRALinear
lora_layer = LoRALinear(child, rank=rank, dropout=dropout, freeze_base=LORA_freeze_base)
setattr(module, name, lora_layer)
replaced_count += 1
if verbose:
print(f"[LoRA] Replaced {name}: {child.in_features} -> {child.out_features} with rank={rank}")
elif isinstance(child, nn.Sequential):
# Handle Sequential containers (e.g., FeedForward nets)
new_sequential = nn.Sequential()
for idx, submodule in enumerate(child):
if isinstance(submodule, nn.Linear):
lora_layer = LoRALinear(submodule, rank=rank, dropout=dropout, freeze_base=LORA_freeze_base)
new_sequential.add_module(str(idx), lora_layer)
replaced_count += 1
if verbose:
print(f"[LoRA] Replaced {name}.{idx}: {submodule.in_features} -> {submodule.out_features} with rank={rank}")
else:
new_sequential.add_module(str(idx), submodule)
setattr(module, name, new_sequential)
else:
# Recurse into the remaining submodules
replace_linear_with_lora(child, rank, dropout, target_modules, verbose)
return module
def count_lora_parameters(module: nn.Module):
"""
Count LoRA parameters within a module.
Returns:
dict: {'trainable': trainable params, 'frozen': frozen params, 'total': total params}
"""
trainable_params = 0
frozen_params = 0
for name, param in module.named_parameters():
num_params = param.numel()
if param.requires_grad:
trainable_params += num_params
else:
frozen_params += num_params
total_params = trainable_params + frozen_params
return {
'trainable': trainable_params,
'frozen': frozen_params,
'total': total_params,
'trainable_ratio': trainable_params / total_params if total_params > 0 else 0,
}
def print_lora_parameters(module: nn.Module, name: str = "Model"):
"""Print LoRA parameter statistics."""
stats = count_lora_parameters(module)
print(f"\n{'='*60}")
print(f"{name} Parameter Statistics:")
print(f"{'='*60}")
print(f"Trainable params: {stats['trainable']:,} ({stats['trainable_ratio']*100:.2f}%)")
print(f"Frozen params: {stats['frozen']:,} ({(1-stats['trainable_ratio'])*100:.2f}%)")
print(f"Total params: {stats['total']:,}")
print(f"{'='*60}\n")
class LoRAConv2d(nn.Module):
"""
LoRA layer for Conv2d.
Treat Conv2d as a matrix multiplication:
- flatten kernel: (out_channels, in_channels, k, k) -> (out_channels, in_channels*k*k)
- apply low-rank decomposition: W = W_0 + B @ A
Args:
original_conv: original nn.Conv2d layer that will be frozen
rank: LoRA rank (r)
dropout: dropout probability
"""
def __init__(
self,
original_conv: nn.Conv2d,
rank: int = 4,
dropout: float = 0.0,
freeze_base: bool = True,
):
super().__init__()
self.out_channels = original_conv.out_channels
self.in_channels = original_conv.in_channels
self.kernel_size = original_conv.kernel_size
self.stride = original_conv.stride
self.padding = original_conv.padding
self.dilation = original_conv.dilation
self.groups = original_conv.groups
self.rank = rank
self.scaling = 2.0
# Freeze the original weights
self.original_conv = original_conv
if freeze_base:
for param in self.original_conv.parameters():
param.requires_grad = False
# LoRA low-rank decomposition
# lora_A: (rank, in_channels, kernel_size, kernel_size)
# lora_B: (out_channels, rank, 1, 1) - via 1x1 convolution
self.lora_A = nn.Parameter(torch.zeros(
rank,
self.in_channels // self.groups,
self.kernel_size[0],
self.kernel_size[1]
))
self.lora_B = nn.Parameter(torch.zeros(self.out_channels, rank, 1, 1))
# Initialization
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B) # initialize B to 0
# Dropout
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
print(f"param orig:lora (M) = {self.original_conv.weight.numel()/1024/1024}:{self.lora_A.numel()+self.lora_B.numel()/1024/1024}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Output from the frozen original convolution
# Use no_grad to avoid computing gradients for the base weights
result = self.original_conv(x)
# LoRA low-rank update
# first apply lora_A (down projection) then lora_B (up projection)
x_dropped = self.dropout(x)
lora_out = F.conv2d(
x_dropped,
self.lora_A,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups
)
lora_out = F.conv2d(lora_out, self.lora_B)
return result + lora_out * self.scaling
def __repr__(self):
return (f"LoRAConv2d(in_channels={self.in_channels}, out_channels={self.out_channels}, "
f"kernel_size={self.kernel_size}, rank={self.rank}, scaling={self.scaling})")
def _auto_lora_rank(in_features: int, out_features: int) -> int:
m = min(in_features, out_features)
r = max(LORA_rank_min, int(round(m / max(1.0, LORA_rank_ratio))))
if (LORA_rank_max is not None) and (r > LORA_rank_max):
r = LORA_rank_max
return max(1, r)
def _svd_low_rank(M: torch.Tensor, rank: int):
# M: [out, in]
orig_device = M.device
orig_dtype = M.dtype
if 1:
M = M.to(device=torch.device('cuda'), dtype=torch.float32)
U, S, Vh = torch.linalg.svd(M, full_matrices=False)
r = min(rank, U.shape[1], Vh.shape[0])
U_r = U[:, :r]
S_r = S[:r]
Vh_r = Vh[:r, :]
S_root = torch.sqrt(torch.clamp(S_r, min=0))
B = U_r @ torch.diag(S_root) # [out, r]
A = torch.diag(S_root) @ Vh_r # [r, in]
B = B.to(device=orig_device, dtype=orig_dtype)
A = A.to(device=orig_device, dtype=orig_dtype)
S = S.to(device=orig_device, dtype=orig_dtype)
return B, A, S
def _svdvals_squared(M: torch.Tensor) -> torch.Tensor:
# Return squared singular values (energy), sorted in descending order; M: [out, in]
orig_device = M.device
orig_dtype = M.dtype
if 1:
M = M.to(device=torch.device('cuda'), dtype=torch.float32)
S = torch.linalg.svdvals(M)
S2 = (S.float() ** 2)
return S2.to(device=orig_device, dtype=torch.float32)
def _compute_adaptive_rank_from_S2_list(
list_S2: list,
avg_threshold: float = None,
min_threshold: float = None,
max_rank: int = None,
) -> int:
# list_S2: squared singular value vectors (descending) for each matrix
assert len(list_S2) > 0
if avg_threshold is None:
avg_threshold = ADAPTIVE_RANK_AVG_ENERGY_THRESH
if min_threshold is None:
min_threshold = ADAPTIVE_RANK_MIN_ENERGY_THRESH
totals = []
lengths = []
for s2 in list_S2:
assert s2.numel() > 0
total = s2.sum()
# Quick fail: zero ΔW has zero energy, so thresholds can't be evaluated
assert float(total.item()) > 0.0, "Zero energy in weight_diff; cannot determine adaptive rank"
totals.append(total)
lengths.append(int(s2.shape[0]))
R_cap = min(lengths)
if LORA_rank_max is not None:
R_cap = min(R_cap, int(LORA_rank_max))
if max_rank is not None:
R_cap = min(R_cap, int(max_rank))
R_cap = max(1, R_cap)
# Iterate ranks r to see if both average and minimum energy ratios meet thresholds
for r in range(1, R_cap + 1):
ratios = []
for s2, total in zip(list_S2, totals):
captured = s2[:r].sum()
ratios.append(float((captured / total).item()))
avg_ratio = sum(ratios) / len(ratios)
min_ratio = min(ratios)
if (avg_ratio >= avg_threshold) and (min_ratio >= min_threshold):
ret = min(int(R_cap), max(int(LORA_rank_min), int(r)))
return ret
# If no rank satisfies both thresholds, fail fast instead of silently degrading
raise AssertionError(f"No rank satisfies avg>={avg_threshold} and min>={min_threshold} up to R_cap={R_cap}")
def _compute_per_task_ranks_from_S2_list(
list_S2: list,
min_threshold: float = None,
max_rank: int = None,
) -> list:
# Compute rank per matrix so its energy ratio >= min_threshold (uses min threshold only)
assert len(list_S2) > 0
if min_threshold is None:
min_threshold = ADAPTIVE_RANK_MIN_ENERGY_THRESH
ret = []
for i, s2 in enumerate(list_S2):
assert s2.numel() > 0
total = s2.sum()
assert float(total.item()) > 0.0, "Zero energy in weight_diff; cannot determine adaptive rank"
R_cap = int(s2.shape[0])
if LORA_rank_max is not None:
R_cap = min(R_cap, int(LORA_rank_max))
if max_rank is not None:
R_cap = min(R_cap, int(max_rank))
R_cap = max(1, R_cap)
found = R_cap
# Task-level threshold: when ranks are allowed to differ, use TASK_2_adaptive_rank_min_energy_thresh
thres_this = TASK_2_adaptive_rank_min_energy_thresh[i] if (not FORCE_SAME_RANK_ACROSS_TASKS) else min_threshold
for r in range(1, R_cap + 1):
ratio = s2[:r].sum() / total
if float(ratio.item()) >= float(thres_this):
found = r
break
ret.append(int(max(int(LORA_rank_min), int(found))))
return ret
def compute_adaptive_rank_for_linear_diffs(
weight_diffs: list,
avg_threshold: float = None,
min_threshold: float = None,
max_rank: int = None,
per_task: bool = None,
):
# weight_diffs: List[Tensor [out, in]]
assert isinstance(weight_diffs, (list, tuple)) and len(weight_diffs) > 0
if per_task is None:
per_task = not FORCE_SAME_RANK_ACROSS_TASKS
list_S2 = [_svdvals_squared(M) for M in weight_diffs]
out0, in0 = weight_diffs[0].shape
if per_task:
ranks = _compute_per_task_ranks_from_S2_list(list_S2, min_threshold, max_rank)
print(f"[AdaptiveRank-Linear per-task] in={in0} out={out0} ranks={ranks}")
return ranks
else:
ret = _compute_adaptive_rank_from_S2_list(list_S2, None, min_threshold, max_rank)
print(f"[AdaptiveRank-Linear] in={in0} out={out0} rank={ret}")
return ret
def compute_adaptive_rank_for_conv_diffs(
weight_diffs: list,
avg_threshold: float = None,
min_threshold: float = None,
max_rank: int = None,
per_task: bool = None,
):
# weight_diffs: List[Tensor [out, in, kH, kW]] -> reshape to [out, in*k*k]
assert isinstance(weight_diffs, (list, tuple)) and len(weight_diffs) > 0
if per_task is None:
per_task = not FORCE_SAME_RANK_ACROSS_TASKS
list_S2 = []
for W in weight_diffs:
out_c, in_c, kH, kW = W.shape
M = W.reshape(out_c, in_c * kH * kW)
list_S2.append(_svdvals_squared(M))
out0, in0, kH0, kW0 = weight_diffs[0].shape
if per_task:
ranks = _compute_per_task_ranks_from_S2_list(list_S2, min_threshold, max_rank)
print(f"[AdaptiveRank-Conv per-task] in_ch={in0} out_ch={out0} kernel=({kH0},{kW0}) ranks={ranks}")
return ranks
else:
ret = _compute_adaptive_rank_from_S2_list(list_S2, None, min_threshold, max_rank)
print(f"[AdaptiveRank-Conv] in_ch={in0} out_ch={out0} kernel=({kH0},{kW0}) rank={ret}")
return ret
class LoRAAdapterLinearOnly(nn.Module):
"""
Incremental LoRA (no base Linear) that returns x @ A^T @ B^T + bias_delta.
"""
def __init__(self, in_features: int, out_features: int, rank: int = None, dropout: float = 0.0, scaling: float = 1.0, use_bias_delta: bool = True):
super().__init__()
if rank is None:
rank = _auto_lora_rank(in_features, out_features)
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.scaling = scaling
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
self.use_bias_delta = use_bias_delta
if use_bias_delta:
self.lora_bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter('lora_bias', None)
# init
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
@torch.no_grad()
def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None):
# weight_diff: [out, in]
B, A, S = _svd_low_rank(weight_diff.float(), self.rank)
self.lora_A.copy_(A.to(self.lora_A.dtype).to(self.lora_A.device))
self.lora_B.copy_(B.to(self.lora_B.dtype).to(self.lora_B.device))
if self.use_bias_delta and (bias_diff is not None):
self.lora_bias.copy_(bias_diff)
if LORA_DEBUG:
energy_total = (S.float() ** 2).sum().item()
energy_top = (S[: self.rank].float() ** 2).sum().item()
energy_ratio = energy_top / max(1e-12, energy_total)
approx = (B @ A).to(weight_diff.device).to(weight_diff.dtype)
err = torch.linalg.norm((approx - weight_diff).float()).item()
base = torch.linalg.norm(weight_diff.float()).item()
rel_err = err / max(1e-12, base)
bias_norm = 0.0 if (bias_diff is None) else float(torch.linalg.norm(bias_diff.float()).item())
print(f"[LoRA-Linear init] shape={tuple(weight_diff.shape)} rank={self.rank} energy={energy_ratio:.4f} rel_err={rel_err:.6f} bias_norm={bias_norm:.6f}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
update = self.dropout(x) @ self.lora_A.T @ self.lora_B.T
if self.lora_bias is not None:
update = update + self.lora_bias
return update * self.scaling
class LoRAAdapterConv2dOnly(nn.Module):
"""
Incremental LoRA for Conv2d: convolve with A then 1x1 B, return the delta.
"""
def __init__(self, in_channels: int, out_channels: int, kernel_size: tuple, stride: tuple, padding: tuple, dilation: tuple, groups: int = 1, rank: int = None, dropout: float = 0.0, scaling: float = 1.0, use_bias_delta: bool = True):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(padding, int):
padding = (padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation)
kH, kW = kernel_size
if rank is None:
# Estimate rank from the flattened in/out dimensions
rank = _auto_lora_rank(in_channels * kH * kW, out_channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.rank = rank
self.scaling = scaling
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
# A: [rank, in/groups, kH, kW]
self.lora_A = nn.Parameter(torch.zeros(rank, in_channels // groups, kH, kW))
# B: [out, rank, 1, 1]
self.lora_B = nn.Parameter(torch.zeros(out_channels, rank, 1, 1))
self.use_bias_delta = use_bias_delta
if use_bias_delta:
self.lora_bias = nn.Parameter(torch.zeros(out_channels))
else:
self.register_parameter('lora_bias', None)
# init
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
@torch.no_grad()
def init_from_diff(self, weight_diff: torch.Tensor, bias_diff: torch.Tensor = None):
# weight_diff: [out, in, kH, kW]
out_c, in_c, kH, kW = weight_diff.shape
M = weight_diff.reshape(out_c, in_c * kH * kW)
B, A, S = _svd_low_rank(M.float(), self.rank) # B:[out,r], A:[r,in*k*k]
A_reshaped = A.view(self.rank, in_c, kH, kW)
self.lora_A.copy_(A_reshaped)
self.lora_B.copy_(B.view(out_c, self.rank, 1, 1))
if self.lora_bias is not None and (bias_diff is not None):
self.lora_bias.copy_(bias_diff)
if LORA_DEBUG:
energy_total = (S.float() ** 2).sum().item()
energy_top = (S[: self.rank].float() ** 2).sum().item()
energy_ratio = energy_top / max(1e-12, energy_total)
approx = (B @ A).to(M.device).to(M.dtype)
err = torch.linalg.norm((approx - M).float()).item()
base = torch.linalg.norm(M.float()).item()
rel_err = err / max(1e-12, base)
bias_norm = 0.0 if (bias_diff is None) else float(torch.linalg.norm(bias_diff.float()).item())
print(f"[LoRA-Conv init] out_in_k=({out_c},{in_c},{kH}x{kW}) rank={self.rank} energy={energy_ratio:.4f} rel_err={rel_err:.6f} bias_norm={bias_norm:.6f}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_d = self.dropout(x)
u = F.conv2d(x_d, self.lora_A, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
u = F.conv2d(u, self.lora_B)
if self.lora_bias is not None:
u = u + self.lora_bias.view(1, -1, 1, 1)
return u * self.scaling