| """ |
| Threshold Network for 2:1 MUX - Magnitude 7 (Proven Optimal) |
| |
| Four equivalent solutions exist at magnitude 7. All compute MUX correctly. |
| Magnitudes 0-6 proven impossible via exhaustive Coq verification. |
| |
| MUX(a, b, s) = a if s=0, b if s=1 |
| """ |
|
|
| import torch |
| from safetensors.torch import load_file |
|
|
|
|
| class ThresholdMUX: |
| """ |
| 2:1 Multiplexer as a 2-layer threshold network. |
| Magnitude 7 - proven optimal by Coq exhaustive computation. |
| """ |
|
|
| def __init__(self, weights_dict): |
| self.l1_weight = weights_dict['layer1.weight'] |
| self.l1_bias = weights_dict['layer1.bias'] |
| self.l2_weight = weights_dict['layer2.weight'] |
| self.l2_bias = weights_dict['layer2.bias'] |
|
|
| def __call__(self, a, b, s): |
| inputs = torch.tensor([float(a), float(b), float(s)]) |
|
|
| hidden = (inputs @ self.l1_weight.T + self.l1_bias >= 0).float() |
| output = (hidden @ self.l2_weight.T + self.l2_bias >= 0).float() |
|
|
| return output |
|
|
| @classmethod |
| def from_safetensors(cls, path="solution1.safetensors"): |
| return cls(load_file(path)) |
|
|
|
|
| def forward(a, b, s, weights): |
| """Forward pass with Heaviside activation.""" |
| inputs = torch.tensor([float(a), float(b), float(s)]) |
|
|
| hidden = (inputs @ weights['layer1.weight'].T + weights['layer1.bias'] >= 0).float() |
| output = (hidden @ weights['layer2.weight'].T + weights['layer2.bias'] >= 0).float() |
|
|
| return output |
|
|
|
|
| def verify_all_solutions(): |
| """Verify all 4 solutions compute MUX correctly.""" |
| print("Verifying all 4 magnitude-7 MUX solutions:") |
| print("=" * 50) |
|
|
| for i in range(1, 5): |
| weights = load_file(f'solution{i}.safetensors') |
| model = ThresholdMUX(weights) |
|
|
| all_correct = True |
| for a in [0, 1]: |
| for b in [0, 1]: |
| for s in [0, 1]: |
| out = int(model(a, b, s).item()) |
| expected = a if s == 0 else b |
| if out != expected: |
| all_correct = False |
|
|
| status = "OK" if all_correct else "FAIL" |
| mag = sum(abs(v).sum().item() for v in weights.values()) |
| print(f" Solution {i}: {status} (magnitude={mag:.0f})") |
|
|
| print("=" * 50) |
|
|
|
|
| if __name__ == "__main__": |
| verify_all_solutions() |
|
|
| print("\nSolution 1 truth table:") |
| print("-" * 30) |
| weights = load_file("solution1.safetensors") |
| model = ThresholdMUX(weights) |
|
|
| for a in [0, 1]: |
| for b in [0, 1]: |
| for s in [0, 1]: |
| out = int(model(a, b, s).item()) |
| expected = a if s == 0 else b |
| status = "OK" if out == expected else "FAIL" |
| print(f"MUX({a}, {b}, s={s}) = {out} [{status}]") |
|
|