phanerozoic commited on
Commit
f5ee650
·
verified ·
1 Parent(s): 4d5bca2

Upload arithmetic_eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. arithmetic_eval.py +1664 -0
arithmetic_eval.py ADDED
@@ -0,0 +1,1664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ARITHMETIC EVALUATOR
3
+ =====================
4
+ Introspection-based exhaustive testing for arithmetic circuits in threshold-calculus.
5
+
6
+ Stripped-down version of comprehensive_eval.py for arithmetic-only weights.
7
+ Removes: ALU, Control, Manifest, Error Detection
8
+ Keeps: Boolean, Arithmetic, Threshold, Modular, Combinational, Pattern Recognition
9
+ """
10
+
11
+ import torch
12
+ from safetensors import safe_open
13
+ from typing import Dict, List, Tuple, Optional, Callable
14
+ from dataclasses import dataclass
15
+ from collections import defaultdict
16
+ import json
17
+ import os
18
+ import re
19
+ import time
20
+
21
+
22
+ @dataclass
23
+ class TestResult:
24
+ """Result of testing a single circuit."""
25
+ circuit_name: str
26
+ passed: int
27
+ total: int
28
+ failures: List[Tuple]
29
+
30
+ @property
31
+ def success(self) -> bool:
32
+ return self.passed == self.total
33
+
34
+ @property
35
+ def rate(self) -> float:
36
+ return self.passed / self.total if self.total > 0 else 0.0
37
+
38
+
39
+ def heaviside(x: torch.Tensor) -> torch.Tensor:
40
+ """Threshold activation: 1 if x >= 0, else 0."""
41
+ return (x >= 0).float()
42
+
43
+
44
+ class TensorRegistry:
45
+ """Discovers and organizes tensors from a safetensors file."""
46
+
47
+ def __init__(self, path: str):
48
+ self.path = path
49
+ self.tensors: Dict[str, torch.Tensor] = {}
50
+ self.circuits: Dict[str, List[str]] = defaultdict(list)
51
+ self.accessed: set = set()
52
+ self._load()
53
+ self._organize()
54
+
55
+ def _load(self):
56
+ with safe_open(self.path, framework='pt') as f:
57
+ for name in f.keys():
58
+ self.tensors[name] = f.get_tensor(name).float()
59
+
60
+ def _organize(self):
61
+ for name in self.tensors:
62
+ circuit = self._extract_circuit(name)
63
+ self.circuits[circuit].append(name)
64
+
65
+ def _extract_circuit(self, tensor_name: str) -> str:
66
+ if tensor_name.endswith('.weight'):
67
+ return tensor_name[:-7]
68
+ elif tensor_name.endswith('.bias'):
69
+ return tensor_name[:-5]
70
+ return tensor_name
71
+
72
+ def get(self, name: str) -> torch.Tensor:
73
+ self.accessed.add(name)
74
+ return self.tensors[name]
75
+
76
+ def has(self, name: str) -> bool:
77
+ return name in self.tensors
78
+
79
+ def get_category(self, prefix: str) -> Dict[str, List[str]]:
80
+ return {k: v for k, v in self.circuits.items() if k.startswith(prefix)}
81
+
82
+ @property
83
+ def categories(self) -> List[str]:
84
+ cats = set()
85
+ for name in self.tensors:
86
+ cats.add(name.split('.')[0])
87
+ return sorted(cats)
88
+
89
+ @property
90
+ def untested(self) -> List[str]:
91
+ return sorted(set(self.tensors.keys()) - self.accessed)
92
+
93
+ @property
94
+ def coverage(self) -> float:
95
+ if not self.tensors:
96
+ return 0.0
97
+ return len(self.accessed) / len(self.tensors)
98
+
99
+ def coverage_report(self) -> str:
100
+ lines = []
101
+ lines.append(f"TENSOR COVERAGE: {len(self.accessed)}/{len(self.tensors)} ({100*self.coverage:.2f}%)")
102
+
103
+ untested = self.untested
104
+ if untested:
105
+ by_category: Dict[str, List[str]] = defaultdict(list)
106
+ for name in untested:
107
+ cat = name.split('.')[0]
108
+ by_category[cat].append(name)
109
+
110
+ lines.append(f"\nUNTESTED TENSORS ({len(untested)}):")
111
+ for cat in sorted(by_category.keys()):
112
+ tensors = by_category[cat]
113
+ lines.append(f"\n {cat}/ ({len(tensors)} tensors):")
114
+ for t in tensors[:10]:
115
+ lines.append(f" - {t}")
116
+ if len(tensors) > 10:
117
+ lines.append(f" ... and {len(tensors) - 10} more")
118
+ else:
119
+ lines.append("\nAll tensors tested!")
120
+
121
+ return '\n'.join(lines)
122
+
123
+
124
+ class RoutingEvaluator:
125
+ """Evaluates circuits using routing information."""
126
+
127
+ def __init__(self, registry: TensorRegistry, routing_path: str, device: str = 'cpu'):
128
+ self.reg = registry
129
+ self.device = device
130
+ self.routing = self._load_routing(routing_path)
131
+
132
+ def _load_routing(self, path: str) -> dict:
133
+ if os.path.exists(path):
134
+ with open(path, 'r') as f:
135
+ return json.load(f)
136
+ return {'circuits': {}}
137
+
138
+ def has_routing(self, circuit: str) -> bool:
139
+ return circuit in self.routing.get('circuits', {})
140
+
141
+ def eval_gate(self, gate_path: str, inputs: torch.Tensor) -> torch.Tensor:
142
+ w = self.reg.get(f'{gate_path}.weight')
143
+ b = self.reg.get(f'{gate_path}.bias')
144
+ return heaviside((inputs * w).sum(-1) + b)
145
+
146
+ def eval_division(self, dividend: int, divisor: int) -> Tuple[int, int]:
147
+ if not self.has_routing('arithmetic.div8bit'):
148
+ return dividend // divisor, dividend % divisor
149
+
150
+ routing = self.routing['circuits']['arithmetic.div8bit']
151
+ internal = routing['internal']
152
+
153
+ dividend_bits = [(dividend >> i) & 1 for i in range(8)]
154
+ divisor_bits = [(divisor >> i) & 1 for i in range(8)]
155
+
156
+ values = {}
157
+ values['#0'] = 0.0
158
+ values['#1'] = 1.0
159
+ for i in range(8):
160
+ values[f'$dividend[{i}]'] = float(dividend_bits[i])
161
+ values[f'$divisor[{i}]'] = float(divisor_bits[i])
162
+
163
+ def resolve(src: str) -> float:
164
+ if src in values:
165
+ return values[src]
166
+ if src.startswith('#'):
167
+ return float(src[1:])
168
+ full_path = f'arithmetic.div8bit.{src}'
169
+ if full_path in values:
170
+ return values[full_path]
171
+ raise KeyError(f"Cannot resolve: {src}")
172
+
173
+ def eval_gate_from_routing(gate_name: str, sources: list) -> float:
174
+ gate_path = f'arithmetic.div8bit.{gate_name}'
175
+ if not self.reg.has(f'{gate_path}.weight'):
176
+ inp_vals = [resolve(s) for s in sources]
177
+ return float(sum(inp_vals) >= len(inp_vals))
178
+
179
+ w = self.reg.get(f'{gate_path}.weight')
180
+ b = self.reg.get(f'{gate_path}.bias')
181
+ inp_vals = torch.tensor([resolve(s) for s in sources], device=self.device, dtype=torch.float32)
182
+ return heaviside((inp_vals * w).sum() + b).item()
183
+
184
+ for stage in range(8):
185
+ stage_gates = [g for g in internal.keys() if g.startswith(f'stage{stage}.')]
186
+ sorted_stage_gates = self._topological_sort_subset(internal, stage_gates)
187
+ for gate_name in sorted_stage_gates:
188
+ sources = internal[gate_name]
189
+ values[f'arithmetic.div8bit.{gate_name}'] = eval_gate_from_routing(gate_name, sources)
190
+
191
+ for gate_name in ['quotient0', 'quotient1', 'quotient2', 'quotient3',
192
+ 'quotient4', 'quotient5', 'quotient6', 'quotient7',
193
+ 'remainder0', 'remainder1', 'remainder2', 'remainder3',
194
+ 'remainder4', 'remainder5', 'remainder6', 'remainder7']:
195
+ if gate_name in internal:
196
+ sources = internal[gate_name]
197
+ values[f'arithmetic.div8bit.{gate_name}'] = eval_gate_from_routing(gate_name, sources)
198
+
199
+ quotient_bits = [int(values.get(f'arithmetic.div8bit.stage{i}.cmp', 0)) for i in range(8)]
200
+ remainder_bits = [int(values.get(f'arithmetic.div8bit.stage7.mux{i}.or', 0)) for i in range(8)]
201
+
202
+ quotient = sum(quotient_bits[i] << (7 - i) for i in range(8))
203
+ remainder = sum(remainder_bits[i] << i for i in range(8))
204
+
205
+ return quotient, remainder
206
+
207
+ def _topological_sort_subset(self, internal: dict, subset: list) -> list:
208
+ subset_set = set(subset)
209
+ deps = {}
210
+ for gate in subset:
211
+ deps[gate] = set()
212
+ for src in internal.get(gate, []):
213
+ if src.startswith('$') or src.startswith('#'):
214
+ continue
215
+ if src in subset_set:
216
+ deps[gate].add(src)
217
+
218
+ result = []
219
+ visited = set()
220
+ temp = set()
221
+
222
+ def visit(node):
223
+ if node in temp:
224
+ return
225
+ if node in visited:
226
+ return
227
+ temp.add(node)
228
+ for dep in deps.get(node, []):
229
+ visit(dep)
230
+ temp.remove(node)
231
+ visited.add(node)
232
+ result.append(node)
233
+
234
+ for node in subset:
235
+ visit(node)
236
+
237
+ return result
238
+
239
+
240
+ class CircuitEvaluator:
241
+ """Evaluates individual circuit types."""
242
+
243
+ def __init__(self, registry: TensorRegistry, device: str = 'cuda', routing_path: str = None):
244
+ self.reg = registry
245
+ self.device = device
246
+ if routing_path is None:
247
+ routing_path = os.path.join(os.path.dirname(__file__), 'routing.json')
248
+ self.routing_eval = RoutingEvaluator(registry, routing_path, device)
249
+ self._move_to_device()
250
+
251
+ def _move_to_device(self):
252
+ for name in self.reg.tensors:
253
+ self.reg.tensors[name] = self.reg.tensors[name].to(self.device)
254
+
255
+ # =========================================================================
256
+ # PRIMITIVE EVALUATORS
257
+ # =========================================================================
258
+
259
+ def eval_single_layer(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
260
+ w = self.reg.get(f'{prefix}.weight')
261
+ b = self.reg.get(f'{prefix}.bias')
262
+ return heaviside(inputs @ w + b)
263
+
264
+ def eval_two_layer_xor(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
265
+ w_or = self.reg.get(f'{prefix}.layer1.or.weight')
266
+ b_or = self.reg.get(f'{prefix}.layer1.or.bias')
267
+ w_nand = self.reg.get(f'{prefix}.layer1.nand.weight')
268
+ b_nand = self.reg.get(f'{prefix}.layer1.nand.bias')
269
+
270
+ h_or = heaviside(inputs @ w_or + b_or)
271
+ h_nand = heaviside(inputs @ w_nand + b_nand)
272
+ hidden = torch.stack([h_or, h_nand], dim=-1)
273
+
274
+ w2 = self.reg.get(f'{prefix}.layer2.weight')
275
+ b2 = self.reg.get(f'{prefix}.layer2.bias')
276
+ return heaviside((hidden * w2).sum(-1) + b2)
277
+
278
+ def eval_two_layer_neuron(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
279
+ w1_n1 = self.reg.get(f'{prefix}.layer1.neuron1.weight')
280
+ b1_n1 = self.reg.get(f'{prefix}.layer1.neuron1.bias')
281
+ w1_n2 = self.reg.get(f'{prefix}.layer1.neuron2.weight')
282
+ b1_n2 = self.reg.get(f'{prefix}.layer1.neuron2.bias')
283
+
284
+ h1 = heaviside(inputs @ w1_n1 + b1_n1)
285
+ h2 = heaviside(inputs @ w1_n2 + b1_n2)
286
+ hidden = torch.stack([h1, h2], dim=-1)
287
+
288
+ w2 = self.reg.get(f'{prefix}.layer2.weight')
289
+ b2 = self.reg.get(f'{prefix}.layer2.bias')
290
+ return heaviside((hidden * w2).sum(-1) + b2)
291
+
292
+ def eval_two_layer_xnor(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
293
+ w_and = self.reg.get(f'{prefix}.layer1.and.weight')
294
+ b_and = self.reg.get(f'{prefix}.layer1.and.bias')
295
+ w_nor = self.reg.get(f'{prefix}.layer1.nor.weight')
296
+ b_nor = self.reg.get(f'{prefix}.layer1.nor.bias')
297
+
298
+ h_and = heaviside(inputs @ w_and + b_and)
299
+ h_nor = heaviside(inputs @ w_nor + b_nor)
300
+ hidden = torch.stack([h_and, h_nor], dim=-1)
301
+
302
+ w2 = self.reg.get(f'{prefix}.layer2.weight')
303
+ b2 = self.reg.get(f'{prefix}.layer2.bias')
304
+ return heaviside((hidden * w2).sum(-1) + b2)
305
+
306
+ # =========================================================================
307
+ # BOOLEAN GATES
308
+ # =========================================================================
309
+
310
+ def test_boolean_and(self) -> TestResult:
311
+ inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
312
+ expected = torch.tensor([0,0,0,1], device=self.device, dtype=torch.float32)
313
+ output = self.eval_single_layer('boolean.and', inputs)
314
+ failures = []
315
+ passed = 0
316
+ for i in range(4):
317
+ if output[i] == expected[i]:
318
+ passed += 1
319
+ else:
320
+ failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
321
+ return TestResult('boolean.and', passed, 4, failures)
322
+
323
+ def test_boolean_or(self) -> TestResult:
324
+ inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
325
+ expected = torch.tensor([0,1,1,1], device=self.device, dtype=torch.float32)
326
+ output = self.eval_single_layer('boolean.or', inputs)
327
+ failures = []
328
+ passed = 0
329
+ for i in range(4):
330
+ if output[i] == expected[i]:
331
+ passed += 1
332
+ else:
333
+ failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
334
+ return TestResult('boolean.or', passed, 4, failures)
335
+
336
+ def test_boolean_nand(self) -> TestResult:
337
+ inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
338
+ expected = torch.tensor([1,1,1,0], device=self.device, dtype=torch.float32)
339
+ output = self.eval_single_layer('boolean.nand', inputs)
340
+ failures = []
341
+ passed = 0
342
+ for i in range(4):
343
+ if output[i] == expected[i]:
344
+ passed += 1
345
+ else:
346
+ failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
347
+ return TestResult('boolean.nand', passed, 4, failures)
348
+
349
+ def test_boolean_nor(self) -> TestResult:
350
+ inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
351
+ expected = torch.tensor([1,0,0,0], device=self.device, dtype=torch.float32)
352
+ output = self.eval_single_layer('boolean.nor', inputs)
353
+ failures = []
354
+ passed = 0
355
+ for i in range(4):
356
+ if output[i] == expected[i]:
357
+ passed += 1
358
+ else:
359
+ failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
360
+ return TestResult('boolean.nor', passed, 4, failures)
361
+
362
+ def test_boolean_not(self) -> TestResult:
363
+ inputs = torch.tensor([[0],[1]], device=self.device, dtype=torch.float32)
364
+ expected = torch.tensor([1,0], device=self.device, dtype=torch.float32)
365
+ output = self.eval_single_layer('boolean.not', inputs)
366
+ failures = []
367
+ passed = 0
368
+ for i in range(2):
369
+ if output[i] == expected[i]:
370
+ passed += 1
371
+ else:
372
+ failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
373
+ return TestResult('boolean.not', passed, 2, failures)
374
+
375
+ def test_boolean_xor(self) -> TestResult:
376
+ inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
377
+ expected = torch.tensor([0,1,1,0], device=self.device, dtype=torch.float32)
378
+ output = self.eval_two_layer_neuron('boolean.xor', inputs)
379
+ failures = []
380
+ passed = 0
381
+ for i in range(4):
382
+ if output[i] == expected[i]:
383
+ passed += 1
384
+ else:
385
+ failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
386
+ return TestResult('boolean.xor', passed, 4, failures)
387
+
388
+ def test_boolean_xnor(self) -> TestResult:
389
+ inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
390
+ expected = torch.tensor([1,0,0,1], device=self.device, dtype=torch.float32)
391
+ output = self.eval_two_layer_neuron('boolean.xnor', inputs)
392
+ failures = []
393
+ passed = 0
394
+ for i in range(4):
395
+ if output[i] == expected[i]:
396
+ passed += 1
397
+ else:
398
+ failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
399
+ return TestResult('boolean.xnor', passed, 4, failures)
400
+
401
+ def test_boolean_implies(self) -> TestResult:
402
+ inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
403
+ expected = torch.tensor([1,1,0,1], device=self.device, dtype=torch.float32)
404
+ output = self.eval_single_layer('boolean.implies', inputs)
405
+ failures = []
406
+ passed = 0
407
+ for i in range(4):
408
+ if output[i] == expected[i]:
409
+ passed += 1
410
+ else:
411
+ failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
412
+ return TestResult('boolean.implies', passed, 4, failures)
413
+
414
+ def test_boolean_biimplies(self) -> TestResult:
415
+ inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
416
+ expected = torch.tensor([1,0,0,1], device=self.device, dtype=torch.float32)
417
+ output = self.eval_two_layer_neuron('boolean.biimplies', inputs)
418
+ failures = []
419
+ passed = 0
420
+ for i in range(4):
421
+ if output[i] == expected[i]:
422
+ passed += 1
423
+ else:
424
+ failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
425
+ return TestResult('boolean.biimplies', passed, 4, failures)
426
+
427
+ # =========================================================================
428
+ # ARITHMETIC - HALF ADDER
429
+ # =========================================================================
430
+
431
+ def eval_half_adder(self, prefix: str, a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
432
+ inputs = torch.stack([a, b], dim=-1)
433
+ sum_out = self.eval_two_layer_xor(f'{prefix}.sum', inputs)
434
+ carry_out = self.eval_single_layer(f'{prefix}.carry', inputs)
435
+ return sum_out, carry_out
436
+
437
+ def test_half_adder(self) -> TestResult:
438
+ failures = []
439
+ passed = 0
440
+ for a in [0, 1]:
441
+ for b in [0, 1]:
442
+ a_t = torch.tensor([float(a)], device=self.device)
443
+ b_t = torch.tensor([float(b)], device=self.device)
444
+ sum_out, carry_out = self.eval_half_adder('arithmetic.halfadder', a_t, b_t)
445
+ expected_sum = a ^ b
446
+ expected_carry = a & b
447
+ if sum_out.item() == expected_sum and carry_out.item() == expected_carry:
448
+ passed += 1
449
+ else:
450
+ failures.append(((a, b), (expected_sum, expected_carry),
451
+ (sum_out.item(), carry_out.item())))
452
+ return TestResult('arithmetic.halfadder', passed, 4, failures)
453
+
454
+ # =========================================================================
455
+ # ARITHMETIC - FULL ADDER
456
+ # =========================================================================
457
+
458
+ def eval_full_adder(self, prefix: str, a: torch.Tensor, b: torch.Tensor,
459
+ cin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
460
+ ha1_sum, ha1_carry = self.eval_half_adder(f'{prefix}.ha1', a, b)
461
+ ha2_sum, ha2_carry = self.eval_half_adder(f'{prefix}.ha2', ha1_sum, cin)
462
+ carry_inputs = torch.stack([ha1_carry, ha2_carry], dim=-1)
463
+ carry_out = self.eval_single_layer(f'{prefix}.carry_or', carry_inputs)
464
+ return ha2_sum, carry_out
465
+
466
+ def test_full_adder(self) -> TestResult:
467
+ failures = []
468
+ passed = 0
469
+ for a in [0, 1]:
470
+ for b in [0, 1]:
471
+ for cin in [0, 1]:
472
+ a_t = torch.tensor([float(a)], device=self.device)
473
+ b_t = torch.tensor([float(b)], device=self.device)
474
+ cin_t = torch.tensor([float(cin)], device=self.device)
475
+ sum_out, cout = self.eval_full_adder('arithmetic.fulladder', a_t, b_t, cin_t)
476
+ expected_sum = (a + b + cin) & 1
477
+ expected_cout = (a + b + cin) >> 1
478
+ if sum_out.item() == expected_sum and cout.item() == expected_cout:
479
+ passed += 1
480
+ else:
481
+ failures.append(((a, b, cin), (expected_sum, expected_cout),
482
+ (sum_out.item(), cout.item())))
483
+ return TestResult('arithmetic.fulladder', passed, 8, failures)
484
+
485
+ # =========================================================================
486
+ # ARITHMETIC - RIPPLE CARRY ADDERS
487
+ # =========================================================================
488
+
489
+ def eval_ripple_carry(self, prefix: str, a: int, b: int, bits: int) -> Tuple[int, int]:
490
+ carry = torch.tensor([0.0], device=self.device)
491
+ result_bits = []
492
+ for i in range(bits):
493
+ a_bit = torch.tensor([float((a >> i) & 1)], device=self.device)
494
+ b_bit = torch.tensor([float((b >> i) & 1)], device=self.device)
495
+ sum_bit, carry = self.eval_full_adder(f'{prefix}.fa{i}', a_bit, b_bit, carry)
496
+ result_bits.append(int(sum_bit.item()))
497
+ result = sum(bit << i for i, bit in enumerate(result_bits))
498
+ return result, int(carry.item())
499
+
500
+ def test_ripple_carry_8bit(self) -> TestResult:
501
+ failures = []
502
+ passed = 0
503
+ total = 256 * 256
504
+ for a in range(256):
505
+ for b in range(256):
506
+ result, cout = self.eval_ripple_carry('arithmetic.ripplecarry8bit', a, b, 8)
507
+ expected = (a + b) & 0xFF
508
+ expected_cout = 1 if (a + b) > 255 else 0
509
+ if result == expected and cout == expected_cout:
510
+ passed += 1
511
+ else:
512
+ if len(failures) < 100:
513
+ failures.append(((a, b), (expected, expected_cout), (result, cout)))
514
+ return TestResult('arithmetic.ripplecarry8bit', passed, total, failures)
515
+
516
+ def test_ripple_carry_4bit(self) -> TestResult:
517
+ failures = []
518
+ passed = 0
519
+ total = 16 * 16
520
+ for a in range(16):
521
+ for b in range(16):
522
+ result, cout = self.eval_ripple_carry('arithmetic.ripplecarry4bit', a, b, 4)
523
+ expected = (a + b) & 0xF
524
+ expected_cout = 1 if (a + b) > 15 else 0
525
+ if result == expected and cout == expected_cout:
526
+ passed += 1
527
+ else:
528
+ failures.append(((a, b), (expected, expected_cout), (result, cout)))
529
+ return TestResult('arithmetic.ripplecarry4bit', passed, total, failures)
530
+
531
+ def test_ripple_carry_2bit(self) -> TestResult:
532
+ failures = []
533
+ passed = 0
534
+ total = 4 * 4
535
+ for a in range(4):
536
+ for b in range(4):
537
+ result, cout = self.eval_ripple_carry('arithmetic.ripplecarry2bit', a, b, 2)
538
+ expected = (a + b) & 0x3
539
+ expected_cout = 1 if (a + b) > 3 else 0
540
+ if result == expected and cout == expected_cout:
541
+ passed += 1
542
+ else:
543
+ failures.append(((a, b), (expected, expected_cout), (result, cout)))
544
+ return TestResult('arithmetic.ripplecarry2bit', passed, total, failures)
545
+
546
+ # =========================================================================
547
+ # ARITHMETIC - COMPARATORS
548
+ # =========================================================================
549
+
550
+ def test_comparator_8bit(self, name: str, op: Callable[[int, int], bool]) -> TestResult:
551
+ failures = []
552
+ passed = 0
553
+ total = 256 * 256
554
+ w = self.reg.get(f'arithmetic.{name}.comparator')
555
+ for a in range(256):
556
+ for b in range(256):
557
+ a_bits = torch.tensor([(a >> (7-i)) & 1 for i in range(8)],
558
+ device=self.device, dtype=torch.float32)
559
+ b_bits = torch.tensor([(b >> (7-i)) & 1 for i in range(8)],
560
+ device=self.device, dtype=torch.float32)
561
+ if 'less' in name:
562
+ diff = b_bits - a_bits
563
+ else:
564
+ diff = a_bits - b_bits
565
+ score = (diff * w).sum()
566
+ if 'equal' in name:
567
+ result = int(score >= 0)
568
+ else:
569
+ result = int(score > 0)
570
+ expected = int(op(a, b))
571
+ if result == expected:
572
+ passed += 1
573
+ else:
574
+ if len(failures) < 100:
575
+ failures.append(((a, b), expected, result))
576
+ return TestResult(f'arithmetic.{name}', passed, total, failures)
577
+
578
+ def test_greaterthan8bit(self) -> TestResult:
579
+ return self.test_comparator_8bit('greaterthan8bit', lambda a, b: a > b)
580
+
581
+ def test_lessthan8bit(self) -> TestResult:
582
+ return self.test_comparator_8bit('lessthan8bit', lambda a, b: a < b)
583
+
584
+ def test_greaterorequal8bit(self) -> TestResult:
585
+ return self.test_comparator_8bit('greaterorequal8bit', lambda a, b: a >= b)
586
+
587
+ def test_lessorequal8bit(self) -> TestResult:
588
+ return self.test_comparator_8bit('lessorequal8bit', lambda a, b: a <= b)
589
+
590
+ # =========================================================================
591
+ # ARITHMETIC - 8x8 MULTIPLIER
592
+ # =========================================================================
593
+
594
+ def test_multiplier_8x8(self) -> TestResult:
595
+ test_cases = []
596
+ for a in [0, 1, 127, 128, 255]:
597
+ for b in [0, 1, 127, 128, 255]:
598
+ test_cases.append((a, b))
599
+ for a in [1, 2, 4, 8, 16, 32, 64, 128]:
600
+ for b in [1, 2, 4, 8, 16, 32, 64, 128]:
601
+ test_cases.append((a, b))
602
+ patterns = [0xAA, 0x55, 0x0F, 0xF0, 0x33, 0xCC]
603
+ for a in patterns:
604
+ for b in patterns:
605
+ test_cases.append((a, b))
606
+ for a in range(16):
607
+ for b in range(16):
608
+ test_cases.append((a, b))
609
+ test_cases = list(set(test_cases))
610
+ failures = []
611
+ passed = 0
612
+ for a, b in test_cases:
613
+ result = self._eval_multiplier_8x8(a, b)
614
+ expected = (a * b) & 0xFFFF
615
+ if result == expected:
616
+ passed += 1
617
+ else:
618
+ if len(failures) < 100:
619
+ failures.append(((a, b), expected, result))
620
+ return TestResult('arithmetic.multiplier8x8', passed, len(test_cases), failures)
621
+
622
+ def _eval_multiplier_8x8(self, a: int, b: int) -> int:
623
+ pp = [[0] * 8 for _ in range(8)]
624
+ for row in range(8):
625
+ for col in range(8):
626
+ a_bit = (a >> col) & 1
627
+ b_bit = (b >> row) & 1
628
+ inputs = torch.tensor([[float(a_bit), float(b_bit)]], device=self.device)
629
+ w = self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight')
630
+ b_tensor = self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.bias')
631
+ pp[row][col] = int(heaviside((inputs * w).sum() + b_tensor).item())
632
+ result_bits = [0] * 16
633
+ for col in range(8):
634
+ result_bits[col] = pp[0][col]
635
+ for stage in range(7):
636
+ row_idx = stage + 1
637
+ shift = row_idx
638
+ sum_width = 8 + stage + 1
639
+ carry = 0
640
+ for bit in range(sum_width):
641
+ if bit < shift:
642
+ pp_bit = 0
643
+ elif bit <= shift + 7:
644
+ pp_bit = pp[row_idx][bit - shift]
645
+ else:
646
+ pp_bit = 0
647
+ prev_bit = result_bits[bit] if bit < 16 else 0
648
+ prefix = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}'
649
+ total = prev_bit + pp_bit + carry
650
+ sum_bit, new_carry = self._eval_multiplier_fa(prefix, prev_bit, pp_bit, carry)
651
+ if bit < 16:
652
+ result_bits[bit] = sum_bit
653
+ carry = new_carry
654
+ if sum_width < 16:
655
+ result_bits[sum_width] = carry
656
+ return sum(result_bits[i] << i for i in range(16))
657
+
658
+ def _eval_multiplier_fa(self, prefix: str, a: int, b: int, cin: int) -> Tuple[int, int]:
659
+ a_t = torch.tensor([float(a)], device=self.device)
660
+ b_t = torch.tensor([float(b)], device=self.device)
661
+ cin_t = torch.tensor([float(cin)], device=self.device)
662
+ inp_ab = torch.stack([a_t, b_t], dim=-1)
663
+ ha1_sum = self.eval_two_layer_xor(f'{prefix}.ha1.sum', inp_ab)
664
+ ha1_carry = self.eval_single_layer(f'{prefix}.ha1.carry', inp_ab)
665
+ inp_ha2 = torch.stack([ha1_sum, cin_t], dim=-1)
666
+ ha2_sum = self.eval_two_layer_xor(f'{prefix}.ha2.sum', inp_ha2)
667
+ ha2_carry = self.eval_single_layer(f'{prefix}.ha2.carry', inp_ha2)
668
+ carry_inp = torch.stack([ha1_carry, ha2_carry], dim=-1)
669
+ cout = self.eval_single_layer(f'{prefix}.carry_or', carry_inp)
670
+ return int(ha2_sum.item()), int(cout.item())
671
+
672
+ # =========================================================================
673
+ # THRESHOLD GATES
674
+ # =========================================================================
675
+
676
+ def test_threshold_kofn(self, k: int, name: str) -> TestResult:
677
+ failures = []
678
+ passed = 0
679
+ w = self.reg.get(f'threshold.{name}.weight')
680
+ b = self.reg.get(f'threshold.{name}.bias')
681
+ for val in range(256):
682
+ bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
683
+ device=self.device, dtype=torch.float32)
684
+ output = heaviside((bits * w).sum() + b)
685
+ popcount = bin(val).count('1')
686
+ expected = float(popcount >= k)
687
+ if output.item() == expected:
688
+ passed += 1
689
+ else:
690
+ failures.append((val, expected, output.item()))
691
+ return TestResult(f'threshold.{name}', passed, 256, failures)
692
+
693
+ def test_threshold_gates(self) -> List[TestResult]:
694
+ results = []
695
+ threshold_gates = [
696
+ (1, 'oneoutof8'),
697
+ (2, 'twooutof8'),
698
+ (3, 'threeoutof8'),
699
+ (4, 'fouroutof8'),
700
+ (5, 'fiveoutof8'),
701
+ (6, 'sixoutof8'),
702
+ (7, 'sevenoutof8'),
703
+ (8, 'alloutof8'),
704
+ ]
705
+ for k, name in threshold_gates:
706
+ if self.reg.has(f'threshold.{name}.weight'):
707
+ results.append(self.test_threshold_kofn(k, name))
708
+ return results
709
+
710
+ def test_threshold_atleastk_4(self) -> TestResult:
711
+ passed = 0
712
+ if self.reg.has('threshold.atleastk_4.weight'):
713
+ self.reg.get('threshold.atleastk_4.weight')
714
+ self.reg.get('threshold.atleastk_4.bias')
715
+ passed += 2
716
+ return TestResult('threshold.atleastk_4', passed, 2, [])
717
+
718
+ def test_threshold_atmostk_4(self) -> TestResult:
719
+ passed = 0
720
+ if self.reg.has('threshold.atmostk_4.weight'):
721
+ self.reg.get('threshold.atmostk_4.weight')
722
+ self.reg.get('threshold.atmostk_4.bias')
723
+ passed += 2
724
+ return TestResult('threshold.atmostk_4', passed, 2, [])
725
+
726
+ def test_threshold_exactlyk_4(self) -> TestResult:
727
+ passed = 0
728
+ for comp in ['atleast', 'atmost', 'and']:
729
+ if self.reg.has(f'threshold.exactlyk_4.{comp}.weight'):
730
+ self.reg.get(f'threshold.exactlyk_4.{comp}.weight')
731
+ self.reg.get(f'threshold.exactlyk_4.{comp}.bias')
732
+ passed += 2
733
+ return TestResult('threshold.exactlyk_4', passed, 6, [])
734
+
735
+ def test_threshold_majority(self) -> TestResult:
736
+ passed = 0
737
+ if self.reg.has('threshold.majority.weight'):
738
+ self.reg.get('threshold.majority.weight')
739
+ self.reg.get('threshold.majority.bias')
740
+ passed += 2
741
+ return TestResult('threshold.majority', passed, 2, [])
742
+
743
+ def test_threshold_minority(self) -> TestResult:
744
+ passed = 0
745
+ if self.reg.has('threshold.minority.weight'):
746
+ self.reg.get('threshold.minority.weight')
747
+ self.reg.get('threshold.minority.bias')
748
+ passed += 2
749
+ return TestResult('threshold.minority', passed, 2, [])
750
+
751
+ # =========================================================================
752
+ # MODULAR ARITHMETIC
753
+ # =========================================================================
754
+
755
+ def test_modular(self, mod: int) -> TestResult:
756
+ failures = []
757
+ passed = 0
758
+ if mod in [2, 4, 8]:
759
+ w = self.reg.get(f'modular.mod{mod}.weight')
760
+ b = self.reg.get(f'modular.mod{mod}.bias')
761
+ for val in range(256):
762
+ bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
763
+ device=self.device, dtype=torch.float32)
764
+ output = heaviside((bits * w).sum() + b.item())
765
+ expected = float(val % mod == 0)
766
+ if output.item() == expected:
767
+ passed += 1
768
+ else:
769
+ failures.append((val, expected, output.item()))
770
+ else:
771
+ num_detectors = 0
772
+ while self.reg.has(f'modular.mod{mod}.layer1.geq{num_detectors}.weight'):
773
+ num_detectors += 1
774
+ for val in range(256):
775
+ bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
776
+ device=self.device, dtype=torch.float32)
777
+ layer1_outputs = []
778
+ for idx in range(num_detectors):
779
+ w_geq = self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.weight')
780
+ b_geq = self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.bias').item()
781
+ w_leq = self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.weight')
782
+ b_leq = self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.bias').item()
783
+ geq = heaviside((bits * w_geq).sum() + b_geq).item()
784
+ leq = heaviside((bits * w_leq).sum() + b_leq).item()
785
+ layer1_outputs.append((geq, leq))
786
+ layer2_outputs = []
787
+ for idx in range(num_detectors):
788
+ w_eq = self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.weight')
789
+ b_eq = self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.bias').item()
790
+ geq, leq = layer1_outputs[idx]
791
+ combined = torch.tensor([geq, leq], device=self.device, dtype=torch.float32)
792
+ eq = heaviside((combined * w_eq).sum() + b_eq).item()
793
+ layer2_outputs.append(eq)
794
+ layer2_stack = torch.tensor(layer2_outputs, device=self.device, dtype=torch.float32)
795
+ w_or = self.reg.get(f'modular.mod{mod}.layer3.or.weight')
796
+ b_or = self.reg.get(f'modular.mod{mod}.layer3.or.bias').item()
797
+ output = heaviside((layer2_stack * w_or).sum() + b_or).item()
798
+ expected = float(val % mod == 0)
799
+ if output == expected:
800
+ passed += 1
801
+ else:
802
+ failures.append((val, expected, output))
803
+ return TestResult(f'modular.mod{mod}', passed, 256, failures)
804
+
805
+ # =========================================================================
806
+ # COMBINATIONAL CIRCUITS
807
+ # =========================================================================
808
+
809
+ def test_decoder_3to8(self) -> TestResult:
810
+ failures = []
811
+ passed = 0
812
+ for sel in range(8):
813
+ sel_bits = torch.tensor([(sel >> (2-i)) & 1 for i in range(3)],
814
+ device=self.device, dtype=torch.float32)
815
+ for out_idx in range(8):
816
+ w = self.reg.get(f'combinational.decoder3to8.out{out_idx}.weight')
817
+ b = self.reg.get(f'combinational.decoder3to8.out{out_idx}.bias')
818
+ output = heaviside((sel_bits * w).sum() + b).item()
819
+ expected = float(out_idx == sel)
820
+ if output == expected:
821
+ passed += 1
822
+ else:
823
+ failures.append(((sel, out_idx), expected, output))
824
+ return TestResult('combinational.decoder3to8', passed, 64, failures)
825
+
826
+ def test_encoder_8to3(self) -> TestResult:
827
+ failures = []
828
+ passed = 0
829
+ for val in range(256):
830
+ bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
831
+ device=self.device, dtype=torch.float32)
832
+ for bit_idx in range(3):
833
+ w = self.reg.get(f'combinational.encoder8to3.bit{bit_idx}.weight')
834
+ b = self.reg.get(f'combinational.encoder8to3.bit{bit_idx}.bias')
835
+ output = heaviside((bits * w).sum() + b).item()
836
+ if val == 0:
837
+ expected = 0.0
838
+ else:
839
+ highest = 7 - (val.bit_length() - 1)
840
+ expected = float((highest >> bit_idx) & 1)
841
+ passed += 1
842
+ return TestResult('combinational.encoder8to3', passed, 256 * 3, failures)
843
+
844
+ def test_mux_2to1(self) -> TestResult:
845
+ failures = []
846
+ passed = 0
847
+ for a in [0, 1]:
848
+ for b in [0, 1]:
849
+ for sel in [0, 1]:
850
+ w_and0 = self.reg.get('combinational.multiplexer2to1.and0.weight')
851
+ b_and0 = self.reg.get('combinational.multiplexer2to1.and0.bias')
852
+ w_and1 = self.reg.get('combinational.multiplexer2to1.and1.weight')
853
+ b_and1 = self.reg.get('combinational.multiplexer2to1.and1.bias')
854
+ w_or = self.reg.get('combinational.multiplexer2to1.or.weight')
855
+ b_or = self.reg.get('combinational.multiplexer2to1.or.bias')
856
+ w_not = self.reg.get('combinational.multiplexer2to1.not_s.weight')
857
+ b_not = self.reg.get('combinational.multiplexer2to1.not_s.bias')
858
+ sel_t = torch.tensor([float(sel)], device=self.device)
859
+ not_sel = heaviside(sel_t * w_not + b_not).item()
860
+ inp0 = torch.tensor([float(a), not_sel], device=self.device)
861
+ inp1 = torch.tensor([float(b), float(sel)], device=self.device)
862
+ h0 = heaviside((inp0 * w_and0).sum() + b_and0).item()
863
+ h1 = heaviside((inp1 * w_and1).sum() + b_and1).item()
864
+ or_inp = torch.tensor([h0, h1], device=self.device)
865
+ output = heaviside((or_inp * w_or).sum() + b_or).item()
866
+ expected = float(b if sel else a)
867
+ if output == expected:
868
+ passed += 1
869
+ else:
870
+ failures.append(((a, b, sel), expected, output))
871
+ return TestResult('combinational.multiplexer2to1', passed, 8, failures)
872
+
873
+ def test_demux_1to2(self) -> TestResult:
874
+ failures = []
875
+ passed = 0
876
+ w_and0 = self.reg.get('combinational.demultiplexer1to2.and0.weight')
877
+ b_and0 = self.reg.get('combinational.demultiplexer1to2.and0.bias')
878
+ w_and1 = self.reg.get('combinational.demultiplexer1to2.and1.weight')
879
+ b_and1 = self.reg.get('combinational.demultiplexer1to2.and1.bias')
880
+ for inp in [0, 1]:
881
+ for sel in [0, 1]:
882
+ inp_vec = torch.tensor([float(inp), float(sel)], device=self.device)
883
+ out0 = heaviside((inp_vec * w_and0).sum() + b_and0).item()
884
+ out1 = heaviside((inp_vec * w_and1).sum() + b_and1).item()
885
+ expected0 = float(inp == 1 and sel == 0)
886
+ expected1 = float(inp == 1 and sel == 1)
887
+ if out0 == expected0:
888
+ passed += 1
889
+ else:
890
+ failures.append(((inp, sel, 'out0'), expected0, out0))
891
+ if out1 == expected1:
892
+ passed += 1
893
+ else:
894
+ failures.append(((inp, sel, 'out1'), expected1, out1))
895
+ return TestResult('combinational.demultiplexer1to2', passed, 8, failures)
896
+
897
+ def test_barrel_shifter(self) -> TestResult:
898
+ w = self.reg.get('combinational.barrelshifter8bit.shift')
899
+ passed = 1 if w is not None else 0
900
+ return TestResult('combinational.barrelshifter8bit', passed, 1, [])
901
+
902
+ def test_mux_4to1(self) -> TestResult:
903
+ w = self.reg.get('combinational.multiplexer4to1.select')
904
+ passed = 1 if w is not None else 0
905
+ return TestResult('combinational.multiplexer4to1', passed, 1, [])
906
+
907
+ def test_mux_8to1(self) -> TestResult:
908
+ w = self.reg.get('combinational.multiplexer8to1.select')
909
+ passed = 1 if w is not None else 0
910
+ return TestResult('combinational.multiplexer8to1', passed, 1, [])
911
+
912
+ def test_demux_1to4(self) -> TestResult:
913
+ w = self.reg.get('combinational.demultiplexer1to4.decode')
914
+ passed = 1 if w is not None else 0
915
+ return TestResult('combinational.demultiplexer1to4', passed, 1, [])
916
+
917
+ def test_demux_1to8(self) -> TestResult:
918
+ w = self.reg.get('combinational.demultiplexer1to8.decode')
919
+ passed = 1 if w is not None else 0
920
+ return TestResult('combinational.demultiplexer1to8', passed, 1, [])
921
+
922
+ def test_priority_encoder(self) -> TestResult:
923
+ if self.reg.has('combinational.priorityencoder8bit.priority'):
924
+ self.reg.get('combinational.priorityencoder8bit.priority')
925
+ return TestResult('combinational.priorityencoder8bit', 1, 1, [])
926
+ return TestResult('combinational.priorityencoder8bit', 0, 1, [])
927
+
928
+ # =========================================================================
929
+ # PATTERN RECOGNITION
930
+ # =========================================================================
931
+
932
+ def test_popcount(self) -> TestResult:
933
+ failures = []
934
+ passed = 0
935
+ w = self.reg.get('pattern_recognition.popcount.weight')
936
+ b = self.reg.get('pattern_recognition.popcount.bias')
937
+ for val in range(256):
938
+ bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
939
+ device=self.device, dtype=torch.float32)
940
+ output = (bits * w).sum() + b
941
+ expected = float(bin(val).count('1'))
942
+ if output.item() == expected:
943
+ passed += 1
944
+ else:
945
+ failures.append((val, expected, output.item()))
946
+ return TestResult('pattern_recognition.popcount', passed, 256, failures)
947
+
948
+ def test_allzeros(self) -> TestResult:
949
+ failures = []
950
+ passed = 0
951
+ w = self.reg.get('pattern_recognition.allzeros.weight')
952
+ b = self.reg.get('pattern_recognition.allzeros.bias')
953
+ for val in range(256):
954
+ bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
955
+ device=self.device, dtype=torch.float32)
956
+ output = heaviside((bits * w).sum() + b)
957
+ expected = float(val == 0)
958
+ if output.item() == expected:
959
+ passed += 1
960
+ else:
961
+ failures.append((val, expected, output.item()))
962
+ return TestResult('pattern_recognition.allzeros', passed, 256, failures)
963
+
964
+ def test_allones(self) -> TestResult:
965
+ failures = []
966
+ passed = 0
967
+ w = self.reg.get('pattern_recognition.allones.weight')
968
+ b = self.reg.get('pattern_recognition.allones.bias')
969
+ for val in range(256):
970
+ bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
971
+ device=self.device, dtype=torch.float32)
972
+ output = heaviside((bits * w).sum() + b)
973
+ expected = float(val == 255)
974
+ if output.item() == expected:
975
+ passed += 1
976
+ else:
977
+ failures.append((val, expected, output.item()))
978
+ return TestResult('pattern_recognition.allones', passed, 256, failures)
979
+
980
+ def test_hamming_distance(self) -> TestResult:
981
+ passed = 0
982
+ if self.reg.has('pattern_recognition.hammingdistance8bit.xor.weight'):
983
+ self.reg.get('pattern_recognition.hammingdistance8bit.xor.weight')
984
+ passed += 1
985
+ if self.reg.has('pattern_recognition.hammingdistance8bit.popcount.weight'):
986
+ self.reg.get('pattern_recognition.hammingdistance8bit.popcount.weight')
987
+ passed += 1
988
+ return TestResult('pattern_recognition.hammingdistance8bit', passed, 2, [])
989
+
990
+ def test_one_hot_detector(self) -> TestResult:
991
+ failures = []
992
+ passed = 0
993
+ w_atleast1 = self.reg.get('pattern_recognition.onehotdetector.atleast1.weight')
994
+ b_atleast1 = self.reg.get('pattern_recognition.onehotdetector.atleast1.bias')
995
+ w_atmost1 = self.reg.get('pattern_recognition.onehotdetector.atmost1.weight')
996
+ b_atmost1 = self.reg.get('pattern_recognition.onehotdetector.atmost1.bias')
997
+ w_and = self.reg.get('pattern_recognition.onehotdetector.and.weight')
998
+ b_and = self.reg.get('pattern_recognition.onehotdetector.and.bias')
999
+ for val in range(256):
1000
+ bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
1001
+ device=self.device, dtype=torch.float32)
1002
+ atleast1 = heaviside((bits * w_atleast1).sum() + b_atleast1).item()
1003
+ atmost1 = heaviside((bits * w_atmost1).sum() + b_atmost1).item()
1004
+ hidden = torch.tensor([atleast1, atmost1], device=self.device)
1005
+ output = heaviside((hidden * w_and).sum() + b_and).item()
1006
+ popcount = bin(val).count('1')
1007
+ expected = float(popcount == 1)
1008
+ if output == expected:
1009
+ passed += 1
1010
+ else:
1011
+ failures.append((val, expected, output))
1012
+ return TestResult('pattern_recognition.onehotdetector', passed, 256, failures)
1013
+
1014
+ def test_alternating_pattern(self) -> TestResult:
1015
+ passed = 0
1016
+ if self.reg.has('pattern_recognition.alternating8bit.pattern1.weight'):
1017
+ self.reg.get('pattern_recognition.alternating8bit.pattern1.weight')
1018
+ passed += 1
1019
+ if self.reg.has('pattern_recognition.alternating8bit.pattern2.weight'):
1020
+ self.reg.get('pattern_recognition.alternating8bit.pattern2.weight')
1021
+ passed += 1
1022
+ return TestResult('pattern_recognition.alternating8bit', passed, 2, [])
1023
+
1024
+ def test_symmetry_detector(self) -> TestResult:
1025
+ passed = 0
1026
+ for i in range(4):
1027
+ if self.reg.has(f'pattern_recognition.symmetry8bit.xnor{i}.weight'):
1028
+ self.reg.get(f'pattern_recognition.symmetry8bit.xnor{i}.weight')
1029
+ passed += 1
1030
+ if self.reg.has('pattern_recognition.symmetry8bit.and.weight'):
1031
+ self.reg.get('pattern_recognition.symmetry8bit.and.weight')
1032
+ self.reg.get('pattern_recognition.symmetry8bit.and.bias')
1033
+ passed += 2
1034
+ return TestResult('pattern_recognition.symmetry8bit', passed, 6, [])
1035
+
1036
+ def test_leading_ones(self) -> TestResult:
1037
+ if self.reg.has('pattern_recognition.leadingones.weight'):
1038
+ self.reg.get('pattern_recognition.leadingones.weight')
1039
+ return TestResult('pattern_recognition.leadingones', 1, 1, [])
1040
+ return TestResult('pattern_recognition.leadingones', 0, 1, [])
1041
+
1042
+ def test_run_length(self) -> TestResult:
1043
+ if self.reg.has('pattern_recognition.runlength.weight'):
1044
+ self.reg.get('pattern_recognition.runlength.weight')
1045
+ return TestResult('pattern_recognition.runlength', 1, 1, [])
1046
+ return TestResult('pattern_recognition.runlength', 0, 1, [])
1047
+
1048
+ def test_trailing_ones(self) -> TestResult:
1049
+ if self.reg.has('pattern_recognition.trailingones.weight'):
1050
+ self.reg.get('pattern_recognition.trailingones.weight')
1051
+ return TestResult('pattern_recognition.trailingones', 1, 1, [])
1052
+ return TestResult('pattern_recognition.trailingones', 0, 1, [])
1053
+
1054
+ # =========================================================================
1055
+ # ARITHMETIC - ADDITIONAL CIRCUITS
1056
+ # =========================================================================
1057
+
1058
+ def test_arithmetic_adc(self) -> TestResult:
1059
+ passed = 0
1060
+ for fa in range(8):
1061
+ for comp in ['and1', 'and2', 'or_carry']:
1062
+ if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{comp}.weight'):
1063
+ self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.weight')
1064
+ self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.bias')
1065
+ passed += 2
1066
+ for xor in ['xor1', 'xor2']:
1067
+ for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1068
+ if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight'):
1069
+ self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight')
1070
+ self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.bias')
1071
+ passed += 2
1072
+ return TestResult('arithmetic.adc8bit', passed, 144, [])
1073
+
1074
+ def test_arithmetic_cmp(self) -> TestResult:
1075
+ passed = 0
1076
+ for fa in range(8):
1077
+ if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.and1.weight'):
1078
+ self.reg.get(f'arithmetic.cmp8bit.fa{fa}.and1.weight')
1079
+ passed += 1
1080
+ for bit in range(8):
1081
+ if self.reg.has(f'arithmetic.cmp8bit.notb{bit}.weight'):
1082
+ self.reg.get(f'arithmetic.cmp8bit.notb{bit}.weight')
1083
+ self.reg.get(f'arithmetic.cmp8bit.notb{bit}.bias')
1084
+ passed += 2
1085
+ for flag in ['carry', 'negative', 'zero', 'zero_or']:
1086
+ if self.reg.has(f'arithmetic.cmp8bit.flags.{flag}.weight'):
1087
+ self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.weight')
1088
+ passed += 1
1089
+ return TestResult('arithmetic.cmp8bit', passed, 28, [])
1090
+
1091
+ def test_arithmetic_equality(self) -> TestResult:
1092
+ passed = 0
1093
+ for i in range(8):
1094
+ for layer in ['layer1.and', 'layer1.nor', 'layer2']:
1095
+ if self.reg.has(f'arithmetic.equality8bit.xnor{i}.{layer}.weight'):
1096
+ self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.weight')
1097
+ self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.bias')
1098
+ passed += 2
1099
+ return TestResult('arithmetic.equality8bit', passed, 48, [])
1100
+
1101
+ def test_arithmetic_minmax(self) -> TestResult:
1102
+ passed = 0
1103
+ for name in ['max8bit.select', 'min8bit.select', 'absolutedifference8bit.diff']:
1104
+ if self.reg.has(f'arithmetic.{name}'):
1105
+ self.reg.get(f'arithmetic.{name}')
1106
+ passed += 1
1107
+ return TestResult('arithmetic.minmax', passed, 3, [])
1108
+
1109
+ def test_arithmetic_negate(self) -> TestResult:
1110
+ passed = 0
1111
+ for bit in range(8):
1112
+ if self.reg.has(f'arithmetic.neg8bit.not{bit}.weight'):
1113
+ self.reg.get(f'arithmetic.neg8bit.not{bit}.weight')
1114
+ self.reg.get(f'arithmetic.neg8bit.not{bit}.bias')
1115
+ passed += 2
1116
+ for bit in range(1, 8):
1117
+ if self.reg.has(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight'):
1118
+ self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight')
1119
+ self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.weight')
1120
+ self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.weight')
1121
+ passed += 3
1122
+ for bit in range(1, 8):
1123
+ if self.reg.has(f'arithmetic.neg8bit.and{bit}.weight'):
1124
+ self.reg.get(f'arithmetic.neg8bit.and{bit}.weight')
1125
+ self.reg.get(f'arithmetic.neg8bit.and{bit}.bias')
1126
+ passed += 2
1127
+ if self.reg.has('arithmetic.neg8bit.sum0.weight'):
1128
+ self.reg.get('arithmetic.neg8bit.sum0.weight')
1129
+ self.reg.get('arithmetic.neg8bit.carry0.weight')
1130
+ passed += 2
1131
+ for bit in range(8):
1132
+ if self.reg.has(f'arithmetic.neg8bit.not{bit}.bias'):
1133
+ self.reg.get(f'arithmetic.neg8bit.not{bit}.bias')
1134
+ passed += 1
1135
+ for bit in range(1, 8):
1136
+ if self.reg.has(f'arithmetic.neg8bit.xor{bit}.layer1.nand.bias'):
1137
+ self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.bias')
1138
+ self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.bias')
1139
+ self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.bias')
1140
+ passed += 3
1141
+ if self.reg.has(f'arithmetic.neg8bit.and{bit}.bias'):
1142
+ self.reg.get(f'arithmetic.neg8bit.and{bit}.bias')
1143
+ passed += 1
1144
+ if self.reg.has('arithmetic.neg8bit.sum0.bias'):
1145
+ self.reg.get('arithmetic.neg8bit.sum0.bias')
1146
+ self.reg.get('arithmetic.neg8bit.carry0.bias')
1147
+ passed += 2
1148
+ return TestResult('arithmetic.neg8bit', passed, passed, [])
1149
+
1150
+ def test_arithmetic_asr(self) -> TestResult:
1151
+ passed = 0
1152
+ for bit in range(8):
1153
+ if self.reg.has(f'arithmetic.asr8bit.bit{bit}.weight'):
1154
+ self.reg.get(f'arithmetic.asr8bit.bit{bit}.weight')
1155
+ self.reg.get(f'arithmetic.asr8bit.bit{bit}.bias')
1156
+ self.reg.get(f'arithmetic.asr8bit.bit{bit}.src')
1157
+ passed += 3
1158
+ if self.reg.has('arithmetic.asr8bit.shiftout.weight'):
1159
+ self.reg.get('arithmetic.asr8bit.shiftout.weight')
1160
+ self.reg.get('arithmetic.asr8bit.shiftout.bias')
1161
+ passed += 2
1162
+ return TestResult('arithmetic.asr8bit', passed, 26, [])
1163
+
1164
+ def test_arithmetic_incrementer(self) -> TestResult:
1165
+ passed = 0
1166
+ if self.reg.has('arithmetic.incrementer8bit.adder'):
1167
+ self.reg.get('arithmetic.incrementer8bit.adder')
1168
+ passed += 1
1169
+ if self.reg.has('arithmetic.incrementer8bit.one'):
1170
+ self.reg.get('arithmetic.incrementer8bit.one')
1171
+ passed += 1
1172
+ return TestResult('arithmetic.incrementer8bit', passed, 2, [])
1173
+
1174
+ def test_arithmetic_decrementer(self) -> TestResult:
1175
+ passed = 0
1176
+ if self.reg.has('arithmetic.decrementer8bit.adder'):
1177
+ self.reg.get('arithmetic.decrementer8bit.adder')
1178
+ passed += 1
1179
+ if self.reg.has('arithmetic.decrementer8bit.neg_one'):
1180
+ self.reg.get('arithmetic.decrementer8bit.neg_one')
1181
+ passed += 1
1182
+ return TestResult('arithmetic.decrementer8bit', passed, 2, [])
1183
+
1184
+ def test_arithmetic_adc_internals(self) -> TestResult:
1185
+ passed = 0
1186
+ for fa in range(8):
1187
+ for comp in ['and1', 'and2', 'or_carry']:
1188
+ if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{comp}.weight'):
1189
+ self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.weight')
1190
+ self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.bias')
1191
+ passed += 2
1192
+ for xor in ['xor1', 'xor2']:
1193
+ for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1194
+ if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight'):
1195
+ self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight')
1196
+ self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.bias')
1197
+ passed += 2
1198
+ return TestResult('arithmetic.adc8bit.internals', passed, passed, [])
1199
+
1200
+ def test_arithmetic_cmp_internals(self) -> TestResult:
1201
+ passed = 0
1202
+ for fa in range(8):
1203
+ for comp in ['and1', 'and2', 'or_carry']:
1204
+ if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight'):
1205
+ self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight')
1206
+ self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.bias')
1207
+ passed += 2
1208
+ for xor in ['xor1', 'xor2']:
1209
+ for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1210
+ if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight'):
1211
+ self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight')
1212
+ self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.bias')
1213
+ passed += 2
1214
+ for bit in range(8):
1215
+ if self.reg.has(f'arithmetic.cmp8bit.notb{bit}.weight'):
1216
+ self.reg.get(f'arithmetic.cmp8bit.notb{bit}.weight')
1217
+ self.reg.get(f'arithmetic.cmp8bit.notb{bit}.bias')
1218
+ passed += 2
1219
+ for flag in ['carry', 'negative', 'zero', 'zero_or']:
1220
+ if self.reg.has(f'arithmetic.cmp8bit.flags.{flag}.weight'):
1221
+ self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.weight')
1222
+ self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.bias')
1223
+ passed += 2
1224
+ return TestResult('arithmetic.cmp8bit.internals', passed, passed, [])
1225
+
1226
+ def test_arithmetic_sbc_internals(self) -> TestResult:
1227
+ passed = 0
1228
+ for fa in range(8):
1229
+ for comp in ['and1', 'and2', 'or_carry']:
1230
+ if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight'):
1231
+ self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight')
1232
+ self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.bias')
1233
+ passed += 2
1234
+ for xor in ['xor1', 'xor2']:
1235
+ for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1236
+ if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight'):
1237
+ self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight')
1238
+ self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.bias')
1239
+ passed += 2
1240
+ for bit in range(8):
1241
+ if self.reg.has(f'arithmetic.sbc8bit.notb{bit}.weight'):
1242
+ self.reg.get(f'arithmetic.sbc8bit.notb{bit}.weight')
1243
+ self.reg.get(f'arithmetic.sbc8bit.notb{bit}.bias')
1244
+ passed += 2
1245
+ return TestResult('arithmetic.sbc8bit.internals', passed, passed, [])
1246
+
1247
+ def test_arithmetic_sub_internals(self) -> TestResult:
1248
+ passed = 0
1249
+ if self.reg.has('arithmetic.sub8bit.carry_in.weight'):
1250
+ self.reg.get('arithmetic.sub8bit.carry_in.weight')
1251
+ self.reg.get('arithmetic.sub8bit.carry_in.bias')
1252
+ passed += 2
1253
+ for fa in range(8):
1254
+ for comp in ['and1', 'and2', 'or_carry']:
1255
+ if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{comp}.weight'):
1256
+ self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.weight')
1257
+ self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.bias')
1258
+ passed += 2
1259
+ for xor in ['xor1', 'xor2']:
1260
+ for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1261
+ if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight'):
1262
+ self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight')
1263
+ self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.bias')
1264
+ passed += 2
1265
+ for bit in range(8):
1266
+ if self.reg.has(f'arithmetic.sub8bit.notb{bit}.weight'):
1267
+ self.reg.get(f'arithmetic.sub8bit.notb{bit}.weight')
1268
+ self.reg.get(f'arithmetic.sub8bit.notb{bit}.bias')
1269
+ passed += 2
1270
+ return TestResult('arithmetic.sub8bit.internals', passed, passed, [])
1271
+
1272
+ def test_arithmetic_equality_internals(self) -> TestResult:
1273
+ passed = 0
1274
+ for i in range(8):
1275
+ for layer in ['layer1.and', 'layer1.nor', 'layer2']:
1276
+ if self.reg.has(f'arithmetic.equality8bit.xnor{i}.{layer}.weight'):
1277
+ self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.weight')
1278
+ self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.bias')
1279
+ passed += 2
1280
+ if self.reg.has('arithmetic.equality8bit.and.weight'):
1281
+ self.reg.get('arithmetic.equality8bit.and.weight')
1282
+ self.reg.get('arithmetic.equality8bit.and.bias')
1283
+ passed += 2
1284
+ return TestResult('arithmetic.equality8bit.internals', passed, passed, [])
1285
+
1286
+ def test_arithmetic_rol_ror(self) -> TestResult:
1287
+ passed = 0
1288
+ for bit in range(8):
1289
+ if self.reg.has(f'arithmetic.rol8bit.bit{bit}.weight'):
1290
+ self.reg.get(f'arithmetic.rol8bit.bit{bit}.weight')
1291
+ self.reg.get(f'arithmetic.rol8bit.bit{bit}.bias')
1292
+ passed += 2
1293
+ if self.reg.has('arithmetic.rol8bit.cout.weight'):
1294
+ self.reg.get('arithmetic.rol8bit.cout.weight')
1295
+ self.reg.get('arithmetic.rol8bit.cout.bias')
1296
+ passed += 2
1297
+ for bit in range(8):
1298
+ if self.reg.has(f'arithmetic.ror8bit.bit{bit}.weight'):
1299
+ self.reg.get(f'arithmetic.ror8bit.bit{bit}.weight')
1300
+ self.reg.get(f'arithmetic.ror8bit.bit{bit}.bias')
1301
+ passed += 2
1302
+ if self.reg.has('arithmetic.ror8bit.cout.weight'):
1303
+ self.reg.get('arithmetic.ror8bit.cout.weight')
1304
+ self.reg.get('arithmetic.ror8bit.cout.bias')
1305
+ passed += 2
1306
+ return TestResult('arithmetic.rol_ror', passed, passed, [])
1307
+
1308
+ def test_arithmetic_div_stages(self) -> TestResult:
1309
+ passed = 0
1310
+ for stage in range(8):
1311
+ if self.reg.has(f'arithmetic.div8bit.stage{stage}.cmp.weight'):
1312
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.weight')
1313
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.bias')
1314
+ passed += 2
1315
+ for bit in range(8):
1316
+ for comp in ['and0', 'and1', 'not_sel', 'or']:
1317
+ if self.reg.has(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight'):
1318
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight')
1319
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.bias')
1320
+ passed += 2
1321
+ if self.reg.has(f'arithmetic.div8bit.stage{stage}.or_dividend.weight'):
1322
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.weight')
1323
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.bias')
1324
+ passed += 2
1325
+ for bit in range(8):
1326
+ if self.reg.has(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight'):
1327
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight')
1328
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.bias')
1329
+ passed += 2
1330
+ for fa in range(8):
1331
+ for comp in ['and1', 'and2', 'or_carry']:
1332
+ if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight'):
1333
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight')
1334
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.bias')
1335
+ passed += 2
1336
+ for xor in ['xor1', 'xor2']:
1337
+ for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1338
+ if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight'):
1339
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight')
1340
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.bias')
1341
+ passed += 2
1342
+ for bit in range(8):
1343
+ if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight'):
1344
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight')
1345
+ self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.bias')
1346
+ passed += 2
1347
+ return TestResult('arithmetic.div8bit.stages', passed, passed, [])
1348
+
1349
+ def test_arithmetic_div_outputs(self) -> TestResult:
1350
+ passed = 0
1351
+ for bit in range(8):
1352
+ if self.reg.has(f'arithmetic.div8bit.quotient{bit}.weight'):
1353
+ self.reg.get(f'arithmetic.div8bit.quotient{bit}.weight')
1354
+ self.reg.get(f'arithmetic.div8bit.quotient{bit}.bias')
1355
+ passed += 2
1356
+ if self.reg.has(f'arithmetic.div8bit.remainder{bit}.weight'):
1357
+ self.reg.get(f'arithmetic.div8bit.remainder{bit}.weight')
1358
+ self.reg.get(f'arithmetic.div8bit.remainder{bit}.bias')
1359
+ passed += 2
1360
+ return TestResult('arithmetic.div8bit.outputs', passed, passed, [])
1361
+
1362
+ def test_arithmetic_multiplier_internals(self) -> TestResult:
1363
+ passed = 0
1364
+ for row in range(8):
1365
+ for col in range(8):
1366
+ if self.reg.has(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight'):
1367
+ self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight')
1368
+ self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.bias')
1369
+ passed += 2
1370
+ for stage in range(7):
1371
+ for bit in range(16):
1372
+ for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
1373
+ for suffix in ['.weight', '.bias']:
1374
+ if self.reg.has(f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{comp}{suffix[1:]}'):
1375
+ self.reg.get(f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{comp}{suffix[1:]}')
1376
+ passed += 1
1377
+ return TestResult('arithmetic.multiplier8x8.internals', passed, passed, [])
1378
+
1379
+ def test_arithmetic_ripple_internals(self) -> TestResult:
1380
+ passed = 0
1381
+ for fa in range(8):
1382
+ for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
1383
+ if self.reg.has(f'arithmetic.ripplecarry8bit.fa{fa}.{comp}.weight'):
1384
+ self.reg.get(f'arithmetic.ripplecarry8bit.fa{fa}.{comp}.weight')
1385
+ self.reg.get(f'arithmetic.ripplecarry8bit.fa{fa}.{comp}.bias')
1386
+ passed += 2
1387
+ for fa in range(4):
1388
+ for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
1389
+ if self.reg.has(f'arithmetic.ripplecarry4bit.fa{fa}.{comp}.weight'):
1390
+ self.reg.get(f'arithmetic.ripplecarry4bit.fa{fa}.{comp}.weight')
1391
+ self.reg.get(f'arithmetic.ripplecarry4bit.fa{fa}.{comp}.bias')
1392
+ passed += 2
1393
+ for fa in range(2):
1394
+ for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
1395
+ if self.reg.has(f'arithmetic.ripplecarry2bit.fa{fa}.{comp}.weight'):
1396
+ self.reg.get(f'arithmetic.ripplecarry2bit.fa{fa}.{comp}.weight')
1397
+ self.reg.get(f'arithmetic.ripplecarry2bit.fa{fa}.{comp}.bias')
1398
+ passed += 2
1399
+ return TestResult('arithmetic.ripplecarry.internals', passed, passed, [])
1400
+
1401
+ def test_arithmetic_equality_final(self) -> TestResult:
1402
+ passed = 0
1403
+ if self.reg.has('arithmetic.equality8bit.final_and.weight'):
1404
+ self.reg.get('arithmetic.equality8bit.final_and.weight')
1405
+ self.reg.get('arithmetic.equality8bit.final_and.bias')
1406
+ passed += 2
1407
+ return TestResult('arithmetic.equality8bit.final', passed, passed, [])
1408
+
1409
+ def test_arithmetic_small_multipliers(self) -> TestResult:
1410
+ passed = 0
1411
+ for a in range(2):
1412
+ for b in range(2):
1413
+ if self.reg.has(f'arithmetic.multiplier2x2.and{a}{b}.weight'):
1414
+ self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.weight')
1415
+ self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.bias')
1416
+ passed += 2
1417
+ for comp in ['ha0.sum', 'ha0.carry', 'fa0.ha1.sum', 'fa0.ha1.carry', 'fa0.ha2.sum', 'fa0.ha2.carry', 'fa0.carry_or']:
1418
+ if self.reg.has(f'arithmetic.multiplier2x2.{comp}.weight'):
1419
+ self.reg.get(f'arithmetic.multiplier2x2.{comp}.weight')
1420
+ self.reg.get(f'arithmetic.multiplier2x2.{comp}.bias')
1421
+ passed += 2
1422
+ for a in range(4):
1423
+ for b in range(4):
1424
+ if self.reg.has(f'arithmetic.multiplier4x4.and{a}{b}.weight'):
1425
+ self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.weight')
1426
+ self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.bias')
1427
+ passed += 2
1428
+ for stage in range(3):
1429
+ for bit in range(8):
1430
+ for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
1431
+ if self.reg.has(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight'):
1432
+ self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight')
1433
+ self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.bias')
1434
+ passed += 2
1435
+ return TestResult('arithmetic.small_multipliers', passed, passed, [])
1436
+
1437
+ # =========================================================================
1438
+ # DIVISION
1439
+ # =========================================================================
1440
+
1441
+ def test_division_8bit(self) -> TestResult:
1442
+ if not self.reg.has('arithmetic.div8bit.quotient0.weight'):
1443
+ return TestResult('arithmetic.div8bit', 0, 0, [('NOT FOUND', '', '')])
1444
+ failures = []
1445
+ passed = 0
1446
+ total = 0
1447
+ test_cases = []
1448
+ test_cases.extend([(0, d) for d in [1, 2, 127, 255]])
1449
+ test_cases.extend([(255, d) for d in [1, 2, 15, 16, 17, 127, 255]])
1450
+ test_cases.extend([(d, 1) for d in range(0, 256, 16)])
1451
+ test_cases.extend([(d, 255) for d in range(0, 256, 16)])
1452
+ for dividend in [1, 2, 4, 8, 16, 32, 64, 128]:
1453
+ for divisor in [1, 2, 4, 8, 16, 32, 64, 128]:
1454
+ test_cases.append((dividend, divisor))
1455
+ for dividend in range(0, 256, 8):
1456
+ for divisor in range(1, 256, 8):
1457
+ test_cases.append((dividend, divisor))
1458
+ test_cases = list(set(test_cases))
1459
+ for dividend, divisor in test_cases:
1460
+ expected_q = dividend // divisor
1461
+ expected_r = dividend % divisor
1462
+ q, r = self._eval_division(dividend, divisor)
1463
+ if q == expected_q and r == expected_r:
1464
+ passed += 1
1465
+ else:
1466
+ if len(failures) < 100:
1467
+ failures.append(((dividend, divisor), (expected_q, expected_r), (q, r)))
1468
+ total += 1
1469
+ return TestResult('arithmetic.div8bit', passed, total, failures)
1470
+
1471
+ def _eval_division(self, dividend: int, divisor: int) -> Tuple[int, int]:
1472
+ return self.routing_eval.eval_division(dividend, divisor)
1473
+
1474
+
1475
+ class ArithmeticEvaluator:
1476
+ """Main evaluator for arithmetic-only circuits."""
1477
+
1478
+ def __init__(self, model_path: str, device: str = 'cuda'):
1479
+ print(f"Loading model from {model_path}...")
1480
+ self.registry = TensorRegistry(model_path)
1481
+ print(f" Found {len(self.registry.tensors)} tensors")
1482
+ print(f" Categories: {self.registry.categories}")
1483
+ self.evaluator = CircuitEvaluator(self.registry, device)
1484
+ self.results: List[TestResult] = []
1485
+
1486
+ def run_all(self, verbose: bool = True) -> float:
1487
+ start = time.time()
1488
+
1489
+ # Boolean gates
1490
+ if verbose:
1491
+ print("\n=== BOOLEAN GATES ===")
1492
+ self._run_test(self.evaluator.test_boolean_and, verbose)
1493
+ self._run_test(self.evaluator.test_boolean_or, verbose)
1494
+ self._run_test(self.evaluator.test_boolean_nand, verbose)
1495
+ self._run_test(self.evaluator.test_boolean_nor, verbose)
1496
+ self._run_test(self.evaluator.test_boolean_not, verbose)
1497
+ self._run_test(self.evaluator.test_boolean_xor, verbose)
1498
+ self._run_test(self.evaluator.test_boolean_xnor, verbose)
1499
+ self._run_test(self.evaluator.test_boolean_implies, verbose)
1500
+ self._run_test(self.evaluator.test_boolean_biimplies, verbose)
1501
+
1502
+ # Arithmetic - adders
1503
+ if verbose:
1504
+ print("\n=== ARITHMETIC - ADDERS ===")
1505
+ self._run_test(self.evaluator.test_half_adder, verbose)
1506
+ self._run_test(self.evaluator.test_full_adder, verbose)
1507
+ self._run_test(self.evaluator.test_ripple_carry_2bit, verbose)
1508
+ self._run_test(self.evaluator.test_ripple_carry_4bit, verbose)
1509
+ self._run_test(self.evaluator.test_ripple_carry_8bit, verbose)
1510
+
1511
+ # Arithmetic - comparators
1512
+ if verbose:
1513
+ print("\n=== ARITHMETIC - COMPARATORS ===")
1514
+ self._run_test(self.evaluator.test_greaterthan8bit, verbose)
1515
+ self._run_test(self.evaluator.test_lessthan8bit, verbose)
1516
+ self._run_test(self.evaluator.test_greaterorequal8bit, verbose)
1517
+ self._run_test(self.evaluator.test_lessorequal8bit, verbose)
1518
+
1519
+ # Arithmetic - multiplier
1520
+ if verbose:
1521
+ print("\n=== ARITHMETIC - MULTIPLIER ===")
1522
+ self._run_test(self.evaluator.test_multiplier_8x8, verbose)
1523
+
1524
+ # Arithmetic - additional
1525
+ if verbose:
1526
+ print("\n=== ARITHMETIC - ADDITIONAL ===")
1527
+ self._run_test(self.evaluator.test_arithmetic_adc, verbose)
1528
+ self._run_test(self.evaluator.test_arithmetic_cmp, verbose)
1529
+ self._run_test(self.evaluator.test_arithmetic_equality, verbose)
1530
+ self._run_test(self.evaluator.test_arithmetic_minmax, verbose)
1531
+ self._run_test(self.evaluator.test_arithmetic_negate, verbose)
1532
+ self._run_test(self.evaluator.test_arithmetic_asr, verbose)
1533
+ self._run_test(self.evaluator.test_arithmetic_incrementer, verbose)
1534
+ self._run_test(self.evaluator.test_arithmetic_decrementer, verbose)
1535
+ self._run_test(self.evaluator.test_arithmetic_adc_internals, verbose)
1536
+ self._run_test(self.evaluator.test_arithmetic_cmp_internals, verbose)
1537
+ self._run_test(self.evaluator.test_arithmetic_sbc_internals, verbose)
1538
+ self._run_test(self.evaluator.test_arithmetic_sub_internals, verbose)
1539
+ self._run_test(self.evaluator.test_arithmetic_equality_internals, verbose)
1540
+ self._run_test(self.evaluator.test_arithmetic_rol_ror, verbose)
1541
+ self._run_test(self.evaluator.test_arithmetic_div_stages, verbose)
1542
+ self._run_test(self.evaluator.test_arithmetic_div_outputs, verbose)
1543
+ self._run_test(self.evaluator.test_arithmetic_multiplier_internals, verbose)
1544
+ self._run_test(self.evaluator.test_arithmetic_ripple_internals, verbose)
1545
+ self._run_test(self.evaluator.test_arithmetic_equality_final, verbose)
1546
+ self._run_test(self.evaluator.test_arithmetic_small_multipliers, verbose)
1547
+
1548
+ # Threshold gates
1549
+ if verbose:
1550
+ print("\n=== THRESHOLD GATES ===")
1551
+ for result in self.evaluator.test_threshold_gates():
1552
+ self.results.append(result)
1553
+ if verbose:
1554
+ self._print_result(result)
1555
+ self._run_test(self.evaluator.test_threshold_atleastk_4, verbose)
1556
+ self._run_test(self.evaluator.test_threshold_atmostk_4, verbose)
1557
+ self._run_test(self.evaluator.test_threshold_exactlyk_4, verbose)
1558
+ self._run_test(self.evaluator.test_threshold_majority, verbose)
1559
+ self._run_test(self.evaluator.test_threshold_minority, verbose)
1560
+
1561
+ # Modular arithmetic
1562
+ if verbose:
1563
+ print("\n=== MODULAR ARITHMETIC ===")
1564
+ for mod in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]:
1565
+ if self.registry.has(f'modular.mod{mod}.weight') or \
1566
+ self.registry.has(f'modular.mod{mod}.layer1.geq0.weight'):
1567
+ self._run_test(lambda m=mod: self.evaluator.test_modular(m), verbose)
1568
+
1569
+ # Combinational
1570
+ if verbose:
1571
+ print("\n=== COMBINATIONAL ===")
1572
+ self._run_test(self.evaluator.test_decoder_3to8, verbose)
1573
+ self._run_test(self.evaluator.test_encoder_8to3, verbose)
1574
+ self._run_test(self.evaluator.test_mux_2to1, verbose)
1575
+ self._run_test(self.evaluator.test_demux_1to2, verbose)
1576
+ self._run_test(self.evaluator.test_barrel_shifter, verbose)
1577
+ self._run_test(self.evaluator.test_mux_4to1, verbose)
1578
+ self._run_test(self.evaluator.test_mux_8to1, verbose)
1579
+ self._run_test(self.evaluator.test_demux_1to4, verbose)
1580
+ self._run_test(self.evaluator.test_demux_1to8, verbose)
1581
+ self._run_test(self.evaluator.test_priority_encoder, verbose)
1582
+
1583
+ # Pattern recognition
1584
+ if verbose:
1585
+ print("\n=== PATTERN RECOGNITION ===")
1586
+ self._run_test(self.evaluator.test_popcount, verbose)
1587
+ self._run_test(self.evaluator.test_allzeros, verbose)
1588
+ self._run_test(self.evaluator.test_allones, verbose)
1589
+ self._run_test(self.evaluator.test_hamming_distance, verbose)
1590
+ self._run_test(self.evaluator.test_one_hot_detector, verbose)
1591
+ self._run_test(self.evaluator.test_alternating_pattern, verbose)
1592
+ self._run_test(self.evaluator.test_symmetry_detector, verbose)
1593
+ self._run_test(self.evaluator.test_leading_ones, verbose)
1594
+ self._run_test(self.evaluator.test_run_length, verbose)
1595
+ self._run_test(self.evaluator.test_trailing_ones, verbose)
1596
+
1597
+ # Division
1598
+ if verbose:
1599
+ print("\n=== DIVISION ===")
1600
+ self._run_test(self.evaluator.test_division_8bit, verbose)
1601
+
1602
+ elapsed = time.time() - start
1603
+
1604
+ # Summary
1605
+ total_passed = sum(r.passed for r in self.results)
1606
+ total_tests = sum(r.total for r in self.results)
1607
+
1608
+ print("\n" + "=" * 60)
1609
+ print("SUMMARY")
1610
+ print("=" * 60)
1611
+ print(f"Total: {total_passed}/{total_tests} ({100*total_passed/total_tests:.4f}%)")
1612
+ print(f"Time: {elapsed:.2f}s")
1613
+
1614
+ failed = [r for r in self.results if not r.success]
1615
+ if failed:
1616
+ print(f"\nFailed circuits ({len(failed)}):")
1617
+ for r in failed:
1618
+ print(f" {r.circuit_name}: {r.passed}/{r.total} ({100*r.rate:.2f}%)")
1619
+ if r.failures:
1620
+ print(f" First failure: input={r.failures[0][0]}, expected={r.failures[0][1]}, got={r.failures[0][2]}")
1621
+ else:
1622
+ print("\nAll circuits passed!")
1623
+
1624
+ print("\n" + "=" * 60)
1625
+ print(self.registry.coverage_report())
1626
+
1627
+ return total_passed / total_tests if total_tests > 0 else 0.0
1628
+
1629
+ def _run_test(self, test_fn: Callable, verbose: bool):
1630
+ try:
1631
+ result = test_fn()
1632
+ self.results.append(result)
1633
+ if verbose:
1634
+ self._print_result(result)
1635
+ except Exception as e:
1636
+ print(f" ERROR: {e}")
1637
+
1638
+ def _print_result(self, result: TestResult):
1639
+ status = "PASS" if result.success else "FAIL"
1640
+ print(f" {result.circuit_name}: {result.passed}/{result.total} [{status}]")
1641
+ if not result.success and result.failures:
1642
+ print(f" First failure: {result.failures[0]}")
1643
+
1644
+
1645
+ def main():
1646
+ import argparse
1647
+ parser = argparse.ArgumentParser(description='Arithmetic circuit evaluator for threshold-calculus')
1648
+ parser.add_argument('--model', type=str, default='./arithmetic.safetensors',
1649
+ help='Path to safetensors model')
1650
+ parser.add_argument('--device', type=str, default='cuda',
1651
+ help='Device (cuda or cpu)')
1652
+ parser.add_argument('--quiet', action='store_true',
1653
+ help='Suppress verbose output')
1654
+ args = parser.parse_args()
1655
+
1656
+ evaluator = ArithmeticEvaluator(args.model, args.device)
1657
+ fitness = evaluator.run_all(verbose=not args.quiet)
1658
+
1659
+ print(f"\nFitness: {fitness:.6f}")
1660
+ return 0 if fitness >= 0.9999 else 1
1661
+
1662
+
1663
+ if __name__ == '__main__':
1664
+ exit(main())