File size: 5,632 Bytes
db9ac85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
"""Merge LoRA adapter into base model weights.
Usage:
pip install torch transformers safetensors tqdm
python merge.py --output ./merged_model
Loads the MXFP4 base model, dequantizes to bf16, applies LoRA deltas, saves merged model.
Requires ~300GB RAM. No GPU needed.
"""
import argparse
import json
import shutil
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import AutoModelForCausalLM
BASE_MODEL = "openai/gpt-oss-120b"
ADAPTER_REPO = "LightningRodLabs/Trump-Forecaster"
def merge(output_dir: str):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Download adapter
print("Downloading adapter...")
adapter_dir = Path(snapshot_download(ADAPTER_REPO))
adapter_config = json.loads((adapter_dir / "adapter_config.json").read_text())
scaling = adapter_config["lora_alpha"] / adapter_config["r"]
adapter_weights = load_file(str(adapter_dir / "adapter_model.safetensors"))
print(f"Adapter: {len(adapter_weights)} keys, scaling={scaling}")
# Load base model (dequantizes MXFP4 → bf16)
print("Loading base model (this takes a while — ~240GB bf16)...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cpu", trust_remote_code=True,
)
state_dict = base_model.state_dict()
del base_model
# Group LoRA A/B pairs
lora_pairs = {}
for key, tensor in adapter_weights.items():
clean = key.replace("base_model.model.", "", 1)
if ".lora_A.weight" in clean:
lora_pairs.setdefault(clean.replace(".lora_A.weight", ""), {})["A"] = tensor
elif ".lora_B.weight" in clean:
lora_pairs.setdefault(clean.replace(".lora_B.weight", ""), {})["B"] = tensor
# Map adapter keys → base model keys + merge operation
# Adapter uses Tinker naming; HF transformers uses different names for this model:
# attn → self_attn, w1(gate)/w3(up) → gate_up_proj (interleaved), w2 → down_proj
base_key_ops = {}
for adapter_path in lora_pairs:
if "unembed_tokens" in adapter_path:
base_key_ops.setdefault("lm_head.weight", []).append(("add", adapter_path))
elif ".attn." in adapter_path:
base_key = adapter_path.replace(".attn.", ".self_attn.") + ".weight"
base_key_ops.setdefault(base_key, []).append(("add", adapter_path))
elif ".mlp.experts.w1" in adapter_path:
prefix = adapter_path.split(".mlp.experts.w1")[0]
base_key_ops.setdefault(prefix + ".mlp.experts.gate_up_proj", []).append(("even_t", adapter_path))
elif ".mlp.experts.w3" in adapter_path:
prefix = adapter_path.split(".mlp.experts.w3")[0]
base_key_ops.setdefault(prefix + ".mlp.experts.gate_up_proj", []).append(("odd_t", adapter_path))
elif ".mlp.experts.w2" in adapter_path:
prefix = adapter_path.split(".mlp.experts.w2")[0]
base_key_ops.setdefault(prefix + ".mlp.experts.down_proj", []).append(("add_t", adapter_path))
# Apply LoRA deltas
for base_key, ops in tqdm(sorted(base_key_ops.items()), desc="Merging LoRA"):
w = state_dict[base_key].float()
for op_type, adapter_path in ops:
A = lora_pairs[adapter_path]["A"].float()
B = lora_pairs[adapter_path]["B"].float()
delta = torch.matmul(B, A) * scaling
if op_type == "add":
w += delta
elif op_type == "even_t":
w[:, :, ::2] += delta.transpose(1, 2)
elif op_type == "odd_t":
w[:, :, 1::2] += delta.transpose(1, 2)
elif op_type == "add_t":
w += delta.transpose(1, 2)
state_dict[base_key] = w.to(torch.bfloat16)
# Save sharded safetensors
print(f"Saving to {output_dir}...")
max_shard = 5 * 1024**3
shards, current, size = [], {}, 0
for k, v in state_dict.items():
nbytes = v.numel() * v.element_size()
if size + nbytes > max_shard and current:
shards.append(current)
current, size = {}, 0
current[k] = v
size += nbytes
if current:
shards.append(current)
weight_map, total = {}, 0
for i, shard in enumerate(shards):
fname = f"model-{i+1:05d}-of-{len(shards):05d}.safetensors"
save_file(shard, str(output_dir / fname))
for k, v in shard.items():
weight_map[k] = fname
total += v.numel() * v.element_size()
(output_dir / "model.safetensors.index.json").write_text(
json.dumps({"metadata": {"total_size": total}, "weight_map": weight_map}, indent=2)
)
# Copy config + tokenizer from base model (remove quantization_config)
base_cache = Path(snapshot_download(BASE_MODEL, allow_patterns=["*.py", "*.json", "tokenizer*", "*.model"]))
for f in base_cache.iterdir():
if f.is_file() and f.name != "model.safetensors.index.json":
shutil.copy2(f, output_dir / f.name)
cfg = json.loads((output_dir / "config.json").read_text())
cfg.pop("quantization_config", None)
cfg["torch_dtype"] = "bfloat16"
(output_dir / "config.json").write_text(json.dumps(cfg, indent=2))
print(f"Done! Merged model saved to {output_dir} ({len(shards)} shards)")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output", required=True, help="Output directory for merged model")
merge(parser.parse_args().output)
|