| import torch | |
| from safetensors.torch import save_file | |
| weights = {} | |
| # Input: A1, A0 (2 bits) | |
| # Output: Y0-Y3 (one-hot) | |
| # Yi fires when input = i | |
| for i in range(4): | |
| a1_bit = (i >> 1) & 1 | |
| a0_bit = i & 1 | |
| w = [1.0 if a1_bit else -1.0, 1.0 if a0_bit else -1.0] | |
| bias = -bin(i).count('1') | |
| weights[f'y{i}.weight'] = torch.tensor([w], dtype=torch.float32) | |
| weights[f'y{i}.bias'] = torch.tensor([float(bias)], dtype=torch.float32) | |
| save_file(weights, 'model.safetensors') | |
| def decode2to4(a1, a0): | |
| inp = torch.tensor([float(a1), float(a0)]) | |
| return [int((inp * weights[f'y{i}.weight']).sum() + weights[f'y{i}.bias'] >= 0) for i in range(4)] | |
| print("Verifying 2to4decoder...") | |
| errors = 0 | |
| for val in range(4): | |
| a1, a0 = (val >> 1) & 1, val & 1 | |
| result = decode2to4(a1, a0) | |
| expected = [1 if i == val else 0 for i in range(4)] | |
| if result != expected: | |
| errors += 1 | |
| print(f"ERROR: {val} -> {result}, expected {expected}") | |
| if errors == 0: | |
| print("All 4 test cases passed!") | |
| print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}") | |