CompressedGemma commited on
Commit
eb88431
·
verified ·
1 Parent(s): 144b8fe

Generate functional neurons from source neuron

Browse files
Files changed (1) hide show
  1. genneuron.py +456 -0
genneuron.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate new neurons by sampling in functional parameter space.
4
+
5
+ Each neuron is a piecewise-linear function fully described by 6 values:
6
+ (boundary_x1, boundary_x2, left_slope, mid_slope, right_slope, y_boundary2)
7
+
8
+ We extract these from your existing neurons, fit a distribution over them,
9
+ sample new combinations, and reconstruct valid W1/b1/W2/b2 for each.
10
+ """
11
+
12
+ import numpy as np
13
+ import torch
14
+ from safetensors.torch import load_file, save_file
15
+ from pathlib import Path
16
+ import json
17
+ import argparse
18
+ import os
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Config
22
+ # ---------------------------------------------------------------------------
23
+
24
+ NEURON_SOURCE = "multi" # "single" | "multi"
25
+ SINGLE_FILE = "test_mlp_hf/model.safetensors"
26
+ MULTI_DIR = "source_llm_neurons"
27
+
28
+ SINGLE_BOUNDARY_MODE = True # Generate single-boundary neurons (2 active) instead of double-boundary (3 active)
29
+ N_GENERATE = 500 # generate 500 neurons
30
+ OUTPUT_DIR = "generated_neurons"
31
+ RANDOM_SEED = 42
32
+
33
+ # Generation strategy:
34
+ # "gaussian" — fit mean/cov to existing neurons, sample from N(mu, sigma)
35
+ # "interpolate" — convex combinations of pairs of existing neurons
36
+ # "grid" — systematic grid over the observed parameter ranges
37
+ # "all" — produce all three sets
38
+ STRATEGY = "all"
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # 1. Load existing neurons
43
+ # ---------------------------------------------------------------------------
44
+
45
+ def load_neurons(source, single_file, multi_dir):
46
+ neurons = []
47
+ if source == "single":
48
+ w = load_file(single_file)
49
+ neurons.append({k: v.float().numpy() for k, v in {
50
+ "W1": w["layer1.weight"],
51
+ "b1": w["layer1.bias"],
52
+ "W2": w["layer2.weight"],
53
+ "b2": w["layer2.bias"],
54
+ }.items()})
55
+ elif source == "multi":
56
+ for f in sorted(Path(multi_dir).glob("neuron_*.safetensors")):
57
+ w = load_file(str(f))
58
+ neurons.append({k: v.float().numpy() for k, v in {
59
+ "W1": w["layer1.weight"],
60
+ "b1": w["layer1.bias"],
61
+ "W2": w["layer2.weight"],
62
+ "b2": w["layer2.bias"],
63
+ }.items()})
64
+ return neurons
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # 2. Extract functional parameters from raw weights
69
+ # ---------------------------------------------------------------------------
70
+
71
+ def weights_to_functional(W1, b1, W2, b2, x_probe_range=(-2.0, 2.0), n_probe=200000):
72
+ xs = np.linspace(x_probe_range[0], x_probe_range[1], n_probe)
73
+
74
+ def forward(x_scalar):
75
+ x = np.array([[x_scalar]], dtype=np.float32)
76
+ h = np.maximum(0, x @ W1.T + b1)
77
+ y = h @ W2.T + b2
78
+ return float(y.squeeze())
79
+
80
+ ys = np.array([forward(x) for x in xs])
81
+
82
+ slopes = np.gradient(ys, xs)
83
+ slope_changes = np.abs(np.gradient(slopes, xs))
84
+
85
+ peak_window = int(n_probe * 0.1)
86
+ idx1 = int(np.argmax(slope_changes))
87
+
88
+ masked_changes = slope_changes.copy()
89
+ l_mask = max(0, idx1 - peak_window)
90
+ r_mask = min(n_probe, idx1 + peak_window)
91
+ masked_changes[l_mask:r_mask] = 0.0
92
+
93
+ idx2 = int(np.argmax(masked_changes))
94
+
95
+ if idx1 > idx2:
96
+ idx1, idx2 = idx2, idx1
97
+
98
+ boundary_x1 = float(xs[idx1])
99
+ boundary_x2 = float(xs[idx2])
100
+
101
+ margin = int(n_probe * 0.03)
102
+
103
+ idx_l = max(0, idx1 - margin)
104
+ idx_m1 = min(n_probe - 1, idx1 + margin)
105
+ idx_m2 = max(0, idx2 - margin)
106
+ idx_r = min(n_probe - 1, idx2 + margin)
107
+
108
+ left_slope = float(np.mean(slopes[:idx_l])) if idx_l > 0 else float(slopes[0])
109
+
110
+ if idx_m2 > idx_m1:
111
+ mid_slope = float(np.mean(slopes[idx_m1:idx_m2]))
112
+ else:
113
+ mid_slope = float(slopes[(idx1 + idx2) // 2])
114
+
115
+ right_slope = float(np.mean(slopes[idx_r:])) if idx_r < n_probe - 1 else float(slopes[-1])
116
+ y_boundary2 = float(ys[idx2])
117
+
118
+ return {
119
+ "boundary_x1": boundary_x1,
120
+ "boundary_x2": boundary_x2,
121
+ "left_slope": left_slope,
122
+ "mid_slope": mid_slope,
123
+ "right_slope": right_slope,
124
+ "y_boundary2": y_boundary2,
125
+ }
126
+
127
+
128
+ # ---------------------------------------------------------------------------
129
+ # 3. Reconstruct weights from functional parameters
130
+ # ---------------------------------------------------------------------------
131
+
132
+ def functional_to_weights(boundary_x1, boundary_x2, left_slope, mid_slope, right_slope, y_boundary2,
133
+ n_hidden=8):
134
+ if boundary_x1 > boundary_x2:
135
+ boundary_x1, boundary_x2 = boundary_x2, boundary_x1
136
+
137
+ W1 = np.zeros((n_hidden, 1), dtype=np.float32)
138
+ b1 = np.zeros(n_hidden, dtype=np.float32)
139
+ W2 = np.zeros((1, n_hidden), dtype=np.float32)
140
+ b2 = np.zeros(1, dtype=np.float32)
141
+
142
+ # Neuron 0: always active, pure slope carrier
143
+ W1[0, 0] = 1.0
144
+ b1[0] = 100.0 # Ensures carrier stability during extreme negative activation outliers
145
+ W2[0, 0] = right_slope
146
+
147
+ # Neuron 1: active left of boundary_x1
148
+ W1[1, 0] = -1.0
149
+ b1[1] = boundary_x1
150
+ W2[0, 1] = -(left_slope - mid_slope)
151
+
152
+ # Neuron 2: active left of boundary_x2
153
+ W1[2, 0] = -1.0
154
+ b1[2] = boundary_x2
155
+ W2[0, 2] = -(mid_slope - right_slope)
156
+
157
+ target_y = y_boundary2
158
+ neuron0_out = W2[0, 0] * (W1[0, 0] * boundary_x2 + b1[0])
159
+ b2[0] = target_y - neuron0_out
160
+
161
+ return W1, b1, W2, b2
162
+
163
+
164
+ def functional_to_weights_single(boundary_x, left_slope, right_slope, y_at_boundary,
165
+ n_hidden=8):
166
+ """Single-boundary version: only 2 active neurons (carrier + 1 transition)"""
167
+ W1 = np.zeros((n_hidden, 1), dtype=np.float32)
168
+ b1 = np.zeros(n_hidden, dtype=np.float32)
169
+ W2 = np.zeros((1, n_hidden), dtype=np.float32)
170
+ b2 = np.zeros(1, dtype=np.float32)
171
+
172
+ # Neuron 0: always active, pure slope carrier (carries right_slope)
173
+ W1[0, 0] = 1.0
174
+ b1[0] = 100.0
175
+ W2[0, 0] = right_slope
176
+
177
+ # Neuron 1: active left of boundary_x (adds left_slope - right_slope)
178
+ W1[1, 0] = -1.0
179
+ b1[1] = boundary_x
180
+ W2[0, 1] = -(left_slope - right_slope)
181
+
182
+ # Calculate b2 for continuity at boundary
183
+ target_y = y_at_boundary
184
+ neuron0_out = W2[0, 0] * (W1[0, 0] * boundary_x + b1[0])
185
+ b2[0] = target_y - neuron0_out
186
+
187
+ return W1, b1, W2, b2
188
+
189
+
190
+ # ---------------------------------------------------------------------------
191
+ # 4. Validate a generated neuron (analytical, not numerical gradient)
192
+ # ---------------------------------------------------------------------------
193
+
194
+ def _mlp_forward(x_scalar, W1, b1, W2, b2):
195
+ x = np.array([[x_scalar]], dtype=np.float32)
196
+ h = np.maximum(0.0, x @ W1.T + b1)
197
+ return float((h @ W2.T + b2).squeeze())
198
+
199
+
200
+ def validate_neuron(W1, b1, W2, b2, params, tol=0.05):
201
+ bx1 = params["boundary_x1"]
202
+ bx2 = params["boundary_x2"]
203
+
204
+ # Dynamically scale probes so we don't accidentally step over boundaries
205
+ # when random generation places bx1 and bx2 extremely close together.
206
+ dist = max(abs(bx2 - bx1), 1e-6)
207
+ eps = min(1e-3, dist / 10.0)
208
+ gap = min(0.05, dist / 4.0)
209
+
210
+ y_at_bx2 = _mlp_forward(bx2, W1, b1, W2, b2)
211
+
212
+ slope_left = (_mlp_forward(bx1 - gap, W1, b1, W2, b2) -
213
+ _mlp_forward(bx1 - gap - eps, W1, b1, W2, b2)) / eps
214
+
215
+ x_mid = (bx1 + bx2) / 2
216
+ slope_mid = (_mlp_forward(x_mid + eps, W1, b1, W2, b2) -
217
+ _mlp_forward(x_mid, W1, b1, W2, b2)) / eps
218
+
219
+ slope_right = (_mlp_forward(bx2 + gap + eps, W1, b1, W2, b2) -
220
+ _mlp_forward(bx2 + gap, W1, b1, W2, b2)) / eps
221
+
222
+ recovered = {
223
+ "boundary_x1": bx1,
224
+ "boundary_x2": bx2,
225
+ "left_slope": slope_left,
226
+ "mid_slope": slope_mid,
227
+ "right_slope": slope_right,
228
+ "y_boundary2": y_at_bx2,
229
+ }
230
+
231
+ checks = {
232
+ "left_slope": abs(slope_left - params["left_slope"]) < tol,
233
+ "mid_slope": abs(slope_mid - params["mid_slope"]) < tol,
234
+ "right_slope": abs(slope_right - params["right_slope"]) < tol,
235
+ "y_boundary2": abs(y_at_bx2 - params["y_boundary2"]) < tol * 5,
236
+ }
237
+ return all(checks.values()), checks, recovered
238
+
239
+
240
+ def validate_neuron_single(W1, b1, W2, b2, params, tol=0.05):
241
+ """Validate single-boundary neuron (only 2 slopes)"""
242
+ bx = params["boundary_x"]
243
+ eps = 1e-3
244
+ gap = 0.05
245
+
246
+ y_at_bx = _mlp_forward(bx, W1, b1, W2, b2)
247
+
248
+ slope_left = (_mlp_forward(bx - gap, W1, b1, W2, b2) -
249
+ _mlp_forward(bx - gap - eps, W1, b1, W2, b2)) / eps
250
+
251
+ slope_right = (_mlp_forward(bx + gap + eps, W1, b1, W2, b2) -
252
+ _mlp_forward(bx + gap, W1, b1, W2, b2)) / eps
253
+
254
+ recovered = {
255
+ "boundary_x": bx,
256
+ "left_slope": slope_left,
257
+ "right_slope": slope_right,
258
+ "y_at_boundary": y_at_bx,
259
+ }
260
+
261
+ checks = {
262
+ "left_slope": abs(slope_left - params["left_slope"]) < tol,
263
+ "right_slope": abs(slope_right - params["right_slope"]) < tol,
264
+ "y_at_boundary": abs(y_at_bx - params["y_at_boundary"]) < tol * 5,
265
+ }
266
+ return all(checks.values()), checks, recovered
267
+
268
+
269
+ # ---------------------------------------------------------------------------
270
+ # 5. Generation strategies
271
+ # ---------------------------------------------------------------------------
272
+
273
+ def strategy_gaussian(functional_params, n, rng):
274
+ mat = np.array([
275
+ [p["boundary_x1"], p["boundary_x2"], p["left_slope"], p["mid_slope"], p["right_slope"], p["y_boundary2"]]
276
+ for p in functional_params
277
+ ])
278
+
279
+ mu = mat.mean(axis=0)
280
+ cov = np.cov(mat.T) if len(mat) > 1 else np.eye(6) * 0.1
281
+ cov += np.eye(6) * 1e-4
282
+
283
+ samples = rng.multivariate_normal(mu, cov, size=n)
284
+ return [
285
+ {"boundary_x1": s[0], "boundary_x2": s[1], "left_slope": s[2],
286
+ "mid_slope": s[3], "right_slope": s[4], "y_boundary2": s[5]}
287
+ for s in samples
288
+ ]
289
+
290
+
291
+ def strategy_interpolate(functional_params, n, rng):
292
+ results = []
293
+ fp = functional_params
294
+ for _ in range(n):
295
+ i, j = rng.choice(len(fp), size=2, replace=True)
296
+ t = rng.uniform(0, 1)
297
+ results.append({
298
+ k: (1 - t) * fp[i][k] + t * fp[j][k]
299
+ for k in fp[i]
300
+ })
301
+ return results
302
+
303
+
304
+ def strategy_grid(functional_params, n, rng):
305
+ def get_range(vals, margin=0.2):
306
+ v_min, v_max = min(vals), max(vals)
307
+ if v_min == v_max:
308
+ # Prevent 0-variance collapse by injecting a spread for single neurons
309
+ offset = abs(v_min) * margin if v_min != 0 else margin
310
+ return v_min - offset, v_max + offset
311
+ return v_min, v_max
312
+
313
+ bx1_min, bx1_max = get_range([p["boundary_x1"] for p in functional_params])
314
+ bx2_min, bx2_max = get_range([p["boundary_x2"] for p in functional_params])
315
+ ls_min, ls_max = get_range([p["left_slope"] for p in functional_params])
316
+ ms_min, ms_max = get_range([p["mid_slope"] for p in functional_params])
317
+ rs_min, rs_max = get_range([p["right_slope"] for p in functional_params])
318
+ yb_min, yb_max = get_range([p["y_boundary2"] for p in functional_params])
319
+
320
+ side = max(2, int(n ** (1.0/6.0)) + 1)
321
+
322
+ grid = []
323
+ for bx1i in np.linspace(bx1_min, bx1_max, side):
324
+ for bx2i in np.linspace(bx2_min, bx2_max, side):
325
+ for lsi in np.linspace(ls_min, ls_max, side):
326
+ for msi in np.linspace(ms_min, ms_max, side):
327
+ for rsi in np.linspace(rs_min, rs_max, side):
328
+ for ybi in np.linspace(yb_min, yb_max, side):
329
+ grid.append({
330
+ "boundary_x1": bx1i, "boundary_x2": bx2i,
331
+ "left_slope": lsi, "mid_slope": msi,
332
+ "right_slope": rsi, "y_boundary2": ybi,
333
+ })
334
+
335
+ rng.shuffle(grid)
336
+ while len(grid) < n:
337
+ grid += grid
338
+ return grid[:n]
339
+
340
+
341
+ # ---------------------------------------------------------------------------
342
+ # 6. Main
343
+ # ---------------------------------------------------------------------------
344
+
345
+ if __name__ == "__main__":
346
+ rng = np.random.default_rng(RANDOM_SEED)
347
+ out = Path(OUTPUT_DIR)
348
+ out.mkdir(exist_ok=True)
349
+
350
+ print("=" * 60)
351
+ print("Generating new neurons from existing ones (Multi-Boundary)")
352
+ print("=" * 60)
353
+
354
+ print("\n[1] Loading existing neurons...")
355
+ neurons = load_neurons(NEURON_SOURCE, SINGLE_FILE, MULTI_DIR)
356
+ print(f" {len(neurons)} source neuron(s)")
357
+
358
+ print("\n[2] Extracting functional parameters...")
359
+ functional_params = []
360
+ for k, n in enumerate(neurons):
361
+ p = weights_to_functional(n["W1"], n["b1"], n["W2"], n["b2"])
362
+ functional_params.append(p)
363
+ print(f" Neuron {k}: boundary1={p['boundary_x1']:+.4f} "
364
+ f"boundary2={p['boundary_x2']:+.4f} "
365
+ f"left_slope={p['left_slope']:+.4f} "
366
+ f"mid_slope={p['mid_slope']:+.4f} "
367
+ f"right_slope={p['right_slope']:+.4f} "
368
+ f"y@boundary2={p['y_boundary2']:+.4f}")
369
+
370
+ strategies = (
371
+ ["gaussian", "interpolate", "grid"] if STRATEGY == "all"
372
+ else [STRATEGY]
373
+ )
374
+
375
+ total_saved = 0
376
+ summary = {}
377
+
378
+ for strat in strategies:
379
+ print(f"\n[3] Generating {N_GENERATE} neurons via '{strat}'...")
380
+
381
+ if strat == "gaussian":
382
+ new_params = strategy_gaussian(functional_params, N_GENERATE, rng)
383
+ elif strat == "interpolate":
384
+ new_params = strategy_interpolate(functional_params, N_GENERATE, rng)
385
+ elif strat == "grid":
386
+ new_params = strategy_grid(functional_params, N_GENERATE, rng)
387
+ else:
388
+ raise ValueError(f"Unknown strategy: {strat}")
389
+
390
+ strat_dir = out / strat
391
+ strat_dir.mkdir(exist_ok=True)
392
+
393
+ n_valid = 0
394
+ for idx, p in enumerate(new_params):
395
+ if SINGLE_BOUNDARY_MODE:
396
+ # Convert double-boundary params to single-boundary
397
+ # Use boundary_x1 as the single boundary, ignore boundary_x2
398
+ # Use left_slope and right_slope, ignore mid_slope
399
+ # Estimate y_at_boundary from y_boundary2
400
+ W1, b1, W2, b2 = functional_to_weights_single(
401
+ p["boundary_x1"], p["left_slope"], p["right_slope"],
402
+ p["y_boundary2"],
403
+ )
404
+ # Create single-boundary params for validation
405
+ p_single = {
406
+ "boundary_x": p["boundary_x1"],
407
+ "left_slope": p["left_slope"],
408
+ "right_slope": p["right_slope"],
409
+ "y_at_boundary": p["y_boundary2"],
410
+ }
411
+ valid, checks, recovered = validate_neuron_single(W1, b1, W2, b2, p_single)
412
+ else:
413
+ W1, b1, W2, b2 = functional_to_weights(
414
+ p["boundary_x1"], p["boundary_x2"], p["left_slope"],
415
+ p["mid_slope"], p["right_slope"], p["y_boundary2"],
416
+ )
417
+ valid, checks, recovered = validate_neuron(W1, b1, W2, b2, p)
418
+
419
+ if valid:
420
+ save_file(
421
+ {
422
+ "layer1.weight": torch.tensor(W1),
423
+ "layer1.bias": torch.tensor(b1),
424
+ "layer2.weight": torch.tensor(W2),
425
+ "layer2.bias": torch.tensor(b2),
426
+ },
427
+ # Padded to 6 digits (06d) to prevent python alphabetical sorting issues downstream
428
+ str(strat_dir / f"neuron_{idx:06d}.safetensors"),
429
+ )
430
+ n_valid += 1
431
+ else:
432
+ failed = [k for k, v in checks.items() if not v]
433
+ if idx < 10 or idx % 50000 == 0:
434
+ print(f" [skip] neuron_{idx:06d}: failed checks {failed}")
435
+
436
+ pct = 100 * n_valid / N_GENERATE
437
+ print(f" Saved {n_valid}/{N_GENERATE} valid neurons ({pct:.0f}%) to {strat_dir}/")
438
+ summary[strat] = {"generated": N_GENERATE, "valid": n_valid, "path": str(strat_dir)}
439
+ total_saved += n_valid
440
+
441
+ meta = {
442
+ "source_neurons": len(neurons),
443
+ "source_functional_params": functional_params,
444
+ "strategies": summary,
445
+ "total_saved": total_saved,
446
+ }
447
+ with open(out / "generation_meta.json", "w") as f:
448
+ json.dump(meta, f, indent=2)
449
+
450
+ print(f"\n{'=' * 60}")
451
+ print(f"Total neurons generated: {total_saved}")
452
+ print(f"Metadata saved to {out}/generation_meta.json")
453
+ print(f"\nTo use generated neurons in append_neurons_to_t5.py:")
454
+ print(f" NEURON_SOURCE = 'multi'")
455
+ print(f" MULTI_DIR = '{out}/gaussian' # or interpolate / grid")
456
+ print(f"{'=' * 60}")