threshold-2to4decoder / create_safetensors.py
phanerozoic's picture
Upload folder using huggingface_hub
67df09a verified
import torch
from safetensors.torch import save_file
weights = {}
# Input: A1, A0 (2 bits)
# Output: Y0-Y3 (one-hot)
# Yi fires when input = i
for i in range(4):
a1_bit = (i >> 1) & 1
a0_bit = i & 1
w = [1.0 if a1_bit else -1.0, 1.0 if a0_bit else -1.0]
bias = -bin(i).count('1')
weights[f'y{i}.weight'] = torch.tensor([w], dtype=torch.float32)
weights[f'y{i}.bias'] = torch.tensor([float(bias)], dtype=torch.float32)
save_file(weights, 'model.safetensors')
def decode2to4(a1, a0):
inp = torch.tensor([float(a1), float(a0)])
return [int((inp * weights[f'y{i}.weight']).sum() + weights[f'y{i}.bias'] >= 0) for i in range(4)]
print("Verifying 2to4decoder...")
errors = 0
for val in range(4):
a1, a0 = (val >> 1) & 1, val & 1
result = decode2to4(a1, a0)
expected = [1 if i == val else 0 for i in range(4)]
if result != expected:
errors += 1
print(f"ERROR: {val} -> {result}, expected {expected}")
if errors == 0:
print("All 4 test cases passed!")
print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")