File size: 1,767 Bytes
b8d4c08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
from ndlinear import NdLinear

# This file contains the custom building blocks for the NdLinear-LoRA architecture.
# It should be in the same directory as the model when loading it.

def find_factor(n):
    """Finds the most balanced integer factors for n."""
    for i in range(int(n ** 0.5), 0, -1):
        if n % i == 0:
            return (i, n // i)
    return (1, n)

class NdLinearLoRA(nn.Module):
    """The NdLinear-LoRA adapter layer."""
    def __init__(self, d_in, d_out, alpha=1.0, dropout=0.0):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.in_factors = find_factor(d_in)
        self.out_factors = find_factor(d_out)
        self.adapter = NdLinear(
            input_dims=self.in_factors,
            hidden_size=self.out_factors,
            transform_outer=False,
            bias=False
        )
        self.scaling = alpha
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        orig_shape = x.shape
        x = self.drop(x).view(-1, *self.in_factors)
        y = self.adapter(x).view(*orig_shape[:-1], self.d_out)
        return y * self.scaling

class LinearWithNdLinearLoRA(nn.Module):
    """A nn.Linear layer wrapped with the NdLinear-LoRA adapter."""
    def __init__(self, base_layer, alpha=1.0, dropout=0.0):
        super().__init__()
        self.base_layer = base_layer
        for param in self.base_layer.parameters():
            param.requires_grad = False
        self.adapter = NdLinearLoRA(
            d_in=self.base_layer.in_features,
            d_out=self.base_layer.out_features,
            alpha=alpha,
            dropout=dropout
        )

    def forward(self, x):
        return self.base_layer(x) + self.adapter(x)