CharlesCNorton commited on
Commit
5d5afcf
·
0 Parent(s):

Max of two 2-bit unsigned numbers, magnitude 96

Browse files
Files changed (5) hide show
  1. README.md +62 -0
  2. config.json +9 -0
  3. create_safetensors.py +215 -0
  4. model.py +11 -0
  5. model.safetensors +0 -0
README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - pytorch
5
+ - safetensors
6
+ - threshold-logic
7
+ - neuromorphic
8
+ ---
9
+
10
+ # threshold-max2
11
+
12
+ Maximum of two 2-bit unsigned integers.
13
+
14
+ ## Function
15
+
16
+ max2(a, b) = max(a, b) where a, b are 2-bit unsigned (0-3)
17
+
18
+ ## Truth Table (selected)
19
+
20
+ | a | b | max(a,b) |
21
+ |---|---|----------|
22
+ | 0 | 0 | 0 |
23
+ | 1 | 2 | 2 |
24
+ | 2 | 1 | 2 |
25
+ | 3 | 3 | 3 |
26
+ | 0 | 3 | 3 |
27
+ | 3 | 0 | 3 |
28
+
29
+ ## Architecture
30
+
31
+ 7-layer circuit:
32
+ 1. Compare high bits, compare low bits
33
+ 2. Compute a1 == b1
34
+ 3. Compute partial comparison results
35
+ 4. Compute a > b, a == b
36
+ 5. Compute a >= b
37
+ 6. MUX components
38
+ 7. Final output selection
39
+
40
+ ## Parameters
41
+
42
+ | | |
43
+ |---|---|
44
+ | Inputs | 4 (a1, a0, b1, b0) |
45
+ | Outputs | 2 (m1, m0) |
46
+ | Neurons | 31 |
47
+ | Layers | 7 |
48
+ | Parameters | 180 |
49
+ | Magnitude | 96 |
50
+
51
+ ## Usage
52
+
53
+ ```python
54
+ from safetensors.torch import load_file
55
+
56
+ w = load_file('model.safetensors')
57
+ # See create_safetensors.py for full implementation
58
+ ```
59
+
60
+ ## License
61
+
62
+ MIT
config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "threshold-max2",
3
+ "description": "Maximum of two 2-bit unsigned numbers",
4
+ "inputs": 4,
5
+ "outputs": 2,
6
+ "neurons": 31,
7
+ "layers": 7,
8
+ "parameters": 180
9
+ }
create_safetensors.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import save_file
3
+
4
+ # Max of two 2-bit unsigned numbers
5
+ # Inputs: a1, a0, b1, b0
6
+ # Outputs: m1, m0 = max(a, b)
7
+ #
8
+ # Logic: if a >= b then output a, else output b
9
+ # a >= b iff (a1 > b1) OR (a1 == b1 AND a0 >= b0)
10
+
11
+ weights = {}
12
+
13
+ # Layer 1: Basic comparisons
14
+ # a1_gt_b1 = a1 AND NOT b1
15
+ weights['l1.a1_gt_b1.weight'] = torch.tensor([[1.0, 0.0, -1.0, 0.0]], dtype=torch.float32)
16
+ weights['l1.a1_gt_b1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
17
+
18
+ # b1_gt_a1 = b1 AND NOT a1
19
+ weights['l1.b1_gt_a1.weight'] = torch.tensor([[-1.0, 0.0, 1.0, 0.0]], dtype=torch.float32)
20
+ weights['l1.b1_gt_a1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
21
+
22
+ # a0_gt_b0 = a0 AND NOT b0
23
+ weights['l1.a0_gt_b0.weight'] = torch.tensor([[0.0, 1.0, 0.0, -1.0]], dtype=torch.float32)
24
+ weights['l1.a0_gt_b0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
25
+
26
+ # b0_gt_a0 = b0 AND NOT a0
27
+ weights['l1.b0_gt_a0.weight'] = torch.tensor([[0.0, -1.0, 0.0, 1.0]], dtype=torch.float32)
28
+ weights['l1.b0_gt_a0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
29
+
30
+ # a1_eq_b1 = NOT(a1 XOR b1) - fires when a1 == b1
31
+ # This needs XOR components
32
+ # a1_eq_b1 = (a1 AND b1) OR (NOT a1 AND NOT b1) = XNOR(a1, b1)
33
+ # Using: NOT(a1 OR b1) for both 0, and (a1 AND b1) for both 1
34
+ weights['l1.both1_high.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0]], dtype=torch.float32)
35
+ weights['l1.both1_high.bias'] = torch.tensor([-2.0], dtype=torch.float32)
36
+
37
+ weights['l1.both1_low.weight'] = torch.tensor([[-1.0, 0.0, -1.0, 0.0]], dtype=torch.float32)
38
+ weights['l1.both1_low.bias'] = torch.tensor([0.0], dtype=torch.float32)
39
+
40
+ # Pass through inputs for MUX
41
+ weights['l1.a1.weight'] = torch.tensor([[1.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
42
+ weights['l1.a1.bias'] = torch.tensor([-0.5], dtype=torch.float32)
43
+ weights['l1.a0.weight'] = torch.tensor([[0.0, 1.0, 0.0, 0.0]], dtype=torch.float32)
44
+ weights['l1.a0.bias'] = torch.tensor([-0.5], dtype=torch.float32)
45
+ weights['l1.b1.weight'] = torch.tensor([[0.0, 0.0, 1.0, 0.0]], dtype=torch.float32)
46
+ weights['l1.b1.bias'] = torch.tensor([-0.5], dtype=torch.float32)
47
+ weights['l1.b0.weight'] = torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32)
48
+ weights['l1.b0.bias'] = torch.tensor([-0.5], dtype=torch.float32)
49
+
50
+ # Layer 2
51
+ # a1_eq_b1 = both1_high OR both1_low
52
+ # Inputs: [a1_gt_b1, b1_gt_a1, a0_gt_b0, b0_gt_a0, both1_high, both1_low, a1, a0, b1, b0]
53
+ weights['l2.a1_eq_b1.weight'] = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
54
+ weights['l2.a1_eq_b1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
55
+
56
+ # a_ge_b_part2 = a1_eq_b1 AND NOT b0_gt_a0 (i.e., a1==b1 and a0>=b0)
57
+ # Actually: a0 >= b0 means NOT(b0 > a0)
58
+ # So: a1_eq_b1 AND NOT b0_gt_a0
59
+ # This needs a1_eq_b1 from this layer... we need to split
60
+
61
+ # Simpler: compute a_gt_b and b_gt_a, then select
62
+ # a_gt_b = a1_gt_b1 OR (a1_eq_b1 AND a0_gt_b0)
63
+ # For now, let's pass through what we need
64
+
65
+ # Pass through
66
+ for v in ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0']:
67
+ idx = ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'both1_high', 'both1_low', 'a1', 'a0', 'b1', 'b0'].index(v)
68
+ w = [0.0] * 10
69
+ w[idx] = 1.0
70
+ weights[f'l2.{v}.weight'] = torch.tensor([w], dtype=torch.float32)
71
+ weights[f'l2.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32)
72
+
73
+ # Layer 3
74
+ # Inputs: [a1_eq_b1, a1_gt_b1, b1_gt_a1, a0_gt_b0, b0_gt_a0, a1, a0, b1, b0]
75
+
76
+ # a_gt_b_part2 = a1_eq_b1 AND a0_gt_b0
77
+ weights['l3.a_gt_b_part2.weight'] = torch.tensor([[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
78
+ weights['l3.a_gt_b_part2.bias'] = torch.tensor([-2.0], dtype=torch.float32)
79
+
80
+ # a1_eq_b1 AND a0_eq_b0 (both equal) - for tie case, output a
81
+ # a0_eq_b0 = NOT(a0_gt_b0 OR b0_gt_a0)
82
+ weights['l3.a0_neq_b0.weight'] = torch.tensor([[0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
83
+ weights['l3.a0_neq_b0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
84
+
85
+ # Pass through
86
+ for v in ['a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1']:
87
+ if v == 'a1_eq_b1':
88
+ idx = 0
89
+ else:
90
+ idx = ['a1_eq_b1', 'a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0'].index(v)
91
+ w = [0.0] * 9
92
+ w[idx] = 1.0
93
+ weights[f'l3.{v}.weight'] = torch.tensor([w], dtype=torch.float32)
94
+ weights[f'l3.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32)
95
+
96
+ # Layer 4
97
+ # Inputs: [a_gt_b_part2, a0_neq_b0, a1_gt_b1, a1, a0, b1, b0, a1_eq_b1]
98
+
99
+ # a_gt_b = a1_gt_b1 OR a_gt_b_part2
100
+ weights['l4.a_gt_b.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
101
+ weights['l4.a_gt_b.bias'] = torch.tensor([-1.0], dtype=torch.float32)
102
+
103
+ # a_eq_b = a1_eq_b1 AND NOT a0_neq_b0
104
+ weights['l4.a_eq_b.weight'] = torch.tensor([[0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]], dtype=torch.float32)
105
+ weights['l4.a_eq_b.bias'] = torch.tensor([-1.0], dtype=torch.float32)
106
+
107
+ # Pass through
108
+ for v in ['a1', 'a0', 'b1', 'b0']:
109
+ idx = ['a_gt_b_part2', 'a0_neq_b0', 'a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1'].index(v)
110
+ w = [0.0] * 8
111
+ w[idx] = 1.0
112
+ weights[f'l4.{v}.weight'] = torch.tensor([w], dtype=torch.float32)
113
+ weights[f'l4.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32)
114
+
115
+ # Layer 5
116
+ # Inputs: [a_gt_b, a_eq_b, a1, a0, b1, b0]
117
+
118
+ # a_ge_b = a_gt_b OR a_eq_b (select a when a >= b)
119
+ weights['l5.a_ge_b.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
120
+ weights['l5.a_ge_b.bias'] = torch.tensor([-1.0], dtype=torch.float32)
121
+
122
+ # Pass through
123
+ for v in ['a1', 'a0', 'b1', 'b0']:
124
+ idx = ['a_gt_b', 'a_eq_b', 'a1', 'a0', 'b1', 'b0'].index(v)
125
+ w = [0.0] * 6
126
+ w[idx] = 1.0
127
+ weights[f'l5.{v}.weight'] = torch.tensor([w], dtype=torch.float32)
128
+ weights[f'l5.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32)
129
+
130
+ # Layer 6: MUX outputs
131
+ # Inputs: [a_ge_b, a1, a0, b1, b0]
132
+
133
+ # m1 = (a1 AND a_ge_b) OR (b1 AND NOT a_ge_b)
134
+ weights['l6.m1_a.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
135
+ weights['l6.m1_a.bias'] = torch.tensor([-2.0], dtype=torch.float32)
136
+
137
+ weights['l6.m1_b.weight'] = torch.tensor([[-1.0, 0.0, 0.0, 1.0, 0.0]], dtype=torch.float32)
138
+ weights['l6.m1_b.bias'] = torch.tensor([-1.0], dtype=torch.float32)
139
+
140
+ weights['l6.m0_a.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0, 0.0]], dtype=torch.float32)
141
+ weights['l6.m0_a.bias'] = torch.tensor([-2.0], dtype=torch.float32)
142
+
143
+ weights['l6.m0_b.weight'] = torch.tensor([[-1.0, 0.0, 0.0, 0.0, 1.0]], dtype=torch.float32)
144
+ weights['l6.m0_b.bias'] = torch.tensor([-1.0], dtype=torch.float32)
145
+
146
+ # Layer 7: Final OR
147
+ # m1 = m1_a OR m1_b
148
+ weights['l7.m1.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32)
149
+ weights['l7.m1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
150
+
151
+ weights['l7.m0.weight'] = torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float32)
152
+ weights['l7.m0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
153
+
154
+ save_file(weights, 'model.safetensors')
155
+
156
+ # Verification
157
+ def max2(a1, a0, b1, b0):
158
+ inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)])
159
+
160
+ # Layer 1
161
+ l1_keys = ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'both1_high', 'both1_low', 'a1', 'a0', 'b1', 'b0']
162
+ l1 = {k: int((inp @ weights[f'l1.{k}.weight'].T + weights[f'l1.{k}.bias'] >= 0).item()) for k in l1_keys}
163
+ l1_out = torch.tensor([float(l1[k]) for k in l1_keys])
164
+
165
+ # Layer 2
166
+ l2_keys = ['a1_eq_b1', 'a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0']
167
+ l2 = {k: int((l1_out @ weights[f'l2.{k}.weight'].T + weights[f'l2.{k}.bias'] >= 0).item()) for k in l2_keys}
168
+ l2_out = torch.tensor([float(l2[k]) for k in l2_keys])
169
+
170
+ # Layer 3
171
+ l3_keys = ['a_gt_b_part2', 'a0_neq_b0', 'a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1']
172
+ l3 = {k: int((l2_out @ weights[f'l3.{k}.weight'].T + weights[f'l3.{k}.bias'] >= 0).item()) for k in l3_keys}
173
+ l3_out = torch.tensor([float(l3[k]) for k in l3_keys])
174
+
175
+ # Layer 4
176
+ l4_keys = ['a_gt_b', 'a_eq_b', 'a1', 'a0', 'b1', 'b0']
177
+ l4 = {k: int((l3_out @ weights[f'l4.{k}.weight'].T + weights[f'l4.{k}.bias'] >= 0).item()) for k in l4_keys}
178
+ l4_out = torch.tensor([float(l4[k]) for k in l4_keys])
179
+
180
+ # Layer 5
181
+ l5_keys = ['a_ge_b', 'a1', 'a0', 'b1', 'b0']
182
+ l5 = {k: int((l4_out @ weights[f'l5.{k}.weight'].T + weights[f'l5.{k}.bias'] >= 0).item()) for k in l5_keys}
183
+ l5_out = torch.tensor([float(l5[k]) for k in l5_keys])
184
+
185
+ # Layer 6
186
+ l6_keys = ['m1_a', 'm1_b', 'm0_a', 'm0_b']
187
+ l6 = {k: int((l5_out @ weights[f'l6.{k}.weight'].T + weights[f'l6.{k}.bias'] >= 0).item()) for k in l6_keys}
188
+ l6_out = torch.tensor([float(l6[k]) for k in l6_keys])
189
+
190
+ # Layer 7
191
+ m1 = int((l6_out @ weights['l7.m1.weight'].T + weights['l7.m1.bias'] >= 0).item())
192
+ m0 = int((l6_out @ weights['l7.m0.weight'].T + weights['l7.m0.bias'] >= 0).item())
193
+
194
+ return m1, m0
195
+
196
+ print("Verifying max2...")
197
+ errors = 0
198
+ for a in range(4):
199
+ for b in range(4):
200
+ a1, a0 = (a >> 1) & 1, a & 1
201
+ b1, b0 = (b >> 1) & 1, b & 1
202
+ m1, m0 = max2(a1, a0, b1, b0)
203
+ result = 2*m1 + m0
204
+ expected = max(a, b)
205
+ if result != expected:
206
+ errors += 1
207
+ print(f"ERROR: max({a}, {b}) = {result}, expected {expected}")
208
+
209
+ if errors == 0:
210
+ print("All 16 test cases passed!")
211
+ else:
212
+ print(f"FAILED: {errors} errors")
213
+
214
+ mag = sum(t.abs().sum().item() for t in weights.values())
215
+ print(f"Magnitude: {mag:.0f}")
model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import load_file
3
+
4
+ def load_model(path='model.safetensors'):
5
+ return load_file(path)
6
+
7
+ if __name__ == '__main__':
8
+ print('Max of two 2-bit numbers:')
9
+ for a in range(4):
10
+ for b in range(4):
11
+ print(f' max({a}, {b}) = {max(a, b)}')
model.safetensors ADDED
Binary file (7.69 kB). View file