phanerozoic commited on
Commit
a76f318
·
verified ·
1 Parent(s): 926f717

Delete test_cryptographic_selftest.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_cryptographic_selftest.py +0 -516
test_cryptographic_selftest.py DELETED
@@ -1,516 +0,0 @@
1
- """
2
- TEST #6: Cryptographic Self-Test
3
- =================================
4
- Have the threshold computer compute a checksum over its own weights.
5
- Verify the result matches external (Python) computation.
6
-
7
- A skeptic would demand: "Prove the computer can verify its own integrity.
8
- Bootstrap trust by having it compute over its own weights."
9
- """
10
-
11
- import torch
12
- from safetensors.torch import load_file
13
- import struct
14
-
15
- # Load circuits
16
- model = load_file('neural_computer.safetensors')
17
-
18
- def heaviside(x):
19
- return (x >= 0).float()
20
-
21
- # =============================================================================
22
- # CIRCUIT PRIMITIVES
23
- # =============================================================================
24
-
25
- def eval_xor_arith(inp, prefix):
26
- """Evaluate XOR for arithmetic circuits."""
27
- w1_or = model[f'{prefix}.layer1.or.weight']
28
- b1_or = model[f'{prefix}.layer1.or.bias']
29
- w1_nand = model[f'{prefix}.layer1.nand.weight']
30
- b1_nand = model[f'{prefix}.layer1.nand.bias']
31
- w2 = model[f'{prefix}.layer2.weight']
32
- b2 = model[f'{prefix}.layer2.bias']
33
- h_or = heaviside(inp @ w1_or + b1_or)
34
- h_nand = heaviside(inp @ w1_nand + b1_nand)
35
- hidden = torch.tensor([h_or.item(), h_nand.item()])
36
- return heaviside(hidden @ w2 + b2).item()
37
-
38
- def eval_full_adder(a, b, cin, prefix):
39
- """Evaluate full adder, return (sum, carry_out)."""
40
- inp_ab = torch.tensor([a, b], dtype=torch.float32)
41
- ha1_sum = eval_xor_arith(inp_ab, f'{prefix}.ha1.sum')
42
- w_c1 = model[f'{prefix}.ha1.carry.weight']
43
- b_c1 = model[f'{prefix}.ha1.carry.bias']
44
- ha1_carry = heaviside(inp_ab @ w_c1 + b_c1).item()
45
- inp_ha2 = torch.tensor([ha1_sum, cin], dtype=torch.float32)
46
- ha2_sum = eval_xor_arith(inp_ha2, f'{prefix}.ha2.sum')
47
- w_c2 = model[f'{prefix}.ha2.carry.weight']
48
- b_c2 = model[f'{prefix}.ha2.carry.bias']
49
- ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2).item()
50
- inp_cout = torch.tensor([ha1_carry, ha2_carry], dtype=torch.float32)
51
- w_or = model[f'{prefix}.carry_or.weight']
52
- b_or = model[f'{prefix}.carry_or.bias']
53
- cout = heaviside(inp_cout @ w_or + b_or).item()
54
- return int(ha2_sum), int(cout)
55
-
56
- def add_8bit(a, b):
57
- """8-bit addition using ripple carry adder."""
58
- carry = 0.0
59
- result_bits = []
60
- for i in range(8):
61
- a_bit = (a >> i) & 1
62
- b_bit = (b >> i) & 1
63
- s, carry = eval_full_adder(float(a_bit), float(b_bit), carry,
64
- f'arithmetic.ripplecarry8bit.fa{i}')
65
- result_bits.append(s)
66
- result = sum(result_bits[i] * (2**i) for i in range(8))
67
- return result, int(carry)
68
-
69
- def eval_xor_byte(a, b):
70
- """XOR two bytes using the XOR circuit, bit by bit."""
71
- result = 0
72
- for i in range(8):
73
- a_bit = (a >> i) & 1
74
- b_bit = (b >> i) & 1
75
- inp = torch.tensor([float(a_bit), float(b_bit)])
76
-
77
- w1_n1 = model['boolean.xor.layer1.neuron1.weight']
78
- b1_n1 = model['boolean.xor.layer1.neuron1.bias']
79
- w1_n2 = model['boolean.xor.layer1.neuron2.weight']
80
- b1_n2 = model['boolean.xor.layer1.neuron2.bias']
81
- w2 = model['boolean.xor.layer2.weight']
82
- b2 = model['boolean.xor.layer2.bias']
83
-
84
- h1 = heaviside(inp @ w1_n1 + b1_n1)
85
- h2 = heaviside(inp @ w1_n2 + b1_n2)
86
- hidden = torch.tensor([h1.item(), h2.item()])
87
- out = int(heaviside(hidden @ w2 + b2).item())
88
-
89
- result |= (out << i)
90
-
91
- return result
92
-
93
- def eval_and_byte(a, b):
94
- """AND two bytes using the AND circuit, bit by bit."""
95
- result = 0
96
- for i in range(8):
97
- a_bit = (a >> i) & 1
98
- b_bit = (b >> i) & 1
99
- inp = torch.tensor([float(a_bit), float(b_bit)])
100
- w = model['boolean.and.weight']
101
- bias = model['boolean.and.bias']
102
- out = int(heaviside(inp @ w + bias).item())
103
- result |= (out << i)
104
- return result
105
-
106
- def shift_left_1(val):
107
- """Shift byte left by 1, return (result, bit_shifted_out)."""
108
- bit_out = (val >> 7) & 1
109
- result = (val << 1) & 0xFF
110
- return result, bit_out
111
-
112
- def shift_right_1(val):
113
- """Shift byte right by 1, return (result, bit_shifted_out)."""
114
- bit_out = val & 1
115
- result = (val >> 1) & 0xFF
116
- return result, bit_out
117
-
118
- # =============================================================================
119
- # CHECKSUM ALGORITHMS IMPLEMENTED ON THRESHOLD CIRCUITS
120
- # =============================================================================
121
-
122
- def circuit_checksum_simple(data_bytes):
123
- """
124
- Simple additive checksum computed using threshold circuits.
125
- Sum all bytes mod 256.
126
- """
127
- acc = 0
128
- for byte in data_bytes:
129
- acc, _ = add_8bit(acc, byte)
130
- return acc
131
-
132
- def circuit_checksum_xor(data_bytes):
133
- """
134
- XOR checksum computed using threshold circuits.
135
- XOR all bytes together.
136
- """
137
- acc = 0
138
- for byte in data_bytes:
139
- acc = eval_xor_byte(acc, byte)
140
- return acc
141
-
142
- def circuit_fletcher8(data_bytes):
143
- """
144
- Fletcher-8 checksum using threshold circuits.
145
- Two running sums: sum1 = sum of bytes, sum2 = sum of sum1s
146
- """
147
- sum1 = 0
148
- sum2 = 0
149
- for byte in data_bytes:
150
- sum1, _ = add_8bit(sum1, byte)
151
- sum2, _ = add_8bit(sum2, sum1)
152
- return (sum2 << 8) | sum1 # Return as 16-bit value
153
-
154
- def circuit_crc8_simple(data_bytes, poly=0x07):
155
- """
156
- Simple CRC-8 using threshold circuits.
157
- Polynomial: x^8 + x^2 + x + 1 (0x07)
158
- """
159
- crc = 0
160
- for byte in data_bytes:
161
- crc = eval_xor_byte(crc, byte)
162
- for _ in range(8):
163
- crc_shifted, high_bit = shift_left_1(crc)
164
- if high_bit:
165
- crc = eval_xor_byte(crc_shifted, poly)
166
- else:
167
- crc = crc_shifted
168
- return crc
169
-
170
- # =============================================================================
171
- # PYTHON REFERENCE IMPLEMENTATIONS
172
- # =============================================================================
173
-
174
- def python_checksum_simple(data_bytes):
175
- """Python reference: additive checksum."""
176
- return sum(data_bytes) % 256
177
-
178
- def python_checksum_xor(data_bytes):
179
- """Python reference: XOR checksum."""
180
- result = 0
181
- for b in data_bytes:
182
- result ^= b
183
- return result
184
-
185
- def python_fletcher8(data_bytes):
186
- """Python reference: Fletcher-8."""
187
- sum1 = 0
188
- sum2 = 0
189
- for byte in data_bytes:
190
- sum1 = (sum1 + byte) % 256
191
- sum2 = (sum2 + sum1) % 256
192
- return (sum2 << 8) | sum1
193
-
194
- def python_crc8(data_bytes, poly=0x07):
195
- """Python reference: CRC-8."""
196
- crc = 0
197
- for byte in data_bytes:
198
- crc ^= byte
199
- for _ in range(8):
200
- if crc & 0x80:
201
- crc = ((crc << 1) ^ poly) & 0xFF
202
- else:
203
- crc = (crc << 1) & 0xFF
204
- return crc
205
-
206
- # =============================================================================
207
- # WEIGHT SERIALIZATION
208
- # =============================================================================
209
-
210
- def serialize_weights():
211
- """
212
- Serialize all model weights to a byte sequence.
213
- This is the data the computer will checksum.
214
- """
215
- all_bytes = []
216
-
217
- # Sort keys for deterministic ordering
218
- for key in sorted(model.keys()):
219
- tensor = model[key]
220
- # Convert to bytes (as int8 since weights are small integers)
221
- for val in tensor.flatten().tolist():
222
- # Clamp to int8 range and convert
223
- int_val = int(val)
224
- # Handle signed values
225
- if int_val < 0:
226
- int_val = 256 + int_val # Two's complement
227
- all_bytes.append(int_val & 0xFF)
228
-
229
- return all_bytes
230
-
231
- # =============================================================================
232
- # TESTS
233
- # =============================================================================
234
-
235
- def test_checksum_primitives():
236
- """Test that checksum primitives work on known data."""
237
- print("\n[TEST 1] Checksum Primitive Verification")
238
- print("-" * 60)
239
-
240
- # Test data
241
- test_cases = [
242
- [0, 0, 0, 0],
243
- [1, 2, 3, 4],
244
- [255, 255, 255, 255],
245
- [0x12, 0x34, 0x56, 0x78],
246
- list(range(10)),
247
- [0xAA, 0x55, 0xAA, 0x55],
248
- ]
249
-
250
- errors = []
251
-
252
- for data in test_cases:
253
- # Simple checksum
254
- circuit_sum = circuit_checksum_simple(data)
255
- python_sum = python_checksum_simple(data)
256
- if circuit_sum != python_sum:
257
- errors.append(('SUM', data, python_sum, circuit_sum))
258
-
259
- # XOR checksum
260
- circuit_xor = circuit_checksum_xor(data)
261
- python_xor = python_checksum_xor(data)
262
- if circuit_xor != python_xor:
263
- errors.append(('XOR', data, python_xor, circuit_xor))
264
-
265
- if errors:
266
- print(f" FAILED: {len(errors)} mismatches")
267
- for e in errors[:5]:
268
- print(f" {e[0]} on {e[1]}: expected {e[2]}, got {e[3]}")
269
- return False
270
- else:
271
- print(f" PASSED: {len(test_cases)} test vectors verified")
272
- print(f" - Simple additive checksum: OK")
273
- print(f" - XOR checksum: OK")
274
- return True
275
-
276
- def test_fletcher8():
277
- """Test Fletcher-8 implementation."""
278
- print("\n[TEST 2] Fletcher-8 Checksum")
279
- print("-" * 60)
280
-
281
- test_cases = [
282
- [0x01, 0x02],
283
- [0x00, 0x00, 0x00, 0x00],
284
- [0xFF, 0xFF],
285
- list(range(16)),
286
- ]
287
-
288
- errors = []
289
-
290
- for data in test_cases:
291
- circuit_f8 = circuit_fletcher8(data)
292
- python_f8 = python_fletcher8(data)
293
-
294
- if circuit_f8 != python_f8:
295
- errors.append((data, python_f8, circuit_f8))
296
-
297
- if errors:
298
- print(f" FAILED: {len(errors)} mismatches")
299
- for e in errors:
300
- print(f" Data {e[0][:4]}...: expected {e[1]:04x}, got {e[2]:04x}")
301
- return False
302
- else:
303
- print(f" PASSED: {len(test_cases)} Fletcher-8 tests")
304
- return True
305
-
306
- def test_crc8():
307
- """Test CRC-8 implementation."""
308
- print("\n[TEST 3] CRC-8 Checksum")
309
- print("-" * 60)
310
-
311
- test_cases = [
312
- [0x00],
313
- [0x01],
314
- [0x01, 0x02, 0x03],
315
- [0xFF],
316
- [0xAA, 0x55],
317
- ]
318
-
319
- errors = []
320
-
321
- for data in test_cases:
322
- circuit_crc = circuit_crc8_simple(data)
323
- python_crc = python_crc8(data)
324
-
325
- if circuit_crc != python_crc:
326
- errors.append((data, python_crc, circuit_crc))
327
-
328
- if errors:
329
- print(f" FAILED: {len(errors)} mismatches")
330
- for e in errors:
331
- print(f" Data {e[0]}: expected {e[1]:02x}, got {e[2]:02x}")
332
- return False
333
- else:
334
- print(f" PASSED: {len(test_cases)} CRC-8 tests")
335
- return True
336
-
337
- def test_self_checksum():
338
- """
339
- The main event: compute checksum of the model's own weights
340
- using the threshold circuits, compare to Python.
341
- """
342
- print("\n[TEST 4] Self-Checksum: Computing checksum of own weights")
343
- print("-" * 60)
344
-
345
- # Serialize weights
346
- print(" Serializing weights...")
347
- weight_bytes = serialize_weights()
348
- print(f" Total bytes: {len(weight_bytes)}")
349
- print(f" First 16 bytes: {weight_bytes[:16]}")
350
-
351
- # For performance, use a subset for the intensive checksums
352
- subset = weight_bytes[:256] # First 256 bytes
353
-
354
- results = {}
355
- errors = []
356
-
357
- # Simple checksum (full weights)
358
- print("\n Computing simple additive checksum (full weights)...")
359
- circuit_sum = circuit_checksum_simple(weight_bytes)
360
- python_sum = python_checksum_simple(weight_bytes)
361
- results['simple'] = (circuit_sum, python_sum, circuit_sum == python_sum)
362
- print(f" Circuit: {circuit_sum:3d} (0x{circuit_sum:02x})")
363
- print(f" Python: {python_sum:3d} (0x{python_sum:02x})")
364
- print(f" Match: {'YES' if circuit_sum == python_sum else 'NO'}")
365
- if circuit_sum != python_sum:
366
- errors.append('simple')
367
-
368
- # XOR checksum (full weights)
369
- print("\n Computing XOR checksum (full weights)...")
370
- circuit_xor = circuit_checksum_xor(weight_bytes)
371
- python_xor = python_checksum_xor(weight_bytes)
372
- results['xor'] = (circuit_xor, python_xor, circuit_xor == python_xor)
373
- print(f" Circuit: {circuit_xor:3d} (0x{circuit_xor:02x})")
374
- print(f" Python: {python_xor:3d} (0x{python_xor:02x})")
375
- print(f" Match: {'YES' if circuit_xor == python_xor else 'NO'}")
376
- if circuit_xor != python_xor:
377
- errors.append('xor')
378
-
379
- # Fletcher-8 (subset for performance)
380
- print(f"\n Computing Fletcher-8 (first {len(subset)} bytes)...")
381
- circuit_f8 = circuit_fletcher8(subset)
382
- python_f8 = python_fletcher8(subset)
383
- results['fletcher8'] = (circuit_f8, python_f8, circuit_f8 == python_f8)
384
- print(f" Circuit: {circuit_f8:5d} (0x{circuit_f8:04x})")
385
- print(f" Python: {python_f8:5d} (0x{python_f8:04x})")
386
- print(f" Match: {'YES' if circuit_f8 == python_f8 else 'NO'}")
387
- if circuit_f8 != python_f8:
388
- errors.append('fletcher8')
389
-
390
- # CRC-8 (smaller subset - it's slow)
391
- crc_subset = weight_bytes[:64]
392
- print(f"\n Computing CRC-8 (first {len(crc_subset)} bytes)...")
393
- circuit_crc = circuit_crc8_simple(crc_subset)
394
- python_crc = python_crc8(crc_subset)
395
- results['crc8'] = (circuit_crc, python_crc, circuit_crc == python_crc)
396
- print(f" Circuit: {circuit_crc:3d} (0x{circuit_crc:02x})")
397
- print(f" Python: {python_crc:3d} (0x{python_crc:02x})")
398
- print(f" Match: {'YES' if circuit_crc == python_crc else 'NO'}")
399
- if circuit_crc != python_crc:
400
- errors.append('crc8')
401
-
402
- print()
403
- if errors:
404
- print(f" FAILED: {len(errors)} checksums did not match")
405
- return False
406
- else:
407
- print(f" PASSED: All 4 self-checksums match Python reference")
408
- return True
409
-
410
- def test_tamper_detection():
411
- """
412
- Verify that tampering with weights changes the checksum.
413
- """
414
- print("\n[TEST 5] Tamper Detection")
415
- print("-" * 60)
416
-
417
- weight_bytes = serialize_weights()
418
- original_checksum = python_checksum_simple(weight_bytes)
419
-
420
- print(f" Original checksum: {original_checksum} (0x{original_checksum:02x})")
421
-
422
- # Tamper with one byte
423
- tampered = weight_bytes.copy()
424
- tampered[100] = (tampered[100] + 1) % 256
425
- tampered_checksum = python_checksum_simple(tampered)
426
-
427
- print(f" Tampered checksum: {tampered_checksum} (0x{tampered_checksum:02x})")
428
- print(f" Checksums differ: {'YES' if original_checksum != tampered_checksum else 'NO'}")
429
-
430
- # Verify circuit detects the same difference
431
- circuit_original = circuit_checksum_simple(weight_bytes[:128])
432
- circuit_tampered = circuit_checksum_simple(tampered[:128])
433
-
434
- print(f"\n Circuit verification (first 128 bytes):")
435
- print(f" Original: {circuit_original}")
436
- print(f" Tampered: {circuit_tampered}")
437
- print(f" Detects tampering: {'YES' if circuit_original != circuit_tampered else 'NO'}")
438
-
439
- if original_checksum != tampered_checksum and circuit_original != circuit_tampered:
440
- print("\n PASSED: Tampering detected by both Python and circuit")
441
- return True
442
- else:
443
- print("\n FAILED: Tampering not properly detected")
444
- return False
445
-
446
- def test_weight_statistics():
447
- """
448
- Compute and display statistics about the weights.
449
- """
450
- print("\n[TEST 6] Weight Statistics")
451
- print("-" * 60)
452
-
453
- weight_bytes = serialize_weights()
454
-
455
- print(f" Total weight bytes: {len(weight_bytes)}")
456
- print(f" Unique values: {len(set(weight_bytes))}")
457
- print(f" Min value: {min(weight_bytes)}")
458
- print(f" Max value: {max(weight_bytes)}")
459
-
460
- # Value distribution
461
- from collections import Counter
462
- counts = Counter(weight_bytes)
463
- most_common = counts.most_common(5)
464
- print(f" Most common values:")
465
- for val, count in most_common:
466
- pct = 100 * count / len(weight_bytes)
467
- print(f" {val:3d} (0x{val:02x}): {count:4d} occurrences ({pct:.1f}%)")
468
-
469
- # Checksums for reference
470
- print(f"\n Reference checksums:")
471
- print(f" Simple sum: {python_checksum_simple(weight_bytes)}")
472
- print(f" XOR: {python_checksum_xor(weight_bytes)}")
473
- print(f" Fletcher-8: 0x{python_fletcher8(weight_bytes):04x}")
474
- print(f" CRC-8: 0x{python_crc8(weight_bytes[:256]):02x} (first 256 bytes)")
475
-
476
- return True
477
-
478
- # =============================================================================
479
- # MAIN
480
- # =============================================================================
481
-
482
- if __name__ == "__main__":
483
- print("=" * 70)
484
- print(" TEST #6: CRYPTOGRAPHIC SELF-TEST")
485
- print(" Computing checksums of weights using the weights themselves")
486
- print("=" * 70)
487
-
488
- results = []
489
-
490
- results.append(("Checksum primitives", test_checksum_primitives()))
491
- results.append(("Fletcher-8", test_fletcher8()))
492
- results.append(("CRC-8", test_crc8()))
493
- results.append(("Self-checksum", test_self_checksum()))
494
- results.append(("Tamper detection", test_tamper_detection()))
495
- results.append(("Weight statistics", test_weight_statistics()))
496
-
497
- print("\n" + "=" * 70)
498
- print(" SUMMARY")
499
- print("=" * 70)
500
-
501
- passed = sum(1 for _, r in results if r)
502
- total = len(results)
503
-
504
- for name, r in results:
505
- status = "PASS" if r else "FAIL"
506
- print(f" {name:25s} [{status}]")
507
-
508
- print(f"\n Total: {passed}/{total} tests passed")
509
-
510
- if passed == total:
511
- print("\n STATUS: CRYPTOGRAPHIC SELF-TEST COMPLETE")
512
- print(" The computer verified its own integrity.")
513
- else:
514
- print("\n STATUS: SOME SELF-TESTS FAILED")
515
-
516
- print("=" * 70)