| | |
| |
|
| |
|
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| | from typing import Callable, Optional, Union |
| | import functools |
| | from dataclasses import asdict |
| |
|
| | from transformers.models.llama.modeling_llama import ( |
| | LlamaMLP, |
| | LlamaAttention, |
| | LlamaDecoderLayer, |
| | LlamaModel, |
| | LlamaForCausalLM |
| | ) |
| |
|
| | from transformers import AutoConfig, PretrainedConfig |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutputWithPast, |
| | CausalLMOutputWithPast, |
| | ) |
| | from transformers.models.llama.modeling_llama import LlamaConfig as HFLlamaConfig |
| | from transformers.processing_utils import Unpack |
| | from transformers.masking_utils import create_causal_mask |
| | from transformers.cache_utils import Cache, DynamicCache |
| | from transformers.utils.deprecation import deprecate_kwarg |
| | from transformers.utils.generic import check_model_inputs |
| | from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging |
| |
|
| | from .Xslora import LoraXSLinear, HyperNetXSexp |
| | from .configIBA import MainConfig, HyperXSConfig, TrainingConfig, from_dict |
| |
|
| |
|
| |
|
| | class IbaXs_LlamaAttention(LlamaAttention): |
| | def __init__(self, config: HFLlamaConfig, layer_idx: int): |
| | super().__init__(config, layer_idx) |
| | |
| | main_cfg = from_dict(MainConfig, config.main_cfg) |
| |
|
| | lora_attn_dim = main_cfg.hyperxs.lora_attn_dim |
| | train_cfg = main_cfg.training |
| |
|
| | self.q_proj = LoraXSLinear( |
| | config.hidden_size, config.num_attention_heads * self.head_dim, |
| | train_cfg=train_cfg, rank = lora_attn_dim, |
| | bias=config.attention_bias |
| | ) |
| | self.k_proj = LoraXSLinear( |
| | config.hidden_size, config.num_key_value_heads * self.head_dim, |
| | train_cfg=train_cfg, rank = lora_attn_dim, |
| | bias=config.attention_bias |
| | ) |
| | self.v_proj = LoraXSLinear( |
| | config.hidden_size, config.num_key_value_heads * self.head_dim, |
| | train_cfg=train_cfg, rank = lora_attn_dim, |
| | bias=config.attention_bias |
| | ) |
| | self.o_proj = LoraXSLinear( |
| | config.num_attention_heads * self.head_dim, config.hidden_size, |
| | train_cfg=train_cfg, rank = lora_attn_dim, |
| | bias=config.attention_bias |
| | ) |
| |
|
| |
|
| | class IbaXs_LlamaMLP(LlamaMLP): |
| | def __init__(self, config: HFLlamaConfig): |
| | super().__init__(config) |
| | |
| | main_cfg = from_dict(MainConfig, config.main_cfg) |
| | lora_attn_dim = main_cfg.hyperxs.lora_attn_dim |
| | train_cfg = main_cfg.training |
| | self.gate_proj = LoraXSLinear(self.hidden_size, self.intermediate_size, |
| | train_cfg=train_cfg, rank = lora_attn_dim, |
| | bias=config.mlp_bias) |
| | self.up_proj = LoraXSLinear(self.hidden_size, self.intermediate_size, |
| | train_cfg=train_cfg, rank = lora_attn_dim, |
| | bias=config.mlp_bias) |
| | self.down_proj = LoraXSLinear(self.intermediate_size, self.hidden_size, |
| | train_cfg=train_cfg, rank = lora_attn_dim, |
| | bias=config.mlp_bias) |
| |
|
| |
|
| | |
| | class IbaXs_LlamaDecoderLayer(LlamaDecoderLayer): |
| | def __init__(self, config: HFLlamaConfig, |
| | layer_idx: int, |
| | hypernetxs: HyperNetXSexp = None, |
| | ): |
| | super().__init__(config, layer_idx) |
| |
|
| | self.hypernetxs = hypernetxs |
| | self.hfconfig = config |
| | |
| | main_cfg = from_dict(MainConfig, config.main_cfg) |
| | self.hyperxs_cfg = main_cfg.hyperxs |
| | self.n_cross_attn_tokens = main_cfg.hyperxs.n_cross_attn_tokens |
| |
|
| | |
| | self.self_attn = IbaXs_LlamaAttention(config=config, layer_idx=layer_idx) |
| | self.mlp = IbaXs_LlamaMLP(config) |
| |
|
| |
|
| | |
| | |
| | |
| | self.register_buffer('layer_idx_hyperxs', torch.tensor(layer_idx, dtype=torch.long)) |
| | |
| |
|
| | self.__loraxsTensor = None |
| | self.layer_idx = layer_idx |
| |
|
| | |
| | |
| | def get_cache_loraxs(self): |
| | loraxsTensor = self.__loraxsTensor |
| | |
| |
|
| | return loraxsTensor |
| |
|
| | def reset_parameters(self): |
| | INIT_STD = 0.01 |
| | |
| |
|
| | def set_loraxs_adapters(self, loraXsTensor: Tensor): |
| | |
| | if loraXsTensor is None: |
| | raise ModuleNotFoundError |
| |
|
| | applied_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'] |
| | idx = 0 |
| | for key in applied_modules: |
| | for name, module in self.named_modules(): |
| | |
| | if name.endswith(key): |
| | if isinstance(module, LoraXSLinear): |
| | module.set_R(loraXsTensor[:, idx, : , :].contiguous()) |
| | idx = idx + 1 |
| | |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | |
| | |
| |
|
| | @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Cache] = None, |
| | use_cache: Optional[bool] = False, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| | |
| | flag_hyper: Optional[bool] = True, |
| | **kwargs |
| | ) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.input_layernorm(hidden_states) |
| | |
| | hidden_states, _ = self.self_attn( |
| | hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | position_embeddings=position_embeddings, |
| | **kwargs, |
| | ) |
| | hidden_states = residual + hidden_states |
| |
|
| | |
| | if flag_hyper: |
| | cross_attention = hidden_states[:, 0:self.n_cross_attn_tokens, :] |
| | |
| | |
| | |
| | self.__loraxsTensor = self.hypernetxs(cross_attention, self.layer_idx) |
| | |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| |
|
| | return hidden_states |
| |
|
| |
|
| | |
| | class IbaXs_LlamaModel(LlamaModel): |
| | def __init__(self, config: HFLlamaConfig): |
| | super().__init__(config) |
| | |
| | main_cfg = from_dict(MainConfig, config.main_cfg) |
| | self.hyperxs_cfg = main_cfg.hyperxs |
| | self.hypernetxs = HyperNetXSexp(main_cfg.hyperxs, config) |
| | self.layers = nn.ModuleList( |
| | [IbaXs_LlamaDecoderLayer(config, layer_idx, self.hypernetxs) \ |
| | for layer_idx in range(config.num_hidden_layers)] |
| | ) |
| |
|
| | self.flag_hyper = True |
| | self.hypernetxs_cross_attn_tokens = nn.Parameter(torch.zeros(main_cfg.hyperxs.n_cross_attn_tokens, |
| | config.hidden_size)) |
| | self.main_cfg = main_cfg |
| | |
| |
|
| | def reset_parameters(self): |
| | INIT_STD = 0.01 |
| | nn.init.normal_(self.hypernetxs_cross_attn_tokens, mean=0.0, std=INIT_STD) |
| |
|
| | def _create_prefix_or_mask( |
| | self, |
| | batch_idx: torch.Tensor, |
| | head_idx: torch.Tensor, |
| | q_idx: torch.Tensor, |
| | kv_idx: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Creates a mask to UNLOCK specific regions. |
| | Boolean values will be process data inside create_causal_mask |
| | 1. Prefix-sees-Prefix (bidirectional) |
| | 2. Prefix-sees-Text (all) |
| | """ |
| | prefix_len = self.hypernetxs_cross_attn_tokens.shape[0] |
| | |
| | |
| | is_query_prefix = q_idx < prefix_len |
| | |
| | is_key_safe = kv_idx < self.safe_boundaries |
| | return is_query_prefix & is_key_safe |
| | |
| | def _create_prefix_and_mask( |
| | self, |
| | batch_idx: torch.Tensor, |
| | head_idx: torch.Tensor, |
| | q_idx: torch.Tensor, |
| | kv_idx: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Creates a mask to LOCK specific regions. |
| | 1. Text-sees-Prefix |
| | """ |
| | prefix_len = self.hypernetxs_cross_attn_tokens.shape[0] |
| | |
| | |
| | |
| | |
| | |
| | |
| | is_forbidden = (q_idx >= prefix_len) & (kv_idx < prefix_len) |
| | |
| | |
| | |
| | |
| | |
| | return ~is_forbidden |
| |
|
| | |
| | |
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Cache] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | **kwargs |
| | ) -> BaseModelOutputWithPast: |
| | if (input_ids is None) ^ (inputs_embeds is not None): |
| | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
| |
|
| | if inputs_embeds is None: |
| | inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) |
| |
|
| | if use_cache and past_key_values is None: |
| | past_key_values = DynamicCache(config=self.config) |
| |
|
| | is_prefill = (past_key_values is None) or \ |
| | (hasattr(past_key_values, 'get_seq_length') and past_key_values.get_seq_length() == 0) |
| | prefix_len = self.main_cfg.hyperxs.n_cross_attn_tokens \ |
| | if self.main_cfg.hyperxs.n_cross_attn_tokens is not None and is_prefill else 0 |
| | |
| | |
| | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| | cache_position: torch.Tensor = torch.arange( |
| | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] + |
| | prefix_len, device=inputs_embeds.device |
| | ) |
| | |
| | |
| | |
| | position_ids = cache_position.unsqueeze(0).expand(inputs_embeds.shape[0], -1) |
| |
|
| | |
| | |
| | batch_size , seq_len_input = inputs_embeds.shape[:2] |
| | safe_boundaries = torch.full( |
| | (batch_size, 1, 1, 1), |
| | seq_len_input + prefix_len, |
| | device=inputs_embeds.device, |
| | dtype=torch.long |
| | ) |
| | if labels is not None and is_prefill: |
| | |
| | |
| | is_real_label = (labels != -100) |
| | |
| | prompt_lens = is_real_label.int().argmax(dim=1) |
| | has_label = is_real_label.any(dim=1) |
| | |
| | |
| | prompt_lens = torch.where( |
| | has_label, |
| | prompt_lens, |
| | torch.tensor(seq_len_input, device=inputs_embeds.device) |
| | ) |
| | |
| | |
| | safe_boundaries = (prefix_len + prompt_lens).view(batch_size, 1, 1, 1) |
| | self.safe_boundaries = safe_boundaries |
| | |
| |
|
| | hidden_states = inputs_embeds |
| | |
| | active_or_mask_func = None |
| | active_and_mask_func = None |
| | |
| | |
| | if is_prefill: |
| | self.flag_hyper = True |
| | |
| | active_or_mask_func = self._create_prefix_or_mask |
| | active_and_mask_func = self._create_prefix_and_mask |
| | |
| | if self.hypernetxs_cross_attn_tokens is not None: |
| | batch_size = hidden_states.shape[0] |
| | |
| | prefix_embeds = self.hypernetxs_cross_attn_tokens.expand(int(batch_size), -1, -1) |
| | hidden_states = torch.concat((prefix_embeds, hidden_states), dim=1) |
| | |
| | if attention_mask is not None: |
| | prefix_attention_mask = torch.ones((batch_size, prefix_len), |
| | dtype=attention_mask.dtype, device=attention_mask.device) |
| | attention_mask = torch.cat([prefix_attention_mask, |
| | attention_mask], dim=1) |
| | else: |
| | |
| | self.flag_hyper = False |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | causal_mask = None |
| | |
| | if is_prefill: |
| | current_seq_len = hidden_states.shape[1] |
| | dtype = inputs_embeds.dtype |
| | min_dtype = torch.finfo(dtype).min |
| |
|
| | |
| | q_idx = torch.arange(current_seq_len, device=inputs_embeds.device).view(1, 1, current_seq_len, 1) |
| | k_idx = torch.arange(current_seq_len, device=inputs_embeds.device).view(1, 1, 1, current_seq_len) |
| |
|
| | |
| | mask_bool = q_idx >= k_idx |
| |
|
| | |
| | |
| | prefix_unlock = (q_idx < prefix_len) & (k_idx < safe_boundaries) |
| | mask_bool = mask_bool | prefix_unlock |
| |
|
| | |
| | text_forbidden_prefix = (q_idx >= prefix_len) & (k_idx < prefix_len) |
| | mask_bool = mask_bool & (~text_forbidden_prefix) |
| |
|
| | |
| | causal_mask = torch.full_like(mask_bool, min_dtype, dtype=dtype) |
| | causal_mask = causal_mask.masked_fill(mask_bool, 0.0) |
| |
|
| | |
| | if attention_mask is not None: |
| | padding_mask_float = (1.0 - attention_mask.to(dtype)) * min_dtype |
| | padding_mask_float = padding_mask_float[:, None, None, :] |
| | causal_mask = causal_mask + padding_mask_float |
| |
|
| | |
| | causal_mask = causal_mask.contiguous() |
| | |
| | else: |
| | |
| | self.flag_hyper = False |
| | |
| | pass |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| |
|
| | for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): |
| | hidden_states = decoder_layer( |
| | hidden_states, |
| | attention_mask=causal_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | cache_position=cache_position, |
| | position_embeddings=position_embeddings, |
| | flag_hyper = self.flag_hyper, |
| | **kwargs, |
| | ) |
| | |
| | if idx < self.config.num_hidden_layers - 1 and self.flag_hyper: |
| | self.layers[idx+1].set_loraxs_adapters(decoder_layer.get_cache_loraxs()) |
| | |
| | |
| | if self.flag_hyper: |
| | hidden_states = hidden_states[:, self.main_cfg.hyperxs.n_cross_attn_tokens:, :] |
| | |
| | hidden_states = self.norm(hidden_states) |
| | return BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=past_key_values, |
| | ) |
| |
|
| |
|
| | class IbaXs_LlamaForCausalLM(LlamaForCausalLM): |
| | def __init__(self, config: HFLlamaConfig, |
| | ): |
| | super().__init__(config) |
| | self.model = IbaXs_LlamaModel(config) |
| |
|
| | def reset_BA_xslora(self): |
| | for name, module in self.named_modules(): |
| | if isinstance(module, LoraXSLinear): |
| | module.decompose_weight_svd(module.rank) |
| | |
| |
|
| | @can_return_tuple |
| | @auto_docstring |
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Cache] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | logits_to_keep: Union[int, torch.Tensor] = 0, |
| | **kwargs: Unpack[TransformersKwargs], |
| | ) -> CausalLMOutputWithPast: |
| | |
| | outputs: BaseModelOutputWithPast = self.model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | labels=labels, |
| | **kwargs, |
| | ) |
| |
|
| | hidden_states = outputs.last_hidden_state |
| | |
| | slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| | logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| | def test_set_loraxs_adapters(): |
| | main_cfg=MainConfig() |
| | |
| | hf_model_cfg = AutoConfig.from_pretrained( |
| | main_cfg.model.base_model_name |
| | ) |
| | |
| | layer = IbaXs_LlamaDecoderLayer(hf_model_cfg, main_cfg=main_cfg, layer_idx=1) |
| | rank = main_cfg.hyperxs.lora_attn_dim |
| |
|
| | batch_size = main_cfg.training.batch_train |
| | modules = 7 |
| | values = torch.arange(1, modules + 1) |
| | values_reshaped = values.view(modules, 1, 1) |
| | loraTensor = values_reshaped.expand(batch_size, modules, rank, rank) |
| |
|
| | layer.set_loraxs_adapters(loraTensor) |
| |
|
| | def test_llm(): |
| | |
| | main_cfg=MainConfig() |
| | config = AutoConfig.from_pretrained( |
| | main_cfg.model.base_model_name |
| | ) |
| | config.hidden_size=128 |
| | config.intermediate_size=256 |
| | config.num_hidden_layers=6 |
| | config.head_dim = config.hidden_size // config.num_attention_heads |
| |
|
| | main_cfg_dict = asdict(main_cfg) |
| | config.main_cfg = main_cfg_dict |
| |
|
| | model_bb = IbaXs_LlamaForCausalLM(config=config) |
| | model_bb.reset_BA_xslora() |
| | batch_size = main_cfg.training.per_device_train_batch_size |
| | input = torch.ones(batch_size, 11, dtype=torch.long) |
| | total_params = sum(p.numel() for p in model_bb.parameters()) |
| | print('input llm', input.shape, total_params) |
| | |
| | output = model_bb(input,logits_to_keep=1) |
| | print('output llm', output.logits.shape) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | device = 'mps' |
| | from transformers import LlamaTokenizer |
| | tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=True) |
| | model_bb.eval() |
| | prompts = [ |
| | "The capital of France is", |
| | "Here is a simple Python function to add two numbers:" |
| | ] |
| | for i, prompt in enumerate(prompts): |
| | print(f"\n--- Prompt {i+1} ---") |
| | print(f"Input: {prompt}") |
| |
|
| | |
| | |
| | inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| |
|
| | |
| | |
| | with torch.no_grad(): |
| | outputs = model_bb.generate( |
| | **inputs, |
| | max_new_tokens=50, |
| | do_sample=True, |
| | temperature=0.7, |
| | top_k=50 |
| | |
| | ) |
| |
|
| | |
| | |
| | output_tokens = outputs[0][inputs["input_ids"].shape[1]:] |
| | generated_text = tokenizer.decode(output_tokens, skip_special_tokens=True) |
| |
|
| | print(f"Output: {generated_text}") |
| |
|
| |
|
| |
|
| | def test_backbone(): |
| | |
| | main_cfg=MainConfig() |
| | config = AutoConfig.from_pretrained( |
| | main_cfg.model.base_model_name |
| | ) |
| | config.hidden_size=128 |
| | config.intermediate_size=256 |
| | config.num_hidden_layers=6 |
| | config.head_dim = config.hidden_size // config.num_attention_heads |
| |
|
| | main_cfg_dict = asdict(main_cfg) |
| | config.main_cfg = main_cfg_dict |
| |
|
| | mode_bb = IbaXs_LlamaModel(config=config) |
| | batch_size = main_cfg.training.batch_train |
| | input = torch.ones(batch_size, 11, dtype=torch.long) |
| | total_params = sum(p.numel() for p in mode_bb.parameters()) |
| | print('input bb', input.shape, total_params) |
| | output = mode_bb(input) |
| | print('output bb', output.last_hidden_state.shape) |
| |
|
| | if __name__ == "__main__": |
| | print("Hello world from XS_llama.py") |
| | |
| | test_llm() |
| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|