phanerozoic commited on
Commit
b3392d0
·
verified ·
1 Parent(s): 1b1d927

Upload skeptic_test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. skeptic_test.py +215 -0
skeptic_test.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings('ignore')
3
+ from safetensors.torch import load_file
4
+ import torch
5
+
6
+ model = load_file('neural_computer.safetensors')
7
+
8
+ def heaviside(x):
9
+ return (x >= 0).float()
10
+
11
+ def int_to_bits(val):
12
+ return torch.tensor([(val >> (7-i)) & 1 for i in range(8)], dtype=torch.float32)
13
+
14
+ def eval_xor_bool(a, b):
15
+ inp = torch.tensor([float(a), float(b)], dtype=torch.float32)
16
+ w1_or = model['boolean.xor.layer1.neuron1.weight']
17
+ b1_or = model['boolean.xor.layer1.neuron1.bias']
18
+ w1_nand = model['boolean.xor.layer1.neuron2.weight']
19
+ b1_nand = model['boolean.xor.layer1.neuron2.bias']
20
+ w2 = model['boolean.xor.layer2.weight']
21
+ b2 = model['boolean.xor.layer2.bias']
22
+ h_or = heaviside(inp @ w1_or + b1_or)
23
+ h_nand = heaviside(inp @ w1_nand + b1_nand)
24
+ hidden = torch.tensor([h_or.item(), h_nand.item()])
25
+ return int(heaviside(hidden @ w2 + b2).item())
26
+
27
+ def eval_xor_arith(inp, prefix):
28
+ w1_or = model[f'{prefix}.layer1.or.weight']
29
+ b1_or = model[f'{prefix}.layer1.or.bias']
30
+ w1_nand = model[f'{prefix}.layer1.nand.weight']
31
+ b1_nand = model[f'{prefix}.layer1.nand.bias']
32
+ w2 = model[f'{prefix}.layer2.weight']
33
+ b2 = model[f'{prefix}.layer2.bias']
34
+ h_or = heaviside(inp @ w1_or + b1_or)
35
+ h_nand = heaviside(inp @ w1_nand + b1_nand)
36
+ hidden = torch.tensor([h_or.item(), h_nand.item()])
37
+ return heaviside(hidden @ w2 + b2).item()
38
+
39
+ def eval_full_adder(a, b, cin, prefix):
40
+ inp_ab = torch.tensor([a, b], dtype=torch.float32)
41
+ ha1_sum = eval_xor_arith(inp_ab, f'{prefix}.ha1.sum')
42
+ w_c1 = model[f'{prefix}.ha1.carry.weight']
43
+ b_c1 = model[f'{prefix}.ha1.carry.bias']
44
+ ha1_carry = heaviside(inp_ab @ w_c1 + b_c1).item()
45
+ inp_ha2 = torch.tensor([ha1_sum, cin], dtype=torch.float32)
46
+ ha2_sum = eval_xor_arith(inp_ha2, f'{prefix}.ha2.sum')
47
+ w_c2 = model[f'{prefix}.ha2.carry.weight']
48
+ b_c2 = model[f'{prefix}.ha2.carry.bias']
49
+ ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2).item()
50
+ inp_cout = torch.tensor([ha1_carry, ha2_carry], dtype=torch.float32)
51
+ w_or = model[f'{prefix}.carry_or.weight']
52
+ b_or = model[f'{prefix}.carry_or.bias']
53
+ cout = heaviside(inp_cout @ w_or + b_or).item()
54
+ return int(ha2_sum), int(cout)
55
+
56
+ def add_8bit(a, b):
57
+ carry = 0.0
58
+ result = 0
59
+ for i in range(8):
60
+ s, carry = eval_full_adder(float((a >> i) & 1), float((b >> i) & 1), carry, f'arithmetic.ripplecarry8bit.fa{i}')
61
+ result |= (s << i)
62
+ return result, int(carry)
63
+
64
+ def xor_8bit(a, b):
65
+ result = 0
66
+ for i in range(8):
67
+ bit = eval_xor_bool((a >> i) & 1, (b >> i) & 1)
68
+ result |= (bit << i)
69
+ return result
70
+
71
+ def and_8bit(a, b):
72
+ result = 0
73
+ w = model['boolean.and.weight']
74
+ bias = model['boolean.and.bias']
75
+ for i in range(8):
76
+ inp = torch.tensor([float((a >> i) & 1), float((b >> i) & 1)], dtype=torch.float32)
77
+ out = int(heaviside(inp @ w + bias).item())
78
+ result |= (out << i)
79
+ return result
80
+
81
+ def or_8bit(a, b):
82
+ result = 0
83
+ w = model['boolean.or.weight']
84
+ bias = model['boolean.or.bias']
85
+ for i in range(8):
86
+ inp = torch.tensor([float((a >> i) & 1), float((b >> i) & 1)], dtype=torch.float32)
87
+ out = int(heaviside(inp @ w + bias).item())
88
+ result |= (out << i)
89
+ return result
90
+
91
+ def not_8bit(a):
92
+ result = 0
93
+ w = model['boolean.not.weight']
94
+ bias = model['boolean.not.bias']
95
+ for i in range(8):
96
+ inp = torch.tensor([float((a >> i) & 1)], dtype=torch.float32)
97
+ out = int(heaviside(inp @ w + bias).item())
98
+ result |= (out << i)
99
+ return result
100
+
101
+ def gt(a, b):
102
+ a_bits, b_bits = int_to_bits(a), int_to_bits(b)
103
+ w = model['arithmetic.greaterthan8bit.comparator']
104
+ return 1 if ((a_bits - b_bits) @ w).item() > 0 else 0
105
+
106
+ def lt(a, b):
107
+ a_bits, b_bits = int_to_bits(a), int_to_bits(b)
108
+ w = model['arithmetic.lessthan8bit.comparator']
109
+ return 1 if ((b_bits - a_bits) @ w).item() > 0 else 0
110
+
111
+ def eq(a, b):
112
+ return 1 if (gt(a,b) == 0 and lt(a,b) == 0) else 0
113
+
114
+ print('=' * 70)
115
+ print('SKEPTICAL NERD TESTS')
116
+ print('=' * 70)
117
+
118
+ failures = []
119
+
120
+ print('\n[1] IDENTITY LAWS')
121
+ for a in [0, 1, 127, 128, 255, 170, 85]:
122
+ r, _ = add_8bit(a, 0)
123
+ if r != a: failures.append(f'A+0: {a}')
124
+ if xor_8bit(a, 0) != a: failures.append(f'A^0: {a}')
125
+ if and_8bit(a, 255) != a: failures.append(f'A&255: {a}')
126
+ if or_8bit(a, 0) != a: failures.append(f'A|0: {a}')
127
+ print(' 28 tests')
128
+
129
+ print('\n[2] ANNIHILATION LAWS')
130
+ for a in [0, 1, 127, 128, 255]:
131
+ if and_8bit(a, 0) != 0: failures.append(f'A&0: {a}')
132
+ if or_8bit(a, 255) != 255: failures.append(f'A|255: {a}')
133
+ if xor_8bit(a, a) != 0: failures.append(f'A^A: {a}')
134
+ print(' 15 tests')
135
+
136
+ print('\n[3] INVOLUTION (~~A = A)')
137
+ for a in [0, 1, 127, 128, 255, 170]:
138
+ if not_8bit(not_8bit(a)) != a: failures.append(f'~~A: {a}')
139
+ print(' 6 tests')
140
+
141
+ print('\n[4] TWOS COMPLEMENT: A + ~A + 1 = 0')
142
+ for a in [0, 1, 42, 127, 128, 255]:
143
+ not_a = not_8bit(a)
144
+ r1, _ = add_8bit(a, not_a)
145
+ r2, _ = add_8bit(r1, 1)
146
+ if r2 != 0: failures.append(f'twos comp: {a}')
147
+ print(' 6 tests')
148
+
149
+ print('\n[5] CARRY PROPAGATION (worst case)')
150
+ cases = [(255, 1, 0), (127, 129, 0), (1, 255, 0), (128, 128, 0), (255, 255, 254)]
151
+ for a, b, exp in cases:
152
+ r, _ = add_8bit(a, b)
153
+ if r != exp: failures.append(f'carry: {a}+{b}={r}, expected {exp}')
154
+ print(' 5 tests')
155
+
156
+ print('\n[6] COMMUTATIVITY')
157
+ pairs = [(17, 42), (0, 255), (128, 127), (1, 254), (170, 85)]
158
+ for a, b in pairs:
159
+ r1, _ = add_8bit(a, b)
160
+ r2, _ = add_8bit(b, a)
161
+ if r1 != r2: failures.append(f'add commute: {a},{b}')
162
+ if xor_8bit(a, b) != xor_8bit(b, a): failures.append(f'xor commute: {a},{b}')
163
+ if and_8bit(a, b) != and_8bit(b, a): failures.append(f'and commute: {a},{b}')
164
+ if or_8bit(a, b) != or_8bit(b, a): failures.append(f'or commute: {a},{b}')
165
+ print(' 20 tests')
166
+
167
+ print('\n[7] DE MORGAN')
168
+ for a, b in [(0, 0), (0, 255), (255, 0), (255, 255), (170, 85)]:
169
+ lhs = not_8bit(and_8bit(a, b))
170
+ rhs = or_8bit(not_8bit(a), not_8bit(b))
171
+ if lhs != rhs: failures.append(f'DM1: {a},{b}')
172
+ lhs = not_8bit(or_8bit(a, b))
173
+ rhs = and_8bit(not_8bit(a), not_8bit(b))
174
+ if lhs != rhs: failures.append(f'DM2: {a},{b}')
175
+ print(' 10 tests')
176
+
177
+ print('\n[8] COMPARATOR EDGE CASES')
178
+ cmp_tests = [
179
+ (0, 0, 0, 0, 1), (0, 1, 0, 1, 0), (1, 0, 1, 0, 0),
180
+ (127, 128, 0, 1, 0), (128, 127, 1, 0, 0),
181
+ (255, 255, 0, 0, 1), (255, 0, 1, 0, 0), (0, 255, 0, 1, 0),
182
+ ]
183
+ for a, b, exp_gt, exp_lt, exp_eq in cmp_tests:
184
+ if gt(a, b) != exp_gt: failures.append(f'gt({a},{b})')
185
+ if lt(a, b) != exp_lt: failures.append(f'lt({a},{b})')
186
+ if eq(a, b) != exp_eq: failures.append(f'eq({a},{b})')
187
+ print(' 24 tests')
188
+
189
+ print('\n[9] POPCOUNT SINGLE BITS + EXTREMES')
190
+ w_pop = model['pattern_recognition.popcount.weight']
191
+ b_pop = model['pattern_recognition.popcount.bias']
192
+ for i in range(8):
193
+ val = 1 << i
194
+ bits = int_to_bits(val)
195
+ pc = int((bits @ w_pop + b_pop).item())
196
+ if pc != 1: failures.append(f'popcount(1<<{i})')
197
+ if int((int_to_bits(0) @ w_pop + b_pop).item()) != 0: failures.append('popcount(0)')
198
+ if int((int_to_bits(255) @ w_pop + b_pop).item()) != 8: failures.append('popcount(255)')
199
+ print(' 10 tests')
200
+
201
+ print('\n[10] DISTRIBUTIVITY: A & (B | C) = (A & B) | (A & C)')
202
+ for a, b, c in [(255, 15, 240), (170, 85, 51), (0, 255, 0)]:
203
+ lhs = and_8bit(a, or_8bit(b, c))
204
+ rhs = or_8bit(and_8bit(a, b), and_8bit(a, c))
205
+ if lhs != rhs: failures.append(f'distrib: {a},{b},{c}')
206
+ print(' 3 tests')
207
+
208
+ print('\n' + '=' * 70)
209
+ if failures:
210
+ print(f'FAILURES: {len(failures)}')
211
+ for f in failures[:20]:
212
+ print(f' {f}')
213
+ else:
214
+ print('ALL 127 SKEPTICAL TESTS PASSED')
215
+ print('=' * 70)