phanerozoic commited on
Commit
b44a3ba
·
verified ·
1 Parent(s): 2a16a08

Fix evaluator for multi-layer modular and parity circuits

Browse files
Files changed (1) hide show
  1. iron_eval.py +105 -26
iron_eval.py CHANGED
@@ -121,10 +121,9 @@ class BatchedFitnessEvaluator:
121
 
122
  def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor,
123
  expected: torch.Tensor) -> torch.Tensor:
124
- """Test two-layer gate (XOR, XNOR, BIIMPLIES)."""
125
  pop_size = next(iter(pop.values())).shape[0]
126
 
127
- # Layer 1
128
  w1_a = pop[f'{prefix}.layer1.neuron1.weight'].view(pop_size, -1)
129
  b1_a = pop[f'{prefix}.layer1.neuron1.bias'].view(pop_size)
130
  w1_b = pop[f'{prefix}.layer1.neuron2.weight'].view(pop_size, -1)
@@ -134,7 +133,26 @@ class BatchedFitnessEvaluator:
134
  h_b = heaviside(inputs @ w1_b.T + b1_b)
135
  hidden = torch.stack([h_a, h_b], dim=2)
136
 
137
- # Layer 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
139
  b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
140
  out = heaviside((hidden * w2.unsqueeze(0)).sum(2) + b2.unsqueeze(0))
@@ -150,9 +168,8 @@ class BatchedFitnessEvaluator:
150
  pop_size = next(iter(pop.values())).shape[0]
151
  scores = torch.zeros(pop_size, device=self.device)
152
 
153
- # Sum (XOR)
154
- scores += self._test_twolayer_gate(pop, 'arithmetic.halfadder.sum',
155
- self.tt2, self.expected['ha_sum'])
156
  # Carry (AND)
157
  w = pop['arithmetic.halfadder.carry.weight'].view(pop_size, -1)
158
  b = pop['arithmetic.halfadder.carry.bias'].view(pop_size)
@@ -500,13 +517,46 @@ class BatchedFitnessEvaluator:
500
  # ERROR DETECTION
501
  # =========================================================================
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  def _test_parity(self, pop: Dict, name: str, even: bool) -> torch.Tensor:
504
- """Test parity checker/generator."""
505
  pop_size = next(iter(pop.values())).shape[0]
