phanerozoic commited on
Commit
a24bb80
·
verified ·
1 Parent(s): 41334a3

Upload test_equivalence.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_equivalence.py +477 -0
test_equivalence.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TEST #2: Formal Equivalence Checking
3
+ =====================================
4
+ Run 8-bit adder against Python's arithmetic for ALL 2^16 input pairs.
5
+ Bit-for-bit comparison of every result and carry flag.
6
+
7
+ A skeptic would demand: "Prove exhaustive correctness, not just sampling."
8
+ """
9
+
10
+ import torch
11
+ from safetensors.torch import load_file
12
+ import time
13
+
14
+ # Load circuits
15
+ model = load_file('neural_computer.safetensors')
16
+
17
+ def heaviside(x):
18
+ return (x >= 0).float()
19
+
20
+ def eval_xor_arith(inp, prefix):
21
+ """Evaluate XOR for arithmetic circuits."""
22
+ w1_or = model[f'{prefix}.layer1.or.weight']
23
+ b1_or = model[f'{prefix}.layer1.or.bias']
24
+ w1_nand = model[f'{prefix}.layer1.nand.weight']
25
+ b1_nand = model[f'{prefix}.layer1.nand.bias']
26
+ w2 = model[f'{prefix}.layer2.weight']
27
+ b2 = model[f'{prefix}.layer2.bias']
28
+ h_or = heaviside(inp @ w1_or + b1_or)
29
+ h_nand = heaviside(inp @ w1_nand + b1_nand)
30
+ hidden = torch.tensor([h_or.item(), h_nand.item()])
31
+ return heaviside(hidden @ w2 + b2).item()
32
+
33
+ def eval_full_adder(a, b, cin, prefix):
34
+ """Evaluate full adder, return (sum, carry_out)."""
35
+ inp_ab = torch.tensor([a, b], dtype=torch.float32)
36
+ ha1_sum = eval_xor_arith(inp_ab, f'{prefix}.ha1.sum')
37
+ w_c1 = model[f'{prefix}.ha1.carry.weight']
38
+ b_c1 = model[f'{prefix}.ha1.carry.bias']
39
+ ha1_carry = heaviside(inp_ab @ w_c1 + b_c1).item()
40
+ inp_ha2 = torch.tensor([ha1_sum, cin], dtype=torch.float32)
41
+ ha2_sum = eval_xor_arith(inp_ha2, f'{prefix}.ha2.sum')
42
+ w_c2 = model[f'{prefix}.ha2.carry.weight']
43
+ b_c2 = model[f'{prefix}.ha2.carry.bias']
44
+ ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2).item()
45
+ inp_cout = torch.tensor([ha1_carry, ha2_carry], dtype=torch.float32)
46
+ w_or = model[f'{prefix}.carry_or.weight']
47
+ b_or = model[f'{prefix}.carry_or.bias']
48
+ cout = heaviside(inp_cout @ w_or + b_or).item()
49
+ return int(ha2_sum), int(cout)
50
+
51
+ def add_8bit(a, b):
52
+ """8-bit addition using ripple carry adder."""
53
+ carry = 0.0
54
+ result_bits = []
55
+ for i in range(8):
56
+ a_bit = (a >> i) & 1
57
+ b_bit = (b >> i) & 1
58
+ s, carry = eval_full_adder(float(a_bit), float(b_bit), carry,
59
+ f'arithmetic.ripplecarry8bit.fa{i}')
60
+ result_bits.append(s)
61
+ result = sum(result_bits[i] * (2**i) for i in range(8))
62
+ return result, int(carry)
63
+
64
+ def compare_8bit(a, b):
65
+ """8-bit comparators."""
66
+ a_bits = torch.tensor([(a >> (7-i)) & 1 for i in range(8)], dtype=torch.float32)
67
+ b_bits = torch.tensor([(b >> (7-i)) & 1 for i in range(8)], dtype=torch.float32)
68
+
69
+ # Greater than
70
+ w_gt = model['arithmetic.greaterthan8bit.comparator']
71
+ gt = 1 if ((a_bits - b_bits) @ w_gt).item() > 0 else 0
72
+
73
+ # Less than
74
+ w_lt = model['arithmetic.lessthan8bit.comparator']
75
+ lt = 1 if ((b_bits - a_bits) @ w_lt).item() > 0 else 0
76
+
77
+ # Equal (neither gt nor lt)
78
+ eq = 1 if (gt == 0 and lt == 0) else 0
79
+
80
+ return gt, lt, eq
81
+
82
+ # =============================================================================
83
+ # EXHAUSTIVE TESTS
84
+ # =============================================================================
85
+
86
+ def test_addition_exhaustive():
87
+ """
88
+ Test ALL 65,536 addition combinations.
89
+ """
90
+ print("\n[TEST 1] Exhaustive 8-bit Addition: 256 x 256 = 65,536 cases")
91
+ print("-" * 60)
92
+
93
+ errors = []
94
+ start = time.perf_counter()
95
+
96
+ for a in range(256):
97
+ for b in range(256):
98
+ # Circuit result
99
+ result, carry = add_8bit(a, b)
100
+
101
+ # Python reference
102
+ full_sum = a + b
103
+ expected_result = full_sum % 256
104
+ expected_carry = 1 if full_sum > 255 else 0
105
+
106
+ # Compare
107
+ if result != expected_result:
108
+ errors.append(('result', a, b, expected_result, result))
109
+ if carry != expected_carry:
110
+ errors.append(('carry', a, b, expected_carry, carry))
111
+
112
+ # Progress every 32 rows
113
+ if (a + 1) % 32 == 0:
114
+ elapsed = time.perf_counter() - start
115
+ rate = ((a + 1) * 256) / elapsed
116
+ eta = (256 - a - 1) * 256 / rate
117
+ print(f" Progress: {a+1}/256 rows ({(a+1)*256:,} tests) "
118
+ f"| {rate:.0f} tests/sec | ETA: {eta:.1f}s")
119
+
120
+ elapsed = time.perf_counter() - start
121
+
122
+ print()
123
+ if errors:
124
+ print(f" FAILED: {len(errors)} mismatches")
125
+ for e in errors[:10]:
126
+ print(f" {e[0]}: {e[1]} + {e[2]} = {e[4]}, expected {e[3]}")
127
+ else:
128
+ print(f" PASSED: 65,536 additions verified")
129
+ print(f" Time: {elapsed:.2f}s ({65536/elapsed:.0f} tests/sec)")
130
+
131
+ return len(errors) == 0
132
+
133
+ def test_comparators_exhaustive():
134
+ """
135
+ Test ALL 65,536 comparator combinations for GT, LT, EQ.
136
+ """
137
+ print("\n[TEST 2] Exhaustive 8-bit Comparators: 256 x 256 x 3 = 196,608 checks")
138
+ print("-" * 60)
139
+
140
+ errors = []
141
+ start = time.perf_counter()
142
+
143
+ for a in range(256):
144
+ for b in range(256):
145
+ gt, lt, eq = compare_8bit(a, b)
146
+
147
+ # Python reference
148
+ exp_gt = 1 if a > b else 0
149
+ exp_lt = 1 if a < b else 0
150
+ exp_eq = 1 if a == b else 0
151
+
152
+ if gt != exp_gt:
153
+ errors.append(('GT', a, b, exp_gt, gt))
154
+ if lt != exp_lt:
155
+ errors.append(('LT', a, b, exp_lt, lt))
156
+ if eq != exp_eq:
157
+ errors.append(('EQ', a, b, exp_eq, eq))
158
+
159
+ if (a + 1) % 32 == 0:
160
+ elapsed = time.perf_counter() - start
161
+ rate = ((a + 1) * 256) / elapsed
162
+ eta = (256 - a - 1) * 256 / rate
163
+ print(f" Progress: {a+1}/256 rows | {rate:.0f} pairs/sec | ETA: {eta:.1f}s")
164
+
165
+ elapsed = time.perf_counter() - start
166
+
167
+ print()
168
+ if errors:
169
+ print(f" FAILED: {len(errors)} mismatches")
170
+ for e in errors[:10]:
171
+ print(f" {e[0]}({e[1]}, {e[2]}) = {e[4]}, expected {e[3]}")
172
+ else:
173
+ print(f" PASSED: 196,608 comparisons verified (GT, LT, EQ for each pair)")
174
+ print(f" Time: {elapsed:.2f}s")
175
+
176
+ return len(errors) == 0
177
+
178
+ def test_boolean_exhaustive():
179
+ """
180
+ Exhaustive test of all 2-input Boolean gates (4 cases each).
181
+ """
182
+ print("\n[TEST 3] Exhaustive Boolean Gates: AND, OR, NAND, NOR, XOR, XNOR")
183
+ print("-" * 60)
184
+
185
+ gates = {
186
+ 'and': lambda a, b: a & b,
187
+ 'or': lambda a, b: a | b,
188
+ 'nand': lambda a, b: 1 - (a & b),
189
+ 'nor': lambda a, b: 1 - (a | b),
190
+ }
191
+
192
+ errors = []
193
+
194
+ # Simple gates (single layer)
195
+ for gate_name, expected_fn in gates.items():
196
+ w = model[f'boolean.{gate_name}.weight']
197
+ bias = model[f'boolean.{gate_name}.bias']
198
+
199
+ for a in [0, 1]:
200
+ for b in [0, 1]:
201
+ inp = torch.tensor([float(a), float(b)])
202
+ result = int(heaviside(inp @ w + bias).item())
203
+ expected = expected_fn(a, b)
204
+
205
+ if result != expected:
206
+ errors.append((gate_name.upper(), a, b, expected, result))
207
+
208
+ # XOR (two-layer)
209
+ for a in [0, 1]:
210
+ for b in [0, 1]:
211
+ inp = torch.tensor([float(a), float(b)])
212
+
213
+ w1_n1 = model['boolean.xor.layer1.neuron1.weight']
214
+ b1_n1 = model['boolean.xor.layer1.neuron1.bias']
215
+ w1_n2 = model['boolean.xor.layer1.neuron2.weight']
216
+ b1_n2 = model['boolean.xor.layer1.neuron2.bias']
217
+ w2 = model['boolean.xor.layer2.weight']
218
+ b2 = model['boolean.xor.layer2.bias']
219
+
220
+ h1 = heaviside(inp @ w1_n1 + b1_n1)
221
+ h2 = heaviside(inp @ w1_n2 + b1_n2)
222
+ hidden = torch.tensor([h1.item(), h2.item()])
223
+ result = int(heaviside(hidden @ w2 + b2).item())
224
+ expected = a ^ b
225
+
226
+ if result != expected:
227
+ errors.append(('XOR', a, b, expected, result))
228
+
229
+ # XNOR (two-layer)
230
+ for a in [0, 1]:
231
+ for b in [0, 1]:
232
+ inp = torch.tensor([float(a), float(b)])
233
+
234
+ w1_n1 = model['boolean.xnor.layer1.neuron1.weight']
235
+ b1_n1 = model['boolean.xnor.layer1.neuron1.bias']
236
+ w1_n2 = model['boolean.xnor.layer1.neuron2.weight']
237
+ b1_n2 = model['boolean.xnor.layer1.neuron2.bias']
238
+ w2 = model['boolean.xnor.layer2.weight']
239
+ b2 = model['boolean.xnor.layer2.bias']
240
+
241
+ h1 = heaviside(inp @ w1_n1 + b1_n1)
242
+ h2 = heaviside(inp @ w1_n2 + b1_n2)
243
+ hidden = torch.tensor([h1.item(), h2.item()])
244
+ result = int(heaviside(hidden @ w2 + b2).item())
245
+ expected = 1 - (a ^ b) # XNOR = NOT XOR
246
+
247
+ if result != expected:
248
+ errors.append(('XNOR', a, b, expected, result))
249
+
250
+ # NOT (single input)
251
+ w = model['boolean.not.weight']
252
+ bias = model['boolean.not.bias']
253
+ for a in [0, 1]:
254
+ inp = torch.tensor([float(a)])
255
+ result = int(heaviside(inp @ w + bias).item())
256
+ expected = 1 - a
257
+ if result != expected:
258
+ errors.append(('NOT', a, '-', expected, result))
259
+
260
+ if errors:
261
+ print(f" FAILED: {len(errors)} mismatches")
262
+ for e in errors:
263
+ print(f" {e[0]}({e[1]}, {e[2]}) = {e[4]}, expected {e[3]}")
264
+ else:
265
+ print(f" PASSED: All Boolean gates verified (AND, OR, NAND, NOR, XOR, XNOR, NOT)")
266
+ print(f" Total: 26 truth table entries")
267
+
268
+ return len(errors) == 0
269
+
270
+ def test_half_adder_exhaustive():
271
+ """
272
+ Exhaustive test of half adder (4 cases).
273
+ """
274
+ print("\n[TEST 4] Exhaustive Half Adder: 4 cases")
275
+ print("-" * 60)
276
+
277
+ errors = []
278
+
279
+ for a in [0, 1]:
280
+ for b in [0, 1]:
281
+ inp = torch.tensor([float(a), float(b)])
282
+
283
+ # Sum (XOR)
284
+ w1_or = model['arithmetic.halfadder.sum.layer1.or.weight']
285
+ b1_or = model['arithmetic.halfadder.sum.layer1.or.bias']
286
+ w1_nand = model['arithmetic.halfadder.sum.layer1.nand.weight']
287
+ b1_nand = model['arithmetic.halfadder.sum.layer1.nand.bias']
288
+ w2 = model['arithmetic.halfadder.sum.layer2.weight']
289
+ b2_sum = model['arithmetic.halfadder.sum.layer2.bias']
290
+
291
+ h_or = heaviside(inp @ w1_or + b1_or)
292
+ h_nand = heaviside(inp @ w1_nand + b1_nand)
293
+ hidden = torch.tensor([h_or.item(), h_nand.item()])
294
+ sum_bit = int(heaviside(hidden @ w2 + b2_sum).item())
295
+
296
+ # Carry (AND)
297
+ w_c = model['arithmetic.halfadder.carry.weight']
298
+ b_c = model['arithmetic.halfadder.carry.bias']
299
+ carry = int(heaviside(inp @ w_c + b_c).item())
300
+
301
+ # Expected
302
+ exp_sum = a ^ b
303
+ exp_carry = a & b
304
+
305
+ if sum_bit != exp_sum:
306
+ errors.append(('SUM', a, b, exp_sum, sum_bit))
307
+ if carry != exp_carry:
308
+ errors.append(('CARRY', a, b, exp_carry, carry))
309
+
310
+ if errors:
311
+ print(f" FAILED: {len(errors)} mismatches")
312
+ for e in errors:
313
+ print(f" HA.{e[0]}({e[1]}, {e[2]}) = {e[4]}, expected {e[3]}")
314
+ else:
315
+ print(f" PASSED: Half adder verified (4 sum + 4 carry = 8 checks)")
316
+
317
+ return len(errors) == 0
318
+
319
+ def test_full_adder_exhaustive():
320
+ """
321
+ Exhaustive test of full adder (8 cases).
322
+ """
323
+ print("\n[TEST 5] Exhaustive Full Adder: 8 cases")
324
+ print("-" * 60)
325
+
326
+ errors = []
327
+
328
+ for a in [0, 1]:
329
+ for b in [0, 1]:
330
+ for cin in [0, 1]:
331
+ sum_bit, cout = eval_full_adder(float(a), float(b), float(cin),
332
+ 'arithmetic.fulladder')
333
+
334
+ # Expected
335
+ total = a + b + cin
336
+ exp_sum = total % 2
337
+ exp_cout = total // 2
338
+
339
+ if sum_bit != exp_sum:
340
+ errors.append(('SUM', a, b, cin, exp_sum, sum_bit))
341
+ if cout != exp_cout:
342
+ errors.append(('COUT', a, b, cin, exp_cout, cout))
343
+
344
+ if errors:
345
+ print(f" FAILED: {len(errors)} mismatches")
346
+ for e in errors:
347
+ print(f" FA.{e[0]}({e[1]}, {e[2]}, {e[3]}) = {e[5]}, expected {e[4]}")
348
+ else:
349
+ print(f" PASSED: Full adder verified (8 sum + 8 carry = 16 checks)")
350
+
351
+ return len(errors) == 0
352
+
353
+ def test_2bit_adder_exhaustive():
354
+ """
355
+ Exhaustive test of 2-bit ripple carry adder (16 cases).
356
+ """
357
+ print("\n[TEST 6] Exhaustive 2-bit Adder: 4 x 4 = 16 cases")
358
+ print("-" * 60)
359
+
360
+ errors = []
361
+
362
+ for a in range(4):
363
+ for b in range(4):
364
+ # Use 2-bit ripple carry
365
+ carry = 0.0
366
+ result_bits = []
367
+
368
+ for i in range(2):
369
+ a_bit = (a >> i) & 1
370
+ b_bit = (b >> i) & 1
371
+ s, carry = eval_full_adder(float(a_bit), float(b_bit), carry,
372
+ f'arithmetic.ripplecarry2bit.fa{i}')
373
+ result_bits.append(s)
374
+
375
+ result = result_bits[0] + 2 * result_bits[1]
376
+ cout = int(carry)
377
+
378
+ exp_result = (a + b) % 4
379
+ exp_carry = 1 if (a + b) >= 4 else 0
380
+
381
+ if result != exp_result:
382
+ errors.append(('result', a, b, exp_result, result))
383
+ if cout != exp_carry:
384
+ errors.append(('carry', a, b, exp_carry, cout))
385
+
386
+ if errors:
387
+ print(f" FAILED: {len(errors)} mismatches")
388
+ for e in errors:
389
+ print(f" {e[0]}: {e[1]} + {e[2]} = {e[4]}, expected {e[3]}")
390
+ else:
391
+ print(f" PASSED: 2-bit adder verified (16 results + 16 carries)")
392
+
393
+ return len(errors) == 0
394
+
395
+ def test_4bit_adder_exhaustive():
396
+ """
397
+ Exhaustive test of 4-bit ripple carry adder (256 cases).
398
+ """
399
+ print("\n[TEST 7] Exhaustive 4-bit Adder: 16 x 16 = 256 cases")
400
+ print("-" * 60)
401
+
402
+ errors = []
403
+
404
+ for a in range(16):
405
+ for b in range(16):
406
+ carry = 0.0
407
+ result_bits = []
408
+
409
+ for i in range(4):
410
+ a_bit = (a >> i) & 1
411
+ b_bit = (b >> i) & 1
412
+ s, carry = eval_full_adder(float(a_bit), float(b_bit), carry,
413
+ f'arithmetic.ripplecarry4bit.fa{i}')
414
+ result_bits.append(s)
415
+
416
+ result = sum(result_bits[i] * (2**i) for i in range(4))
417
+ cout = int(carry)
418
+
419
+ exp_result = (a + b) % 16
420
+ exp_carry = 1 if (a + b) >= 16 else 0
421
+
422
+ if result != exp_result:
423
+ errors.append(('result', a, b, exp_result, result))
424
+ if cout != exp_carry:
425
+ errors.append(('carry', a, b, exp_carry, cout))
426
+
427
+ if errors:
428
+ print(f" FAILED: {len(errors)} mismatches")
429
+ for e in errors[:10]:
430
+ print(f" {e[0]}: {e[1]} + {e[2]} = {e[4]}, expected {e[3]}")
431
+ else:
432
+ print(f" PASSED: 4-bit adder verified (256 results + 256 carries)")
433
+
434
+ return len(errors) == 0
435
+
436
+ # =============================================================================
437
+ # MAIN
438
+ # =============================================================================
439
+
440
+ if __name__ == "__main__":
441
+ print("=" * 70)
442
+ print(" TEST #2: FORMAL EQUIVALENCE CHECKING")
443
+ print(" Exhaustive verification against Python's arithmetic")
444
+ print("=" * 70)
445
+
446
+ results = []
447
+
448
+ results.append(("Boolean gates", test_boolean_exhaustive()))
449
+ results.append(("Half adder", test_half_adder_exhaustive()))
450
+ results.append(("Full adder", test_full_adder_exhaustive()))
451
+ results.append(("2-bit adder", test_2bit_adder_exhaustive()))
452
+ results.append(("4-bit adder", test_4bit_adder_exhaustive()))
453
+ results.append(("8-bit adder", test_addition_exhaustive()))
454
+ results.append(("Comparators", test_comparators_exhaustive()))
455
+
456
+ print("\n" + "=" * 70)
457
+ print(" SUMMARY")
458
+ print("=" * 70)
459
+
460
+ passed = sum(1 for _, r in results if r)
461
+ total = len(results)
462
+
463
+ for name, r in results:
464
+ status = "PASS" if r else "FAIL"
465
+ print(f" {name:20s} [{status}]")
466
+
467
+ print(f"\n Total: {passed}/{total} test categories passed")
468
+
469
+ total_checks = 26 + 8 + 16 + 32 + 512 + 65536*2 + 65536*3
470
+ print(f" Individual checks: ~{total_checks:,}")
471
+
472
+ if passed == total:
473
+ print("\n STATUS: EXHAUSTIVE EQUIVALENCE VERIFIED")
474
+ else:
475
+ print("\n STATUS: EQUIVALENCE FAILURES DETECTED")
476
+
477
+ print("=" * 70)