| | import torch |
| | import yaml |
| | from dataclasses import asdict |
| | import draccus |
| |
|
| | from datasets import load_dataset |
| |
|
| | import os |
| | import transformers |
| | from transformers import (AutoModelForCausalLM, AutoTokenizer, |
| | LlamaTokenizer, AutoModel, AutoConfig, |
| | TrainingArguments) |
| | import inspect |
| | from transformers import logging as hf_logging |
| |
|
| | import random |
| | import numpy as np |
| | from datetime import datetime |
| |
|
| | |
| | |
| | |
| | from iba import (IbaXs_LlamaModel, IbaXs_LlamaForCausalLM, |
| | HyperNetXSexp, |
| | count_parameters, MainConfig, mark_iba_as_trainable_only |
| | ) |
| |
|
| | from transformers.models.llama.modeling_llama import ( |
| | LlamaMLP, |
| | LlamaAttention, |
| | LlamaDecoderLayer, |
| | LlamaModel, |
| | LlamaForCausalLM |
| | ) |
| |
|
| | PROMPT_TEMPLATE = ( |
| | "Below is an instruction that describes a task. " |
| | "Write a response that appropriately completes the request.\n\n" |
| | "### Instruction:\n{instruction}\n\n{input_section}" |
| | "### Response:\n" |
| | ) |
| |
|
| | |
| | DEVICE = 'cuda' |
| | |
| |
|
| | def set_seed(seed: int): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | transformers.set_seed(seed) |
| |
|
| | def test_generate(config, main_cfg): |
| | |
| | base_model_name = main_cfg.model.base_model_name |
| | if config.model_type == 'llama': |
| | |
| | |
| | if "lama-3" in base_model_name: |
| | print("load llama-3 tokenizer") |
| | tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| | else: |
| | tokenizer = LlamaTokenizer.from_pretrained(base_model_name, legacy=True) |
| | else: |
| | tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) |
| | |
| | model = IbaXs_LlamaForCausalLM(config=config).to(DEVICE) |
| | model.eval() |
| | prompts = [ |
| | "The capital of France is", |
| | |
| | ] |
| | for i, prompt in enumerate(prompts): |
| | print(f"\n--- Prompt {i+1} ---") |
| | print(f"Input: {prompt}") |
| |
|
| | |
| | |
| | inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
| |
|
| | |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=4, |
| | do_sample=True, |
| | temperature=0.7, |
| | top_k=50 |
| | |
| | ) |
| |
|
| | |
| | |
| | output_tokens = outputs[0][inputs["input_ids"].shape[1]:] |
| | generated_text = tokenizer.decode(output_tokens, skip_special_tokens=True) |
| |
|
| | print(f"Output: {generated_text}") |
| |
|
| | def get_hyper_model(config, base_model_name): |
| | |
| | with torch.no_grad(): |
| | torch.set_default_device('cpu') |
| | model = IbaXs_LlamaForCausalLM(config=config) |
| | torch.set_default_device('cpu') |
| | |
| | transformers.logging.set_verbosity_error() |
| | base_model_temp = LlamaForCausalLM.from_pretrained( |
| | base_model_name, |
| | config=config, |
| | device_map=None, |
| | low_cpu_mem_usage=False, |
| | torch_dtype=torch.float32 |
| | ) |
| | missing_keys, unexpected_keys = model.load_state_dict(base_model_temp.state_dict(), strict=False) |
| | base_model_temp = base_model_temp.to(DEVICE) |
| | |
| | |
| | del base_model_temp |
| | torch.cuda.empty_cache() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if missing_keys: |
| | print('missing_keys:') |
| | for key in (missing_keys): |
| | if 'layers' in key and 'hypernetxs' not in key and 'layer_idx_hyperxs' not in key: |
| | print(f" missing: [x] {key}") |
| | else: |
| | print("\n>>> No missing keys.") |
| | if unexpected_keys: |
| | for key in unexpected_keys: |
| | print(f" [?] {key}") |
| | else: |
| | print("\n>>> No unexpected keys.") |
| | return model |
| | def compare_models(custom_model, ref_model, base_model_name, device="cuda"): |
| | """ |
| | Compares logits between the custom IbaXs model and the original Llama 2. |
| | REMEMBER: SET VALID SIZE = 1 |
| | """ |
| | def setup_precise_gpu_environment(): |
| | """ |
| | Configures PyTorch to prioritize numerical precision over speed on GPU. |
| | This helps in matching GPU results with CPU results for debugging purposes. |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | torch.backends.cuda.matmul.allow_tf32 = False |
| | torch.backends.cudnn.allow_tf32 = False |
| |
|
| | |
| | |
| | |
| | torch.backends.cudnn.benchmark = False |
| | torch.backends.cudnn.deterministic = True |
| | |
| | |
| | |
| | |
| |
|
| | print(">> GPU Precision Setup: TF32 Disabled. Deterministic Mode set (partial).") |
| | setup_precise_gpu_environment() |
| |
|
| | print(f"\n--- Starting Comparison on {device} {custom_model.dtype} {ref_model.dtype}---") |
| | |
| | |
| | ref_model.eval() |
| | custom_model.eval() |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | text = "Hello, this is a test for model comparison." |
| | inputs = tokenizer(text, return_tensors="pt").to(device) |
| | |
| | |
| | ref_inputs = inputs.to(ref_model.device) |
| |
|
| | |
| | with torch.no_grad(): |
| | print("Running inference on Custom Model...") |
| | logits_custom = custom_model(**inputs).logits |
| | |
| | print("Running inference on Reference Model...") |
| | logits_ref = ref_model(**ref_inputs).logits |
| |
|
| | |
| | |
| | diff = (logits_custom.cpu() - logits_ref.cpu()).abs() |
| | max_diff = diff.max().item() |
| | mean_diff = diff.mean().item() |
| |
|
| | print("\n--- Comparison Results ---") |
| | print(f"Max Absolute Difference: {max_diff:.6f}") |
| | print(f"Mean Absolute Difference: {mean_diff:.6f}") |
| | |
| | |
| | print("\nFirst 5 logits (Last Token):") |
| | print(f"Custom: {logits_custom[0, -1, :5].cpu().tolist()}") |
| | print(f"Ref : {logits_ref[0, -1, :5].cpu().tolist()}") |
| |
|
| | if max_diff < 1e-3: |
| | print(">> VERDICT: Models are effectively IDENTICAL.") |
| | else: |
| | print(">> VERDICT: Models are DIFFERENT (Expected if custom layers are random initialized).") |
| |
|
| | |
| | del ref_model |
| | torch.cuda.empty_cache() |
| |
|
| | class GradientInspector: |
| | """ |
| | A debugging tool to attach hooks to PyTorch modules. |
| | It prints the gradient norm flowing through specific layers during backward pass. |
| | """ |
| | |
| | def __init__(self): |
| | self.hooks = [] |
| |
|
| | def print_grad_stats(self, module, grad_input, grad_output): |
| | """ |
| | Callback function triggered during backward pass. |
| | """ |
| | from tqdm import tqdm |
| | |
| | name = getattr(module, 'debug_name', 'Unknown Layer') |
| | |
| | |
| | if grad_output[0] is not None: |
| | out_norm = grad_output[0].norm().item() |
| | tqdm.write(f"[DEBUG-BACKWARD] {name} | Output Grad Norm (from upstream): {out_norm:.6f}") |
| | else: |
| | tqdm.write(f"[DEBUG-BACKWARD] {name} | Output Grad is None!") |
| |
|
| | |
| | |
| | if grad_input[0] is not None: |
| | in_norm = grad_input[0].norm().item() |
| | msg = (f"[DEBUG-BACKWARD] {name} | Input Grad Norm (passing downstream): {in_norm:.6f}") |
| | |
| | tqdm.write(msg) |
| | |
| | if in_norm == 0: |
| | tqdm.write(f" >>> ALARM: Gradient died at {name}!") |
| | else: |
| | |
| | pass |
| |
|
| | def register_hooks(self, model): |
| | from tqdm import tqdm |
| | """ |
| | Recursively attach hooks to important modules. |
| | """ |
| | tqdm.write("Registering debug hooks...") |
| | |
| | |
| | |
| | if hasattr(model.model, 'hypernetxs'): |
| | model.model.hypernetxs.debug_name = "HyperNetwork_Top" |
| | |
| | handle = model.model.hypernetxs.register_full_backward_hook(self.print_grad_stats) |
| | self.hooks.append(handle) |
| | |
| | |
| | if hasattr(model.model.hypernetxs, 'c_proj'): |
| | last_layer = model.model.hypernetxs.c_proj |
| | last_layer.debug_name = "HyperNetwork_Last_Linear" |
| | handle = last_layer.register_full_backward_hook(self.print_grad_stats) |
| | self.hooks.append(handle) |
| |
|
| | |
| | |
| | count = 0 |
| | for name, module in model.named_modules(): |
| | |
| | if "Linear" in str(type(module)): |
| | if count == 0: |
| | module.debug_name = f"DynamicLayer_First_{name}" |
| | handle = module.register_full_backward_hook(self.print_grad_stats) |
| | self.hooks.append(handle) |
| | |
| | count += 1 |
| | |
| | print(f"Registered {len(self.hooks)} hooks.") |
| |
|
| | def clear_hooks(self): |
| | for h in self.hooks: |
| | h.remove() |
| |
|
| | def reset_trainable_modules(model): |
| | for name, module in model.named_modules(): |
| | if isinstance(module, HyperNetXSexp) or isinstance(module, IbaXs_LlamaModel): |
| | if hasattr(module, 'reset_parameters'): |
| | module.reset_parameters() |
| | print('reset: ', name) |
| | return model |
| | |
| | |
| | def trainIBA(config, main_cfg): |
| | training_cfg = main_cfg.training |
| | data_cfg = main_cfg.data |
| |
|
| | valid_hf_arg_names = set(inspect.signature(TrainingArguments).parameters.keys()) |
| | training_config_dict = asdict(training_cfg) |
| | filtered_trainer_args_dict = { |
| | key: value for key, value in training_config_dict.items() |
| | if key in valid_hf_arg_names |
| | } |
| | trainer_args = TrainingArguments(**filtered_trainer_args_dict) |
| |
|
| | gradient_accumulation_steps = training_cfg.gradient_accumulation_steps |
| |
|
| | device_map = "auto" |
| | world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| | ddp = world_size != 1 |
| | if ddp: |
| | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} |
| | gradient_accumulation_steps = gradient_accumulation_steps // world_size |
| |
|
| | base_model_name = main_cfg.model.base_model_name |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if False: |
| | model = get_hyper_model(config=config, base_model_name=base_model_name) |
| | |
| | mark_iba_as_trainable_only(model) |
| | count_parameters(model) |
| | model.reset_BA_xslora() |
| | model.save_pretrained('./SVD64_llama2', safe_serialization=False) |
| | exit() |
| | else: |
| | hf_logging.set_verbosity_error() |
| | model = IbaXs_LlamaForCausalLM.from_pretrained( |
| | './SVD64_llama2', |
| | device_map="auto", |
| | dtype=torch.bfloat16, |
| | config=config, |
| | local_files_only=True, |
| | ignore_mismatched_sizes=True |
| | ) |
| | hf_logging.set_verbosity_warning() |
| | |
| | model = reset_trainable_modules(model) |
| | mark_iba_as_trainable_only(model) |
| | count_parameters(model) |
| | |
| | |
| | |
| | |
| |
|
| | if config.model_type == 'llama': |
| | |
| | |
| | if "lama-3" in base_model_name: |
| | print("load llama-3 tokenizer") |
| | tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| | else: |
| | tokenizer = LlamaTokenizer.from_pretrained(base_model_name, legacy=True) |
| | else: |
| | tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) |
| |
|
| | tokenizer.pad_token_id = ( |
| | 0 |
| | ) |
| |
|
| | tokenizer.padding_side = "left" |
| |
|
| | def tokenize(prompt, max_length=main_cfg.model.cutoff_len, add_eos_token=True): |
| | result = tokenizer( |
| | prompt, |
| | truncation=True, |
| | max_length=main_cfg.model.cutoff_len, |
| | padding=False, |
| | return_tensors=None, |
| | ) |
| | if ( |
| | result["input_ids"][-1] != tokenizer.eos_token_id |
| | and len(result["input_ids"]) < max_length |
| | and add_eos_token |
| | ): |
| | result["input_ids"].append(tokenizer.eos_token_id) |
| | if "chatglm" not in base_model_name: |
| | result["attention_mask"].append(1) |
| |
|
| | result["labels"] = result["input_ids"].copy() |
| |
|
| | if "chatglm" in base_model_name: |
| | return {"input_ids": result["input_ids"], "labels": result["labels"]} |
| | else: |
| | return result |
| |
|
| | def generate_and_tokenize_prompt(data_point): |
| | instruction = data_point.get("instruction", "") |
| | inp = data_point.get("input", "") |
| | target_output = data_point.get("output", "") |
| |
|
| | |
| | input_section = f"### Input:\n{inp}\n\n" if inp and str(inp).strip() else "" |
| | |
| | source_text = PROMPT_TEMPLATE.format( |
| | instruction=instruction, |
| | input_section=input_section |
| | ) |
| | full_text = source_text + target_output + tokenizer.eos_token |
| |
|
| | tokenized_full = tokenizer(full_text, truncation=True, max_length=main_cfg.model.cutoff_len, padding=False) |
| | |
| | if not main_cfg.model.train_on_inputs: |
| | tokenized_source = tokenizer(source_text, truncation=True, max_length=main_cfg.model.cutoff_len, padding=False) |
| | source_len = len(tokenized_source["input_ids"]) |
| | |
| | labels = [-100] * source_len + tokenized_full["input_ids"][source_len:] |
| | tokenized_full["labels"] = labels |
| | |
| | return tokenized_full |
| |
|
| | |
| | def generate_and_tokenize_prompt3(data_point): |
| | """ |
| | Standardizes training data to match Eval template and handles label masking. |
| | """ |
| | instruction = data_point.get("instruction", "") |
| | inp = data_point.get("input", "") |
| | output = data_point.get("output", "") |
| |
|
| | |
| | if inp and str(inp).strip(): |
| | input_section = f"### Input:\n{inp}\n\n" |
| | else: |
| | input_section = "" |
| |
|
| | |
| | source_text = PROMPT_TEMPLATE.format( |
| | instruction=instruction, |
| | input_section=input_section |
| | ) |
| | full_text = source_text + output + tokenizer.eos_token |
| |
|
| | |
| | tokenized_full = tokenizer( |
| | full_text, |
| | truncation=True, |
| | max_length=main_cfg.model.cutoff_len, |
| | padding=False, |
| | ) |
| |
|
| | |
| | |
| | if not training_cfg.train_on_inputs: |
| | tokenized_source = tokenizer( |
| | source_text, |
| | truncation=True, |
| | max_length=main_cfg.model.cutoff_len, |
| | padding=False, |
| | ) |
| | source_len = len(tokenized_source["input_ids"]) |
| | |
| | |
| | tokenized_full["labels"] = [ |
| | -100 if i < source_len else token_id |
| | for i, token_id in enumerate(tokenized_full["input_ids"]) |
| | ] |
| | else: |
| | tokenized_full["labels"] = tokenized_full["input_ids"].copy() |
| |
|
| | return tokenized_full |
| | |
| |
|
| | if data_cfg.data_path.endswith(".json"): |
| | data = load_dataset("json", data_files=data_cfg.data_path) |
| | else: |
| | data = load_dataset(data_cfg.data_path) |
| |
|
| | |
| | if training_cfg.resume_from_checkpoint: |
| | |
| | checkpoint_name = os.path.join( |
| | resume_from_checkpoint, "pytorch_model.bin" |
| | ) |
| | if not os.path.exists(checkpoint_name): |
| | checkpoint_name = os.path.join( |
| | resume_from_checkpoint, "adapter_model.bin" |
| | ) |
| | resume_from_checkpoint = ( |
| | False |
| | ) |
| | |
| | if os.path.exists(checkpoint_name): |
| | print(f"Restarting from {checkpoint_name}") |
| | model = IbaXs_LlamaModel.from_pretrained("./my-saved-model") |
| | else: |
| | print(f"Checkpoint {checkpoint_name} not found") |
| |
|
| | if main_cfg.data.val_set_size > 0: |
| | train_val = data["train"].train_test_split( |
| | test_size=main_cfg.data.val_set_size, shuffle=True, seed=42 |
| | ) |
| | train_data = ( |
| | train_val["train"].map(generate_and_tokenize_prompt, num_proc=8) |
| | ) |
| | val_data = ( |
| | train_val["test"].map(generate_and_tokenize_prompt) |
| | ) |
| | else: |
| | train_data = data["train"].shuffle().map(generate_and_tokenize_prompt, num_proc=8) |
| | val_data = None |
| | print('data size', len(train_data), len(val_data)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | start_time = datetime.now() |
| | date_str = start_time.strftime("%dd%Hh%Mm%S") |
| | output_dir = f'{trainer_args.output_dir}/{main_cfg.data.dataset_name}/'\ |
| | f't={date_str},' \ |
| | f'mlr{trainer_args.learning_rate:.1e},'\ |
| | f'b{trainer_args.per_device_train_batch_size},'\ |
| | f'r{main_cfg.hyperxs.lora_attn_dim},n_ct{main_cfg.hyperxs.n_cross_attn_tokens},'\ |
| | f't{date_str},' \ |
| | f'init{main_cfg.run_text},dr{main_cfg.hyperxs.drop_out},'\ |
| | f'ep{trainer_args.num_train_epochs},' \ |
| | f'ds{len(train_data)}' |
| |
|
| | trainer_args.output_dir=output_dir |
| | print(f'Current output_dir: {output_dir}') |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | trainer = transformers.Trainer( |
| | model=model, |
| | train_dataset=train_data, |
| | eval_dataset=val_data, |
| | args=trainer_args, |
| | data_collator=transformers.DataCollatorForSeq2Seq( |
| | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True |
| | ), |
| | ) |
| | model.config.use_cache = False |
| |
|
| | |
| | trainer.train() |
| | end_time = datetime.now() |
| | print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time) |
| | |
| | tokenizer.save_pretrained(os.path.join(trainer_args.output_dir, 'ft')) |
| | trainer.save_state() |
| | config.save_pretrained(os.path.join(trainer_args.output_dir, 'ft')) |
| | model.save_pretrained(os.path.join(trainer_args.output_dir, 'ft2'), safe_serialization=False) |
| | |
| |
|
| | |
| | @draccus.wrap(config_path="./config_draccus/config.yaml") |
| | def main(main_cfg: MainConfig): |
| | |
| | main_cfg_dict = asdict(main_cfg) |
| | |
| |
|
| | config = AutoConfig.from_pretrained( |
| | main_cfg.model.base_model_name, |
| | |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | config.main_cfg = main_cfg_dict |
| | set_seed(main_cfg.seed) |
| | trainIBA(config, main_cfg) |
| |
|
| | |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|