import torch import yaml from dataclasses import asdict import draccus from datasets import load_dataset import os import transformers from transformers import (AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, AutoModel, AutoConfig, TrainingArguments) import inspect from transformers import logging as hf_logging import random import numpy as np from datetime import datetime # from XS_llama import IbaXs_LlamaModel, IbaXs_LlamaForCausalLM # from utils import count_parameters # from .configIBA import MainConfig from iba import (IbaXs_LlamaModel, IbaXs_LlamaForCausalLM, HyperNetXSexp, count_parameters, MainConfig, mark_iba_as_trainable_only ) from transformers.models.llama.modeling_llama import ( LlamaMLP, LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaForCausalLM ) PROMPT_TEMPLATE = ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n{input_section}" "### Response:\n" ) # Register 'TrainConfig' as the schema for the config named 'config' DEVICE = 'cuda' # torch.compile = lambda model, *args, **kwargs: model def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) transformers.set_seed(seed) def test_generate(config, main_cfg): ### base_model_name = main_cfg.model.base_model_name if config.model_type == 'llama': # Due to the name of transformers' LlamaTokenizer, we have to do this # need to handle llama 3 separately if "lama-3" in base_model_name: print("load llama-3 tokenizer") tokenizer = AutoTokenizer.from_pretrained(base_model_name) else: tokenizer = LlamaTokenizer.from_pretrained(base_model_name, legacy=True) else: tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) model = IbaXs_LlamaForCausalLM(config=config).to(DEVICE) model.eval() prompts = [ "The capital of France is", #"Here is a simple Python function to add two numbers:" ] for i, prompt in enumerate(prompts): print(f"\n--- Prompt {i+1} ---") print(f"Input: {prompt}") # 4.1. Tokenize the Input # Convert the prompt string to PyTorch tensors inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) # 4.2. Generate Text # Use torch.no_grad() for inference with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=4, # Generate up to 50 new tokens do_sample=True, temperature=0.7, top_k=50 # Note: We don't need 'add_generation_prompt' here ) # 4.3. Decode the Output # The output includes the prompt, so we slice it output_tokens = outputs[0][inputs["input_ids"].shape[1]:] generated_text = tokenizer.decode(output_tokens, skip_special_tokens=True) print(f"Output: {generated_text}") def get_hyper_model(config, base_model_name): # Avoid to init on cpu with torch.no_grad(): torch.set_default_device('cpu') model = IbaXs_LlamaForCausalLM(config=config) # test torch.set_default_device('cpu') # Workaround to meta tensor on cuda issue. transformers.logging.set_verbosity_error() base_model_temp = LlamaForCausalLM.from_pretrained( base_model_name, config=config, device_map=None, # Strictly None low_cpu_mem_usage=False, # Force real memory torch_dtype=torch.float32 ) missing_keys, unexpected_keys = model.load_state_dict(base_model_temp.state_dict(), strict=False) base_model_temp = base_model_temp.to(DEVICE) ## Test REMEMBER: SET VALID SIZE = 1. Comment out when normal running ## compare_models(model, base_model_temp, base_model_name) del base_model_temp torch.cuda.empty_cache() # model, loading_info = IbaXs_LlamaForCausalLM.from_pretrained(base_model_name, config=config, # output_loading_info=True, # dtype=torch.float32,low_cpu_mem_usage=False,device_map=None # ) # model = model.to('cuda') # missing_keys = loading_info.get("missing_keys", []) # unexpected_keys = loading_info.get("unexpected_keys", []) if missing_keys: print('missing_keys:') for key in (missing_keys): if 'layers' in key and 'hypernetxs' not in key and 'layer_idx_hyperxs' not in key: print(f" missing: [x] {key}") else: print("\n>>> No missing keys.") if unexpected_keys: for key in unexpected_keys: print(f" [?] {key}") else: print("\n>>> No unexpected keys.") return model def compare_models(custom_model, ref_model, base_model_name, device="cuda"): """ Compares logits between the custom IbaXs model and the original Llama 2. REMEMBER: SET VALID SIZE = 1 """ def setup_precise_gpu_environment(): """ Configures PyTorch to prioritize numerical precision over speed on GPU. This helps in matching GPU results with CPU results for debugging purposes. """ # 1. DISABLE TensorFloat-32 (TF32) # By default, newer NVIDIA GPUs (Ampere+) use TF32 for matmul/conv, # which sacrifices precision for speed. # We disable it to force true Float32 calculations. torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False # 2. ENFORCE Deterministic Algorithms (Optional but Recommended) # Some CUDA operations are non-deterministic (e.g., atomic additions). # This forces PyTorch to use deterministic algorithms where possible. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # Note: If you face errors like "deterministic algorithm not found", # you might need to set the environment variable: CUBLAS_WORKSPACE_CONFIG=:4096:8 # torch.use_deterministic_algorithms(True) print(">> GPU Precision Setup: TF32 Disabled. Deterministic Mode set (partial).") setup_precise_gpu_environment() print(f"\n--- Starting Comparison on {device} {custom_model.dtype} {ref_model.dtype}---") # ref_model = ref_model.to(device) # custom_model = custom_model.to(device) ref_model.eval() custom_model.eval() # Set your model to eval mode # 2. Prepare dummy input tokenizer = AutoTokenizer.from_pretrained(base_model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token text = "Hello, this is a test for model comparison." inputs = tokenizer(text, return_tensors="pt").to(device) # Ensure inputs are on the same device as the reference model's first layer ref_inputs = inputs.to(ref_model.device) # 3. Forward pass (No gradients needed) with torch.no_grad(): print("Running inference on Custom Model...") logits_custom = custom_model(**inputs).logits print("Running inference on Reference Model...") logits_ref = ref_model(**ref_inputs).logits # 4. Compare results # Move both to CPU for comparison to avoid device mismatch errors diff = (logits_custom.cpu() - logits_ref.cpu()).abs() max_diff = diff.max().item() mean_diff = diff.mean().item() print("\n--- Comparison Results ---") print(f"Max Absolute Difference: {max_diff:.6f}") print(f"Mean Absolute Difference: {mean_diff:.6f}") # Check first few logits of the last token print("\nFirst 5 logits (Last Token):") print(f"Custom: {logits_custom[0, -1, :5].cpu().tolist()}") print(f"Ref : {logits_ref[0, -1, :5].cpu().tolist()}") if max_diff < 1e-3: print(">> VERDICT: Models are effectively IDENTICAL.") else: print(">> VERDICT: Models are DIFFERENT (Expected if custom layers are random initialized).") # Clean up reference model to free memory del ref_model torch.cuda.empty_cache() class GradientInspector: """ A debugging tool to attach hooks to PyTorch modules. It prints the gradient norm flowing through specific layers during backward pass. """ def __init__(self): self.hooks = [] def print_grad_stats(self, module, grad_input, grad_output): """ Callback function triggered during backward pass. """ from tqdm import tqdm # module_name is stored in the module object for identification name = getattr(module, 'debug_name', 'Unknown Layer') # Check Output Gradients (Gradients coming from the Loss towards this layer) if grad_output[0] is not None: out_norm = grad_output[0].norm().item() tqdm.write(f"[DEBUG-BACKWARD] {name} | Output Grad Norm (from upstream): {out_norm:.6f}") else: tqdm.write(f"[DEBUG-BACKWARD] {name} | Output Grad is None!") # Check Input Gradients (Gradients passing through this layer to the next) # Note: In backward pass, "input" usually refers to the gradients w.r.t weights or previous layer outputs if grad_input[0] is not None: in_norm = grad_input[0].norm().item() msg = (f"[DEBUG-BACKWARD] {name} | Input Grad Norm (passing downstream): {in_norm:.6f}") tqdm.write(msg) if in_norm == 0: tqdm.write(f" >>> ALARM: Gradient died at {name}!") else: # Some layers (like input embeddings) might have None grad_input at the very end pass def register_hooks(self, model): from tqdm import tqdm """ Recursively attach hooks to important modules. """ tqdm.write("Registering debug hooks...") # 1. Hook into the Hypernetwork Output (The most critical bridge) # Assuming model.hypernet is your hypernetwork instance if hasattr(model.model, 'hypernetxs'): model.model.hypernetxs.debug_name = "HyperNetwork_Top" # Hook the whole hypernet module handle = model.model.hypernetxs.register_full_backward_hook(self.print_grad_stats) self.hooks.append(handle) # Hook specifically the last linear layer of hypernet to see if weights get update if hasattr(model.model.hypernetxs, 'c_proj'): last_layer = model.model.hypernetxs.c_proj last_layer.debug_name = "HyperNetwork_Last_Linear" handle = last_layer.register_full_backward_hook(self.print_grad_stats) self.hooks.append(handle) # 2. Hook into a few Dynamic Layers (e.g., the first and last one) # Assuming you used the wrapper or replaced layers in base_model count = 0 for name, module in model.named_modules(): # Adjust 'DynamicSVDLinear' to match your actual class name if "Linear" in str(type(module)): if count == 0: # First dynamic layer module.debug_name = f"DynamicLayer_First_{name}" handle = module.register_full_backward_hook(self.print_grad_stats) self.hooks.append(handle) # You can add logic to hook the last one too count += 1 print(f"Registered {len(self.hooks)} hooks.") def clear_hooks(self): for h in self.hooks: h.remove() def reset_trainable_modules(model): for name, module in model.named_modules(): if isinstance(module, HyperNetXSexp) or isinstance(module, IbaXs_LlamaModel): if hasattr(module, 'reset_parameters'): module.reset_parameters() print('reset: ', name) return model def trainIBA(config, main_cfg): training_cfg = main_cfg.training data_cfg = main_cfg.data valid_hf_arg_names = set(inspect.signature(TrainingArguments).parameters.keys()) training_config_dict = asdict(training_cfg) filtered_trainer_args_dict = { key: value for key, value in training_config_dict.items() if key in valid_hf_arg_names } trainer_args = TrainingArguments(**filtered_trainer_args_dict) gradient_accumulation_steps = training_cfg.gradient_accumulation_steps device_map = "auto" world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 if ddp: device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} gradient_accumulation_steps = gradient_accumulation_steps // world_size base_model_name = main_cfg.model.base_model_name # A ramdom model to debug # with torch.no_grad(): # torch.set_default_device('cuda') # model = IbaXs_LlamaForCausalLM(config=config) # test # torch.set_default_device('cpu') # SVD caluation for each rank. if False: model = get_hyper_model(config=config, base_model_name=base_model_name) # print('device', model.device) mark_iba_as_trainable_only(model) count_parameters(model) model.reset_BA_xslora() model.save_pretrained('./SVD64_llama2', safe_serialization=False) exit() else: hf_logging.set_verbosity_error() model = IbaXs_LlamaForCausalLM.from_pretrained( './SVD64_llama2', device_map="auto", dtype=torch.bfloat16, config=config, local_files_only=True, # Strictly force loading from local, no internet check for config ignore_mismatched_sizes=True ) hf_logging.set_verbosity_warning() # reset trainable hypernets model = reset_trainable_modules(model) mark_iba_as_trainable_only(model) count_parameters(model) # for n, p in model.named_parameters(): # if 'hypernetxs' not in n: # print(f'n = {n}, shape {p.shape}') # print(model) if config.model_type == 'llama': # Due to the name of transformers' LlamaTokenizer, we have to do this # need to handle llama 3 separately if "lama-3" in base_model_name: print("load llama-3 tokenizer") tokenizer = AutoTokenizer.from_pretrained(base_model_name) else: tokenizer = LlamaTokenizer.from_pretrained(base_model_name, legacy=True) else: tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) tokenizer.pad_token_id = ( 0 # unk. we want this to be different from the eos token ) tokenizer.padding_side = "left" # Allow batched inference def tokenize(prompt, max_length=main_cfg.model.cutoff_len, add_eos_token=True): result = tokenizer( prompt, truncation=True, max_length=main_cfg.model.cutoff_len, padding=False, return_tensors=None, ) if ( result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < max_length and add_eos_token ): result["input_ids"].append(tokenizer.eos_token_id) if "chatglm" not in base_model_name: result["attention_mask"].append(1) result["labels"] = result["input_ids"].copy() if "chatglm" in base_model_name: return {"input_ids": result["input_ids"], "labels": result["labels"]} else: return result def generate_and_tokenize_prompt(data_point): instruction = data_point.get("instruction", "") inp = data_point.get("input", "") target_output = data_point.get("output", "") # "the correct answer is true" # Match your EVAL template exactly input_section = f"### Input:\n{inp}\n\n" if inp and str(inp).strip() else "" source_text = PROMPT_TEMPLATE.format( instruction=instruction, input_section=input_section ) full_text = source_text + target_output + tokenizer.eos_token tokenized_full = tokenizer(full_text, truncation=True, max_length=main_cfg.model.cutoff_len, padding=False) if not main_cfg.model.train_on_inputs: tokenized_source = tokenizer(source_text, truncation=True, max_length=main_cfg.model.cutoff_len, padding=False) source_len = len(tokenized_source["input_ids"]) # Ensure we don't mask the entire sequence labels = [-100] * source_len + tokenized_full["input_ids"][source_len:] tokenized_full["labels"] = labels return tokenized_full # outdated def generate_and_tokenize_prompt3(data_point): """ Standardizes training data to match Eval template and handles label masking. """ instruction = data_point.get("instruction", "") inp = data_point.get("input", "") output = data_point.get("output", "") # The target we want to train on # 1. Format Input Section if inp and str(inp).strip(): input_section = f"### Input:\n{inp}\n\n" else: input_section = "" # 2. Build Source (Prompt) and Full Text source_text = PROMPT_TEMPLATE.format( instruction=instruction, input_section=input_section ) full_text = source_text + output + tokenizer.eos_token # 3. Tokenize tokenized_full = tokenizer( full_text, truncation=True, max_length=main_cfg.model.cutoff_len, padding=False, ) # 4. Handle Labels (Masking the Instruction part) # Only calculate loss on the 'output' part if not training_cfg.train_on_inputs: tokenized_source = tokenizer( source_text, truncation=True, max_length=main_cfg.model.cutoff_len, padding=False, ) source_len = len(tokenized_source["input_ids"]) # Mask prompt tokens with -100 so they are ignored by CrossEntropyLoss tokenized_full["labels"] = [ -100 if i < source_len else token_id for i, token_id in enumerate(tokenized_full["input_ids"]) ] else: tokenized_full["labels"] = tokenized_full["input_ids"].copy() return tokenized_full if data_cfg.data_path.endswith(".json"): data = load_dataset("json", data_files=data_cfg.data_path) else: data = load_dataset(data_cfg.data_path) ### Check later if training_cfg.resume_from_checkpoint: # Check the available weights and load them checkpoint_name = os.path.join( resume_from_checkpoint, "pytorch_model.bin" ) # Full checkpoint if not os.path.exists(checkpoint_name): checkpoint_name = os.path.join( resume_from_checkpoint, "adapter_model.bin" ) # only LoRA model - LoRA config above has to fit resume_from_checkpoint = ( False # So the trainer won't try loading its state ) # The two files above have a different name depending on how they were saved, but are actually the same. if os.path.exists(checkpoint_name): print(f"Restarting from {checkpoint_name}") model = IbaXs_LlamaModel.from_pretrained("./my-saved-model") else: print(f"Checkpoint {checkpoint_name} not found") if main_cfg.data.val_set_size > 0: train_val = data["train"].train_test_split( test_size=main_cfg.data.val_set_size, shuffle=True, seed=42 ) train_data = ( train_val["train"].map(generate_and_tokenize_prompt, num_proc=8) ) val_data = ( train_val["test"].map(generate_and_tokenize_prompt) ) else: train_data = data["train"].shuffle().map(generate_and_tokenize_prompt, num_proc=8) val_data = None print('data size', len(train_data), len(val_data)) # print('val data', type(val_data), val_data) # for k,v in val_data[0].items(): # print('kv', k, ': ', v) # exit() # count_parameters(model) # Gradient debug # inspector = GradientInspector() # inspector.register_hooks(model) start_time = datetime.now() date_str = start_time.strftime("%dd%Hh%Mm%S") output_dir = f'{trainer_args.output_dir}/{main_cfg.data.dataset_name}/'\ f't={date_str},' \ f'mlr{trainer_args.learning_rate:.1e},'\ f'b{trainer_args.per_device_train_batch_size},'\ f'r{main_cfg.hyperxs.lora_attn_dim},n_ct{main_cfg.hyperxs.n_cross_attn_tokens},'\ f't{date_str},' \ f'init{main_cfg.run_text},dr{main_cfg.hyperxs.drop_out},'\ f'ep{trainer_args.num_train_epochs},' \ f'ds{len(train_data)}' trainer_args.output_dir=output_dir print(f'Current output_dir: {output_dir}') # trainer_args.run_name = f'[{next_run_num}]'\ # f't={date_str}', \ # f'mlr{trainer_args.learning_rate:.1e},'\ # f'b{trainer_args.per_device_train_batch_size},'\ # f'r{main_cfg.hyperxs.lora_attn_dim},n_ct{main_cfg.hyperxs.n_cross_attn_tokens},'\ # f't{date_str},' \ # f'init={main_cfg.run_text},dr{main_cfg.hyperxs.drop_out},'\ # f'ep{trainer_args.num_train_epochs},' \ # f'ds={len(train_data)}' # print('Run nume: ', trainer_args.run_name) trainer = transformers.Trainer( model=model, train_dataset=train_data, eval_dataset=val_data, args=trainer_args, data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True ), ) model.config.use_cache = False # trainer.train(resume_from_checkpoint=training_cfg.resume_from_checkpoint) trainer.train() end_time = datetime.now() print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time) tokenizer.save_pretrained(os.path.join(trainer_args.output_dir, 'ft')) trainer.save_state() config.save_pretrained(os.path.join(trainer_args.output_dir, 'ft')) model.save_pretrained(os.path.join(trainer_args.output_dir, 'ft2'), safe_serialization=False) # inspector.clear_hooks() @draccus.wrap(config_path="./config_draccus/config.yaml") def main(main_cfg: MainConfig): # print('Hello\n', main_cfg) main_cfg_dict = asdict(main_cfg) # print(yaml.dump(main_cfg_dict, indent=2, default_flow_style=False)) config = AutoConfig.from_pretrained( main_cfg.model.base_model_name, # attn_implementation="eager", ) # config.hidden_size=128 # config.intermediate_size=290 # config.num_hidden_layers=3 # # config._attn_implementation = "eager" # config.head_dim = config.hidden_size // config.num_attention_heads # main_cfg_dict = asdict(main_cfg) config.main_cfg = main_cfg_dict set_seed(main_cfg.seed) trainIBA(config, main_cfg) if __name__ == "__main__": main()