phanerozoic commited on
Commit
9ce428c
·
verified ·
1 Parent(s): a24bb80

Upload test_gate_reconstruction.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_gate_reconstruction.py +469 -0
test_gate_reconstruction.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TEST #3: Gate-Level Reconstruction
3
+ ===================================
4
+ Given ONLY the weights, reconstruct the Boolean function each neuron computes.
5
+ Verify it matches the claimed gate type via truth table exhaustion.
6
+
7
+ A skeptic would demand: "Prove your weights actually implement the gates you
8
+ claim. Derive the function from weights alone, then verify."
9
+ """
10
+
11
+ import torch
12
+ from safetensors.torch import load_file
13
+ from itertools import product
14
+
15
+ # Load circuits
16
+ model = load_file('neural_computer.safetensors')
17
+
18
+ def heaviside(x):
19
+ return (x >= 0).float()
20
+
21
+ # Known Boolean functions (truth tables as tuples)
22
+ KNOWN_FUNCTIONS = {
23
+ # 1-input functions
24
+ 'IDENTITY': ((0,), (1,)), # f(0)=0, f(1)=1
25
+ 'NOT': ((1,), (0,)), # f(0)=1, f(1)=0
26
+ 'CONST_0': ((0,), (0,)),
27
+ 'CONST_1': ((1,), (1,)),
28
+
29
+ # 2-input functions (indexed by (0,0), (0,1), (1,0), (1,1))
30
+ 'AND': (0, 0, 0, 1),
31
+ 'OR': (0, 1, 1, 1),
32
+ 'NAND': (1, 1, 1, 0),
33
+ 'NOR': (1, 0, 0, 0),
34
+ 'XOR': (0, 1, 1, 0),
35
+ 'XNOR': (1, 0, 0, 1),
36
+ 'IMPLIES': (1, 1, 0, 1), # a -> b = ~a | b
37
+ 'NIMPLIES': (0, 0, 1, 0), # a & ~b
38
+ 'PROJ_A': (0, 0, 1, 1), # output = a
39
+ 'PROJ_B': (0, 1, 0, 1), # output = b
40
+ 'NOT_A': (1, 1, 0, 0),
41
+ 'NOT_B': (1, 0, 1, 0),
42
+ }
43
+
44
+ def identify_1input_function(w, b):
45
+ """
46
+ Given weights w and bias b for a 1-input gate,
47
+ reconstruct and identify the Boolean function.
48
+ """
49
+ truth_table = []
50
+ for x in [0, 1]:
51
+ inp = torch.tensor([float(x)])
52
+ out = int(heaviside(inp @ w + b).item())
53
+ truth_table.append(out)
54
+
55
+ truth_table = tuple(truth_table)
56
+
57
+ # Match against known functions
58
+ for name, tt in KNOWN_FUNCTIONS.items():
59
+ if len(tt) == 2 and isinstance(tt[0], tuple):
60
+ # 1-input format
61
+ if (tt[0][0], tt[1][0]) == truth_table:
62
+ return name, truth_table
63
+
64
+ return 'UNKNOWN', truth_table
65
+
66
+ def identify_2input_function(w, b):
67
+ """
68
+ Given weights w and bias b for a 2-input gate,
69
+ reconstruct and identify the Boolean function.
70
+ """
71
+ truth_table = []
72
+ for a, b_in in [(0, 0), (0, 1), (1, 0), (1, 1)]:
73
+ inp = torch.tensor([float(a), float(b_in)])
74
+ out = int(heaviside(inp @ w + b).item())
75
+ truth_table.append(out)
76
+
77
+ truth_table = tuple(truth_table)
78
+
79
+ # Match against known functions
80
+ for name, tt in KNOWN_FUNCTIONS.items():
81
+ if isinstance(tt, tuple) and len(tt) == 4 and isinstance(tt[0], int):
82
+ if tt == truth_table:
83
+ return name, truth_table
84
+
85
+ return 'UNKNOWN', truth_table
86
+
87
+ def identify_2layer_function(w1_n1, b1_n1, w1_n2, b1_n2, w2, b2):
88
+ """
89
+ Given weights for a 2-layer network (2 hidden neurons),
90
+ reconstruct and identify the Boolean function.
91
+ """
92
+ truth_table = []
93
+ for a, b_in in [(0, 0), (0, 1), (1, 0), (1, 1)]:
94
+ inp = torch.tensor([float(a), float(b_in)])
95
+
96
+ # Layer 1
97
+ h1 = heaviside(inp @ w1_n1 + b1_n1).item()
98
+ h2 = heaviside(inp @ w1_n2 + b1_n2).item()
99
+ hidden = torch.tensor([h1, h2])
100
+
101
+ # Layer 2
102
+ out = int(heaviside(hidden @ w2 + b2).item())
103
+ truth_table.append(out)
104
+
105
+ truth_table = tuple(truth_table)
106
+
107
+ for name, tt in KNOWN_FUNCTIONS.items():
108
+ if isinstance(tt, tuple) and len(tt) == 4 and isinstance(tt[0], int):
109
+ if tt == truth_table:
110
+ return name, truth_table
111
+
112
+ return 'UNKNOWN', truth_table
113
+
114
+ def analyze_threshold_gate(w, b, n_inputs):
115
+ """
116
+ Analyze a threshold gate: compute threshold and effective function.
117
+ For a gate with weights w and bias b:
118
+ - Fires when sum(w_i * x_i) + b >= 0
119
+ - Threshold = -b (fires when weighted sum >= -b)
120
+ """
121
+ w_list = w.tolist() if hasattr(w, 'tolist') else [w]
122
+ b_val = b.item() if hasattr(b, 'item') else b
123
+
124
+ threshold = -b_val
125
+
126
+ # Compute min and max weighted sums
127
+ min_sum = sum(min(0, wi) for wi in w_list)
128
+ max_sum = sum(max(0, wi) for wi in w_list)
129
+
130
+ return {
131
+ 'weights': w_list,
132
+ 'bias': b_val,
133
+ 'threshold': threshold,
134
+ 'min_sum': min_sum,
135
+ 'max_sum': max_sum,
136
+ }
137
+
138
+ # =============================================================================
139
+ # RECONSTRUCTION TESTS
140
+ # =============================================================================
141
+
142
+ def test_single_layer_gates():
143
+ """
144
+ Reconstruct all single-layer Boolean gates from weights.
145
+ """
146
+ print("\n[TEST 1] Single-Layer Gate Reconstruction")
147
+ print("-" * 60)
148
+
149
+ gates_to_test = [
150
+ ('boolean.and', 'AND', 2),
151
+ ('boolean.or', 'OR', 2),
152
+ ('boolean.nand', 'NAND', 2),
153
+ ('boolean.nor', 'NOR', 2),
154
+ ('boolean.not', 'NOT', 1),
155
+ ('boolean.implies', 'IMPLIES', 2),
156
+ ]
157
+
158
+ errors = []
159
+ results = []
160
+
161
+ for prefix, expected_name, n_inputs in gates_to_test:
162
+ w = model[f'{prefix}.weight']
163
+ b = model[f'{prefix}.bias']
164
+
165
+ if n_inputs == 1:
166
+ identified, tt = identify_1input_function(w, b)
167
+ else:
168
+ identified, tt = identify_2input_function(w, b)
169
+
170
+ analysis = analyze_threshold_gate(w, b, n_inputs)
171
+
172
+ match = identified == expected_name
173
+ status = "OK" if match else "MISMATCH"
174
+
175
+ results.append({
176
+ 'gate': prefix,
177
+ 'expected': expected_name,
178
+ 'identified': identified,
179
+ 'truth_table': tt,
180
+ 'weights': analysis['weights'],
181
+ 'bias': analysis['bias'],
182
+ 'threshold': analysis['threshold'],
183
+ 'match': match
184
+ })
185
+
186
+ if not match:
187
+ errors.append((prefix, expected_name, identified))
188
+
189
+ print(f" {prefix:25s} -> {identified:10s} [w={analysis['weights']}, b={analysis['bias']:.0f}] [{status}]")
190
+
191
+ print()
192
+ if errors:
193
+ print(f" FAILED: {len(errors)} mismatches")
194
+ for e in errors:
195
+ print(f" {e[0]}: expected {e[1]}, got {e[2]}")
196
+ else:
197
+ print(f" PASSED: {len(gates_to_test)} single-layer gates reconstructed correctly")
198
+
199
+ return len(errors) == 0, results
200
+
201
+ def test_two_layer_gates():
202
+ """
203
+ Reconstruct all two-layer Boolean gates from weights.
204
+ """
205
+ print("\n[TEST 2] Two-Layer Gate Reconstruction")
206
+ print("-" * 60)
207
+
208
+ gates_to_test = [
209
+ ('boolean.xor', 'XOR'),
210
+ ('boolean.xnor', 'XNOR'),
211
+ ('boolean.biimplies', 'XNOR'), # biimplies = XNOR
212
+ ]
213
+
214
+ errors = []
215
+ results = []
216
+
217
+ for prefix, expected_name in gates_to_test:
218
+ w1_n1 = model[f'{prefix}.layer1.neuron1.weight']
219
+ b1_n1 = model[f'{prefix}.layer1.neuron1.bias']
220
+ w1_n2 = model[f'{prefix}.layer1.neuron2.weight']
221
+ b1_n2 = model[f'{prefix}.layer1.neuron2.bias']
222
+ w2 = model[f'{prefix}.layer2.weight']
223
+ b2 = model[f'{prefix}.layer2.bias']
224
+
225
+ identified, tt = identify_2layer_function(w1_n1, b1_n1, w1_n2, b1_n2, w2, b2)
226
+
227
+ # Also identify the hidden neurons
228
+ hidden1_id, _ = identify_2input_function(w1_n1, b1_n1)
229
+ hidden2_id, _ = identify_2input_function(w1_n2, b1_n2)
230
+
231
+ match = identified == expected_name
232
+ status = "OK" if match else "MISMATCH"
233
+
234
+ results.append({
235
+ 'gate': prefix,
236
+ 'expected': expected_name,
237
+ 'identified': identified,
238
+ 'truth_table': tt,
239
+ 'hidden1': hidden1_id,
240
+ 'hidden2': hidden2_id,
241
+ 'match': match
242
+ })
243
+
244
+ if not match:
245
+ errors.append((prefix, expected_name, identified))
246
+
247
+ print(f" {prefix:25s} -> {identified:10s} [hidden: {hidden1_id} + {hidden2_id}] [{status}]")
248
+
249
+ print()
250
+ if errors:
251
+ print(f" FAILED: {len(errors)} mismatches")
252
+ else:
253
+ print(f" PASSED: {len(gates_to_test)} two-layer gates reconstructed correctly")
254
+
255
+ return len(errors) == 0, results
256
+
257
+ def test_adder_components():
258
+ """
259
+ Reconstruct and verify adder component gates.
260
+ """
261
+ print("\n[TEST 3] Adder Component Reconstruction")
262
+ print("-" * 60)
263
+
264
+ errors = []
265
+
266
+ # Half adder carry = AND
267
+ w = model['arithmetic.halfadder.carry.weight']
268
+ b = model['arithmetic.halfadder.carry.bias']
269
+ identified, tt = identify_2input_function(w, b)
270
+ status = "OK" if identified == 'AND' else "MISMATCH"
271
+ print(f" halfadder.carry -> {identified:10s} [{status}]")
272
+ if identified != 'AND':
273
+ errors.append(('halfadder.carry', 'AND', identified))
274
+
275
+ # Full adder carry_or = OR
276
+ w = model['arithmetic.fulladder.carry_or.weight']
277
+ b = model['arithmetic.fulladder.carry_or.bias']
278
+ identified, tt = identify_2input_function(w, b)
279
+ status = "OK" if identified == 'OR' else "MISMATCH"
280
+ print(f" fulladder.carry_or -> {identified:10s} [{status}]")
281
+ if identified != 'OR':
282
+ errors.append(('fulladder.carry_or', 'OR', identified))
283
+
284
+ # Ripple carry FA0 carry_or = OR
285
+ w = model['arithmetic.ripplecarry8bit.fa0.carry_or.weight']
286
+ b = model['arithmetic.ripplecarry8bit.fa0.carry_or.bias']
287
+ identified, tt = identify_2input_function(w, b)
288
+ status = "OK" if identified == 'OR' else "MISMATCH"
289
+ print(f" rc8.fa0.carry_or -> {identified:10s} [{status}]")
290
+ if identified != 'OR':
291
+ errors.append(('rc8.fa0.carry_or', 'OR', identified))
292
+
293
+ # Verify all 8 FA carry gates in ripple carry
294
+ print("\n Verifying all 8 FA carry_or gates in 8-bit ripple carry...")
295
+ for i in range(8):
296
+ w = model[f'arithmetic.ripplecarry8bit.fa{i}.carry_or.weight']
297
+ b = model[f'arithmetic.ripplecarry8bit.fa{i}.carry_or.bias']
298
+ identified, _ = identify_2input_function(w, b)
299
+ if identified != 'OR':
300
+ errors.append((f'rc8.fa{i}.carry_or', 'OR', identified))
301
+ print(f" fa{i}.carry_or: MISMATCH (got {identified})")
302
+
303
+ if not any('rc8.fa' in e[0] and 'carry_or' in e[0] for e in errors):
304
+ print(f" All 8 carry_or gates verified as OR")
305
+
306
+ print()
307
+ if errors:
308
+ print(f" FAILED: {len(errors)} mismatches")
309
+ else:
310
+ print(f" PASSED: All adder components match expected gate types")
311
+
312
+ return len(errors) == 0
313
+
314
+ def test_threshold_analysis():
315
+ """
316
+ Analyze threshold characteristics of various gates.
317
+ """
318
+ print("\n[TEST 4] Threshold Analysis")
319
+ print("-" * 60)
320
+
321
+ print(" Gate Weights Bias Threshold Function")
322
+ print(" " + "-" * 56)
323
+
324
+ gates = [
325
+ ('boolean.and', 'AND'),
326
+ ('boolean.or', 'OR'),
327
+ ('boolean.nand', 'NAND'),
328
+ ('boolean.nor', 'NOR'),
329
+ ]
330
+
331
+ for prefix, name in gates:
332
+ w = model[f'{prefix}.weight']
333
+ b = model[f'{prefix}.bias']
334
+ analysis = analyze_threshold_gate(w, b, 2)
335
+
336
+ # Verify the threshold makes sense
337
+ # AND: w=[1,1], b=-2, threshold=2 (both inputs needed)
338
+ # OR: w=[1,1], b=-1, threshold=1 (one input needed)
339
+ # NAND: w=[-1,-1], b=1, threshold=-1 (inverted AND)
340
+ # NOR: w=[-1,-1], b=0, threshold=0 (inverted OR)
341
+
342
+ w_str = str(analysis['weights'])
343
+ print(f" {prefix:18s} {w_str:15s} {analysis['bias']:6.0f} {analysis['threshold']:10.0f} {name}")
344
+
345
+ print()
346
+ print(" Interpretation:")
347
+ print(" AND: fires when sum >= 2 (both inputs must be 1)")
348
+ print(" OR: fires when sum >= 1 (at least one input is 1)")
349
+ print(" NAND: fires when sum >= -1 (always, unless both inputs are 1)")
350
+ print(" NOR: fires when sum >= 0 (only when both inputs are 0)")
351
+
352
+ return True
353
+
354
+ def test_weight_uniqueness():
355
+ """
356
+ Verify that different gate types have different weight configurations.
357
+ """
358
+ print("\n[TEST 5] Weight Configuration Uniqueness")
359
+ print("-" * 60)
360
+
361
+ configs = {}
362
+
363
+ gates = ['and', 'or', 'nand', 'nor']
364
+ for gate in gates:
365
+ w = model[f'boolean.{gate}.weight']
366
+ b = model[f'boolean.{gate}.bias']
367
+ config = (tuple(w.tolist()), b.item())
368
+ configs[gate] = config
369
+
370
+ # Check all configs are unique
371
+ unique_configs = set(configs.values())
372
+
373
+ print(f" Configurations found:")
374
+ for gate, config in configs.items():
375
+ print(f" {gate:6s}: w={config[0]}, b={config[1]}")
376
+
377
+ print()
378
+ if len(unique_configs) == len(configs):
379
+ print(f" PASSED: All {len(gates)} gate types have unique weight configurations")
380
+ return True
381
+ else:
382
+ print(f" FAILED: Some gates share weight configurations")
383
+ return False
384
+
385
+ def test_reconstruction_from_scratch():
386
+ """
387
+ Ultimate test: Given arbitrary weights, derive the Boolean function
388
+ without knowing what gate it's supposed to be.
389
+ """
390
+ print("\n[TEST 6] Blind Reconstruction (No Prior Knowledge)")
391
+ print("-" * 60)
392
+
393
+ # Pick some gates without looking at names
394
+ test_tensors = [
395
+ 'boolean.and.weight',
396
+ 'boolean.or.weight',
397
+ 'boolean.nand.weight',
398
+ 'boolean.xor.layer1.neuron1.weight',
399
+ 'arithmetic.halfadder.carry.weight',
400
+ ]
401
+
402
+ print(" Given only weights and bias, reconstructing functions...\n")
403
+
404
+ for w_name in test_tensors:
405
+ b_name = w_name.replace('.weight', '.bias')
406
+ w = model[w_name]
407
+ b = model[b_name]
408
+
409
+ identified, tt = identify_2input_function(w, b)
410
+
411
+ print(f" {w_name}")
412
+ print(f" Weights: {w.tolist()}")
413
+ print(f" Bias: {b.item()}")
414
+ print(f" Truth table: {tt}")
415
+ print(f" Identified: {identified}")
416
+ print()
417
+
418
+ print(" (All identifications derived purely from weight enumeration)")
419
+ return True
420
+
421
+ # =============================================================================
422
+ # MAIN
423
+ # =============================================================================
424
+
425
+ if __name__ == "__main__":
426
+ print("=" * 70)
427
+ print(" TEST #3: GATE-LEVEL RECONSTRUCTION")
428
+ print(" Deriving Boolean functions from weights alone")
429
+ print("=" * 70)
430
+
431
+ results = []
432
+
433
+ r1, _ = test_single_layer_gates()
434
+ results.append(("Single-layer gates", r1))
435
+
436
+ r2, _ = test_two_layer_gates()
437
+ results.append(("Two-layer gates", r2))
438
+
439
+ r3 = test_adder_components()
440
+ results.append(("Adder components", r3))
441
+
442
+ r4 = test_threshold_analysis()
443
+ results.append(("Threshold analysis", r4))
444
+
445
+ r5 = test_weight_uniqueness()
446
+ results.append(("Weight uniqueness", r5))
447
+
448
+ r6 = test_reconstruction_from_scratch()
449
+ results.append(("Blind reconstruction", r6))
450
+
451
+ print("\n" + "=" * 70)
452
+ print(" SUMMARY")
453
+ print("=" * 70)
454
+
455
+ passed = sum(1 for _, r in results if r)
456
+ total = len(results)
457
+
458
+ for name, r in results:
459
+ status = "PASS" if r else "FAIL"
460
+ print(f" {name:25s} [{status}]")
461
+
462
+ print(f"\n Total: {passed}/{total} test categories passed")
463
+
464
+ if passed == total:
465
+ print("\n STATUS: ALL GATES RECONSTRUCTED AND VERIFIED")
466
+ else:
467
+ print("\n STATUS: RECONSTRUCTION FAILURES DETECTED")
468
+
469
+ print("=" * 70)