# mypy: ignore-errors # A single seq of representive cross-attention tokens is added at the begining only. # the next layer re-use output from the previous layer import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import Callable, Optional, Union import functools from dataclasses import asdict from transformers.models.llama.modeling_llama import ( LlamaMLP, LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaForCausalLM ) from transformers import AutoConfig, PretrainedConfig from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.models.llama.modeling_llama import LlamaConfig as HFLlamaConfig from transformers.processing_utils import Unpack from transformers.masking_utils import create_causal_mask from transformers.cache_utils import Cache, DynamicCache from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import check_model_inputs from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from .Xslora import LoraXSLinear, HyperNetXSexp from .configIBA import MainConfig, HyperXSConfig, TrainingConfig, from_dict class IbaXs_LlamaAttention(LlamaAttention): def __init__(self, config: HFLlamaConfig, layer_idx: int): super().__init__(config, layer_idx) # Get main_config as a dataclass object main_cfg = from_dict(MainConfig, config.main_cfg) lora_attn_dim = main_cfg.hyperxs.lora_attn_dim train_cfg = main_cfg.training self.q_proj = LoraXSLinear( config.hidden_size, config.num_attention_heads * self.head_dim, train_cfg=train_cfg, rank = lora_attn_dim, bias=config.attention_bias ) self.k_proj = LoraXSLinear( config.hidden_size, config.num_key_value_heads * self.head_dim, train_cfg=train_cfg, rank = lora_attn_dim, bias=config.attention_bias ) self.v_proj = LoraXSLinear( config.hidden_size, config.num_key_value_heads * self.head_dim, train_cfg=train_cfg, rank = lora_attn_dim, bias=config.attention_bias ) self.o_proj = LoraXSLinear( config.num_attention_heads * self.head_dim, config.hidden_size, train_cfg=train_cfg, rank = lora_attn_dim, bias=config.attention_bias ) class IbaXs_LlamaMLP(LlamaMLP): def __init__(self, config: HFLlamaConfig): super().__init__(config) # Get main_config as a dataclass object main_cfg = from_dict(MainConfig, config.main_cfg) lora_attn_dim = main_cfg.hyperxs.lora_attn_dim train_cfg = main_cfg.training self.gate_proj = LoraXSLinear(self.hidden_size, self.intermediate_size, train_cfg=train_cfg, rank = lora_attn_dim, bias=config.mlp_bias) self.up_proj = LoraXSLinear(self.hidden_size, self.intermediate_size, train_cfg=train_cfg, rank = lora_attn_dim, bias=config.mlp_bias) self.down_proj = LoraXSLinear(self.intermediate_size, self.hidden_size, train_cfg=train_cfg, rank = lora_attn_dim, bias=config.mlp_bias) # block layer class IbaXs_LlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: HFLlamaConfig, layer_idx: int, hypernetxs: HyperNetXSexp = None, ): super().__init__(config, layer_idx) self.hypernetxs = hypernetxs self.hfconfig = config # Get main_config as a dataclass object main_cfg = from_dict(MainConfig, config.main_cfg) self.hyperxs_cfg = main_cfg.hyperxs self.n_cross_attn_tokens = main_cfg.hyperxs.n_cross_attn_tokens # Replace self.self_attn = IbaXs_LlamaAttention(config=config, layer_idx=layer_idx) self.mlp = IbaXs_LlamaMLP(config) #self.cross_attn_tokens = nn.Parameter(torch.empty(main_cfg.hyperxs.n_cross_attn_tokens, # hf_model_cfg.hidden_size)) # In case of to(device) -> do not use self.layer_idx = LongTensor(layer_idx) self.register_buffer('layer_idx_hyperxs', torch.tensor(layer_idx, dtype=torch.long)) # self.flag_hyper = True self.__loraxsTensor = None self.layer_idx = layer_idx # self.reset_parameters() def get_cache_loraxs(self): loraxsTensor = self.__loraxsTensor # self.loraxsTensor = None return loraxsTensor def reset_parameters(self): INIT_STD = 0.01 # nn.init.normal_(self.hypernetxs_cross_attn_tokens, mean=0.0, std=INIT_STD) def set_loraxs_adapters(self, loraXsTensor: Tensor): # (batch, modules, rank, rank) if loraXsTensor is None: raise ModuleNotFoundError applied_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'] idx = 0 for key in applied_modules: for name, module in self.named_modules(): # print('name', name, type(name)) if name.endswith(key): if isinstance(module, LoraXSLinear): module.set_R(loraXsTensor[:, idx, : , :].contiguous()) idx = idx + 1 # print(f'name: {name}. R: {module.lora_train_R.shape, module.lora_train_R[1,1,1]}') else: raise NotImplementedError # def set_flag_hyper(self, flag: bool = False): # self.flag_hyper = flag @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ### addtional arg flag_hyper: Optional[bool] = True, **kwargs #: Unpack[TransformersKwargs], ) -> torch.Tensor: # if self.flag_hyper: # batch_size = hidden_states.shape[0] # hypernetxs_cross_attn_tokens = self.hypernetxs_cross_attn_tokens.expand(int(batch_size), -1, -1) # print('batch', batch_size, hypernetxs_cross_attn_tokens.shape) # hidden_states = torch.concat((hypernetxs_cross_attn_tokens, hidden_states), dim=1) # Copy paste modify from modeling_llama.py residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Extract representative tokens if flag_hyper: cross_attention = hidden_states[:, 0:self.n_cross_attn_tokens, :] # Still push cross_attention to the next layer # hidden_states = hidden_states[:, self.n_cross_attn_tokens:, :] # save all lora adapters as a attribute self.__loraxsTensor = self.hypernetxs(cross_attention, self.layer_idx) # (batch, n_modules, r, r) # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states #back bone models class IbaXs_LlamaModel(LlamaModel): def __init__(self, config: HFLlamaConfig): super().__init__(config) # Get main_config as a dataclass object main_cfg = from_dict(MainConfig, config.main_cfg) self.hyperxs_cfg = main_cfg.hyperxs self.hypernetxs = HyperNetXSexp(main_cfg.hyperxs, config) self.layers = nn.ModuleList( [IbaXs_LlamaDecoderLayer(config, layer_idx, self.hypernetxs) \ for layer_idx in range(config.num_hidden_layers)] ) self.flag_hyper = True self.hypernetxs_cross_attn_tokens = nn.Parameter(torch.zeros(main_cfg.hyperxs.n_cross_attn_tokens, config.hidden_size)) self.main_cfg = main_cfg # self.reset_parameters() def reset_parameters(self): INIT_STD = 0.01 nn.init.normal_(self.hypernetxs_cross_attn_tokens, mean=0.0, std=INIT_STD) def _create_prefix_or_mask( self, batch_idx: torch.Tensor, head_idx: torch.Tensor, q_idx: torch.Tensor, # kv_idx: torch.Tensor, # ) -> torch.Tensor: """ Creates a mask to UNLOCK specific regions. Boolean values will be process data inside create_causal_mask 1. Prefix-sees-Prefix (bidirectional) 2. Prefix-sees-Text (all) """ prefix_len = self.hypernetxs_cross_attn_tokens.shape[0] # K (int) # 1. Query is Prefix? is_query_prefix = q_idx < prefix_len # kv_idx [1, 1, 1, K] compared with safe_boundaries [Batch, 1, 1, 1] is_key_safe = kv_idx < self.safe_boundaries return is_query_prefix & is_key_safe def _create_prefix_and_mask( self, batch_idx: torch.Tensor, head_idx: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor, ) -> torch.Tensor: """ Creates a mask to LOCK specific regions. 1. Text-sees-Prefix """ prefix_len = self.hypernetxs_cross_attn_tokens.shape[0] # K (int) # The "forbidden" zone is: # Query is Text (q_idx >= prefix_len) # AND # Key is Prefix (kv_idx < prefix_len) is_forbidden = (q_idx >= prefix_len) & (kv_idx < prefix_len) # Return True if *not* in the forbidden zone. # ~ is the vmap-safe "NOT" operator for boolean tensors. # if q_idx.item() <= 10 and kv_idx.item() <= 10: # print('is_forbidden', ~is_forbidden) return ~is_forbidden # @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, **kwargs #: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) is_prefill = (past_key_values is None) or \ (hasattr(past_key_values, 'get_seq_length') and past_key_values.get_seq_length() == 0) prefix_len = self.main_cfg.hyperxs.n_cross_attn_tokens \ if self.main_cfg.hyperxs.n_cross_attn_tokens is not None and is_prefill else 0 # if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position: torch.Tensor = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] + prefix_len, device=inputs_embeds.device ) #if position_ids is None: # Count from K (prefix) + S (sequence) position_ids = cache_position.unsqueeze(0).expand(inputs_embeds.shape[0], -1) #### # prefix mask boundary from labels batch_size , seq_len_input = inputs_embeds.shape[:2] safe_boundaries = torch.full( (batch_size, 1, 1, 1), seq_len_input + prefix_len, device=inputs_embeds.device, dtype=torch.long ) if labels is not None and is_prefill: # labels shape: [Batch, Seq_Len] # Find the FIRST index where label != -100 for EACH row in the batch. is_real_label = (labels != -100) # all False (all -100) -> 0 prompt_lens = is_real_label.int().argmax(dim=1) has_label = is_real_label.any(dim=1) # If a row has no labels (all -100), prompt_len should be the full sequence length prompt_lens = torch.where( has_label, prompt_lens, torch.tensor(seq_len_input, device=inputs_embeds.device) ) # Calculate safe boundary: Prefix Length + Prompt Length # Reshape to [Batch, 1, 1, 1] to allow comparison with kv_idx [1, 1, 1, Seq] safe_boundaries = (prefix_len + prompt_lens).view(batch_size, 1, 1, 1) self.safe_boundaries = safe_boundaries #### hidden_states = inputs_embeds # concat A SINGLE seq of tokens: active_or_mask_func = None active_and_mask_func = None # No cache or empty cache # if past_key_values is None or past_key_values.get_seq_length() == 0: if is_prefill: self.flag_hyper = True # Use functools.partial to pass `self` active_or_mask_func = self._create_prefix_or_mask active_and_mask_func = self._create_prefix_and_mask if self.hypernetxs_cross_attn_tokens is not None: batch_size = hidden_states.shape[0] # prefix cross-attention tokens prefix_embeds = self.hypernetxs_cross_attn_tokens.expand(int(batch_size), -1, -1) hidden_states = torch.concat((prefix_embeds, hidden_states), dim=1) # modify causal_mask ## NEED to check carefully later if attention_mask is not None: prefix_attention_mask = torch.ones((batch_size, prefix_len), dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([prefix_attention_mask, attention_mask], dim=1) else: # generating mode self.flag_hyper = False # position_ids = text_position_ids # cache_position.unsqueeze(0) ### # Need to check at generate() # print('attention_mask', attention_mask, attention_mask.shape, input_ids.shape) # causal_mask = create_causal_mask( # config=self.config, # input_embeds=hidden_states, # # attention_mask=attention_mask, # attention_mask = None, # cache_position=cache_position, # past_key_values=past_key_values, # position_ids=position_ids, # # Pass custom logic. Not work. # or_mask_function=active_or_mask_func, # and_mask_function=active_and_mask_func # ) causal_mask = None if is_prefill: current_seq_len = hidden_states.shape[1] dtype = inputs_embeds.dtype min_dtype = torch.finfo(dtype).min # Grid q_idx = torch.arange(current_seq_len, device=inputs_embeds.device).view(1, 1, current_seq_len, 1) k_idx = torch.arange(current_seq_len, device=inputs_embeds.device).view(1, 1, 1, current_seq_len) # Basic Causal Mask mask_bool = q_idx >= k_idx # C. Logic Custom (Prefill) # Logic 1: Prefix Unlock prefix_unlock = (q_idx < prefix_len) & (k_idx < safe_boundaries) mask_bool = mask_bool | prefix_unlock # Logic 2: Text Forbidden text_forbidden_prefix = (q_idx >= prefix_len) & (k_idx < prefix_len) mask_bool = mask_bool & (~text_forbidden_prefix) #Float Mask (Bias) causal_mask = torch.full_like(mask_bool, min_dtype, dtype=dtype) causal_mask = causal_mask.masked_fill(mask_bool, 0.0) # Add Padding Mask if attention_mask is not None: padding_mask_float = (1.0 - attention_mask.to(dtype)) * min_dtype padding_mask_float = padding_mask_float[:, None, None, :] causal_mask = causal_mask + padding_mask_float # (Prevent 8D & SDPA Compatibility) causal_mask = causal_mask.contiguous() else: # --- GENERATE (DECODING) --- self.flag_hyper = False # Để causal_mask = None. FLASH ATTENTION pass # print('causal_mask', type(causal_mask), causal_mask.dtype, causal_mask.shape) #### #### position_embeddings = self.rotary_emb(hidden_states, position_ids) for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, flag_hyper = self.flag_hyper, **kwargs, ) ### Modification if idx < self.config.num_hidden_layers - 1 and self.flag_hyper: self.layers[idx+1].set_loraxs_adapters(decoder_layer.get_cache_loraxs()) ### Apply previous output to the next stage ### remove the representative cross-attention tokens. if self.flag_hyper: hidden_states = hidden_states[:, self.main_cfg.hyperxs.n_cross_attn_tokens:, :] ### hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) class IbaXs_LlamaForCausalLM(LlamaForCausalLM): def __init__(self, config: HFLlamaConfig, ): super().__init__(config) self.model = IbaXs_LlamaModel(config) def reset_BA_xslora(self): for name, module in self.named_modules(): if isinstance(module, LoraXSLinear): module.decompose_weight_svd(module.rank) # print('Reset BA for', name) @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, labels=labels, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def test_set_loraxs_adapters(): main_cfg=MainConfig() # print(mainCfg) hf_model_cfg = AutoConfig.from_pretrained( main_cfg.model.base_model_name ) #hypernetxs = HyperNetXSexp(hf_model_cfg = hf_model_cfg, hyperxs_cfg=mainCfg.hyperxs) layer = IbaXs_LlamaDecoderLayer(hf_model_cfg, main_cfg=main_cfg, layer_idx=1) rank = main_cfg.hyperxs.lora_attn_dim batch_size = main_cfg.training.batch_train modules = 7 values = torch.arange(1, modules + 1) values_reshaped = values.view(modules, 1, 1) loraTensor = values_reshaped.expand(batch_size, modules, rank, rank) layer.set_loraxs_adapters(loraTensor) def test_llm(): # print(mainCfg) main_cfg=MainConfig() config = AutoConfig.from_pretrained( main_cfg.model.base_model_name ) config.hidden_size=128 config.intermediate_size=256 config.num_hidden_layers=6 config.head_dim = config.hidden_size // config.num_attention_heads main_cfg_dict = asdict(main_cfg) config.main_cfg = main_cfg_dict model_bb = IbaXs_LlamaForCausalLM(config=config) model_bb.reset_BA_xslora() batch_size = main_cfg.training.per_device_train_batch_size input = torch.ones(batch_size, 11, dtype=torch.long) total_params = sum(p.numel() for p in model_bb.parameters()) print('input llm', input.shape, total_params) # inference output = model_bb(input,logits_to_keep=1) print('output llm', output.logits.shape) # Assuming 'model' is your instantiated IbaXs_LlamaModel # model = model_bb.model # if hasattr(model, 'layers') and len(model.layers) > 1: # # Get the hypernet object from layer 0 and layer 1 # hypernet_0 = model.layers[0].hypernetxs # hypernet_1 = model.layers[1].hypernetxs # # Check if they are the same object in memory # is_same_object = (hypernet_0 is hypernet_1) # print(f"Hypernet from Layer 0 ID: {id(hypernet_0)}") # print(f"Hypernet from Layer 1 ID: {id(hypernet_1)}") # print(f"Are they the same shared object? {is_same_object}") # # You can even check the parameter tensors directly # param_0 = hypernet_0.c_dim.weight # param_1 = hypernet_1.c_dim.weight # is_same_tensor = (param_0 is param_1) # print(f"Are their 'c_dim.weight' tensors the same object? {is_same_tensor}") # print('-'*50) ### generate device = 'mps' from transformers import LlamaTokenizer tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=True) model_bb.eval() prompts = [ "The capital of France is", "Here is a simple Python function to add two numbers:" ] for i, prompt in enumerate(prompts): print(f"\n--- Prompt {i+1} ---") print(f"Input: {prompt}") # 4.1. Tokenize the Input # Convert the prompt string to PyTorch tensors inputs = tokenizer(prompt, return_tensors="pt").to(device) # 4.2. Generate Text # Use torch.no_grad() for inference with torch.no_grad(): outputs = model_bb.generate( **inputs, max_new_tokens=50, # Generate up to 50 new tokens do_sample=True, temperature=0.7, top_k=50 # Note: We don't need 'add_generation_prompt' here ) # 4.3. Decode the Output # The output includes the prompt, so we slice it output_tokens = outputs[0][inputs["input_ids"].shape[1]:] generated_text = tokenizer.decode(output_tokens, skip_special_tokens=True) print(f"Output: {generated_text}") def test_backbone(): # print(mainCfg) main_cfg=MainConfig() config = AutoConfig.from_pretrained( main_cfg.model.base_model_name ) config.hidden_size=128 config.intermediate_size=256 config.num_hidden_layers=6 config.head_dim = config.hidden_size // config.num_attention_heads main_cfg_dict = asdict(main_cfg) config.main_cfg = main_cfg_dict mode_bb = IbaXs_LlamaModel(config=config) batch_size = main_cfg.training.batch_train input = torch.ones(batch_size, 11, dtype=torch.long) total_params = sum(p.numel() for p in mode_bb.parameters()) print('input bb', input.shape, total_params) output = mode_bb(input) print('output bb', output.last_hidden_state.shape) if __name__ == "__main__": print("Hello world from XS_llama.py") # test_backbone() test_llm()