File size: 8,496 Bytes
48ecd01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#!/usr/bin/env python3
"""Migrate checkpoint from separate Q/K/V projections to fused QKV.

Usage:
    python3 scripts/migrate_qkv_checkpoint.py <checkpoint_dir>

Migrates both model.pt AND optimizer.pt:
  - model.pt:     q_proj/k_proj/v_proj weights β†’ qkv_proj weight
  - optimizer.pt: exp_avg/exp_avg_sq states fused, param indices re-mapped

The concatenation order is [Q ; K ; V] along the output (dim-0) axis,
which matches the split in MultiHeadAttention.forward:
    q, k, v = qkv.split([_q_dim, _kv_dim, _kv_dim], dim=-1)

Optimizer layout (group 0 = weight_decay, per layer Γ— 28):
  [i*6+0] q_proj.weight  [3072, 3072]
  [i*6+1] k_proj.weight  [1024, 3072]
  [i*6+2] v_proj.weight  [1024, 3072]
  [i*6+3] out_proj.weight [3072, 3072]
  [i*6+4] fc1_weight     [16384, 3072]
  [i*6+5] fc2_weight     [3072, 8192]
After fusion: indices 0,1,2 β†’ single qkv_proj β†’ 4 params per layer.
"""
import sys
import torch
from pathlib import Path

N_LAYERS = 28
OLD_PARAMS_PER_LAYER = 6  # q, k, v, out, fc1, fc2
NEW_PARAMS_PER_LAYER = 4  # qkv, out, fc1, fc2


def migrate_model(state: dict) -> dict:
    """Fuse Q/K/V projection weights into QKV in model state dict."""
    new_state: dict = {}
    layers_done: set = set()

    for key, val in state.items():
        if ".q_proj." not in key and ".k_proj." not in key and ".v_proj." not in key:
            new_state[key] = val
            continue

        if ".q_proj." not in key:
            continue

        prefix = key.rsplit(".", 2)[0]
        suffix = key.rsplit(".", 1)[-1]

        tag = (prefix, suffix)
        if tag in layers_done:
            continue
        layers_done.add(tag)

        q_key = f"{prefix}.q_proj.{suffix}"
        k_key = f"{prefix}.k_proj.{suffix}"
        v_key = f"{prefix}.v_proj.{suffix}"

        missing = [k for k in (q_key, k_key, v_key) if k not in state]
        if missing:
            raise KeyError(f"Expected keys not found in checkpoint: {missing}")

        q_w, k_w, v_w = state[q_key], state[k_key], state[v_key]
        fused = torch.cat([q_w, k_w, v_w], dim=0)
        fused_key = f"{prefix}.qkv_proj.{suffix}"
        new_state[fused_key] = fused
        print(f"  Fused  {fused_key}: {list(fused.shape)}"
              f"  (q={list(q_w.shape)}, k={list(k_w.shape)}, v={list(v_w.shape)})")

    leaked = [k for k in new_state if ".q_proj." in k or ".k_proj." in k or ".v_proj." in k]
    if leaked:
        raise RuntimeError(f"BUG: old projection keys still present: {leaked}")

    return new_state