506
- w = pop[f'error_detection.{name}.weight'].view(pop_size, -1)
507
- b = pop[f'error_detection.{name}.bias'].view(pop_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
- out = heaviside(self.test_8bit_bits @ w.T + b)
510
  popcounts = self.test_8bit_bits.sum(1)
511
  if even:
512
  expected = ((popcounts.long() % 2) == 0).float()
@@ -519,15 +569,50 @@ class BatchedFitnessEvaluator:
519
  # MODULAR ARITHMETIC
520
  # =========================================================================
521
 
 
 
 
 
 
 
522
  def _test_modular(self, pop: Dict, mod: int) -> torch.Tensor:
523
  """Test modular arithmetic circuit."""
524
  pop_size = next(iter(pop.values())).shape[0]
525
- w = pop[f'modular.mod{mod}.weight'].view(pop_size, -1)
526
- b = pop[f'modular.mod{mod}.bias'].view(pop_size)
527
 
528
- out = heaviside(self.mod_test_bits @ w.T + b)
529
- expected = ((self.mod_test % mod) == 0).float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
 
531
  return (out == expected.unsqueeze(1)).float().sum(0)
532
 
533
  # =========================================================================
@@ -539,35 +624,29 @@ class BatchedFitnessEvaluator:
539
  pop_size = next(iter(pop.values())).shape[0]
540
  scores = torch.zeros(pop_size, device=self.device)
541
 
542
- # Test all 8 combinations of (a, b, sel)
543
  for a in [0, 1]:
544
  for b in [0, 1]:
545
  for sel in [0, 1]:
546
  expected = a if sel == 1 else b
547
 
548
- # MUX uses: and_a, and_b, not_sel, or
549
  a_t = torch.full((pop_size,), float(a), device=self.device)
550
  b_t = torch.full((pop_size,), float(b), device=self.device)
551
  sel_t = torch.full((pop_size,), float(sel), device=self.device)
552
 
553
- # NOT sel
554
- w_not = pop['combinational.multiplexer2to1.not_sel.weight'].view(pop_size, -1)
555
- b_not = pop['combinational.multiplexer2to1.not_sel.bias'].view(pop_size)
556
  not_sel = heaviside(sel_t.unsqueeze(1) @ w_not.T + b_not)
557
 
558
- # AND(a, sel)
559
  inp_a = torch.stack([a_t, sel_t], dim=1)
560
- w_and_a = pop['combinational.multiplexer2to1.and_a.weight'].view(pop_size, -1)
561
- b_and_a = pop['combinational.multiplexer2to1.and_a.bias'].view(pop_size)
562
  and_a = heaviside((inp_a * w_and_a).sum(1) + b_and_a)
563
 
564
- # AND(b, not_sel)
565
  inp_b = torch.stack([b_t, not_sel.squeeze(1)], dim=1)
566
- w_and_b = pop['combinational.multiplexer2to1.and_b.weight'].view(pop_size, -1)
567
- b_and_b = pop['combinational.multiplexer2to1.and_b.bias'].view(pop_size)
568
  and_b = heaviside((inp_b * w_and_b).sum(1) + b_and_b)
569
 
570
- # OR
571
  inp_or = torch.stack([and_a, and_b], dim=1)
572
  w_or = pop['combinational.multiplexer2to1.or.weight'].view(pop_size, -1)
573
  b_or = pop['combinational.multiplexer2to1.or.bias'].view(pop_size)
 
121
 
122
  def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor,
123
  expected: torch.Tensor) -> torch.Tensor:
124
+ """Test two-layer gate (XOR, XNOR, BIIMPLIES) - boolean naming (neuron1/neuron2)."""
125
  pop_size = next(iter(pop.values())).shape[0]
126
 
 
127
  w1_a = pop[f'{prefix}.layer1.neuron1.weight'].view(pop_size, -1)
128
  b1_a = pop[f'{prefix}.layer1.neuron1.bias'].view(pop_size)
129
  w1_b = pop[f'{prefix}.layer1.neuron2.weight'].view(pop_size, -1)
 
133
  h_b = heaviside(inputs @ w1_b.T + b1_b)
134
  hidden = torch.stack([h_a, h_b], dim=2)
135
 
136
+ w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
137
+ b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
138
+ out = heaviside((hidden * w2.unsqueeze(0)).sum(2) + b2.unsqueeze(0))
139
+
140
+ return (out == expected.unsqueeze(1)).float().sum(0)
141
+
142
+ def _test_xor_gate_ornand(self, pop: Dict, prefix: str, inputs: torch.Tensor,
143
+ expected: torch.Tensor) -> torch.Tensor:
144
+ """Test two-layer XOR gate - arithmetic naming (or/nand)."""
145
+ pop_size = next(iter(pop.values())).shape[0]
146
+
147
+ w1_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1)
148
+ b1_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
149
+ w1_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1)
150
+ b1_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
151
+
152
+ h_or = heaviside(inputs @ w1_or.T + b1_or)
153
+ h_nand = heaviside(inputs @ w1_nand.T + b1_nand)
154
+ hidden = torch.stack([h_or, h_nand], dim=2)
155
+
156
  w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
157
  b2 = pop[f'{prefix}.layer2.bias'].view(pop_size)
158
  out = heaviside((hidden * w2.unsqueeze(0)).sum(2) + b2.unsqueeze(0))
 
168
  pop_size = next(iter(pop.values())).shape[0]
169
  scores = torch.zeros(pop_size, device=self.device)
170
 
171
+ scores += self._test_xor_gate_ornand(pop, 'arithmetic.halfadder.sum',
172
+ self.tt2, self.expected['ha_sum'])
 
173
  # Carry (AND)
174
  w = pop['arithmetic.halfadder.carry.weight'].view(pop_size, -1)
