Trump-Forecaster / merge.py
Bturtel's picture
Upload folder using huggingface_hub
db9ac85 verified
"""Merge LoRA adapter into base model weights.
Usage:
pip install torch transformers safetensors tqdm
python merge.py --output ./merged_model
Loads the MXFP4 base model, dequantizes to bf16, applies LoRA deltas, saves merged model.
Requires ~300GB RAM. No GPU needed.
"""
import argparse
import json
import shutil
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import AutoModelForCausalLM
BASE_MODEL = "openai/gpt-oss-120b"
ADAPTER_REPO = "LightningRodLabs/Trump-Forecaster"
def merge(output_dir: str):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Download adapter
print("Downloading adapter...")
adapter_dir = Path(snapshot_download(ADAPTER_REPO))
adapter_config = json.loads((adapter_dir / "adapter_config.json").read_text())
scaling = adapter_config["lora_alpha"] / adapter_config["r"]
adapter_weights = load_file(str(adapter_dir / "adapter_model.safetensors"))
print(f"Adapter: {len(adapter_weights)} keys, scaling={scaling}")
# Load base model (dequantizes MXFP4 → bf16)
print("Loading base model (this takes a while — ~240GB bf16)...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cpu", trust_remote_code=True,
)
state_dict = base_model.state_dict()
del base_model
# Group LoRA A/B pairs
lora_pairs = {}
for key, tensor in adapter_weights.items():
clean = key.replace("base_model.model.", "", 1)
if ".lora_A.weight" in clean:
lora_pairs.setdefault(clean.replace(".lora_A.weight", ""), {})["A"] = tensor
elif ".lora_B.weight" in clean:
lora_pairs.setdefault(clean.replace(".lora_B.weight", ""), {})["B"] = tensor
# Map adapter keys → base model keys + merge operation
# Adapter uses Tinker naming; HF transformers uses different names for this model:
# attn → self_attn, w1(gate)/w3(up) → gate_up_proj (interleaved), w2 → down_proj
base_key_ops = {}
for adapter_path in lora_pairs:
if "unembed_tokens" in adapter_path:
base_key_ops.setdefault("lm_head.weight", []).append(("add", adapter_path))
elif ".attn." in adapter_path:
base_key = adapter_path.replace(".attn.", ".self_attn.") + ".weight"
base_key_ops.setdefault(base_key, []).append(("add", adapter_path))
elif ".mlp.experts.w1" in adapter_path:
prefix = adapter_path.split(".mlp.experts.w1")[0]
base_key_ops.setdefault(prefix + ".mlp.experts.gate_up_proj", []).append(("even_t", adapter_path))
elif ".mlp.experts.w3" in adapter_path:
prefix = adapter_path.split(".mlp.experts.w3")[0]
base_key_ops.setdefault(prefix + ".mlp.experts.gate_up_proj", []).append(("odd_t", adapter_path))
elif ".mlp.experts.w2" in adapter_path:
prefix = adapter_path.split(".mlp.experts.w2")[0]
base_key_ops.setdefault(prefix + ".mlp.experts.down_proj", []).append(("add_t", adapter_path))
# Apply LoRA deltas
for base_key, ops in tqdm(sorted(base_key_ops.items()), desc="Merging LoRA"):
w = state_dict[base_key].float()
for op_type, adapter_path in ops:
A = lora_pairs[adapter_path]["A"].float()
B = lora_pairs[adapter_path]["B"].float()
delta = torch.matmul(B, A) * scaling
if op_type == "add":
w += delta
elif op_type == "even_t":
w[:, :, ::2] += delta.transpose(1, 2)
elif op_type == "odd_t":
w[:, :, 1::2] += delta.transpose(1, 2)
elif op_type == "add_t":
w += delta.transpose(1, 2)
state_dict[base_key] = w.to(torch.bfloat16)
# Save sharded safetensors
print(f"Saving to {output_dir}...")
max_shard = 5 * 1024**3
shards, current, size = [], {}, 0
for k, v in state_dict.items():
nbytes = v.numel() * v.element_size()
if size + nbytes > max_shard and current:
shards.append(current)
current, size = {}, 0
current[k] = v
size += nbytes
if current:
shards.append(current)
weight_map, total = {}, 0
for i, shard in enumerate(shards):
fname = f"model-{i+1:05d}-of-{len(shards):05d}.safetensors"
save_file(shard, str(output_dir / fname))
for k, v in shard.items():
weight_map[k] = fname
total += v.numel() * v.element_size()
(output_dir / "model.safetensors.index.json").write_text(
json.dumps({"metadata": {"total_size": total}, "weight_map": weight_map}, indent=2)
)
# Copy config + tokenizer from base model (remove quantization_config)
base_cache = Path(snapshot_download(BASE_MODEL, allow_patterns=["*.py", "*.json", "tokenizer*", "*.model"]))
for f in base_cache.iterdir():
if f.is_file() and f.name != "model.safetensors.index.json":
shutil.copy2(f, output_dir / f.name)
cfg = json.loads((output_dir / "config.json").read_text())
cfg.pop("quantization_config", None)
cfg["torch_dtype"] = "bfloat16"
(output_dir / "config.json").write_text(json.dumps(cfg, indent=2))
print(f"Done! Merged model saved to {output_dir} ({len(shards)} shards)")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output", required=True, help="Output directory for merged model")
merge(parser.parse_args().output)