File size: 14,338 Bytes
a0d95b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
import torch
import torch.nn as nn
from torch import Tensor
import math
import torch.nn.functional as F

from transformers import AutoConfig, PretrainedConfig
from jaxtyping import Float
from dataclasses import asdict, dataclass
from typing import List, Optional, Tuple, Dict
import einops


from .configIBA import MainConfig, HyperXSConfig, TrainingConfig



def transpose(weight, fan_in_fan_out):
    return weight.T if fan_in_fan_out else weight

class LoraLayer:
    def __init__(
        self,
        
        rank: int,
        train_cfg: TrainingConfig,
        # batch: int,
        lora_alpha: int,
        lora_dropout: float,
    ):
        self.rank = rank
        self.batch_train = train_cfg.per_device_train_batch_size
        self.batch_valid = train_cfg.per_device_eval_batch_size
        # self.batch = batch
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.0:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.disable_adapters = False

class LoraXSLinear(nn.Linear, LoraLayer):
    # Lora implemented in a dense layer
    def __init__(
        self,
        in_features: int,
        out_features: int,
        train_cfg: TrainingConfig,
        rank: int = 64,
        # batch: int = 32,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        **kwargs,
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoraLayer.__init__(self, rank=rank, train_cfg=train_cfg, lora_alpha=lora_alpha,
                           lora_dropout=lora_dropout)

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if rank > 0:
            # self.register_buffer("lora_A", self.weight.new_zeros(in_features, rank), persistent=False)
            self.register_buffer("lora_A", torch.zeros([in_features, rank]), persistent=True)
            self.register_buffer("lora_B", torch.zeros([rank, out_features]), persistent=True)

            self.scaling = self.lora_alpha / self.rank
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            self.lora_R = None
            # self.lora_A.weight.requires_grad = False
            # self.lora_B.weight.requires_grad = False
        
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T
        self.reset_parameters()

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, "lora_A"):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, mode='fan_out', a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.lora_B, mode='fan_in', a=math.sqrt(5))


    # def train(self, mode: bool = True):
    #     nn.Linear.train(self, mode)

    def set_R(self, R: torch.Tensor):
        self.lora_R = R

    def decompose_weight_svd(self, rank):
        W = self.weight.data
        device, dtype = W.device, W.dtype
        #out_features, in_features = W.shape
        try:
            U, S, Vt = torch.linalg.svd(W,full_matrices=False)
        except torch.linalg.LinAlgError as e:
            print(f"SVD computation failed: {e}")
            return None, None

        # Set first r-rank columns
        U_r = U[:, :rank]  # Shape: (d, r)
        S_r_values = S[:rank]
        sqrt_S_r_diag = torch.diag(torch.sqrt(S_r_values)) # Shape: (r, r)
        Vt_r = Vt[:rank, :] # Shape: (r, e)

        B = U_r @ sqrt_S_r_diag # Shape: (d, r)
        A = sqrt_S_r_diag @ Vt_r # Shape: (r, d)

        #return B.to(device, dtype), A.to(device, dtype)
        self.lora_A = A.T.to(device, dtype)
        self.lora_B = B.T.to(device, dtype)
        
        # Safer way to do with trainable params
        # with torch.no_grad():
        #     self.lora_A.T.weight.copy_(A.to(device, dtype))
        #     self.lora_B.T.weight.copy_(B.to(device, dtype))

        
    def forward(self, x: torch.Tensor):
        previous_dtype = self.weight.dtype

        if self.disable_adapters:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        elif self.rank > 0:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

            if self.lora_R is not None:
                lora_R = self.lora_R
                result = result + (self.lora_dropout(x) @ self.lora_A) @ (lora_R @ self.lora_B) * self.scaling

            # else:
            #     # unapplied layers

        else:
             result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

        if result.dtype != previous_dtype:
            result = result.to(previous_dtype)

        return result
    