175
  b = pop['arithmetic.halfadder.carry.bias'].view(pop_size)
 
517
  # ERROR DETECTION
518
  # =========================================================================
519
 
520
+ def _eval_xor_gate(self, pop: Dict, prefix: str, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
521
+ """Evaluate XOR gate on batched inputs."""
522
+ pop_size = next(iter(pop.values())).shape[0]
523
+
524
+ w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1)
525
+ b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size)
526
+ w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1)
527
+ b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size)
528
+ w_and = pop[f'{prefix}.layer2.weight'].view(pop_size, -1)
529
+ b_and = pop[f'{prefix}.layer2.bias'].view(pop_size)
530
+
531
+ inp = torch.stack([a, b], dim=2)
532
+ h_or = heaviside((inp * w_or.unsqueeze(0)).sum(2) + b_or.unsqueeze(0))
533
+ h_nand = heaviside((inp * w_nand.unsqueeze(0)).sum(2) + b_nand.unsqueeze(0))
534
+ hidden = torch.stack([h_or, h_nand], dim=2)
535
+ return heaviside((hidden * w_and.unsqueeze(0)).sum(2) + b_and.unsqueeze(0))
536
+
537
  def _test_parity(self, pop: Dict, name: str, even: bool) -> torch.Tensor:
538
+ """Test parity checker/generator with XOR tree."""
539
  pop_size = next(iter(pop.values())).shape[0]
540
+ prefix = f'error_detection.{name}'
541
+ num_tests = self.test_8bit_bits.shape[0]
542
+
543
+ bits = self.test_8bit_bits.unsqueeze(1).expand(-1, pop_size, -1)
544
+
545
+ stage1 = []
546
+ for i, (a, b) in enumerate([(0, 1), (2, 3), (4, 5), (6, 7)]):
547
+ xor_out = self._eval_xor_gate(pop, f'{prefix}.stage1.xor{i}', bits[:,:,a], bits[:,:,b])
548
+ stage1.append(xor_out)
549
+
550
+ stage2 = []
551
+ stage2.append(self._eval_xor_gate(pop, f'{prefix}.stage2.xor0', stage1[0], stage1[1]))
552
+ stage2.append(self._eval_xor_gate(pop, f'{prefix}.stage2.xor1', stage1[2], stage1[3]))
553
+
554
+ xor_all = self._eval_xor_gate(pop, f'{prefix}.stage3.xor0', stage2[0], stage2[1])
555
+
556
+ w_not = pop[f'{prefix}.output.not.weight'].view(pop_size, -1)
557
+ b_not = pop[f'{prefix}.output.not.bias'].view(pop_size)
558
+ out = heaviside(xor_all.unsqueeze(2) * w_not.unsqueeze(0) + b_not.unsqueeze(0)).squeeze(2)
559
 
 
560
  popcounts = self.test_8bit_bits.sum(1)
561
  if even:
562
  expected = ((popcounts.long() % 2) == 0).float()
 
569
  # MODULAR ARITHMETIC
570
  # =========================================================================
571
 
572
+ def _get_divisible_sums(self, mod: int) -> list:
573
+ """Get sum values that indicate divisibility by mod."""
574
+ weights = [(2**(7-i)) % mod for i in range(8)]
575
+ max_sum = sum(weights)
576
+ return [k for k in range(0, max_sum + 1) if k % mod == 0]
577
+
578
  def _test_modular(self, pop: Dict, mod: int) -> torch.Tensor:
579
  """Test modular arithmetic circuit."""
580
  pop_size = next(iter(pop.values())).shape[0]
 
 
581
 
