File size: 3,670 Bytes
9b5d8a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import re
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class DiffLoRALinear(nn.Module):
    """
    Fused DiffLoRALinear implements a differential low-rank adapter:
      Δy = (α/r) * [A_pos @ B_pos - τ * (A_neg @ B_neg)]
    The fused version computes:
      update = x_dropped @ concat(A_pos, A_neg) @ concat(B_pos, -τ * B_neg)
    This version explicitly moves τ to the same device as the input.
    """
    def __init__(self, in_features: int, out_features: int, r: int = 8,
                 lora_alpha: float = 16.0, dropout: float = 0.0,
                 merge_weights: bool = False, init_method: str = "kaiming"):
        super().__init__()
        # Base linear layer (frozen)
        self.linear = nn.Linear(in_features, out_features, bias=False)
        self.linear.weight.requires_grad = False
        self.in_features = in_features
        self.out_features = out_features
        self.r = r
        self.scaling = lora_alpha / r
        self.merge_weights = merge_weights
        self.merged = False
        self.lora_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        # Low-rank parameters for positive and negative components
        self.A_pos = nn.Parameter(torch.zeros(in_features, r))
        self.B_pos = nn.Parameter(torch.zeros(r, out_features))
        self.A_neg = nn.Parameter(torch.zeros(in_features, r))
        self.B_neg = nn.Parameter(torch.zeros(r, out_features))
        self.tau = nn.Parameter(torch.tensor(1.0))  # Scalar parameter
        self.reset_parameters(init_method)

    def reset_parameters(self, init_method: str = "kaiming"):
        if init_method == "kaiming":
            nn.init.kaiming_uniform_(self.A_pos, a=math.sqrt(5))
            nn.init.zeros_(self.B_pos)
            nn.init.kaiming_uniform_(self.A_neg, a=math.sqrt(5))
            nn.init.zeros_(self.B_neg)
        elif init_method == "xavier":
            nn.init.xavier_uniform_(self.A_pos)
            nn.init.zeros_(self.B_pos)
            nn.init.xavier_uniform_(self.A_neg)
            nn.init.zeros_(self.B_neg)
        else:
            raise ValueError(f"Unknown init_method: {init_method}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.merge_weights and self.merged:
            return self.linear(x)
        base_out = self.linear(x)
        x_dropped = self.lora_dropout(x)
        # Ensure tau is on the same device as x
        tau = self.tau.to(x.device)
        # Concatenate positive and negative parameters along the rank dimension
        combined_A = torch.cat([self.A_pos, self.A_neg], dim=1)  # (in_features, 2*r)
        combined_B = torch.cat([self.B_pos, -tau * self.B_neg], dim=0)  # (2*r, out_features)
        update = x_dropped @ combined_A @ combined_B
        delta = self.scaling * update
        return base_out + delta

def replace_linear_with_diff_lora(module: nn.Module, target_regex: str, r: int):
    """
    Recursively replace nn.Linear modules whose names match target_regex
    with DiffLoRALinear modules using rank r.
    """
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and re.search(target_regex, name, re.IGNORECASE):
            new_layer = DiffLoRALinear(
                in_features=child.in_features,
                out_features=child.out_features,
                r=r,
                lora_alpha=16.0,
                dropout=0.1,
                merge_weights=False,
            )
            new_layer.linear.weight.data.copy_(child.weight.data)
            setattr(module, name, new_layer)
        else:
            replace_linear_with_diff_lora(child, target_regex, r)