phanerozoic commited on
Commit
41334a3
·
verified ·
1 Parent(s): 7a7b45c

Upload test_overflow_chains.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_overflow_chains.py +423 -0
test_overflow_chains.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TEST #1: Arithmetic Overflow Chains
3
+ ====================================
4
+ Chains 1000+ arithmetic operations, verifying every intermediate state.
5
+ Tests carry/borrow propagation across long sequences, not just single ops.
6
+
7
+ A skeptic would demand: "Prove your adder doesn't accumulate errors over
8
+ repeated use. Show me every intermediate value matches Python's arithmetic."
9
+ """
10
+
11
+ import torch
12
+ from safetensors.torch import load_file
13
+ import random
14
+
15
+ # Load circuits
16
+ model = load_file('neural_computer.safetensors')
17
+
18
+ def heaviside(x):
19
+ return (x >= 0).float()
20
+
21
+ def int_to_bits_lsb(val, width=8):
22
+ """Convert int to bits, LSB first (for arithmetic)."""
23
+ return torch.tensor([(val >> i) & 1 for i in range(width)], dtype=torch.float32)
24
+
25
+ def bits_to_int_lsb(bits):
26
+ """Convert bits back to int, LSB first."""
27
+ return sum(int(bits[i].item()) * (2**i) for i in range(len(bits)))
28
+
29
+ def eval_xor(a, b, prefix='boolean.xor'):
30
+ """Evaluate XOR gate."""
31
+ inp = torch.tensor([a, b], dtype=torch.float32)
32
+ w1_n1 = model[f'{prefix}.layer1.neuron1.weight']
33
+ b1_n1 = model[f'{prefix}.layer1.neuron1.bias']
34
+ w1_n2 = model[f'{prefix}.layer1.neuron2.weight']
35
+ b1_n2 = model[f'{prefix}.layer1.neuron2.bias']
36
+ w2 = model[f'{prefix}.layer2.weight']
37
+ b2 = model[f'{prefix}.layer2.bias']
38
+ h1 = heaviside(inp @ w1_n1 + b1_n1)
39
+ h2 = heaviside(inp @ w1_n2 + b1_n2)
40
+ hidden = torch.tensor([h1.item(), h2.item()])
41
+ return heaviside(hidden @ w2 + b2).item()
42
+
43
+ def eval_xor_arith(inp, prefix):
44
+ """Evaluate XOR for arithmetic circuits (different naming)."""
45
+ w1_or = model[f'{prefix}.layer1.or.weight']
46
+ b1_or = model[f'{prefix}.layer1.or.bias']
47
+ w1_nand = model[f'{prefix}.layer1.nand.weight']
48
+ b1_nand = model[f'{prefix}.layer1.nand.bias']
49
+ w2 = model[f'{prefix}.layer2.weight']
50
+ b2 = model[f'{prefix}.layer2.bias']
51
+ h_or = heaviside(inp @ w1_or + b1_or)
52
+ h_nand = heaviside(inp @ w1_nand + b1_nand)
53
+ hidden = torch.tensor([h_or.item(), h_nand.item()])
54
+ return heaviside(hidden @ w2 + b2).item()
55
+
56
+ def eval_full_adder(a, b, cin, prefix):
57
+ """Evaluate full adder, return (sum, carry_out)."""
58
+ inp_ab = torch.tensor([a, b], dtype=torch.float32)
59
+
60
+ # HA1: a XOR b
61
+ ha1_sum = eval_xor_arith(inp_ab, f'{prefix}.ha1.sum')
62
+
63
+ # HA1 carry: a AND b
64
+ w_c1 = model[f'{prefix}.ha1.carry.weight']
65
+ b_c1 = model[f'{prefix}.ha1.carry.bias']
66
+ ha1_carry = heaviside(inp_ab @ w_c1 + b_c1).item()
67
+
68
+ # HA2: ha1_sum XOR cin
69
+ inp_ha2 = torch.tensor([ha1_sum, cin], dtype=torch.float32)
70
+ ha2_sum = eval_xor_arith(inp_ha2, f'{prefix}.ha2.sum')
71
+
72
+ # HA2 carry
73
+ w_c2 = model[f'{prefix}.ha2.carry.weight']
74
+ b_c2 = model[f'{prefix}.ha2.carry.bias']
75
+ ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2).item()
76
+
77
+ # Carry out = ha1_carry OR ha2_carry
78
+ inp_cout = torch.tensor([ha1_carry, ha2_carry], dtype=torch.float32)
79
+ w_or = model[f'{prefix}.carry_or.weight']
80
+ b_or = model[f'{prefix}.carry_or.bias']
81
+ cout = heaviside(inp_cout @ w_or + b_or).item()
82
+
83
+ return int(ha2_sum), int(cout)
84
+
85
+ def add_8bit(a, b):
86
+ """8-bit addition using ripple carry adder. Returns (result, carry)."""
87
+ carry = 0.0
88
+ result_bits = []
89
+
90
+ for i in range(8):
91
+ a_bit = (a >> i) & 1
92
+ b_bit = (b >> i) & 1
93
+ s, carry = eval_full_adder(float(a_bit), float(b_bit), carry,
94
+ f'arithmetic.ripplecarry8bit.fa{i}')
95
+ result_bits.append(s)
96
+
97
+ result = sum(result_bits[i] * (2**i) for i in range(8))
98
+ return result, int(carry)
99
+
100
+ def sub_8bit(a, b):
101
+ """8-bit subtraction via two's complement: a - b = a + (~b) + 1."""
102
+ not_b = (~b) & 0xFF
103
+ temp, c1 = add_8bit(a, not_b)
104
+ result, c2 = add_8bit(temp, 1)
105
+ return result, c1 | c2
106
+
107
+ # =============================================================================
108
+ # TEST CHAINS
109
+ # =============================================================================
110
+
111
+ def test_chain_add_overflow():
112
+ """
113
+ Start at 0, add 1 repeatedly until we wrap around multiple times.
114
+ Verify every single intermediate value.
115
+ """
116
+ print("\n[TEST 1] Add-1 chain: 0 -> 255 -> 0 -> 255 (512 additions)")
117
+ print("-" * 60)
118
+
119
+ value = 0
120
+ errors = []
121
+
122
+ for i in range(512):
123
+ expected = (value + 1) % 256
124
+ result, carry = add_8bit(value, 1)
125
+
126
+ if result != expected:
127
+ errors.append((i, value, 1, expected, result))
128
+
129
+ # Check carry on overflow
130
+ if value == 255 and carry != 1:
131
+ errors.append((i, value, 1, "carry=1", f"carry={carry}"))
132
+
133
+ value = result
134
+
135
+ if errors:
136
+ print(f" FAILED: {len(errors)} errors")
137
+ for e in errors[:5]:
138
+ print(f" Step {e[0]}: {e[1]} + {e[2]} = {e[4]}, expected {e[3]}")
139
+ else:
140
+ print(f" PASSED: 512 additions, 2 full wraparounds verified")
141
+
142
+ return len(errors) == 0
143
+
144
+ def test_chain_sub_overflow():
145
+ """
146
+ Start at 255, subtract 1 repeatedly until we wrap around.
147
+ """
148
+ print("\n[TEST 2] Sub-1 chain: 255 -> 0 -> 255 (512 subtractions)")
149
+ print("-" * 60)
150
+
151
+ value = 255
152
+ errors = []
153
+
154
+ for i in range(512):
155
+ expected = (value - 1) % 256
156
+ result, _ = sub_8bit(value, 1)
157
+
158
+ if result != expected:
159
+ errors.append((i, value, 1, expected, result))
160
+
161
+ value = result
162
+
163
+ if errors:
164
+ print(f" FAILED: {len(errors)} errors")
165
+ for e in errors[:5]:
166
+ print(f" Step {e[0]}: {e[1]} - {e[2]} = {e[4]}, expected {e[3]}")
167
+ else:
168
+ print(f" PASSED: 512 subtractions verified")
169
+
170
+ return len(errors) == 0
171
+
172
+ def test_chain_mixed():
173
+ """
174
+ Random mix of +1, -1, +k, -k operations. Verify all intermediates.
175
+ """
176
+ print("\n[TEST 3] Mixed chain: 1000 random +/- operations")
177
+ print("-" * 60)
178
+
179
+ random.seed(42) # Reproducible
180
+
181
+ value = 128 # Start in middle
182
+ python_value = 128
183
+ errors = []
184
+
185
+ for i in range(1000):
186
+ op = random.choice(['+1', '-1', '+k', '-k'])
187
+
188
+ if op == '+1':
189
+ result, _ = add_8bit(value, 1)
190
+ python_value = (python_value + 1) % 256
191
+ elif op == '-1':
192
+ result, _ = sub_8bit(value, 1)
193
+ python_value = (python_value - 1) % 256
194
+ elif op == '+k':
195
+ k = random.randint(1, 50)
196
+ result, _ = add_8bit(value, k)
197
+ python_value = (python_value + k) % 256
198
+ else: # '-k'
199
+ k = random.randint(1, 50)
200
+ result, _ = sub_8bit(value, k)
201
+ python_value = (python_value - k) % 256
202
+
203
+ if result != python_value:
204
+ errors.append((i, op, value, python_value, result))
205
+
206
+ value = result
207
+
208
+ if errors:
209
+ print(f" FAILED: {len(errors)} errors")
210
+ for e in errors[:5]:
211
+ print(f" Step {e[0]}: {e[1]} on {e[2]} = {e[4]}, expected {e[3]}")
212
+ else:
213
+ print(f" PASSED: 1000 random ops verified")
214
+
215
+ return len(errors) == 0
216
+
217
+ def test_chain_carry_stress():
218
+ """
219
+ Worst-case carry propagation: repeatedly compute 127+128=255, 255+1=0.
220
+ """
221
+ print("\n[TEST 4] Carry stress: 127+128 and 255+1 chains (500 each)")
222
+ print("-" * 60)
223
+
224
+ errors = []
225
+
226
+ # 127 + 128 = 255 (all bits flip via carry)
227
+ for i in range(500):
228
+ result, carry = add_8bit(127, 128)
229
+ if result != 255:
230
+ errors.append((i, '127+128', 255, result))
231
+
232
+ # 255 + 1 = 0 with carry out (8-bit carry chain)
233
+ for i in range(500):
234
+ result, carry = add_8bit(255, 1)
235
+ if result != 0 or carry != 1:
236
+ errors.append((i, '255+1', '0,c=1', f'{result},c={carry}'))
237
+
238
+ if errors:
239
+ print(f" FAILED: {len(errors)} errors")
240
+ for e in errors[:5]:
241
+ print(f" Iteration {e[0]}: {e[1]} = {e[3]}, expected {e[2]}")
242
+ else:
243
+ print(f" PASSED: 1000 worst-case carry operations")
244
+
245
+ return len(errors) == 0
246
+
247
+ def test_chain_accumulator():
248
+ """
249
+ Accumulate: start at 0, add 1,2,3,...,100. Verify running sum at each step.
250
+ """
251
+ print("\n[TEST 5] Accumulator: sum(1..100) with intermediate verification")
252
+ print("-" * 60)
253
+
254
+ acc = 0
255
+ errors = []
256
+
257
+ for i in range(1, 101):
258
+ result, _ = add_8bit(acc, i)
259
+ expected = (acc + i) % 256
260
+
261
+ if result != expected:
262
+ errors.append((i, acc, i, expected, result))
263
+
264
+ acc = result
265
+
266
+ # Final value: sum(1..100) = 5050, mod 256 = 5050 % 256 = 186
267
+ final_expected = sum(range(1, 101)) % 256
268
+
269
+ if acc != final_expected:
270
+ errors.append(('final', acc, final_expected))
271
+
272
+ if errors:
273
+ print(f" FAILED: {len(errors)} errors")
274
+ for e in errors[:5]:
275
+ print(f" {e}")
276
+ else:
277
+ print(f" PASSED: sum(1..100) mod 256 = {acc} verified at every step")
278
+
279
+ return len(errors) == 0
280
+
281
+ def test_chain_fibonacci():
282
+ """
283
+ Compute Fibonacci sequence mod 256. Verify against Python.
284
+ """
285
+ print("\n[TEST 6] Fibonacci chain: F(0)..F(100) mod 256")
286
+ print("-" * 60)
287
+
288
+ a, b = 0, 1 # Circuit values
289
+ pa, pb = 0, 1 # Python values
290
+ errors = []
291
+
292
+ for i in range(100):
293
+ # Verify current values
294
+ if a != pa:
295
+ errors.append((i, 'a', pa, a))
296
+ if b != pb:
297
+ errors.append((i, 'b', pb, b))
298
+
299
+ # Compute next
300
+ next_val, _ = add_8bit(a, b)
301
+ next_python = (pa + pb) % 256
302
+
303
+ a, b = b, next_val
304
+ pa, pb = pb, next_python
305
+
306
+ if errors:
307
+ print(f" FAILED: {len(errors)} errors")
308
+ for e in errors[:5]:
309
+ print(f" F({e[0]}) {e[1]}: expected {e[2]}, got {e[3]}")
310
+ else:
311
+ print(f" PASSED: 100 Fibonacci terms verified")
312
+
313
+ return len(errors) == 0
314
+
315
+ def test_chain_alternating():
316
+ """
317
+ Alternating +127/-127 to stress positive/negative boundaries.
318
+ """
319
+ print("\n[TEST 7] Alternating +127/-127 (200 operations)")
320
+ print("-" * 60)
321
+
322
+ value = 0
323
+ python_value = 0
324
+ errors = []
325
+
326
+ for i in range(200):
327
+ if i % 2 == 0:
328
+ result, _ = add_8bit(value, 127)
329
+ python_value = (python_value + 127) % 256
330
+ else:
331
+ result, _ = sub_8bit(value, 127)
332
+ python_value = (python_value - 127) % 256
333
+
334
+ if result != python_value:
335
+ errors.append((i, value, python_value, result))
336
+
337
+ value = result
338
+
339
+ if errors:
340
+ print(f" FAILED: {len(errors)} errors")
341
+ for e in errors[:5]:
342
+ print(f" Step {e[0]}: from {e[1]}, expected {e[2]}, got {e[3]}")
343
+ else:
344
+ print(f" PASSED: 200 alternating ops verified")
345
+
346
+ return len(errors) == 0
347
+
348
+ def test_chain_powers_of_two():
349
+ """
350
+ Add powers of 2: 1+2+4+8+...+128. Verify intermediate sums.
351
+ """
352
+ print("\n[TEST 8] Powers of 2: 1+2+4+8+16+32+64+128")
353
+ print("-" * 60)
354
+
355
+ acc = 0
356
+ errors = []
357
+
358
+ for i in range(8):
359
+ power = 2 ** i
360
+ result, _ = add_8bit(acc, power)
361
+ expected = (acc + power) % 256
362
+
363
+ if result != expected:
364
+ errors.append((i, acc, power, expected, result))
365
+
366
+ acc = result
367
+
368
+ # Final: 1+2+4+8+16+32+64+128 = 255
369
+ if acc != 255:
370
+ errors.append(('final', 255, acc))
371
+
372
+ if errors:
373
+ print(f" FAILED: {len(errors)} errors")
374
+ for e in errors[:5]:
375
+ print(f" {e}")
376
+ else:
377
+ print(f" PASSED: 2^0 + 2^1 + ... + 2^7 = {acc}")
378
+
379
+ return len(errors) == 0
380
+
381
+ # =============================================================================
382
+ # MAIN
383
+ # =============================================================================
384
+
385
+ if __name__ == "__main__":
386
+ print("=" * 70)
387
+ print(" TEST #1: ARITHMETIC OVERFLOW CHAINS")
388
+ print(" Verifying every intermediate state across 3000+ chained operations")
389
+ print("=" * 70)
390
+
391
+ results = []
392
+
393
+ results.append(("Add-1 chain", test_chain_add_overflow()))
394
+ results.append(("Sub-1 chain", test_chain_sub_overflow()))
395
+ results.append(("Mixed random", test_chain_mixed()))
396
+ results.append(("Carry stress", test_chain_carry_stress()))
397
+ results.append(("Accumulator", test_chain_accumulator()))
398
+ results.append(("Fibonacci", test_chain_fibonacci()))
399
+ results.append(("Alternating", test_chain_alternating()))
400
+ results.append(("Powers of 2", test_chain_powers_of_two()))
401
+
402
+ print("\n" + "=" * 70)
403
+ print(" SUMMARY")
404
+ print("=" * 70)
405
+
406
+ passed = sum(1 for _, r in results if r)
407
+ total = len(results)
408
+
409
+ for name, r in results:
410
+ status = "PASS" if r else "FAIL"
411
+ print(f" {name:20s} [{status}]")
412
+
413
+ print(f"\n Total: {passed}/{total} tests passed")
414
+
415
+ total_ops = 512 + 512 + 1000 + 1000 + 100 + 100 + 200 + 8 # ~3400
416
+ print(f" Operations verified: ~{total_ops}")
417
+
418
+ if passed == total:
419
+ print("\n STATUS: ALL CHAINS VERIFIED - NO ACCUMULATED ERRORS")
420
+ else:
421
+ print("\n STATUS: FAILURES DETECTED")
422
+
423
+ print("=" * 70)