File size: 1,134 Bytes
67df09a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from safetensors.torch import save_file

weights = {}

# Input: A1, A0 (2 bits)
# Output: Y0-Y3 (one-hot)
# Yi fires when input = i

for i in range(4):
    a1_bit = (i >> 1) & 1
    a0_bit = i & 1
    w = [1.0 if a1_bit else -1.0, 1.0 if a0_bit else -1.0]
    bias = -bin(i).count('1')
    weights[f'y{i}.weight'] = torch.tensor([w], dtype=torch.float32)
    weights[f'y{i}.bias'] = torch.tensor([float(bias)], dtype=torch.float32)

save_file(weights, 'model.safetensors')

def decode2to4(a1, a0):
    inp = torch.tensor([float(a1), float(a0)])
    return [int((inp * weights[f'y{i}.weight']).sum() + weights[f'y{i}.bias'] >= 0) for i in range(4)]

print("Verifying 2to4decoder...")
errors = 0
for val in range(4):
    a1, a0 = (val >> 1) & 1, val & 1
    result = decode2to4(a1, a0)
    expected = [1 if i == val else 0 for i in range(4)]
    if result != expected:
        errors += 1
        print(f"ERROR: {val} -> {result}, expected {expected}")

if errors == 0:
    print("All 4 test cases passed!")
print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}")