582
+ if mod in [2, 4, 8]:
583
+ w = pop[f'modular.mod{mod}.weight'].view(pop_size, -1)
584
+ b = pop[f'modular.mod{mod}.bias'].view(pop_size)
585
+ out = heaviside(self.mod_test_bits @ w.T + b)
586
+ else:
587
+ divisible_sums = self._get_divisible_sums(mod)
588
+ num_detectors = len(divisible_sums)
589
+
590
+ layer1_outputs = []
591
+ for idx in range(num_detectors):
592
+ w_geq = pop[f'modular.mod{mod}.layer1.geq{idx}.weight'].view(pop_size, -1)
593
+ b_geq = pop[f'modular.mod{mod}.layer1.geq{idx}.bias'].view(pop_size)
594
+ w_leq = pop[f'modular.mod{mod}.layer1.leq{idx}.weight'].view(pop_size, -1)
595
+ b_leq = pop[f'modular.mod{mod}.layer1.leq{idx}.bias'].view(pop_size)
596
+
597
+ geq = heaviside(self.mod_test_bits @ w_geq.T + b_geq)
598
+ leq = heaviside(self.mod_test_bits @ w_leq.T + b_leq)
599
+ layer1_outputs.append((geq, leq))
600
+
601
+ layer2_outputs = []
602
+ for idx in range(num_detectors):
603
+ w_eq = pop[f'modular.mod{mod}.layer2.eq{idx}.weight'].view(pop_size, -1)
604
+ b_eq = pop[f'modular.mod{mod}.layer2.eq{idx}.bias'].view(pop_size)
605
+ geq, leq = layer1_outputs[idx]
606
+ combined = torch.stack([geq, leq], dim=2)
607
+ eq = heaviside((combined * w_eq.unsqueeze(0)).sum(2) + b_eq.unsqueeze(0))
608
+ layer2_outputs.append(eq)
609
+
610
+ layer2_stack = torch.stack(layer2_outputs, dim=2)
611
+ w_or = pop[f'modular.mod{mod}.layer3.or.weight'].view(pop_size, -1)
612
+ b_or = pop[f'modular.mod{mod}.layer3.or.bias'].view(pop_size)
613
+ out = heaviside((layer2_stack * w_or.unsqueeze(0)).sum(2) + b_or.unsqueeze(0))
614
 
615
+ expected = ((self.mod_test % mod) == 0).float()
616
  return (out == expected.unsqueeze(1)).float().sum(0)
617
 
618
  # =========================================================================
 
624
  pop_size = next(iter(pop.values())).shape[0]
625
  scores = torch.zeros(pop_size, device=self.device)
626
 
 
627
  for a in [0, 1]:
628
  for b in [0, 1]:
629
  for sel in [0, 1]:
630
  expected = a if sel == 1 else b
631
 
 
632
  a_t = torch.full((pop_size,), float(a), device=self.device)
633
  b_t = torch.full((pop_size,), float(b), device=self.device)
634
  sel_t = torch.full((pop_size,), float(sel), device=self.device)
635
 
636
+ w_not = pop['combinational.multiplexer2to1.not_s.weight'].view(pop_size, -1)
637
+ b_not = pop['combinational.multiplexer2to1.not_s.bias'].view(pop_size)
 
638
  not_sel = heaviside(sel_t.unsqueeze(1) @ w_not.T + b_not)
639
 
 
640
  inp_a = torch.stack([a_t, sel_t], dim=1)
641
+ w_and_a = pop['combinational.multiplexer2to1.and1.weight'].view(pop_size, -1)
642
+ b_and_a = pop['combinational.multiplexer2to1.and1.bias'].view(pop_size)
643
  and_a = heaviside((inp_a * w_and_a).sum(1) + b_and_a)
644
 
 
645
  inp_b = torch.stack([b_t, not_sel.squeeze(1)], dim=1)
646
+ w_and_b = pop['combinational.multiplexer2to1.and0.weight'].view(pop_size, -1)
647
+ b_and_b = pop['combinational.multiplexer2to1.and0.bias'].view(pop_size)
648
  and_b = heaviside((inp_b * w_and_b).sum(1) + b_and_b)
649
 
 
650
  inp_or = torch.stack([and_a, and_b], dim=1)
651
  w_or = pop['combinational.multiplexer2to1.or.weight'].view(pop_size, -1)
652
  b_or = pop['combinational.multiplexer2to1.or.bias'].view(pop_size)