threshold-parity7 / create_safetensors.py
phanerozoic's picture
Upload folder using huggingface_hub
6a0af35 verified
import torch
from safetensors.torch import save_file
# 7-bit parity using tree of XOR gates
# Structure: parity7 = XOR(XOR(XOR(x0,x1), XOR(x2,x3)), XOR(XOR(x4,x5), x6))
# 6 XOR2 gates total
def xor_block(prefix):
return {
f'{prefix}.or.weight': torch.tensor([1.0, 1.0], dtype=torch.float32),
f'{prefix}.or.bias': torch.tensor([-1.0], dtype=torch.float32),
f'{prefix}.nand.weight': torch.tensor([-1.0, -1.0], dtype=torch.float32),
f'{prefix}.nand.bias': torch.tensor([1.0], dtype=torch.float32),
f'{prefix}.and.weight': torch.tensor([1.0, 1.0], dtype=torch.float32),
f'{prefix}.and.bias': torch.tensor([-2.0], dtype=torch.float32),
}
weights = {}
# Level 1: XOR pairs
weights.update(xor_block('xor_01')) # XOR(x0, x1)
weights.update(xor_block('xor_23')) # XOR(x2, x3)
weights.update(xor_block('xor_45')) # XOR(x4, x5)
# Level 2: Combine
weights.update(xor_block('xor_0123')) # XOR(xor01, xor23)
weights.update(xor_block('xor_456')) # XOR(xor45, x6)
# Level 3: Final
weights.update(xor_block('xor_final')) # XOR(xor0123, xor456)
save_file(weights, 'model.safetensors')
def xor2(a, b, prefix):
or_out = int(a * weights[f'{prefix}.or.weight'][0] + b * weights[f'{prefix}.or.weight'][1] + weights[f'{prefix}.or.bias'] >= 0)
nand_out = int(a * weights[f'{prefix}.nand.weight'][0] + b * weights[f'{prefix}.nand.weight'][1] + weights[f'{prefix}.nand.bias'] >= 0)
return int(or_out * weights[f'{prefix}.and.weight'][0] + nand_out * weights[f'{prefix}.and.weight'][1] + weights[f'{prefix}.and.bias'] >= 0)
def parity7(x0, x1, x2, x3, x4, x5, x6):
xor01 = xor2(x0, x1, 'xor_01')
xor23 = xor2(x2, x3, 'xor_23')
xor45 = xor2(x4, x5, 'xor_45')
xor0123 = xor2(xor01, xor23, 'xor_0123')
xor456 = xor2(xor45, x6, 'xor_456')
return xor2(xor0123, xor456, 'xor_final')
print("Verifying parity7...")
errors = 0
for i in range(128):
bits = [(i >> j) & 1 for j in range(7)]
result = parity7(*bits)
expected = sum(bits) % 2
if result != expected:
errors += 1
print(f"ERROR: parity({bits}) = {result}, expected {expected}")
if errors == 0:
print("All 128 test cases passed!")
else:
print(f"FAILED: {errors} errors")
print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")
print(f"Parameters: {sum(t.numel() for t in weights.values())}")