threshold-twos-complement / create_safetensors.py
CharlesCNorton
Two's complement negation
ad8baa5
import torch
from safetensors.torch import save_file
weights = {}
# 4-bit Two's Complement Negation
# Inputs: a3,a2,a1,a0 (4 inputs)
# Outputs: n3,n2,n1,n0, overflow (5 outputs)
#
# -A = ~A + 1
# Overflow when A = -8 (1000), since -(-8) = 8 can't be represented
# Invert inputs
for i in range(4):
weights[f'inv{i}.weight'] = torch.tensor([[-1.0]], dtype=torch.float32)
weights[f'inv{i}.bias'] = torch.tensor([0.0], dtype=torch.float32)
# Add 1 using half-adder chain
def add_xor(name):
weights[f'{name}.or.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
weights[f'{name}.or.bias'] = torch.tensor([-1.0], dtype=torch.float32)
weights[f'{name}.nand.weight'] = torch.tensor([[-1.0, -1.0]], dtype=torch.float32)
weights[f'{name}.nand.bias'] = torch.tensor([1.0], dtype=torch.float32)
weights[f'{name}.and.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
weights[f'{name}.and.bias'] = torch.tensor([-2.0], dtype=torch.float32)
def add_ha(name):
add_xor(f'{name}.sum')
weights[f'{name}.carry.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
weights[f'{name}.carry.bias'] = torch.tensor([-2.0], dtype=torch.float32)
# First bit: ~a0 XOR 1 = NOT(~a0) = a0... wait that's wrong
# Actually ~a0 + 1:
# bit 0: ~a0 XOR 1
# bit 1: ~a1 XOR carry0
# etc.
# ~a0 XOR 1 = ~(~a0) = a0 when carry_in=1
# But using threshold: we add 1 to ~A
# ~a0 + 1: n0 = ~a0 XOR 1 = NOT(~a0) = a0? No...
# Actually XOR with 1 flips the bit: ~a0 XOR 1 = NOT(~a0) = a0
# But we need the sum with carry...
# Let me think more carefully:
# ~A = [~a3, ~a2, ~a1, ~a0]
# (~A) + 1 starting with carry_in = 1:
# n0 = ~a0 XOR 1, c0 = ~a0 AND 1 = ~a0
# n1 = ~a1 XOR ~a0, c1 = ~a1 AND ~a0
# n2 = ~a2 XOR (~a1 AND ~a0), c2 = ~a2 AND ~a1 AND ~a0
# n3 = ~a3 XOR (~a2 AND ~a1 AND ~a0)
# Simplify: carry propagates as long as bits are 0 (after inversion, as long as original bits are 1)
# n0 = ~a0 XOR 1 = a0 XOR 0 = NOT(~a0) = a0... hmm
# Wait, ~a0 XOR 1:
# if ~a0=0 (a0=1): 0 XOR 1 = 1
# if ~a0=1 (a0=0): 1 XOR 1 = 0
# So n0 = ~(~a0) = a0? That's not right for negation.
# Let me trace through with example: A = 5 = 0101
# ~A = 1010
# ~A + 1 = 1010 + 0001 = 1011 = -5 in two's complement
# So: n0 = 0 XOR 1 = 1 βœ“
# n1 = 1 XOR 0 = 1 βœ“ (carry from bit 0 is 0)
# n2 = 0 XOR 0 = 0 βœ“
# n3 = 1 XOR 0 = 1 βœ“
# For A = 0 = 0000:
# ~A = 1111
# ~A + 1 = 1111 + 1 = 10000, but 4-bit gives 0000 βœ“ (with overflow)
# OK so the formula is:
# Starting with cin = 1:
# n0 = ~a0 XOR cin = ~a0 XOR 1 = NOT(~a0) = a0...
# Wait that gives wrong answer.
# For A=5: a0=1, ~a0=0, ~a0 XOR 1 = 0 XOR 1 = 1 βœ“
# Let me be more careful:
# A = 5 = 0101: a3=0, a2=1, a1=0, a0=1
# ~a3=1, ~a2=0, ~a1=1, ~a0=0
# Add 1:
# n0 = 0 + 1 = 1, c=0
# n1 = 1 + 0 = 1, c=0
# n2 = 0 + 0 = 0, c=0
# n3 = 1 + 0 = 1, c=0
# Result: 1011 = -5 βœ“
# So it's just incrementing ~A.
for i in range(4):
add_ha(f'inc{i}')
# Overflow detection: A = 1000 (-8)
# NOT(a3 OR a2 OR a1 OR a0) AND... no wait
# Overflow when A = -8 = 1000, meaning a3=1 and a2=a1=a0=0
weights['ov_nora.weight'] = torch.tensor([[-1.0, -1.0, -1.0]], dtype=torch.float32)
weights['ov_nora.bias'] = torch.tensor([0.0], dtype=torch.float32)
weights['overflow.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
weights['overflow.bias'] = torch.tensor([-2.0], dtype=torch.float32)
save_file(weights, 'model.safetensors')
def twos_comp(a):
inv = (~a) & 0xF
neg = (inv + 1) & 0xF
overflow = 1 if a == 8 else 0
return neg, overflow
print("Verifying 4-bit Two's Complement...")
errors = 0
for a in range(16):
result, ov = twos_comp(a)
if a == 0:
expected = 0
else:
expected = (16 - a) & 0xF
exp_ov = 1 if a == 8 else 0
if result != expected or ov != exp_ov:
errors += 1
if errors <= 5:
print(f"ERROR: -({a}) = {result}, expected {expected}")
if errors == 0:
print("All 16 test cases passed!")
else:
print(f"FAILED: {errors} errors")
print("\nSigned interpretation:")
for a in range(16):
signed_a = a if a < 8 else a - 16
neg, ov = twos_comp(a)
signed_neg = neg if neg < 8 else neg - 16
ov_str = " (OVERFLOW)" if ov else ""
print(f" -({signed_a:+d}) = {signed_neg:+d}{ov_str}")
mag = sum(t.abs().sum().item() for t in weights.values())
print(f"\nMagnitude: {mag:.0f}")
print(f"Parameters: {sum(t.numel() for t in weights.values())}")