threshold-onehot-encoder / create_safetensors.py
CharlesCNorton
Add 2-to-4 one-hot encoder threshold circuit
3f960d8
import torch
from safetensors.torch import save_file
weights = {}
# One-Hot Encoder (2-to-4)
# Inputs: a1, a0 (binary value 0-3)
# Outputs: y3, y2, y1, y0 (one-hot encoding)
#
# a1 a0 | y3 y2 y1 y0
# ------+------------
# 0 0 | 0 0 0 1
# 0 1 | 0 0 1 0
# 1 0 | 0 1 0 0
# 1 1 | 1 0 0 0
#
# Single layer implementation:
# y0 = NOR(a1, a0) = 1 iff a1 + a0 = 0
# y1 = NOT(a1) AND a0 = 1 iff a0 - a1 >= 1
# y2 = a1 AND NOT(a0) = 1 iff a1 - a0 >= 1
# y3 = a1 AND a0 = 1 iff a1 + a0 >= 2
# y0 = NOR(a1, a0): neither input is 1
weights['y0.weight'] = torch.tensor([[-1.0, -1.0]], dtype=torch.float32)
weights['y0.bias'] = torch.tensor([0.0], dtype=torch.float32)
# y1 = NOT(a1) AND a0: a0 is 1 but a1 is 0
weights['y1.weight'] = torch.tensor([[-1.0, 1.0]], dtype=torch.float32)
weights['y1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
# y2 = a1 AND NOT(a0): a1 is 1 but a0 is 0
weights['y2.weight'] = torch.tensor([[1.0, -1.0]], dtype=torch.float32)
weights['y2.bias'] = torch.tensor([-1.0], dtype=torch.float32)
# y3 = a1 AND a0: both inputs are 1
weights['y3.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
weights['y3.bias'] = torch.tensor([-2.0], dtype=torch.float32)
save_file(weights, 'model.safetensors')
def onehot_encode(a1, a0):
inp = torch.tensor([float(a1), float(a0)])
y0 = int((inp @ weights['y0.weight'].T + weights['y0.bias'] >= 0).item())
y1 = int((inp @ weights['y1.weight'].T + weights['y1.bias'] >= 0).item())
y2 = int((inp @ weights['y2.weight'].T + weights['y2.bias'] >= 0).item())
y3 = int((inp @ weights['y3.weight'].T + weights['y3.bias'] >= 0).item())
return y3, y2, y1, y0
def reference_onehot(a1, a0):
val = a1 * 2 + a0
return (1 if val == 3 else 0,
1 if val == 2 else 0,
1 if val == 1 else 0,
1 if val == 0 else 0)
print("Verifying One-Hot Encoder (2-to-4)...")
errors = 0
for a1 in range(2):
for a0 in range(2):
result = onehot_encode(a1, a0)
expected = reference_onehot(a1, a0)
if result != expected:
errors += 1
print(f"ERROR: ({a1},{a0}) -> {result}, expected {expected}")
if errors == 0:
print("All 4 test cases passed!")
else:
print(f"FAILED: {errors} errors")
print("\nTruth Table:")
print("a1 a0 | y3 y2 y1 y0 | value")
print("-" * 30)
for a1 in range(2):
for a0 in range(2):
y3, y2, y1, y0 = onehot_encode(a1, a0)
val = a1 * 2 + a0
print(f" {a1} {a0} | {y3} {y2} {y1} {y0} | {val}")
mag = sum(t.abs().sum().item() for t in weights.values())
print(f"\nMagnitude: {mag:.0f}")
print(f"Parameters: {sum(t.numel() for t in weights.values())}")
print(f"Neurons: {len([k for k in weights.keys() if 'weight' in k])}")