Bturtel commited on
Commit
db9ac85
·
verified ·
1 Parent(s): b92728f

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. adapter_config.json +31 -0
  2. adapter_model.safetensors +3 -0
  3. merge.py +134 -0
adapter_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "openai/gpt-oss-120b",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": false,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 32,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "r": 32,
24
+ "rank_pattern": {},
25
+ "revision": null,
26
+ "target_modules": "all-linear",
27
+ "task_type": "CAUSAL_LM",
28
+ "trainable_token_indices": null,
29
+ "use_dora": false,
30
+ "use_rslora": false
31
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4411909ab174c93964b34b4a7a26e224b4e97645d607b9fd1d5d1591bc1a4b3
3
+ size 5257620504
merge.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Merge LoRA adapter into base model weights.
2
+
3
+ Usage:
4
+ pip install torch transformers safetensors tqdm
5
+ python merge.py --output ./merged_model
6
+
7
+ Loads the MXFP4 base model, dequantizes to bf16, applies LoRA deltas, saves merged model.
8
+ Requires ~300GB RAM. No GPU needed.
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+ import shutil
14
+ from pathlib import Path
15
+
16
+ import torch
17
+ from huggingface_hub import snapshot_download
18
+ from safetensors.torch import load_file, save_file
19
+ from tqdm import tqdm
20
+ from transformers import AutoModelForCausalLM
21
+
22
+ BASE_MODEL = "openai/gpt-oss-120b"
23
+ ADAPTER_REPO = "LightningRodLabs/Trump-Forecaster"
24
+
25
+
26
+ def merge(output_dir: str):
27
+ output_dir = Path(output_dir)
28
+ output_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ # Download adapter
31
+ print("Downloading adapter...")
32
+ adapter_dir = Path(snapshot_download(ADAPTER_REPO))
33
+ adapter_config = json.loads((adapter_dir / "adapter_config.json").read_text())
34
+ scaling = adapter_config["lora_alpha"] / adapter_config["r"]
35
+ adapter_weights = load_file(str(adapter_dir / "adapter_model.safetensors"))
36
+ print(f"Adapter: {len(adapter_weights)} keys, scaling={scaling}")
37
+
38
+ # Load base model (dequantizes MXFP4 → bf16)
39
+ print("Loading base model (this takes a while — ~240GB bf16)...")
40
+ base_model = AutoModelForCausalLM.from_pretrained(
41
+ BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cpu", trust_remote_code=True,
42
+ )
43
+ state_dict = base_model.state_dict()
44
+ del base_model
45
+
46
+ # Group LoRA A/B pairs
47
+ lora_pairs = {}
48
+ for key, tensor in adapter_weights.items():
49
+ clean = key.replace("base_model.model.", "", 1)
50
+ if ".lora_A.weight" in clean:
51
+ lora_pairs.setdefault(clean.replace(".lora_A.weight", ""), {})["A"] = tensor
52
+ elif ".lora_B.weight" in clean:
53
+ lora_pairs.setdefault(clean.replace(".lora_B.weight", ""), {})["B"] = tensor
54
+
55
+ # Map adapter keys → base model keys + merge operation
56
+ # Adapter uses Tinker naming; HF transformers uses different names for this model:
57
+ # attn → self_attn, w1(gate)/w3(up) → gate_up_proj (interleaved), w2 → down_proj
58
+ base_key_ops = {}
59
+ for adapter_path in lora_pairs:
60
+ if "unembed_tokens" in adapter_path:
61
+ base_key_ops.setdefault("lm_head.weight", []).append(("add", adapter_path))
62
+ elif ".attn." in adapter_path:
63
+ base_key = adapter_path.replace(".attn.", ".self_attn.") + ".weight"
64
+ base_key_ops.setdefault(base_key, []).append(("add", adapter_path))
65
+ elif ".mlp.experts.w1" in adapter_path:
66
+ prefix = adapter_path.split(".mlp.experts.w1")[0]
67
+ base_key_ops.setdefault(prefix + ".mlp.experts.gate_up_proj", []).append(("even_t", adapter_path))
68
+ elif ".mlp.experts.w3" in adapter_path:
69
+ prefix = adapter_path.split(".mlp.experts.w3")[0]
70
+ base_key_ops.setdefault(prefix + ".mlp.experts.gate_up_proj", []).append(("odd_t", adapter_path))
71
+ elif ".mlp.experts.w2" in adapter_path:
72
+ prefix = adapter_path.split(".mlp.experts.w2")[0]
73
+ base_key_ops.setdefault(prefix + ".mlp.experts.down_proj", []).append(("add_t", adapter_path))
74
+
75
+ # Apply LoRA deltas
76
+ for base_key, ops in tqdm(sorted(base_key_ops.items()), desc="Merging LoRA"):
77
+ w = state_dict[base_key].float()
78
+ for op_type, adapter_path in ops:
79
+ A = lora_pairs[adapter_path]["A"].float()
80
+ B = lora_pairs[adapter_path]["B"].float()
81
+ delta = torch.matmul(B, A) * scaling
82
+ if op_type == "add":
83
+ w += delta
84
+ elif op_type == "even_t":
85
+ w[:, :, ::2] += delta.transpose(1, 2)
86
+ elif op_type == "odd_t":
87
+ w[:, :, 1::2] += delta.transpose(1, 2)
88
+ elif op_type == "add_t":
89
+ w += delta.transpose(1, 2)
90
+ state_dict[base_key] = w.to(torch.bfloat16)
91
+
92
+ # Save sharded safetensors
93
+ print(f"Saving to {output_dir}...")
94
+ max_shard = 5 * 1024**3
95
+ shards, current, size = [], {}, 0
96
+ for k, v in state_dict.items():
97
+ nbytes = v.numel() * v.element_size()
98
+ if size + nbytes > max_shard and current:
99
+ shards.append(current)
100
+ current, size = {}, 0
101
+ current[k] = v
102
+ size += nbytes
103
+ if current:
104
+ shards.append(current)
105
+
106
+ weight_map, total = {}, 0
107
+ for i, shard in enumerate(shards):
108
+ fname = f"model-{i+1:05d}-of-{len(shards):05d}.safetensors"
109
+ save_file(shard, str(output_dir / fname))
110
+ for k, v in shard.items():
111
+ weight_map[k] = fname
112
+ total += v.numel() * v.element_size()
113
+
114
+ (output_dir / "model.safetensors.index.json").write_text(
115
+ json.dumps({"metadata": {"total_size": total}, "weight_map": weight_map}, indent=2)
116
+ )
117
+
118
+ # Copy config + tokenizer from base model (remove quantization_config)
119
+ base_cache = Path(snapshot_download(BASE_MODEL, allow_patterns=["*.py", "*.json", "tokenizer*", "*.model"]))
120
+ for f in base_cache.iterdir():
121
+ if f.is_file() and f.name != "model.safetensors.index.json":
122
+ shutil.copy2(f, output_dir / f.name)
123
+ cfg = json.loads((output_dir / "config.json").read_text())
124
+ cfg.pop("quantization_config", None)
125
+ cfg["torch_dtype"] = "bfloat16"
126
+ (output_dir / "config.json").write_text(json.dumps(cfg, indent=2))
127
+
128
+ print(f"Done! Merged model saved to {output_dir} ({len(shards)} shards)")
129
+
130
+
131
+ if __name__ == "__main__":
132
+ parser = argparse.ArgumentParser()
133
+ parser.add_argument("--output", required=True, help="Output directory for merged model")
134
+ merge(parser.parse_args().output)