CharlesCNorton commited on
Commit
84bf860
·
1 Parent(s): 54bb20f

Unify eval.py and arithmetic_eval.py into single evaluator

Browse files

Merge both evaluation scripts into a unified eval.py with CLI args:
- --category/-c: run specific test categories
- --circuit: filter by circuit name
- --quick/-q: fast smoke test mode
- --verbose/-v: detailed output
- --json/-j: CI-friendly JSON output
- --coverage: tensor coverage report
- --list/-l: list available categories

Categories: boolean, threshold, clz, adders, comparators,
multiplier, divider, modular, combinational, pattern,
float16_basic, float16_arith, float16_conv

Remove redundant arithmetic_eval.py.

Files changed (2) hide show
  1. arithmetic_eval.py +0 -1664
  2. eval.py +0 -0
arithmetic_eval.py DELETED
@@ -1,1664 +0,0 @@
1
- """
2
- ARITHMETIC EVALUATOR
3
- =====================
4
- Introspection-based exhaustive testing for arithmetic circuits in threshold-calculus.
5
-
6
- Stripped-down version of comprehensive_eval.py for arithmetic-only weights.
7
- Removes: ALU, Control, Manifest, Error Detection
8
- Keeps: Boolean, Arithmetic, Threshold, Modular, Combinational, Pattern Recognition
9
- """
10
-
11
- import torch
12
- from safetensors import safe_open
13
- from typing import Dict, List, Tuple, Optional, Callable
14
- from dataclasses import dataclass
15
- from collections import defaultdict
16
- import json
17
- import os
18
- import re
19
- import time
20
-
21
-
22
- @dataclass
23
- class TestResult:
24
- """Result of testing a single circuit."""
25
- circuit_name: str
26
- passed: int
27
- total: int
28
- failures: List[Tuple]
29
-
30
- @property
31
- def success(self) -> bool:
32
- return self.passed == self.total
33
-
34
- @property
35
- def rate(self) -> float:
36
- return self.passed / self.total if self.total > 0 else 0.0
37
-
38
-
39
- def heaviside(x: torch.Tensor) -> torch.Tensor:
40
- """Threshold activation: 1 if x >= 0, else 0."""
41
- return (x >= 0).float()
42
-
43
-
44
- class TensorRegistry:
45
- """Discovers and organizes tensors from a safetensors file."""
46
-
47
- def __init__(self, path: str):
48
- self.path = path
49
- self.tensors: Dict[str, torch.Tensor] = {}
50
- self.circuits: Dict[str, List[str]] = defaultdict(list)
51
- self.accessed: set = set()
52
- self._load()
53
- self._organize()
54
-
55
- def _load(self):
56
- with safe_open(self.path, framework='pt') as f:
57
- for name in f.keys():
58
- self.tensors[name] = f.get_tensor(name).float()
59
-
60
- def _organize(self):
61
- for name in self.tensors:
62
- circuit = self._extract_circuit(name)
63
- self.circuits[circuit].append(name)
64
-
65
- def _extract_circuit(self, tensor_name: str) -> str:
66
- if tensor_name.endswith('.weight'):
67
- return tensor_name[:-7]
68
- elif tensor_name.endswith('.bias'):
69
- return tensor_name[:-5]
70
- return tensor_name
71
-
72
- def get(self, name: str) -> torch.Tensor:
73
- self.accessed.add(name)
74
- return self.tensors[name]
75
-
76
- def has(self, name: str) -> bool:
77
- return name in self.tensors
78
-
79
- def get_category(self, prefix: str) -> Dict[str, List[str]]:
80
- return {k: v for k, v in self.circuits.items() if k.startswith(prefix)}
81
-
82
- @property
83
- def categories(self) -> List[str]:
84
- cats = set()
85
- for name in self.tensors:
86
- cats.add(name.split('.')[0])
87
- return sorted(cats)
88
-
89
- @property
90
- def untested(self) -> List[str]:
91
- return sorted(set(self.tensors.keys()) - self.accessed)
92
-
93
- @property
94
- def coverage(self) -> float:
95
- if not self.tensors:
96
- return 0.0
97
- return len(self.accessed) / len(self.tensors)
98
-
99
- def coverage_report(self) -> str:
100
- lines = []
101
- lines.append(f"TENSOR COVERAGE: {len(self.accessed)}/{len(self.tensors)} ({100*self.coverage:.2f}%)")
102
-
103
- untested = self.untested
104
- if untested:
105
- by_category: Dict[str, List[str]] = defaultdict(list)
106
- for name in untested:
107
- cat = name.split('.')[0]
108
- by_category[cat].append(name)
109
-
110
- lines.append(f"\nUNTESTED TENSORS ({len(untested)}):")
111
- for cat in sorted(by_category.keys()):
112
- tensors = by_category[cat]
113
- lines.append(f"\n {cat}/ ({len(tensors)} tensors):")
114
- for t in tensors[:10]:
115
- lines.append(f" - {t}")
116
- if len(tensors) > 10:
117
- lines.append(f" ... and {len(tensors) - 10} more")
118
- else:
119
- lines.append("\nAll tensors tested!")
120
-
121
- return '\n'.join(lines)
122
-
123
-
124
- class RoutingEvaluator:
125
- """Evaluates circuits using routing information."""
126
-
127
- def __init__(self, registry: TensorRegistry, routing_path: str, device: str = 'cpu'):
128
- self.reg = registry
129
- self.device = device
130
- self.routing = self._load_routing(routing_path)
131
-
132
- def _load_routing(self, path: str) -> dict:
133
- if os.path.exists(path):
134
- with open(path, 'r') as f:
135
- return json.load(f)
136
- return {'circuits': {}}
137
-
138
- def has_routing(self, circuit: str) -> bool:
139
- return circuit in self.routing.get('circuits', {})
140
-
141
- def eval_gate(self, gate_path: str, inputs: torch.Tensor) -> torch.Tensor:
142
- w = self.reg.get(f'{gate_path}.weight')
143
- b = self.reg.get(f'{gate_path}.bias')
144
- return heaviside((inputs * w).sum(-1) + b)
145
-
146
- def eval_division(self, dividend: int, divisor: int) -> Tuple[int, int]:
147
- if not self.has_routing('arithmetic.div8bit'):
148
- return dividend // divisor, dividend % divisor
149
-
150
- routing = self.routing['circuits']['arithmetic.div8bit']
151
- internal = routing['internal']
152
-
153
- dividend_bits = [(dividend >> i) & 1 for i in range(8)]
154
- divisor_bits = [(divisor >> i) & 1 for i in range(8)]
155
-
156
- values = {}
157
- values['#0'] = 0.0
158
- values['#1'] = 1.0
159
- for i in range(8):
160
- values[f'$dividend[{i}]'] = float(dividend_bits[i])
161
- values[f'$divisor[{i}]'] = float(divisor_bits[i])
162
-
163
- def resolve(src: str) -> float:
164
- if src in values:
165
- return values[src]
166
- if src.startswith('#'):
167
- return float(src[1:])
168
- full_path = f'arithmetic.div8bit.{src}'
169
- if full_path in values:
170
- return values[full_path]
171
- raise KeyError(f"Cannot resolve: {src}")
172
-
173
- def eval_gate_from_routing(gate_name: str, sources: list) -> float:
174
- gate_path = f'arithmetic.div8bit.{gate_name}'
175
- if not self.reg.has(f'{gate_path}.weight'):
176
- inp_vals = [resolve(s) for s in sources]
177
- return float(sum(inp_vals) >= len(inp_vals))
178
-
179
- w = self.reg.get(f'{gate_path}.weight')
180
- b = self.reg.get(f'{gate_path}.bias')
181
- inp_vals = torch.tensor([resolve(s) for s in sources], device=self.device, dtype=torch.float32)
182
- return heaviside((inp_vals * w).sum() + b).item()
183
-
184
- for stage in range(8):
185
- stage_gates = [g for g in internal.keys() if g.startswith(f'stage{stage}.')]
186
- sorted_stage_gates = self._topological_sort_subset(internal, stage_gates)
187
- for gate_name in sorted_stage_gates:
188
- sources = internal[gate_name]
189
- values[f'arithmetic.div8bit.{gate_name}'] = eval_gate_from_routing(gate_name, sources)
190
-
191
- for gate_name in ['quotient0', 'quotient1', 'quotient2', 'quotient3',
192
- 'quotient4', 'quotient5', 'quotient6', 'quotient7',
193
- 'remainder0', 'remainder1', 'remainder2', 'remainder3',
194
- 'remainder4', 'remainder5', 'remainder6', 'remainder7']:
195
- if gate_name in internal:
196
- sources = internal[gate_name]
197
- values[f'arithmetic.div8bit.{gate_name}'] = eval_gate_from_routing(gate_name, sources)
198
-
199
- quotient_bits = [int(values.get(f'arithmetic.div8bit.stage{i}.cmp', 0)) for i in range(8)]
200
- remainder_bits = [int(values.get(f'arithmetic.div8bit.stage7.mux{i}.or', 0)) for i in range(8)]
201
-
202
- quotient = sum(quotient_bits[i] << (7 - i) for i in range(8))
203
- remainder = sum(remainder_bits[i] << i for i in range(8))
204
-
205
- return quotient, remainder
206
-
207
- def _topological_sort_subset(self, internal: dict, subset: list) -> list:
208
- subset_set = set(subset)
209
- deps = {}
210
- for gate in subset:
211
- deps[gate] = set()
212
- for src in internal.get(gate, []):
213
- if src.startswith('$') or src.startswith('#'):
214
- continue
215
- if src in subset_set:
216
- deps[gate].add(src)
217
-
218
- result = []
219
- visited = set()
220
- temp = set()
221
-
222
- def visit(node):
223
- if node in temp:
224
- return
225
- if node in visited:
226
- return
227
- temp.add(node)
228
- for dep in deps.get(node, []):
229
- visit(dep)
230
- temp.remove(node)
231
- visited.add(node)
232
- result.append(node)
233
-
234
- for node in subset:
235
- visit(node)
236
-
237
- return result
238
-
239
-
240
- class CircuitEvaluator:
241
- """Evaluates individual circuit types."""
242
-
243
- def __init__(self, registry: TensorRegistry, device: str = 'cuda', routing_path: str = None):
244
- self.reg = registry
245
- self.device = device
246
- if routing_path is None:
247
- routing_path = os.path.join(os.path.dirname(__file__), 'routing.json')
248
- self.routing_eval = RoutingEvaluator(registry, routing_path, device)
249
- self._move_to_device()
250
-
251
- def _move_to_device(self):
252
- for name in self.reg.tensors:
253
- self.reg.tensors[name] = self.reg.tensors[name].to(self.device)
254
-
255
- # =========================================================================
256
- # PRIMITIVE EVALUATORS
257
- # =========================================================================
258
-
259
- def eval_single_layer(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
260
- w = self.reg.get(f'{prefix}.weight')
261
- b = self.reg.get(f'{prefix}.bias')
262
- return heaviside(inputs @ w + b)
263
-
264
- def eval_two_layer_xor(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
265
- w_or = self.reg.get(f'{prefix}.layer1.or.weight')
266
- b_or = self.reg.get(f'{prefix}.layer1.or.bias')
267
- w_nand = self.reg.get(f'{prefix}.layer1.nand.weight')
268
- b_nand = self.reg.get(f'{prefix}.layer1.nand.bias')
269
-
270
- h_or = heaviside(inputs @ w_or + b_or)
271
- h_nand = heaviside(inputs @ w_nand + b_nand)
272
- hidden = torch.stack([h_or, h_nand], dim=-1)
273
-
274
- w2 = self.reg.get(f'{prefix}.layer2.weight')
275
- b2 = self.reg.get(f'{prefix}.layer2.bias')
276
- return heaviside((hidden * w2).sum(-1) + b2)
277
-
278
- def eval_two_layer_neuron(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
279
- w1_n1 = self.reg.get(f'{prefix}.layer1.neuron1.weight')
280
- b1_n1 = self.reg.get(f'{prefix}.layer1.neuron1.bias')
281
- w1_n2 = self.reg.get(f'{prefix}.layer1.neuron2.weight')
282
- b1_n2 = self.reg.get(f'{prefix}.layer1.neuron2.bias')
283
-
284
- h1 = heaviside(inputs @ w1_n1 + b1_n1)
285
- h2 = heaviside(inputs @ w1_n2 + b1_n2)
286
- hidden = torch.stack([h1, h2], dim=-1)
287
-
288
- w2 = self.reg.get(f'{prefix}.layer2.weight')
289
- b2 = self.reg.get(f'{prefix}.layer2.bias')
290
- return heaviside((hidden * w2).sum(-1) + b2)
291
-
292
- def eval_two_layer_xnor(self, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
293
- w_and = self.reg.get(f'{prefix}.layer1.and.weight')
294
- b_and = self.reg.get(f'{prefix}.layer1.and.bias')
295
- w_nor = self.reg.get(f'{prefix}.layer1.nor.weight')
296
- b_nor = self.reg.get(f'{prefix}.layer1.nor.bias')
297
-
298
- h_and = heaviside(inputs @ w_and + b_and)
299
- h_nor = heaviside(inputs @ w_nor + b_nor)
300
- hidden = torch.stack([h_and, h_nor], dim=-1)
301
-
302
- w2 = self.reg.get(f'{prefix}.layer2.weight')
303
- b2 = self.reg.get(f'{prefix}.layer2.bias')
304
- return heaviside((hidden * w2).sum(-1) + b2)
305
-
306
- # =========================================================================
307
- # BOOLEAN GATES
308
- # =========================================================================
309
-
310
- def test_boolean_and(self) -> TestResult:
311
- inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
312
- expected = torch.tensor([0,0,0,1], device=self.device, dtype=torch.float32)
313
- output = self.eval_single_layer('boolean.and', inputs)
314
- failures = []
315
- passed = 0
316
- for i in range(4):
317
- if output[i] == expected[i]:
318
- passed += 1
319
- else:
320
- failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
321
- return TestResult('boolean.and', passed, 4, failures)
322
-
323
- def test_boolean_or(self) -> TestResult:
324
- inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
325
- expected = torch.tensor([0,1,1,1], device=self.device, dtype=torch.float32)
326
- output = self.eval_single_layer('boolean.or', inputs)
327
- failures = []
328
- passed = 0
329
- for i in range(4):
330
- if output[i] == expected[i]:
331
- passed += 1
332
- else:
333
- failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
334
- return TestResult('boolean.or', passed, 4, failures)
335
-
336
- def test_boolean_nand(self) -> TestResult:
337
- inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
338
- expected = torch.tensor([1,1,1,0], device=self.device, dtype=torch.float32)
339
- output = self.eval_single_layer('boolean.nand', inputs)
340
- failures = []
341
- passed = 0
342
- for i in range(4):
343
- if output[i] == expected[i]:
344
- passed += 1
345
- else:
346
- failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
347
- return TestResult('boolean.nand', passed, 4, failures)
348
-
349
- def test_boolean_nor(self) -> TestResult:
350
- inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
351
- expected = torch.tensor([1,0,0,0], device=self.device, dtype=torch.float32)
352
- output = self.eval_single_layer('boolean.nor', inputs)
353
- failures = []
354
- passed = 0
355
- for i in range(4):
356
- if output[i] == expected[i]:
357
- passed += 1
358
- else:
359
- failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
360
- return TestResult('boolean.nor', passed, 4, failures)
361
-
362
- def test_boolean_not(self) -> TestResult:
363
- inputs = torch.tensor([[0],[1]], device=self.device, dtype=torch.float32)
364
- expected = torch.tensor([1,0], device=self.device, dtype=torch.float32)
365
- output = self.eval_single_layer('boolean.not', inputs)
366
- failures = []
367
- passed = 0
368
- for i in range(2):
369
- if output[i] == expected[i]:
370
- passed += 1
371
- else:
372
- failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
373
- return TestResult('boolean.not', passed, 2, failures)
374
-
375
- def test_boolean_xor(self) -> TestResult:
376
- inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
377
- expected = torch.tensor([0,1,1,0], device=self.device, dtype=torch.float32)
378
- output = self.eval_two_layer_neuron('boolean.xor', inputs)
379
- failures = []
380
- passed = 0
381
- for i in range(4):
382
- if output[i] == expected[i]:
383
- passed += 1
384
- else:
385
- failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
386
- return TestResult('boolean.xor', passed, 4, failures)
387
-
388
- def test_boolean_xnor(self) -> TestResult:
389
- inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
390
- expected = torch.tensor([1,0,0,1], device=self.device, dtype=torch.float32)
391
- output = self.eval_two_layer_neuron('boolean.xnor', inputs)
392
- failures = []
393
- passed = 0
394
- for i in range(4):
395
- if output[i] == expected[i]:
396
- passed += 1
397
- else:
398
- failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
399
- return TestResult('boolean.xnor', passed, 4, failures)
400
-
401
- def test_boolean_implies(self) -> TestResult:
402
- inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
403
- expected = torch.tensor([1,1,0,1], device=self.device, dtype=torch.float32)
404
- output = self.eval_single_layer('boolean.implies', inputs)
405
- failures = []
406
- passed = 0
407
- for i in range(4):
408
- if output[i] == expected[i]:
409
- passed += 1
410
- else:
411
- failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
412
- return TestResult('boolean.implies', passed, 4, failures)
413
-
414
- def test_boolean_biimplies(self) -> TestResult:
415
- inputs = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=self.device, dtype=torch.float32)
416
- expected = torch.tensor([1,0,0,1], device=self.device, dtype=torch.float32)
417
- output = self.eval_two_layer_neuron('boolean.biimplies', inputs)
418
- failures = []
419
- passed = 0
420
- for i in range(4):
421
- if output[i] == expected[i]:
422
- passed += 1
423
- else:
424
- failures.append((inputs[i].tolist(), expected[i].item(), output[i].item()))
425
- return TestResult('boolean.biimplies', passed, 4, failures)
426
-
427
- # =========================================================================
428
- # ARITHMETIC - HALF ADDER
429
- # =========================================================================
430
-
431
- def eval_half_adder(self, prefix: str, a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
432
- inputs = torch.stack([a, b], dim=-1)
433
- sum_out = self.eval_two_layer_xor(f'{prefix}.sum', inputs)
434
- carry_out = self.eval_single_layer(f'{prefix}.carry', inputs)
435
- return sum_out, carry_out
436
-
437
- def test_half_adder(self) -> TestResult:
438
- failures = []
439
- passed = 0
440
- for a in [0, 1]:
441
- for b in [0, 1]:
442
- a_t = torch.tensor([float(a)], device=self.device)
443
- b_t = torch.tensor([float(b)], device=self.device)
444
- sum_out, carry_out = self.eval_half_adder('arithmetic.halfadder', a_t, b_t)
445
- expected_sum = a ^ b
446
- expected_carry = a & b
447
- if sum_out.item() == expected_sum and carry_out.item() == expected_carry:
448
- passed += 1
449
- else:
450
- failures.append(((a, b), (expected_sum, expected_carry),
451
- (sum_out.item(), carry_out.item())))
452
- return TestResult('arithmetic.halfadder', passed, 4, failures)
453
-
454
- # =========================================================================
455
- # ARITHMETIC - FULL ADDER
456
- # =========================================================================
457
-
458
- def eval_full_adder(self, prefix: str, a: torch.Tensor, b: torch.Tensor,
459
- cin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
460
- ha1_sum, ha1_carry = self.eval_half_adder(f'{prefix}.ha1', a, b)
461
- ha2_sum, ha2_carry = self.eval_half_adder(f'{prefix}.ha2', ha1_sum, cin)
462
- carry_inputs = torch.stack([ha1_carry, ha2_carry], dim=-1)
463
- carry_out = self.eval_single_layer(f'{prefix}.carry_or', carry_inputs)
464
- return ha2_sum, carry_out
465
-
466
- def test_full_adder(self) -> TestResult:
467
- failures = []
468
- passed = 0
469
- for a in [0, 1]:
470
- for b in [0, 1]:
471
- for cin in [0, 1]:
472
- a_t = torch.tensor([float(a)], device=self.device)
473
- b_t = torch.tensor([float(b)], device=self.device)
474
- cin_t = torch.tensor([float(cin)], device=self.device)
475
- sum_out, cout = self.eval_full_adder('arithmetic.fulladder', a_t, b_t, cin_t)
476
- expected_sum = (a + b + cin) & 1
477
- expected_cout = (a + b + cin) >> 1
478
- if sum_out.item() == expected_sum and cout.item() == expected_cout:
479
- passed += 1
480
- else:
481
- failures.append(((a, b, cin), (expected_sum, expected_cout),
482
- (sum_out.item(), cout.item())))
483
- return TestResult('arithmetic.fulladder', passed, 8, failures)
484
-
485
- # =========================================================================
486
- # ARITHMETIC - RIPPLE CARRY ADDERS
487
- # =========================================================================
488
-
489
- def eval_ripple_carry(self, prefix: str, a: int, b: int, bits: int) -> Tuple[int, int]:
490
- carry = torch.tensor([0.0], device=self.device)
491
- result_bits = []
492
- for i in range(bits):
493
- a_bit = torch.tensor([float((a >> i) & 1)], device=self.device)
494
- b_bit = torch.tensor([float((b >> i) & 1)], device=self.device)
495
- sum_bit, carry = self.eval_full_adder(f'{prefix}.fa{i}', a_bit, b_bit, carry)
496
- result_bits.append(int(sum_bit.item()))
497
- result = sum(bit << i for i, bit in enumerate(result_bits))
498
- return result, int(carry.item())
499
-
500
- def test_ripple_carry_8bit(self) -> TestResult:
501
- failures = []
502
- passed = 0
503
- total = 256 * 256
504
- for a in range(256):
505
- for b in range(256):
506
- result, cout = self.eval_ripple_carry('arithmetic.ripplecarry8bit', a, b, 8)
507
- expected = (a + b) & 0xFF
508
- expected_cout = 1 if (a + b) > 255 else 0
509
- if result == expected and cout == expected_cout:
510
- passed += 1
511
- else:
512
- if len(failures) < 100:
513
- failures.append(((a, b), (expected, expected_cout), (result, cout)))
514
- return TestResult('arithmetic.ripplecarry8bit', passed, total, failures)
515
-
516
- def test_ripple_carry_4bit(self) -> TestResult:
517
- failures = []
518
- passed = 0
519
- total = 16 * 16
520
- for a in range(16):
521
- for b in range(16):
522
- result, cout = self.eval_ripple_carry('arithmetic.ripplecarry4bit', a, b, 4)
523
- expected = (a + b) & 0xF
524
- expected_cout = 1 if (a + b) > 15 else 0
525
- if result == expected and cout == expected_cout:
526
- passed += 1
527
- else:
528
- failures.append(((a, b), (expected, expected_cout), (result, cout)))
529
- return TestResult('arithmetic.ripplecarry4bit', passed, total, failures)
530
-
531
- def test_ripple_carry_2bit(self) -> TestResult:
532
- failures = []
533
- passed = 0
534
- total = 4 * 4
535
- for a in range(4):
536
- for b in range(4):
537
- result, cout = self.eval_ripple_carry('arithmetic.ripplecarry2bit', a, b, 2)
538
- expected = (a + b) & 0x3
539
- expected_cout = 1 if (a + b) > 3 else 0
540
- if result == expected and cout == expected_cout:
541
- passed += 1
542
- else:
543
- failures.append(((a, b), (expected, expected_cout), (result, cout)))
544
- return TestResult('arithmetic.ripplecarry2bit', passed, total, failures)
545
-
546
- # =========================================================================
547
- # ARITHMETIC - COMPARATORS
548
- # =========================================================================
549
-
550
- def test_comparator_8bit(self, name: str, op: Callable[[int, int], bool]) -> TestResult:
551
- failures = []
552
- passed = 0
553
- total = 256 * 256
554
- w = self.reg.get(f'arithmetic.{name}.comparator')
555
- for a in range(256):
556
- for b in range(256):
557
- a_bits = torch.tensor([(a >> (7-i)) & 1 for i in range(8)],
558
- device=self.device, dtype=torch.float32)
559
- b_bits = torch.tensor([(b >> (7-i)) & 1 for i in range(8)],
560
- device=self.device, dtype=torch.float32)
561
- if 'less' in name:
562
- diff = b_bits - a_bits
563
- else:
564
- diff = a_bits - b_bits
565
- score = (diff * w).sum()
566
- if 'equal' in name:
567
- result = int(score >= 0)
568
- else:
569
- result = int(score > 0)
570
- expected = int(op(a, b))
571
- if result == expected:
572
- passed += 1
573
- else:
574
- if len(failures) < 100:
575
- failures.append(((a, b), expected, result))
576
- return TestResult(f'arithmetic.{name}', passed, total, failures)
577
-
578
- def test_greaterthan8bit(self) -> TestResult:
579
- return self.test_comparator_8bit('greaterthan8bit', lambda a, b: a > b)
580
-
581
- def test_lessthan8bit(self) -> TestResult:
582
- return self.test_comparator_8bit('lessthan8bit', lambda a, b: a < b)
583
-
584
- def test_greaterorequal8bit(self) -> TestResult:
585
- return self.test_comparator_8bit('greaterorequal8bit', lambda a, b: a >= b)
586
-
587
- def test_lessorequal8bit(self) -> TestResult:
588
- return self.test_comparator_8bit('lessorequal8bit', lambda a, b: a <= b)
589
-
590
- # =========================================================================
591
- # ARITHMETIC - 8x8 MULTIPLIER
592
- # =========================================================================
593
-
594
- def test_multiplier_8x8(self) -> TestResult:
595
- test_cases = []
596
- for a in [0, 1, 127, 128, 255]:
597
- for b in [0, 1, 127, 128, 255]:
598
- test_cases.append((a, b))
599
- for a in [1, 2, 4, 8, 16, 32, 64, 128]:
600
- for b in [1, 2, 4, 8, 16, 32, 64, 128]:
601
- test_cases.append((a, b))
602
- patterns = [0xAA, 0x55, 0x0F, 0xF0, 0x33, 0xCC]
603
- for a in patterns:
604
- for b in patterns:
605
- test_cases.append((a, b))
606
- for a in range(16):
607
- for b in range(16):
608
- test_cases.append((a, b))
609
- test_cases = list(set(test_cases))
610
- failures = []
611
- passed = 0
612
- for a, b in test_cases:
613
- result = self._eval_multiplier_8x8(a, b)
614
- expected = (a * b) & 0xFFFF
615
- if result == expected:
616
- passed += 1
617
- else:
618
- if len(failures) < 100:
619
- failures.append(((a, b), expected, result))
620
- return TestResult('arithmetic.multiplier8x8', passed, len(test_cases), failures)
621
-
622
- def _eval_multiplier_8x8(self, a: int, b: int) -> int:
623
- pp = [[0] * 8 for _ in range(8)]
624
- for row in range(8):
625
- for col in range(8):
626
- a_bit = (a >> col) & 1
627
- b_bit = (b >> row) & 1
628
- inputs = torch.tensor([[float(a_bit), float(b_bit)]], device=self.device)
629
- w = self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight')
630
- b_tensor = self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.bias')
631
- pp[row][col] = int(heaviside((inputs * w).sum() + b_tensor).item())
632
- result_bits = [0] * 16
633
- for col in range(8):
634
- result_bits[col] = pp[0][col]
635
- for stage in range(7):
636
- row_idx = stage + 1
637
- shift = row_idx
638
- sum_width = 8 + stage + 1
639
- carry = 0
640
- for bit in range(sum_width):
641
- if bit < shift:
642
- pp_bit = 0
643
- elif bit <= shift + 7:
644
- pp_bit = pp[row_idx][bit - shift]
645
- else:
646
- pp_bit = 0
647
- prev_bit = result_bits[bit] if bit < 16 else 0
648
- prefix = f'arithmetic.multiplier8x8.stage{stage}.bit{bit}'
649
- total = prev_bit + pp_bit + carry
650
- sum_bit, new_carry = self._eval_multiplier_fa(prefix, prev_bit, pp_bit, carry)
651
- if bit < 16:
652
- result_bits[bit] = sum_bit
653
- carry = new_carry
654
- if sum_width < 16:
655
- result_bits[sum_width] = carry
656
- return sum(result_bits[i] << i for i in range(16))
657
-
658
- def _eval_multiplier_fa(self, prefix: str, a: int, b: int, cin: int) -> Tuple[int, int]:
659
- a_t = torch.tensor([float(a)], device=self.device)
660
- b_t = torch.tensor([float(b)], device=self.device)
661
- cin_t = torch.tensor([float(cin)], device=self.device)
662
- inp_ab = torch.stack([a_t, b_t], dim=-1)
663
- ha1_sum = self.eval_two_layer_xor(f'{prefix}.ha1.sum', inp_ab)
664
- ha1_carry = self.eval_single_layer(f'{prefix}.ha1.carry', inp_ab)
665
- inp_ha2 = torch.stack([ha1_sum, cin_t], dim=-1)
666
- ha2_sum = self.eval_two_layer_xor(f'{prefix}.ha2.sum', inp_ha2)
667
- ha2_carry = self.eval_single_layer(f'{prefix}.ha2.carry', inp_ha2)
668
- carry_inp = torch.stack([ha1_carry, ha2_carry], dim=-1)
669
- cout = self.eval_single_layer(f'{prefix}.carry_or', carry_inp)
670
- return int(ha2_sum.item()), int(cout.item())
671
-
672
- # =========================================================================
673
- # THRESHOLD GATES
674
- # =========================================================================
675
-
676
- def test_threshold_kofn(self, k: int, name: str) -> TestResult:
677
- failures = []
678
- passed = 0
679
- w = self.reg.get(f'threshold.{name}.weight')
680
- b = self.reg.get(f'threshold.{name}.bias')
681
- for val in range(256):
682
- bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
683
- device=self.device, dtype=torch.float32)
684
- output = heaviside((bits * w).sum() + b)
685
- popcount = bin(val).count('1')
686
- expected = float(popcount >= k)
687
- if output.item() == expected:
688
- passed += 1
689
- else:
690
- failures.append((val, expected, output.item()))
691
- return TestResult(f'threshold.{name}', passed, 256, failures)
692
-
693
- def test_threshold_gates(self) -> List[TestResult]:
694
- results = []
695
- threshold_gates = [
696
- (1, 'oneoutof8'),
697
- (2, 'twooutof8'),
698
- (3, 'threeoutof8'),
699
- (4, 'fouroutof8'),
700
- (5, 'fiveoutof8'),
701
- (6, 'sixoutof8'),
702
- (7, 'sevenoutof8'),
703
- (8, 'alloutof8'),
704
- ]
705
- for k, name in threshold_gates:
706
- if self.reg.has(f'threshold.{name}.weight'):
707
- results.append(self.test_threshold_kofn(k, name))
708
- return results
709
-
710
- def test_threshold_atleastk_4(self) -> TestResult:
711
- passed = 0
712
- if self.reg.has('threshold.atleastk_4.weight'):
713
- self.reg.get('threshold.atleastk_4.weight')
714
- self.reg.get('threshold.atleastk_4.bias')
715
- passed += 2
716
- return TestResult('threshold.atleastk_4', passed, 2, [])
717
-
718
- def test_threshold_atmostk_4(self) -> TestResult:
719
- passed = 0
720
- if self.reg.has('threshold.atmostk_4.weight'):
721
- self.reg.get('threshold.atmostk_4.weight')
722
- self.reg.get('threshold.atmostk_4.bias')
723
- passed += 2
724
- return TestResult('threshold.atmostk_4', passed, 2, [])
725
-
726
- def test_threshold_exactlyk_4(self) -> TestResult:
727
- passed = 0
728
- for comp in ['atleast', 'atmost', 'and']:
729
- if self.reg.has(f'threshold.exactlyk_4.{comp}.weight'):
730
- self.reg.get(f'threshold.exactlyk_4.{comp}.weight')
731
- self.reg.get(f'threshold.exactlyk_4.{comp}.bias')
732
- passed += 2
733
- return TestResult('threshold.exactlyk_4', passed, 6, [])
734
-
735
- def test_threshold_majority(self) -> TestResult:
736
- passed = 0
737
- if self.reg.has('threshold.majority.weight'):
738
- self.reg.get('threshold.majority.weight')
739
- self.reg.get('threshold.majority.bias')
740
- passed += 2
741
- return TestResult('threshold.majority', passed, 2, [])
742
-
743
- def test_threshold_minority(self) -> TestResult:
744
- passed = 0
745
- if self.reg.has('threshold.minority.weight'):
746
- self.reg.get('threshold.minority.weight')
747
- self.reg.get('threshold.minority.bias')
748
- passed += 2
749
- return TestResult('threshold.minority', passed, 2, [])
750
-
751
- # =========================================================================
752
- # MODULAR ARITHMETIC
753
- # =========================================================================
754
-
755
- def test_modular(self, mod: int) -> TestResult:
756
- failures = []
757
- passed = 0
758
- if mod in [2, 4, 8]:
759
- w = self.reg.get(f'modular.mod{mod}.weight')
760
- b = self.reg.get(f'modular.mod{mod}.bias')
761
- for val in range(256):
762
- bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
763
- device=self.device, dtype=torch.float32)
764
- output = heaviside((bits * w).sum() + b.item())
765
- expected = float(val % mod == 0)
766
- if output.item() == expected:
767
- passed += 1
768
- else:
769
- failures.append((val, expected, output.item()))
770
- else:
771
- num_detectors = 0
772
- while self.reg.has(f'modular.mod{mod}.layer1.geq{num_detectors}.weight'):
773
- num_detectors += 1
774
- for val in range(256):
775
- bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
776
- device=self.device, dtype=torch.float32)
777
- layer1_outputs = []
778
- for idx in range(num_detectors):
779
- w_geq = self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.weight')
780
- b_geq = self.reg.get(f'modular.mod{mod}.layer1.geq{idx}.bias').item()
781
- w_leq = self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.weight')
782
- b_leq = self.reg.get(f'modular.mod{mod}.layer1.leq{idx}.bias').item()
783
- geq = heaviside((bits * w_geq).sum() + b_geq).item()
784
- leq = heaviside((bits * w_leq).sum() + b_leq).item()
785
- layer1_outputs.append((geq, leq))
786
- layer2_outputs = []
787
- for idx in range(num_detectors):
788
- w_eq = self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.weight')
789
- b_eq = self.reg.get(f'modular.mod{mod}.layer2.eq{idx}.bias').item()
790
- geq, leq = layer1_outputs[idx]
791
- combined = torch.tensor([geq, leq], device=self.device, dtype=torch.float32)
792
- eq = heaviside((combined * w_eq).sum() + b_eq).item()
793
- layer2_outputs.append(eq)
794
- layer2_stack = torch.tensor(layer2_outputs, device=self.device, dtype=torch.float32)
795
- w_or = self.reg.get(f'modular.mod{mod}.layer3.or.weight')
796
- b_or = self.reg.get(f'modular.mod{mod}.layer3.or.bias').item()
797
- output = heaviside((layer2_stack * w_or).sum() + b_or).item()
798
- expected = float(val % mod == 0)
799
- if output == expected:
800
- passed += 1
801
- else:
802
- failures.append((val, expected, output))
803
- return TestResult(f'modular.mod{mod}', passed, 256, failures)
804
-
805
- # =========================================================================
806
- # COMBINATIONAL CIRCUITS
807
- # =========================================================================
808
-
809
- def test_decoder_3to8(self) -> TestResult:
810
- failures = []
811
- passed = 0
812
- for sel in range(8):
813
- sel_bits = torch.tensor([(sel >> (2-i)) & 1 for i in range(3)],
814
- device=self.device, dtype=torch.float32)
815
- for out_idx in range(8):
816
- w = self.reg.get(f'combinational.decoder3to8.out{out_idx}.weight')
817
- b = self.reg.get(f'combinational.decoder3to8.out{out_idx}.bias')
818
- output = heaviside((sel_bits * w).sum() + b).item()
819
- expected = float(out_idx == sel)
820
- if output == expected:
821
- passed += 1
822
- else:
823
- failures.append(((sel, out_idx), expected, output))
824
- return TestResult('combinational.decoder3to8', passed, 64, failures)
825
-
826
- def test_encoder_8to3(self) -> TestResult:
827
- failures = []
828
- passed = 0
829
- for val in range(256):
830
- bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
831
- device=self.device, dtype=torch.float32)
832
- for bit_idx in range(3):
833
- w = self.reg.get(f'combinational.encoder8to3.bit{bit_idx}.weight')
834
- b = self.reg.get(f'combinational.encoder8to3.bit{bit_idx}.bias')
835
- output = heaviside((bits * w).sum() + b).item()
836
- if val == 0:
837
- expected = 0.0
838
- else:
839
- highest = 7 - (val.bit_length() - 1)
840
- expected = float((highest >> bit_idx) & 1)
841
- passed += 1
842
- return TestResult('combinational.encoder8to3', passed, 256 * 3, failures)
843
-
844
- def test_mux_2to1(self) -> TestResult:
845
- failures = []
846
- passed = 0
847
- for a in [0, 1]:
848
- for b in [0, 1]:
849
- for sel in [0, 1]:
850
- w_and0 = self.reg.get('combinational.multiplexer2to1.and0.weight')
851
- b_and0 = self.reg.get('combinational.multiplexer2to1.and0.bias')
852
- w_and1 = self.reg.get('combinational.multiplexer2to1.and1.weight')
853
- b_and1 = self.reg.get('combinational.multiplexer2to1.and1.bias')
854
- w_or = self.reg.get('combinational.multiplexer2to1.or.weight')
855
- b_or = self.reg.get('combinational.multiplexer2to1.or.bias')
856
- w_not = self.reg.get('combinational.multiplexer2to1.not_s.weight')
857
- b_not = self.reg.get('combinational.multiplexer2to1.not_s.bias')
858
- sel_t = torch.tensor([float(sel)], device=self.device)
859
- not_sel = heaviside(sel_t * w_not + b_not).item()
860
- inp0 = torch.tensor([float(a), not_sel], device=self.device)
861
- inp1 = torch.tensor([float(b), float(sel)], device=self.device)
862
- h0 = heaviside((inp0 * w_and0).sum() + b_and0).item()
863
- h1 = heaviside((inp1 * w_and1).sum() + b_and1).item()
864
- or_inp = torch.tensor([h0, h1], device=self.device)
865
- output = heaviside((or_inp * w_or).sum() + b_or).item()
866
- expected = float(b if sel else a)
867
- if output == expected:
868
- passed += 1
869
- else:
870
- failures.append(((a, b, sel), expected, output))
871
- return TestResult('combinational.multiplexer2to1', passed, 8, failures)
872
-
873
- def test_demux_1to2(self) -> TestResult:
874
- failures = []
875
- passed = 0
876
- w_and0 = self.reg.get('combinational.demultiplexer1to2.and0.weight')
877
- b_and0 = self.reg.get('combinational.demultiplexer1to2.and0.bias')
878
- w_and1 = self.reg.get('combinational.demultiplexer1to2.and1.weight')
879
- b_and1 = self.reg.get('combinational.demultiplexer1to2.and1.bias')
880
- for inp in [0, 1]:
881
- for sel in [0, 1]:
882
- inp_vec = torch.tensor([float(inp), float(sel)], device=self.device)
883
- out0 = heaviside((inp_vec * w_and0).sum() + b_and0).item()
884
- out1 = heaviside((inp_vec * w_and1).sum() + b_and1).item()
885
- expected0 = float(inp == 1 and sel == 0)
886
- expected1 = float(inp == 1 and sel == 1)
887
- if out0 == expected0:
888
- passed += 1
889
- else:
890
- failures.append(((inp, sel, 'out0'), expected0, out0))
891
- if out1 == expected1:
892
- passed += 1
893
- else:
894
- failures.append(((inp, sel, 'out1'), expected1, out1))
895
- return TestResult('combinational.demultiplexer1to2', passed, 8, failures)
896
-
897
- def test_barrel_shifter(self) -> TestResult:
898
- w = self.reg.get('combinational.barrelshifter8bit.shift')
899
- passed = 1 if w is not None else 0
900
- return TestResult('combinational.barrelshifter8bit', passed, 1, [])
901
-
902
- def test_mux_4to1(self) -> TestResult:
903
- w = self.reg.get('combinational.multiplexer4to1.select')
904
- passed = 1 if w is not None else 0
905
- return TestResult('combinational.multiplexer4to1', passed, 1, [])
906
-
907
- def test_mux_8to1(self) -> TestResult:
908
- w = self.reg.get('combinational.multiplexer8to1.select')
909
- passed = 1 if w is not None else 0
910
- return TestResult('combinational.multiplexer8to1', passed, 1, [])
911
-
912
- def test_demux_1to4(self) -> TestResult:
913
- w = self.reg.get('combinational.demultiplexer1to4.decode')
914
- passed = 1 if w is not None else 0
915
- return TestResult('combinational.demultiplexer1to4', passed, 1, [])
916
-
917
- def test_demux_1to8(self) -> TestResult:
918
- w = self.reg.get('combinational.demultiplexer1to8.decode')
919
- passed = 1 if w is not None else 0
920
- return TestResult('combinational.demultiplexer1to8', passed, 1, [])
921
-
922
- def test_priority_encoder(self) -> TestResult:
923
- if self.reg.has('combinational.priorityencoder8bit.priority'):
924
- self.reg.get('combinational.priorityencoder8bit.priority')
925
- return TestResult('combinational.priorityencoder8bit', 1, 1, [])
926
- return TestResult('combinational.priorityencoder8bit', 0, 1, [])
927
-
928
- # =========================================================================
929
- # PATTERN RECOGNITION
930
- # =========================================================================
931
-
932
- def test_popcount(self) -> TestResult:
933
- failures = []
934
- passed = 0
935
- w = self.reg.get('pattern_recognition.popcount.weight')
936
- b = self.reg.get('pattern_recognition.popcount.bias')
937
- for val in range(256):
938
- bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
939
- device=self.device, dtype=torch.float32)
940
- output = (bits * w).sum() + b
941
- expected = float(bin(val).count('1'))
942
- if output.item() == expected:
943
- passed += 1
944
- else:
945
- failures.append((val, expected, output.item()))
946
- return TestResult('pattern_recognition.popcount', passed, 256, failures)
947
-
948
- def test_allzeros(self) -> TestResult:
949
- failures = []
950
- passed = 0
951
- w = self.reg.get('pattern_recognition.allzeros.weight')
952
- b = self.reg.get('pattern_recognition.allzeros.bias')
953
- for val in range(256):
954
- bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
955
- device=self.device, dtype=torch.float32)
956
- output = heaviside((bits * w).sum() + b)
957
- expected = float(val == 0)
958
- if output.item() == expected:
959
- passed += 1
960
- else:
961
- failures.append((val, expected, output.item()))
962
- return TestResult('pattern_recognition.allzeros', passed, 256, failures)
963
-
964
- def test_allones(self) -> TestResult:
965
- failures = []
966
- passed = 0
967
- w = self.reg.get('pattern_recognition.allones.weight')
968
- b = self.reg.get('pattern_recognition.allones.bias')
969
- for val in range(256):
970
- bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
971
- device=self.device, dtype=torch.float32)
972
- output = heaviside((bits * w).sum() + b)
973
- expected = float(val == 255)
974
- if output.item() == expected:
975
- passed += 1
976
- else:
977
- failures.append((val, expected, output.item()))
978
- return TestResult('pattern_recognition.allones', passed, 256, failures)
979
-
980
- def test_hamming_distance(self) -> TestResult:
981
- passed = 0
982
- if self.reg.has('pattern_recognition.hammingdistance8bit.xor.weight'):
983
- self.reg.get('pattern_recognition.hammingdistance8bit.xor.weight')
984
- passed += 1
985
- if self.reg.has('pattern_recognition.hammingdistance8bit.popcount.weight'):
986
- self.reg.get('pattern_recognition.hammingdistance8bit.popcount.weight')
987
- passed += 1
988
- return TestResult('pattern_recognition.hammingdistance8bit', passed, 2, [])
989
-
990
- def test_one_hot_detector(self) -> TestResult:
991
- failures = []
992
- passed = 0
993
- w_atleast1 = self.reg.get('pattern_recognition.onehotdetector.atleast1.weight')
994
- b_atleast1 = self.reg.get('pattern_recognition.onehotdetector.atleast1.bias')
995
- w_atmost1 = self.reg.get('pattern_recognition.onehotdetector.atmost1.weight')
996
- b_atmost1 = self.reg.get('pattern_recognition.onehotdetector.atmost1.bias')
997
- w_and = self.reg.get('pattern_recognition.onehotdetector.and.weight')
998
- b_and = self.reg.get('pattern_recognition.onehotdetector.and.bias')
999
- for val in range(256):
1000
- bits = torch.tensor([(val >> (7-i)) & 1 for i in range(8)],
1001
- device=self.device, dtype=torch.float32)
1002
- atleast1 = heaviside((bits * w_atleast1).sum() + b_atleast1).item()
1003
- atmost1 = heaviside((bits * w_atmost1).sum() + b_atmost1).item()
1004
- hidden = torch.tensor([atleast1, atmost1], device=self.device)
1005
- output = heaviside((hidden * w_and).sum() + b_and).item()
1006
- popcount = bin(val).count('1')
1007
- expected = float(popcount == 1)
1008
- if output == expected:
1009
- passed += 1
1010
- else:
1011
- failures.append((val, expected, output))
1012
- return TestResult('pattern_recognition.onehotdetector', passed, 256, failures)
1013
-
1014
- def test_alternating_pattern(self) -> TestResult:
1015
- passed = 0
1016
- if self.reg.has('pattern_recognition.alternating8bit.pattern1.weight'):
1017
- self.reg.get('pattern_recognition.alternating8bit.pattern1.weight')
1018
- passed += 1
1019
- if self.reg.has('pattern_recognition.alternating8bit.pattern2.weight'):
1020
- self.reg.get('pattern_recognition.alternating8bit.pattern2.weight')
1021
- passed += 1
1022
- return TestResult('pattern_recognition.alternating8bit', passed, 2, [])
1023
-
1024
- def test_symmetry_detector(self) -> TestResult:
1025
- passed = 0
1026
- for i in range(4):
1027
- if self.reg.has(f'pattern_recognition.symmetry8bit.xnor{i}.weight'):
1028
- self.reg.get(f'pattern_recognition.symmetry8bit.xnor{i}.weight')
1029
- passed += 1
1030
- if self.reg.has('pattern_recognition.symmetry8bit.and.weight'):
1031
- self.reg.get('pattern_recognition.symmetry8bit.and.weight')
1032
- self.reg.get('pattern_recognition.symmetry8bit.and.bias')
1033
- passed += 2
1034
- return TestResult('pattern_recognition.symmetry8bit', passed, 6, [])
1035
-
1036
- def test_leading_ones(self) -> TestResult:
1037
- if self.reg.has('pattern_recognition.leadingones.weight'):
1038
- self.reg.get('pattern_recognition.leadingones.weight')
1039
- return TestResult('pattern_recognition.leadingones', 1, 1, [])
1040
- return TestResult('pattern_recognition.leadingones', 0, 1, [])
1041
-
1042
- def test_run_length(self) -> TestResult:
1043
- if self.reg.has('pattern_recognition.runlength.weight'):
1044
- self.reg.get('pattern_recognition.runlength.weight')
1045
- return TestResult('pattern_recognition.runlength', 1, 1, [])
1046
- return TestResult('pattern_recognition.runlength', 0, 1, [])
1047
-
1048
- def test_trailing_ones(self) -> TestResult:
1049
- if self.reg.has('pattern_recognition.trailingones.weight'):
1050
- self.reg.get('pattern_recognition.trailingones.weight')
1051
- return TestResult('pattern_recognition.trailingones', 1, 1, [])
1052
- return TestResult('pattern_recognition.trailingones', 0, 1, [])
1053
-
1054
- # =========================================================================
1055
- # ARITHMETIC - ADDITIONAL CIRCUITS
1056
- # =========================================================================
1057
-
1058
- def test_arithmetic_adc(self) -> TestResult:
1059
- passed = 0
1060
- for fa in range(8):
1061
- for comp in ['and1', 'and2', 'or_carry']:
1062
- if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{comp}.weight'):
1063
- self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.weight')
1064
- self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.bias')
1065
- passed += 2
1066
- for xor in ['xor1', 'xor2']:
1067
- for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1068
- if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight'):
1069
- self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight')
1070
- self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.bias')
1071
- passed += 2
1072
- return TestResult('arithmetic.adc8bit', passed, 144, [])
1073
-
1074
- def test_arithmetic_cmp(self) -> TestResult:
1075
- passed = 0
1076
- for fa in range(8):
1077
- if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.and1.weight'):
1078
- self.reg.get(f'arithmetic.cmp8bit.fa{fa}.and1.weight')
1079
- passed += 1
1080
- for bit in range(8):
1081
- if self.reg.has(f'arithmetic.cmp8bit.notb{bit}.weight'):
1082
- self.reg.get(f'arithmetic.cmp8bit.notb{bit}.weight')
1083
- self.reg.get(f'arithmetic.cmp8bit.notb{bit}.bias')
1084
- passed += 2
1085
- for flag in ['carry', 'negative', 'zero', 'zero_or']:
1086
- if self.reg.has(f'arithmetic.cmp8bit.flags.{flag}.weight'):
1087
- self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.weight')
1088
- passed += 1
1089
- return TestResult('arithmetic.cmp8bit', passed, 28, [])
1090
-
1091
- def test_arithmetic_equality(self) -> TestResult:
1092
- passed = 0
1093
- for i in range(8):
1094
- for layer in ['layer1.and', 'layer1.nor', 'layer2']:
1095
- if self.reg.has(f'arithmetic.equality8bit.xnor{i}.{layer}.weight'):
1096
- self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.weight')
1097
- self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.bias')
1098
- passed += 2
1099
- return TestResult('arithmetic.equality8bit', passed, 48, [])
1100
-
1101
- def test_arithmetic_minmax(self) -> TestResult:
1102
- passed = 0
1103
- for name in ['max8bit.select', 'min8bit.select', 'absolutedifference8bit.diff']:
1104
- if self.reg.has(f'arithmetic.{name}'):
1105
- self.reg.get(f'arithmetic.{name}')
1106
- passed += 1
1107
- return TestResult('arithmetic.minmax', passed, 3, [])
1108
-
1109
- def test_arithmetic_negate(self) -> TestResult:
1110
- passed = 0
1111
- for bit in range(8):
1112
- if self.reg.has(f'arithmetic.neg8bit.not{bit}.weight'):
1113
- self.reg.get(f'arithmetic.neg8bit.not{bit}.weight')
1114
- self.reg.get(f'arithmetic.neg8bit.not{bit}.bias')
1115
- passed += 2
1116
- for bit in range(1, 8):
1117
- if self.reg.has(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight'):
1118
- self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.weight')
1119
- self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.weight')
1120
- self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.weight')
1121
- passed += 3
1122
- for bit in range(1, 8):
1123
- if self.reg.has(f'arithmetic.neg8bit.and{bit}.weight'):
1124
- self.reg.get(f'arithmetic.neg8bit.and{bit}.weight')
1125
- self.reg.get(f'arithmetic.neg8bit.and{bit}.bias')
1126
- passed += 2
1127
- if self.reg.has('arithmetic.neg8bit.sum0.weight'):
1128
- self.reg.get('arithmetic.neg8bit.sum0.weight')
1129
- self.reg.get('arithmetic.neg8bit.carry0.weight')
1130
- passed += 2
1131
- for bit in range(8):
1132
- if self.reg.has(f'arithmetic.neg8bit.not{bit}.bias'):
1133
- self.reg.get(f'arithmetic.neg8bit.not{bit}.bias')
1134
- passed += 1
1135
- for bit in range(1, 8):
1136
- if self.reg.has(f'arithmetic.neg8bit.xor{bit}.layer1.nand.bias'):
1137
- self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.nand.bias')
1138
- self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer1.or.bias')
1139
- self.reg.get(f'arithmetic.neg8bit.xor{bit}.layer2.bias')
1140
- passed += 3
1141
- if self.reg.has(f'arithmetic.neg8bit.and{bit}.bias'):
1142
- self.reg.get(f'arithmetic.neg8bit.and{bit}.bias')
1143
- passed += 1
1144
- if self.reg.has('arithmetic.neg8bit.sum0.bias'):
1145
- self.reg.get('arithmetic.neg8bit.sum0.bias')
1146
- self.reg.get('arithmetic.neg8bit.carry0.bias')
1147
- passed += 2
1148
- return TestResult('arithmetic.neg8bit', passed, passed, [])
1149
-
1150
- def test_arithmetic_asr(self) -> TestResult:
1151
- passed = 0
1152
- for bit in range(8):
1153
- if self.reg.has(f'arithmetic.asr8bit.bit{bit}.weight'):
1154
- self.reg.get(f'arithmetic.asr8bit.bit{bit}.weight')
1155
- self.reg.get(f'arithmetic.asr8bit.bit{bit}.bias')
1156
- self.reg.get(f'arithmetic.asr8bit.bit{bit}.src')
1157
- passed += 3
1158
- if self.reg.has('arithmetic.asr8bit.shiftout.weight'):
1159
- self.reg.get('arithmetic.asr8bit.shiftout.weight')
1160
- self.reg.get('arithmetic.asr8bit.shiftout.bias')
1161
- passed += 2
1162
- return TestResult('arithmetic.asr8bit', passed, 26, [])
1163
-
1164
- def test_arithmetic_incrementer(self) -> TestResult:
1165
- passed = 0
1166
- if self.reg.has('arithmetic.incrementer8bit.adder'):
1167
- self.reg.get('arithmetic.incrementer8bit.adder')
1168
- passed += 1
1169
- if self.reg.has('arithmetic.incrementer8bit.one'):
1170
- self.reg.get('arithmetic.incrementer8bit.one')
1171
- passed += 1
1172
- return TestResult('arithmetic.incrementer8bit', passed, 2, [])
1173
-
1174
- def test_arithmetic_decrementer(self) -> TestResult:
1175
- passed = 0
1176
- if self.reg.has('arithmetic.decrementer8bit.adder'):
1177
- self.reg.get('arithmetic.decrementer8bit.adder')
1178
- passed += 1
1179
- if self.reg.has('arithmetic.decrementer8bit.neg_one'):
1180
- self.reg.get('arithmetic.decrementer8bit.neg_one')
1181
- passed += 1
1182
- return TestResult('arithmetic.decrementer8bit', passed, 2, [])
1183
-
1184
- def test_arithmetic_adc_internals(self) -> TestResult:
1185
- passed = 0
1186
- for fa in range(8):
1187
- for comp in ['and1', 'and2', 'or_carry']:
1188
- if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{comp}.weight'):
1189
- self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.weight')
1190
- self.reg.get(f'arithmetic.adc8bit.fa{fa}.{comp}.bias')
1191
- passed += 2
1192
- for xor in ['xor1', 'xor2']:
1193
- for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1194
- if self.reg.has(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight'):
1195
- self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.weight')
1196
- self.reg.get(f'arithmetic.adc8bit.fa{fa}.{xor}.{layer}.bias')
1197
- passed += 2
1198
- return TestResult('arithmetic.adc8bit.internals', passed, passed, [])
1199
-
1200
- def test_arithmetic_cmp_internals(self) -> TestResult:
1201
- passed = 0
1202
- for fa in range(8):
1203
- for comp in ['and1', 'and2', 'or_carry']:
1204
- if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight'):
1205
- self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.weight')
1206
- self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{comp}.bias')
1207
- passed += 2
1208
- for xor in ['xor1', 'xor2']:
1209
- for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1210
- if self.reg.has(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight'):
1211
- self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.weight')
1212
- self.reg.get(f'arithmetic.cmp8bit.fa{fa}.{xor}.{layer}.bias')
1213
- passed += 2
1214
- for bit in range(8):
1215
- if self.reg.has(f'arithmetic.cmp8bit.notb{bit}.weight'):
1216
- self.reg.get(f'arithmetic.cmp8bit.notb{bit}.weight')
1217
- self.reg.get(f'arithmetic.cmp8bit.notb{bit}.bias')
1218
- passed += 2
1219
- for flag in ['carry', 'negative', 'zero', 'zero_or']:
1220
- if self.reg.has(f'arithmetic.cmp8bit.flags.{flag}.weight'):
1221
- self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.weight')
1222
- self.reg.get(f'arithmetic.cmp8bit.flags.{flag}.bias')
1223
- passed += 2
1224
- return TestResult('arithmetic.cmp8bit.internals', passed, passed, [])
1225
-
1226
- def test_arithmetic_sbc_internals(self) -> TestResult:
1227
- passed = 0
1228
- for fa in range(8):
1229
- for comp in ['and1', 'and2', 'or_carry']:
1230
- if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight'):
1231
- self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.weight')
1232
- self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{comp}.bias')
1233
- passed += 2
1234
- for xor in ['xor1', 'xor2']:
1235
- for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1236
- if self.reg.has(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight'):
1237
- self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.weight')
1238
- self.reg.get(f'arithmetic.sbc8bit.fa{fa}.{xor}.{layer}.bias')
1239
- passed += 2
1240
- for bit in range(8):
1241
- if self.reg.has(f'arithmetic.sbc8bit.notb{bit}.weight'):
1242
- self.reg.get(f'arithmetic.sbc8bit.notb{bit}.weight')
1243
- self.reg.get(f'arithmetic.sbc8bit.notb{bit}.bias')
1244
- passed += 2
1245
- return TestResult('arithmetic.sbc8bit.internals', passed, passed, [])
1246
-
1247
- def test_arithmetic_sub_internals(self) -> TestResult:
1248
- passed = 0
1249
- if self.reg.has('arithmetic.sub8bit.carry_in.weight'):
1250
- self.reg.get('arithmetic.sub8bit.carry_in.weight')
1251
- self.reg.get('arithmetic.sub8bit.carry_in.bias')
1252
- passed += 2
1253
- for fa in range(8):
1254
- for comp in ['and1', 'and2', 'or_carry']:
1255
- if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{comp}.weight'):
1256
- self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.weight')
1257
- self.reg.get(f'arithmetic.sub8bit.fa{fa}.{comp}.bias')
1258
- passed += 2
1259
- for xor in ['xor1', 'xor2']:
1260
- for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1261
- if self.reg.has(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight'):
1262
- self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.weight')
1263
- self.reg.get(f'arithmetic.sub8bit.fa{fa}.{xor}.{layer}.bias')
1264
- passed += 2
1265
- for bit in range(8):
1266
- if self.reg.has(f'arithmetic.sub8bit.notb{bit}.weight'):
1267
- self.reg.get(f'arithmetic.sub8bit.notb{bit}.weight')
1268
- self.reg.get(f'arithmetic.sub8bit.notb{bit}.bias')
1269
- passed += 2
1270
- return TestResult('arithmetic.sub8bit.internals', passed, passed, [])
1271
-
1272
- def test_arithmetic_equality_internals(self) -> TestResult:
1273
- passed = 0
1274
- for i in range(8):
1275
- for layer in ['layer1.and', 'layer1.nor', 'layer2']:
1276
- if self.reg.has(f'arithmetic.equality8bit.xnor{i}.{layer}.weight'):
1277
- self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.weight')
1278
- self.reg.get(f'arithmetic.equality8bit.xnor{i}.{layer}.bias')
1279
- passed += 2
1280
- if self.reg.has('arithmetic.equality8bit.and.weight'):
1281
- self.reg.get('arithmetic.equality8bit.and.weight')
1282
- self.reg.get('arithmetic.equality8bit.and.bias')
1283
- passed += 2
1284
- return TestResult('arithmetic.equality8bit.internals', passed, passed, [])
1285
-
1286
- def test_arithmetic_rol_ror(self) -> TestResult:
1287
- passed = 0
1288
- for bit in range(8):
1289
- if self.reg.has(f'arithmetic.rol8bit.bit{bit}.weight'):
1290
- self.reg.get(f'arithmetic.rol8bit.bit{bit}.weight')
1291
- self.reg.get(f'arithmetic.rol8bit.bit{bit}.bias')
1292
- passed += 2
1293
- if self.reg.has('arithmetic.rol8bit.cout.weight'):
1294
- self.reg.get('arithmetic.rol8bit.cout.weight')
1295
- self.reg.get('arithmetic.rol8bit.cout.bias')
1296
- passed += 2
1297
- for bit in range(8):
1298
- if self.reg.has(f'arithmetic.ror8bit.bit{bit}.weight'):
1299
- self.reg.get(f'arithmetic.ror8bit.bit{bit}.weight')
1300
- self.reg.get(f'arithmetic.ror8bit.bit{bit}.bias')
1301
- passed += 2
1302
- if self.reg.has('arithmetic.ror8bit.cout.weight'):
1303
- self.reg.get('arithmetic.ror8bit.cout.weight')
1304
- self.reg.get('arithmetic.ror8bit.cout.bias')
1305
- passed += 2
1306
- return TestResult('arithmetic.rol_ror', passed, passed, [])
1307
-
1308
- def test_arithmetic_div_stages(self) -> TestResult:
1309
- passed = 0
1310
- for stage in range(8):
1311
- if self.reg.has(f'arithmetic.div8bit.stage{stage}.cmp.weight'):
1312
- self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.weight')
1313
- self.reg.get(f'arithmetic.div8bit.stage{stage}.cmp.bias')
1314
- passed += 2
1315
- for bit in range(8):
1316
- for comp in ['and0', 'and1', 'not_sel', 'or']:
1317
- if self.reg.has(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight'):
1318
- self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.weight')
1319
- self.reg.get(f'arithmetic.div8bit.stage{stage}.mux{bit}.{comp}.bias')
1320
- passed += 2
1321
- if self.reg.has(f'arithmetic.div8bit.stage{stage}.or_dividend.weight'):
1322
- self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.weight')
1323
- self.reg.get(f'arithmetic.div8bit.stage{stage}.or_dividend.bias')
1324
- passed += 2
1325
- for bit in range(8):
1326
- if self.reg.has(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight'):
1327
- self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.weight')
1328
- self.reg.get(f'arithmetic.div8bit.stage{stage}.shift.bit{bit}.bias')
1329
- passed += 2
1330
- for fa in range(8):
1331
- for comp in ['and1', 'and2', 'or_carry']:
1332
- if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight'):
1333
- self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.weight')
1334
- self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{comp}.bias')
1335
- passed += 2
1336
- for xor in ['xor1', 'xor2']:
1337
- for layer in ['layer1.nand', 'layer1.or', 'layer2']:
1338
- if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight'):
1339
- self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.weight')
1340
- self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.fa{fa}.{xor}.{layer}.bias')
1341
- passed += 2
1342
- for bit in range(8):
1343
- if self.reg.has(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight'):
1344
- self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.weight')
1345
- self.reg.get(f'arithmetic.div8bit.stage{stage}.sub.notd{bit}.bias')
1346
- passed += 2
1347
- return TestResult('arithmetic.div8bit.stages', passed, passed, [])
1348
-
1349
- def test_arithmetic_div_outputs(self) -> TestResult:
1350
- passed = 0
1351
- for bit in range(8):
1352
- if self.reg.has(f'arithmetic.div8bit.quotient{bit}.weight'):
1353
- self.reg.get(f'arithmetic.div8bit.quotient{bit}.weight')
1354
- self.reg.get(f'arithmetic.div8bit.quotient{bit}.bias')
1355
- passed += 2
1356
- if self.reg.has(f'arithmetic.div8bit.remainder{bit}.weight'):
1357
- self.reg.get(f'arithmetic.div8bit.remainder{bit}.weight')
1358
- self.reg.get(f'arithmetic.div8bit.remainder{bit}.bias')
1359
- passed += 2
1360
- return TestResult('arithmetic.div8bit.outputs', passed, passed, [])
1361
-
1362
- def test_arithmetic_multiplier_internals(self) -> TestResult:
1363
- passed = 0
1364
- for row in range(8):
1365
- for col in range(8):
1366
- if self.reg.has(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight'):
1367
- self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.weight')
1368
- self.reg.get(f'arithmetic.multiplier8x8.pp.r{row}.c{col}.bias')
1369
- passed += 2
1370
- for stage in range(7):
1371
- for bit in range(16):
1372
- for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
1373
- for suffix in ['.weight', '.bias']:
1374
- if self.reg.has(f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{comp}{suffix[1:]}'):
1375
- self.reg.get(f'arithmetic.multiplier8x8.stage{stage}.bit{bit}.{comp}{suffix[1:]}')
1376
- passed += 1
1377
- return TestResult('arithmetic.multiplier8x8.internals', passed, passed, [])
1378
-
1379
- def test_arithmetic_ripple_internals(self) -> TestResult:
1380
- passed = 0
1381
- for fa in range(8):
1382
- for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
1383
- if self.reg.has(f'arithmetic.ripplecarry8bit.fa{fa}.{comp}.weight'):
1384
- self.reg.get(f'arithmetic.ripplecarry8bit.fa{fa}.{comp}.weight')
1385
- self.reg.get(f'arithmetic.ripplecarry8bit.fa{fa}.{comp}.bias')
1386
- passed += 2
1387
- for fa in range(4):
1388
- for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
1389
- if self.reg.has(f'arithmetic.ripplecarry4bit.fa{fa}.{comp}.weight'):
1390
- self.reg.get(f'arithmetic.ripplecarry4bit.fa{fa}.{comp}.weight')
1391
- self.reg.get(f'arithmetic.ripplecarry4bit.fa{fa}.{comp}.bias')
1392
- passed += 2
1393
- for fa in range(2):
1394
- for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
1395
- if self.reg.has(f'arithmetic.ripplecarry2bit.fa{fa}.{comp}.weight'):
1396
- self.reg.get(f'arithmetic.ripplecarry2bit.fa{fa}.{comp}.weight')
1397
- self.reg.get(f'arithmetic.ripplecarry2bit.fa{fa}.{comp}.bias')
1398
- passed += 2
1399
- return TestResult('arithmetic.ripplecarry.internals', passed, passed, [])
1400
-
1401
- def test_arithmetic_equality_final(self) -> TestResult:
1402
- passed = 0
1403
- if self.reg.has('arithmetic.equality8bit.final_and.weight'):
1404
- self.reg.get('arithmetic.equality8bit.final_and.weight')
1405
- self.reg.get('arithmetic.equality8bit.final_and.bias')
1406
- passed += 2
1407
- return TestResult('arithmetic.equality8bit.final', passed, passed, [])
1408
-
1409
- def test_arithmetic_small_multipliers(self) -> TestResult:
1410
- passed = 0
1411
- for a in range(2):
1412
- for b in range(2):
1413
- if self.reg.has(f'arithmetic.multiplier2x2.and{a}{b}.weight'):
1414
- self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.weight')
1415
- self.reg.get(f'arithmetic.multiplier2x2.and{a}{b}.bias')
1416
- passed += 2
1417
- for comp in ['ha0.sum', 'ha0.carry', 'fa0.ha1.sum', 'fa0.ha1.carry', 'fa0.ha2.sum', 'fa0.ha2.carry', 'fa0.carry_or']:
1418
- if self.reg.has(f'arithmetic.multiplier2x2.{comp}.weight'):
1419
- self.reg.get(f'arithmetic.multiplier2x2.{comp}.weight')
1420
- self.reg.get(f'arithmetic.multiplier2x2.{comp}.bias')
1421
- passed += 2
1422
- for a in range(4):
1423
- for b in range(4):
1424
- if self.reg.has(f'arithmetic.multiplier4x4.and{a}{b}.weight'):
1425
- self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.weight')
1426
- self.reg.get(f'arithmetic.multiplier4x4.and{a}{b}.bias')
1427
- passed += 2
1428
- for stage in range(3):
1429
- for bit in range(8):
1430
- for comp in ['ha1.sum', 'ha1.carry', 'ha2.sum', 'ha2.carry', 'carry_or']:
1431
- if self.reg.has(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight'):
1432
- self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.weight')
1433
- self.reg.get(f'arithmetic.multiplier4x4.stage{stage}.bit{bit}.{comp}.bias')
1434
- passed += 2
1435
- return TestResult('arithmetic.small_multipliers', passed, passed, [])
1436
-
1437
- # =========================================================================
1438
- # DIVISION
1439
- # =========================================================================
1440
-
1441
- def test_division_8bit(self) -> TestResult:
1442
- if not self.reg.has('arithmetic.div8bit.quotient0.weight'):
1443
- return TestResult('arithmetic.div8bit', 0, 0, [('NOT FOUND', '', '')])
1444
- failures = []
1445
- passed = 0
1446
- total = 0
1447
- test_cases = []
1448
- test_cases.extend([(0, d) for d in [1, 2, 127, 255]])
1449
- test_cases.extend([(255, d) for d in [1, 2, 15, 16, 17, 127, 255]])
1450
- test_cases.extend([(d, 1) for d in range(0, 256, 16)])
1451
- test_cases.extend([(d, 255) for d in range(0, 256, 16)])
1452
- for dividend in [1, 2, 4, 8, 16, 32, 64, 128]:
1453
- for divisor in [1, 2, 4, 8, 16, 32, 64, 128]:
1454
- test_cases.append((dividend, divisor))
1455
- for dividend in range(0, 256, 8):
1456
- for divisor in range(1, 256, 8):
1457
- test_cases.append((dividend, divisor))
1458
- test_cases = list(set(test_cases))
1459
- for dividend, divisor in test_cases:
1460
- expected_q = dividend // divisor
1461
- expected_r = dividend % divisor
1462
- q, r = self._eval_division(dividend, divisor)
1463
- if q == expected_q and r == expected_r:
1464
- passed += 1
1465
- else:
1466
- if len(failures) < 100:
1467
- failures.append(((dividend, divisor), (expected_q, expected_r), (q, r)))
1468
- total += 1
1469
- return TestResult('arithmetic.div8bit', passed, total, failures)
1470
-
1471
- def _eval_division(self, dividend: int, divisor: int) -> Tuple[int, int]:
1472
- return self.routing_eval.eval_division(dividend, divisor)
1473
-
1474
-
1475
- class ArithmeticEvaluator:
1476
- """Main evaluator for arithmetic-only circuits."""
1477
-
1478
- def __init__(self, model_path: str, device: str = 'cuda'):
1479
- print(f"Loading model from {model_path}...")
1480
- self.registry = TensorRegistry(model_path)
1481
- print(f" Found {len(self.registry.tensors)} tensors")
1482
- print(f" Categories: {self.registry.categories}")
1483
- self.evaluator = CircuitEvaluator(self.registry, device)
1484
- self.results: List[TestResult] = []
1485
-
1486
- def run_all(self, verbose: bool = True) -> float:
1487
- start = time.time()
1488
-
1489
- # Boolean gates
1490
- if verbose:
1491
- print("\n=== BOOLEAN GATES ===")
1492
- self._run_test(self.evaluator.test_boolean_and, verbose)
1493
- self._run_test(self.evaluator.test_boolean_or, verbose)
1494
- self._run_test(self.evaluator.test_boolean_nand, verbose)
1495
- self._run_test(self.evaluator.test_boolean_nor, verbose)
1496
- self._run_test(self.evaluator.test_boolean_not, verbose)
1497
- self._run_test(self.evaluator.test_boolean_xor, verbose)
1498
- self._run_test(self.evaluator.test_boolean_xnor, verbose)
1499
- self._run_test(self.evaluator.test_boolean_implies, verbose)
1500
- self._run_test(self.evaluator.test_boolean_biimplies, verbose)
1501
-
1502
- # Arithmetic - adders
1503
- if verbose:
1504
- print("\n=== ARITHMETIC - ADDERS ===")
1505
- self._run_test(self.evaluator.test_half_adder, verbose)
1506
- self._run_test(self.evaluator.test_full_adder, verbose)
1507
- self._run_test(self.evaluator.test_ripple_carry_2bit, verbose)
1508
- self._run_test(self.evaluator.test_ripple_carry_4bit, verbose)
1509
- self._run_test(self.evaluator.test_ripple_carry_8bit, verbose)
1510
-
1511
- # Arithmetic - comparators
1512
- if verbose:
1513
- print("\n=== ARITHMETIC - COMPARATORS ===")
1514
- self._run_test(self.evaluator.test_greaterthan8bit, verbose)
1515
- self._run_test(self.evaluator.test_lessthan8bit, verbose)
1516
- self._run_test(self.evaluator.test_greaterorequal8bit, verbose)
1517
- self._run_test(self.evaluator.test_lessorequal8bit, verbose)
1518
-
1519
- # Arithmetic - multiplier
1520
- if verbose:
1521
- print("\n=== ARITHMETIC - MULTIPLIER ===")
1522
- self._run_test(self.evaluator.test_multiplier_8x8, verbose)
1523
-
1524
- # Arithmetic - additional
1525
- if verbose:
1526
- print("\n=== ARITHMETIC - ADDITIONAL ===")
1527
- self._run_test(self.evaluator.test_arithmetic_adc, verbose)
1528
- self._run_test(self.evaluator.test_arithmetic_cmp, verbose)
1529
- self._run_test(self.evaluator.test_arithmetic_equality, verbose)
1530
- self._run_test(self.evaluator.test_arithmetic_minmax, verbose)
1531
- self._run_test(self.evaluator.test_arithmetic_negate, verbose)
1532
- self._run_test(self.evaluator.test_arithmetic_asr, verbose)
1533
- self._run_test(self.evaluator.test_arithmetic_incrementer, verbose)
1534
- self._run_test(self.evaluator.test_arithmetic_decrementer, verbose)
1535
- self._run_test(self.evaluator.test_arithmetic_adc_internals, verbose)
1536
- self._run_test(self.evaluator.test_arithmetic_cmp_internals, verbose)
1537
- self._run_test(self.evaluator.test_arithmetic_sbc_internals, verbose)
1538
- self._run_test(self.evaluator.test_arithmetic_sub_internals, verbose)
1539
- self._run_test(self.evaluator.test_arithmetic_equality_internals, verbose)
1540
- self._run_test(self.evaluator.test_arithmetic_rol_ror, verbose)
1541
- self._run_test(self.evaluator.test_arithmetic_div_stages, verbose)
1542
- self._run_test(self.evaluator.test_arithmetic_div_outputs, verbose)
1543
- self._run_test(self.evaluator.test_arithmetic_multiplier_internals, verbose)
1544
- self._run_test(self.evaluator.test_arithmetic_ripple_internals, verbose)
1545
- self._run_test(self.evaluator.test_arithmetic_equality_final, verbose)
1546
- self._run_test(self.evaluator.test_arithmetic_small_multipliers, verbose)
1547
-
1548
- # Threshold gates
1549
- if verbose:
1550
- print("\n=== THRESHOLD GATES ===")
1551
- for result in self.evaluator.test_threshold_gates():
1552
- self.results.append(result)
1553
- if verbose:
1554
- self._print_result(result)
1555
- self._run_test(self.evaluator.test_threshold_atleastk_4, verbose)
1556
- self._run_test(self.evaluator.test_threshold_atmostk_4, verbose)
1557
- self._run_test(self.evaluator.test_threshold_exactlyk_4, verbose)
1558
- self._run_test(self.evaluator.test_threshold_majority, verbose)
1559
- self._run_test(self.evaluator.test_threshold_minority, verbose)
1560
-
1561
- # Modular arithmetic
1562
- if verbose:
1563
- print("\n=== MODULAR ARITHMETIC ===")
1564
- for mod in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]:
1565
- if self.registry.has(f'modular.mod{mod}.weight') or \
1566
- self.registry.has(f'modular.mod{mod}.layer1.geq0.weight'):
1567
- self._run_test(lambda m=mod: self.evaluator.test_modular(m), verbose)
1568
-
1569
- # Combinational
1570
- if verbose:
1571
- print("\n=== COMBINATIONAL ===")
1572
- self._run_test(self.evaluator.test_decoder_3to8, verbose)
1573
- self._run_test(self.evaluator.test_encoder_8to3, verbose)
1574
- self._run_test(self.evaluator.test_mux_2to1, verbose)
1575
- self._run_test(self.evaluator.test_demux_1to2, verbose)
1576
- self._run_test(self.evaluator.test_barrel_shifter, verbose)
1577
- self._run_test(self.evaluator.test_mux_4to1, verbose)
1578
- self._run_test(self.evaluator.test_mux_8to1, verbose)
1579
- self._run_test(self.evaluator.test_demux_1to4, verbose)
1580
- self._run_test(self.evaluator.test_demux_1to8, verbose)
1581
- self._run_test(self.evaluator.test_priority_encoder, verbose)
1582
-
1583
- # Pattern recognition
1584
- if verbose:
1585
- print("\n=== PATTERN RECOGNITION ===")
1586
- self._run_test(self.evaluator.test_popcount, verbose)
1587
- self._run_test(self.evaluator.test_allzeros, verbose)
1588
- self._run_test(self.evaluator.test_allones, verbose)
1589
- self._run_test(self.evaluator.test_hamming_distance, verbose)
1590
- self._run_test(self.evaluator.test_one_hot_detector, verbose)
1591
- self._run_test(self.evaluator.test_alternating_pattern, verbose)
1592
- self._run_test(self.evaluator.test_symmetry_detector, verbose)
1593
- self._run_test(self.evaluator.test_leading_ones, verbose)
1594
- self._run_test(self.evaluator.test_run_length, verbose)
1595
- self._run_test(self.evaluator.test_trailing_ones, verbose)
1596
-
1597
- # Division
1598
- if verbose:
1599
- print("\n=== DIVISION ===")
1600
- self._run_test(self.evaluator.test_division_8bit, verbose)
1601
-
1602
- elapsed = time.time() - start
1603
-
1604
- # Summary
1605
- total_passed = sum(r.passed for r in self.results)
1606
- total_tests = sum(r.total for r in self.results)
1607
-
1608
- print("\n" + "=" * 60)
1609
- print("SUMMARY")
1610
- print("=" * 60)
1611
- print(f"Total: {total_passed}/{total_tests} ({100*total_passed/total_tests:.4f}%)")
1612
- print(f"Time: {elapsed:.2f}s")
1613
-
1614
- failed = [r for r in self.results if not r.success]
1615
- if failed:
1616
- print(f"\nFailed circuits ({len(failed)}):")
1617
- for r in failed:
1618
- print(f" {r.circuit_name}: {r.passed}/{r.total} ({100*r.rate:.2f}%)")
1619
- if r.failures:
1620
- print(f" First failure: input={r.failures[0][0]}, expected={r.failures[0][1]}, got={r.failures[0][2]}")
1621
- else:
1622
- print("\nAll circuits passed!")
1623
-
1624
- print("\n" + "=" * 60)
1625
- print(self.registry.coverage_report())
1626
-
1627
- return total_passed / total_tests if total_tests > 0 else 0.0
1628
-
1629
- def _run_test(self, test_fn: Callable, verbose: bool):
1630
- try:
1631
- result = test_fn()
1632
- self.results.append(result)
1633
- if verbose:
1634
- self._print_result(result)
1635
- except Exception as e:
1636
- print(f" ERROR: {e}")
1637
-
1638
- def _print_result(self, result: TestResult):
1639
- status = "PASS" if result.success else "FAIL"
1640
- print(f" {result.circuit_name}: {result.passed}/{result.total} [{status}]")
1641
- if not result.success and result.failures:
1642
- print(f" First failure: {result.failures[0]}")
1643
-
1644
-
1645
- def main():
1646
- import argparse
1647
- parser = argparse.ArgumentParser(description='Arithmetic circuit evaluator for threshold-calculus')
1648
- parser.add_argument('--model', type=str, default='./arithmetic.safetensors',
1649
- help='Path to safetensors model')
1650
- parser.add_argument('--device', type=str, default='cuda',
1651
- help='Device (cuda or cpu)')
1652
- parser.add_argument('--quiet', action='store_true',
1653
- help='Suppress verbose output')
1654
- args = parser.parse_args()
1655
-
1656
- evaluator = ArithmeticEvaluator(args.model, args.device)
1657
- fitness = evaluator.run_all(verbose=not args.quiet)
1658
-
1659
- print(f"\nFitness: {fitness:.6f}")
1660
- return 0 if fitness >= 0.9999 else 1
1661
-
1662
-
1663
- if __name__ == '__main__':
1664
- exit(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval.py CHANGED
The diff for this file is too large to render. See raw diff