| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor |
| | import math |
| | import torch.nn.functional as F |
| |
|
| | from transformers import AutoConfig, PretrainedConfig |
| | from jaxtyping import Float |
| | from dataclasses import asdict, dataclass |
| | from typing import List, Optional, Tuple, Dict |
| | import einops |
| |
|
| |
|
| | from .configIBA import MainConfig, HyperXSConfig, TrainingConfig |
| |
|
| |
|
| |
|
| | def transpose(weight, fan_in_fan_out): |
| | return weight.T if fan_in_fan_out else weight |
| |
|
| | class LoraLayer: |
| | def __init__( |
| | self, |
| | |
| | rank: int, |
| | train_cfg: TrainingConfig, |
| | |
| | lora_alpha: int, |
| | lora_dropout: float, |
| | ): |
| | self.rank = rank |
| | self.batch_train = train_cfg.per_device_train_batch_size |
| | self.batch_valid = train_cfg.per_device_eval_batch_size |
| | |
| | self.lora_alpha = lora_alpha |
| | |
| | if lora_dropout > 0.0: |
| | self.lora_dropout = nn.Dropout(p=lora_dropout) |
| | else: |
| | self.lora_dropout = lambda x: x |
| | |
| | self.disable_adapters = False |
| |
|
| | class LoraXSLinear(nn.Linear, LoraLayer): |
| | |
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int, |
| | train_cfg: TrainingConfig, |
| | rank: int = 64, |
| | |
| | lora_alpha: int = 1, |
| | lora_dropout: float = 0.0, |
| | fan_in_fan_out: bool = False, |
| | **kwargs, |
| | ): |
| | nn.Linear.__init__(self, in_features, out_features, **kwargs) |
| | LoraLayer.__init__(self, rank=rank, train_cfg=train_cfg, lora_alpha=lora_alpha, |
| | lora_dropout=lora_dropout) |
| |
|
| | self.fan_in_fan_out = fan_in_fan_out |
| | |
| | if rank > 0: |
| | |
| | self.register_buffer("lora_A", torch.zeros([in_features, rank]), persistent=True) |
| | self.register_buffer("lora_B", torch.zeros([rank, out_features]), persistent=True) |
| |
|
| | self.scaling = self.lora_alpha / self.rank |
| | |
| | self.weight.requires_grad = False |
| | self.lora_R = None |
| | |
| | |
| | |
| | if fan_in_fan_out: |
| | self.weight.data = self.weight.data.T |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | nn.Linear.reset_parameters(self) |
| | if hasattr(self, "lora_A"): |
| | |
| | nn.init.kaiming_uniform_(self.lora_A, mode='fan_out', a=math.sqrt(5)) |
| | nn.init.kaiming_uniform_(self.lora_B, mode='fan_in', a=math.sqrt(5)) |
| |
|
| |
|
| | |
| | |
| |
|
| | def set_R(self, R: torch.Tensor): |
| | self.lora_R = R |
| |
|
| | def decompose_weight_svd(self, rank): |
| | W = self.weight.data |
| | device, dtype = W.device, W.dtype |
| | |
| | try: |
| | U, S, Vt = torch.linalg.svd(W,full_matrices=False) |
| | except torch.linalg.LinAlgError as e: |
| | print(f"SVD computation failed: {e}") |
| | return None, None |
| |
|
| | |
| | U_r = U[:, :rank] |
| | S_r_values = S[:rank] |
| | sqrt_S_r_diag = torch.diag(torch.sqrt(S_r_values)) |
| | Vt_r = Vt[:rank, :] |
| |
|
| | B = U_r @ sqrt_S_r_diag |
| | A = sqrt_S_r_diag @ Vt_r |
| |
|
| | |
| | self.lora_A = A.T.to(device, dtype) |
| | self.lora_B = B.T.to(device, dtype) |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | def forward(self, x: torch.Tensor): |
| | previous_dtype = self.weight.dtype |
| |
|
| | if self.disable_adapters: |
| | result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
| | elif self.rank > 0: |
| | result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
| |
|
| | if self.lora_R is not None: |
| | lora_R = self.lora_R |
| | result = result + (self.lora_dropout(x) @ self.lora_A) @ (lora_R @ self.lora_B) * self.scaling |
| |
|
| | |
| | |
| |
|
| | else: |
| | result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
| |
|
| | if result.dtype != previous_dtype: |
| | result = result.to(previous_dtype) |
| |
|
| | return result |
| | |
| |
|
| | class HyperNetXSexp(nn.Module): |
| | def __init__( |
| | self, |
| | hyperxs_cfg: HyperXSConfig, |
| | hf_model_cfg: PretrainedConfig, |
| | |
| | ): |
| | super(HyperNetXSexp, self).__init__() |
| | self.n_modules = hyperxs_cfg.modules_per_layer |
| | self.rank = hyperxs_cfg.lora_attn_dim |
| | self.latent_feature_dim = hyperxs_cfg.latent_feature_dim |
| |
|
| | self.module_embed_dim = hyperxs_cfg.module_embed_dim |
| | self.layer_embed_dim = hyperxs_cfg.layer_embed_dim |
| | self.hyper_out = hyperxs_cfg.lora_attn_dim ** 2 |
| |
|
| | |
| | |
| | n_flat_indim = self.latent_feature_dim * hyperxs_cfg.n_cross_attn_tokens + self.module_embed_dim + self.layer_embed_dim |
| |
|
| | n_flat_outdim = hyperxs_cfg.out_proj_dim * hyperxs_cfg.n_cross_attn_tokens |
| | n_proj = 4 * n_flat_outdim |
| |
|
| | self.latent_proj = nn.Linear(hf_model_cfg.hidden_size, self.latent_feature_dim) |
| | self.mixture = nn.Linear(n_flat_indim, n_flat_outdim) |
| | self.c_fc = nn.Linear(n_flat_outdim, n_proj) |
| | self.c_proj = nn.Linear(n_proj, self.hyper_out) |
| | self.act = nn.GELU() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | self.ln_latent = nn.LayerNorm(hf_model_cfg.hidden_size, eps=hyperxs_cfg.layer_norm_epsilon) |
| | self.ln_1 = nn.LayerNorm(n_flat_indim, eps=hyperxs_cfg.layer_norm_epsilon) |
| | self.ln_2 = nn.LayerNorm(n_flat_outdim, eps=hyperxs_cfg.layer_norm_epsilon) |
| | |
| | |
| | self.layer_embedding = nn.Embedding(hf_model_cfg.num_hidden_layers, self.layer_embed_dim) |
| | |
| | self.module_embedding = nn.Embedding(self.n_modules, self.module_embed_dim) |
| | self.hyperxs_cfg = hyperxs_cfg |
| | self.hf_model_cfg = hf_model_cfg |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | |
| | INIT_STD = 1e-3 |
| | nn.init.kaiming_normal_(self.latent_proj.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
| | nn.init.constant_(self.latent_proj.bias, 0) |
| |
|
| | nn.init.kaiming_normal_(self.mixture.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
| | |
| | nn.init.constant_(self.mixture.bias, 0) |
| |
|
| | nn.init.kaiming_normal_(self.c_fc.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
| | |
| | nn.init.constant_(self.c_fc.bias, 0) |
| |
|
| | nn.init.normal_(self.layer_embedding.weight, mean=0.0, std=INIT_STD) |
| |
|
| | |
| | |
| | nn.init.constant_(self.c_proj.weight, 0) |
| | nn.init.constant_(self.c_proj.bias, 0) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def forward(self, x: Float[Tensor, 'b s f'], layer_idx) -> Float[Tensor, 'b r in out']: |
| | batch_size = x.shape[0] |
| | dtype_in = x.dtype |
| | x = x.to(self.latent_proj.weight.dtype) |
| | |
| |
|
| | |
| | x = self.ln_latent(x) |
| | x = self.latent_proj(x) |
| | |
| | |
| | |
| | x = einops.rearrange(x, 'batch seq fea -> batch (seq fea)') |
| | |
| | module_embedding = self.module_embedding.weight |
| | |
| | module_embedding = module_embedding.expand(batch_size, -1, -1) |
| | x = x[:, None, ...] |
| | x = x.expand(-1, self.n_modules, -1) |
| | |
| | |
| | x = torch.cat((module_embedding, x), dim=-1) |
| | x = einops.rearrange(x, 'batch n_modules in_dim -> (batch n_modules) in_dim') |
| | |
| | |
| | if self.layer_embed_dim > 0: |
| | |
| | layer_embedding = self.layer_embedding(torch.tensor(layer_idx, device=x.device)) |
| | |
| | layer_embedding = layer_embedding.expand(batch_size, self.n_modules, -1) |
| | layer_embedding = einops.rearrange(layer_embedding, 'batch n_modules in_dim -> (batch n_modules) in_dim') |
| |
|
| | x = torch.cat((layer_embedding, x), dim=-1) |
| | |
| | assert x.shape == (batch_size*self.n_modules, self.mixture.weight.data.shape[1]), 'Wrong at hypernetMLP.forward.x' |
| | |
| | h = self.ln_1(x) |
| | h = self.mixture(x) |
| | |
| | h = self.act(h) |
| | |
| | |
| | h = self.ln_2(h) |
| | h = self.c_fc(h) |
| | |
| | h = self.act(h) |
| |
|
| | |
| | h = self.c_proj(h) |
| |
|
| | h = einops.rearrange(h, '(batch n_modules) (rank r) -> batch n_modules rank r', |
| | batch = batch_size, n_modules=self.n_modules, |
| | rank = self.rank, r = self.rank) |
| | h = h.to(dtype_in) |
| | return h |
| |
|
| | def test_hypernet(): |
| | """ |
| | A simple test function for the HyperNetMLP class. |
| | Given empty B @ A |
| | """ |
| | mainCfg=MainConfig() |
| | print(mainCfg) |
| | hf_model_cfg = AutoConfig.from_pretrained( |
| | mainCfg.model.base_model_name |
| | ) |
| | print(hf_model_cfg) |
| |
|
| | print("--- Starting HyperNetMLP Test ---") |
| | |
| | in_features = hf_model_cfg.hidden_size |
| | reduced_dim = 128 |
| | out_features = 256 |
| | batch_size = 27 |
| |
|
| | rank = 30 |
| | outW = [768, 2*768] |
| | n_mlp=2 |
| | input_tensor = torch.randn(batch_size, mainCfg.hyperxs.n_cross_attn_tokens, in_features) |
| |
|
| | model = HyperNetXSexp(mainCfg.hyperxs, hf_model_cfg) |
| | count_parameters(model) |
| | |
| | output = model(input_tensor, layer_idx=torch.tensor(1, dtype=torch.long)) |
| | print('output shape', output.shape) |
| | B = output[:,1,:,:768] |
| | print('input shape', input_tensor.shape) |
| | print('output shape and sum of B', output.shape, output.sum(), B.sum()) |
| | if output.shape == (batch_size, n_mlp, rank, rank) and B.sum().item()==0: |
| | print("\n--- HyperNetMLP Test Passed Successfully! ✅ ---") |
| |
|
| | def count_parameters(model:nn.Module): |
| | print(f'Counting params in {model.__class__.__name__}') |
| | total_params = 0 |
| |
|
| | |
| | counted_param_ids = set() |
| | print(f"{'Parameter Name':^60} | {'Shape':^20} | {'Num Params':^20}") |
| | print("-" * 110) |
| |
|
| | for name, parameter in model.named_parameters(): |
| | if not parameter.requires_grad: |
| | continue |
| |
|
| | |
| | |
| | |
| | param_id = id(parameter) |
| | if param_id in counted_param_ids: |
| | |
| | print(f"Skipping shared parameter: {name}") |
| | continue |
| | counted_param_ids.add(param_id) |
| | |
| | shape = list(parameter.shape) |
| | |
| | |
| | num_params = parameter.numel() |
| | |
| | |
| | |
| | |
| | print(f"{name:<60} | {str(shape):<25} | {num_params:,}") |
| |
|
| | total_params += num_params |
| | print(f"Model: {model.__class__.__name__} Total Trainable Params: {total_params:,}") |
| | return total_params |
| |
|
| | if __name__ == "__main__": |
| | print("Hello world from iba_lora") |
| |
|
| | mainCfg=MainConfig() |
| | |
| | hf_model_cfg = AutoConfig.from_pretrained( |
| | mainCfg.model.base_model_name |
| | ) |
| | |
| | print('-'*50) |
| | test_hypernet() |