File size: 6,312 Bytes
29fc577 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | #!/usr/bin/env python3
"""
scripts/merge_checkpoints.py β Slerp (Spherical Linear Interpolation) checkpoint merge.
Merges two model checkpoints (e.g., SFT + DPO) using SLERP interpolation
to balance knowledge retention and alignment improvement.
Reference: Nemotron-H paper β SLERP merging reduces alignment tax.
Usage:
python scripts/merge_checkpoints.py \
--ckpt_a checkpoints/3b_sft_v2/checkpoint-best \
--ckpt_b checkpoints/3b_dpo/checkpoint-merged \
--output checkpoints/3b_dpo/checkpoint-slerp \
--alpha 0.5
alpha=0.0 β pure ckpt_a (SFT)
alpha=1.0 β pure ckpt_b (DPO)
alpha=0.5 β equal blend (recommended starting point)
"""
from __future__ import annotations
import argparse
import math
import shutil
from pathlib import Path
import torch
import yaml
def slerp(t: float, v0: torch.Tensor, v1: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""Spherical linear interpolation between two tensors.
Args:
t: Interpolation factor in [0, 1]. 0 β v0, 1 β v1.
v0: First tensor (flattened internally).
v1: Second tensor (same shape as v0).
eps: Small value to avoid division by zero.
Returns:
Interpolated tensor with the same shape as v0.
"""
original_shape = v0.shape
v0_flat = v0.flatten().float()
v1_flat = v1.flatten().float()
# Normalize
v0_norm = v0_flat / (v0_flat.norm() + eps)
v1_norm = v1_flat / (v1_flat.norm() + eps)
# Cosine of angle between vectors
cos_omega = torch.dot(v0_norm, v1_norm).clamp(-1.0, 1.0)
# If vectors are very similar, fall back to linear interpolation
if abs(cos_omega.item()) > 0.9995:
result = (1.0 - t) * v0_flat + t * v1_flat
return result.reshape(original_shape).to(v0.dtype)
omega = torch.acos(cos_omega)
sin_omega = torch.sin(omega)
s0 = torch.sin((1.0 - t) * omega) / sin_omega
s1 = torch.sin(t * omega) / sin_omega
# Interpolate using original (non-normalized) vectors to preserve scale
result = s0 * v0_flat + s1 * v1_flat
return result.reshape(original_shape).to(v0.dtype)
def lerp(t: float, v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
"""Simple linear interpolation."""
return ((1.0 - t) * v0.float() + t * v1.float()).to(v0.dtype)
def merge_state_dicts(
sd_a: dict[str, torch.Tensor],
sd_b: dict[str, torch.Tensor],
alpha: float = 0.5,
method: str = "slerp",
) -> dict[str, torch.Tensor]:
"""Merge two state dicts using SLERP or LERP.
Args:
sd_a: State dict A (e.g., SFT model).
sd_b: State dict B (e.g., DPO model).
alpha: Interpolation factor. 0 β A, 1 β B.
method: "slerp" or "lerp".
Returns:
Merged state dict.
"""
interp_fn = slerp if method == "slerp" else lerp
merged = {}
keys_a = set(sd_a.keys())
keys_b = set(sd_b.keys())
common = keys_a & keys_b
only_a = keys_a - keys_b
only_b = keys_b - keys_a
if only_a:
print(f"[WARN] {len(only_a)} keys only in ckpt_a (kept as-is)")
if only_b:
print(f"[WARN] {len(only_b)} keys only in ckpt_b (kept as-is)")
for key in sorted(common):
va = sd_a[key]
vb = sd_b[key]
if va.shape != vb.shape:
print(f"[WARN] Shape mismatch for {key}: {va.shape} vs {vb.shape}, keeping ckpt_a")
merged[key] = va
continue
# Only interpolate float parameters (skip int buffers, etc.)
if va.is_floating_point() and va.numel() > 1:
merged[key] = interp_fn(alpha, va, vb)
else:
merged[key] = va # Keep from ckpt_a for non-float/scalar
# Include keys unique to each
for key in only_a:
merged[key] = sd_a[key]
for key in only_b:
merged[key] = sd_b[key]
return merged
def main():
parser = argparse.ArgumentParser(description="SLERP checkpoint merge")
parser.add_argument("--ckpt_a", type=Path, required=True,
help="Path to checkpoint A (e.g., SFT)")
parser.add_argument("--ckpt_b", type=Path, required=True,
help="Path to checkpoint B (e.g., DPO)")
parser.add_argument("--output", type=Path, required=True,
help="Output checkpoint directory")
parser.add_argument("--alpha", type=float, default=0.5,
help="Interpolation factor (0=A, 1=B, default 0.5)")
parser.add_argument("--method", choices=["slerp", "lerp"], default="slerp",
help="Interpolation method (default: slerp)")
args = parser.parse_args()
print(f"Merge: {args.ckpt_a.name} β({1-args.alpha:.1%})β ({args.alpha:.1%})β {args.ckpt_b.name}")
print(f"Method: {args.method}, alpha={args.alpha}")
# Load state dicts
print("Loading checkpoint A...")
sd_a = torch.load(args.ckpt_a / "model.pt", map_location="cpu", weights_only=True)
print(f" {len(sd_a)} keys loaded")
print("Loading checkpoint B...")
sd_b = torch.load(args.ckpt_b / "model.pt", map_location="cpu", weights_only=True)
print(f" {len(sd_b)} keys loaded")
# Merge
print("Merging...")
merged_sd = merge_state_dicts(sd_a, sd_b, alpha=args.alpha, method=args.method)
print(f" {len(merged_sd)} keys in merged state dict")
# Save
args.output.mkdir(parents=True, exist_ok=True)
torch.save(merged_sd, args.output / "model.pt")
# Copy config from ckpt_a
config_src = args.ckpt_a / "config.yaml"
if config_src.exists():
shutil.copy2(str(config_src), str(args.output / "config.yaml"))
# Copy tokenizer if available
for tok_name in ["tokenizer.json", "tokenizer.model"]:
tok_src = args.ckpt_a / tok_name
if tok_src.exists():
shutil.copy2(str(tok_src), str(args.output / tok_name))
# Write merge metadata
meta = {
"ckpt_a": str(args.ckpt_a),
"ckpt_b": str(args.ckpt_b),
"alpha": args.alpha,
"method": args.method,
}
with open(args.output / "merge_info.yaml", "w") as f:
yaml.safe_dump(meta, f)
size_mb = (args.output / "model.pt").stat().st_size / 1e6
print(f"\nMerged checkpoint saved β {args.output} ({size_mb:.0f} MB)")
if __name__ == "__main__":
main()
|