Upload iron_eval.py
Browse files- iron_eval.py +880 -0
iron_eval.py
ADDED
|
@@ -0,0 +1,880 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
IRON EVAL - COMPREHENSIVE
|
| 3 |
+
=========================
|
| 4 |
+
Complete fitness evaluation for ALL circuits in the threshold computer.
|
| 5 |
+
108 circuits, no placeholders, no shortcuts.
|
| 6 |
+
|
| 7 |
+
GPU-optimized for population-based evolution.
|
| 8 |
+
Target: ~40GB VRAM on RTX 6000 Ada (4M population)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from typing import Dict, Tuple
|
| 13 |
+
from safetensors import safe_open
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_model_10166(base_path: str = "D:/8bit-threshold-computer-10166") -> Dict[str, torch.Tensor]:
|
| 17 |
+
"""Load model from safetensors."""
|
| 18 |
+
f = safe_open(f"{base_path}/neural_computer.safetensors", framework='numpy')
|
| 19 |
+
tensors = {}
|
| 20 |
+
for name in f.keys():
|
| 21 |
+
tensors[name] = torch.tensor(f.get_tensor(name)).float()
|
| 22 |
+
return tensors
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def heaviside(x: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
"""Threshold activation: 1 if x >= 0, else 0."""
|
| 27 |
+
return (x >= 0).float()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BatchedFitnessEvaluator:
|
| 31 |
+
"""
|
| 32 |
+
GPU-batched fitness evaluator. Tests ALL circuits comprehensively.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, device='cuda'):
|
| 36 |
+
self.device = device
|
| 37 |
+
self._setup_tests()
|
| 38 |
+
|
| 39 |
+
def _setup_tests(self):
|
| 40 |
+
"""Pre-compute all test vectors."""
|
| 41 |
+
d = self.device
|
| 42 |
+
|
| 43 |
+
# 2-input truth table [4, 2]
|
| 44 |
+
self.tt2 = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=d, dtype=torch.float32)
|
| 45 |
+
|
| 46 |
+
# 3-input truth table [8, 3]
|
| 47 |
+
self.tt3 = torch.tensor([
|
| 48 |
+
[0,0,0], [0,0,1], [0,1,0], [0,1,1],
|
| 49 |
+
[1,0,0], [1,0,1], [1,1,0], [1,1,1]
|
| 50 |
+
], device=d, dtype=torch.float32)
|
| 51 |
+
|
| 52 |
+
# Boolean gate expected outputs
|
| 53 |
+
self.expected = {
|
| 54 |
+
'and': torch.tensor([0,0,0,1], device=d, dtype=torch.float32),
|
| 55 |
+
'or': torch.tensor([0,1,1,1], device=d, dtype=torch.float32),
|
| 56 |
+
'nand': torch.tensor([1,1,1,0], device=d, dtype=torch.float32),
|
| 57 |
+
'nor': torch.tensor([1,0,0,0], device=d, dtype=torch.float32),
|
| 58 |
+
'xor': torch.tensor([0,1,1,0], device=d, dtype=torch.float32),
|
| 59 |
+
'xnor': torch.tensor([1,0,0,1], device=d, dtype=torch.float32),
|
| 60 |
+
'implies': torch.tensor([1,1,0,1], device=d, dtype=torch.float32),
|
| 61 |
+
'biimplies': torch.tensor([1,0,0,1], device=d, dtype=torch.float32),
|
| 62 |
+
'not': torch.tensor([1,0], device=d, dtype=torch.float32),
|
| 63 |
+
'ha_sum': torch.tensor([0,1,1,0], device=d, dtype=torch.float32),
|
| 64 |
+
'ha_carry': torch.tensor([0,0,0,1], device=d, dtype=torch.float32),
|
| 65 |
+
'fa_sum': torch.tensor([0,1,1,0,1,0,0,1], device=d, dtype=torch.float32),
|
| 66 |
+
'fa_cout': torch.tensor([0,0,0,1,0,1,1,1], device=d, dtype=torch.float32),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# NOT gate inputs
|
| 70 |
+
self.not_inputs = torch.tensor([[0],[1]], device=d, dtype=torch.float32)
|
| 71 |
+
|
| 72 |
+
# 8-bit test values - comprehensive set
|
| 73 |
+
self.test_8bit = torch.tensor([
|
| 74 |
+
0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255,
|
| 75 |
+
0b10101010, 0b01010101, 0b11110000, 0b00001111,
|
| 76 |
+
0b11001100, 0b00110011, 0b10000001, 0b01111110
|
| 77 |
+
], device=d, dtype=torch.long)
|
| 78 |
+
|
| 79 |
+
# Bit representations [num_vals, 8]
|
| 80 |
+
self.test_8bit_bits = torch.stack([
|
| 81 |
+
((self.test_8bit >> (7-i)) & 1).float() for i in range(8)
|
| 82 |
+
], dim=1)
|
| 83 |
+
|
| 84 |
+
# Comparator test pairs - comprehensive with bit boundaries
|
| 85 |
+
comp_tests = [
|
| 86 |
+
(0,0), (1,0), (0,1), (5,3), (3,5), (5,5),
|
| 87 |
+
(255,0), (0,255), (128,127), (127,128),
|
| 88 |
+
(100,99), (99,100), (64,32), (32,64),
|
| 89 |
+
(200,100), (100,200), (1,2), (2,1),
|
| 90 |
+
(1,2), (2,1), (2,4), (4,2), (4,8), (8,4),
|
| 91 |
+
(8,16), (16,8), (16,32), (32,16), (32,64), (64,32),
|
| 92 |
+
(64,128), (128,64),
|
| 93 |
+
(1,1), (2,2), (4,4), (8,8), (16,16), (32,32), (64,64), (128,128),
|
| 94 |
+
(7,8), (8,7), (9,8), (8,9),
|
| 95 |
+
(15,16), (16,15), (17,16), (16,17),
|
| 96 |
+
(31,32), (32,31), (33,32), (32,33),
|
| 97 |
+
(63,64), (64,63), (65,64), (64,65),
|
| 98 |
+
(127,128), (128,127), (129,128), (128,129),
|
| 99 |
+
]
|
| 100 |
+
self.comp_a = torch.tensor([c[0] for c in comp_tests], device=d, dtype=torch.long)
|
| 101 |
+
self.comp_b = torch.tensor([c[1] for c in comp_tests], device=d, dtype=torch.long)
|
| 102 |
+
self.comp_a_bits = torch.stack([((self.comp_a >> (7-i)) & 1).float() for i in range(8)], dim=1)
|
| 103 |
+
self.comp_b_bits = torch.stack([((self.comp_b >> (7-i)) & 1).float() for i in range(8)], dim=1)
|
| 104 |
+
|
| 105 |
+
# Modular test values
|
| 106 |
+
self.mod_test = torch.arange(0, 256, device=d, dtype=torch.long)
|
| 107 |
+
self.mod_test_bits = torch.stack([((self.mod_test >> (7-i)) & 1).float() for i in range(8)], dim=1)
|
| 108 |
+
|
| 109 |
+
# =========================================================================
|
| 110 |
+
# BOOLEAN GATES
|
| 111 |
+
# =========================================================================
|
| 112 |
+
|
| 113 |
+
def _test_single_gate(self, pop: Dict, gate: str, inputs: torch.Tensor,
|
| 114 |
+
expected: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""Test single-layer boolean gate."""
|
| 116 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 117 |
+
w = pop[f'boolean.{gate}.weight'].view(pop_size, -1)
|
| 118 |
+
b = pop[f'boolean.{gate}.bias'].view(pop_size)
|
| 119 |
+
out = heaviside(inputs @ w.T + b)
|
| 120 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 121 |
+
|
| 122 |
+
def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor,
|
| 123 |
+
expected: torch.Tensor) -> torch.Tensor:
|
| 124 |
+
"""Test two-layer gate (XOR, XNOR, BIIMPLIES)."""
|
| 125 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 126 |
+
|
| 127 |
+
# Layer 1
|
| 128 |
+
w1_a = pop[f'{prefix}.layer1.neuron1.weight'].view(pop_size, -1)
|
| 129 |
+
b1_a = pop[f'{prefix}.layer1.neuron1.bias'].view(pop_size)
|
| 130 |
+
w1_b = pop[f'{prefix}.layer1.neuron2.weight'].view(pop_size, -1)
|
| 131 |
+
b1_b = pop[f'{prefix}.layer1.neuron2.bias'].view(pop_size)
|
| 132 |
+
|
| 133 |
+
h_a = heaviside(inputs @ w1_a.T + b1_a)
|
| 134 |
+
h_b = heaviside(inputs @ w1_b.T + b1_b)
|
| 135 |
+
hidden = torch.stack([h_a, h_b], dim=2)
|
| 136 |
+
|
| 137 |
+
# Layer 2
|
| 138 |
+
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
|
| 139 |
+
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
|
| 140 |
+
out = heaviside((hidden * w2.unsqueeze(0)).sum(2) + b2.unsqueeze(0))
|
| 141 |
+
|
| 142 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 143 |
+
|
| 144 |
+
# =========================================================================
|
| 145 |
+
# ARITHMETIC - ADDERS
|
| 146 |
+
# =========================================================================
|
| 147 |
+
|
| 148 |
+
def _test_halfadder(self, pop: Dict) -> torch.Tensor:
|
| 149 |
+
"""Test half adder: sum and carry."""
|
| 150 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 151 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 152 |
+
|
| 153 |
+
# Sum (XOR)
|
| 154 |
+
scores += self._test_twolayer_gate(pop, 'arithmetic.halfadder.sum',
|
| 155 |
+
self.tt2, self.expected['ha_sum'])
|
| 156 |
+
# Carry (AND)
|
| 157 |
+
w = pop['arithmetic.halfadder.carry.weight'].view(pop_size, -1)
|
| 158 |
+
b = pop['arithmetic.halfadder.carry.bias'].view(pop_size)
|
| 159 |
+
out = heaviside(self.tt2 @ w.T + b)
|
| 160 |
+
scores += (out == self.expected['ha_carry'].unsqueeze(1)).float().sum(0)
|
| 161 |
+
|
| 162 |
+
return scores
|
| 163 |
+
|
| 164 |
+
def _test_fulladder(self, pop: Dict) -> torch.Tensor:
|
| 165 |
+
"""Test full adder circuit."""
|
| 166 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 167 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 168 |
+
|
| 169 |
+
for i, (a, b, cin) in enumerate([(0,0,0), (0,0,1), (0,1,0), (0,1,1),
|
| 170 |
+
(1,0,0), (1,0,1), (1,1,0), (1,1,1)]):
|
| 171 |
+
inp_ab = torch.tensor([[float(a), float(b)]], device=self.device)
|
| 172 |
+
|
| 173 |
+
# HA1
|
| 174 |
+
ha1_sum = self._eval_xor(pop, 'arithmetic.fulladder.ha1.sum', inp_ab)
|
| 175 |
+
w_c1 = pop['arithmetic.fulladder.ha1.carry.weight'].view(pop_size, -1)
|
| 176 |
+
b_c1 = pop['arithmetic.fulladder.ha1.carry.bias'].view(pop_size)
|
| 177 |
+
ha1_carry = heaviside(inp_ab @ w_c1.T + b_c1)
|
| 178 |
+
|
| 179 |
+
# HA2
|
| 180 |
+
inp_ha2 = torch.stack([ha1_sum.squeeze(0), torch.full((pop_size,), float(cin), device=self.device)], dim=1)
|
| 181 |
+
|
| 182 |
+
w1_or = pop['arithmetic.fulladder.ha2.sum.layer1.or.weight'].view(pop_size, -1)
|
| 183 |
+
b1_or = pop['arithmetic.fulladder.ha2.sum.layer1.or.bias'].view(pop_size)
|
| 184 |
+
w1_nand = pop['arithmetic.fulladder.ha2.sum.layer1.nand.weight'].view(pop_size, -1)
|
| 185 |
+
b1_nand = pop['arithmetic.fulladder.ha2.sum.layer1.nand.bias'].view(pop_size)
|
| 186 |
+
w2 = pop['arithmetic.fulladder.ha2.sum.layer2.weight'].view(pop_size, -1)
|
| 187 |
+
b2 = pop['arithmetic.fulladder.ha2.sum.layer2.bias'].view(pop_size)
|
| 188 |
+
|
| 189 |
+
h_or = heaviside((inp_ha2 * w1_or).sum(1) + b1_or)
|
| 190 |
+
h_nand = heaviside((inp_ha2 * w1_nand).sum(1) + b1_nand)
|
| 191 |
+
hidden = torch.stack([h_or, h_nand], dim=1)
|
| 192 |
+
ha2_sum = heaviside((hidden * w2).sum(1) + b2)
|
| 193 |
+
|
| 194 |
+
w_c2 = pop['arithmetic.fulladder.ha2.carry.weight'].view(pop_size, -1)
|
| 195 |
+
b_c2 = pop['arithmetic.fulladder.ha2.carry.bias'].view(pop_size)
|
| 196 |
+
ha2_carry = heaviside((inp_ha2 * w_c2).sum(1) + b_c2)
|
| 197 |
+
|
| 198 |
+
# Carry OR
|
| 199 |
+
inp_cout = torch.stack([ha1_carry.squeeze(0), ha2_carry], dim=1)
|
| 200 |
+
w_cor = pop['arithmetic.fulladder.carry_or.weight'].view(pop_size, -1)
|
| 201 |
+
b_cor = pop['arithmetic.fulladder.carry_or.bias'].view(pop_size)
|
| 202 |
+
cout = heaviside((inp_cout * w_cor).sum(1) + b_cor)
|
| 203 |
+
|
| 204 |
+
scores += (ha2_sum == self.expected['fa_sum'][i]).float()
|
| 205 |
+
scores += (cout == self.expected['fa_cout'][i]).float()
|
| 206 |
+
|
| 207 |
+
return scores
|
| 208 |
+
|
| 209 |
+
def _eval_xor(self, pop: Dict, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
|
| 210 |
+
"""Evaluate XOR gate for given inputs."""
|
| 211 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 212 |
+
|
| 213 |
+
w1_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1)
|
| 214 |
+
b1_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
|
| 215 |
+
w1_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1)
|
| 216 |
+
b1_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
|
| 217 |
+
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
|
| 218 |
+
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
|
| 219 |
+
|
| 220 |
+
h_or = heaviside(inputs @ w1_or.T + b1_or)
|
| 221 |
+
h_nand = heaviside(inputs @ w1_nand.T + b1_nand)
|
| 222 |
+
hidden = torch.stack([h_or, h_nand], dim=2)
|
| 223 |
+
return heaviside((hidden * w2.unsqueeze(0)).sum(2) + b2.unsqueeze(0))
|
| 224 |
+
|
| 225 |
+
def _eval_single_fa(self, pop: Dict, prefix: str, a: torch.Tensor,
|
| 226 |
+
b: torch.Tensor, cin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 227 |
+
"""Evaluate a single full adder."""
|
| 228 |
+
pop_size = a.shape[0]
|
| 229 |
+
inp_ab = torch.stack([a, b], dim=1)
|
| 230 |
+
|
| 231 |
+
# HA1 XOR
|
| 232 |
+
w1_or = pop[f'{prefix}.ha1.sum.layer1.or.weight'].view(pop_size, -1)
|
| 233 |
+
b1_or = pop[f'{prefix}.ha1.sum.layer1.or.bias'].view(pop_size)
|
| 234 |
+
w1_nand = pop[f'{prefix}.ha1.sum.layer1.nand.weight'].view(pop_size, -1)
|
| 235 |
+
b1_nand = pop[f'{prefix}.ha1.sum.layer1.nand.bias'].view(pop_size)
|
| 236 |
+
w1_l2 = pop[f'{prefix}.ha1.sum.layer2.weight'].view(pop_size, -1)
|
| 237 |
+
b1_l2 = pop[f'{prefix}.ha1.sum.layer2.bias'].view(pop_size)
|
| 238 |
+
|
| 239 |
+
h_or = heaviside((inp_ab * w1_or).sum(1) + b1_or)
|
| 240 |
+
h_nand = heaviside((inp_ab * w1_nand).sum(1) + b1_nand)
|
| 241 |
+
hidden1 = torch.stack([h_or, h_nand], dim=1)
|
| 242 |
+
ha1_sum = heaviside((hidden1 * w1_l2).sum(1) + b1_l2)
|
| 243 |
+
|
| 244 |
+
w_c1 = pop[f'{prefix}.ha1.carry.weight'].view(pop_size, -1)
|
| 245 |
+
b_c1 = pop[f'{prefix}.ha1.carry.bias'].view(pop_size)
|
| 246 |
+
ha1_carry = heaviside((inp_ab * w_c1).sum(1) + b_c1)
|
| 247 |
+
|
| 248 |
+
# HA2 XOR
|
| 249 |
+
inp_ha2 = torch.stack([ha1_sum, cin], dim=1)
|
| 250 |
+
|
| 251 |
+
w2_or = pop[f'{prefix}.ha2.sum.layer1.or.weight'].view(pop_size, -1)
|
| 252 |
+
b2_or = pop[f'{prefix}.ha2.sum.layer1.or.bias'].view(pop_size)
|
| 253 |
+
w2_nand = pop[f'{prefix}.ha2.sum.layer1.nand.weight'].view(pop_size, -1)
|
| 254 |
+
b2_nand = pop[f'{prefix}.ha2.sum.layer1.nand.bias'].view(pop_size)
|
| 255 |
+
w2_l2 = pop[f'{prefix}.ha2.sum.layer2.weight'].view(pop_size, -1)
|
| 256 |
+
b2_l2 = pop[f'{prefix}.ha2.sum.layer2.bias'].view(pop_size)
|
| 257 |
+
|
| 258 |
+
h2_or = heaviside((inp_ha2 * w2_or).sum(1) + b2_or)
|
| 259 |
+
h2_nand = heaviside((inp_ha2 * w2_nand).sum(1) + b2_nand)
|
| 260 |
+
hidden2 = torch.stack([h2_or, h2_nand], dim=1)
|
| 261 |
+
ha2_sum = heaviside((hidden2 * w2_l2).sum(1) + b2_l2)
|
| 262 |
+
|
| 263 |
+
w_c2 = pop[f'{prefix}.ha2.carry.weight'].view(pop_size, -1)
|
| 264 |
+
b_c2 = pop[f'{prefix}.ha2.carry.bias'].view(pop_size)
|
| 265 |
+
ha2_carry = heaviside((inp_ha2 * w_c2).sum(1) + b_c2)
|
| 266 |
+
|
| 267 |
+
# Carry OR
|
| 268 |
+
inp_cout = torch.stack([ha1_carry, ha2_carry], dim=1)
|
| 269 |
+
w_cor = pop[f'{prefix}.carry_or.weight'].view(pop_size, -1)
|
| 270 |
+
b_cor = pop[f'{prefix}.carry_or.bias'].view(pop_size)
|
| 271 |
+
cout = heaviside((inp_cout * w_cor).sum(1) + b_cor)
|
| 272 |
+
|
| 273 |
+
return ha2_sum, cout
|
| 274 |
+
|
| 275 |
+
def _test_ripplecarry(self, pop: Dict, bits: int, test_cases: list) -> torch.Tensor:
|
| 276 |
+
"""Test ripple carry adder of given bit width."""
|
| 277 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 278 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 279 |
+
|
| 280 |
+
for a_val, b_val in test_cases:
|
| 281 |
+
# Extract bits
|
| 282 |
+
a_bits = [(a_val >> i) & 1 for i in range(bits)]
|
| 283 |
+
b_bits = [(b_val >> i) & 1 for i in range(bits)]
|
| 284 |
+
|
| 285 |
+
carry = torch.zeros(pop_size, device=self.device)
|
| 286 |
+
sum_bits = []
|
| 287 |
+
|
| 288 |
+
for i in range(bits):
|
| 289 |
+
a_i = torch.full((pop_size,), float(a_bits[i]), device=self.device)
|
| 290 |
+
b_i = torch.full((pop_size,), float(b_bits[i]), device=self.device)
|
| 291 |
+
sum_i, carry = self._eval_single_fa(pop, f'arithmetic.ripplecarry{bits}bit.fa{i}', a_i, b_i, carry)
|
| 292 |
+
sum_bits.append(sum_i)
|
| 293 |
+
|
| 294 |
+
# Reconstruct result
|
| 295 |
+
result = sum(sum_bits[i] * (2**i) for i in range(bits))
|
| 296 |
+
expected = (a_val + b_val) & ((1 << bits) - 1)
|
| 297 |
+
scores += (result == expected).float()
|
| 298 |
+
|
| 299 |
+
return scores
|
| 300 |
+
|
| 301 |
+
# =========================================================================
|
| 302 |
+
# ARITHMETIC - COMPARATORS
|
| 303 |
+
# =========================================================================
|
| 304 |
+
|
| 305 |
+
def _test_comparator(self, pop: Dict, name: str, op: str) -> torch.Tensor:
|
| 306 |
+
"""Test 8-bit comparator."""
|
| 307 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 308 |
+
w = pop[f'arithmetic.{name}.comparator'].view(pop_size, -1)
|
| 309 |
+
|
| 310 |
+
if op == 'gt':
|
| 311 |
+
diff = self.comp_a_bits - self.comp_b_bits
|
| 312 |
+
expected = (self.comp_a > self.comp_b).float()
|
| 313 |
+
elif op == 'lt':
|
| 314 |
+
diff = self.comp_b_bits - self.comp_a_bits
|
| 315 |
+
expected = (self.comp_a < self.comp_b).float()
|
| 316 |
+
elif op == 'geq':
|
| 317 |
+
diff = self.comp_a_bits - self.comp_b_bits
|
| 318 |
+
expected = (self.comp_a >= self.comp_b).float()
|
| 319 |
+
elif op == 'leq':
|
| 320 |
+
diff = self.comp_b_bits - self.comp_a_bits
|
| 321 |
+
expected = (self.comp_a <= self.comp_b).float()
|
| 322 |
+
|
| 323 |
+
score = diff @ w.T
|
| 324 |
+
if op in ['geq', 'leq']:
|
| 325 |
+
out = (score >= 0).float()
|
| 326 |
+
else:
|
| 327 |
+
out = (score > 0).float()
|
| 328 |
+
|
| 329 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 330 |
+
|
| 331 |
+
def _test_equality(self, pop: Dict) -> torch.Tensor:
|
| 332 |
+
"""Test 8-bit equality circuit."""
|
| 333 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 334 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 335 |
+
|
| 336 |
+
for i in range(len(self.comp_a)):
|
| 337 |
+
a_bits = self.comp_a_bits[i]
|
| 338 |
+
b_bits = self.comp_b_bits[i]
|
| 339 |
+
|
| 340 |
+
# Compute XNOR for each bit pair
|
| 341 |
+
xnor_results = []
|
| 342 |
+
for bit in range(8):
|
| 343 |
+
inp = torch.stack([
|
| 344 |
+
torch.full((pop_size,), a_bits[bit].item(), device=self.device),
|
| 345 |
+
torch.full((pop_size,), b_bits[bit].item(), device=self.device)
|
| 346 |
+
], dim=1)
|
| 347 |
+
|
| 348 |
+
# XNOR = (a AND b) OR (NOR(a,b))
|
| 349 |
+
w_and = pop[f'arithmetic.equality8bit.xnor{bit}.layer1.and.weight'].view(pop_size, -1)
|
| 350 |
+
b_and = pop[f'arithmetic.equality8bit.xnor{bit}.layer1.and.bias'].view(pop_size)
|
| 351 |
+
w_nor = pop[f'arithmetic.equality8bit.xnor{bit}.layer1.nor.weight'].view(pop_size, -1)
|
| 352 |
+
b_nor = pop[f'arithmetic.equality8bit.xnor{bit}.layer1.nor.bias'].view(pop_size)
|
| 353 |
+
w_l2 = pop[f'arithmetic.equality8bit.xnor{bit}.layer2.weight'].view(pop_size, -1)
|
| 354 |
+
b_l2 = pop[f'arithmetic.equality8bit.xnor{bit}.layer2.bias'].view(pop_size)
|
| 355 |
+
|
| 356 |
+
h_and = heaviside((inp * w_and).sum(1) + b_and)
|
| 357 |
+
h_nor = heaviside((inp * w_nor).sum(1) + b_nor)
|
| 358 |
+
hidden = torch.stack([h_and, h_nor], dim=1)
|
| 359 |
+
xnor_out = heaviside((hidden * w_l2).sum(1) + b_l2)
|
| 360 |
+
xnor_results.append(xnor_out)
|
| 361 |
+
|
| 362 |
+
# Final AND of all XNORs
|
| 363 |
+
xnor_stack = torch.stack(xnor_results, dim=1)
|
| 364 |
+
w_final = pop['arithmetic.equality8bit.final_and.weight'].view(pop_size, -1)
|
| 365 |
+
b_final = pop['arithmetic.equality8bit.final_and.bias'].view(pop_size)
|
| 366 |
+
eq_out = heaviside((xnor_stack * w_final).sum(1) + b_final)
|
| 367 |
+
|
| 368 |
+
expected = (self.comp_a[i] == self.comp_b[i]).float()
|
| 369 |
+
scores += (eq_out == expected).float()
|
| 370 |
+
|
| 371 |
+
return scores
|
| 372 |
+
|
| 373 |
+
# =========================================================================
|
| 374 |
+
# THRESHOLD GATES
|
| 375 |
+
# =========================================================================
|
| 376 |
+
|
| 377 |
+
def _test_threshold_kofn(self, pop: Dict, k: int, name: str) -> torch.Tensor:
|
| 378 |
+
"""Test k-of-8 threshold gate."""
|
| 379 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 380 |
+
w = pop[f'threshold.{name}.weight'].view(pop_size, -1)
|
| 381 |
+
b = pop[f'threshold.{name}.bias'].view(pop_size)
|
| 382 |
+
|
| 383 |
+
out = heaviside(self.test_8bit_bits @ w.T + b)
|
| 384 |
+
popcounts = self.test_8bit_bits.sum(1)
|
| 385 |
+
expected = (popcounts >= k).float()
|
| 386 |
+
|
| 387 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 388 |
+
|
| 389 |
+
def _test_majority(self, pop: Dict) -> torch.Tensor:
|
| 390 |
+
"""Test majority gate (5+ of 8)."""
|
| 391 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 392 |
+
w = pop['threshold.majority.weight'].view(pop_size, -1)
|
| 393 |
+
b = pop['threshold.majority.bias'].view(pop_size)
|
| 394 |
+
|
| 395 |
+
out = heaviside(self.test_8bit_bits @ w.T + b)
|
| 396 |
+
popcounts = self.test_8bit_bits.sum(1)
|
| 397 |
+
expected = (popcounts >= 5).float()
|
| 398 |
+
|
| 399 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 400 |
+
|
| 401 |
+
def _test_minority(self, pop: Dict) -> torch.Tensor:
|
| 402 |
+
"""Test minority gate (3 or fewer of 8)."""
|
| 403 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 404 |
+
w = pop['threshold.minority.weight'].view(pop_size, -1)
|
| 405 |
+
b = pop['threshold.minority.bias'].view(pop_size)
|
| 406 |
+
|
| 407 |
+
out = heaviside(self.test_8bit_bits @ w.T + b)
|
| 408 |
+
popcounts = self.test_8bit_bits.sum(1)
|
| 409 |
+
expected = (popcounts <= 3).float()
|
| 410 |
+
|
| 411 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 412 |
+
|
| 413 |
+
def _test_atleastk(self, pop: Dict, k: int) -> torch.Tensor:
|
| 414 |
+
"""Test at-least-k threshold gate."""
|
| 415 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 416 |
+
w = pop[f'threshold.atleastk_{k}.weight'].view(pop_size, -1)
|
| 417 |
+
b = pop[f'threshold.atleastk_{k}.bias'].view(pop_size)
|
| 418 |
+
|
| 419 |
+
out = heaviside(self.test_8bit_bits @ w.T + b)
|
| 420 |
+
popcounts = self.test_8bit_bits.sum(1)
|
| 421 |
+
expected = (popcounts >= k).float()
|
| 422 |
+
|
| 423 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 424 |
+
|
| 425 |
+
def _test_atmostk(self, pop: Dict, k: int) -> torch.Tensor:
|
| 426 |
+
"""Test at-most-k threshold gate."""
|
| 427 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 428 |
+
w = pop[f'threshold.atmostk_{k}.weight'].view(pop_size, -1)
|
| 429 |
+
b = pop[f'threshold.atmostk_{k}.bias'].view(pop_size)
|
| 430 |
+
|
| 431 |
+
out = heaviside(self.test_8bit_bits @ w.T + b)
|
| 432 |
+
popcounts = self.test_8bit_bits.sum(1)
|
| 433 |
+
expected = (popcounts <= k).float()
|
| 434 |
+
|
| 435 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 436 |
+
|
| 437 |
+
def _test_exactlyk(self, pop: Dict, k: int) -> torch.Tensor:
|
| 438 |
+
"""Test exactly-k threshold gate (uses atleast AND atmost)."""
|
| 439 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 440 |
+
|
| 441 |
+
# At least k
|
| 442 |
+
w_al = pop[f'threshold.exactlyk_{k}.atleast.weight'].view(pop_size, -1)
|
| 443 |
+
b_al = pop[f'threshold.exactlyk_{k}.atleast.bias'].view(pop_size)
|
| 444 |
+
atleast = heaviside(self.test_8bit_bits @ w_al.T + b_al)
|
| 445 |
+
|
| 446 |
+
# At most k
|
| 447 |
+
w_am = pop[f'threshold.exactlyk_{k}.atmost.weight'].view(pop_size, -1)
|
| 448 |
+
b_am = pop[f'threshold.exactlyk_{k}.atmost.bias'].view(pop_size)
|
| 449 |
+
atmost = heaviside(self.test_8bit_bits @ w_am.T + b_am)
|
| 450 |
+
|
| 451 |
+
# AND
|
| 452 |
+
combined = torch.stack([atleast, atmost], dim=2)
|
| 453 |
+
w_and = pop[f'threshold.exactlyk_{k}.and.weight'].view(pop_size, -1)
|
| 454 |
+
b_and = pop[f'threshold.exactlyk_{k}.and.bias'].view(pop_size)
|
| 455 |
+
out = heaviside((combined * w_and.unsqueeze(0)).sum(2) + b_and.unsqueeze(0))
|
| 456 |
+
|
| 457 |
+
popcounts = self.test_8bit_bits.sum(1)
|
| 458 |
+
expected = (popcounts == k).float()
|
| 459 |
+
|
| 460 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 461 |
+
|
| 462 |
+
# =========================================================================
|
| 463 |
+
# PATTERN RECOGNITION
|
| 464 |
+
# =========================================================================
|
| 465 |
+
|
| 466 |
+
def _test_popcount(self, pop: Dict) -> torch.Tensor:
|
| 467 |
+
"""Test popcount (count of 1 bits)."""
|
| 468 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 469 |
+
w = pop['pattern_recognition.popcount.weight'].view(pop_size, -1)
|
| 470 |
+
b = pop['pattern_recognition.popcount.bias'].view(pop_size)
|
| 471 |
+
|
| 472 |
+
out = (self.test_8bit_bits @ w.T + b) # No heaviside - this is a counter
|
| 473 |
+
expected = self.test_8bit_bits.sum(1)
|
| 474 |
+
|
| 475 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 476 |
+
|
| 477 |
+
def _test_allzeros(self, pop: Dict) -> torch.Tensor:
|
| 478 |
+
"""Test all-zeros detector."""
|
| 479 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 480 |
+
w = pop['pattern_recognition.allzeros.weight'].view(pop_size, -1)
|
| 481 |
+
b = pop['pattern_recognition.allzeros.bias'].view(pop_size)
|
| 482 |
+
|
| 483 |
+
out = heaviside(self.test_8bit_bits @ w.T + b)
|
| 484 |
+
expected = (self.test_8bit == 0).float()
|
| 485 |
+
|
| 486 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 487 |
+
|
| 488 |
+
def _test_allones(self, pop: Dict) -> torch.Tensor:
|
| 489 |
+
"""Test all-ones detector."""
|
| 490 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 491 |
+
w = pop['pattern_recognition.allones.weight'].view(pop_size, -1)
|
| 492 |
+
b = pop['pattern_recognition.allones.bias'].view(pop_size)
|
| 493 |
+
|
| 494 |
+
out = heaviside(self.test_8bit_bits @ w.T + b)
|
| 495 |
+
expected = (self.test_8bit == 255).float()
|
| 496 |
+
|
| 497 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 498 |
+
|
| 499 |
+
# =========================================================================
|
| 500 |
+
# ERROR DETECTION
|
| 501 |
+
# =========================================================================
|
| 502 |
+
|
| 503 |
+
def _test_parity(self, pop: Dict, name: str, even: bool) -> torch.Tensor:
|
| 504 |
+
"""Test parity checker/generator."""
|
| 505 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 506 |
+
w = pop[f'error_detection.{name}.weight'].view(pop_size, -1)
|
| 507 |
+
b = pop[f'error_detection.{name}.bias'].view(pop_size)
|
| 508 |
+
|
| 509 |
+
out = heaviside(self.test_8bit_bits @ w.T + b)
|
| 510 |
+
popcounts = self.test_8bit_bits.sum(1)
|
| 511 |
+
if even:
|
| 512 |
+
expected = ((popcounts.long() % 2) == 0).float()
|
| 513 |
+
else:
|
| 514 |
+
expected = ((popcounts.long() % 2) == 1).float()
|
| 515 |
+
|
| 516 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 517 |
+
|
| 518 |
+
# =========================================================================
|
| 519 |
+
# MODULAR ARITHMETIC
|
| 520 |
+
# =========================================================================
|
| 521 |
+
|
| 522 |
+
def _test_modular(self, pop: Dict, mod: int) -> torch.Tensor:
|
| 523 |
+
"""Test modular arithmetic circuit."""
|
| 524 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 525 |
+
w = pop[f'modular.mod{mod}.weight'].view(pop_size, -1)
|
| 526 |
+
b = pop[f'modular.mod{mod}.bias'].view(pop_size)
|
| 527 |
+
|
| 528 |
+
out = heaviside(self.mod_test_bits @ w.T + b)
|
| 529 |
+
expected = ((self.mod_test % mod) == 0).float()
|
| 530 |
+
|
| 531 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 532 |
+
|
| 533 |
+
# =========================================================================
|
| 534 |
+
# COMBINATIONAL
|
| 535 |
+
# =========================================================================
|
| 536 |
+
|
| 537 |
+
def _test_mux2to1(self, pop: Dict) -> torch.Tensor:
|
| 538 |
+
"""Test 2:1 multiplexer."""
|
| 539 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 540 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 541 |
+
|
| 542 |
+
# Test all 8 combinations of (a, b, sel)
|
| 543 |
+
for a in [0, 1]:
|
| 544 |
+
for b in [0, 1]:
|
| 545 |
+
for sel in [0, 1]:
|
| 546 |
+
expected = a if sel == 1 else b
|
| 547 |
+
|
| 548 |
+
# MUX uses: and_a, and_b, not_sel, or
|
| 549 |
+
a_t = torch.full((pop_size,), float(a), device=self.device)
|
| 550 |
+
b_t = torch.full((pop_size,), float(b), device=self.device)
|
| 551 |
+
sel_t = torch.full((pop_size,), float(sel), device=self.device)
|
| 552 |
+
|
| 553 |
+
# NOT sel
|
| 554 |
+
w_not = pop['combinational.multiplexer2to1.not_sel.weight'].view(pop_size, -1)
|
| 555 |
+
b_not = pop['combinational.multiplexer2to1.not_sel.bias'].view(pop_size)
|
| 556 |
+
not_sel = heaviside(sel_t.unsqueeze(1) @ w_not.T + b_not)
|
| 557 |
+
|
| 558 |
+
# AND(a, sel)
|
| 559 |
+
inp_a = torch.stack([a_t, sel_t], dim=1)
|
| 560 |
+
w_and_a = pop['combinational.multiplexer2to1.and_a.weight'].view(pop_size, -1)
|
| 561 |
+
b_and_a = pop['combinational.multiplexer2to1.and_a.bias'].view(pop_size)
|
| 562 |
+
and_a = heaviside((inp_a * w_and_a).sum(1) + b_and_a)
|
| 563 |
+
|
| 564 |
+
# AND(b, not_sel)
|
| 565 |
+
inp_b = torch.stack([b_t, not_sel.squeeze(1)], dim=1)
|
| 566 |
+
w_and_b = pop['combinational.multiplexer2to1.and_b.weight'].view(pop_size, -1)
|
| 567 |
+
b_and_b = pop['combinational.multiplexer2to1.and_b.bias'].view(pop_size)
|
| 568 |
+
and_b = heaviside((inp_b * w_and_b).sum(1) + b_and_b)
|
| 569 |
+
|
| 570 |
+
# OR
|
| 571 |
+
inp_or = torch.stack([and_a, and_b], dim=1)
|
| 572 |
+
w_or = pop['combinational.multiplexer2to1.or.weight'].view(pop_size, -1)
|
| 573 |
+
b_or = pop['combinational.multiplexer2to1.or.bias'].view(pop_size)
|
| 574 |
+
out = heaviside((inp_or * w_or).sum(1) + b_or)
|
| 575 |
+
|
| 576 |
+
scores += (out == expected).float()
|
| 577 |
+
|
| 578 |
+
return scores
|
| 579 |
+
|
| 580 |
+
def _test_decoder3to8(self, pop: Dict) -> torch.Tensor:
|
| 581 |
+
"""Test 3-to-8 decoder."""
|
| 582 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 583 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 584 |
+
|
| 585 |
+
for val in range(8):
|
| 586 |
+
bits = [(val >> i) & 1 for i in range(3)]
|
| 587 |
+
inp = torch.tensor([[float(bits[2]), float(bits[1]), float(bits[0])]], device=self.device)
|
| 588 |
+
|
| 589 |
+
# Test each output
|
| 590 |
+
for out_idx in range(8):
|
| 591 |
+
w = pop[f'combinational.decoder3to8.out{out_idx}.weight'].view(pop_size, -1)
|
| 592 |
+
b = pop[f'combinational.decoder3to8.out{out_idx}.bias'].view(pop_size)
|
| 593 |
+
out = heaviside(inp @ w.T + b)
|
| 594 |
+
expected = 1.0 if out_idx == val else 0.0
|
| 595 |
+
scores += (out.squeeze() == expected).float()
|
| 596 |
+
|
| 597 |
+
return scores
|
| 598 |
+
|
| 599 |
+
def _test_encoder8to3(self, pop: Dict) -> torch.Tensor:
|
| 600 |
+
"""Test 8-to-3 encoder (one-hot to binary)."""
|
| 601 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 602 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 603 |
+
|
| 604 |
+
for val in range(8):
|
| 605 |
+
# One-hot input
|
| 606 |
+
inp = torch.zeros(1, 8, device=self.device)
|
| 607 |
+
inp[0, val] = 1.0
|
| 608 |
+
|
| 609 |
+
for bit in range(3):
|
| 610 |
+
w = pop[f'combinational.encoder8to3.bit{bit}.weight'].view(pop_size, -1)
|
| 611 |
+
b = pop[f'combinational.encoder8to3.bit{bit}.bias'].view(pop_size)
|
| 612 |
+
out = heaviside(inp @ w.T + b)
|
| 613 |
+
expected = float((val >> bit) & 1)
|
| 614 |
+
scores += (out.squeeze() == expected).float()
|
| 615 |
+
|
| 616 |
+
return scores
|
| 617 |
+
|
| 618 |
+
# =========================================================================
|
| 619 |
+
# CONTROL FLOW (8-bit conditional MUX)
|
| 620 |
+
# =========================================================================
|
| 621 |
+
|
| 622 |
+
def _test_conditional_jump(self, pop: Dict, name: str) -> torch.Tensor:
|
| 623 |
+
"""Test 8-bit conditional jump (MUX) circuit."""
|
| 624 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 625 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 626 |
+
|
| 627 |
+
# Test with a few representative 8-bit value pairs and conditions
|
| 628 |
+
test_vals = [(0, 255, 0), (0, 255, 1), (127, 128, 0), (127, 128, 1),
|
| 629 |
+
(0xAA, 0x55, 0), (0xAA, 0x55, 1)]
|
| 630 |
+
|
| 631 |
+
for a_val, b_val, sel in test_vals:
|
| 632 |
+
expected = a_val if sel == 1 else b_val
|
| 633 |
+
|
| 634 |
+
for bit in range(8):
|
| 635 |
+
a_bit = (a_val >> bit) & 1
|
| 636 |
+
b_bit = (b_val >> bit) & 1
|
| 637 |
+
exp_bit = (expected >> bit) & 1
|
| 638 |
+
|
| 639 |
+
a_t = torch.full((pop_size,), float(a_bit), device=self.device)
|
| 640 |
+
b_t = torch.full((pop_size,), float(b_bit), device=self.device)
|
| 641 |
+
sel_t = torch.full((pop_size,), float(sel), device=self.device)
|
| 642 |
+
|
| 643 |
+
# NOT sel
|
| 644 |
+
w_not = pop[f'control.{name}.bit{bit}.not_sel.weight'].view(pop_size, -1)
|
| 645 |
+
b_not = pop[f'control.{name}.bit{bit}.not_sel.bias'].view(pop_size)
|
| 646 |
+
not_sel = heaviside(sel_t.unsqueeze(1) @ w_not.T + b_not)
|
| 647 |
+
|
| 648 |
+
# AND(a, sel)
|
| 649 |
+
inp_a = torch.stack([a_t, sel_t], dim=1)
|
| 650 |
+
w_and_a = pop[f'control.{name}.bit{bit}.and_a.weight'].view(pop_size, -1)
|
| 651 |
+
b_and_a = pop[f'control.{name}.bit{bit}.and_a.bias'].view(pop_size)
|
| 652 |
+
and_a = heaviside((inp_a * w_and_a).sum(1) + b_and_a)
|
| 653 |
+
|
| 654 |
+
# AND(b, not_sel)
|
| 655 |
+
inp_b = torch.stack([b_t, not_sel.squeeze(1)], dim=1)
|
| 656 |
+
w_and_b = pop[f'control.{name}.bit{bit}.and_b.weight'].view(pop_size, -1)
|
| 657 |
+
b_and_b = pop[f'control.{name}.bit{bit}.and_b.bias'].view(pop_size)
|
| 658 |
+
and_b = heaviside((inp_b * w_and_b).sum(1) + b_and_b)
|
| 659 |
+
|
| 660 |
+
# OR
|
| 661 |
+
inp_or = torch.stack([and_a, and_b], dim=1)
|
| 662 |
+
w_or = pop[f'control.{name}.bit{bit}.or.weight'].view(pop_size, -1)
|
| 663 |
+
b_or = pop[f'control.{name}.bit{bit}.or.bias'].view(pop_size)
|
| 664 |
+
out = heaviside((inp_or * w_or).sum(1) + b_or)
|
| 665 |
+
|
| 666 |
+
scores += (out == exp_bit).float()
|
| 667 |
+
|
| 668 |
+
return scores
|
| 669 |
+
|
| 670 |
+
# =========================================================================
|
| 671 |
+
# ALU
|
| 672 |
+
# =========================================================================
|
| 673 |
+
|
| 674 |
+
def _test_alu_op(self, pop: Dict, op: str, test_fn) -> torch.Tensor:
|
| 675 |
+
"""Test an 8-bit ALU operation."""
|
| 676 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 677 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 678 |
+
|
| 679 |
+
test_pairs = [(0, 0), (255, 255), (0, 255), (255, 0),
|
| 680 |
+
(0xAA, 0x55), (0x0F, 0xF0), (1, 1), (127, 128)]
|
| 681 |
+
|
| 682 |
+
for a_val, b_val in test_pairs:
|
| 683 |
+
expected = test_fn(a_val, b_val) & 0xFF
|
| 684 |
+
|
| 685 |
+
a_bits = torch.tensor([(a_val >> (7-i)) & 1 for i in range(8)], device=self.device, dtype=torch.float32)
|
| 686 |
+
b_bits = torch.tensor([(b_val >> (7-i)) & 1 for i in range(8)], device=self.device, dtype=torch.float32)
|
| 687 |
+
|
| 688 |
+
if op == 'and':
|
| 689 |
+
inp = torch.stack([a_bits, b_bits], dim=0).T.unsqueeze(0) # [1, 8, 2]
|
| 690 |
+
w = pop['alu.alu8bit.and.weight'].view(pop_size, -1) # [pop, 16]
|
| 691 |
+
b = pop['alu.alu8bit.and.bias'].view(pop_size, -1) # [pop, 8]
|
| 692 |
+
# This needs proper reshaping based on actual circuit structure
|
| 693 |
+
# Simplified: check if result bits match
|
| 694 |
+
out_val = a_val & b_val
|
| 695 |
+
elif op == 'or':
|
| 696 |
+
out_val = a_val | b_val
|
| 697 |
+
elif op == 'xor':
|
| 698 |
+
out_val = a_val ^ b_val
|
| 699 |
+
elif op == 'not':
|
| 700 |
+
out_val = (~a_val) & 0xFF
|
| 701 |
+
|
| 702 |
+
scores += (out_val == expected)
|
| 703 |
+
|
| 704 |
+
return scores
|
| 705 |
+
|
| 706 |
+
# =========================================================================
|
| 707 |
+
# MAIN EVALUATE
|
| 708 |
+
# =========================================================================
|
| 709 |
+
|
| 710 |
+
def evaluate(self, population: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 711 |
+
"""Evaluate fitness for entire population."""
|
| 712 |
+
pop_size = next(iter(population.values())).shape[0]
|
| 713 |
+
scores = torch.zeros(pop_size, device=self.device)
|
| 714 |
+
total_tests = 0
|
| 715 |
+
|
| 716 |
+
# =================================================================
|
| 717 |
+
# BOOLEAN GATES (34 tests)
|
| 718 |
+
# =================================================================
|
| 719 |
+
for gate in ['and', 'or', 'nand', 'nor']:
|
| 720 |
+
scores += self._test_single_gate(population, gate, self.tt2, self.expected[gate])
|
| 721 |
+
total_tests += 4
|
| 722 |
+
|
| 723 |
+
# NOT
|
| 724 |
+
w = population['boolean.not.weight'].view(pop_size, -1)
|
| 725 |
+
b = population['boolean.not.bias'].view(pop_size)
|
| 726 |
+
out = heaviside(self.not_inputs @ w.T + b)
|
| 727 |
+
scores += (out == self.expected['not'].unsqueeze(1)).float().sum(0)
|
| 728 |
+
total_tests += 2
|
| 729 |
+
|
| 730 |
+
# IMPLIES
|
| 731 |
+
scores += self._test_single_gate(population, 'implies', self.tt2, self.expected['implies'])
|
| 732 |
+
total_tests += 4
|
| 733 |
+
|
| 734 |
+
# XOR, XNOR, BIIMPLIES
|
| 735 |
+
scores += self._test_twolayer_gate(population, 'boolean.xor', self.tt2, self.expected['xor'])
|
| 736 |
+
scores += self._test_twolayer_gate(population, 'boolean.xnor', self.tt2, self.expected['xnor'])
|
| 737 |
+
scores += self._test_twolayer_gate(population, 'boolean.biimplies', self.tt2, self.expected['biimplies'])
|
| 738 |
+
total_tests += 12
|
| 739 |
+
|
| 740 |
+
# =================================================================
|
| 741 |
+
# ARITHMETIC - ADDERS (340 tests)
|
| 742 |
+
# =================================================================
|
| 743 |
+
scores += self._test_halfadder(population)
|
| 744 |
+
total_tests += 8
|
| 745 |
+
|
| 746 |
+
scores += self._test_fulladder(population)
|
| 747 |
+
total_tests += 16
|
| 748 |
+
|
| 749 |
+
# Ripple carry adders
|
| 750 |
+
rc2_tests = [(a, b) for a in range(4) for b in range(4)]
|
| 751 |
+
scores += self._test_ripplecarry(population, 2, rc2_tests)
|
| 752 |
+
total_tests += 16
|
| 753 |
+
|
| 754 |
+
rc4_tests = [(a, b) for a in range(16) for b in range(16)]
|
| 755 |
+
scores += self._test_ripplecarry(population, 4, rc4_tests)
|
| 756 |
+
total_tests += 256
|
| 757 |
+
|
| 758 |
+
rc8_tests = [(0,0), (1,1), (127,128), (255,1), (128,127), (255,255),
|
| 759 |
+
(0xAA, 0x55), (0x0F, 0xF0), (100, 155), (200, 55)]
|
| 760 |
+
scores += self._test_ripplecarry(population, 8, rc8_tests)
|
| 761 |
+
total_tests += len(rc8_tests)
|
| 762 |
+
|
| 763 |
+
# =================================================================
|
| 764 |
+
# ARITHMETIC - COMPARATORS (240 tests)
|
| 765 |
+
# =================================================================
|
| 766 |
+
scores += self._test_comparator(population, 'greaterthan8bit', 'gt')
|
| 767 |
+
scores += self._test_comparator(population, 'lessthan8bit', 'lt')
|
| 768 |
+
scores += self._test_comparator(population, 'greaterorequal8bit', 'geq')
|
| 769 |
+
scores += self._test_comparator(population, 'lessorequal8bit', 'leq')
|
| 770 |
+
total_tests += 4 * len(self.comp_a)
|
| 771 |
+
|
| 772 |
+
scores += self._test_equality(population)
|
| 773 |
+
total_tests += len(self.comp_a)
|
| 774 |
+
|
| 775 |
+
# =================================================================
|
| 776 |
+
# THRESHOLD GATES (264 tests)
|
| 777 |
+
# =================================================================
|
| 778 |
+
for k, name in enumerate(['oneoutof8', 'twooutof8', 'threeoutof8', 'fouroutof8',
|
| 779 |
+
'fiveoutof8', 'sixoutof8', 'sevenoutof8', 'alloutof8'], 1):
|
| 780 |
+
scores += self._test_threshold_kofn(population, k, name)
|
| 781 |
+
total_tests += len(self.test_8bit)
|
| 782 |
+
|
| 783 |
+
scores += self._test_majority(population)
|
| 784 |
+
scores += self._test_minority(population)
|
| 785 |
+
total_tests += 2 * len(self.test_8bit)
|
| 786 |
+
|
| 787 |
+
scores += self._test_atleastk(population, 4)
|
| 788 |
+
scores += self._test_atmostk(population, 4)
|
| 789 |
+
scores += self._test_exactlyk(population, 4)
|
| 790 |
+
total_tests += 3 * len(self.test_8bit)
|
| 791 |
+
|
| 792 |
+
# =================================================================
|
| 793 |
+
# PATTERN RECOGNITION (72 tests)
|
| 794 |
+
# =================================================================
|
| 795 |
+
scores += self._test_popcount(population)
|
| 796 |
+
scores += self._test_allzeros(population)
|
| 797 |
+
scores += self._test_allones(population)
|
| 798 |
+
total_tests += 3 * len(self.test_8bit)
|
| 799 |
+
|
| 800 |
+
# =================================================================
|
| 801 |
+
# ERROR DETECTION (48 tests)
|
| 802 |
+
# =================================================================
|
| 803 |
+
scores += self._test_parity(population, 'paritychecker8bit', True)
|
| 804 |
+
scores += self._test_parity(population, 'paritygenerator8bit', True)
|
| 805 |
+
total_tests += 2 * len(self.test_8bit)
|
| 806 |
+
|
| 807 |
+
# =================================================================
|
| 808 |
+
# MODULAR ARITHMETIC (2816 tests: 256 values × 11 moduli)
|
| 809 |
+
# =================================================================
|
| 810 |
+
for mod in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]:
|
| 811 |
+
scores += self._test_modular(population, mod)
|
| 812 |
+
total_tests += len(self.mod_test)
|
| 813 |
+
|
| 814 |
+
# =================================================================
|
| 815 |
+
# COMBINATIONAL (88 tests)
|
| 816 |
+
# =================================================================
|
| 817 |
+
scores += self._test_mux2to1(population)
|
| 818 |
+
total_tests += 8
|
| 819 |
+
|
| 820 |
+
scores += self._test_decoder3to8(population)
|
| 821 |
+
total_tests += 64
|
| 822 |
+
|
| 823 |
+
scores += self._test_encoder8to3(population)
|
| 824 |
+
total_tests += 24
|
| 825 |
+
|
| 826 |
+
# =================================================================
|
| 827 |
+
# CONTROL FLOW (480 tests: 10 circuits × 6 cases × 8 bits)
|
| 828 |
+
# =================================================================
|
| 829 |
+
for ctrl in ['conditionaljump', 'jz', 'jnz', 'jc', 'jnc', 'jn', 'jp', 'jv', 'jnv']:
|
| 830 |
+
scores += self._test_conditional_jump(population, ctrl)
|
| 831 |
+
total_tests += 6 * 8
|
| 832 |
+
|
| 833 |
+
self.total_tests = total_tests
|
| 834 |
+
return scores / total_tests
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
def create_population(base_tensors: Dict[str, torch.Tensor],
|
| 838 |
+
pop_size: int,
|
| 839 |
+
device='cuda') -> Dict[str, torch.Tensor]:
|
| 840 |
+
"""Create population by replicating base tensors."""
|
| 841 |
+
population = {}
|
| 842 |
+
for name, weight in base_tensors.items():
|
| 843 |
+
population[name] = weight.unsqueeze(0).expand(pop_size, *weight.shape).clone().to(device)
|
| 844 |
+
return population
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
if __name__ == "__main__":
|
| 848 |
+
import time
|
| 849 |
+
|
| 850 |
+
print("="*70)
|
| 851 |
+
print(" IRON EVAL - COMPREHENSIVE TEST")
|
| 852 |
+
print("="*70)
|
| 853 |
+
|
| 854 |
+
print("\nLoading model...")
|
| 855 |
+
model = load_model_10166()
|
| 856 |
+
print(f"Loaded {len(model)} tensors, {sum(t.numel() for t in model.values())} params")
|
| 857 |
+
|
| 858 |
+
print("\nInitializing evaluator...")
|
| 859 |
+
evaluator = BatchedFitnessEvaluator(device='cuda')
|
| 860 |
+
|
| 861 |
+
print("\nCreating population (size 1)...")
|
| 862 |
+
pop = create_population(model, pop_size=1, device='cuda')
|
| 863 |
+
|
| 864 |
+
print("\nRunning evaluation...")
|
| 865 |
+
torch.cuda.synchronize()
|
| 866 |
+
start = time.perf_counter()
|
| 867 |
+
fitness = evaluator.evaluate(pop)
|
| 868 |
+
torch.cuda.synchronize()
|
| 869 |
+
elapsed = time.perf_counter() - start
|
| 870 |
+
|
| 871 |
+
print(f"\nResults:")
|
| 872 |
+
print(f" Fitness: {fitness[0]:.6f}")
|
| 873 |
+
print(f" Total tests: {evaluator.total_tests}")
|
| 874 |
+
print(f" Time: {elapsed*1000:.2f} ms")
|
| 875 |
+
|
| 876 |
+
if fitness[0] == 1.0:
|
| 877 |
+
print("\n STATUS: PASS - All circuits functional")
|
| 878 |
+
else:
|
| 879 |
+
failed = int((1 - fitness[0]) * evaluator.total_tests)
|
| 880 |
+
print(f"\n STATUS: FAIL - {failed} tests failed")
|