import re import math import torch import torch.nn as nn import torch.nn.functional as F class DiffLoRALinear(nn.Module): """ Fused DiffLoRALinear implements a differential low-rank adapter: Δy = (α/r) * [A_pos @ B_pos - τ * (A_neg @ B_neg)] The fused version computes: update = x_dropped @ concat(A_pos, A_neg) @ concat(B_pos, -τ * B_neg) This version explicitly moves τ to the same device as the input. """ def __init__(self, in_features: int, out_features: int, r: int = 8, lora_alpha: float = 16.0, dropout: float = 0.0, merge_weights: bool = False, init_method: str = "kaiming"): super().__init__() # Base linear layer (frozen) self.linear = nn.Linear(in_features, out_features, bias=False) self.linear.weight.requires_grad = False self.in_features = in_features self.out_features = out_features self.r = r self.scaling = lora_alpha / r self.merge_weights = merge_weights self.merged = False self.lora_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() # Low-rank parameters for positive and negative components self.A_pos = nn.Parameter(torch.zeros(in_features, r)) self.B_pos = nn.Parameter(torch.zeros(r, out_features)) self.A_neg = nn.Parameter(torch.zeros(in_features, r)) self.B_neg = nn.Parameter(torch.zeros(r, out_features)) self.tau = nn.Parameter(torch.tensor(1.0)) # Scalar parameter self.reset_parameters(init_method) def reset_parameters(self, init_method: str = "kaiming"): if init_method == "kaiming": nn.init.kaiming_uniform_(self.A_pos, a=math.sqrt(5)) nn.init.zeros_(self.B_pos) nn.init.kaiming_uniform_(self.A_neg, a=math.sqrt(5)) nn.init.zeros_(self.B_neg) elif init_method == "xavier": nn.init.xavier_uniform_(self.A_pos) nn.init.zeros_(self.B_pos) nn.init.xavier_uniform_(self.A_neg) nn.init.zeros_(self.B_neg) else: raise ValueError(f"Unknown init_method: {init_method}") def forward(self, x: torch.Tensor) -> torch.Tensor: if self.merge_weights and self.merged: return self.linear(x) base_out = self.linear(x) x_dropped = self.lora_dropout(x) # Ensure tau is on the same device as x tau = self.tau.to(x.device) # Concatenate positive and negative parameters along the rank dimension combined_A = torch.cat([self.A_pos, self.A_neg], dim=1) # (in_features, 2*r) combined_B = torch.cat([self.B_pos, -tau * self.B_neg], dim=0) # (2*r, out_features) update = x_dropped @ combined_A @ combined_B delta = self.scaling * update return base_out + delta def replace_linear_with_diff_lora(module: nn.Module, target_regex: str, r: int): """ Recursively replace nn.Linear modules whose names match target_regex with DiffLoRALinear modules using rank r. """ for name, child in module.named_children(): if isinstance(child, nn.Linear) and re.search(target_regex, name, re.IGNORECASE): new_layer = DiffLoRALinear( in_features=child.in_features, out_features=child.out_features, r=r, lora_alpha=16.0, dropout=0.1, merge_weights=False, ) new_layer.linear.weight.data.copy_(child.weight.data) setattr(module, name, new_layer) else: replace_linear_with_diff_lora(child, target_regex, r)