| | import os |
| | import shutil |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer |
| |
|
| |
|
| | def _load_model_fp32(model_dir: str): |
| | |
| | try: |
| | return AutoModelForCausalLM.from_pretrained( |
| | model_dir, |
| | dtype=torch.float32, |
| | device_map="cpu", |
| | trust_remote_code=True, |
| | ) |
| | except TypeError: |
| | return AutoModelForCausalLM.from_pretrained( |
| | model_dir, |
| | torch_dtype=torch.float32, |
| | device_map="cpu", |
| | trust_remote_code=True, |
| | ) |
| |
|
| |
|
| | def merge_instruction_residual(lr_dir, base_model_dir, output_dir): |
| | """ |
| | Merge instruction residual into a (possibly vocab-resized) CPT model. |
| | |
| | If vocab was resized after the residual was computed, we add residual only |
| | for the overlapping token rows and keep extra rows (new tokens) unchanged. |
| | """ |
| |
|
| | adapter_file = os.path.join(lr_dir, "adapter_model.bin") |
| | if not os.path.exists(adapter_file): |
| | raise FileNotFoundError(f"Adapter checkpoint not found at {adapter_file}") |
| |
|
| | print("Loading residual adapter...") |
| | residual_state_dict = torch.load(adapter_file, map_location="cpu") |
| |
|
| | print(f"\nMerging residual into base model: {base_model_dir}") |
| | base_model = _load_model_fp32(base_model_dir) |
| | base_state_dict = base_model.state_dict() |
| |
|
| | merged_state_dict = {} |
| | mismatched = [] |
| |
|
| | for key, base_tensor in base_state_dict.items(): |
| | if key not in residual_state_dict: |
| | merged_state_dict[key] = base_tensor |
| | continue |
| |
|
| | res_tensor = residual_state_dict[key] |
| |
|
| | |
| | if base_tensor.shape == res_tensor.shape: |
| | merged_state_dict[key] = (base_tensor + res_tensor).to(torch.float32) |
| | continue |
| |
|
| | |
| | if ( |
| | base_tensor.ndim == res_tensor.ndim |
| | and base_tensor.ndim >= 1 |
| | and base_tensor.shape[1:] == res_tensor.shape[1:] |
| | and base_tensor.shape[0] != res_tensor.shape[0] |
| | ): |
| | n = min(base_tensor.shape[0], res_tensor.shape[0]) |
| | out = base_tensor.clone().to(torch.float32) |
| | out[:n] += res_tensor[:n].to(torch.float32) |
| | merged_state_dict[key] = out |
| | mismatched.append((key, tuple(base_tensor.shape), tuple(res_tensor.shape), n)) |
| | continue |
| |
|
| | |
| | raise RuntimeError( |
| | f"Shape mismatch for key '{key}': base={tuple(base_tensor.shape)} " |
| | f"residual={tuple(res_tensor.shape)}. Not a simple vocab-resize mismatch." |
| | ) |
| |
|
| | if mismatched: |
| | print("\nHandled vocab-resize mismatches by partial add:") |
| | for k, bs, rs, n in mismatched[:20]: |
| | print(f" - {k}: base{bs} vs res{rs} → added first {n} rows, kept the rest unchanged") |
| | if len(mismatched) > 20: |
| | print(f" ... and {len(mismatched) - 20} more") |
| |
|
| | |
| | base_model.load_state_dict(merged_state_dict, strict=True) |
| |
|
| | |
| | base_model = base_model.to(torch.bfloat16) |
| | os.makedirs(output_dir, exist_ok=True) |
| | base_model.save_pretrained(output_dir, safe_serialization=True) |
| |
|
| | |
| | base_config = AutoConfig.from_pretrained(base_model_dir) |
| | base_config.save_pretrained(output_dir) |
| |
|
| | |
| | try: |
| | tok = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True) |
| | tok.save_pretrained(output_dir) |
| | except Exception: |
| | |
| | for file_name in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]: |
| | src_path = os.path.join(base_model_dir, file_name) |
| | dst_path = os.path.join(output_dir, file_name) |
| | if os.path.exists(src_path): |
| | shutil.copyfile(src_path, dst_path) |
| |
|
| | print(f"\n✅ Merge complete.") |
| | print(f"🧠 fp32 math → saved bf16 at: {output_dir}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | lr_file = "/workspace/Llama-3.2-3B-Lr/instruction_residual_adapter" |
| | base_model_file = "/workspace/v126rc_exp3/F_r10000/checkpoint-31" |
| | output_root = "/workspace/v126rc_exp3/F_r10000/checkpoint-31/residued" |
| |
|
| | merge_instruction_residual(lr_file, base_model_file, output_root) |
| |
|