File size: 6,312 Bytes
29fc577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#!/usr/bin/env python3
"""
scripts/merge_checkpoints.py β€” Slerp (Spherical Linear Interpolation) checkpoint merge.

Merges two model checkpoints (e.g., SFT + DPO) using SLERP interpolation
to balance knowledge retention and alignment improvement.

Reference: Nemotron-H paper β€” SLERP merging reduces alignment tax.

Usage:
    python scripts/merge_checkpoints.py \
        --ckpt_a checkpoints/3b_sft_v2/checkpoint-best \
        --ckpt_b checkpoints/3b_dpo/checkpoint-merged \
        --output checkpoints/3b_dpo/checkpoint-slerp \
        --alpha 0.5

    alpha=0.0 β†’ pure ckpt_a (SFT)
    alpha=1.0 β†’ pure ckpt_b (DPO)
    alpha=0.5 β†’ equal blend (recommended starting point)
"""

from __future__ import annotations

import argparse
import math
import shutil
from pathlib import Path

import torch
import yaml


def slerp(t: float, v0: torch.Tensor, v1: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """Spherical linear interpolation between two tensors.

    Args:
        t: Interpolation factor in [0, 1]. 0 β†’ v0, 1 β†’ v1.
        v0: First tensor (flattened internally).
        v1: Second tensor (same shape as v0).
        eps: Small value to avoid division by zero.

    Returns:
        Interpolated tensor with the same shape as v0.
    """
    original_shape = v0.shape
    v0_flat = v0.flatten().float()
    v1_flat = v1.flatten().float()

    # Normalize
    v0_norm = v0_flat / (v0_flat.norm() + eps)
    v1_norm = v1_flat / (v1_flat.norm() + eps)

    # Cosine of angle between vectors
    cos_omega = torch.dot(v0_norm, v1_norm).clamp(-1.0, 1.0)

    # If vectors are very similar, fall back to linear interpolation
    if abs(cos_omega.item()) > 0.9995:
        result = (1.0 - t) * v0_flat + t * v1_flat
        return result.reshape(original_shape).to(v0.dtype)

    omega = torch.acos(cos_omega)
    sin_omega = torch.sin(omega)

    s0 = torch.sin((1.0 - t) * omega) / sin_omega
    s1 = torch.sin(t * omega) / sin_omega

    # Interpolate using original (non-normalized) vectors to preserve scale
    result = s0 * v0_flat + s1 * v1_flat
    return result.reshape(original_shape).to(v0.dtype)


def lerp(t: float, v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
    """Simple linear interpolation."""
    return ((1.0 - t) * v0.float() + t * v1.float()).to(v0.dtype)


def merge_state_dicts(
    sd_a: dict[str, torch.Tensor],
    sd_b: dict[str, torch.Tensor],
    alpha: float = 0.5,
    method: str = "slerp",
) -> dict[str, torch.Tensor]:
    """Merge two state dicts using SLERP or LERP.

    Args:
        sd_a: State dict A (e.g., SFT model).
        sd_b: State dict B (e.g., DPO model).
        alpha: Interpolation factor. 0 β†’ A, 1 β†’ B.
        method: "slerp" or "lerp".

    Returns:
        Merged state dict.
    """
    interp_fn = slerp if method == "slerp" else lerp

    merged = {}
    keys_a = set(sd_a.keys())
    keys_b = set(sd_b.keys())

    common = keys_a & keys_b
    only_a = keys_a - keys_b
    only_b = keys_b - keys_a

    if only_a:
        print(f"[WARN] {len(only_a)} keys only in ckpt_a (kept as-is)")
    if only_b:
        print(f"[WARN] {len(only_b)} keys only in ckpt_b (kept as-is)")

    for key in sorted(common):
        va = sd_a[key]
        vb = sd_b[key]

        if va.shape != vb.shape:
            print(f"[WARN] Shape mismatch for {key}: {va.shape} vs {vb.shape}, keeping ckpt_a")
            merged[key] = va
            continue

        # Only interpolate float parameters (skip int buffers, etc.)
        if va.is_floating_point() and va.numel() > 1:
            merged[key] = interp_fn(alpha, va, vb)
        else:
            merged[key] = va  # Keep from ckpt_a for non-float/scalar

    # Include keys unique to each
    for key in only_a:
        merged[key] = sd_a[key]
    for key in only_b:
        merged[key] = sd_b[key]

    return merged


def main():
    parser = argparse.ArgumentParser(description="SLERP checkpoint merge")
    parser.add_argument("--ckpt_a", type=Path, required=True,
                        help="Path to checkpoint A (e.g., SFT)")
    parser.add_argument("--ckpt_b", type=Path, required=True,
                        help="Path to checkpoint B (e.g., DPO)")
    parser.add_argument("--output", type=Path, required=True,
                        help="Output checkpoint directory")
    parser.add_argument("--alpha", type=float, default=0.5,
                        help="Interpolation factor (0=A, 1=B, default 0.5)")
    parser.add_argument("--method", choices=["slerp", "lerp"], default="slerp",
                        help="Interpolation method (default: slerp)")
    args = parser.parse_args()

    print(f"Merge: {args.ckpt_a.name} ←({1-args.alpha:.1%})β€” ({args.alpha:.1%})β†’ {args.ckpt_b.name}")
    print(f"Method: {args.method}, alpha={args.alpha}")

    # Load state dicts
    print("Loading checkpoint A...")
    sd_a = torch.load(args.ckpt_a / "model.pt", map_location="cpu", weights_only=True)
    print(f"  {len(sd_a)} keys loaded")

    print("Loading checkpoint B...")
    sd_b = torch.load(args.ckpt_b / "model.pt", map_location="cpu", weights_only=True)
    print(f"  {len(sd_b)} keys loaded")

    # Merge
    print("Merging...")
    merged_sd = merge_state_dicts(sd_a, sd_b, alpha=args.alpha, method=args.method)
    print(f"  {len(merged_sd)} keys in merged state dict")

    # Save
    args.output.mkdir(parents=True, exist_ok=True)
    torch.save(merged_sd, args.output / "model.pt")

    # Copy config from ckpt_a
    config_src = args.ckpt_a / "config.yaml"
    if config_src.exists():
        shutil.copy2(str(config_src), str(args.output / "config.yaml"))

    # Copy tokenizer if available
    for tok_name in ["tokenizer.json", "tokenizer.model"]:
        tok_src = args.ckpt_a / tok_name
        if tok_src.exists():
            shutil.copy2(str(tok_src), str(args.output / tok_name))

    # Write merge metadata
    meta = {
        "ckpt_a": str(args.ckpt_a),
        "ckpt_b": str(args.ckpt_b),
        "alpha": args.alpha,
        "method": args.method,
    }
    with open(args.output / "merge_info.yaml", "w") as f:
        yaml.safe_dump(meta, f)

    size_mb = (args.output / "model.pt").stat().st_size / 1e6
    print(f"\nMerged checkpoint saved β†’ {args.output} ({size_mb:.0f} MB)")


if __name__ == "__main__":
    main()