File size: 1,965 Bytes
81fa364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from safetensors.torch import save_file

# Input order: i3, i2, i1, i0 (i3 = highest priority)
# Outputs: y1, y0 (binary encoding), v (valid)

weights = {
    # y1 = i3 OR i2
    'y1.weight': torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32),
    'y1.bias': torch.tensor([-1.0], dtype=torch.float32),
    # y0 = i3 OR (NOT i2 AND i1)
    'y0.weight': torch.tensor([[2.0, -1.0, 1.0, 0.0]], dtype=torch.float32),
    'y0.bias': torch.tensor([-1.0], dtype=torch.float32),
    # v = i3 OR i2 OR i1 OR i0
    'v.weight': torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32),
    'v.bias': torch.tensor([-1.0], dtype=torch.float32),
}
save_file(weights, 'model.safetensors')

def priority_encode(i3, i2, i1, i0):
    inp = torch.tensor([float(i3), float(i2), float(i1), float(i0)])
    y1 = int((inp @ weights['y1.weight'].T + weights['y1.bias'] >= 0).item())
    y0 = int((inp @ weights['y0.weight'].T + weights['y0.bias'] >= 0).item())
    v = int((inp @ weights['v.weight'].T + weights['v.bias'] >= 0).item())
    return y1, y0, v

print("Verifying priorityencoder4...")
errors = 0
for val in range(16):
    i3, i2, i1, i0 = (val >> 3) & 1, (val >> 2) & 1, (val >> 1) & 1, val & 1
    y1, y0, v = priority_encode(i3, i2, i1, i0)

    # Determine expected output
    if i3:
        exp_idx, exp_v = 3, 1
    elif i2:
        exp_idx, exp_v = 2, 1
    elif i1:
        exp_idx, exp_v = 1, 1
    elif i0:
        exp_idx, exp_v = 0, 1
    else:
        exp_idx, exp_v = 0, 0

    exp_y1, exp_y0 = (exp_idx >> 1) & 1, exp_idx & 1
    if exp_v == 0:
        exp_y1, exp_y0 = 0, 0  # don't care, but we output 0

    if v != exp_v or (v == 1 and (y1 != exp_y1 or y0 != exp_y0)):
        errors += 1
        print(f"ERROR: {i3}{i2}{i1}{i0} -> y1={y1},y0={y0},v={v}, expected {exp_y1},{exp_y0},{exp_v}")

if errors == 0:
    print("All 16 test cases passed!")

mag = sum(t.abs().sum().item() for t in weights.values())
print(f"Magnitude: {mag:.0f}")