File size: 5,632 Bytes
db9ac85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Merge LoRA adapter into base model weights.

Usage:
    pip install torch transformers safetensors tqdm
    python merge.py --output ./merged_model

Loads the MXFP4 base model, dequantizes to bf16, applies LoRA deltas, saves merged model.
Requires ~300GB RAM. No GPU needed.
"""

import argparse
import json
import shutil
from pathlib import Path

import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import AutoModelForCausalLM

BASE_MODEL = "openai/gpt-oss-120b"
ADAPTER_REPO = "LightningRodLabs/Trump-Forecaster"


def merge(output_dir: str):
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Download adapter
    print("Downloading adapter...")
    adapter_dir = Path(snapshot_download(ADAPTER_REPO))
    adapter_config = json.loads((adapter_dir / "adapter_config.json").read_text())
    scaling = adapter_config["lora_alpha"] / adapter_config["r"]
    adapter_weights = load_file(str(adapter_dir / "adapter_model.safetensors"))
    print(f"Adapter: {len(adapter_weights)} keys, scaling={scaling}")

    # Load base model (dequantizes MXFP4 → bf16)
    print("Loading base model (this takes a while — ~240GB bf16)...")
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cpu", trust_remote_code=True,
    )
    state_dict = base_model.state_dict()
    del base_model

    # Group LoRA A/B pairs
    lora_pairs = {}
    for key, tensor in adapter_weights.items():
        clean = key.replace("base_model.model.", "", 1)
        if ".lora_A.weight" in clean:
            lora_pairs.setdefault(clean.replace(".lora_A.weight", ""), {})["A"] = tensor
        elif ".lora_B.weight" in clean:
            lora_pairs.setdefault(clean.replace(".lora_B.weight", ""), {})["B"] = tensor

    # Map adapter keys → base model keys + merge operation
    # Adapter uses Tinker naming; HF transformers uses different names for this model:
    #   attn → self_attn, w1(gate)/w3(up) → gate_up_proj (interleaved), w2 → down_proj
    base_key_ops = {}
    for adapter_path in lora_pairs:
        if "unembed_tokens" in adapter_path:
            base_key_ops.setdefault("lm_head.weight", []).append(("add", adapter_path))
        elif ".attn." in adapter_path:
            base_key = adapter_path.replace(".attn.", ".self_attn.") + ".weight"
            base_key_ops.setdefault(base_key, []).append(("add", adapter_path))
        elif ".mlp.experts.w1" in adapter_path:
            prefix = adapter_path.split(".mlp.experts.w1")[0]
            base_key_ops.setdefault(prefix + ".mlp.experts.gate_up_proj", []).append(("even_t", adapter_path))
        elif ".mlp.experts.w3" in adapter_path:
            prefix = adapter_path.split(".mlp.experts.w3")[0]
            base_key_ops.setdefault(prefix + ".mlp.experts.gate_up_proj", []).append(("odd_t", adapter_path))
        elif ".mlp.experts.w2" in adapter_path:
            prefix = adapter_path.split(".mlp.experts.w2")[0]
            base_key_ops.setdefault(prefix + ".mlp.experts.down_proj", []).append(("add_t", adapter_path))

    # Apply LoRA deltas
    for base_key, ops in tqdm(sorted(base_key_ops.items()), desc="Merging LoRA"):
        w = state_dict[base_key].float()
        for op_type, adapter_path in ops:
            A = lora_pairs[adapter_path]["A"].float()
            B = lora_pairs[adapter_path]["B"].float()
            delta = torch.matmul(B, A) * scaling
            if op_type == "add":
                w += delta
            elif op_type == "even_t":
                w[:, :, ::2] += delta.transpose(1, 2)
            elif op_type == "odd_t":
                w[:, :, 1::2] += delta.transpose(1, 2)
            elif op_type == "add_t":
                w += delta.transpose(1, 2)
        state_dict[base_key] = w.to(torch.bfloat16)

    # Save sharded safetensors
    print(f"Saving to {output_dir}...")
    max_shard = 5 * 1024**3
    shards, current, size = [], {}, 0
    for k, v in state_dict.items():
        nbytes = v.numel() * v.element_size()
        if size + nbytes > max_shard and current:
            shards.append(current)
            current, size = {}, 0
        current[k] = v
        size += nbytes
    if current:
        shards.append(current)

    weight_map, total = {}, 0
    for i, shard in enumerate(shards):
        fname = f"model-{i+1:05d}-of-{len(shards):05d}.safetensors"
        save_file(shard, str(output_dir / fname))
        for k, v in shard.items():
            weight_map[k] = fname
            total += v.numel() * v.element_size()

    (output_dir / "model.safetensors.index.json").write_text(
        json.dumps({"metadata": {"total_size": total}, "weight_map": weight_map}, indent=2)
    )

    # Copy config + tokenizer from base model (remove quantization_config)
    base_cache = Path(snapshot_download(BASE_MODEL, allow_patterns=["*.py", "*.json", "tokenizer*", "*.model"]))
    for f in base_cache.iterdir():
        if f.is_file() and f.name != "model.safetensors.index.json":
            shutil.copy2(f, output_dir / f.name)
    cfg = json.loads((output_dir / "config.json").read_text())
    cfg.pop("quantization_config", None)
    cfg["torch_dtype"] = "bfloat16"
    (output_dir / "config.json").write_text(json.dumps(cfg, indent=2))

    print(f"Done! Merged model saved to {output_dir} ({len(shards)} shards)")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--output", required=True, help="Output directory for merged model")
    merge(parser.parse_args().output)