import torch import os import json from transformers import AutoModelForCausalLM def extract_and_merge_instruction_residual( instruction_model_dir, base_model_dir, output_dir, ): """ Extract instruction residual in full precision (float32) without any loss. """ # Load models base_model = AutoModelForCausalLM.from_pretrained( base_model_dir, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True ) instruction_model = AutoModelForCausalLM.from_pretrained( instruction_model_dir, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True ) base_state_dict = base_model.state_dict() instruction_state_dict = instruction_model.state_dict() # Compute high-precision residual residual_state_dict = {} for key in base_state_dict: if key in instruction_state_dict: residual_state_dict[key] = (instruction_state_dict[key] - base_state_dict[key]).to(torch.float32) else: print(f"Warning: Key {key} not found in instruction model state dict") os.makedirs(output_dir, exist_ok=True) adapter_path = os.path.join(output_dir, "instruction_residual_adapter") os.makedirs(adapter_path, exist_ok=True) torch.save(residual_state_dict, os.path.join(adapter_path, "adapter_model.bin")) # Adapter config adapter_config = { "adapter_type": "instruction_residual", "base_model_name_or_path": base_model_dir, "target_modules": ["all"], "lora_alpha": 1.0, "lora_dropout": 0.0, "task_type": "CAUSAL_LM" } with open(os.path.join(adapter_path, "adapter_config.json"), "w") as f: json.dump(adapter_config, f, indent=4) print(f"✅ Full-precision (float32) instruction residual adapter saved to {adapter_path}") if __name__ == "__main__": instruction_model_file = "/workspace/meta-llama/Llama-3.2-3B-Instruct" base_model_file = "/workspace/meta-llama/Llama-3.2-3B" residual_output_file = "/workspace/Llama-3.2-3B-Lr" extract_and_merge_instruction_residual( instruction_model_file, base_model_file, residual_output_file, )