def migrate_optimizer(opt_state: dict) -> dict:
    """Fuse optimizer states for Q/K/V β†’ QKV and re-index parameters.

    The optimizer has 2 param groups:
      Group 0 (weight_decay): 168 = 28 layers Γ— 6 (q,k,v,out,fc1,fc2)
      Group 1 (no weight_decay): 58 = norms + embedding

    We fuse q,k,v entries in group 0 (indices i*6+0,1,2 β†’ one entry per layer).
    Group 0 shrinks from 168 to 112 (28 layers Γ— 4 params).
    Group 1 stays at 58. Total: 170.
    """
    old_state = opt_state["state"]
    old_groups = opt_state["param_groups"]

    group0_count = len(old_groups[0]["params"])
    expected_g0 = N_LAYERS * OLD_PARAMS_PER_LAYER
    if group0_count != expected_g0:
        raise ValueError(
            f"Group 0 has {group0_count} params, expected {expected_g0}. "
            f"Cannot auto-detect QKV layout."
        )

    # Validate shapes for first layer
    shapes = []
    for j in range(OLD_PARAMS_PER_LAYER):
        idx = old_groups[0]["params"][j]
        shapes.append(list(old_state[idx]["exp_avg"].shape))
    expected_shapes = [[3072, 3072], [1024, 3072], [1024, 3072],
                       [3072, 3072], [16384, 3072], [3072, 8192]]
    if shapes != expected_shapes:
        raise ValueError(
            f"Layer 0 shapes {shapes} don't match expected {expected_shapes}. "
            f"Cannot auto-detect QKV layout."
        )
    print(f"  Shape validation passed for layer 0.")

    new_state_entries = {}
    new_idx = 0

    # --- Group 0: fuse q/k/v per layer ---
    for layer_i in range(N_LAYERS):
        base = layer_i * OLD_PARAMS_PER_LAYER
        q_opt_idx = old_groups[0]["params"][base + 0]
        k_opt_idx = old_groups[0]["params"][base + 1]
        v_opt_idx = old_groups[0]["params"][base + 2]

        q_entry = old_state[q_opt_idx]
        k_entry = old_state[k_opt_idx]
        v_entry = old_state[v_opt_idx]

        # Fuse QKV
        fused_entry = {"step": q_entry["step"]}
        for field in ["exp_avg", "exp_avg_sq"]:
            if field in q_entry:
                fused_entry[field] = torch.cat(
                    [q_entry[field], k_entry[field], v_entry[field]], dim=0
                )
        new_state_entries[new_idx] = fused_entry
        if layer_i == 0:
            print(f"  Layer 0 QKV fused: exp_avg {list(fused_entry['exp_avg'].shape)}")
        new_idx += 1

        # Copy remaining params (out, fc1, fc2)
        for offset in [3, 4, 5]:
            opt_idx = old_groups[0]["params"][base + offset]
            new_state_entries[new_idx] = old_state[opt_idx]
            new_idx += 1

    new_group0_count = new_idx  # should be N_LAYERS * NEW_PARAMS_PER_LAYER = 112
    print(f"  Group 0: {group0_count} β†’ {new_group0_count} params")

    # --- Group 1: copy as-is (norms, embedding β€” no QKV) ---
    group1_count = len(old_groups[1]["params"])
    for j in range(group1_count):
        opt_idx = old_groups[1]["params"][j]
        if opt_idx in old_state:
            new_state_entries[new_idx] = old_state[opt_idx]
        new_idx += 1
    print(f"  Group 1: {group1_count} β†’ {group1_count} params (unchanged)")

    # Build new param_groups
    new_groups = []
    g0 = {k: v for k, v in old_groups[0].items() if k != "params"}
    g0["params"] = list(range(0, new_group0_count))
    new_groups.append(g0)

    g1 = {k: v for k, v in old_groups[1].items() if k != "params"}
    g1["params"] = list(range(new_group0_count, new_group0_count + group1_count))
    new_groups.append(g1)

    total = new_group0_count + group1_count
    print(f"  Total: {len(old_state)} β†’ {total} optimizer params")

    return {"state": new_state_entries, "param_groups": new_groups}


def migrate(ckpt_dir: Path) -> None:
    model_path = ckpt_dir / "model.pt"
    opt_path = ckpt_dir / "optimizer.pt"

    if not model_path.exists():
        raise FileNotFoundError(f"model.pt not found in {ckpt_dir}")

    # --- Model migration ---
    print(f"[1/2] Migrating model weights from {model_path} ...")
    state = torch.load(model_path, map_location="cpu", weights_only=True)

    has_old = any(".q_proj." in k for k in state)
    has_new = any(".qkv_proj." in k for k in state)

    if has_new and not has_old:
        print("  Model already migrated. Skipping.")
    elif has_old:
        new_model_state = migrate_model(state)
        torch.save(new_model_state, model_path)
        print(f"  Model saved.")
    else:
        raise RuntimeError("Model state has neither q_proj nor qkv_proj keys!")

    # --- Optimizer migration ---
    if opt_path.exists():
        print(f"\n[2/2] Migrating optimizer states from {opt_path} ...")
        opt = torch.load(opt_path, map_location="cpu", weights_only=True)

        # Check if already migrated
        total_params = sum(len(pg["params"]) for pg in opt["param_groups"])
        expected_old = N_LAYERS * OLD_PARAMS_PER_LAYER + 58  # 168 + 58 = 226
        expected_new = N_LAYERS * NEW_PARAMS_PER_LAYER + 58  # 112 + 58 = 170

        if total_params == expected_old:
            opt_backup = ckpt_dir / "optimizer.pt.backup_pre_qkv"
            if not opt_backup.exists():
                torch.save(opt, opt_backup)
                print(f"  Backup: {opt_backup}")
            new_opt = migrate_optimizer(opt)
            torch.save(new_opt, opt_path)
            print(f"  Optimizer saved.")
        elif total_params == expected_new:
            print(f"  Optimizer already migrated ({total_params} params). Skipping.")
        else:
            print(f"  [WARN] Unexpected param count {total_params} "
                  f"(expected old={expected_old} or new={expected_new}). "
                  f"Deleting optimizer.pt β€” optimizer will restart fresh.")
            opt_path.unlink()
    else:
        print("\n[2/2] No optimizer.pt found. Optimizer will restart fresh.")

    print("\nMigration complete!")


if __name__ == "__main__":
    if len(sys.argv) != 2:
        print(__doc__)
        sys.exit(1)
    migrate(Path(sys.argv[1]))