phanerozoic commited on
Commit
926f717
·
verified ·
1 Parent(s): 12be566

Delete skeptic_test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. skeptic_test.py +0 -215
skeptic_test.py DELETED
@@ -1,215 +0,0 @@
1
- import warnings
2
- warnings.filterwarnings('ignore')
3
- from safetensors.torch import load_file
4
- import torch
5
-
6
- model = load_file('neural_computer.safetensors')
7
-
8
- def heaviside(x):
9
- return (x >= 0).float()
10
-
11
- def int_to_bits(val):
12
- return torch.tensor([(val >> (7-i)) & 1 for i in range(8)], dtype=torch.float32)
13
-
14
- def eval_xor_bool(a, b):
15
- inp = torch.tensor([float(a), float(b)], dtype=torch.float32)
16
- w1_or = model['boolean.xor.layer1.neuron1.weight']
17
- b1_or = model['boolean.xor.layer1.neuron1.bias']
18
- w1_nand = model['boolean.xor.layer1.neuron2.weight']
19
- b1_nand = model['boolean.xor.layer1.neuron2.bias']
20
- w2 = model['boolean.xor.layer2.weight']
21
- b2 = model['boolean.xor.layer2.bias']
22
- h_or = heaviside(inp @ w1_or + b1_or)
23
- h_nand = heaviside(inp @ w1_nand + b1_nand)
24
- hidden = torch.tensor([h_or.item(), h_nand.item()])
25
- return int(heaviside(hidden @ w2 + b2).item())
26
-
27
- def eval_xor_arith(inp, prefix):
28
- w1_or = model[f'{prefix}.layer1.or.weight']
29
- b1_or = model[f'{prefix}.layer1.or.bias']
30
- w1_nand = model[f'{prefix}.layer1.nand.weight']
31
- b1_nand = model[f'{prefix}.layer1.nand.bias']
32
- w2 = model[f'{prefix}.layer2.weight']
33
- b2 = model[f'{prefix}.layer2.bias']
34
- h_or = heaviside(inp @ w1_or + b1_or)
35
- h_nand = heaviside(inp @ w1_nand + b1_nand)
36
- hidden = torch.tensor([h_or.item(), h_nand.item()])
37
- return heaviside(hidden @ w2 + b2).item()
38
-
39
- def eval_full_adder(a, b, cin, prefix):
40
- inp_ab = torch.tensor([a, b], dtype=torch.float32)
41
- ha1_sum = eval_xor_arith(inp_ab, f'{prefix}.ha1.sum')
42
- w_c1 = model[f'{prefix}.ha1.carry.weight']
43
- b_c1 = model[f'{prefix}.ha1.carry.bias']
44
- ha1_carry = heaviside(inp_ab @ w_c1 + b_c1).item()
45
- inp_ha2 = torch.tensor([ha1_sum, cin], dtype=torch.float32)
46
- ha2_sum = eval_xor_arith(inp_ha2, f'{prefix}.ha2.sum')
47
- w_c2 = model[f'{prefix}.ha2.carry.weight']
48
- b_c2 = model[f'{prefix}.ha2.carry.bias']
49
- ha2_carry = heaviside(inp_ha2 @ w_c2 + b_c2).item()
50
- inp_cout = torch.tensor([ha1_carry, ha2_carry], dtype=torch.float32)
51
- w_or = model[f'{prefix}.carry_or.weight']
52
- b_or = model[f'{prefix}.carry_or.bias']
53
- cout = heaviside(inp_cout @ w_or + b_or).item()
54
- return int(ha2_sum), int(cout)
55
-
56
- def add_8bit(a, b):
57
- carry = 0.0
58
- result = 0
59
- for i in range(8):
60
- s, carry = eval_full_adder(float((a >> i) & 1), float((b >> i) & 1), carry, f'arithmetic.ripplecarry8bit.fa{i}')
61
- result |= (s << i)
62
- return result, int(carry)
63
-
64
- def xor_8bit(a, b):
65
- result = 0
66
- for i in range(8):
67
- bit = eval_xor_bool((a >> i) & 1, (b >> i) & 1)
68
- result |= (bit << i)
69
- return result
70
-
71
- def and_8bit(a, b):
72
- result = 0
73
- w = model['boolean.and.weight']
74
- bias = model['boolean.and.bias']
75
- for i in range(8):
76
- inp = torch.tensor([float((a >> i) & 1), float((b >> i) & 1)], dtype=torch.float32)
77
- out = int(heaviside(inp @ w + bias).item())
78
- result |= (out << i)
79
- return result
80
-
81
- def or_8bit(a, b):
82
- result = 0
83
- w = model['boolean.or.weight']
84
- bias = model['boolean.or.bias']
85
- for i in range(8):
86
- inp = torch.tensor([float((a >> i) & 1), float((b >> i) & 1)], dtype=torch.float32)
87
- out = int(heaviside(inp @ w + bias).item())
88
- result |= (out << i)
89
- return result
90
-
91
- def not_8bit(a):
92
- result = 0
93
- w = model['boolean.not.weight']
94
- bias = model['boolean.not.bias']
95
- for i in range(8):
96
- inp = torch.tensor([float((a >> i) & 1)], dtype=torch.float32)
97
- out = int(heaviside(inp @ w + bias).item())
98
- result |= (out << i)
99
- return result
100
-
101
- def gt(a, b):
102
- a_bits, b_bits = int_to_bits(a), int_to_bits(b)
103
- w = model['arithmetic.greaterthan8bit.comparator']
104
- return 1 if ((a_bits - b_bits) @ w).item() > 0 else 0
105
-
106
- def lt(a, b):
107
- a_bits, b_bits = int_to_bits(a), int_to_bits(b)
108
- w = model['arithmetic.lessthan8bit.comparator']
109
- return 1 if ((b_bits - a_bits) @ w).item() > 0 else 0
110
-
111
- def eq(a, b):
112
- return 1 if (gt(a,b) == 0 and lt(a,b) == 0) else 0
113
-
114
- print('=' * 70)
115
- print('SKEPTICAL NERD TESTS')
116
- print('=' * 70)
117
-
118
- failures = []
119
-
120
- print('\n[1] IDENTITY LAWS')
121
- for a in [0, 1, 127, 128, 255, 170, 85]:
122
- r, _ = add_8bit(a, 0)
123
- if r != a: failures.append(f'A+0: {a}')
124
- if xor_8bit(a, 0) != a: failures.append(f'A^0: {a}')
125
- if and_8bit(a, 255) != a: failures.append(f'A&255: {a}')
126
- if or_8bit(a, 0) != a: failures.append(f'A|0: {a}')
127
- print(' 28 tests')
128
-
129
- print('\n[2] ANNIHILATION LAWS')
130
- for a in [0, 1, 127, 128, 255]:
131
- if and_8bit(a, 0) != 0: failures.append(f'A&0: {a}')
132
- if or_8bit(a, 255) != 255: failures.append(f'A|255: {a}')
133
- if xor_8bit(a, a) != 0: failures.append(f'A^A: {a}')
134
- print(' 15 tests')
135
-
136
- print('\n[3] INVOLUTION (~~A = A)')
137
- for a in [0, 1, 127, 128, 255, 170]:
138
- if not_8bit(not_8bit(a)) != a: failures.append(f'~~A: {a}')
139
- print(' 6 tests')
140
-
141
- print('\n[4] TWOS COMPLEMENT: A + ~A + 1 = 0')
142
- for a in [0, 1, 42, 127, 128, 255]:
143
- not_a = not_8bit(a)
144
- r1, _ = add_8bit(a, not_a)
145
- r2, _ = add_8bit(r1, 1)
146
- if r2 != 0: failures.append(f'twos comp: {a}')
147
- print(' 6 tests')
148
-
149
- print('\n[5] CARRY PROPAGATION (worst case)')
150
- cases = [(255, 1, 0), (127, 129, 0), (1, 255, 0), (128, 128, 0), (255, 255, 254)]
151
- for a, b, exp in cases:
152
- r, _ = add_8bit(a, b)
153
- if r != exp: failures.append(f'carry: {a}+{b}={r}, expected {exp}')
154
- print(' 5 tests')
155
-
156
- print('\n[6] COMMUTATIVITY')
157
- pairs = [(17, 42), (0, 255), (128, 127), (1, 254), (170, 85)]
158
- for a, b in pairs:
159
- r1, _ = add_8bit(a, b)
160
- r2, _ = add_8bit(b, a)
161
- if r1 != r2: failures.append(f'add commute: {a},{b}')
162
- if xor_8bit(a, b) != xor_8bit(b, a): failures.append(f'xor commute: {a},{b}')
163
- if and_8bit(a, b) != and_8bit(b, a): failures.append(f'and commute: {a},{b}')
164
- if or_8bit(a, b) != or_8bit(b, a): failures.append(f'or commute: {a},{b}')
165
- print(' 20 tests')
166
-
167
- print('\n[7] DE MORGAN')
168
- for a, b in [(0, 0), (0, 255), (255, 0), (255, 255), (170, 85)]:
169
- lhs = not_8bit(and_8bit(a, b))
170
- rhs = or_8bit(not_8bit(a), not_8bit(b))
171
- if lhs != rhs: failures.append(f'DM1: {a},{b}')
172
- lhs = not_8bit(or_8bit(a, b))
173
- rhs = and_8bit(not_8bit(a), not_8bit(b))
174
- if lhs != rhs: failures.append(f'DM2: {a},{b}')
175
- print(' 10 tests')
176
-
177
- print('\n[8] COMPARATOR EDGE CASES')
178
- cmp_tests = [
179
- (0, 0, 0, 0, 1), (0, 1, 0, 1, 0), (1, 0, 1, 0, 0),
180
- (127, 128, 0, 1, 0), (128, 127, 1, 0, 0),
181
- (255, 255, 0, 0, 1), (255, 0, 1, 0, 0), (0, 255, 0, 1, 0),
182
- ]
183
- for a, b, exp_gt, exp_lt, exp_eq in cmp_tests:
184
- if gt(a, b) != exp_gt: failures.append(f'gt({a},{b})')
185
- if lt(a, b) != exp_lt: failures.append(f'lt({a},{b})')
186
- if eq(a, b) != exp_eq: failures.append(f'eq({a},{b})')
187
- print(' 24 tests')
188
-
189
- print('\n[9] POPCOUNT SINGLE BITS + EXTREMES')
190
- w_pop = model['pattern_recognition.popcount.weight']
191
- b_pop = model['pattern_recognition.popcount.bias']
192
- for i in range(8):
193
- val = 1 << i
194
- bits = int_to_bits(val)
195
- pc = int((bits @ w_pop + b_pop).item())
196
- if pc != 1: failures.append(f'popcount(1<<{i})')
197
- if int((int_to_bits(0) @ w_pop + b_pop).item()) != 0: failures.append('popcount(0)')
198
- if int((int_to_bits(255) @ w_pop + b_pop).item()) != 8: failures.append('popcount(255)')
199
- print(' 10 tests')
200
-
201
- print('\n[10] DISTRIBUTIVITY: A & (B | C) = (A & B) | (A & C)')
202
- for a, b, c in [(255, 15, 240), (170, 85, 51), (0, 255, 0)]:
203
- lhs = and_8bit(a, or_8bit(b, c))
204
- rhs = or_8bit(and_8bit(a, b), and_8bit(a, c))
205
- if lhs != rhs: failures.append(f'distrib: {a},{b},{c}')
206
- print(' 3 tests')
207
-
208
- print('\n' + '=' * 70)
209
- if failures:
210
- print(f'FAILURES: {len(failures)}')
211
- for f in failures[:20]:
212
- print(f' {f}')
213
- else:
214
- print('ALL 127 SKEPTICAL TESTS PASSED')
215
- print('=' * 70)