class HyperNetXSexp(nn.Module):
    def __init__(
        self,
        hyperxs_cfg: HyperXSConfig,
        hf_model_cfg: PretrainedConfig,

    ):
        super(HyperNetXSexp, self).__init__()
        self.n_modules = hyperxs_cfg.modules_per_layer # qkvo attn, up down gate mlp
        self.rank = hyperxs_cfg.lora_attn_dim # rank
        self.latent_feature_dim = hyperxs_cfg.latent_feature_dim # latent feature: embedding -> latent

        self.module_embed_dim = hyperxs_cfg.module_embed_dim
        self.layer_embed_dim = hyperxs_cfg.layer_embed_dim
        self.hyper_out =  hyperxs_cfg.lora_attn_dim ** 2

        # n_flat_indim = hf_model_cfg.hidden_size * hyperxs_cfg.n_cross_attn_tokens + self.module_embed_dim + self.layer_embed_dim
        # hyper_in_dim = 
        n_flat_indim = self.latent_feature_dim * hyperxs_cfg.n_cross_attn_tokens + self.module_embed_dim + self.layer_embed_dim

        n_flat_outdim = hyperxs_cfg.out_proj_dim * hyperxs_cfg.n_cross_attn_tokens 
        n_proj = 4 * n_flat_outdim

        self.latent_proj = nn.Linear(hf_model_cfg.hidden_size, self.latent_feature_dim) # rescale the embedđing first
        self.mixture = nn.Linear(n_flat_indim, n_flat_outdim)
        self.c_fc = nn.Linear(n_flat_outdim, n_proj)
        self.c_proj = nn.Linear(n_proj, self.hyper_out)
        self.act = nn.GELU()

        # Post-layer Normalization
        # self.ln_latent = nn.LayerNorm(self.latent_feature_dim, eps=hyperxs_cfg.layer_norm_epsilon)
        # self.ln_1 = nn.LayerNorm(n_flat_outdim, eps=hyperxs_cfg.layer_norm_epsilon)
        # self.ln_2 = nn.LayerNorm(n_proj, eps=hyperxs_cfg.layer_norm_epsilon)
        
        self.ln_latent = nn.LayerNorm(hf_model_cfg.hidden_size, eps=hyperxs_cfg.layer_norm_epsilon)
        self.ln_1 = nn.LayerNorm(n_flat_indim, eps=hyperxs_cfg.layer_norm_epsilon)
        self.ln_2 = nn.LayerNorm(n_flat_outdim, eps=hyperxs_cfg.layer_norm_epsilon)
        
        # A lookup table for each layer
        self.layer_embedding = nn.Embedding(hf_model_cfg.num_hidden_layers, self.layer_embed_dim)
        # Embedding for MLP
        self.module_embedding = nn.Embedding(self.n_modules, self.module_embed_dim)
        self.hyperxs_cfg = hyperxs_cfg
        self.hf_model_cfg = hf_model_cfg

        self.reset_parameters()

    def reset_parameters(self):
        # Initialize the MLP layers
        INIT_STD = 1e-3
        nn.init.kaiming_normal_(self.latent_proj.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
        nn.init.constant_(self.latent_proj.bias, 0)

        nn.init.kaiming_normal_(self.mixture.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
        # nn.init.normal_(self.mixture.weight, mean=0.0, std=INIT_STD)
        nn.init.constant_(self.mixture.bias, 0)

        nn.init.kaiming_normal_(self.c_fc.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
        # nn.init.normal_(self.c_fc.weight, mean=0.0, std=INIT_STD)
        nn.init.constant_(self.c_fc.bias, 0)

        nn.init.normal_(self.layer_embedding.weight, mean=0.0, std=INIT_STD)

        # partly zeros for the last layer
        # nn.init.kaiming_normal_(self.c_proj.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
        nn.init.constant_(self.c_proj.weight, 0)
        nn.init.constant_(self.c_proj.bias, 0)

        # with torch.no_grad():
        #     # Get the dimensions for loraB and loraA per rank. [B_part, A_part] for each rank.
        #     dim_b = self.outW[0]
        #     dim_a = self.outW[1]
        #     dim_per_rank = dim_b + dim_a

        #     # It starts as all zeros, so the loraB part is already correct.
        #     new_bias = torch.zeros_like(self.c_proj.bias)

        #     # Reshape the flat bias vector into (rank, dim_per_rank) for easy manipulation.
        #     new_bias_reshaped = new_bias.view(self.rank, dim_per_rank)

        #     # Select the part of the bias that corresponds to loraA for all ranks.
        #     # This is the slice from dim_b to the end for each rank.
        #     bias_a_part = new_bias_reshaped[:, dim_b:]

        #     # Initialize this loraA part with a small normal distribution.
        #     # A small standard deviation is crucial to keep the initial LoRA adjustment small.
        #     nn.init.kaiming_normal_(bias_a_part, a=0, mode='fan_in', nonlinearity='leaky_relu') #, mean=0.0, std=INIT_STD)
        #     self.c_proj.bias.data.copy_(new_bias)


    def forward(self, x: Float[Tensor, 'b s f'], layer_idx) -> Float[Tensor, 'b r in out']:
        batch_size = x.shape[0]
        dtype_in = x.dtype
        x = x.to(self.latent_proj.weight.dtype)
       

        # preprocess
        x = self.ln_latent(x)
        x = self.latent_proj(x)
        # x = self.ln_latent(x)
        
        # flatten
        x = einops.rearrange(x, 'batch seq fea -> batch (seq fea)')
        # get weight from mlp_embedding
        module_embedding = self.module_embedding.weight  # (n_mlp, embed_dim)
        # mlp_embedding = mlp_embedding[None, ...]
        module_embedding = module_embedding.expand(batch_size, -1, -1)
        x = x[:, None, ...]
        x = x.expand(-1, self.n_modules, -1)
        
        # Concatenate by the last dim & rearrange into 2D
        x = torch.cat((module_embedding, x), dim=-1)
        x = einops.rearrange(x, 'batch n_modules in_dim -> (batch n_modules) in_dim')
        
        # Add parameters to distinguish adapters
        if self.layer_embed_dim > 0:
            # Get the layer_embedding  (1, embedding) -> (embedding)
            layer_embedding = self.layer_embedding(torch.tensor(layer_idx, device=x.device))
            # Optimize the memory
            layer_embedding = layer_embedding.expand(batch_size, self.n_modules, -1)
            layer_embedding = einops.rearrange(layer_embedding, 'batch n_modules in_dim -> (batch n_modules) in_dim')

            x = torch.cat((layer_embedding, x), dim=-1)
        
        assert x.shape == (batch_size*self.n_modules, self.mixture.weight.data.shape[1]), 'Wrong at hypernetMLP.forward.x'
        # Post LayerNorm
        h = self.ln_1(x)
        h = self.mixture(x)
        # h = self.ln_1(h)
        h = self.act(h)
        
        # 2nd layer
        h = self.ln_2(h)
        h = self.c_fc(h)
        # h = self.ln_2(h)
        h = self.act(h)

        # 3rd layer
        h = self.c_proj(h)

        h = einops.rearrange(h, '(batch n_modules) (rank r) -> batch n_modules rank r',
                             batch = batch_size, n_modules=self.n_modules,
                             rank = self.rank, r = self.rank)
        h = h.to(dtype_in)
        return h

def test_hypernet():
    """
    A simple test function for the HyperNetMLP class.
    Given empty B @ A
    """
    mainCfg=MainConfig()
    print(mainCfg)
    hf_model_cfg = AutoConfig.from_pretrained(
        mainCfg.model.base_model_name
    )
    print(hf_model_cfg)

    print("--- Starting HyperNetMLP Test ---")
    # 1. Define parameters for the test
    in_features = hf_model_cfg.hidden_size # 768
    reduced_dim = 128
    out_features = 256
    batch_size = 27

    rank = 30
    outW = [768, 2*768]
    n_mlp=2
    input_tensor = torch.randn(batch_size, mainCfg.hyperxs.n_cross_attn_tokens, in_features)

    model = HyperNetXSexp(mainCfg.hyperxs, hf_model_cfg)
    count_parameters(model)
    # print(model)
    output = model(input_tensor, layer_idx=torch.tensor(1, dtype=torch.long))
    print('output shape', output.shape)
    B = output[:,1,:,:768]
    print('input shape', input_tensor.shape)
    print('output shape and sum of B', output.shape, output.sum(), B.sum())
    if output.shape == (batch_size, n_mlp, rank, rank) and B.sum().item()==0:
        print("\n--- HyperNetMLP Test Passed Successfully! ✅ ---")       

def count_parameters(model:nn.Module):
    print(f'Counting params in {model.__class__.__name__}')
    total_params = 0

    # Use a set to store the IDs of parameters that have already been counted
    counted_param_ids = set()
    print(f"{'Parameter Name':^60} | {'Shape':^20} | {'Num Params':^20}")
    print("-" * 110)

    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue

        # if not 'hypernet' in name or 'dummy' in name:
        #     continue
        # Get the unique ID of the parameter tensor in memory
        param_id = id(parameter)
        if param_id in counted_param_ids:
            # Optional: print a message to verify that sharing is working
            print(f"Skipping shared parameter: {name}")
            continue
        counted_param_ids.add(param_id)
        
        shape = list(parameter.shape)
        
        # the number of parameters in this layer
        num_params = parameter.numel()
        
        # layer name and n_params
        # print(f"{name:<50} | {num_params:<10,}")
        # if 'hypernet' in name or 'dummy' in name:
        print(f"{name:<60} | {str(shape):<25} | {num_params:,}")

        total_params += num_params
    print(f"Model: {model.__class__.__name__} Total Trainable Params: {total_params:,}")
    return total_params

if __name__ == "__main__":
    print("Hello world from iba_lora")

    mainCfg=MainConfig()
    # print(mainCfg)
    hf_model_cfg = AutoConfig.from_pretrained(
        mainCfg.model.base_model_name
    )
    # print(hf_model_cfg)
    print('-'*50)
    test_hypernet()