| |
| """ |
| scripts/merge_checkpoints.py — Slerp (Spherical Linear Interpolation) checkpoint merge. |
| |
| Merges two model checkpoints (e.g., SFT + DPO) using SLERP interpolation |
| to balance knowledge retention and alignment improvement. |
| |
| Reference: Nemotron-H paper — SLERP merging reduces alignment tax. |
| |
| Usage: |
| python scripts/merge_checkpoints.py \ |
| --ckpt_a checkpoints/3b_sft_v2/checkpoint-best \ |
| --ckpt_b checkpoints/3b_dpo/checkpoint-merged \ |
| --output checkpoints/3b_dpo/checkpoint-slerp \ |
| --alpha 0.5 |
| |
| alpha=0.0 → pure ckpt_a (SFT) |
| alpha=1.0 → pure ckpt_b (DPO) |
| alpha=0.5 → equal blend (recommended starting point) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import shutil |
| from pathlib import Path |
|
|
| import torch |
| import yaml |
|
|
|
|
| def slerp(t: float, v0: torch.Tensor, v1: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: |
| """Spherical linear interpolation between two tensors. |
| |
| Args: |
| t: Interpolation factor in [0, 1]. 0 → v0, 1 → v1. |
| v0: First tensor (flattened internally). |
| v1: Second tensor (same shape as v0). |
| eps: Small value to avoid division by zero. |
| |
| Returns: |
| Interpolated tensor with the same shape as v0. |
| """ |
| original_shape = v0.shape |
| v0_flat = v0.flatten().float() |
| v1_flat = v1.flatten().float() |
|
|
| |
| v0_norm = v0_flat / (v0_flat.norm() + eps) |
| v1_norm = v1_flat / (v1_flat.norm() + eps) |
|
|
| |
| cos_omega = torch.dot(v0_norm, v1_norm).clamp(-1.0, 1.0) |
|
|
| |
| if abs(cos_omega.item()) > 0.9995: |
| result = (1.0 - t) * v0_flat + t * v1_flat |
| return result.reshape(original_shape).to(v0.dtype) |
|
|
| omega = torch.acos(cos_omega) |
| sin_omega = torch.sin(omega) |
|
|
| s0 = torch.sin((1.0 - t) * omega) / sin_omega |
| s1 = torch.sin(t * omega) / sin_omega |
|
|
| |
| result = s0 * v0_flat + s1 * v1_flat |
| return result.reshape(original_shape).to(v0.dtype) |
|
|
|
|
| def lerp(t: float, v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor: |
| """Simple linear interpolation.""" |
| return ((1.0 - t) * v0.float() + t * v1.float()).to(v0.dtype) |
|
|
|
|
| def merge_state_dicts( |
| sd_a: dict[str, torch.Tensor], |
| sd_b: dict[str, torch.Tensor], |
| alpha: float = 0.5, |
| method: str = "slerp", |
| ) -> dict[str, torch.Tensor]: |
| """Merge two state dicts using SLERP or LERP. |
| |
| Args: |
| sd_a: State dict A (e.g., SFT model). |
| sd_b: State dict B (e.g., DPO model). |
| alpha: Interpolation factor. 0 → A, 1 → B. |
| method: "slerp" or "lerp". |
| |
| Returns: |
| Merged state dict. |
| """ |
| interp_fn = slerp if method == "slerp" else lerp |
|
|
| merged = {} |
| keys_a = set(sd_a.keys()) |
| keys_b = set(sd_b.keys()) |
|
|
| common = keys_a & keys_b |
| only_a = keys_a - keys_b |
| only_b = keys_b - keys_a |
|
|
| if only_a: |
| print(f"[WARN] {len(only_a)} keys only in ckpt_a (kept as-is)") |
| if only_b: |
| print(f"[WARN] {len(only_b)} keys only in ckpt_b (kept as-is)") |
|
|
| for key in sorted(common): |
| va = sd_a[key] |
| vb = sd_b[key] |
|
|
| if va.shape != vb.shape: |
| print(f"[WARN] Shape mismatch for {key}: {va.shape} vs {vb.shape}, keeping ckpt_a") |
| merged[key] = va |
| continue |
|
|
| |
| if va.is_floating_point() and va.numel() > 1: |
| merged[key] = interp_fn(alpha, va, vb) |
| else: |
| merged[key] = va |
|
|
| |
| for key in only_a: |
| merged[key] = sd_a[key] |
| for key in only_b: |
| merged[key] = sd_b[key] |
|
|
| return merged |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="SLERP checkpoint merge") |
| parser.add_argument("--ckpt_a", type=Path, required=True, |
| help="Path to checkpoint A (e.g., SFT)") |
| parser.add_argument("--ckpt_b", type=Path, required=True, |
| help="Path to checkpoint B (e.g., DPO)") |
| parser.add_argument("--output", type=Path, required=True, |
| help="Output checkpoint directory") |
| parser.add_argument("--alpha", type=float, default=0.5, |
| help="Interpolation factor (0=A, 1=B, default 0.5)") |
| parser.add_argument("--method", choices=["slerp", "lerp"], default="slerp", |
| help="Interpolation method (default: slerp)") |
| args = parser.parse_args() |
|
|
| print(f"Merge: {args.ckpt_a.name} ←({1-args.alpha:.1%})— ({args.alpha:.1%})→ {args.ckpt_b.name}") |
| print(f"Method: {args.method}, alpha={args.alpha}") |
|
|
| |
| print("Loading checkpoint A...") |
| sd_a = torch.load(args.ckpt_a / "model.pt", map_location="cpu", weights_only=True) |
| print(f" {len(sd_a)} keys loaded") |
|
|
| print("Loading checkpoint B...") |
| sd_b = torch.load(args.ckpt_b / "model.pt", map_location="cpu", weights_only=True) |
| print(f" {len(sd_b)} keys loaded") |
|
|
| |
| print("Merging...") |
| merged_sd = merge_state_dicts(sd_a, sd_b, alpha=args.alpha, method=args.method) |
| print(f" {len(merged_sd)} keys in merged state dict") |
|
|
| |
| args.output.mkdir(parents=True, exist_ok=True) |
| torch.save(merged_sd, args.output / "model.pt") |
|
|
| |
| config_src = args.ckpt_a / "config.yaml" |
| if config_src.exists(): |
| shutil.copy2(str(config_src), str(args.output / "config.yaml")) |
|
|
| |
| for tok_name in ["tokenizer.json", "tokenizer.model"]: |
| tok_src = args.ckpt_a / tok_name |
| if tok_src.exists(): |
| shutil.copy2(str(tok_src), str(args.output / tok_name)) |
|
|
| |
| meta = { |
| "ckpt_a": str(args.ckpt_a), |
| "ckpt_b": str(args.ckpt_b), |
| "alpha": args.alpha, |
| "method": args.method, |
| } |
| with open(args.output / "merge_info.yaml", "w") as f: |
| yaml.safe_dump(meta, f) |
|
|
| size_mb = (args.output / "model.pt").stat().st_size / 1e6 |
| print(f"\nMerged checkpoint saved → {args.output} ({size_mb:.0f} MB)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|