File size: 2,237 Bytes
7c31071 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | import torch
import os
import json
from transformers import AutoModelForCausalLM
def extract_and_merge_instruction_residual(
instruction_model_dir,
base_model_dir,
output_dir,
):
"""
Extract instruction residual in full precision (float32) without any loss.
"""
# Load models
base_model = AutoModelForCausalLM.from_pretrained(
base_model_dir,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True
)
instruction_model = AutoModelForCausalLM.from_pretrained(
instruction_model_dir,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True
)
base_state_dict = base_model.state_dict()
instruction_state_dict = instruction_model.state_dict()
# Compute high-precision residual
residual_state_dict = {}
for key in base_state_dict:
if key in instruction_state_dict:
residual_state_dict[key] = (instruction_state_dict[key] - base_state_dict[key]).to(torch.float32)
else:
print(f"Warning: Key {key} not found in instruction model state dict")
os.makedirs(output_dir, exist_ok=True)
adapter_path = os.path.join(output_dir, "instruction_residual_adapter")
os.makedirs(adapter_path, exist_ok=True)
torch.save(residual_state_dict, os.path.join(adapter_path, "adapter_model.bin"))
# Adapter config
adapter_config = {
"adapter_type": "instruction_residual",
"base_model_name_or_path": base_model_dir,
"target_modules": ["all"],
"lora_alpha": 1.0,
"lora_dropout": 0.0,
"task_type": "CAUSAL_LM"
}
with open(os.path.join(adapter_path, "adapter_config.json"), "w") as f:
json.dump(adapter_config, f, indent=4)
print(f"✅ Full-precision (float32) instruction residual adapter saved to {adapter_path}")
if __name__ == "__main__":
instruction_model_file = "/workspace/meta-llama/Llama-3.2-3B-Instruct"
base_model_file = "/workspace/meta-llama/Llama-3.2-3B"
residual_output_file = "/workspace/Llama-3.2-3B-Lr"
extract_and_merge_instruction_residual(
instruction_model_file,
base_model_file,
residual_output_file,
)
|