Instance-based-FT / iba /XS_llama.py
nvan13's picture
Upload folder using huggingface_hub
a0d95b0 verified
# mypy: ignore-errors
# A single seq of representive cross-attention tokens is added at the begining only.
# the next layer re-use output from the previous layer
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Callable, Optional, Union
import functools
from dataclasses import asdict
from transformers.models.llama.modeling_llama import (
LlamaMLP,
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaForCausalLM
)
from transformers import AutoConfig, PretrainedConfig
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaConfig as HFLlamaConfig
from transformers.processing_utils import Unpack
from transformers.masking_utils import create_causal_mask
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from .Xslora import LoraXSLinear, HyperNetXSexp
from .configIBA import MainConfig, HyperXSConfig, TrainingConfig, from_dict
class IbaXs_LlamaAttention(LlamaAttention):
def __init__(self, config: HFLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
# Get main_config as a dataclass object
main_cfg = from_dict(MainConfig, config.main_cfg)
lora_attn_dim = main_cfg.hyperxs.lora_attn_dim
train_cfg = main_cfg.training
self.q_proj = LoraXSLinear(
config.hidden_size, config.num_attention_heads * self.head_dim,
train_cfg=train_cfg, rank = lora_attn_dim,
bias=config.attention_bias
)
self.k_proj = LoraXSLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim,
train_cfg=train_cfg, rank = lora_attn_dim,
bias=config.attention_bias
)
self.v_proj = LoraXSLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim,
train_cfg=train_cfg, rank = lora_attn_dim,
bias=config.attention_bias
)
self.o_proj = LoraXSLinear(
config.num_attention_heads * self.head_dim, config.hidden_size,
train_cfg=train_cfg, rank = lora_attn_dim,
bias=config.attention_bias
)
class IbaXs_LlamaMLP(LlamaMLP):
def __init__(self, config: HFLlamaConfig):
super().__init__(config)
# Get main_config as a dataclass object
main_cfg = from_dict(MainConfig, config.main_cfg)
lora_attn_dim = main_cfg.hyperxs.lora_attn_dim
train_cfg = main_cfg.training
self.gate_proj = LoraXSLinear(self.hidden_size, self.intermediate_size,
train_cfg=train_cfg, rank = lora_attn_dim,
bias=config.mlp_bias)
self.up_proj = LoraXSLinear(self.hidden_size, self.intermediate_size,
train_cfg=train_cfg, rank = lora_attn_dim,
bias=config.mlp_bias)
self.down_proj = LoraXSLinear(self.intermediate_size, self.hidden_size,
train_cfg=train_cfg, rank = lora_attn_dim,
bias=config.mlp_bias)
# block layer
class IbaXs_LlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: HFLlamaConfig,
layer_idx: int,
hypernetxs: HyperNetXSexp = None,
):
super().__init__(config, layer_idx)
self.hypernetxs = hypernetxs
self.hfconfig = config
# Get main_config as a dataclass object
main_cfg = from_dict(MainConfig, config.main_cfg)
self.hyperxs_cfg = main_cfg.hyperxs
self.n_cross_attn_tokens = main_cfg.hyperxs.n_cross_attn_tokens
# Replace
self.self_attn = IbaXs_LlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = IbaXs_LlamaMLP(config)
#self.cross_attn_tokens = nn.Parameter(torch.empty(main_cfg.hyperxs.n_cross_attn_tokens,
# hf_model_cfg.hidden_size))
# In case of to(device) -> do not use self.layer_idx = LongTensor(layer_idx)
self.register_buffer('layer_idx_hyperxs', torch.tensor(layer_idx, dtype=torch.long))
# self.flag_hyper = True
self.__loraxsTensor = None
self.layer_idx = layer_idx
# self.reset_parameters()
def get_cache_loraxs(self):
loraxsTensor = self.__loraxsTensor
# self.loraxsTensor = None
return loraxsTensor
def reset_parameters(self):
INIT_STD = 0.01
# nn.init.normal_(self.hypernetxs_cross_attn_tokens, mean=0.0, std=INIT_STD)
def set_loraxs_adapters(self, loraXsTensor: Tensor):
# (batch, modules, rank, rank)
if loraXsTensor is None:
raise ModuleNotFoundError
applied_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
idx = 0
for key in applied_modules:
for name, module in self.named_modules():
# print('name', name, type(name))
if name.endswith(key):
if isinstance(module, LoraXSLinear):
module.set_R(loraXsTensor[:, idx, : , :].contiguous())
idx = idx + 1
# print(f'name: {name}. R: {module.lora_train_R.shape, module.lora_train_R[1,1,1]}')
else:
raise NotImplementedError
# def set_flag_hyper(self, flag: bool = False):
# self.flag_hyper = flag
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
### addtional arg
flag_hyper: Optional[bool] = True,
**kwargs #: Unpack[TransformersKwargs],
) -> torch.Tensor:
# if self.flag_hyper:
# batch_size = hidden_states.shape[0]
# hypernetxs_cross_attn_tokens = self.hypernetxs_cross_attn_tokens.expand(int(batch_size), -1, -1)
# print('batch', batch_size, hypernetxs_cross_attn_tokens.shape)
# hidden_states = torch.concat((hypernetxs_cross_attn_tokens, hidden_states), dim=1)
# Copy paste modify from modeling_llama.py
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Extract representative tokens
if flag_hyper:
cross_attention = hidden_states[:, 0:self.n_cross_attn_tokens, :]
# Still push cross_attention to the next layer
# hidden_states = hidden_states[:, self.n_cross_attn_tokens:, :]
# save all lora adapters as a attribute
self.__loraxsTensor = self.hypernetxs(cross_attention, self.layer_idx)
# (batch, n_modules, r, r)
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
#back bone models
class IbaXs_LlamaModel(LlamaModel):
def __init__(self, config: HFLlamaConfig):
super().__init__(config)
# Get main_config as a dataclass object
main_cfg = from_dict(MainConfig, config.main_cfg)
self.hyperxs_cfg = main_cfg.hyperxs
self.hypernetxs = HyperNetXSexp(main_cfg.hyperxs, config)
self.layers = nn.ModuleList(
[IbaXs_LlamaDecoderLayer(config, layer_idx, self.hypernetxs) \
for layer_idx in range(config.num_hidden_layers)]
)
self.flag_hyper = True
self.hypernetxs_cross_attn_tokens = nn.Parameter(torch.zeros(main_cfg.hyperxs.n_cross_attn_tokens,
config.hidden_size))
self.main_cfg = main_cfg
# self.reset_parameters()
def reset_parameters(self):
INIT_STD = 0.01
nn.init.normal_(self.hypernetxs_cross_attn_tokens, mean=0.0, std=INIT_STD)
def _create_prefix_or_mask(
self,
batch_idx: torch.Tensor,
head_idx: torch.Tensor,
q_idx: torch.Tensor, #
kv_idx: torch.Tensor, #
) -> torch.Tensor:
"""
Creates a mask to UNLOCK specific regions.
Boolean values will be process data inside create_causal_mask
1. Prefix-sees-Prefix (bidirectional)
2. Prefix-sees-Text (all)
"""
prefix_len = self.hypernetxs_cross_attn_tokens.shape[0] # K (int)
# 1. Query is Prefix?
is_query_prefix = q_idx < prefix_len
# kv_idx [1, 1, 1, K] compared with safe_boundaries [Batch, 1, 1, 1]
is_key_safe = kv_idx < self.safe_boundaries
return is_query_prefix & is_key_safe
def _create_prefix_and_mask(
self,
batch_idx: torch.Tensor,
head_idx: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
) -> torch.Tensor:
"""
Creates a mask to LOCK specific regions.
1. Text-sees-Prefix
"""
prefix_len = self.hypernetxs_cross_attn_tokens.shape[0] # K (int)
# The "forbidden" zone is:
# Query is Text (q_idx >= prefix_len)
# AND
# Key is Prefix (kv_idx < prefix_len)
is_forbidden = (q_idx >= prefix_len) & (kv_idx < prefix_len)
# Return True if *not* in the forbidden zone.
# ~ is the vmap-safe "NOT" operator for boolean tensors.
# if q_idx.item() <= 10 and kv_idx.item() <= 10:
# print('is_forbidden', ~is_forbidden)
return ~is_forbidden
# @check_model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs #: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
is_prefill = (past_key_values is None) or \
(hasattr(past_key_values, 'get_seq_length') and past_key_values.get_seq_length() == 0)
prefix_len = self.main_cfg.hyperxs.n_cross_attn_tokens \
if self.main_cfg.hyperxs.n_cross_attn_tokens is not None and is_prefill else 0
# if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] +
prefix_len, device=inputs_embeds.device
)
#if position_ids is None:
# Count from K (prefix) + S (sequence)
position_ids = cache_position.unsqueeze(0).expand(inputs_embeds.shape[0], -1)
####
# prefix mask boundary from labels
batch_size , seq_len_input = inputs_embeds.shape[:2]
safe_boundaries = torch.full(
(batch_size, 1, 1, 1),
seq_len_input + prefix_len,
device=inputs_embeds.device,
dtype=torch.long
)
if labels is not None and is_prefill:
# labels shape: [Batch, Seq_Len]
# Find the FIRST index where label != -100 for EACH row in the batch.
is_real_label = (labels != -100)
# all False (all -100) -> 0
prompt_lens = is_real_label.int().argmax(dim=1)
has_label = is_real_label.any(dim=1)
# If a row has no labels (all -100), prompt_len should be the full sequence length
prompt_lens = torch.where(
has_label,
prompt_lens,
torch.tensor(seq_len_input, device=inputs_embeds.device)
)
# Calculate safe boundary: Prefix Length + Prompt Length
# Reshape to [Batch, 1, 1, 1] to allow comparison with kv_idx [1, 1, 1, Seq]
safe_boundaries = (prefix_len + prompt_lens).view(batch_size, 1, 1, 1)
self.safe_boundaries = safe_boundaries
####
hidden_states = inputs_embeds
# concat A SINGLE seq of tokens:
active_or_mask_func = None
active_and_mask_func = None
# No cache or empty cache
# if past_key_values is None or past_key_values.get_seq_length() == 0:
if is_prefill:
self.flag_hyper = True
# Use functools.partial to pass `self`
active_or_mask_func = self._create_prefix_or_mask
active_and_mask_func = self._create_prefix_and_mask
if self.hypernetxs_cross_attn_tokens is not None:
batch_size = hidden_states.shape[0]
# prefix cross-attention tokens
prefix_embeds = self.hypernetxs_cross_attn_tokens.expand(int(batch_size), -1, -1)
hidden_states = torch.concat((prefix_embeds, hidden_states), dim=1)
# modify causal_mask ## NEED to check carefully later
if attention_mask is not None:
prefix_attention_mask = torch.ones((batch_size, prefix_len),
dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([prefix_attention_mask,
attention_mask], dim=1)
else:
# generating mode
self.flag_hyper = False
# position_ids = text_position_ids # cache_position.unsqueeze(0)
###
# Need to check at generate()
# print('attention_mask', attention_mask, attention_mask.shape, input_ids.shape)
# causal_mask = create_causal_mask(
# config=self.config,
# input_embeds=hidden_states,
# # attention_mask=attention_mask,
# attention_mask = None,
# cache_position=cache_position,
# past_key_values=past_key_values,
# position_ids=position_ids,
# # Pass custom logic. Not work.
# or_mask_function=active_or_mask_func,
# and_mask_function=active_and_mask_func
# )
causal_mask = None
if is_prefill:
current_seq_len = hidden_states.shape[1]
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
# Grid
q_idx = torch.arange(current_seq_len, device=inputs_embeds.device).view(1, 1, current_seq_len, 1)
k_idx = torch.arange(current_seq_len, device=inputs_embeds.device).view(1, 1, 1, current_seq_len)
# Basic Causal Mask
mask_bool = q_idx >= k_idx
# C. Logic Custom (Prefill)
# Logic 1: Prefix Unlock
prefix_unlock = (q_idx < prefix_len) & (k_idx < safe_boundaries)
mask_bool = mask_bool | prefix_unlock
# Logic 2: Text Forbidden
text_forbidden_prefix = (q_idx >= prefix_len) & (k_idx < prefix_len)
mask_bool = mask_bool & (~text_forbidden_prefix)
#Float Mask (Bias)
causal_mask = torch.full_like(mask_bool, min_dtype, dtype=dtype)
causal_mask = causal_mask.masked_fill(mask_bool, 0.0)
# Add Padding Mask
if attention_mask is not None:
padding_mask_float = (1.0 - attention_mask.to(dtype)) * min_dtype
padding_mask_float = padding_mask_float[:, None, None, :]
causal_mask = causal_mask + padding_mask_float
# (Prevent 8D & SDPA Compatibility)
causal_mask = causal_mask.contiguous()
else:
# --- GENERATE (DECODING) ---
self.flag_hyper = False
# Để causal_mask = None. FLASH ATTENTION
pass
# print('causal_mask', type(causal_mask), causal_mask.dtype, causal_mask.shape)
####
####
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
flag_hyper = self.flag_hyper,
**kwargs,
)
### Modification
if idx < self.config.num_hidden_layers - 1 and self.flag_hyper:
self.layers[idx+1].set_loraxs_adapters(decoder_layer.get_cache_loraxs())
### Apply previous output to the next stage
### remove the representative cross-attention tokens.
if self.flag_hyper:
hidden_states = hidden_states[:, self.main_cfg.hyperxs.n_cross_attn_tokens:, :]
###
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
class IbaXs_LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, config: HFLlamaConfig,
):
super().__init__(config)
self.model = IbaXs_LlamaModel(config)
def reset_BA_xslora(self):
for name, module in self.named_modules():
if isinstance(module, LoraXSLinear):
module.decompose_weight_svd(module.rank)
# print('Reset BA for', name)
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
labels=labels,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def test_set_loraxs_adapters():
main_cfg=MainConfig()
# print(mainCfg)
hf_model_cfg = AutoConfig.from_pretrained(
main_cfg.model.base_model_name
)
#hypernetxs = HyperNetXSexp(hf_model_cfg = hf_model_cfg, hyperxs_cfg=mainCfg.hyperxs)
layer = IbaXs_LlamaDecoderLayer(hf_model_cfg, main_cfg=main_cfg, layer_idx=1)
rank = main_cfg.hyperxs.lora_attn_dim
batch_size = main_cfg.training.batch_train
modules = 7
values = torch.arange(1, modules + 1)
values_reshaped = values.view(modules, 1, 1)
loraTensor = values_reshaped.expand(batch_size, modules, rank, rank)
layer.set_loraxs_adapters(loraTensor)
def test_llm():
# print(mainCfg)
main_cfg=MainConfig()
config = AutoConfig.from_pretrained(
main_cfg.model.base_model_name
)
config.hidden_size=128
config.intermediate_size=256
config.num_hidden_layers=6
config.head_dim = config.hidden_size // config.num_attention_heads
main_cfg_dict = asdict(main_cfg)
config.main_cfg = main_cfg_dict
model_bb = IbaXs_LlamaForCausalLM(config=config)
model_bb.reset_BA_xslora()
batch_size = main_cfg.training.per_device_train_batch_size
input = torch.ones(batch_size, 11, dtype=torch.long)
total_params = sum(p.numel() for p in model_bb.parameters())
print('input llm', input.shape, total_params)
# inference
output = model_bb(input,logits_to_keep=1)
print('output llm', output.logits.shape)
# Assuming 'model' is your instantiated IbaXs_LlamaModel
# model = model_bb.model
# if hasattr(model, 'layers') and len(model.layers) > 1:
# # Get the hypernet object from layer 0 and layer 1
# hypernet_0 = model.layers[0].hypernetxs
# hypernet_1 = model.layers[1].hypernetxs
# # Check if they are the same object in memory
# is_same_object = (hypernet_0 is hypernet_1)
# print(f"Hypernet from Layer 0 ID: {id(hypernet_0)}")
# print(f"Hypernet from Layer 1 ID: {id(hypernet_1)}")
# print(f"Are they the same shared object? {is_same_object}")
# # You can even check the parameter tensors directly
# param_0 = hypernet_0.c_dim.weight
# param_1 = hypernet_1.c_dim.weight
# is_same_tensor = (param_0 is param_1)
# print(f"Are their 'c_dim.weight' tensors the same object? {is_same_tensor}")
# print('-'*50)
### generate
device = 'mps'
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=True)
model_bb.eval()
prompts = [
"The capital of France is",
"Here is a simple Python function to add two numbers:"
]
for i, prompt in enumerate(prompts):
print(f"\n--- Prompt {i+1} ---")
print(f"Input: {prompt}")
# 4.1. Tokenize the Input
# Convert the prompt string to PyTorch tensors
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# 4.2. Generate Text
# Use torch.no_grad() for inference
with torch.no_grad():
outputs = model_bb.generate(
**inputs,
max_new_tokens=50, # Generate up to 50 new tokens
do_sample=True,
temperature=0.7,
top_k=50
# Note: We don't need 'add_generation_prompt' here
)
# 4.3. Decode the Output
# The output includes the prompt, so we slice it
output_tokens = outputs[0][inputs["input_ids"].shape[1]:]
generated_text = tokenizer.decode(output_tokens, skip_special_tokens=True)
print(f"Output: {generated_text}")
def test_backbone():
# print(mainCfg)
main_cfg=MainConfig()
config = AutoConfig.from_pretrained(
main_cfg.model.base_model_name
)
config.hidden_size=128
config.intermediate_size=256
config.num_hidden_layers=6
config.head_dim = config.hidden_size // config.num_attention_heads
main_cfg_dict = asdict(main_cfg)
config.main_cfg = main_cfg_dict
mode_bb = IbaXs_LlamaModel(config=config)
batch_size = main_cfg.training.batch_train
input = torch.ones(batch_size, 11, dtype=torch.long)
total_params = sum(p.numel() for p in mode_bb.parameters())
print('input bb', input.shape, total_params)
output = mode_bb(input)
print('output bb', output.last_hidden_state.shape)
if __name__ == "__main__":
print("Hello world from XS_llama.py")
# test_backbone()
test_llm()