import torch from safetensors.torch import save_file weights = {} # Input: A1, A0 (2 bits) # Output: Y0-Y3 (one-hot) # Yi fires when input = i for i in range(4): a1_bit = (i >> 1) & 1 a0_bit = i & 1 w = [1.0 if a1_bit else -1.0, 1.0 if a0_bit else -1.0] bias = -bin(i).count('1') weights[f'y{i}.weight'] = torch.tensor([w], dtype=torch.float32) weights[f'y{i}.bias'] = torch.tensor([float(bias)], dtype=torch.float32) save_file(weights, 'model.safetensors') def decode2to4(a1, a0): inp = torch.tensor([float(a1), float(a0)]) return [int((inp * weights[f'y{i}.weight']).sum() + weights[f'y{i}.bias'] >= 0) for i in range(4)] print("Verifying 2to4decoder...") errors = 0 for val in range(4): a1, a0 = (val >> 1) & 1, val & 1 result = decode2to4(a1, a0) expected = [1 if i == val else 0 for i in range(4)] if result != expected: errors += 1 print(f"ERROR: {val} -> {result}, expected {expected}") if errors == 0: print("All 4 test cases passed!") print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")