| | import re |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | class DiffLoRALinear(nn.Module): |
| | """ |
| | Fused DiffLoRALinear implements a differential low-rank adapter: |
| | Δy = (α/r) * [A_pos @ B_pos - τ * (A_neg @ B_neg)] |
| | The fused version computes: |
| | update = x_dropped @ concat(A_pos, A_neg) @ concat(B_pos, -τ * B_neg) |
| | This version explicitly moves τ to the same device as the input. |
| | """ |
| | def __init__(self, in_features: int, out_features: int, r: int = 8, |
| | lora_alpha: float = 16.0, dropout: float = 0.0, |
| | merge_weights: bool = False, init_method: str = "kaiming"): |
| | super().__init__() |
| | |
| | self.linear = nn.Linear(in_features, out_features, bias=False) |
| | self.linear.weight.requires_grad = False |
| | self.in_features = in_features |
| | self.out_features = out_features |
| | self.r = r |
| | self.scaling = lora_alpha / r |
| | self.merge_weights = merge_weights |
| | self.merged = False |
| | self.lora_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() |
| | |
| | self.A_pos = nn.Parameter(torch.zeros(in_features, r)) |
| | self.B_pos = nn.Parameter(torch.zeros(r, out_features)) |
| | self.A_neg = nn.Parameter(torch.zeros(in_features, r)) |
| | self.B_neg = nn.Parameter(torch.zeros(r, out_features)) |
| | self.tau = nn.Parameter(torch.tensor(1.0)) |
| | self.reset_parameters(init_method) |
| |
|
| | def reset_parameters(self, init_method: str = "kaiming"): |
| | if init_method == "kaiming": |
| | nn.init.kaiming_uniform_(self.A_pos, a=math.sqrt(5)) |
| | nn.init.zeros_(self.B_pos) |
| | nn.init.kaiming_uniform_(self.A_neg, a=math.sqrt(5)) |
| | nn.init.zeros_(self.B_neg) |
| | elif init_method == "xavier": |
| | nn.init.xavier_uniform_(self.A_pos) |
| | nn.init.zeros_(self.B_pos) |
| | nn.init.xavier_uniform_(self.A_neg) |
| | nn.init.zeros_(self.B_neg) |
| | else: |
| | raise ValueError(f"Unknown init_method: {init_method}") |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | if self.merge_weights and self.merged: |
| | return self.linear(x) |
| | base_out = self.linear(x) |
| | x_dropped = self.lora_dropout(x) |
| | |
| | tau = self.tau.to(x.device) |
| | |
| | combined_A = torch.cat([self.A_pos, self.A_neg], dim=1) |
| | combined_B = torch.cat([self.B_pos, -tau * self.B_neg], dim=0) |
| | update = x_dropped @ combined_A @ combined_B |
| | delta = self.scaling * update |
| | return base_out + delta |
| |
|
| | def replace_linear_with_diff_lora(module: nn.Module, target_regex: str, r: int): |
| | """ |
| | Recursively replace nn.Linear modules whose names match target_regex |
| | with DiffLoRALinear modules using rank r. |
| | """ |
| | for name, child in module.named_children(): |
| | if isinstance(child, nn.Linear) and re.search(target_regex, name, re.IGNORECASE): |
| | new_layer = DiffLoRALinear( |
| | in_features=child.in_features, |
| | out_features=child.out_features, |
| | r=r, |
| | lora_alpha=16.0, |
| | dropout=0.1, |
| | merge_weights=False, |
| | ) |
| | new_layer.linear.weight.data.copy_(child.weight.data) |
| | setattr(module, name, new_layer) |
| | else: |
| | replace_linear_with_diff_lora(child, target_regex, r) |
| |
|