CompressedGemma commited on
Commit
baf2188
·
verified ·
1 Parent(s): 944a442

Apply weights to Gemma 4

Browse files

This is set to target 31B, expect weight source files to be huge if you use it.

Files changed (1) hide show
  1. applyweights.py +410 -0
applyweights.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Apply generated SwiGLU MLP weights to a Gemma 4 31B safetensors model.
4
+ Layer files contain gate_proj.weight / up_proj.weight / down_proj.weight
5
+ as pre-computed delta tensors — fused via Shape-Contoured Fusion (SCF).
6
+
7
+ SCF replaces the old naive additive delta approach:
8
+ - down_proj : contoured multiplicative delta (dynamic_alpha * delta * W_existing)
9
+ - gate_proj : multiplicative gamma scaling (W * (1 + clamp(delta, +/-gamma_cap)))
10
+ - up_proj : intentionally unchanged (linear path, as in fuzer.py)
11
+
12
+ Gemma 4 31B interleaved attention: 5 SWA + 1 global per period (60 layers total).
13
+ Global layers (5, 11, 17, 23, 29, 35, 41, 47, 53, 59) may carry double-wide MLP tensors;
14
+ partial coverage is handled transparently via row/col clamping.
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import shutil
20
+ from pathlib import Path
21
+
22
+ import numpy as np
23
+ import torch
24
+ from safetensors.torch import load, save_file
25
+
26
+ PROJ_KEYS = ("gate_proj.weight", "up_proj.weight", "down_proj.weight")
27
+
28
+ INTERLEAVE_PERIOD = 6
29
+ GLOBAL_LAYER_OFFSET = 5
30
+
31
+
32
+ def is_global_attention_layer(layer_idx: int) -> bool:
33
+ return (
34
+ layer_idx >= GLOBAL_LAYER_OFFSET
35
+ and (layer_idx - GLOBAL_LAYER_OFFSET) % INTERLEAVE_PERIOD == 0
36
+ )
37
+
38
+
39
+ def detect_key_prefix(tensor_keys, layer_idx: int, proj: str) -> str:
40
+ """Dynamically locate the exact key prefix in the target file.
41
+
42
+ Gemma 4 is a VLM: always prefer language_model matches over vision tower.
43
+ """
44
+ suffix = f"layers.{layer_idx}.mlp.{proj}"
45
+ matches = [k for k in tensor_keys if k.endswith(suffix)]
46
+ for k in matches:
47
+ if "language_model" in k:
48
+ return k[: -len(suffix)]
49
+ if matches:
50
+ return matches[0][: -len(suffix)]
51
+ return "model.language_model.model."
52
+
53
+
54
+ def discover_generated_layers(weights_dir: Path) -> dict:
55
+ layers = {}
56
+ for f in sorted(weights_dir.glob("layer_*.safetensors")):
57
+ try:
58
+ idx = int(f.stem.split("_")[1])
59
+ layers[idx] = f
60
+ except (IndexError, ValueError):
61
+ continue
62
+ return layers
63
+
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # Shape-Contoured Fusion applied to pre-computed delta tensors
67
+ # ---------------------------------------------------------------------------
68
+
69
+ def fuse_layer_deltas(
70
+ layer_idx: int,
71
+ gate_w: torch.Tensor, # float32, modified in-place
72
+ up_w: torch.Tensor, # float32, intentionally NOT modified
73
+ down_w: torch.Tensor, # float32, modified in-place
74
+ new_weights: dict,
75
+ args: argparse.Namespace,
76
+ ) -> None:
77
+ """
78
+ Apply SCF to one layer using pre-computed delta tensors.
79
+
80
+ down_proj -- contoured additive:
81
+ delta is scaled by the existing weight profile so the update respects
82
+ the model's learned contour. dynamic_alpha is variance-normalised so
83
+ scale stays consistent across layers regardless of initialisation.
84
+
85
+ gate_proj -- multiplicative gamma:
86
+ gamma = 1 + clamp(delta, +-gamma_cap)
87
+ Matches fuzer's W*gamma pattern without needing raw adapter weights.
88
+
89
+ up_proj -- unchanged:
90
+ Linear value path in SwiGLU must not receive non-linear scaling.
91
+ Intentional, mirrors fuzer's explicit decision.
92
+ """
93
+
94
+ # down_proj: contoured multiplicative delta
95
+ if "down_proj.weight" in new_weights:
96
+ delta_down = new_weights["down_proj.weight"].float()
97
+ nr = min(delta_down.shape[0], down_w.shape[0])
98
+ nc = min(delta_down.shape[1], down_w.shape[1])
99
+
100
+ fan_in = down_w.shape[1]
101
+ expected_var = 1.0 / fan_in
102
+ down_var = down_w[:nr, :nc].var().item()
103
+ dynamic_alpha = float(np.clip(
104
+ args.alpha * (down_var / (expected_var + 1e-8)),
105
+ args.alpha * 0.1,
106
+ args.alpha * 10.0,
107
+ ))
108
+
109
+ contoured = dynamic_alpha * delta_down[:nr, :nc] * down_w[:nr, :nc]
110
+ down_w[:nr, :nc] = down_w[:nr, :nc] + contoured
111
+
112
+ if nr < down_w.shape[0] or nc < down_w.shape[1]:
113
+ print(f" [warn] Layer {layer_idx}: down_proj delta covers "
114
+ f"{nr}x{nc} of {down_w.shape[0]}x{down_w.shape[1]} -- partial fusion")
115
+
116
+ # gate_proj: multiplicative gamma
117
+ if "gate_proj.weight" in new_weights:
118
+ delta_gate = new_weights["gate_proj.weight"].float()
119
+ nr = min(delta_gate.shape[0], gate_w.shape[0])
120
+ nc = min(delta_gate.shape[1], gate_w.shape[1])
121
+
122
+ gamma = 1.0 + delta_gate[:nr, :nc].clamp(-args.gamma_cap, args.gamma_cap)
123
+ gate_w[:nr, :nc] = gate_w[:nr, :nc] * gamma
124
+
125
+ # up_proj: intentionally untouched -- linear path must stay unchanged
126
+
127
+
128
+ # ---------------------------------------------------------------------------
129
+ # Single-file apply
130
+ # ---------------------------------------------------------------------------
131
+
132
+ def apply_single_file(model_path: Path, output_dir: Path, layer_files: dict, args) -> int:
133
+ dry_run = args.dry_run
134
+ print(f"\n[model] Processing file: {model_path.name}")
135
+
136
+ with open(model_path, "rb") as f:
137
+ tensors = load(f.read())
138
+
139
+ fused = 0
140
+ skipped = 0
141
+
142
+ for layer_idx, layer_path in sorted(layer_files.items()):
143
+ layer_type = "global" if is_global_attention_layer(layer_idx) else "swa"
144
+
145
+ with open(layer_path, "rb") as f:
146
+ new_weights = load(f.read())
147
+
148
+ if not any(k in new_weights for k in PROJ_KEYS):
149
+ print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. "
150
+ f"Got: {list(new_weights.keys())}")
151
+ skipped += 1
152
+ continue
153
+
154
+ proj_model_keys = {}
155
+ all_found = True
156
+ for proj in PROJ_KEYS:
157
+ prefix = detect_key_prefix(tensors.keys(), layer_idx, proj)
158
+ model_key = f"{prefix}layers.{layer_idx}.mlp.{proj}"
159
+ if model_key not in tensors:
160
+ print(f" [skip] Key not found in model: {model_key!r}")
161
+ all_found = False
162
+ break
163
+ proj_model_keys[proj] = model_key
164
+
165
+ if not all_found:
166
+ skipped += 1
167
+ continue
168
+
169
+ gate_key = proj_model_keys["gate_proj.weight"]
170
+ up_key = proj_model_keys["up_proj.weight"]
171
+ down_key = proj_model_keys["down_proj.weight"]
172
+
173
+ orig_gate_dtype = tensors[gate_key].dtype
174
+ orig_down_dtype = tensors[down_key].dtype
175
+
176
+ gate_w = tensors[gate_key].clone().float()
177
+ up_w = tensors[up_key].clone().float()
178
+ down_w = tensors[down_key].clone().float()
179
+
180
+ if not dry_run:
181
+ fuse_layer_deltas(layer_idx, gate_w, up_w, down_w, new_weights, args)
182
+ tensors[gate_key] = gate_w.to(orig_gate_dtype)
183
+ # up_w unchanged by SCF -- no write-back needed
184
+ tensors[down_key] = down_w.to(orig_down_dtype)
185
+
186
+ fused += 1
187
+ print(f" {'[dry]' if dry_run else '[ok]'} Fused layer {layer_idx:02d} [{layer_type}]"
188
+ f" gate*gamma + down contoured (up unchanged)")
189
+
190
+ if skipped > 0 and fused == 0:
191
+ raise RuntimeError(
192
+ f"No layers were fused -- all {skipped} layer(s) were skipped.\n"
193
+ f"Sample model keys: {list(tensors.keys())[:4]}"
194
+ )
195
+ if skipped > 0:
196
+ print(f" [warn] {skipped} layer(s) skipped, {fused} fused.")
197
+
198
+ if not dry_run:
199
+ out_path = output_dir / model_path.name
200
+ save_file(tensors, str(out_path))
201
+ print(f" Saved -> {out_path.resolve()}")
202
+
203
+ return fused
204
+
205
+
206
+ # ---------------------------------------------------------------------------
207
+ # Sharded apply
208
+ # ---------------------------------------------------------------------------
209
+
210
+ def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) -> int:
211
+ dry_run = args.dry_run
212
+ index_path = model_dir / "model.safetensors.index.json"
213
+ if not index_path.exists():
214
+ raise FileNotFoundError(f"Sharded index missing: {index_path}")
215
+
216
+ with open(index_path) as f:
217
+ index = json.load(f)
218
+ weight_map = index["weight_map"]
219
+
220
+ # Per-projection fusion plan keyed by shard.
221
+ # Each entry: (layer_idx, proj, model_key, delta_tensor, layer_type).
222
+ # A layer whose projections span multiple shards will appear in several
223
+ # shard buckets — one entry per projection — instead of being skipped.
224
+ fusion_plan: dict = {}
225
+ skipped = 0
226
+
227
+ for layer_idx, layer_path in sorted(layer_files.items()):
228
+ layer_type = "global" if is_global_attention_layer(layer_idx) else "swa"
229
+
230
+ with open(layer_path, "rb") as f:
231
+ new_weights = load(f.read())
232
+
233
+ if not any(k in new_weights for k in PROJ_KEYS):
234
+ print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. "
235
+ f"Got: {list(new_weights.keys())}")
236
+ skipped += 1
237
+ continue
238
+
239
+ proj_registered = 0
240
+ for proj in PROJ_KEYS:
241
+ if proj not in new_weights:
242
+ continue
243
+ prefix = detect_key_prefix(weight_map.keys(), layer_idx, proj)
244
+ model_key = f"{prefix}layers.{layer_idx}.mlp.{proj}"
245
+ if model_key not in weight_map:
246
+ print(f" [skip] Layer {layer_idx}: {model_key!r} not in weight_map")
247
+ continue
248
+ shard_name = weight_map[model_key]
249
+ fusion_plan.setdefault(shard_name, []).append(
250
+ (layer_idx, proj, model_key, new_weights[proj], layer_type)
251
+ )
252
+ proj_registered += 1
253
+
254
+ if proj_registered == 0:
255
+ skipped += 1
256
+
257
+ if not fusion_plan:
258
+ sample = list(weight_map.keys())[:6]
259
+ raise RuntimeError(
260
+ f"No layers matched in weight_map. Sample keys: {sample}"
261
+ )
262
+
263
+ if not dry_run:
264
+ if output_dir.exists():
265
+ shutil.rmtree(output_dir)
266
+ shutil.copytree(model_dir, output_dir)
267
+
268
+ fused_layer_idxs: set = set()
269
+
270
+ for shard_name, ops in sorted(fusion_plan.items()):
271
+ shard_src = model_dir / shard_name
272
+ shard_dst = output_dir / shard_name
273
+
274
+ with open(shard_src, "rb") as f:
275
+ tensors = load(f.read())
276
+
277
+ # Re-group by layer so fuse_layer_deltas is called once per layer per shard.
278
+ by_layer: dict = {}
279
+ for layer_idx, proj, model_key, delta, layer_type in ops:
280
+ by_layer.setdefault(layer_idx, []).append((proj, model_key, delta, layer_type))
281
+
282
+ for layer_idx, proj_ops in sorted(by_layer.items()):
283
+ layer_type = proj_ops[0][3]
284
+
285
+ # Deltas restricted to projections whose tensors live in this shard.
286
+ # fuse_layer_deltas gates every block on presence in new_weights, so
287
+ # absent projections are never touched regardless of the tensor passed.
288
+ partial_new_weights = {proj: delta for proj, _, delta, _ in proj_ops}
289
+
290
+ # Build weight tensors for projections present in this shard; supply
291
+ # an empty sentinel for absent slots — they are never accessed because
292
+ # their keys are absent from partial_new_weights.
293
+ proj_tensors = {
294
+ proj: (model_key, tensors[model_key].clone().float())
295
+ for proj, model_key, _, _ in proj_ops
296
+ }
297
+ gate_w = proj_tensors.get("gate_proj.weight", (None, torch.empty(0)))[1]
298
+ up_w = proj_tensors.get("up_proj.weight", (None, torch.empty(0)))[1]
299
+ down_w = proj_tensors.get("down_proj.weight", (None, torch.empty(0)))[1]
300
+
301
+ orig_dtypes = {
302
+ proj: tensors[model_key].dtype
303
+ for proj, model_key, _, _ in proj_ops
304
+ }
305
+
306
+ if not dry_run:
307
+ fuse_layer_deltas(layer_idx, gate_w, up_w, down_w, partial_new_weights, args)
308
+ for proj, model_key, _, _ in proj_ops:
309
+ if proj == "gate_proj.weight":
310
+ tensors[model_key] = gate_w.to(orig_dtypes[proj])
311
+ elif proj == "down_proj.weight":
312
+ tensors[model_key] = down_w.to(orig_dtypes[proj])
313
+ # up_proj: SCF intentionally leaves it unchanged
314
+
315
+ fused_layer_idxs.add(layer_idx)
316
+ proj_names = [p.split(".")[0] for p, *_ in proj_ops]
317
+ print(f" {'[dry]' if dry_run else '[ok]'} Fused layer {layer_idx:02d} [{layer_type}]"
318
+ f" ({', '.join(proj_names)} in this shard)")
319
+
320
+ if not dry_run:
321
+ save_file(tensors, str(shard_dst))
322
+ print(f" [ok] Saved shard {shard_name} ({len(by_layer)} layer(s))")
323
+
324
+ if skipped > 0:
325
+ print(f" [warn] {skipped} layer(s) fully skipped, "
326
+ f"{len(fused_layer_idxs)} unique layer(s) fused.")
327
+
328
+ return len(fused_layer_idxs)
329
+
330
+
331
+ # ---------------------------------------------------------------------------
332
+ # Entry point
333
+ # ---------------------------------------------------------------------------
334
+
335
+ def main():
336
+ parser = argparse.ArgumentParser(
337
+ description="Apply delta weights to a model via Shape-Contoured Fusion."
338
+ )
339
+ parser.add_argument("--model", required=True)
340
+ parser.add_argument("--weights", required=True)
341
+ parser.add_argument("--output", required=True)
342
+ parser.add_argument("--layers", type=int, nargs="+", default=None)
343
+ parser.add_argument("--dry-run", action="store_true")
344
+ parser.add_argument("--alpha", type=float, default=0.02,
345
+ help="down-proj variance scale multiplier (default: 0.02)")
346
+ parser.add_argument("--gamma-cap", type=float, default=0.05,
347
+ help="max fractional gate_proj adjustment (default: 0.05)")
348
+ args = parser.parse_args()
349
+
350
+ model_path = Path(args.model)
351
+ weights_dir = Path(args.weights)
352
+ output_dir = Path(args.output)
353
+
354
+ layer_files = discover_generated_layers(weights_dir)
355
+ if not layer_files:
356
+ raise FileNotFoundError(
357
+ f"No layer_*.safetensors files found in: {weights_dir.resolve()}"
358
+ )
359
+ if args.layers is not None:
360
+ layer_files = {i: layer_files[i] for i in args.layers if i in layer_files}
361
+ if not layer_files:
362
+ available = sorted(discover_generated_layers(weights_dir).keys())
363
+ raise ValueError(f"--layers filter empty. Available: {available}")
364
+
365
+ print(f"[info] Found {len(layer_files)} layer file(s): indices {sorted(layer_files.keys())}")
366
+ print(f"[info] SCF params: alpha={args.alpha}, gamma_cap={args.gamma_cap}")
367
+
368
+ if not args.dry_run:
369
+ output_dir.mkdir(parents=True, exist_ok=True)
370
+
371
+ if model_path.is_file() and model_path.suffix == ".safetensors":
372
+ apply_single_file(model_path, output_dir, layer_files, args)
373
+
374
+ elif model_path.is_dir():
375
+ single = model_path / "model.safetensors"
376
+ index = model_path / "model.safetensors.index.json"
377
+
378
+ if single.exists() and not index.exists():
379
+ if not args.dry_run:
380
+ for f in model_path.iterdir():
381
+ if f.name != "model.safetensors":
382
+ dst = output_dir / f.name
383
+ if f.is_dir():
384
+ shutil.copytree(f, dst, dirs_exist_ok=True)
385
+ else:
386
+ shutil.copy2(f, dst)
387
+ apply_single_file(single, output_dir, layer_files, args)
388
+
389
+ elif index.exists():
390
+ apply_sharded(model_path, output_dir, layer_files, args)
391
+
392
+ else:
393
+ raise FileNotFoundError(
394
+ f"No model.safetensors or model.safetensors.index.json in {model_path}"
395
+ )
396
+ else:
397
+ raise FileNotFoundError(f"--model not found: {model_path}")
398
+
399
+ config_path = (
400
+ model_path / "config.json"
401
+ if model_path.is_dir()
402
+ else model_path.parent / "config.json"
403
+ )
404
+ if config_path.exists() and not args.dry_run:
405
+ shutil.copy2(config_path, output_dir / "config.json")
406
+ print(" [ok] Copied config.json (activation unchanged).")
407
+
408
+
409
+ if __name__ == "__main__":
410
+ main()