| | |
| | """Migrate checkpoint from separate Q/K/V projections to fused QKV. |
| | |
| | Usage: |
| | python3 scripts/migrate_qkv_checkpoint.py <checkpoint_dir> |
| | |
| | Migrates both model.pt AND optimizer.pt: |
| | - model.pt: q_proj/k_proj/v_proj weights → qkv_proj weight |
| | - optimizer.pt: exp_avg/exp_avg_sq states fused, param indices re-mapped |
| | |
| | The concatenation order is [Q ; K ; V] along the output (dim-0) axis, |
| | which matches the split in MultiHeadAttention.forward: |
| | q, k, v = qkv.split([_q_dim, _kv_dim, _kv_dim], dim=-1) |
| | |
| | Optimizer layout (group 0 = weight_decay, per layer × 28): |
| | [i*6+0] q_proj.weight [3072, 3072] |
| | [i*6+1] k_proj.weight [1024, 3072] |
| | [i*6+2] v_proj.weight [1024, 3072] |
| | [i*6+3] out_proj.weight [3072, 3072] |
| | [i*6+4] fc1_weight [16384, 3072] |
| | [i*6+5] fc2_weight [3072, 8192] |
| | After fusion: indices 0,1,2 → single qkv_proj → 4 params per layer. |
| | """ |
| | import sys |
| | import torch |
| | from pathlib import Path |
| |
|
| | N_LAYERS = 28 |
| | OLD_PARAMS_PER_LAYER = 6 |
| | NEW_PARAMS_PER_LAYER = 4 |
| |
|
| |
|
| | def migrate_model(state: dict) -> dict: |
| | """Fuse Q/K/V projection weights into QKV in model state dict.""" |
| | new_state: dict = {} |
| | layers_done: set = set() |
| |
|
| | for key, val in state.items(): |
| | if ".q_proj." not in key and ".k_proj." not in key and ".v_proj." not in key: |
| | new_state[key] = val |
| | continue |
| |
|
| | if ".q_proj." not in key: |
| | continue |
| |
|
| | prefix = key.rsplit(".", 2)[0] |
| | suffix = key.rsplit(".", 1)[-1] |
| |
|
| | tag = (prefix, suffix) |
| | if tag in layers_done: |
| | continue |
| | layers_done.add(tag) |
| |
|
| | q_key = f"{prefix}.q_proj.{suffix}" |
| | k_key = f"{prefix}.k_proj.{suffix}" |
| | v_key = f"{prefix}.v_proj.{suffix}" |
| |
|
| | missing = [k for k in (q_key, k_key, v_key) if k not in state] |
| | if missing: |
| | raise KeyError(f"Expected keys not found in checkpoint: {missing}") |
| |
|
| | q_w, k_w, v_w = state[q_key], state[k_key], state[v_key] |
| | fused = torch.cat([q_w, k_w, v_w], dim=0) |
| | fused_key = f"{prefix}.qkv_proj.{suffix}" |
| | new_state[fused_key] = fused |
| | print(f" Fused {fused_key}: {list(fused.shape)}" |
| | f" (q={list(q_w.shape)}, k={list(k_w.shape)}, v={list(v_w.shape)})") |
| |
|
| | leaked = [k for k in new_state if ".q_proj." in k or ".k_proj." in k or ".v_proj." in k] |
| | if leaked: |
| | raise RuntimeError(f"BUG: old projection keys still present: {leaked}") |
| |
|
| | return new_state |
| |
|
| |
|
| | def migrate_optimizer(opt_state: dict) -> dict: |
| | """Fuse optimizer states for Q/K/V → QKV and re-index parameters. |
| | |
| | The optimizer has 2 param groups: |
| | Group 0 (weight_decay): 168 = 28 layers × 6 (q,k,v,out,fc1,fc2) |
| | Group 1 (no weight_decay): 58 = norms + embedding |
| | |
| | We fuse q,k,v entries in group 0 (indices i*6+0,1,2 → one entry per layer). |
| | Group 0 shrinks from 168 to 112 (28 layers × 4 params). |
| | Group 1 stays at 58. Total: 170. |
| | """ |
| | old_state = opt_state["state"] |
| | old_groups = opt_state["param_groups"] |
| |
|
| | group0_count = len(old_groups[0]["params"]) |
| | expected_g0 = N_LAYERS * OLD_PARAMS_PER_LAYER |
| | if group0_count != expected_g0: |
| | raise ValueError( |
| | f"Group 0 has {group0_count} params, expected {expected_g0}. " |
| | f"Cannot auto-detect QKV layout." |
| | ) |
| |
|
| | |
| | shapes = [] |
| | for j in range(OLD_PARAMS_PER_LAYER): |
| | idx = old_groups[0]["params"][j] |
| | shapes.append(list(old_state[idx]["exp_avg"].shape)) |
| | expected_shapes = [[3072, 3072], [1024, 3072], [1024, 3072], |
| | [3072, 3072], [16384, 3072], [3072, 8192]] |
| | if shapes != expected_shapes: |
| | raise ValueError( |
| | f"Layer 0 shapes {shapes} don't match expected {expected_shapes}. " |
| | f"Cannot auto-detect QKV layout." |
| | ) |
| | print(f" Shape validation passed for layer 0.") |
| |
|
| | new_state_entries = {} |
| | new_idx = 0 |
| |
|
| | |
| | for layer_i in range(N_LAYERS): |
| | base = layer_i * OLD_PARAMS_PER_LAYER |
| | q_opt_idx = old_groups[0]["params"][base + 0] |
| | k_opt_idx = old_groups[0]["params"][base + 1] |
| | v_opt_idx = old_groups[0]["params"][base + 2] |
| |
|
| | q_entry = old_state[q_opt_idx] |
| | k_entry = old_state[k_opt_idx] |
| | v_entry = old_state[v_opt_idx] |
| |
|
| | |
| | fused_entry = {"step": q_entry["step"]} |
| | for field in ["exp_avg", "exp_avg_sq"]: |
| | if field in q_entry: |
| | fused_entry[field] = torch.cat( |
| | [q_entry[field], k_entry[field], v_entry[field]], dim=0 |
| | ) |
| | new_state_entries[new_idx] = fused_entry |
| | if layer_i == 0: |
| | print(f" Layer 0 QKV fused: exp_avg {list(fused_entry['exp_avg'].shape)}") |
| | new_idx += 1 |
| |
|
| | |
| | for offset in [3, 4, 5]: |
| | opt_idx = old_groups[0]["params"][base + offset] |
| | new_state_entries[new_idx] = old_state[opt_idx] |
| | new_idx += 1 |
| |
|
| | new_group0_count = new_idx |
| | print(f" Group 0: {group0_count} → {new_group0_count} params") |
| |
|
| | |
| | group1_count = len(old_groups[1]["params"]) |
| | for j in range(group1_count): |
| | opt_idx = old_groups[1]["params"][j] |
| | if opt_idx in old_state: |
| | new_state_entries[new_idx] = old_state[opt_idx] |
| | new_idx += 1 |
| | print(f" Group 1: {group1_count} → {group1_count} params (unchanged)") |
| |
|
| | |
| | new_groups = [] |
| | g0 = {k: v for k, v in old_groups[0].items() if k != "params"} |
| | g0["params"] = list(range(0, new_group0_count)) |
| | new_groups.append(g0) |
| |
|
| | g1 = {k: v for k, v in old_groups[1].items() if k != "params"} |
| | g1["params"] = list(range(new_group0_count, new_group0_count + group1_count)) |
| | new_groups.append(g1) |
| |
|
| | total = new_group0_count + group1_count |
| | print(f" Total: {len(old_state)} → {total} optimizer params") |
| |
|
| | return {"state": new_state_entries, "param_groups": new_groups} |
| |
|
| |
|
| | def migrate(ckpt_dir: Path) -> None: |
| | model_path = ckpt_dir / "model.pt" |
| | opt_path = ckpt_dir / "optimizer.pt" |
| |
|
| | if not model_path.exists(): |
| | raise FileNotFoundError(f"model.pt not found in {ckpt_dir}") |
| |
|
| | |
| | print(f"[1/2] Migrating model weights from {model_path} ...") |
| | state = torch.load(model_path, map_location="cpu", weights_only=True) |
| |
|
| | has_old = any(".q_proj." in k for k in state) |
| | has_new = any(".qkv_proj." in k for k in state) |
| |
|
| | if has_new and not has_old: |
| | print(" Model already migrated. Skipping.") |
| | elif has_old: |
| | new_model_state = migrate_model(state) |
| | torch.save(new_model_state, model_path) |
| | print(f" Model saved.") |
| | else: |
| | raise RuntimeError("Model state has neither q_proj nor qkv_proj keys!") |
| |
|
| | |
| | if opt_path.exists(): |
| | print(f"\n[2/2] Migrating optimizer states from {opt_path} ...") |
| | opt = torch.load(opt_path, map_location="cpu", weights_only=True) |
| |
|
| | |
| | total_params = sum(len(pg["params"]) for pg in opt["param_groups"]) |
| | expected_old = N_LAYERS * OLD_PARAMS_PER_LAYER + 58 |
| | expected_new = N_LAYERS * NEW_PARAMS_PER_LAYER + 58 |
| |
|
| | if total_params == expected_old: |
| | opt_backup = ckpt_dir / "optimizer.pt.backup_pre_qkv" |
| | if not opt_backup.exists(): |
| | torch.save(opt, opt_backup) |
| | print(f" Backup: {opt_backup}") |
| | new_opt = migrate_optimizer(opt) |
| | torch.save(new_opt, opt_path) |
| | print(f" Optimizer saved.") |
| | elif total_params == expected_new: |
| | print(f" Optimizer already migrated ({total_params} params). Skipping.") |
| | else: |
| | print(f" [WARN] Unexpected param count {total_params} " |
| | f"(expected old={expected_old} or new={expected_new}). " |
| | f"Deleting optimizer.pt — optimizer will restart fresh.") |
| | opt_path.unlink() |
| | else: |
| | print("\n[2/2] No optimizer.pt found. Optimizer will restart fresh.") |
| |
|
| | print("\nMigration complete!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | if len(sys.argv) != 2: |
| | print(__doc__) |
| | sys.exit(1) |
| | migrate(Path(sys.argv[1])) |
| |
|