import torch from safetensors.torch import save_file weights = {} # 4-bit Two's Complement Negation # Inputs: a3,a2,a1,a0 (4 inputs) # Outputs: n3,n2,n1,n0, overflow (5 outputs) # # -A = ~A + 1 # Overflow when A = -8 (1000), since -(-8) = 8 can't be represented # Invert inputs for i in range(4): weights[f'inv{i}.weight'] = torch.tensor([[-1.0]], dtype=torch.float32) weights[f'inv{i}.bias'] = torch.tensor([0.0], dtype=torch.float32) # Add 1 using half-adder chain def add_xor(name): weights[f'{name}.or.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32) weights[f'{name}.or.bias'] = torch.tensor([-1.0], dtype=torch.float32) weights[f'{name}.nand.weight'] = torch.tensor([[-1.0, -1.0]], dtype=torch.float32) weights[f'{name}.nand.bias'] = torch.tensor([1.0], dtype=torch.float32) weights[f'{name}.and.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32) weights[f'{name}.and.bias'] = torch.tensor([-2.0], dtype=torch.float32) def add_ha(name): add_xor(f'{name}.sum') weights[f'{name}.carry.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32) weights[f'{name}.carry.bias'] = torch.tensor([-2.0], dtype=torch.float32) # First bit: ~a0 XOR 1 = NOT(~a0) = a0... wait that's wrong # Actually ~a0 + 1: # bit 0: ~a0 XOR 1 # bit 1: ~a1 XOR carry0 # etc. # ~a0 XOR 1 = ~(~a0) = a0 when carry_in=1 # But using threshold: we add 1 to ~A # ~a0 + 1: n0 = ~a0 XOR 1 = NOT(~a0) = a0? No... # Actually XOR with 1 flips the bit: ~a0 XOR 1 = NOT(~a0) = a0 # But we need the sum with carry... # Let me think more carefully: # ~A = [~a3, ~a2, ~a1, ~a0] # (~A) + 1 starting with carry_in = 1: # n0 = ~a0 XOR 1, c0 = ~a0 AND 1 = ~a0 # n1 = ~a1 XOR ~a0, c1 = ~a1 AND ~a0 # n2 = ~a2 XOR (~a1 AND ~a0), c2 = ~a2 AND ~a1 AND ~a0 # n3 = ~a3 XOR (~a2 AND ~a1 AND ~a0) # Simplify: carry propagates as long as bits are 0 (after inversion, as long as original bits are 1) # n0 = ~a0 XOR 1 = a0 XOR 0 = NOT(~a0) = a0... hmm # Wait, ~a0 XOR 1: # if ~a0=0 (a0=1): 0 XOR 1 = 1 # if ~a0=1 (a0=0): 1 XOR 1 = 0 # So n0 = ~(~a0) = a0? That's not right for negation. # Let me trace through with example: A = 5 = 0101 # ~A = 1010 # ~A + 1 = 1010 + 0001 = 1011 = -5 in two's complement # So: n0 = 0 XOR 1 = 1 ✓ # n1 = 1 XOR 0 = 1 ✓ (carry from bit 0 is 0) # n2 = 0 XOR 0 = 0 ✓ # n3 = 1 XOR 0 = 1 ✓ # For A = 0 = 0000: # ~A = 1111 # ~A + 1 = 1111 + 1 = 10000, but 4-bit gives 0000 ✓ (with overflow) # OK so the formula is: # Starting with cin = 1: # n0 = ~a0 XOR cin = ~a0 XOR 1 = NOT(~a0) = a0... # Wait that gives wrong answer. # For A=5: a0=1, ~a0=0, ~a0 XOR 1 = 0 XOR 1 = 1 ✓ # Let me be more careful: # A = 5 = 0101: a3=0, a2=1, a1=0, a0=1 # ~a3=1, ~a2=0, ~a1=1, ~a0=0 # Add 1: # n0 = 0 + 1 = 1, c=0 # n1 = 1 + 0 = 1, c=0 # n2 = 0 + 0 = 0, c=0 # n3 = 1 + 0 = 1, c=0 # Result: 1011 = -5 ✓ # So it's just incrementing ~A. for i in range(4): add_ha(f'inc{i}') # Overflow detection: A = 1000 (-8) # NOT(a3 OR a2 OR a1 OR a0) AND... no wait # Overflow when A = -8 = 1000, meaning a3=1 and a2=a1=a0=0 weights['ov_nora.weight'] = torch.tensor([[-1.0, -1.0, -1.0]], dtype=torch.float32) weights['ov_nora.bias'] = torch.tensor([0.0], dtype=torch.float32) weights['overflow.weight'] = torch.tensor([[1.0, 1.0]], dtype=torch.float32) weights['overflow.bias'] = torch.tensor([-2.0], dtype=torch.float32) save_file(weights, 'model.safetensors') def twos_comp(a): inv = (~a) & 0xF neg = (inv + 1) & 0xF overflow = 1 if a == 8 else 0 return neg, overflow print("Verifying 4-bit Two's Complement...") errors = 0 for a in range(16): result, ov = twos_comp(a) if a == 0: expected = 0 else: expected = (16 - a) & 0xF exp_ov = 1 if a == 8 else 0 if result != expected or ov != exp_ov: errors += 1 if errors <= 5: print(f"ERROR: -({a}) = {result}, expected {expected}") if errors == 0: print("All 16 test cases passed!") else: print(f"FAILED: {errors} errors") print("\nSigned interpretation:") for a in range(16): signed_a = a if a < 8 else a - 16 neg, ov = twos_comp(a) signed_neg = neg if neg < 8 else neg - 16 ov_str = " (OVERFLOW)" if ov else "" print(f" -({signed_a:+d}) = {signed_neg:+d}{ov_str}") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"\nMagnitude: {mag:.0f}") print(f"Parameters: {sum(t.numel() for t in weights.values())}")