brianling16 commited on
Commit
f8f3daa
·
verified ·
1 Parent(s): 9720c2e

Upload lora_layer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. lora_layer.py +139 -0
lora_layer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from typing import Optional, List
7
+
8
+ # ---- LoRA ----
9
+ class LoRAAdapter(nn.Module):
10
+ def __init__(self, in_features: int, out_features: int, rank: int, alpha: float = 1.0,
11
+ weight: Optional[torch.Tensor] = None):
12
+ super().__init__()
13
+ self.rank = rank
14
+ self.alpha = alpha
15
+ if rank > 0:
16
+ self.A = nn.Parameter(torch.zeros((rank, in_features)))
17
+ self.B = nn.Parameter(torch.zeros((out_features, rank)))
18
+
19
+ # Initialize with SVD if base weight is provided
20
+ if weight is not None:
21
+ U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
22
+ U = U[:, :rank]
23
+ S = S[:rank]
24
+ Vh = Vh[:rank, :]
25
+ self.A.data = Vh # (rank, in_features)
26
+ self.B.data = U @ torch.diag(S) # (out_features, rank)
27
+ else:
28
+ nn.init.normal_(self.A, std=1/rank)
29
+ nn.init.zeros_(self.B)
30
+ else:
31
+ self.register_parameter('A', None)
32
+ self.register_parameter('B', None)
33
+
34
+ def delta(self) -> Optional[torch.Tensor]:
35
+ if self.rank == 0 or self.A is None or self.B is None:
36
+ return None
37
+ return (self.B @ self.A) * (self.alpha / self.rank) # (out, in)
38
+
39
+ def lora_parameters(self):
40
+ if self.A is not None:
41
+ yield self.A
42
+ if self.B is not None:
43
+ yield self.B
44
+
45
+ class LoRALinear(nn.Module):
46
+ def __init__(self, linear: nn.Linear, rank: int, alpha: float = 1.0, num_repeats: int = 1):
47
+ super().__init__()
48
+ self.linear = linear # base frozen linear
49
+ self.rank = rank
50
+ self.num_repeats = num_repeats
51
+
52
+ if rank > 0:
53
+ self.loras = nn.ModuleList([
54
+ LoRAAdapter(linear.in_features, linear.out_features, rank, alpha)
55
+ for _ in range(num_repeats)
56
+ ])
57
+ else:
58
+ self.loras = nn.ModuleList([])
59
+
60
+ def forward(self, x, repeat_idx: int = 0):
61
+ out = self.linear(x) # [batch, ..., out_features]
62
+ if self.rank == 0:
63
+ return out
64
+ delta = self.loras[repeat_idx].delta() # (out, in)
65
+ if delta is not None:
66
+ delta_t = delta # nn.Linear expects (out, in)
67
+ return out + F.linear(x, delta_t)
68
+ return out
69
+
70
+ def lora_parameters(self):
71
+ for lora in self.loras:
72
+ yield from lora.lora_parameters()
73
+
74
+
75
+ class LoRAConv1D(nn.Module):
76
+ """GPT-2 style Conv1D with LoRA support."""
77
+ def __init__(self, conv1d, rank: int, alpha: float = 1.0, num_repeats: int = 1):
78
+ super().__init__()
79
+ self.conv1d = conv1d # base GPT-2 Conv1D
80
+ self.rank = rank
81
+ self.num_repeats = num_repeats
82
+ in_features, out_features = conv1d.weight.shape # GPT-2 Conv1D: [in, out]
83
+
84
+ # Special handling for c_attn layer which has 3x output features
85
+ self.is_c_attn = (out_features % 3 == 0) and ("c_attn" in str(conv1d))
86
+ self.split_size = out_features // 3 if self.is_c_attn else out_features
87
+
88
+ if rank > 0:
89
+ if self.is_c_attn:
90
+ # Create separate LoRA adapters for Q, K, V projections
91
+ self.loras = nn.ModuleList([
92
+ nn.ModuleList([
93
+ LoRAAdapter(in_features, self.split_size, rank, alpha)
94
+ for _ in range(3) # Q, K, V
95
+ ]) for _ in range(num_repeats)
96
+ ])
97
+ else:
98
+ self.loras = nn.ModuleList([
99
+ LoRAAdapter(in_features, out_features, rank, alpha)
100
+ for _ in range(num_repeats)
101
+ ])
102
+ else:
103
+ self.loras = nn.ModuleList([])
104
+
105
+ def forward(self, x, repeat_idx: int = 0):
106
+ """
107
+ x: [batch, seq_len, in_features]
108
+ returns: [batch, seq_len, out_features]
109
+ """
110
+ out = self.conv1d(x)
111
+ if self.rank == 0 or len(self.loras) == 0:
112
+ return out
113
+
114
+ if self.is_c_attn:
115
+ # Handle Q, K, V projections separately
116
+ deltas = []
117
+ for i in range(3):
118
+ delta = self.loras[repeat_idx][i].delta() # (split_size, in)
119
+ if delta is not None:
120
+ delta_t = delta.T # (in, split_size)
121
+ deltas.append(torch.matmul(x, delta_t))
122
+ if deltas:
123
+ return out + torch.cat(deltas, dim=-1)
124
+ return out
125
+ else:
126
+ delta = self.loras[repeat_idx].delta() # (out, in)
127
+ if delta is not None:
128
+ delta_t = delta.T # (in, out)
129
+ return out + torch.matmul(x, delta_t)
130
+ return out
131
+
132
+ def lora_parameters(self):
133
+ if self.is_c_attn:
134
+ for lora_group in self.loras:
135
+ for lora in lora_group:
136
+ yield from lora.lora_parameters()
137
+ else:
138
+ for lora in self.loras:
139
+ yield from lora.lora_parameters()