|
|
""" |
|
|
Threshold Network for 4-input AND Gate |
|
|
|
|
|
A formally verified single-neuron threshold network computing 4-way logical conjunction. |
|
|
Weights are integer-constrained and activation uses the Heaviside step function. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
class ThresholdAND4: |
|
|
""" |
|
|
4-input AND gate implemented as a threshold neuron. |
|
|
|
|
|
Circuit: output = (w1*x1 + w2*x2 + w3*x3 + w4*x4 + bias >= 0) |
|
|
With weights=[1,1,1,1], bias=-4: only (1,1,1,1) reaches threshold. |
|
|
""" |
|
|
|
|
|
def __init__(self, weights_dict): |
|
|
self.weight = weights_dict['weight'] |
|
|
self.bias = weights_dict['bias'] |
|
|
|
|
|
def __call__(self, x1, x2, x3, x4): |
|
|
inputs = torch.tensor([float(x1), float(x2), float(x3), float(x4)]) |
|
|
weighted_sum = (inputs * self.weight).sum() + self.bias |
|
|
return (weighted_sum >= 0).float() |
|
|
|
|
|
@classmethod |
|
|
def from_safetensors(cls, path="model.safetensors"): |
|
|
return cls(load_file(path)) |
|
|
|
|
|
|
|
|
def forward(x, weights): |
|
|
""" |
|
|
Forward pass with Heaviside activation. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape [..., 4] |
|
|
weights: Dict with 'weight' and 'bias' tensors |
|
|
|
|
|
Returns: |
|
|
AND(x[0], x[1], x[2], x[3]) |
|
|
""" |
|
|
x = torch.as_tensor(x, dtype=torch.float32) |
|
|
weighted_sum = (x * weights['weight']).sum(dim=-1) + weights['bias'] |
|
|
return (weighted_sum >= 0).float() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
weights = load_file("model.safetensors") |
|
|
model = ThresholdAND4(weights) |
|
|
|
|
|
print("4-input AND Gate Truth Table:") |
|
|
print("-" * 35) |
|
|
correct = 0 |
|
|
for x1 in [0, 1]: |
|
|
for x2 in [0, 1]: |
|
|
for x3 in [0, 1]: |
|
|
for x4 in [0, 1]: |
|
|
out = int(model(x1, x2, x3, x4).item()) |
|
|
expected = x1 & x2 & x3 & x4 |
|
|
status = "OK" if out == expected else "FAIL" |
|
|
if out == expected: |
|
|
correct += 1 |
|
|
print(f"AND4({x1}, {x2}, {x3}, {x4}) = {out} [{status}]") |
|
|
print(f"\nTotal: {correct}/16 correct") |
|
|
|