Fix evaluator for multi-layer modular and parity circuits
Browse files- iron_eval.py +105 -26
iron_eval.py
CHANGED
|
@@ -121,10 +121,9 @@ class BatchedFitnessEvaluator:
|
|
| 121 |
|
| 122 |
def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor,
|
| 123 |
expected: torch.Tensor) -> torch.Tensor:
|
| 124 |
-
"""Test two-layer gate (XOR, XNOR, BIIMPLIES)."""
|
| 125 |
pop_size = next(iter(pop.values())).shape[0]
|
| 126 |
|
| 127 |
-
# Layer 1
|
| 128 |
w1_a = pop[f'{prefix}.layer1.neuron1.weight'].view(pop_size, -1)
|
| 129 |
b1_a = pop[f'{prefix}.layer1.neuron1.bias'].view(pop_size)
|
| 130 |
w1_b = pop[f'{prefix}.layer1.neuron2.weight'].view(pop_size, -1)
|
|
@@ -134,7 +133,26 @@ class BatchedFitnessEvaluator:
|
|
| 134 |
h_b = heaviside(inputs @ w1_b.T + b1_b)
|
| 135 |
hidden = torch.stack([h_a, h_b], dim=2)
|
| 136 |
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
|
| 139 |
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
|
| 140 |
out = heaviside((hidden * w2.unsqueeze(0)).sum(2) + b2.unsqueeze(0))
|
|
@@ -150,9 +168,8 @@ class BatchedFitnessEvaluator:
|
|
| 150 |
pop_size = next(iter(pop.values())).shape[0]
|
| 151 |
scores = torch.zeros(pop_size, device=self.device)
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
self.tt2, self.expected['ha_sum'])
|
| 156 |
# Carry (AND)
|
| 157 |
w = pop['arithmetic.halfadder.carry.weight'].view(pop_size, -1)
|
| 158 |
b = pop['arithmetic.halfadder.carry.bias'].view(pop_size)
|
|
@@ -500,13 +517,46 @@ class BatchedFitnessEvaluator:
|
|
| 500 |
# ERROR DETECTION
|
| 501 |
# =========================================================================
|
| 502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
def _test_parity(self, pop: Dict, name: str, even: bool) -> torch.Tensor:
|
| 504 |
-
"""Test parity checker/generator."""
|
| 505 |
pop_size = next(iter(pop.values())).shape[0]
|
| 506 |
-
|
| 507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
-
out = heaviside(self.test_8bit_bits @ w.T + b)
|
| 510 |
popcounts = self.test_8bit_bits.sum(1)
|
| 511 |
if even:
|
| 512 |
expected = ((popcounts.long() % 2) == 0).float()
|
|
@@ -519,15 +569,50 @@ class BatchedFitnessEvaluator:
|
|
| 519 |
# MODULAR ARITHMETIC
|
| 520 |
# =========================================================================
|
| 521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
def _test_modular(self, pop: Dict, mod: int) -> torch.Tensor:
|
| 523 |
"""Test modular arithmetic circuit."""
|
| 524 |
pop_size = next(iter(pop.values())).shape[0]
|
| 525 |
-
w = pop[f'modular.mod{mod}.weight'].view(pop_size, -1)
|
| 526 |
-
b = pop[f'modular.mod{mod}.bias'].view(pop_size)
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
|
|
|
| 531 |
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 532 |
|
| 533 |
# =========================================================================
|
|
@@ -539,35 +624,29 @@ class BatchedFitnessEvaluator:
|
|
| 539 |
pop_size = next(iter(pop.values())).shape[0]
|
| 540 |
scores = torch.zeros(pop_size, device=self.device)
|
| 541 |
|
| 542 |
-
# Test all 8 combinations of (a, b, sel)
|
| 543 |
for a in [0, 1]:
|
| 544 |
for b in [0, 1]:
|
| 545 |
for sel in [0, 1]:
|
| 546 |
expected = a if sel == 1 else b
|
| 547 |
|
| 548 |
-
# MUX uses: and_a, and_b, not_sel, or
|
| 549 |
a_t = torch.full((pop_size,), float(a), device=self.device)
|
| 550 |
b_t = torch.full((pop_size,), float(b), device=self.device)
|
| 551 |
sel_t = torch.full((pop_size,), float(sel), device=self.device)
|
| 552 |
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
b_not = pop['combinational.multiplexer2to1.not_sel.bias'].view(pop_size)
|
| 556 |
not_sel = heaviside(sel_t.unsqueeze(1) @ w_not.T + b_not)
|
| 557 |
|
| 558 |
-
# AND(a, sel)
|
| 559 |
inp_a = torch.stack([a_t, sel_t], dim=1)
|
| 560 |
-
w_and_a = pop['combinational.multiplexer2to1.
|
| 561 |
-
b_and_a = pop['combinational.multiplexer2to1.
|
| 562 |
and_a = heaviside((inp_a * w_and_a).sum(1) + b_and_a)
|
| 563 |
|
| 564 |
-
# AND(b, not_sel)
|
| 565 |
inp_b = torch.stack([b_t, not_sel.squeeze(1)], dim=1)
|
| 566 |
-
w_and_b = pop['combinational.multiplexer2to1.
|
| 567 |
-
b_and_b = pop['combinational.multiplexer2to1.
|
| 568 |
and_b = heaviside((inp_b * w_and_b).sum(1) + b_and_b)
|
| 569 |
|
| 570 |
-
# OR
|
| 571 |
inp_or = torch.stack([and_a, and_b], dim=1)
|
| 572 |
w_or = pop['combinational.multiplexer2to1.or.weight'].view(pop_size, -1)
|
| 573 |
b_or = pop['combinational.multiplexer2to1.or.bias'].view(pop_size)
|
|
|
|
| 121 |
|
| 122 |
def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor,
|
| 123 |
expected: torch.Tensor) -> torch.Tensor:
|
| 124 |
+
"""Test two-layer gate (XOR, XNOR, BIIMPLIES) - boolean naming (neuron1/neuron2)."""
|
| 125 |
pop_size = next(iter(pop.values())).shape[0]
|
| 126 |
|
|
|
|
| 127 |
w1_a = pop[f'{prefix}.layer1.neuron1.weight'].view(pop_size, -1)
|
| 128 |
b1_a = pop[f'{prefix}.layer1.neuron1.bias'].view(pop_size)
|
| 129 |
w1_b = pop[f'{prefix}.layer1.neuron2.weight'].view(pop_size, -1)
|
|
|
|
| 133 |
h_b = heaviside(inputs @ w1_b.T + b1_b)
|
| 134 |
hidden = torch.stack([h_a, h_b], dim=2)
|
| 135 |
|
| 136 |
+
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
|
| 137 |
+
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
|
| 138 |
+
out = heaviside((hidden * w2.unsqueeze(0)).sum(2) + b2.unsqueeze(0))
|
| 139 |
+
|
| 140 |
+
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 141 |
+
|
| 142 |
+
def _test_xor_gate_ornand(self, pop: Dict, prefix: str, inputs: torch.Tensor,
|
| 143 |
+
expected: torch.Tensor) -> torch.Tensor:
|
| 144 |
+
"""Test two-layer XOR gate - arithmetic naming (or/nand)."""
|
| 145 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 146 |
+
|
| 147 |
+
w1_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1)
|
| 148 |
+
b1_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
|
| 149 |
+
w1_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1)
|
| 150 |
+
b1_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
|
| 151 |
+
|
| 152 |
+
h_or = heaviside(inputs @ w1_or.T + b1_or)
|
| 153 |
+
h_nand = heaviside(inputs @ w1_nand.T + b1_nand)
|
| 154 |
+
hidden = torch.stack([h_or, h_nand], dim=2)
|
| 155 |
+
|
| 156 |
w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
|
| 157 |
b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
|
| 158 |
out = heaviside((hidden * w2.unsqueeze(0)).sum(2) + b2.unsqueeze(0))
|
|
|
|
| 168 |
pop_size = next(iter(pop.values())).shape[0]
|
| 169 |
scores = torch.zeros(pop_size, device=self.device)
|
| 170 |
|
| 171 |
+
scores += self._test_xor_gate_ornand(pop, 'arithmetic.halfadder.sum',
|
| 172 |
+
self.tt2, self.expected['ha_sum'])
|
|
|
|
| 173 |
# Carry (AND)
|
| 174 |
w = pop['arithmetic.halfadder.carry.weight'].view(pop_size, -1)
|
| 175 |
b = pop['arithmetic.halfadder.carry.bias'].view(pop_size)
|
|
|
|
| 517 |
# ERROR DETECTION
|
| 518 |
# =========================================================================
|
| 519 |
|
| 520 |
+
def _eval_xor_gate(self, pop: Dict, prefix: str, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 521 |
+
"""Evaluate XOR gate on batched inputs."""
|
| 522 |
+
pop_size = next(iter(pop.values())).shape[0]
|
| 523 |
+
|
| 524 |
+
w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1)
|
| 525 |
+
b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
|
| 526 |
+
w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1)
|
| 527 |
+
b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
|
| 528 |
+
w_and = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
|
| 529 |
+
b_and = pop[f'{prefix}.layer2.bias'].view(pop_size)
|
| 530 |
+
|
| 531 |
+
inp = torch.stack([a, b], dim=2)
|
| 532 |
+
h_or = heaviside((inp * w_or.unsqueeze(0)).sum(2) + b_or.unsqueeze(0))
|
| 533 |
+
h_nand = heaviside((inp * w_nand.unsqueeze(0)).sum(2) + b_nand.unsqueeze(0))
|
| 534 |
+
hidden = torch.stack([h_or, h_nand], dim=2)
|
| 535 |
+
return heaviside((hidden * w_and.unsqueeze(0)).sum(2) + b_and.unsqueeze(0))
|
| 536 |
+
|
| 537 |
def _test_parity(self, pop: Dict, name: str, even: bool) -> torch.Tensor:
|
| 538 |
+
"""Test parity checker/generator with XOR tree."""
|
| 539 |
pop_size = next(iter(pop.values())).shape[0]
|
| 540 |
+
prefix = f'error_detection.{name}'
|
| 541 |
+
num_tests = self.test_8bit_bits.shape[0]
|
| 542 |
+
|
| 543 |
+
bits = self.test_8bit_bits.unsqueeze(1).expand(-1, pop_size, -1)
|
| 544 |
+
|
| 545 |
+
stage1 = []
|
| 546 |
+
for i, (a, b) in enumerate([(0, 1), (2, 3), (4, 5), (6, 7)]):
|
| 547 |
+
xor_out = self._eval_xor_gate(pop, f'{prefix}.stage1.xor{i}', bits[:,:,a], bits[:,:,b])
|
| 548 |
+
stage1.append(xor_out)
|
| 549 |
+
|
| 550 |
+
stage2 = []
|
| 551 |
+
stage2.append(self._eval_xor_gate(pop, f'{prefix}.stage2.xor0', stage1[0], stage1[1]))
|
| 552 |
+
stage2.append(self._eval_xor_gate(pop, f'{prefix}.stage2.xor1', stage1[2], stage1[3]))
|
| 553 |
+
|
| 554 |
+
xor_all = self._eval_xor_gate(pop, f'{prefix}.stage3.xor0', stage2[0], stage2[1])
|
| 555 |
+
|
| 556 |
+
w_not = pop[f'{prefix}.output.not.weight'].view(pop_size, -1)
|
| 557 |
+
b_not = pop[f'{prefix}.output.not.bias'].view(pop_size)
|
| 558 |
+
out = heaviside(xor_all.unsqueeze(2) * w_not.unsqueeze(0) + b_not.unsqueeze(0)).squeeze(2)
|
| 559 |
|
|
|
|
| 560 |
popcounts = self.test_8bit_bits.sum(1)
|
| 561 |
if even:
|
| 562 |
expected = ((popcounts.long() % 2) == 0).float()
|
|
|
|
| 569 |
# MODULAR ARITHMETIC
|
| 570 |
# =========================================================================
|
| 571 |
|
| 572 |
+
def _get_divisible_sums(self, mod: int) -> list:
|
| 573 |
+
"""Get sum values that indicate divisibility by mod."""
|
| 574 |
+
weights = [(2**(7-i)) % mod for i in range(8)]
|
| 575 |
+
max_sum = sum(weights)
|
| 576 |
+
return [k for k in range(0, max_sum + 1) if k % mod == 0]
|
| 577 |
+
|
| 578 |
def _test_modular(self, pop: Dict, mod: int) -> torch.Tensor:
|
| 579 |
"""Test modular arithmetic circuit."""
|
| 580 |
pop_size = next(iter(pop.values())).shape[0]
|
|
|
|
|
|
|
| 581 |
|
| 582 |
+
if mod in [2, 4, 8]:
|
| 583 |
+
w = pop[f'modular.mod{mod}.weight'].view(pop_size, -1)
|
| 584 |
+
b = pop[f'modular.mod{mod}.bias'].view(pop_size)
|
| 585 |
+
out = heaviside(self.mod_test_bits @ w.T + b)
|
| 586 |
+
else:
|
| 587 |
+
divisible_sums = self._get_divisible_sums(mod)
|
| 588 |
+
num_detectors = len(divisible_sums)
|
| 589 |
+
|
| 590 |
+
layer1_outputs = []
|
| 591 |
+
for idx in range(num_detectors):
|
| 592 |
+
w_geq = pop[f'modular.mod{mod}.layer1.geq{idx}.weight'].view(pop_size, -1)
|
| 593 |
+
b_geq = pop[f'modular.mod{mod}.layer1.geq{idx}.bias'].view(pop_size)
|
| 594 |
+
w_leq = pop[f'modular.mod{mod}.layer1.leq{idx}.weight'].view(pop_size, -1)
|
| 595 |
+
b_leq = pop[f'modular.mod{mod}.layer1.leq{idx}.bias'].view(pop_size)
|
| 596 |
+
|
| 597 |
+
geq = heaviside(self.mod_test_bits @ w_geq.T + b_geq)
|
| 598 |
+
leq = heaviside(self.mod_test_bits @ w_leq.T + b_leq)
|
| 599 |
+
layer1_outputs.append((geq, leq))
|
| 600 |
+
|
| 601 |
+
layer2_outputs = []
|
| 602 |
+
for idx in range(num_detectors):
|
| 603 |
+
w_eq = pop[f'modular.mod{mod}.layer2.eq{idx}.weight'].view(pop_size, -1)
|
| 604 |
+
b_eq = pop[f'modular.mod{mod}.layer2.eq{idx}.bias'].view(pop_size)
|
| 605 |
+
geq, leq = layer1_outputs[idx]
|
| 606 |
+
combined = torch.stack([geq, leq], dim=2)
|
| 607 |
+
eq = heaviside((combined * w_eq.unsqueeze(0)).sum(2) + b_eq.unsqueeze(0))
|
| 608 |
+
layer2_outputs.append(eq)
|
| 609 |
+
|
| 610 |
+
layer2_stack = torch.stack(layer2_outputs, dim=2)
|
| 611 |
+
w_or = pop[f'modular.mod{mod}.layer3.or.weight'].view(pop_size, -1)
|
| 612 |
+
b_or = pop[f'modular.mod{mod}.layer3.or.bias'].view(pop_size)
|
| 613 |
+
out = heaviside((layer2_stack * w_or.unsqueeze(0)).sum(2) + b_or.unsqueeze(0))
|
| 614 |
|
| 615 |
+
expected = ((self.mod_test % mod) == 0).float()
|
| 616 |
return (out == expected.unsqueeze(1)).float().sum(0)
|
| 617 |
|
| 618 |
# =========================================================================
|
|
|
|
| 624 |
pop_size = next(iter(pop.values())).shape[0]
|
| 625 |
scores = torch.zeros(pop_size, device=self.device)
|
| 626 |
|
|
|
|
| 627 |
for a in [0, 1]:
|
| 628 |
for b in [0, 1]:
|
| 629 |
for sel in [0, 1]:
|
| 630 |
expected = a if sel == 1 else b
|
| 631 |
|
|
|
|
| 632 |
a_t = torch.full((pop_size,), float(a), device=self.device)
|
| 633 |
b_t = torch.full((pop_size,), float(b), device=self.device)
|
| 634 |
sel_t = torch.full((pop_size,), float(sel), device=self.device)
|
| 635 |
|
| 636 |
+
w_not = pop['combinational.multiplexer2to1.not_s.weight'].view(pop_size, -1)
|
| 637 |
+
b_not = pop['combinational.multiplexer2to1.not_s.bias'].view(pop_size)
|
|
|
|
| 638 |
not_sel = heaviside(sel_t.unsqueeze(1) @ w_not.T + b_not)
|
| 639 |
|
|
|
|
| 640 |
inp_a = torch.stack([a_t, sel_t], dim=1)
|
| 641 |
+
w_and_a = pop['combinational.multiplexer2to1.and1.weight'].view(pop_size, -1)
|
| 642 |
+
b_and_a = pop['combinational.multiplexer2to1.and1.bias'].view(pop_size)
|
| 643 |
and_a = heaviside((inp_a * w_and_a).sum(1) + b_and_a)
|
| 644 |
|
|
|
|
| 645 |
inp_b = torch.stack([b_t, not_sel.squeeze(1)], dim=1)
|
| 646 |
+
w_and_b = pop['combinational.multiplexer2to1.and0.weight'].view(pop_size, -1)
|
| 647 |
+
b_and_b = pop['combinational.multiplexer2to1.and0.bias'].view(pop_size)
|
| 648 |
and_b = heaviside((inp_b * w_and_b).sum(1) + b_and_b)
|
| 649 |
|
|
|
|
| 650 |
inp_or = torch.stack([and_a, and_b], dim=1)
|
| 651 |
w_or = pop['combinational.multiplexer2to1.or.weight'].view(pop_size, -1)
|
| 652 |
b_or = pop['combinational.multiplexer2to1.or.bias'].view(pop_size)
|