threshold-4to16decoder / create_safetensors.py
phanerozoic's picture
Upload folder using huggingface_hub
ed72eb8 verified
import torch
from safetensors.torch import save_file
weights = {}
# Input order: a3, a2, a1, a0 (MSB to LSB)
# Output: one-hot y0..y15 where yi=1 iff input=i
for i in range(16):
# Pattern matcher: +1 for bits that should be 1, -1 for bits that should be 0
w = []
for bit_pos in range(3, -1, -1): # a3, a2, a1, a0
bit_val = (i >> bit_pos) & 1
w.append(1.0 if bit_val else -1.0)
# Bias: -(number of 1 bits in i)
bias = -bin(i).count('1')
weights[f'y{i}.weight'] = torch.tensor([w], dtype=torch.float32)
weights[f'y{i}.bias'] = torch.tensor([float(bias)], dtype=torch.float32)
save_file(weights, 'model.safetensors')
# Verify
def decode(a3, a2, a1, a0):
inp = torch.tensor([float(a3), float(a2), float(a1), float(a0)])
outputs = []
for i in range(16):
y = int((inp * weights[f'y{i}.weight']).sum() + weights[f'y{i}.bias'] >= 0)
outputs.append(y)
return outputs
print("Verifying 4to16decoder...")
errors = 0
for val in range(16):
a3, a2, a1, a0 = (val >> 3) & 1, (val >> 2) & 1, (val >> 1) & 1, val & 1
result = decode(a3, a2, a1, a0)
expected = [1 if i == val else 0 for i in range(16)]
if result != expected:
errors += 1
print(f"ERROR: {val} ({a3}{a2}{a1}{a0}) -> {result}")
if errors == 0:
print("All 16 test cases passed!")
mag = sum(t.abs().sum().item() for t in weights.values())
print(f"Magnitude: {mag:.0f}")