dcostenco commited on
Commit
51a8f2c
Β·
verified Β·
1 Parent(s): bf958b1

Add training/merge_4b_v43.py

Browse files
Files changed (1) hide show
  1. training/merge_4b_v43.py +202 -0
training/merge_4b_v43.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ merge_4b_v43.py β€” Merge mlx_lm LoRA adapter into Qwen3-4B base weights.
4
+
5
+ Does NOT use mlx_lm.fuse (broken β€” silently loses LoRA during GGUF conversion).
6
+ Instead: loads safetensors directly, computes delta = (alpha/rank) * B @ A per layer,
7
+ writes merged BF16 safetensors compatible with llama.cpp convert_hf_to_gguf.py.
8
+
9
+ Usage:
10
+ python3 merge_4b_v43.py \
11
+ --base Qwen/Qwen3-4B \
12
+ --adapter /tmp/4b_v43_adapter \
13
+ --out /tmp/4b_v43_merged
14
+
15
+ Requires: transformers, safetensors, torch (or mlx)
16
+ """
17
+ import argparse
18
+ import json
19
+ import shutil
20
+ import sys
21
+ from pathlib import Path
22
+
23
+ import torch
24
+ from safetensors.torch import load_file, save_file
25
+ from transformers import AutoTokenizer, AutoConfig
26
+
27
+
28
+ def load_adapter_config(adapter_dir: Path) -> float:
29
+ """Returns the final scale factor from mlx_lm adapter_config.json.
30
+
31
+ mlx_lm stores lora_parameters.scale = alpha/rank pre-computed.
32
+ Falls back to 20.0 (default for r=8, alpha=160) if not found.
33
+ """
34
+ cfg_path = adapter_dir / "adapter_config.json"
35
+ if cfg_path.exists():
36
+ cfg = json.loads(cfg_path.read_text())
37
+ lora_params = cfg.get("lora_parameters", {})
38
+ if "scale" in lora_params:
39
+ return float(lora_params["scale"])
40
+ print("WARN: lora_parameters.scale not found β€” defaulting to 20.0")
41
+ return 20.0
42
+
43
+
44
+ def find_safetensors(directory: Path) -> list[Path]:
45
+ files = sorted(directory.glob("*.safetensors"))
46
+ if not files:
47
+ print(f"ERROR: no .safetensors files in {directory}", file=sys.stderr)
48
+ sys.exit(1)
49
+ return files
50
+
51
+
52
+ def load_all_safetensors(directory: Path) -> dict[str, torch.Tensor]:
53
+ tensors = {}
54
+ for f in find_safetensors(directory):
55
+ tensors.update(load_file(str(f), device="cpu"))
56
+ return tensors
57
+
58
+
59
+ def merge(base_dir: Path, adapter_dir: Path, out_dir: Path) -> None:
60
+ scale = load_adapter_config(adapter_dir)
61
+ print(f" LoRA scale: {scale:.4f} (from adapter_config.json lora_parameters.scale)")
62
+
63
+ print("\nLoading base model weights...")
64
+ base = load_all_safetensors(base_dir)
65
+ print(f" {len(base)} tensors loaded from base")
66
+
67
+ print("Loading adapter weights...")
68
+ adapter = load_all_safetensors(adapter_dir)
69
+ # mlx_lm adapter keys look like: model.layers.0.self_attn.q_proj.lora_a
70
+ lora_keys = [k for k in adapter if k.endswith(".lora_a")]
71
+ print(f" {len(lora_keys)} LoRA A matrices found")
72
+
73
+ if not lora_keys:
74
+ print("ERROR: no lora_a keys found in adapter β€” wrong adapter format?", file=sys.stderr)
75
+ sys.exit(1)
76
+
77
+ merged = {k: v.clone() for k, v in base.items()}
78
+ applied = 0
79
+
80
+ for a_key in lora_keys:
81
+ b_key = a_key.replace(".lora_a", ".lora_b")
82
+ # Derive base weight key: strip .lora_a suffix, map to base weight name
83
+ # mlx_lm uses e.g. "model.layers.0.self_attn.q_proj.lora_a"
84
+ # base weight is "model.layers.0.self_attn.q_proj.weight"
85
+ base_key = a_key.replace(".lora_a", ".weight")
86
+
87
+ if b_key not in adapter:
88
+ print(f" WARN: missing lora_b for {a_key} β€” skipping")
89
+ continue
90
+ if base_key not in merged:
91
+ print(f" WARN: base key {base_key} not found β€” skipping")
92
+ continue
93
+
94
+ A = adapter[a_key].float() # mlx_lm: (in_features, rank)
95
+ B = adapter[b_key].float() # mlx_lm: (rank, out_features)
96
+ W = merged[base_key].float()
97
+
98
+ # mlx_lm stores weights as (in, out) β€” delta = scale * A @ B β†’ (in, out)
99
+ delta = scale * (A @ B)
100
+
101
+ if delta.shape != W.shape:
102
+ # Fallback: try transposed orientation
103
+ if delta.T.shape == W.shape:
104
+ delta = delta.T
105
+ else:
106
+ print(f" WARN: shape mismatch {delta.shape} vs {W.shape} for {base_key} β€” skipping")
107
+ continue
108
+
109
+ merged[base_key] = (W + delta).to(torch.bfloat16)
110
+ applied += 1
111
+
112
+ print(f"\n Applied {applied}/{len(lora_keys)} LoRA deltas")
113
+ if applied == 0:
114
+ print("ERROR: zero deltas applied β€” check adapter key format", file=sys.stderr)
115
+ sys.exit(1)
116
+ if applied < len(lora_keys) * 0.9:
117
+ print(f"ERROR: only {applied}/{len(lora_keys)} deltas applied (<90%) β€” likely key mismatch", file=sys.stderr)
118
+ sys.exit(1)
119
+
120
+ # Cast all to bfloat16 for GGUF conversion
121
+ merged = {k: v.to(torch.bfloat16) if v.is_floating_point() else v for k, v in merged.items()}
122
+
123
+ print(f"\nSaving merged model to {out_dir}...")
124
+ out_dir.mkdir(parents=True, exist_ok=True)
125
+
126
+ # Split into shards of ~4GB each (llama.cpp prefers <5GB shards)
127
+ SHARD_BYTES = 4 * 1024 ** 3
128
+ shard, shard_bytes, shard_idx = {}, 0, 0
129
+ for k, v in merged.items():
130
+ size = v.numel() * v.element_size()
131
+ if shard and shard_bytes + size > SHARD_BYTES:
132
+ fname = out_dir / f"model-{shard_idx:05d}-of-XXXXX.safetensors"
133
+ save_file(shard, str(fname))
134
+ print(f" Shard {shard_idx}: {fname.name} ({shard_bytes / 1e9:.2f} GB)")
135
+ shard, shard_bytes, shard_idx = {}, 0, shard_idx + 1
136
+ shard[k] = v
137
+ shard_bytes += size
138
+ if shard:
139
+ # Single-file model: use standard name
140
+ fname = out_dir / ("model.safetensors" if shard_idx == 0 else f"model-{shard_idx:05d}-of-XXXXX.safetensors")
141
+ save_file(shard, str(fname))
142
+ print(f" Shard {shard_idx}: {fname.name} ({shard_bytes / 1e9:.2f} GB)")
143
+
144
+ # Rename shards with correct total count
145
+ shards = sorted(out_dir.glob("model-*-of-XXXXX.safetensors"))
146
+ n = len(shards)
147
+ if n > 0:
148
+ index = {"metadata": {"total_size": sum(v.numel() * v.element_size() for v in merged.values())}, "weight_map": {}}
149
+ for i, p in enumerate(shards):
150
+ new_name = f"model-{i:05d}-of-{n:05d}.safetensors"
151
+ p.rename(out_dir / new_name)
152
+ tensors = load_file(str(out_dir / new_name), device="cpu")
153
+ for k in tensors:
154
+ index["weight_map"][k] = new_name
155
+ (out_dir / "model.safetensors.index.json").write_text(json.dumps(index, indent=2))
156
+
157
+ # Copy tokenizer + config from base
158
+ for fname in ["config.json", "tokenizer.json", "tokenizer_config.json",
159
+ "special_tokens_map.json", "generation_config.json", "chat_template.jinja"]:
160
+ src = base_dir / fname
161
+ if src.exists():
162
+ shutil.copy(src, out_dir / fname)
163
+
164
+ print(f"\nβœ… Merge complete β†’ {out_dir}")
165
+ print(f" Applied {applied} LoRA deltas at scale {scale:.4f}")
166
+ print(f"\nNext: bash export_4b_v43_gguf.sh")
167
+
168
+
169
+ def main():
170
+ p = argparse.ArgumentParser()
171
+ p.add_argument("--base", type=Path, default=None,
172
+ help="Path to HF base model dir (or HF hub id β€” will be downloaded)")
173
+ p.add_argument("--adapter", type=Path, default=Path("/tmp/4b_v43_adapter"))
174
+ p.add_argument("--out", type=Path, default=Path("/tmp/4b_v43_merged"))
175
+ args = p.parse_args()
176
+
177
+ # Resolve base: try local cache first
178
+ if args.base is None:
179
+ from transformers.utils import cached_file
180
+ try:
181
+ # Trigger download/cache of config to locate cache dir
182
+ cfg_path = cached_file("Qwen/Qwen3-4B", "config.json")
183
+ args.base = Path(cfg_path).parent
184
+ print(f"Using cached base: {args.base}")
185
+ except Exception:
186
+ print("ERROR: --base not specified and Qwen/Qwen3-4B not in cache.", file=sys.stderr)
187
+ print("Run: python3 -c \"from transformers import AutoModelForCausalLM; AutoModelForCausalLM.from_pretrained('Qwen/Qwen3-4B')\"", file=sys.stderr)
188
+ sys.exit(1)
189
+ elif not args.base.exists():
190
+ print(f"ERROR: --base path not found: {args.base}", file=sys.stderr)
191
+ sys.exit(1)
192
+
193
+ if not args.adapter.exists():
194
+ print(f"ERROR: --adapter path not found: {args.adapter}", file=sys.stderr)
195
+ print("Run training first: bash train_4b_v43_local.sh", file=sys.stderr)
196
+ sys.exit(1)
197
+
198
+ merge(args.base, args.adapter, args.out)
199
+
200
+
201
+ if __name__ == "__main__":
202
+ main()