File size: 1,767 Bytes
b8d4c08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import torch.nn as nn
from ndlinear import NdLinear
# This file contains the custom building blocks for the NdLinear-LoRA architecture.
# It should be in the same directory as the model when loading it.
def find_factor(n):
"""Finds the most balanced integer factors for n."""
for i in range(int(n ** 0.5), 0, -1):
if n % i == 0:
return (i, n // i)
return (1, n)
class NdLinearLoRA(nn.Module):
"""The NdLinear-LoRA adapter layer."""
def __init__(self, d_in, d_out, alpha=1.0, dropout=0.0):
super().__init__()
self.d_in = d_in
self.d_out = d_out
self.in_factors = find_factor(d_in)
self.out_factors = find_factor(d_out)
self.adapter = NdLinear(
input_dims=self.in_factors,
hidden_size=self.out_factors,
transform_outer=False,
bias=False
)
self.scaling = alpha
self.drop = nn.Dropout(dropout)
def forward(self, x):
orig_shape = x.shape
x = self.drop(x).view(-1, *self.in_factors)
y = self.adapter(x).view(*orig_shape[:-1], self.d_out)
return y * self.scaling
class LinearWithNdLinearLoRA(nn.Module):
"""A nn.Linear layer wrapped with the NdLinear-LoRA adapter."""
def __init__(self, base_layer, alpha=1.0, dropout=0.0):
super().__init__()
self.base_layer = base_layer
for param in self.base_layer.parameters():
param.requires_grad = False
self.adapter = NdLinearLoRA(
d_in=self.base_layer.in_features,
d_out=self.base_layer.out_features,
alpha=alpha,
dropout=dropout
)
def forward(self, x):
return self.base_layer(x) + self.adapter(x) |