brianling16 commited on
Commit
0b84fe0
·
verified ·
1 Parent(s): 3fd15e7

Upload shared_attention.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. shared_attention.py +142 -0
shared_attention.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from typing import List, Optional
7
+
8
+ from transformer import MultiheadSelfAttention, MLP, TransformerLayer
9
+ from lora_layer import LoRALinear, LoRAAdapter, LoRAConv1D
10
+
11
+ class SharedAttention(nn.Module):
12
+ def __init__(self, base_attn, num_repeats: int, lora_rank: int, lora_alpha: float):
13
+ super().__init__()
14
+ self.n_heads = base_attn.n_heads
15
+ self.d_head = base_attn.d_head
16
+ self.d_model = base_attn.d_model
17
+
18
+ self.q_proj = LoRALinear(base_attn.q_proj, lora_rank, lora_alpha, num_repeats)
19
+ self.k_proj = LoRALinear(base_attn.k_proj, lora_rank, lora_alpha, num_repeats)
20
+ self.v_proj = LoRALinear(base_attn.v_proj, lora_rank, lora_alpha, num_repeats)
21
+ self.out_proj = LoRALinear(base_attn.out_proj, lora_rank, lora_alpha, num_repeats)
22
+
23
+ def forward(self, x, repeat_idx: int, attn_mask: Optional[torch.Tensor] = None):
24
+ B, T, C = x.shape
25
+ H, D = self.n_heads, self.d_head
26
+
27
+ q = self.q_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
28
+ k = self.k_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
29
+ v = self.v_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
30
+
31
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(D)
32
+ if attn_mask is not None:
33
+ att = att + attn_mask
34
+ att = F.softmax(att, dim=-1)
35
+ y = att @ v
36
+ y = y.transpose(1,2).contiguous().view(B, T, C)
37
+ return self.out_proj(y, repeat_idx)
38
+
39
+ class SharedMLP(nn.Module):
40
+ def __init__(self, base_mlp, num_repeats: int, lora_rank: int, lora_alpha: float):
41
+ super().__init__()
42
+ self.fc1 = LoRALinear(base_mlp.fc1, lora_rank, lora_alpha, num_repeats)
43
+ self.fc2 = LoRALinear(base_mlp.fc2, lora_rank, lora_alpha, num_repeats)
44
+ self.act = base_mlp.act
45
+
46
+ def forward(self, x, repeat_idx: int):
47
+ return self.fc2(self.act(self.fc1(x, repeat_idx)), repeat_idx)
48
+
49
+ class SharedTransformerLayer(nn.Module):
50
+ def __init__(self, base_layer, num_repeats: int, lora_rank: int, lora_alpha: float):
51
+ super().__init__()
52
+ self.ln1 = base_layer.ln1
53
+ self.ln2 = base_layer.ln2
54
+ self.dropout1 = base_layer.dropout1
55
+ self.dropout2 = base_layer.dropout2
56
+ self.attn = SharedAttention(base_layer.attn, num_repeats, lora_rank, lora_alpha)
57
+ self.mlp = SharedMLP(base_layer.mlp, num_repeats, lora_rank, lora_alpha)
58
+
59
+ def forward(self, x, repeat_idx: int, attn_mask: Optional[torch.Tensor] = None):
60
+ y = self.attn(self.ln1(x), repeat_idx, attn_mask)
61
+ x = x + self.dropout1(y)
62
+ y = self.mlp(self.ln2(x), repeat_idx)
63
+ x = x + self.dropout2(y)
64
+ return x
65
+
66
+ # ---- Conversion Utilities ----
67
+ def average_weights(layers, attr):
68
+ weights = [getattr(layer, attr).weight.data for layer in layers]
69
+ return torch.stack(weights, dim=0).mean(dim=0)
70
+
71
+
72
+ def initialize_lora_with_svd(lora_layer, original_weights, repeat_indices, rank):
73
+ """
74
+ original_weights: list of original weights for each repeat index
75
+ repeat_indices: which repeat indices these weights correspond to
76
+ """
77
+ shared_weight = lora_layer.base_layer.weight.data.clone()
78
+
79
+ for idx, orig_weight in zip(repeat_indices, original_weights):
80
+ residual = orig_weight - shared_weight
81
+ U, S, Vh = torch.linalg.svd(residual, full_matrices=False)
82
+
83
+ # Truncate to rank
84
+ U = U[:, :rank]
85
+ S = S[:rank]
86
+ Vh = Vh[:rank, :]
87
+
88
+ # Initialize LoRA weights
89
+ lora_layer.lora_A[idx].weight.data = Vh # A = Vᵣᵀ
90
+ lora_layer.lora_B[idx].weight.data = U @ torch.diag(S) # B = UᵣΣᵣ
91
+
92
+ def convert_to_recursive(model, K=2, rank=8, lora_alpha=1.0):
93
+ n_layers = len(model.transformer.h)
94
+ new_blocks = []
95
+
96
+ for b in range(n_layers // K):
97
+ block_layers = model.transformer.h[b*K:(b+1)*K]
98
+ base_layer = copy.deepcopy(block_layers[0])
99
+
100
+ # Average weights across the block for shared parameters
101
+ with torch.no_grad():
102
+ if hasattr(base_layer.attn, 'c_attn'):
103
+ shared_weight = average_weights([l.attn for l in block_layers], 'c_attn')
104
+ base_layer.attn.c_attn.weight.data = shared_weight
105
+
106
+ if hasattr(base_layer.attn, 'c_proj'):
107
+ shared_weight = average_weights([l.attn for l in block_layers], 'c_proj')
108
+ base_layer.attn.c_proj.weight.data = shared_weight
109
+
110
+ if hasattr(base_layer.mlp, 'c_fc'):
111
+ shared_weight = average_weights([l.mlp for l in block_layers], 'c_fc')
112
+ base_layer.mlp.c_fc.weight.data = shared_weight
113
+
114
+ if hasattr(base_layer.mlp, 'c_proj'):
115
+ shared_weight = average_weights([l.mlp for l in block_layers], 'c_proj')
116
+ base_layer.mlp.c_proj.weight.data = shared_weight
117
+
118
+ # Convert to LoRA
119
+ if hasattr(base_layer.attn, 'c_attn'):
120
+ base_layer.attn.c_attn = LoRAConv1D(
121
+ base_layer.attn.c_attn, rank, lora_alpha, K
122
+ )
123
+
124
+ if hasattr(base_layer.attn, 'c_proj'):
125
+ base_layer.attn.c_proj = LoRAConv1D(
126
+ base_layer.attn.c_proj, rank, lora_alpha, K
127
+ )
128
+
129
+ if hasattr(base_layer.mlp, 'c_fc'):
130
+ base_layer.mlp.c_fc = LoRAConv1D(
131
+ base_layer.mlp.c_fc, rank, lora_alpha, K
132
+ )
133
+
134
+ if hasattr(base_layer.mlp, 'c_proj'):
135
+ base_layer.mlp.c_proj = LoRAConv1D(
136
+ base_layer.mlp.c_proj, rank, lora_alpha, K
137
+ )
138
+
139
+ new_blocks.append(base_layer)
140
+
141
+ model.transformer.h = nn.ModuleList(new_blocks)
142
+ return model