File size: 671 Bytes
a090db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank, alpha):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha/rank

        self.loraA = nn.Linear(in_features, rank, bias=False) 
        self.loraB = nn.Linear(rank, out_features, bias=False)

        nn.init.kaiming_uniform_(self.loraA.weight, a=5**0.5)
        nn.init.zeros_(self.loraB.weight)

    def forward(self, x):
        delta = self.loraB(self.loraA(x))   # (x*A)*B -->  ((B, S, D) * (B, D, R)) * (B, R, D) --> (B, S, R) * (B, R, D) --> (B, S, D)
        x = self.scaling * delta
        return x