phanerozoic commited on
Commit
ea808d6
·
verified ·
1 Parent(s): def2700

Upload test_cryptographic_selftest.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_cryptographic_selftest.py +516 -0
test_cryptographic_selftest.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TEST #6: Cryptographic Self-Test
3
+ =================================
4
+ Have the threshold computer compute a checksum over its own weights.
5
+ Verify the result matches external (Python) computation.
6
+
7
+ A skeptic would demand: "Prove the computer can verify its own integrity.
8
+ Bootstrap trust by having it compute over its own weights."
9
+ """
10
+
11
+ import torch
12
+ from safetensors.torch import load_file
13
+ import struct
14
+
15
+ # Load circuits
16
+ model = load_file('neural_computer.safetensors')
17
+
18
+ def heaviside(x):
19
+ return (x >= 0).float()
20
+
21
+ # =============================================================================
22
+ # CIRCUIT PRIMITIVES
23
+ # =============================================================================
24
+
25
+ def eval_xor_arith(inp, prefix):
26
+ """Evaluate XOR for arithmetic circuits."""
27
+ w1_or = model[f'{prefix}.layer1.or.weight']
28
+ b1_or = model[f'{prefix}.layer1.or.bias']
29
+ w1_nand = model[f'{prefix}.layer1.nand.weight']
30
+ b1_nand = model[f'{prefix}.layer1.nand.bias']
31
+ w2 = model[f'{prefix}.layer2.weight']
32
+ b2 = model[f'{prefix}.layer2.bias']
33
+ h_or = heaviside(inp @ w1_or + b1_or)
34
+ h_nand = heaviside(inp @ w1_nand + b1_nand)
35
+ hidden = torch.tensor([h_or.item(), h_nand.item()])
36
+ return heaviside(hidden @ w2 + b2).item()
37
+
38
+ def eval_full_adder(a, b, cin, prefix):
39
+ """Evaluate full adder, return (sum, carry_out)."""
40
+ inp_ab = torch.tensor([a, b], dtype=torch.float32)
41
+ ha1_sum = eval_xor_arith(inp_ab, f'{prefix}.ha1.sum')
42
+ w_c1 = model[f'{prefix}.ha1.carry.weight']
43
+ b_c1 = model[f'{prefix}.ha1.carry.bias']
44
+ ha1_carry = heaviside(inp_ab @ w_c1 + b_c1).item()
45
+ inp_ha2 = torch.tensor([ha1_sum, cin], dtype=torch.float32)
46
+ ha2_sum = eval_xor_arith(inp_ha2, f'{prefix}.ha2.sum')
47
+ w_c2 = model[f'{prefix}.ha2.carry.weight']
48
+ b_c2 = model[f'{prefix}.ha2.carry.bias']
49
+ ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2).item()
50
+ inp_cout = torch.tensor([ha1_carry, ha2_carry], dtype=torch.float32)
51
+ w_or = model[f'{prefix}.carry_or.weight']
52
+ b_or = model[f'{prefix}.carry_or.bias']
53
+ cout = heaviside(inp_cout @ w_or + b_or).item()
54
+ return int(ha2_sum), int(cout)
55
+
56
+ def add_8bit(a, b):
57
+ """8-bit addition using ripple carry adder."""
58
+ carry = 0.0
59
+ result_bits = []
60
+ for i in range(8):
61
+ a_bit = (a >> i) & 1
62
+ b_bit = (b >> i) & 1
63
+ s, carry = eval_full_adder(float(a_bit), float(b_bit), carry,
64
+ f'arithmetic.ripplecarry8bit.fa{i}')
65
+ result_bits.append(s)
66
+ result = sum(result_bits[i] * (2**i) for i in range(8))
67
+ return result, int(carry)
68
+
69
+ def eval_xor_byte(a, b):
70
+ """XOR two bytes using the XOR circuit, bit by bit."""
71
+ result = 0
72
+ for i in range(8):
73
+ a_bit = (a >> i) & 1
74
+ b_bit = (b >> i) & 1
75
+ inp = torch.tensor([float(a_bit), float(b_bit)])
76
+
77
+ w1_n1 = model['boolean.xor.layer1.neuron1.weight']
78
+ b1_n1 = model['boolean.xor.layer1.neuron1.bias']
79
+ w1_n2 = model['boolean.xor.layer1.neuron2.weight']
80
+ b1_n2 = model['boolean.xor.layer1.neuron2.bias']
81
+ w2 = model['boolean.xor.layer2.weight']
82
+ b2 = model['boolean.xor.layer2.bias']
83
+
84
+ h1 = heaviside(inp @ w1_n1 + b1_n1)
85
+ h2 = heaviside(inp @ w1_n2 + b1_n2)
86
+ hidden = torch.tensor([h1.item(), h2.item()])
87
+ out = int(heaviside(hidden @ w2 + b2).item())
88
+
89
+ result |= (out << i)
90
+
91
+ return result
92
+
93
+ def eval_and_byte(a, b):
94
+ """AND two bytes using the AND circuit, bit by bit."""
95
+ result = 0
96
+ for i in range(8):
97
+ a_bit = (a >> i) & 1
98
+ b_bit = (b >> i) & 1
99
+ inp = torch.tensor([float(a_bit), float(b_bit)])
100
+ w = model['boolean.and.weight']
101
+ bias = model['boolean.and.bias']
102
+ out = int(heaviside(inp @ w + bias).item())
103
+ result |= (out << i)
104
+ return result
105
+
106
+ def shift_left_1(val):
107
+ """Shift byte left by 1, return (result, bit_shifted_out)."""
108
+ bit_out = (val >> 7) & 1
109
+ result = (val << 1) & 0xFF
110
+ return result, bit_out
111
+
112
+ def shift_right_1(val):
113
+ """Shift byte right by 1, return (result, bit_shifted_out)."""
114
+ bit_out = val & 1
115
+ result = (val >> 1) & 0xFF
116
+ return result, bit_out
117
+
118
+ # =============================================================================
119
+ # CHECKSUM ALGORITHMS IMPLEMENTED ON THRESHOLD CIRCUITS
120
+ # =============================================================================
121
+
122
+ def circuit_checksum_simple(data_bytes):
123
+ """
124
+ Simple additive checksum computed using threshold circuits.
125
+ Sum all bytes mod 256.
126
+ """
127
+ acc = 0
128
+ for byte in data_bytes:
129
+ acc, _ = add_8bit(acc, byte)
130
+ return acc
131
+
132
+ def circuit_checksum_xor(data_bytes):
133
+ """
134
+ XOR checksum computed using threshold circuits.
135
+ XOR all bytes together.
136
+ """
137
+ acc = 0
138
+ for byte in data_bytes:
139
+ acc = eval_xor_byte(acc, byte)
140
+ return acc
141
+
142
+ def circuit_fletcher8(data_bytes):
143
+ """
144
+ Fletcher-8 checksum using threshold circuits.
145
+ Two running sums: sum1 = sum of bytes, sum2 = sum of sum1s
146
+ """
147
+ sum1 = 0
148
+ sum2 = 0
149
+ for byte in data_bytes:
150
+ sum1, _ = add_8bit(sum1, byte)
151
+ sum2, _ = add_8bit(sum2, sum1)
152
+ return (sum2 << 8) | sum1 # Return as 16-bit value
153
+
154
+ def circuit_crc8_simple(data_bytes, poly=0x07):
155
+ """
156
+ Simple CRC-8 using threshold circuits.
157
+ Polynomial: x^8 + x^2 + x + 1 (0x07)
158
+ """
159
+ crc = 0
160
+ for byte in data_bytes:
161
+ crc = eval_xor_byte(crc, byte)
162
+ for _ in range(8):
163
+ crc_shifted, high_bit = shift_left_1(crc)
164
+ if high_bit:
165
+ crc = eval_xor_byte(crc_shifted, poly)
166
+ else:
167
+ crc = crc_shifted
168
+ return crc
169
+
170
+ # =============================================================================
171
+ # PYTHON REFERENCE IMPLEMENTATIONS
172
+ # =============================================================================
173
+
174
+ def python_checksum_simple(data_bytes):
175
+ """Python reference: additive checksum."""
176
+ return sum(data_bytes) % 256
177
+
178
+ def python_checksum_xor(data_bytes):
179
+ """Python reference: XOR checksum."""
180
+ result = 0
181
+ for b in data_bytes:
182
+ result ^= b
183
+ return result
184
+
185
+ def python_fletcher8(data_bytes):
186
+ """Python reference: Fletcher-8."""
187
+ sum1 = 0
188
+ sum2 = 0
189
+ for byte in data_bytes:
190
+ sum1 = (sum1 + byte) % 256
191
+ sum2 = (sum2 + sum1) % 256
192
+ return (sum2 << 8) | sum1
193
+
194
+ def python_crc8(data_bytes, poly=0x07):
195
+ """Python reference: CRC-8."""
196
+ crc = 0
197
+ for byte in data_bytes:
198
+ crc ^= byte
199
+ for _ in range(8):
200
+ if crc & 0x80:
201
+ crc = ((crc << 1) ^ poly) & 0xFF
202
+ else:
203
+ crc = (crc << 1) & 0xFF
204
+ return crc
205
+
206
+ # =============================================================================
207
+ # WEIGHT SERIALIZATION
208
+ # =============================================================================
209
+
210
+ def serialize_weights():
211
+ """
212
+ Serialize all model weights to a byte sequence.
213
+ This is the data the computer will checksum.
214
+ """
215
+ all_bytes = []
216
+
217
+ # Sort keys for deterministic ordering
218
+ for key in sorted(model.keys()):
219
+ tensor = model[key]
220
+ # Convert to bytes (as int8 since weights are small integers)
221
+ for val in tensor.flatten().tolist():
222
+ # Clamp to int8 range and convert
223
+ int_val = int(val)
224
+ # Handle signed values
225
+ if int_val < 0:
226
+ int_val = 256 + int_val # Two's complement
227
+ all_bytes.append(int_val & 0xFF)
228
+
229
+ return all_bytes
230
+
231
+ # =============================================================================
232
+ # TESTS
233
+ # =============================================================================
234
+
235
+ def test_checksum_primitives():
236
+ """Test that checksum primitives work on known data."""
237
+ print("\n[TEST 1] Checksum Primitive Verification")
238
+ print("-" * 60)
239
+
240
+ # Test data
241
+ test_cases = [
242
+ [0, 0, 0, 0],
243
+ [1, 2, 3, 4],
244
+ [255, 255, 255, 255],
245
+ [0x12, 0x34, 0x56, 0x78],
246
+ list(range(10)),
247
+ [0xAA, 0x55, 0xAA, 0x55],
248
+ ]
249
+
250
+ errors = []
251
+
252
+ for data in test_cases:
253
+ # Simple checksum
254
+ circuit_sum = circuit_checksum_simple(data)
255
+ python_sum = python_checksum_simple(data)
256
+ if circuit_sum != python_sum:
257
+ errors.append(('SUM', data, python_sum, circuit_sum))
258
+
259
+ # XOR checksum
260
+ circuit_xor = circuit_checksum_xor(data)
261
+ python_xor = python_checksum_xor(data)
262
+ if circuit_xor != python_xor:
263
+ errors.append(('XOR', data, python_xor, circuit_xor))
264
+
265
+ if errors:
266
+ print(f" FAILED: {len(errors)} mismatches")
267
+ for e in errors[:5]:
268
+ print(f" {e[0]} on {e[1]}: expected {e[2]}, got {e[3]}")
269
+ return False
270
+ else:
271
+ print(f" PASSED: {len(test_cases)} test vectors verified")
272
+ print(f" - Simple additive checksum: OK")
273
+ print(f" - XOR checksum: OK")
274
+ return True
275
+
276
+ def test_fletcher8():
277
+ """Test Fletcher-8 implementation."""
278
+ print("\n[TEST 2] Fletcher-8 Checksum")
279
+ print("-" * 60)
280
+
281
+ test_cases = [
282
+ [0x01, 0x02],
283
+ [0x00, 0x00, 0x00, 0x00],
284
+ [0xFF, 0xFF],
285
+ list(range(16)),
286
+ ]
287
+
288
+ errors = []
289
+
290
+ for data in test_cases:
291
+ circuit_f8 = circuit_fletcher8(data)
292
+ python_f8 = python_fletcher8(data)
293
+
294
+ if circuit_f8 != python_f8:
295
+ errors.append((data, python_f8, circuit_f8))
296
+
297
+ if errors:
298
+ print(f" FAILED: {len(errors)} mismatches")
299
+ for e in errors:
300
+ print(f" Data {e[0][:4]}...: expected {e[1]:04x}, got {e[2]:04x}")
301
+ return False
302
+ else:
303
+ print(f" PASSED: {len(test_cases)} Fletcher-8 tests")
304
+ return True
305
+
306
+ def test_crc8():
307
+ """Test CRC-8 implementation."""
308
+ print("\n[TEST 3] CRC-8 Checksum")
309
+ print("-" * 60)
310
+
311
+ test_cases = [
312
+ [0x00],
313
+ [0x01],
314
+ [0x01, 0x02, 0x03],
315
+ [0xFF],
316
+ [0xAA, 0x55],
317
+ ]
318
+
319
+ errors = []
320
+
321
+ for data in test_cases:
322
+ circuit_crc = circuit_crc8_simple(data)
323
+ python_crc = python_crc8(data)
324
+
325
+ if circuit_crc != python_crc:
326
+ errors.append((data, python_crc, circuit_crc))
327
+
328
+ if errors:
329
+ print(f" FAILED: {len(errors)} mismatches")
330
+ for e in errors:
331
+ print(f" Data {e[0]}: expected {e[1]:02x}, got {e[2]:02x}")
332
+ return False
333
+ else:
334
+ print(f" PASSED: {len(test_cases)} CRC-8 tests")
335
+ return True
336
+
337
+ def test_self_checksum():
338
+ """
339
+ The main event: compute checksum of the model's own weights
340
+ using the threshold circuits, compare to Python.
341
+ """
342
+ print("\n[TEST 4] Self-Checksum: Computing checksum of own weights")
343
+ print("-" * 60)
344
+
345
+ # Serialize weights
346
+ print(" Serializing weights...")
347
+ weight_bytes = serialize_weights()
348
+ print(f" Total bytes: {len(weight_bytes)}")
349
+ print(f" First 16 bytes: {weight_bytes[:16]}")
350
+
351
+ # For performance, use a subset for the intensive checksums
352
+ subset = weight_bytes[:256] # First 256 bytes
353
+
354
+ results = {}
355
+ errors = []
356
+
357
+ # Simple checksum (full weights)
358
+ print("\n Computing simple additive checksum (full weights)...")
359
+ circuit_sum = circuit_checksum_simple(weight_bytes)
360
+ python_sum = python_checksum_simple(weight_bytes)
361
+ results['simple'] = (circuit_sum, python_sum, circuit_sum == python_sum)
362
+ print(f" Circuit: {circuit_sum:3d} (0x{circuit_sum:02x})")
363
+ print(f" Python: {python_sum:3d} (0x{python_sum:02x})")
364
+ print(f" Match: {'YES' if circuit_sum == python_sum else 'NO'}")
365
+ if circuit_sum != python_sum:
366
+ errors.append('simple')
367
+
368
+ # XOR checksum (full weights)
369
+ print("\n Computing XOR checksum (full weights)...")
370
+ circuit_xor = circuit_checksum_xor(weight_bytes)
371
+ python_xor = python_checksum_xor(weight_bytes)
372
+ results['xor'] = (circuit_xor, python_xor, circuit_xor == python_xor)
373
+ print(f" Circuit: {circuit_xor:3d} (0x{circuit_xor:02x})")
374
+ print(f" Python: {python_xor:3d} (0x{python_xor:02x})")
375
+ print(f" Match: {'YES' if circuit_xor == python_xor else 'NO'}")
376
+ if circuit_xor != python_xor:
377
+ errors.append('xor')
378
+
379
+ # Fletcher-8 (subset for performance)
380
+ print(f"\n Computing Fletcher-8 (first {len(subset)} bytes)...")
381
+ circuit_f8 = circuit_fletcher8(subset)
382
+ python_f8 = python_fletcher8(subset)
383
+ results['fletcher8'] = (circuit_f8, python_f8, circuit_f8 == python_f8)
384
+ print(f" Circuit: {circuit_f8:5d} (0x{circuit_f8:04x})")
385
+ print(f" Python: {python_f8:5d} (0x{python_f8:04x})")
386
+ print(f" Match: {'YES' if circuit_f8 == python_f8 else 'NO'}")
387
+ if circuit_f8 != python_f8:
388
+ errors.append('fletcher8')
389
+
390
+ # CRC-8 (smaller subset - it's slow)
391
+ crc_subset = weight_bytes[:64]
392
+ print(f"\n Computing CRC-8 (first {len(crc_subset)} bytes)...")
393
+ circuit_crc = circuit_crc8_simple(crc_subset)
394
+ python_crc = python_crc8(crc_subset)
395
+ results['crc8'] = (circuit_crc, python_crc, circuit_crc == python_crc)
396
+ print(f" Circuit: {circuit_crc:3d} (0x{circuit_crc:02x})")
397
+ print(f" Python: {python_crc:3d} (0x{python_crc:02x})")
398
+ print(f" Match: {'YES' if circuit_crc == python_crc else 'NO'}")
399
+ if circuit_crc != python_crc:
400
+ errors.append('crc8')
401
+
402
+ print()
403
+ if errors:
404
+ print(f" FAILED: {len(errors)} checksums did not match")
405
+ return False
406
+ else:
407
+ print(f" PASSED: All 4 self-checksums match Python reference")
408
+ return True
409
+
410
+ def test_tamper_detection():
411
+ """
412
+ Verify that tampering with weights changes the checksum.
413
+ """
414
+ print("\n[TEST 5] Tamper Detection")
415
+ print("-" * 60)
416
+
417
+ weight_bytes = serialize_weights()
418
+ original_checksum = python_checksum_simple(weight_bytes)
419
+
420
+ print(f" Original checksum: {original_checksum} (0x{original_checksum:02x})")
421
+
422
+ # Tamper with one byte
423
+ tampered = weight_bytes.copy()
424
+ tampered[100] = (tampered[100] + 1) % 256
425
+ tampered_checksum = python_checksum_simple(tampered)
426
+
427
+ print(f" Tampered checksum: {tampered_checksum} (0x{tampered_checksum:02x})")
428
+ print(f" Checksums differ: {'YES' if original_checksum != tampered_checksum else 'NO'}")
429
+
430
+ # Verify circuit detects the same difference
431
+ circuit_original = circuit_checksum_simple(weight_bytes[:128])
432
+ circuit_tampered = circuit_checksum_simple(tampered[:128])
433
+
434
+ print(f"\n Circuit verification (first 128 bytes):")
435
+ print(f" Original: {circuit_original}")
436
+ print(f" Tampered: {circuit_tampered}")
437
+ print(f" Detects tampering: {'YES' if circuit_original != circuit_tampered else 'NO'}")
438
+
439
+ if original_checksum != tampered_checksum and circuit_original != circuit_tampered:
440
+ print("\n PASSED: Tampering detected by both Python and circuit")
441
+ return True
442
+ else:
443
+ print("\n FAILED: Tampering not properly detected")
444
+ return False
445
+
446
+ def test_weight_statistics():
447
+ """
448
+ Compute and display statistics about the weights.
449
+ """
450
+ print("\n[TEST 6] Weight Statistics")
451
+ print("-" * 60)
452
+
453
+ weight_bytes = serialize_weights()
454
+
455
+ print(f" Total weight bytes: {len(weight_bytes)}")
456
+ print(f" Unique values: {len(set(weight_bytes))}")
457
+ print(f" Min value: {min(weight_bytes)}")
458
+ print(f" Max value: {max(weight_bytes)}")
459
+
460
+ # Value distribution
461
+ from collections import Counter
462
+ counts = Counter(weight_bytes)
463
+ most_common = counts.most_common(5)
464
+ print(f" Most common values:")
465
+ for val, count in most_common:
466
+ pct = 100 * count / len(weight_bytes)
467
+ print(f" {val:3d} (0x{val:02x}): {count:4d} occurrences ({pct:.1f}%)")
468
+
469
+ # Checksums for reference
470
+ print(f"\n Reference checksums:")
471
+ print(f" Simple sum: {python_checksum_simple(weight_bytes)}")
472
+ print(f" XOR: {python_checksum_xor(weight_bytes)}")
473
+ print(f" Fletcher-8: 0x{python_fletcher8(weight_bytes):04x}")
474
+ print(f" CRC-8: 0x{python_crc8(weight_bytes[:256]):02x} (first 256 bytes)")
475
+
476
+ return True
477
+
478
+ # =============================================================================
479
+ # MAIN
480
+ # =============================================================================
481
+
482
+ if __name__ == "__main__":
483
+ print("=" * 70)
484
+ print(" TEST #6: CRYPTOGRAPHIC SELF-TEST")
485
+ print(" Computing checksums of weights using the weights themselves")
486
+ print("=" * 70)
487
+
488
+ results = []
489
+
490
+ results.append(("Checksum primitives", test_checksum_primitives()))
491
+ results.append(("Fletcher-8", test_fletcher8()))
492
+ results.append(("CRC-8", test_crc8()))
493
+ results.append(("Self-checksum", test_self_checksum()))
494
+ results.append(("Tamper detection", test_tamper_detection()))
495
+ results.append(("Weight statistics", test_weight_statistics()))
496
+
497
+ print("\n" + "=" * 70)
498
+ print(" SUMMARY")
499
+ print("=" * 70)
500
+
501
+ passed = sum(1 for _, r in results if r)
502
+ total = len(results)
503
+
504
+ for name, r in results:
505
+ status = "PASS" if r else "FAIL"
506
+ print(f" {name:25s} [{status}]")
507
+
508
+ print(f"\n Total: {passed}/{total} tests passed")
509
+
510
+ if passed == total:
511
+ print("\n STATUS: CRYPTOGRAPHIC SELF-TEST COMPLETE")
512
+ print(" The computer verified its own integrity.")
513
+ else:
514
+ print("\n STATUS: SOME SELF-TESTS FAILED")
515
+
516
+ print("=" * 70)