""" Threshold Network for Full Adder Adds three 1-bit inputs (a, b, cin), producing sum and carry outputs. Built from two half adders and an OR gate. """ import torch from safetensors.torch import load_file def heaviside(x): return (x >= 0).float() class ThresholdFullAdder: """ Full adder: sum = a XOR b XOR cin, cout = (a AND b) OR ((a XOR b) AND cin) """ def __init__(self, weights_dict): self.weights = weights_dict def half_adder(self, prefix, a, b): inputs = torch.tensor([float(a), float(b)]) # XOR for sum or_out = heaviside((inputs * self.weights[f'{prefix}.sum.layer1.or.weight']).sum() + self.weights[f'{prefix}.sum.layer1.or.bias']) nand_out = heaviside((inputs * self.weights[f'{prefix}.sum.layer1.nand.weight']).sum() + self.weights[f'{prefix}.sum.layer1.nand.bias']) layer1 = torch.tensor([or_out, nand_out]) sum_out = heaviside((layer1 * self.weights[f'{prefix}.sum.layer2.weight']).sum() + self.weights[f'{prefix}.sum.layer2.bias']) # AND for carry carry_out = heaviside((inputs * self.weights[f'{prefix}.carry.weight']).sum() + self.weights[f'{prefix}.carry.bias']) return int(sum_out.item()), int(carry_out.item()) def __call__(self, a, b, cin): # First half adder: a + b s1, c1 = self.half_adder('ha1', a, b) # Second half adder: s1 + cin sum_out, c2 = self.half_adder('ha2', s1, cin) # Carry out = c1 OR c2 carry_inputs = torch.tensor([float(c1), float(c2)]) cout = heaviside((carry_inputs * self.weights['carry_or.weight']).sum() + self.weights['carry_or.bias']) return int(sum_out), int(cout.item()) @classmethod def from_safetensors(cls, path="model.safetensors"): return cls(load_file(path)) if __name__ == "__main__": model = ThresholdFullAdder.from_safetensors("model.safetensors") print("Full Adder Truth Table:") print("-" * 40) print("a | b | cin | sum | cout") print("-" * 40) for a in [0, 1]: for b in [0, 1]: for cin in [0, 1]: s, cout = model(a, b, cin) expected_sum = (a + b + cin) % 2 expected_cout = (a + b + cin) // 2 status = "OK" if (s == expected_sum and cout == expected_cout) else "FAIL" print(f"{a} | {b} | {cin} | {s} | {cout} [{status}]")