import os import shutil import torch from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer def _load_model_fp32(model_dir: str): # transformers versions differ: some warn about torch_dtype, some prefer dtype try: return AutoModelForCausalLM.from_pretrained( model_dir, dtype=torch.float32, device_map="cpu", trust_remote_code=True, ) except TypeError: return AutoModelForCausalLM.from_pretrained( model_dir, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, ) def merge_instruction_residual(lr_dir, base_model_dir, output_dir): """ Merge instruction residual into a (possibly vocab-resized) CPT model. If vocab was resized after the residual was computed, we add residual only for the overlapping token rows and keep extra rows (new tokens) unchanged. """ adapter_file = os.path.join(lr_dir, "adapter_model.bin") if not os.path.exists(adapter_file): raise FileNotFoundError(f"Adapter checkpoint not found at {adapter_file}") print("Loading residual adapter...") residual_state_dict = torch.load(adapter_file, map_location="cpu") print(f"\nMerging residual into base model: {base_model_dir}") base_model = _load_model_fp32(base_model_dir) base_state_dict = base_model.state_dict() merged_state_dict = {} mismatched = [] for key, base_tensor in base_state_dict.items(): if key not in residual_state_dict: merged_state_dict[key] = base_tensor continue res_tensor = residual_state_dict[key] # Exact match → normal add if base_tensor.shape == res_tensor.shape: merged_state_dict[key] = (base_tensor + res_tensor).to(torch.float32) continue # Common case: vocab resized → dim0 differs, rest matches if ( base_tensor.ndim == res_tensor.ndim and base_tensor.ndim >= 1 and base_tensor.shape[1:] == res_tensor.shape[1:] and base_tensor.shape[0] != res_tensor.shape[0] ): n = min(base_tensor.shape[0], res_tensor.shape[0]) out = base_tensor.clone().to(torch.float32) out[:n] += res_tensor[:n].to(torch.float32) merged_state_dict[key] = out mismatched.append((key, tuple(base_tensor.shape), tuple(res_tensor.shape), n)) continue # Anything else is suspicious → don’t silently corrupt raise RuntimeError( f"Shape mismatch for key '{key}': base={tuple(base_tensor.shape)} " f"residual={tuple(res_tensor.shape)}. Not a simple vocab-resize mismatch." ) if mismatched: print("\nHandled vocab-resize mismatches by partial add:") for k, bs, rs, n in mismatched[:20]: print(f" - {k}: base{bs} vs res{rs} → added first {n} rows, kept the rest unchanged") if len(mismatched) > 20: print(f" ... and {len(mismatched) - 20} more") # Load merged weights back base_model.load_state_dict(merged_state_dict, strict=True) # Save as bf16 base_model = base_model.to(torch.bfloat16) os.makedirs(output_dir, exist_ok=True) base_model.save_pretrained(output_dir, safe_serialization=True) # Save config (optional; save_pretrained usually does it, but keeping your intent) base_config = AutoConfig.from_pretrained(base_model_dir) base_config.save_pretrained(output_dir) # Best way to keep tokenizer consistent (incl. added tokens) try: tok = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True) tok.save_pretrained(output_dir) except Exception: # fallback to your original file-copy approach for file_name in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]: src_path = os.path.join(base_model_dir, file_name) dst_path = os.path.join(output_dir, file_name) if os.path.exists(src_path): shutil.copyfile(src_path, dst_path) print(f"\n✅ Merge complete.") print(f"🧠 fp32 math → saved bf16 at: {output_dir}") if __name__ == "__main__": lr_file = "/workspace/Llama-3.2-3B-Lr/instruction_residual_adapter" base_model_file = "/workspace/v126rc_exp3/F_r10000/checkpoint-31" output_root = "/workspace/v126rc_exp3/F_r10000/checkpoint-31/residued" merge_instruction_residual(lr_file, base_model_file, output_root)