phanerozoic commited on
Commit
0a7c400
·
verified ·
1 Parent(s): e3a06f3

Delete test_overflow_chains.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_overflow_chains.py +0 -423
test_overflow_chains.py DELETED
@@ -1,423 +0,0 @@
1
- """
2
- TEST #1: Arithmetic Overflow Chains
3
- ====================================
4
- Chains 1000+ arithmetic operations, verifying every intermediate state.
5
- Tests carry/borrow propagation across long sequences, not just single ops.
6
-
7
- A skeptic would demand: "Prove your adder doesn't accumulate errors over
8
- repeated use. Show me every intermediate value matches Python's arithmetic."
9
- """
10
-
11
- import torch
12
- from safetensors.torch import load_file
13
- import random
14
-
15
- # Load circuits
16
- model = load_file('neural_computer.safetensors')
17
-
18
- def heaviside(x):
19
- return (x >= 0).float()
20
-
21
- def int_to_bits_lsb(val, width=8):
22
- """Convert int to bits, LSB first (for arithmetic)."""
23
- return torch.tensor([(val >> i) & 1 for i in range(width)], dtype=torch.float32)
24
-
25
- def bits_to_int_lsb(bits):
26
- """Convert bits back to int, LSB first."""
27
- return sum(int(bits[i].item()) * (2**i) for i in range(len(bits)))
28
-
29
- def eval_xor(a, b, prefix='boolean.xor'):
30
- """Evaluate XOR gate."""
31
- inp = torch.tensor([a, b], dtype=torch.float32)
32
- w1_n1 = model[f'{prefix}.layer1.neuron1.weight']
33
- b1_n1 = model[f'{prefix}.layer1.neuron1.bias']
34
- w1_n2 = model[f'{prefix}.layer1.neuron2.weight']
35
- b1_n2 = model[f'{prefix}.layer1.neuron2.bias']
36
- w2 = model[f'{prefix}.layer2.weight']
37
- b2 = model[f'{prefix}.layer2.bias']
38
- h1 = heaviside(inp @ w1_n1 + b1_n1)
39
- h2 = heaviside(inp @ w1_n2 + b1_n2)
40
- hidden = torch.tensor([h1.item(), h2.item()])
41
- return heaviside(hidden @ w2 + b2).item()
42
-
43
- def eval_xor_arith(inp, prefix):
44
- """Evaluate XOR for arithmetic circuits (different naming)."""
45
- w1_or = model[f'{prefix}.layer1.or.weight']
46
- b1_or = model[f'{prefix}.layer1.or.bias']
47
- w1_nand = model[f'{prefix}.layer1.nand.weight']
48
- b1_nand = model[f'{prefix}.layer1.nand.bias']
49
- w2 = model[f'{prefix}.layer2.weight']
50
- b2 = model[f'{prefix}.layer2.bias']
51
- h_or = heaviside(inp @ w1_or + b1_or)
52
- h_nand = heaviside(inp @ w1_nand + b1_nand)
53
- hidden = torch.tensor([h_or.item(), h_nand.item()])
54
- return heaviside(hidden @ w2 + b2).item()
55
-
56
- def eval_full_adder(a, b, cin, prefix):
57
- """Evaluate full adder, return (sum, carry_out)."""
58
- inp_ab = torch.tensor([a, b], dtype=torch.float32)
59
-
60
- # HA1: a XOR b
61
- ha1_sum = eval_xor_arith(inp_ab, f'{prefix}.ha1.sum')
62
-
63
- # HA1 carry: a AND b
64
- w_c1 = model[f'{prefix}.ha1.carry.weight']
65
- b_c1 = model[f'{prefix}.ha1.carry.bias']
66
- ha1_carry = heaviside(inp_ab @ w_c1 + b_c1).item()
67
-
68
- # HA2: ha1_sum XOR cin
69
- inp_ha2 = torch.tensor([ha1_sum, cin], dtype=torch.float32)
70
- ha2_sum = eval_xor_arith(inp_ha2, f'{prefix}.ha2.sum')
71
-
72
- # HA2 carry
73
- w_c2 = model[f'{prefix}.ha2.carry.weight']
74
- b_c2 = model[f'{prefix}.ha2.carry.bias']
75
- ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2).item()
76
-
77
- # Carry out = ha1_carry OR ha2_carry
78
- inp_cout = torch.tensor([ha1_carry, ha2_carry], dtype=torch.float32)
79
- w_or = model[f'{prefix}.carry_or.weight']
80
- b_or = model[f'{prefix}.carry_or.bias']
81
- cout = heaviside(inp_cout @ w_or + b_or).item()
82
-
83
- return int(ha2_sum), int(cout)
84
-
85
- def add_8bit(a, b):
86
- """8-bit addition using ripple carry adder. Returns (result, carry)."""
87
- carry = 0.0
88
- result_bits = []
89
-
90
- for i in range(8):
91
- a_bit = (a >> i) & 1
92
- b_bit = (b >> i) & 1
93
- s, carry = eval_full_adder(float(a_bit), float(b_bit), carry,
94
- f'arithmetic.ripplecarry8bit.fa{i}')
95
- result_bits.append(s)
96
-
97
- result = sum(result_bits[i] * (2**i) for i in range(8))
98
- return result, int(carry)
99
-
100
- def sub_8bit(a, b):
101
- """8-bit subtraction via two's complement: a - b = a + (~b) + 1."""
102
- not_b = (~b) & 0xFF
103
- temp, c1 = add_8bit(a, not_b)
104
- result, c2 = add_8bit(temp, 1)
105
- return result, c1 | c2
106
-
107
- # =============================================================================
108
- # TEST CHAINS
109
- # =============================================================================
110
-
111
- def test_chain_add_overflow():
112
- """
113
- Start at 0, add 1 repeatedly until we wrap around multiple times.
114
- Verify every single intermediate value.
115
- """
116
- print("\n[TEST 1] Add-1 chain: 0 -> 255 -> 0 -> 255 (512 additions)")
117
- print("-" * 60)
118
-
119
- value = 0
120
- errors = []
121
-
122
- for i in range(512):
123
- expected = (value + 1) % 256
124
- result, carry = add_8bit(value, 1)
125
-
126
- if result != expected:
127
- errors.append((i, value, 1, expected, result))
128
-
129
- # Check carry on overflow
130
- if value == 255 and carry != 1:
131
- errors.append((i, value, 1, "carry=1", f"carry={carry}"))
132
-
133
- value = result
134
-
135
- if errors:
136
- print(f" FAILED: {len(errors)} errors")
137
- for e in errors[:5]:
138
- print(f" Step {e[0]}: {e[1]} + {e[2]} = {e[4]}, expected {e[3]}")
139
- else:
140
- print(f" PASSED: 512 additions, 2 full wraparounds verified")
141
-
142
- return len(errors) == 0
143
-
144
- def test_chain_sub_overflow():
145
- """
146
- Start at 255, subtract 1 repeatedly until we wrap around.
147
- """
148
- print("\n[TEST 2] Sub-1 chain: 255 -> 0 -> 255 (512 subtractions)")
149
- print("-" * 60)
150
-
151
- value = 255
152
- errors = []
153
-
154
- for i in range(512):
155
- expected = (value - 1) % 256
156
- result, _ = sub_8bit(value, 1)
157
-
158
- if result != expected:
159
- errors.append((i, value, 1, expected, result))
160
-
161
- value = result
162
-
163
- if errors:
164
- print(f" FAILED: {len(errors)} errors")
165
- for e in errors[:5]:
166
- print(f" Step {e[0]}: {e[1]} - {e[2]} = {e[4]}, expected {e[3]}")
167
- else:
168
- print(f" PASSED: 512 subtractions verified")
169
-
170
- return len(errors) == 0
171
-
172
- def test_chain_mixed():
173
- """
174
- Random mix of +1, -1, +k, -k operations. Verify all intermediates.
175
- """
176
- print("\n[TEST 3] Mixed chain: 1000 random +/- operations")
177
- print("-" * 60)
178
-
179
- random.seed(42) # Reproducible
180
-
181
- value = 128 # Start in middle
182
- python_value = 128
183
- errors = []
184
-
185
- for i in range(1000):
186
- op = random.choice(['+1', '-1', '+k', '-k'])
187
-
188
- if op == '+1':
189
- result, _ = add_8bit(value, 1)
190
- python_value = (python_value + 1) % 256
191
- elif op == '-1':
192
- result, _ = sub_8bit(value, 1)
193
- python_value = (python_value - 1) % 256
194
- elif op == '+k':
195
- k = random.randint(1, 50)
196
- result, _ = add_8bit(value, k)
197
- python_value = (python_value + k) % 256
198
- else: # '-k'
199
- k = random.randint(1, 50)
200
- result, _ = sub_8bit(value, k)
201
- python_value = (python_value - k) % 256
202
-
203
- if result != python_value:
204
- errors.append((i, op, value, python_value, result))
205
-
206
- value = result
207
-
208
- if errors:
209
- print(f" FAILED: {len(errors)} errors")
210
- for e in errors[:5]:
211
- print(f" Step {e[0]}: {e[1]} on {e[2]} = {e[4]}, expected {e[3]}")
212
- else:
213
- print(f" PASSED: 1000 random ops verified")
214
-
215
- return len(errors) == 0
216
-
217
- def test_chain_carry_stress():
218
- """
219
- Worst-case carry propagation: repeatedly compute 127+128=255, 255+1=0.
220
- """
221
- print("\n[TEST 4] Carry stress: 127+128 and 255+1 chains (500 each)")
222
- print("-" * 60)
223
-
224
- errors = []
225
-
226
- # 127 + 128 = 255 (all bits flip via carry)
227
- for i in range(500):
228
- result, carry = add_8bit(127, 128)
229
- if result != 255:
230
- errors.append((i, '127+128', 255, result))
231
-
232
- # 255 + 1 = 0 with carry out (8-bit carry chain)
233
- for i in range(500):
234
- result, carry = add_8bit(255, 1)
235
- if result != 0 or carry != 1:
236
- errors.append((i, '255+1', '0,c=1', f'{result},c={carry}'))
237
-
238
- if errors:
239
- print(f" FAILED: {len(errors)} errors")
240
- for e in errors[:5]:
241
- print(f" Iteration {e[0]}: {e[1]} = {e[3]}, expected {e[2]}")
242
- else:
243
- print(f" PASSED: 1000 worst-case carry operations")
244
-
245
- return len(errors) == 0
246
-
247
- def test_chain_accumulator():
248
- """
249
- Accumulate: start at 0, add 1,2,3,...,100. Verify running sum at each step.
250
- """
251
- print("\n[TEST 5] Accumulator: sum(1..100) with intermediate verification")
252
- print("-" * 60)
253
-
254
- acc = 0
255
- errors = []
256
-
257
- for i in range(1, 101):
258
- result, _ = add_8bit(acc, i)
259
- expected = (acc + i) % 256
260
-
261
- if result != expected:
262
- errors.append((i, acc, i, expected, result))
263
-
264
- acc = result
265
-
266
- # Final value: sum(1..100) = 5050, mod 256 = 5050 % 256 = 186
267
- final_expected = sum(range(1, 101)) % 256
268
-
269
- if acc != final_expected:
270
- errors.append(('final', acc, final_expected))
271
-
272
- if errors:
273
- print(f" FAILED: {len(errors)} errors")
274
- for e in errors[:5]:
275
- print(f" {e}")
276
- else:
277
- print(f" PASSED: sum(1..100) mod 256 = {acc} verified at every step")
278
-
279
- return len(errors) == 0
280
-
281
- def test_chain_fibonacci():
282
- """
283
- Compute Fibonacci sequence mod 256. Verify against Python.
284
- """
285
- print("\n[TEST 6] Fibonacci chain: F(0)..F(100) mod 256")
286
- print("-" * 60)
287
-
288
- a, b = 0, 1 # Circuit values
289
- pa, pb = 0, 1 # Python values
290
- errors = []
291
-
292
- for i in range(100):
293
- # Verify current values
294
- if a != pa:
295
- errors.append((i, 'a', pa, a))
296
- if b != pb:
297
- errors.append((i, 'b', pb, b))
298
-
299
- # Compute next
300
- next_val, _ = add_8bit(a, b)
301
- next_python = (pa + pb) % 256
302
-
303
- a, b = b, next_val
304
- pa, pb = pb, next_python
305
-
306
- if errors:
307
- print(f" FAILED: {len(errors)} errors")
308
- for e in errors[:5]:
309
- print(f" F({e[0]}) {e[1]}: expected {e[2]}, got {e[3]}")
310
- else:
311
- print(f" PASSED: 100 Fibonacci terms verified")
312
-
313
- return len(errors) == 0
314
-
315
- def test_chain_alternating():
316
- """
317
- Alternating +127/-127 to stress positive/negative boundaries.
318
- """
319
- print("\n[TEST 7] Alternating +127/-127 (200 operations)")
320
- print("-" * 60)
321
-
322
- value = 0
323
- python_value = 0
324
- errors = []
325
-
326
- for i in range(200):
327
- if i % 2 == 0:
328
- result, _ = add_8bit(value, 127)
329
- python_value = (python_value + 127) % 256
330
- else:
331
- result, _ = sub_8bit(value, 127)
332
- python_value = (python_value - 127) % 256
333
-
334
- if result != python_value:
335
- errors.append((i, value, python_value, result))
336
-
337
- value = result
338
-
339
- if errors:
340
- print(f" FAILED: {len(errors)} errors")
341
- for e in errors[:5]:
342
- print(f" Step {e[0]}: from {e[1]}, expected {e[2]}, got {e[3]}")
343
- else:
344
- print(f" PASSED: 200 alternating ops verified")
345
-
346
- return len(errors) == 0
347
-
348
- def test_chain_powers_of_two():
349
- """
350
- Add powers of 2: 1+2+4+8+...+128. Verify intermediate sums.
351
- """
352
- print("\n[TEST 8] Powers of 2: 1+2+4+8+16+32+64+128")
353
- print("-" * 60)
354
-
355
- acc = 0
356
- errors = []
357
-
358
- for i in range(8):
359
- power = 2 ** i
360
- result, _ = add_8bit(acc, power)
361
- expected = (acc + power) % 256
362
-
363
- if result != expected:
364
- errors.append((i, acc, power, expected, result))
365
-
366
- acc = result
367
-
368
- # Final: 1+2+4+8+16+32+64+128 = 255
369
- if acc != 255:
370
- errors.append(('final', 255, acc))
371
-
372
- if errors:
373
- print(f" FAILED: {len(errors)} errors")
374
- for e in errors[:5]:
375
- print(f" {e}")
376
- else:
377
- print(f" PASSED: 2^0 + 2^1 + ... + 2^7 = {acc}")
378
-
379
- return len(errors) == 0
380
-
381
- # =============================================================================
382
- # MAIN
383
- # =============================================================================
384
-
385
- if __name__ == "__main__":
386
- print("=" * 70)
387
- print(" TEST #1: ARITHMETIC OVERFLOW CHAINS")
388
- print(" Verifying every intermediate state across 3000+ chained operations")
389
- print("=" * 70)
390
-
391
- results = []
392
-
393
- results.append(("Add-1 chain", test_chain_add_overflow()))
394
- results.append(("Sub-1 chain", test_chain_sub_overflow()))
395
- results.append(("Mixed random", test_chain_mixed()))
396
- results.append(("Carry stress", test_chain_carry_stress()))
397
- results.append(("Accumulator", test_chain_accumulator()))
398
- results.append(("Fibonacci", test_chain_fibonacci()))
399
- results.append(("Alternating", test_chain_alternating()))
400
- results.append(("Powers of 2", test_chain_powers_of_two()))
401
-
402
- print("\n" + "=" * 70)
403
- print(" SUMMARY")
404
- print("=" * 70)
405
-
406
- passed = sum(1 for _, r in results if r)
407
- total = len(results)
408
-
409
- for name, r in results:
410
- status = "PASS" if r else "FAIL"
411
- print(f" {name:20s} [{status}]")
412
-
413
- print(f"\n Total: {passed}/{total} tests passed")
414
-
415
- total_ops = 512 + 512 + 1000 + 1000 + 100 + 100 + 200 + 8 # ~3400
416
- print(f" Operations verified: ~{total_ops}")
417
-
418
- if passed == total:
419
- print("\n STATUS: ALL CHAINS VERIFIED - NO ACCUMULATED ERRORS")
420
- else:
421
- print("\n STATUS: FAILURES DETECTED")
422
-
423
- print("=" * 70)