| | import torch |
| | import os |
| | import json |
| | from transformers import AutoModelForCausalLM |
| |
|
| |
|
| | def extract_and_merge_instruction_residual( |
| | instruction_model_dir, |
| | base_model_dir, |
| | output_dir, |
| | ): |
| | """ |
| | Extract instruction residual in full precision (float32) without any loss. |
| | """ |
| |
|
| | |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | base_model_dir, |
| | torch_dtype=torch.float32, |
| | device_map="cpu", |
| | trust_remote_code=True |
| | ) |
| |
|
| | instruction_model = AutoModelForCausalLM.from_pretrained( |
| | instruction_model_dir, |
| | torch_dtype=torch.float32, |
| | device_map="cpu", |
| | trust_remote_code=True |
| | ) |
| |
|
| | base_state_dict = base_model.state_dict() |
| | instruction_state_dict = instruction_model.state_dict() |
| |
|
| | |
| | residual_state_dict = {} |
| | for key in base_state_dict: |
| | if key in instruction_state_dict: |
| | residual_state_dict[key] = (instruction_state_dict[key] - base_state_dict[key]).to(torch.float32) |
| | else: |
| | print(f"Warning: Key {key} not found in instruction model state dict") |
| |
|
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | adapter_path = os.path.join(output_dir, "instruction_residual_adapter") |
| | os.makedirs(adapter_path, exist_ok=True) |
| | torch.save(residual_state_dict, os.path.join(adapter_path, "adapter_model.bin")) |
| |
|
| | |
| | adapter_config = { |
| | "adapter_type": "instruction_residual", |
| | "base_model_name_or_path": base_model_dir, |
| | "target_modules": ["all"], |
| | "lora_alpha": 1.0, |
| | "lora_dropout": 0.0, |
| | "task_type": "CAUSAL_LM" |
| | } |
| |
|
| | with open(os.path.join(adapter_path, "adapter_config.json"), "w") as f: |
| | json.dump(adapter_config, f, indent=4) |
| |
|
| | print(f"✅ Full-precision (float32) instruction residual adapter saved to {adapter_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | instruction_model_file = "/workspace/meta-llama/Llama-3.2-3B-Instruct" |
| | base_model_file = "/workspace/meta-llama/Llama-3.2-3B" |
| | residual_output_file = "/workspace/Llama-3.2-3B-Lr" |
| |
|
| | extract_and_merge_instruction_residual( |
| | instruction_model_file, |
| | base_model_file, |
| | residual_output_file, |
| | ) |
| |
|