workspace_Feb13 / mergeLr.py
Linksome's picture
Add files using upload-large-folder tool
7c31071 verified
import os
import shutil
import torch
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
def _load_model_fp32(model_dir: str):
# transformers versions differ: some warn about torch_dtype, some prefer dtype
try:
return AutoModelForCausalLM.from_pretrained(
model_dir,
dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
except TypeError:
return AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
def merge_instruction_residual(lr_dir, base_model_dir, output_dir):
"""
Merge instruction residual into a (possibly vocab-resized) CPT model.
If vocab was resized after the residual was computed, we add residual only
for the overlapping token rows and keep extra rows (new tokens) unchanged.
"""
adapter_file = os.path.join(lr_dir, "adapter_model.bin")
if not os.path.exists(adapter_file):
raise FileNotFoundError(f"Adapter checkpoint not found at {adapter_file}")
print("Loading residual adapter...")
residual_state_dict = torch.load(adapter_file, map_location="cpu")
print(f"\nMerging residual into base model: {base_model_dir}")
base_model = _load_model_fp32(base_model_dir)
base_state_dict = base_model.state_dict()
merged_state_dict = {}
mismatched = []
for key, base_tensor in base_state_dict.items():
if key not in residual_state_dict:
merged_state_dict[key] = base_tensor
continue
res_tensor = residual_state_dict[key]
# Exact match → normal add
if base_tensor.shape == res_tensor.shape:
merged_state_dict[key] = (base_tensor + res_tensor).to(torch.float32)
continue
# Common case: vocab resized → dim0 differs, rest matches
if (
base_tensor.ndim == res_tensor.ndim
and base_tensor.ndim >= 1
and base_tensor.shape[1:] == res_tensor.shape[1:]
and base_tensor.shape[0] != res_tensor.shape[0]
):
n = min(base_tensor.shape[0], res_tensor.shape[0])
out = base_tensor.clone().to(torch.float32)
out[:n] += res_tensor[:n].to(torch.float32)
merged_state_dict[key] = out
mismatched.append((key, tuple(base_tensor.shape), tuple(res_tensor.shape), n))
continue
# Anything else is suspicious → don’t silently corrupt
raise RuntimeError(
f"Shape mismatch for key '{key}': base={tuple(base_tensor.shape)} "
f"residual={tuple(res_tensor.shape)}. Not a simple vocab-resize mismatch."
)
if mismatched:
print("\nHandled vocab-resize mismatches by partial add:")
for k, bs, rs, n in mismatched[:20]:
print(f" - {k}: base{bs} vs res{rs} → added first {n} rows, kept the rest unchanged")
if len(mismatched) > 20:
print(f" ... and {len(mismatched) - 20} more")
# Load merged weights back
base_model.load_state_dict(merged_state_dict, strict=True)
# Save as bf16
base_model = base_model.to(torch.bfloat16)
os.makedirs(output_dir, exist_ok=True)
base_model.save_pretrained(output_dir, safe_serialization=True)
# Save config (optional; save_pretrained usually does it, but keeping your intent)
base_config = AutoConfig.from_pretrained(base_model_dir)
base_config.save_pretrained(output_dir)
# Best way to keep tokenizer consistent (incl. added tokens)
try:
tok = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True)
tok.save_pretrained(output_dir)
except Exception:
# fallback to your original file-copy approach
for file_name in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]:
src_path = os.path.join(base_model_dir, file_name)
dst_path = os.path.join(output_dir, file_name)
if os.path.exists(src_path):
shutil.copyfile(src_path, dst_path)
print(f"\n✅ Merge complete.")
print(f"🧠 fp32 math → saved bf16 at: {output_dir}")
if __name__ == "__main__":
lr_file = "/workspace/Llama-3.2-3B-Lr/instruction_residual_adapter"
base_model_file = "/workspace/v126rc_exp3/F_r10000/checkpoint-31"
output_root = "/workspace/v126rc_exp3/F_r10000/checkpoint-31/residued"
merge_instruction_residual(lr_file, base_model_file, output_root)