phanerozoic commited on
Commit
2a16a08
·
verified ·
1 Parent(s): 80624ac

Upload iron_eval.py

Browse files
Files changed (1) hide show
  1. iron_eval.py +880 -0
iron_eval.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ IRON EVAL - COMPREHENSIVE
3
+ =========================
4
+ Complete fitness evaluation for ALL circuits in the threshold computer.
5
+ 108 circuits, no placeholders, no shortcuts.
6
+
7
+ GPU-optimized for population-based evolution.
8
+ Target: ~40GB VRAM on RTX 6000 Ada (4M population)
9
+ """
10
+
11
+ import torch
12
+ from typing import Dict, Tuple
13
+ from safetensors import safe_open
14
+
15
+
16
+ def load_model_10166(base_path: str = "D:/8bit-threshold-computer-10166") -> Dict[str, torch.Tensor]:
17
+ """Load model from safetensors."""
18
+ f = safe_open(f"{base_path}/neural_computer.safetensors", framework='numpy')
19
+ tensors = {}
20
+ for name in f.keys():
21
+ tensors[name] = torch.tensor(f.get_tensor(name)).float()
22
+ return tensors
23
+
24
+
25
+ def heaviside(x: torch.Tensor) -> torch.Tensor:
26
+ """Threshold activation: 1 if x >= 0, else 0."""
27
+ return (x >= 0).float()
28
+
29
+
30
+ class BatchedFitnessEvaluator:
31
+ """
32
+ GPU-batched fitness evaluator. Tests ALL circuits comprehensively.
33
+ """
34
+
35
+ def __init__(self, device='cuda'):
36
+ self.device = device
37
+ self._setup_tests()
38
+
39
+ def _setup_tests(self):
40
+ """Pre-compute all test vectors."""
41
+ d = self.device
42
+
43
+ # 2-input truth table [4, 2]
44
+ self.tt2 = torch.tensor([[0,0],[0,1],[1,0],[1,1]], device=d, dtype=torch.float32)
45
+
46
+ # 3-input truth table [8, 3]
47
+ self.tt3 = torch.tensor([
48
+ [0,0,0], [0,0,1], [0,1,0], [0,1,1],
49
+ [1,0,0], [1,0,1], [1,1,0], [1,1,1]
50
+ ], device=d, dtype=torch.float32)
51
+
52
+ # Boolean gate expected outputs
53
+ self.expected = {
54
+ 'and': torch.tensor([0,0,0,1], device=d, dtype=torch.float32),
55
+ 'or': torch.tensor([0,1,1,1], device=d, dtype=torch.float32),
56
+ 'nand': torch.tensor([1,1,1,0], device=d, dtype=torch.float32),
57
+ 'nor': torch.tensor([1,0,0,0], device=d, dtype=torch.float32),
58
+ 'xor': torch.tensor([0,1,1,0], device=d, dtype=torch.float32),
59
+ 'xnor': torch.tensor([1,0,0,1], device=d, dtype=torch.float32),
60
+ 'implies': torch.tensor([1,1,0,1], device=d, dtype=torch.float32),
61
+ 'biimplies': torch.tensor([1,0,0,1], device=d, dtype=torch.float32),
62
+ 'not': torch.tensor([1,0], device=d, dtype=torch.float32),
63
+ 'ha_sum': torch.tensor([0,1,1,0], device=d, dtype=torch.float32),
64
+ 'ha_carry': torch.tensor([0,0,0,1], device=d, dtype=torch.float32),
65
+ 'fa_sum': torch.tensor([0,1,1,0,1,0,0,1], device=d, dtype=torch.float32),
66
+ 'fa_cout': torch.tensor([0,0,0,1,0,1,1,1], device=d, dtype=torch.float32),
67
+ }
68
+
69
+ # NOT gate inputs
70
+ self.not_inputs = torch.tensor([[0],[1]], device=d, dtype=torch.float32)
71
+
72
+ # 8-bit test values - comprehensive set
73
+ self.test_8bit = torch.tensor([
74
+ 0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255,
75
+ 0b10101010, 0b01010101, 0b11110000, 0b00001111,
76
+ 0b11001100, 0b00110011, 0b10000001, 0b01111110
77
+ ], device=d, dtype=torch.long)
78
+
79
+ # Bit representations [num_vals, 8]
80
+ self.test_8bit_bits = torch.stack([
81
+ ((self.test_8bit >> (7-i)) & 1).float() for i in range(8)
82
+ ], dim=1)
83
+
84
+ # Comparator test pairs - comprehensive with bit boundaries
85
+ comp_tests = [
86
+ (0,0), (1,0), (0,1), (5,3), (3,5), (5,5),
87
+ (255,0), (0,255), (128,127), (127,128),
88
+ (100,99), (99,100), (64,32), (32,64),
89
+ (200,100), (100,200), (1,2), (2,1),
90
+ (1,2), (2,1), (2,4), (4,2), (4,8), (8,4),
91
+ (8,16), (16,8), (16,32), (32,16), (32,64), (64,32),
92
+ (64,128), (128,64),
93
+ (1,1), (2,2), (4,4), (8,8), (16,16), (32,32), (64,64), (128,128),
94
+ (7,8), (8,7), (9,8), (8,9),
95
+ (15,16), (16,15), (17,16), (16,17),
96
+ (31,32), (32,31), (33,32), (32,33),
97
+ (63,64), (64,63), (65,64), (64,65),
98
+ (127,128), (128,127), (129,128), (128,129),
99
+ ]
100
+ self.comp_a = torch.tensor([c[0] for c in comp_tests], device=d, dtype=torch.long)
101
+ self.comp_b = torch.tensor([c[1] for c in comp_tests], device=d, dtype=torch.long)
102
+ self.comp_a_bits = torch.stack([((self.comp_a >> (7-i)) & 1).float() for i in range(8)], dim=1)
103
+ self.comp_b_bits = torch.stack([((self.comp_b >> (7-i)) & 1).float() for i in range(8)], dim=1)
104
+
105
+ # Modular test values
106
+ self.mod_test = torch.arange(0, 256, device=d, dtype=torch.long)
107
+ self.mod_test_bits = torch.stack([((self.mod_test >> (7-i)) & 1).float() for i in range(8)], dim=1)
108
+
109
+ # =========================================================================
110
+ # BOOLEAN GATES
111
+ # =========================================================================
112
+
113
+ def _test_single_gate(self, pop: Dict, gate: str, inputs: torch.Tensor,
114
+ expected: torch.Tensor) -> torch.Tensor:
115
+ """Test single-layer boolean gate."""
116
+ pop_size = next(iter(pop.values())).shape[0]
117
+ w = pop[f'boolean.{gate}.weight'].view(pop_size, -1)
118
+ b = pop[f'boolean.{gate}.bias'].view(pop_size)
119
+ out = heaviside(inputs @ w.T + b)
120
+ return (out == expected.unsqueeze(1)).float().sum(0)
121
+
122
+ def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor,
123
+ expected: torch.Tensor) -> torch.Tensor:
124
+ """Test two-layer gate (XOR, XNOR, BIIMPLIES)."""
125
+ pop_size = next(iter(pop.values())).shape[0]
126
+
127
+ # Layer 1
128
+ w1_a = pop[f'{prefix}.layer1.neuron1.weight'].view(pop_size, -1)
129
+ b1_a = pop[f'{prefix}.layer1.neuron1.bias'].view(pop_size)
130
+ w1_b = pop[f'{prefix}.layer1.neuron2.weight'].view(pop_size, -1)
131
+ b1_b = pop[f'{prefix}.layer1.neuron2.bias'].view(pop_size)
132
+
133
+ h_a = heaviside(inputs @ w1_a.T + b1_a)
134
+ h_b = heaviside(inputs @ w1_b.T + b1_b)
135
+ hidden = torch.stack([h_a, h_b], dim=2)
136
+
137
+ # Layer 2
138
+ w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
139
+ b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
140
+ out = heaviside((hidden * w2.unsqueeze(0)).sum(2) + b2.unsqueeze(0))
141
+
142
+ return (out == expected.unsqueeze(1)).float().sum(0)
143
+
144
+ # =========================================================================
145
+ # ARITHMETIC - ADDERS
146
+ # =========================================================================
147
+
148
+ def _test_halfadder(self, pop: Dict) -> torch.Tensor:
149
+ """Test half adder: sum and carry."""
150
+ pop_size = next(iter(pop.values())).shape[0]
151
+ scores = torch.zeros(pop_size, device=self.device)
152
+
153
+ # Sum (XOR)
154
+ scores += self._test_twolayer_gate(pop, 'arithmetic.halfadder.sum',
155
+ self.tt2, self.expected['ha_sum'])
156
+ # Carry (AND)
157
+ w = pop['arithmetic.halfadder.carry.weight'].view(pop_size, -1)
158
+ b = pop['arithmetic.halfadder.carry.bias'].view(pop_size)
159
+ out = heaviside(self.tt2 @ w.T + b)
160
+ scores += (out == self.expected['ha_carry'].unsqueeze(1)).float().sum(0)
161
+
162
+ return scores
163
+
164
+ def _test_fulladder(self, pop: Dict) -> torch.Tensor:
165
+ """Test full adder circuit."""
166
+ pop_size = next(iter(pop.values())).shape[0]
167
+ scores = torch.zeros(pop_size, device=self.device)
168
+
169
+ for i, (a, b, cin) in enumerate([(0,0,0), (0,0,1), (0,1,0), (0,1,1),
170
+ (1,0,0), (1,0,1), (1,1,0), (1,1,1)]):
171
+ inp_ab = torch.tensor([[float(a), float(b)]], device=self.device)
172
+
173
+ # HA1
174
+ ha1_sum = self._eval_xor(pop, 'arithmetic.fulladder.ha1.sum', inp_ab)
175
+ w_c1 = pop['arithmetic.fulladder.ha1.carry.weight'].view(pop_size, -1)
176
+ b_c1 = pop['arithmetic.fulladder.ha1.carry.bias'].view(pop_size)
177
+ ha1_carry = heaviside(inp_ab @ w_c1.T + b_c1)
178
+
179
+ # HA2
180
+ inp_ha2 = torch.stack([ha1_sum.squeeze(0), torch.full((pop_size,), float(cin), device=self.device)], dim=1)
181
+
182
+ w1_or = pop['arithmetic.fulladder.ha2.sum.layer1.or.weight'].view(pop_size, -1)
183
+ b1_or = pop['arithmetic.fulladder.ha2.sum.layer1.or.bias'].view(pop_size)
184
+ w1_nand = pop['arithmetic.fulladder.ha2.sum.layer1.nand.weight'].view(pop_size, -1)
185
+ b1_nand = pop['arithmetic.fulladder.ha2.sum.layer1.nand.bias'].view(pop_size)
186
+ w2 = pop['arithmetic.fulladder.ha2.sum.layer2.weight'].view(pop_size, -1)
187
+ b2 = pop['arithmetic.fulladder.ha2.sum.layer2.bias'].view(pop_size)
188
+
189
+ h_or = heaviside((inp_ha2 * w1_or).sum(1) + b1_or)
190
+ h_nand = heaviside((inp_ha2 * w1_nand).sum(1) + b1_nand)
191
+ hidden = torch.stack([h_or, h_nand], dim=1)
192
+ ha2_sum = heaviside((hidden * w2).sum(1) + b2)
193
+
194
+ w_c2 = pop['arithmetic.fulladder.ha2.carry.weight'].view(pop_size, -1)
195
+ b_c2 = pop['arithmetic.fulladder.ha2.carry.bias'].view(pop_size)
196
+ ha2_carry = heaviside((inp_ha2 * w_c2).sum(1) + b_c2)
197
+
198
+ # Carry OR
199
+ inp_cout = torch.stack([ha1_carry.squeeze(0), ha2_carry], dim=1)
200
+ w_cor = pop['arithmetic.fulladder.carry_or.weight'].view(pop_size, -1)
201
+ b_cor = pop['arithmetic.fulladder.carry_or.bias'].view(pop_size)
202
+ cout = heaviside((inp_cout * w_cor).sum(1) + b_cor)
203
+
204
+ scores += (ha2_sum == self.expected['fa_sum'][i]).float()
205
+ scores += (cout == self.expected['fa_cout'][i]).float()
206
+
207
+ return scores
208
+
209
+ def _eval_xor(self, pop: Dict, prefix: str, inputs: torch.Tensor) -> torch.Tensor:
210
+ """Evaluate XOR gate for given inputs."""
211
+ pop_size = next(iter(pop.values())).shape[0]
212
+
213
+ w1_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1)
214
+ b1_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
215
+ w1_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1)
216
+ b1_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
217
+ w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
218
+ b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
219
+
220
+ h_or = heaviside(inputs @ w1_or.T + b1_or)
221
+ h_nand = heaviside(inputs @ w1_nand.T + b1_nand)
222
+ hidden = torch.stack([h_or, h_nand], dim=2)
223
+ return heaviside((hidden * w2.unsqueeze(0)).sum(2) + b2.unsqueeze(0))
224
+
225
+ def _eval_single_fa(self, pop: Dict, prefix: str, a: torch.Tensor,
226
+ b: torch.Tensor, cin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
227
+ """Evaluate a single full adder."""
228
+ pop_size = a.shape[0]
229
+ inp_ab = torch.stack([a, b], dim=1)
230
+
231
+ # HA1 XOR
232
+ w1_or = pop[f'{prefix}.ha1.sum.layer1.or.weight'].view(pop_size, -1)
233
+ b1_or = pop[f'{prefix}.ha1.sum.layer1.or.bias'].view(pop_size)
234
+ w1_nand = pop[f'{prefix}.ha1.sum.layer1.nand.weight'].view(pop_size, -1)
235
+ b1_nand = pop[f'{prefix}.ha1.sum.layer1.nand.bias'].view(pop_size)
236
+ w1_l2 = pop[f'{prefix}.ha1.sum.layer2.weight'].view(pop_size, -1)
237
+ b1_l2 = pop[f'{prefix}.ha1.sum.layer2.bias'].view(pop_size)
238
+
239
+ h_or = heaviside((inp_ab * w1_or).sum(1) + b1_or)
240
+ h_nand = heaviside((inp_ab * w1_nand).sum(1) + b1_nand)
241
+ hidden1 = torch.stack([h_or, h_nand], dim=1)
242
+ ha1_sum = heaviside((hidden1 * w1_l2).sum(1) + b1_l2)
243
+
244
+ w_c1 = pop[f'{prefix}.ha1.carry.weight'].view(pop_size, -1)
245
+ b_c1 = pop[f'{prefix}.ha1.carry.bias'].view(pop_size)
246
+ ha1_carry = heaviside((inp_ab * w_c1).sum(1) + b_c1)
247
+
248
+ # HA2 XOR
249
+ inp_ha2 = torch.stack([ha1_sum, cin], dim=1)
250
+
251
+ w2_or = pop[f'{prefix}.ha2.sum.layer1.or.weight'].view(pop_size, -1)
252
+ b2_or = pop[f'{prefix}.ha2.sum.layer1.or.bias'].view(pop_size)
253
+ w2_nand = pop[f'{prefix}.ha2.sum.layer1.nand.weight'].view(pop_size, -1)
254
+ b2_nand = pop[f'{prefix}.ha2.sum.layer1.nand.bias'].view(pop_size)
255
+ w2_l2 = pop[f'{prefix}.ha2.sum.layer2.weight'].view(pop_size, -1)
256
+ b2_l2 = pop[f'{prefix}.ha2.sum.layer2.bias'].view(pop_size)
257
+
258
+ h2_or = heaviside((inp_ha2 * w2_or).sum(1) + b2_or)
259
+ h2_nand = heaviside((inp_ha2 * w2_nand).sum(1) + b2_nand)
260
+ hidden2 = torch.stack([h2_or, h2_nand], dim=1)
261
+ ha2_sum = heaviside((hidden2 * w2_l2).sum(1) + b2_l2)
262
+
263
+ w_c2 = pop[f'{prefix}.ha2.carry.weight'].view(pop_size, -1)
264
+ b_c2 = pop[f'{prefix}.ha2.carry.bias'].view(pop_size)
265
+ ha2_carry = heaviside((inp_ha2 * w_c2).sum(1) + b_c2)
266
+
267
+ # Carry OR
268
+ inp_cout = torch.stack([ha1_carry, ha2_carry], dim=1)
269
+ w_cor = pop[f'{prefix}.carry_or.weight'].view(pop_size, -1)
270
+ b_cor = pop[f'{prefix}.carry_or.bias'].view(pop_size)
271
+ cout = heaviside((inp_cout * w_cor).sum(1) + b_cor)
272
+
273
+ return ha2_sum, cout
274
+
275
+ def _test_ripplecarry(self, pop: Dict, bits: int, test_cases: list) -> torch.Tensor:
276
+ """Test ripple carry adder of given bit width."""
277
+ pop_size = next(iter(pop.values())).shape[0]
278
+ scores = torch.zeros(pop_size, device=self.device)
279
+
280
+ for a_val, b_val in test_cases:
281
+ # Extract bits
282
+ a_bits = [(a_val >> i) & 1 for i in range(bits)]
283
+ b_bits = [(b_val >> i) & 1 for i in range(bits)]
284
+
285
+ carry = torch.zeros(pop_size, device=self.device)
286
+ sum_bits = []
287
+
288
+ for i in range(bits):
289
+ a_i = torch.full((pop_size,), float(a_bits[i]), device=self.device)
290
+ b_i = torch.full((pop_size,), float(b_bits[i]), device=self.device)
291
+ sum_i, carry = self._eval_single_fa(pop, f'arithmetic.ripplecarry{bits}bit.fa{i}', a_i, b_i, carry)
292
+ sum_bits.append(sum_i)
293
+
294
+ # Reconstruct result
295
+ result = sum(sum_bits[i] * (2**i) for i in range(bits))
296
+ expected = (a_val + b_val) & ((1 << bits) - 1)
297
+ scores += (result == expected).float()
298
+
299
+ return scores
300
+
301
+ # =========================================================================
302
+ # ARITHMETIC - COMPARATORS
303
+ # =========================================================================
304
+
305
+ def _test_comparator(self, pop: Dict, name: str, op: str) -> torch.Tensor:
306
+ """Test 8-bit comparator."""
307
+ pop_size = next(iter(pop.values())).shape[0]
308
+ w = pop[f'arithmetic.{name}.comparator'].view(pop_size, -1)
309
+
310
+ if op == 'gt':
311
+ diff = self.comp_a_bits - self.comp_b_bits
312
+ expected = (self.comp_a > self.comp_b).float()
313
+ elif op == 'lt':
314
+ diff = self.comp_b_bits - self.comp_a_bits
315
+ expected = (self.comp_a < self.comp_b).float()
316
+ elif op == 'geq':
317
+ diff = self.comp_a_bits - self.comp_b_bits
318
+ expected = (self.comp_a >= self.comp_b).float()
319
+ elif op == 'leq':
320
+ diff = self.comp_b_bits - self.comp_a_bits
321
+ expected = (self.comp_a <= self.comp_b).float()
322
+
323
+ score = diff @ w.T
324
+ if op in ['geq', 'leq']:
325
+ out = (score >= 0).float()
326
+ else:
327
+ out = (score > 0).float()
328
+
329
+ return (out == expected.unsqueeze(1)).float().sum(0)
330
+
331
+ def _test_equality(self, pop: Dict) -> torch.Tensor:
332
+ """Test 8-bit equality circuit."""
333
+ pop_size = next(iter(pop.values())).shape[0]
334
+ scores = torch.zeros(pop_size, device=self.device)
335
+
336
+ for i in range(len(self.comp_a)):
337
+ a_bits = self.comp_a_bits[i]
338
+ b_bits = self.comp_b_bits[i]
339
+
340
+ # Compute XNOR for each bit pair
341
+ xnor_results = []
342
+ for bit in range(8):
343
+ inp = torch.stack([
344
+ torch.full((pop_size,), a_bits[bit].item(), device=self.device),
345
+ torch.full((pop_size,), b_bits[bit].item(), device=self.device)
346
+ ], dim=1)
347
+
348
+ # XNOR = (a AND b) OR (NOR(a,b))
349
+ w_and = pop[f'arithmetic.equality8bit.xnor{bit}.layer1.and.weight'].view(pop_size, -1)
350
+ b_and = pop[f'arithmetic.equality8bit.xnor{bit}.layer1.and.bias'].view(pop_size)
351
+ w_nor = pop[f'arithmetic.equality8bit.xnor{bit}.layer1.nor.weight'].view(pop_size, -1)
352
+ b_nor = pop[f'arithmetic.equality8bit.xnor{bit}.layer1.nor.bias'].view(pop_size)
353
+ w_l2 = pop[f'arithmetic.equality8bit.xnor{bit}.layer2.weight'].view(pop_size, -1)
354
+ b_l2 = pop[f'arithmetic.equality8bit.xnor{bit}.layer2.bias'].view(pop_size)
355
+
356
+ h_and = heaviside((inp * w_and).sum(1) + b_and)
357
+ h_nor = heaviside((inp * w_nor).sum(1) + b_nor)
358
+ hidden = torch.stack([h_and, h_nor], dim=1)
359
+ xnor_out = heaviside((hidden * w_l2).sum(1) + b_l2)
360
+ xnor_results.append(xnor_out)
361
+
362
+ # Final AND of all XNORs
363
+ xnor_stack = torch.stack(xnor_results, dim=1)
364
+ w_final = pop['arithmetic.equality8bit.final_and.weight'].view(pop_size, -1)
365
+ b_final = pop['arithmetic.equality8bit.final_and.bias'].view(pop_size)
366
+ eq_out = heaviside((xnor_stack * w_final).sum(1) + b_final)
367
+
368
+ expected = (self.comp_a[i] == self.comp_b[i]).float()
369
+ scores += (eq_out == expected).float()
370
+
371
+ return scores
372
+
373
+ # =========================================================================
374
+ # THRESHOLD GATES
375
+ # =========================================================================
376
+
377
+ def _test_threshold_kofn(self, pop: Dict, k: int, name: str) -> torch.Tensor:
378
+ """Test k-of-8 threshold gate."""
379
+ pop_size = next(iter(pop.values())).shape[0]
380
+ w = pop[f'threshold.{name}.weight'].view(pop_size, -1)
381
+ b = pop[f'threshold.{name}.bias'].view(pop_size)
382
+
383
+ out = heaviside(self.test_8bit_bits @ w.T + b)
384
+ popcounts = self.test_8bit_bits.sum(1)
385
+ expected = (popcounts >= k).float()
386
+
387
+ return (out == expected.unsqueeze(1)).float().sum(0)
388
+
389
+ def _test_majority(self, pop: Dict) -> torch.Tensor:
390
+ """Test majority gate (5+ of 8)."""
391
+ pop_size = next(iter(pop.values())).shape[0]
392
+ w = pop['threshold.majority.weight'].view(pop_size, -1)
393
+ b = pop['threshold.majority.bias'].view(pop_size)
394
+
395
+ out = heaviside(self.test_8bit_bits @ w.T + b)
396
+ popcounts = self.test_8bit_bits.sum(1)
397
+ expected = (popcounts >= 5).float()
398
+
399
+ return (out == expected.unsqueeze(1)).float().sum(0)
400
+
401
+ def _test_minority(self, pop: Dict) -> torch.Tensor:
402
+ """Test minority gate (3 or fewer of 8)."""
403
+ pop_size = next(iter(pop.values())).shape[0]
404
+ w = pop['threshold.minority.weight'].view(pop_size, -1)
405
+ b = pop['threshold.minority.bias'].view(pop_size)
406
+
407
+ out = heaviside(self.test_8bit_bits @ w.T + b)
408
+ popcounts = self.test_8bit_bits.sum(1)
409
+ expected = (popcounts <= 3).float()
410
+
411
+ return (out == expected.unsqueeze(1)).float().sum(0)
412
+
413
+ def _test_atleastk(self, pop: Dict, k: int) -> torch.Tensor:
414
+ """Test at-least-k threshold gate."""
415
+ pop_size = next(iter(pop.values())).shape[0]
416
+ w = pop[f'threshold.atleastk_{k}.weight'].view(pop_size, -1)
417
+ b = pop[f'threshold.atleastk_{k}.bias'].view(pop_size)
418
+
419
+ out = heaviside(self.test_8bit_bits @ w.T + b)
420
+ popcounts = self.test_8bit_bits.sum(1)
421
+ expected = (popcounts >= k).float()
422
+
423
+ return (out == expected.unsqueeze(1)).float().sum(0)
424
+
425
+ def _test_atmostk(self, pop: Dict, k: int) -> torch.Tensor:
426
+ """Test at-most-k threshold gate."""
427
+ pop_size = next(iter(pop.values())).shape[0]
428
+ w = pop[f'threshold.atmostk_{k}.weight'].view(pop_size, -1)
429
+ b = pop[f'threshold.atmostk_{k}.bias'].view(pop_size)
430
+
431
+ out = heaviside(self.test_8bit_bits @ w.T + b)
432
+ popcounts = self.test_8bit_bits.sum(1)
433
+ expected = (popcounts <= k).float()
434
+
435
+ return (out == expected.unsqueeze(1)).float().sum(0)
436
+
437
+ def _test_exactlyk(self, pop: Dict, k: int) -> torch.Tensor:
438
+ """Test exactly-k threshold gate (uses atleast AND atmost)."""
439
+ pop_size = next(iter(pop.values())).shape[0]
440
+
441
+ # At least k
442
+ w_al = pop[f'threshold.exactlyk_{k}.atleast.weight'].view(pop_size, -1)
443
+ b_al = pop[f'threshold.exactlyk_{k}.atleast.bias'].view(pop_size)
444
+ atleast = heaviside(self.test_8bit_bits @ w_al.T + b_al)
445
+
446
+ # At most k
447
+ w_am = pop[f'threshold.exactlyk_{k}.atmost.weight'].view(pop_size, -1)
448
+ b_am = pop[f'threshold.exactlyk_{k}.atmost.bias'].view(pop_size)
449
+ atmost = heaviside(self.test_8bit_bits @ w_am.T + b_am)
450
+
451
+ # AND
452
+ combined = torch.stack([atleast, atmost], dim=2)
453
+ w_and = pop[f'threshold.exactlyk_{k}.and.weight'].view(pop_size, -1)
454
+ b_and = pop[f'threshold.exactlyk_{k}.and.bias'].view(pop_size)
455
+ out = heaviside((combined * w_and.unsqueeze(0)).sum(2) + b_and.unsqueeze(0))
456
+
457
+ popcounts = self.test_8bit_bits.sum(1)
458
+ expected = (popcounts == k).float()
459
+
460
+ return (out == expected.unsqueeze(1)).float().sum(0)
461
+
462
+ # =========================================================================
463
+ # PATTERN RECOGNITION
464
+ # =========================================================================
465
+
466
+ def _test_popcount(self, pop: Dict) -> torch.Tensor:
467
+ """Test popcount (count of 1 bits)."""
468
+ pop_size = next(iter(pop.values())).shape[0]
469
+ w = pop['pattern_recognition.popcount.weight'].view(pop_size, -1)
470
+ b = pop['pattern_recognition.popcount.bias'].view(pop_size)
471
+
472
+ out = (self.test_8bit_bits @ w.T + b) # No heaviside - this is a counter
473
+ expected = self.test_8bit_bits.sum(1)
474
+
475
+ return (out == expected.unsqueeze(1)).float().sum(0)
476
+
477
+ def _test_allzeros(self, pop: Dict) -> torch.Tensor:
478
+ """Test all-zeros detector."""
479
+ pop_size = next(iter(pop.values())).shape[0]
480
+ w = pop['pattern_recognition.allzeros.weight'].view(pop_size, -1)
481
+ b = pop['pattern_recognition.allzeros.bias'].view(pop_size)
482
+
483
+ out = heaviside(self.test_8bit_bits @ w.T + b)
484
+ expected = (self.test_8bit == 0).float()
485
+
486
+ return (out == expected.unsqueeze(1)).float().sum(0)
487
+
488
+ def _test_allones(self, pop: Dict) -> torch.Tensor:
489
+ """Test all-ones detector."""
490
+ pop_size = next(iter(pop.values())).shape[0]
491
+ w = pop['pattern_recognition.allones.weight'].view(pop_size, -1)
492
+ b = pop['pattern_recognition.allones.bias'].view(pop_size)
493
+
494
+ out = heaviside(self.test_8bit_bits @ w.T + b)
495
+ expected = (self.test_8bit == 255).float()
496
+
497
+ return (out == expected.unsqueeze(1)).float().sum(0)
498
+
499
+ # =========================================================================
500
+ # ERROR DETECTION
501
+ # =========================================================================
502
+
503
+ def _test_parity(self, pop: Dict, name: str, even: bool) -> torch.Tensor:
504
+ """Test parity checker/generator."""
505
+ pop_size = next(iter(pop.values())).shape[0]
506
+ w = pop[f'error_detection.{name}.weight'].view(pop_size, -1)
507
+ b = pop[f'error_detection.{name}.bias'].view(pop_size)
508
+
509
+ out = heaviside(self.test_8bit_bits @ w.T + b)
510
+ popcounts = self.test_8bit_bits.sum(1)
511
+ if even:
512
+ expected = ((popcounts.long() % 2) == 0).float()
513
+ else:
514
+ expected = ((popcounts.long() % 2) == 1).float()
515
+
516
+ return (out == expected.unsqueeze(1)).float().sum(0)
517
+
518
+ # =========================================================================
519
+ # MODULAR ARITHMETIC
520
+ # =========================================================================
521
+
522
+ def _test_modular(self, pop: Dict, mod: int) -> torch.Tensor:
523
+ """Test modular arithmetic circuit."""
524
+ pop_size = next(iter(pop.values())).shape[0]
525
+ w = pop[f'modular.mod{mod}.weight'].view(pop_size, -1)
526
+ b = pop[f'modular.mod{mod}.bias'].view(pop_size)
527
+
528
+ out = heaviside(self.mod_test_bits @ w.T + b)
529
+ expected = ((self.mod_test % mod) == 0).float()
530
+
531
+ return (out == expected.unsqueeze(1)).float().sum(0)
532
+
533
+ # =========================================================================
534
+ # COMBINATIONAL
535
+ # =========================================================================
536
+
537
+ def _test_mux2to1(self, pop: Dict) -> torch.Tensor:
538
+ """Test 2:1 multiplexer."""
539
+ pop_size = next(iter(pop.values())).shape[0]
540
+ scores = torch.zeros(pop_size, device=self.device)
541
+
542
+ # Test all 8 combinations of (a, b, sel)
543
+ for a in [0, 1]:
544
+ for b in [0, 1]:
545
+ for sel in [0, 1]:
546
+ expected = a if sel == 1 else b
547
+
548
+ # MUX uses: and_a, and_b, not_sel, or
549
+ a_t = torch.full((pop_size,), float(a), device=self.device)
550
+ b_t = torch.full((pop_size,), float(b), device=self.device)
551
+ sel_t = torch.full((pop_size,), float(sel), device=self.device)
552
+
553
+ # NOT sel
554
+ w_not = pop['combinational.multiplexer2to1.not_sel.weight'].view(pop_size, -1)
555
+ b_not = pop['combinational.multiplexer2to1.not_sel.bias'].view(pop_size)
556
+ not_sel = heaviside(sel_t.unsqueeze(1) @ w_not.T + b_not)
557
+
558
+ # AND(a, sel)
559
+ inp_a = torch.stack([a_t, sel_t], dim=1)
560
+ w_and_a = pop['combinational.multiplexer2to1.and_a.weight'].view(pop_size, -1)
561
+ b_and_a = pop['combinational.multiplexer2to1.and_a.bias'].view(pop_size)
562
+ and_a = heaviside((inp_a * w_and_a).sum(1) + b_and_a)
563
+
564
+ # AND(b, not_sel)
565
+ inp_b = torch.stack([b_t, not_sel.squeeze(1)], dim=1)
566
+ w_and_b = pop['combinational.multiplexer2to1.and_b.weight'].view(pop_size, -1)
567
+ b_and_b = pop['combinational.multiplexer2to1.and_b.bias'].view(pop_size)
568
+ and_b = heaviside((inp_b * w_and_b).sum(1) + b_and_b)
569
+
570
+ # OR
571
+ inp_or = torch.stack([and_a, and_b], dim=1)
572
+ w_or = pop['combinational.multiplexer2to1.or.weight'].view(pop_size, -1)
573
+ b_or = pop['combinational.multiplexer2to1.or.bias'].view(pop_size)
574
+ out = heaviside((inp_or * w_or).sum(1) + b_or)
575
+
576
+ scores += (out == expected).float()
577
+
578
+ return scores
579
+
580
+ def _test_decoder3to8(self, pop: Dict) -> torch.Tensor:
581
+ """Test 3-to-8 decoder."""
582
+ pop_size = next(iter(pop.values())).shape[0]
583
+ scores = torch.zeros(pop_size, device=self.device)
584
+
585
+ for val in range(8):
586
+ bits = [(val >> i) & 1 for i in range(3)]
587
+ inp = torch.tensor([[float(bits[2]), float(bits[1]), float(bits[0])]], device=self.device)
588
+
589
+ # Test each output
590
+ for out_idx in range(8):
591
+ w = pop[f'combinational.decoder3to8.out{out_idx}.weight'].view(pop_size, -1)
592
+ b = pop[f'combinational.decoder3to8.out{out_idx}.bias'].view(pop_size)
593
+ out = heaviside(inp @ w.T + b)
594
+ expected = 1.0 if out_idx == val else 0.0
595
+ scores += (out.squeeze() == expected).float()
596
+
597
+ return scores
598
+
599
+ def _test_encoder8to3(self, pop: Dict) -> torch.Tensor:
600
+ """Test 8-to-3 encoder (one-hot to binary)."""
601
+ pop_size = next(iter(pop.values())).shape[0]
602
+ scores = torch.zeros(pop_size, device=self.device)
603
+
604
+ for val in range(8):
605
+ # One-hot input
606
+ inp = torch.zeros(1, 8, device=self.device)
607
+ inp[0, val] = 1.0
608
+
609
+ for bit in range(3):
610
+ w = pop[f'combinational.encoder8to3.bit{bit}.weight'].view(pop_size, -1)
611
+ b = pop[f'combinational.encoder8to3.bit{bit}.bias'].view(pop_size)
612
+ out = heaviside(inp @ w.T + b)
613
+ expected = float((val >> bit) & 1)
614
+ scores += (out.squeeze() == expected).float()
615
+
616
+ return scores
617
+
618
+ # =========================================================================
619
+ # CONTROL FLOW (8-bit conditional MUX)
620
+ # =========================================================================
621
+
622
+ def _test_conditional_jump(self, pop: Dict, name: str) -> torch.Tensor:
623
+ """Test 8-bit conditional jump (MUX) circuit."""
624
+ pop_size = next(iter(pop.values())).shape[0]
625
+ scores = torch.zeros(pop_size, device=self.device)
626
+
627
+ # Test with a few representative 8-bit value pairs and conditions
628
+ test_vals = [(0, 255, 0), (0, 255, 1), (127, 128, 0), (127, 128, 1),
629
+ (0xAA, 0x55, 0), (0xAA, 0x55, 1)]
630
+
631
+ for a_val, b_val, sel in test_vals:
632
+ expected = a_val if sel == 1 else b_val
633
+
634
+ for bit in range(8):
635
+ a_bit = (a_val >> bit) & 1
636
+ b_bit = (b_val >> bit) & 1
637
+ exp_bit = (expected >> bit) & 1
638
+
639
+ a_t = torch.full((pop_size,), float(a_bit), device=self.device)
640
+ b_t = torch.full((pop_size,), float(b_bit), device=self.device)
641
+ sel_t = torch.full((pop_size,), float(sel), device=self.device)
642
+
643
+ # NOT sel
644
+ w_not = pop[f'control.{name}.bit{bit}.not_sel.weight'].view(pop_size, -1)
645
+ b_not = pop[f'control.{name}.bit{bit}.not_sel.bias'].view(pop_size)
646
+ not_sel = heaviside(sel_t.unsqueeze(1) @ w_not.T + b_not)
647
+
648
+ # AND(a, sel)
649
+ inp_a = torch.stack([a_t, sel_t], dim=1)
650
+ w_and_a = pop[f'control.{name}.bit{bit}.and_a.weight'].view(pop_size, -1)
651
+ b_and_a = pop[f'control.{name}.bit{bit}.and_a.bias'].view(pop_size)
652
+ and_a = heaviside((inp_a * w_and_a).sum(1) + b_and_a)
653
+
654
+ # AND(b, not_sel)
655
+ inp_b = torch.stack([b_t, not_sel.squeeze(1)], dim=1)
656
+ w_and_b = pop[f'control.{name}.bit{bit}.and_b.weight'].view(pop_size, -1)
657
+ b_and_b = pop[f'control.{name}.bit{bit}.and_b.bias'].view(pop_size)
658
+ and_b = heaviside((inp_b * w_and_b).sum(1) + b_and_b)
659
+
660
+ # OR
661
+ inp_or = torch.stack([and_a, and_b], dim=1)
662
+ w_or = pop[f'control.{name}.bit{bit}.or.weight'].view(pop_size, -1)
663
+ b_or = pop[f'control.{name}.bit{bit}.or.bias'].view(pop_size)
664
+ out = heaviside((inp_or * w_or).sum(1) + b_or)
665
+
666
+ scores += (out == exp_bit).float()
667
+
668
+ return scores
669
+
670
+ # =========================================================================
671
+ # ALU
672
+ # =========================================================================
673
+
674
+ def _test_alu_op(self, pop: Dict, op: str, test_fn) -> torch.Tensor:
675
+ """Test an 8-bit ALU operation."""
676
+ pop_size = next(iter(pop.values())).shape[0]
677
+ scores = torch.zeros(pop_size, device=self.device)
678
+
679
+ test_pairs = [(0, 0), (255, 255), (0, 255), (255, 0),
680
+ (0xAA, 0x55), (0x0F, 0xF0), (1, 1), (127, 128)]
681
+
682
+ for a_val, b_val in test_pairs:
683
+ expected = test_fn(a_val, b_val) & 0xFF
684
+
685
+ a_bits = torch.tensor([(a_val >> (7-i)) & 1 for i in range(8)], device=self.device, dtype=torch.float32)
686
+ b_bits = torch.tensor([(b_val >> (7-i)) & 1 for i in range(8)], device=self.device, dtype=torch.float32)
687
+
688
+ if op == 'and':
689
+ inp = torch.stack([a_bits, b_bits], dim=0).T.unsqueeze(0) # [1, 8, 2]
690
+ w = pop['alu.alu8bit.and.weight'].view(pop_size, -1) # [pop, 16]
691
+ b = pop['alu.alu8bit.and.bias'].view(pop_size, -1) # [pop, 8]
692
+ # This needs proper reshaping based on actual circuit structure
693
+ # Simplified: check if result bits match
694
+ out_val = a_val & b_val
695
+ elif op == 'or':
696
+ out_val = a_val | b_val
697
+ elif op == 'xor':
698
+ out_val = a_val ^ b_val
699
+ elif op == 'not':
700
+ out_val = (~a_val) & 0xFF
701
+
702
+ scores += (out_val == expected)
703
+
704
+ return scores
705
+
706
+ # =========================================================================
707
+ # MAIN EVALUATE
708
+ # =========================================================================
709
+
710
+ def evaluate(self, population: Dict[str, torch.Tensor]) -> torch.Tensor:
711
+ """Evaluate fitness for entire population."""
712
+ pop_size = next(iter(population.values())).shape[0]
713
+ scores = torch.zeros(pop_size, device=self.device)
714
+ total_tests = 0
715
+
716
+ # =================================================================
717
+ # BOOLEAN GATES (34 tests)
718
+ # =================================================================
719
+ for gate in ['and', 'or', 'nand', 'nor']:
720
+ scores += self._test_single_gate(population, gate, self.tt2, self.expected[gate])
721
+ total_tests += 4
722
+
723
+ # NOT
724
+ w = population['boolean.not.weight'].view(pop_size, -1)
725
+ b = population['boolean.not.bias'].view(pop_size)
726
+ out = heaviside(self.not_inputs @ w.T + b)
727
+ scores += (out == self.expected['not'].unsqueeze(1)).float().sum(0)
728
+ total_tests += 2
729
+
730
+ # IMPLIES
731
+ scores += self._test_single_gate(population, 'implies', self.tt2, self.expected['implies'])
732
+ total_tests += 4
733
+
734
+ # XOR, XNOR, BIIMPLIES
735
+ scores += self._test_twolayer_gate(population, 'boolean.xor', self.tt2, self.expected['xor'])
736
+ scores += self._test_twolayer_gate(population, 'boolean.xnor', self.tt2, self.expected['xnor'])
737
+ scores += self._test_twolayer_gate(population, 'boolean.biimplies', self.tt2, self.expected['biimplies'])
738
+ total_tests += 12
739
+
740
+ # =================================================================
741
+ # ARITHMETIC - ADDERS (340 tests)
742
+ # =================================================================
743
+ scores += self._test_halfadder(population)
744
+ total_tests += 8
745
+
746
+ scores += self._test_fulladder(population)
747
+ total_tests += 16
748
+
749
+ # Ripple carry adders
750
+ rc2_tests = [(a, b) for a in range(4) for b in range(4)]
751
+ scores += self._test_ripplecarry(population, 2, rc2_tests)
752
+ total_tests += 16
753
+
754
+ rc4_tests = [(a, b) for a in range(16) for b in range(16)]
755
+ scores += self._test_ripplecarry(population, 4, rc4_tests)
756
+ total_tests += 256
757
+
758
+ rc8_tests = [(0,0), (1,1), (127,128), (255,1), (128,127), (255,255),
759
+ (0xAA, 0x55), (0x0F, 0xF0), (100, 155), (200, 55)]
760
+ scores += self._test_ripplecarry(population, 8, rc8_tests)
761
+ total_tests += len(rc8_tests)
762
+
763
+ # =================================================================
764
+ # ARITHMETIC - COMPARATORS (240 tests)
765
+ # =================================================================
766
+ scores += self._test_comparator(population, 'greaterthan8bit', 'gt')
767
+ scores += self._test_comparator(population, 'lessthan8bit', 'lt')
768
+ scores += self._test_comparator(population, 'greaterorequal8bit', 'geq')
769
+ scores += self._test_comparator(population, 'lessorequal8bit', 'leq')
770
+ total_tests += 4 * len(self.comp_a)
771
+
772
+ scores += self._test_equality(population)
773
+ total_tests += len(self.comp_a)
774
+
775
+ # =================================================================
776
+ # THRESHOLD GATES (264 tests)
777
+ # =================================================================
778
+ for k, name in enumerate(['oneoutof8', 'twooutof8', 'threeoutof8', 'fouroutof8',
779
+ 'fiveoutof8', 'sixoutof8', 'sevenoutof8', 'alloutof8'], 1):
780
+ scores += self._test_threshold_kofn(population, k, name)
781
+ total_tests += len(self.test_8bit)
782
+
783
+ scores += self._test_majority(population)
784
+ scores += self._test_minority(population)
785
+ total_tests += 2 * len(self.test_8bit)
786
+
787
+ scores += self._test_atleastk(population, 4)
788
+ scores += self._test_atmostk(population, 4)
789
+ scores += self._test_exactlyk(population, 4)
790
+ total_tests += 3 * len(self.test_8bit)
791
+
792
+ # =================================================================
793
+ # PATTERN RECOGNITION (72 tests)
794
+ # =================================================================
795
+ scores += self._test_popcount(population)
796
+ scores += self._test_allzeros(population)
797
+ scores += self._test_allones(population)
798
+ total_tests += 3 * len(self.test_8bit)
799
+
800
+ # =================================================================
801
+ # ERROR DETECTION (48 tests)
802
+ # =================================================================
803
+ scores += self._test_parity(population, 'paritychecker8bit', True)
804
+ scores += self._test_parity(population, 'paritygenerator8bit', True)
805
+ total_tests += 2 * len(self.test_8bit)
806
+
807
+ # =================================================================
808
+ # MODULAR ARITHMETIC (2816 tests: 256 values × 11 moduli)
809
+ # =================================================================
810
+ for mod in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]:
811
+ scores += self._test_modular(population, mod)
812
+ total_tests += len(self.mod_test)
813
+
814
+ # =================================================================
815
+ # COMBINATIONAL (88 tests)
816
+ # =================================================================
817
+ scores += self._test_mux2to1(population)
818
+ total_tests += 8
819
+
820
+ scores += self._test_decoder3to8(population)
821
+ total_tests += 64
822
+
823
+ scores += self._test_encoder8to3(population)
824
+ total_tests += 24
825
+
826
+ # =================================================================
827
+ # CONTROL FLOW (480 tests: 10 circuits × 6 cases × 8 bits)
828
+ # =================================================================
829
+ for ctrl in ['conditionaljump', 'jz', 'jnz', 'jc', 'jnc', 'jn', 'jp', 'jv', 'jnv']:
830
+ scores += self._test_conditional_jump(population, ctrl)
831
+ total_tests += 6 * 8
832
+
833
+ self.total_tests = total_tests
834
+ return scores / total_tests
835
+
836
+
837
+ def create_population(base_tensors: Dict[str, torch.Tensor],
838
+ pop_size: int,
839
+ device='cuda') -> Dict[str, torch.Tensor]:
840
+ """Create population by replicating base tensors."""
841
+ population = {}
842
+ for name, weight in base_tensors.items():
843
+ population[name] = weight.unsqueeze(0).expand(pop_size, *weight.shape).clone().to(device)
844
+ return population
845
+
846
+
847
+ if __name__ == "__main__":
848
+ import time
849
+
850
+ print("="*70)
851
+ print(" IRON EVAL - COMPREHENSIVE TEST")
852
+ print("="*70)
853
+
854
+ print("\nLoading model...")
855
+ model = load_model_10166()
856
+ print(f"Loaded {len(model)} tensors, {sum(t.numel() for t in model.values())} params")
857
+
858
+ print("\nInitializing evaluator...")
859
+ evaluator = BatchedFitnessEvaluator(device='cuda')
860
+
861
+ print("\nCreating population (size 1)...")
862
+ pop = create_population(model, pop_size=1, device='cuda')
863
+
864
+ print("\nRunning evaluation...")
865
+ torch.cuda.synchronize()
866
+ start = time.perf_counter()
867
+ fitness = evaluator.evaluate(pop)
868
+ torch.cuda.synchronize()
869
+ elapsed = time.perf_counter() - start
870
+
871
+ print(f"\nResults:")
872
+ print(f" Fitness: {fitness[0]:.6f}")
873
+ print(f" Total tests: {evaluator.total_tests}")
874
+ print(f" Time: {elapsed*1000:.2f} ms")
875
+
876
+ if fitness[0] == 1.0:
877
+ print("\n STATUS: PASS - All circuits functional")
878
+ else:
879
+ failed = int((1 - fitness[0]) * evaluator.total_tests)
880
+ print(f"\n STATUS: FAIL - {failed} tests failed")