File size: 2,423 Bytes
6e6bb88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def min2(a1, a0, b1, b0, weights):
    """Minimum of two 2-bit unsigned integers.

    Returns (m1, m0) where m = min(a, b).

    """
    inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)])

    l1_keys = ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'both1_high', 'both1_low', 'a1', 'a0', 'b1', 'b0']
    l1 = {k: int((inp @ weights[f'l1.{k}.weight'].T + weights[f'l1.{k}.bias'] >= 0).item()) for k in l1_keys}
    l1_out = torch.tensor([float(l1[k]) for k in l1_keys])

    l2_keys = ['a1_eq_b1', 'a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0']
    l2 = {k: int((l1_out @ weights[f'l2.{k}.weight'].T + weights[f'l2.{k}.bias'] >= 0).item()) for k in l2_keys}
    l2_out = torch.tensor([float(l2[k]) for k in l2_keys])

    l3_keys = ['a_gt_b_part2', 'a0_neq_b0', 'a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1']
    l3 = {k: int((l2_out @ weights[f'l3.{k}.weight'].T + weights[f'l3.{k}.bias'] >= 0).item()) for k in l3_keys}
    l3_out = torch.tensor([float(l3[k]) for k in l3_keys])

    l4_keys = ['a_gt_b', 'a_eq_b', 'a1', 'a0', 'b1', 'b0']
    l4 = {k: int((l3_out @ weights[f'l4.{k}.weight'].T + weights[f'l4.{k}.bias'] >= 0).item()) for k in l4_keys}
    l4_out = torch.tensor([float(l4[k]) for k in l4_keys])

    l5_keys = ['a_le_b', 'a1', 'a0', 'b1', 'b0']
    l5 = {k: int((l4_out @ weights[f'l5.{k}.weight'].T + weights[f'l5.{k}.bias'] >= 0).item()) for k in l5_keys}
    l5_out = torch.tensor([float(l5[k]) for k in l5_keys])

    l6_keys = ['m1_a', 'm1_b', 'm0_a', 'm0_b']
    l6 = {k: int((l5_out @ weights[f'l6.{k}.weight'].T + weights[f'l6.{k}.bias'] >= 0).item()) for k in l6_keys}
    l6_out = torch.tensor([float(l6[k]) for k in l6_keys])

    m1 = int((l6_out @ weights['l7.m1.weight'].T + weights['l7.m1.bias'] >= 0).item())
    m0 = int((l6_out @ weights['l7.m0.weight'].T + weights['l7.m0.bias'] >= 0).item())

    return m1, m0

if __name__ == '__main__':
    w = load_model()
    print('min2 truth table:')
    print('  a  b | min')
    print('-------+----')
    for a in range(4):
        for b in range(4):
            a1, a0 = (a >> 1) & 1, a & 1
            b1, b0 = (b >> 1) & 1, b & 1
            m1, m0 = min2(a1, a0, b1, b0, w)
            print(f'  {a}  {b} |  {2*m1 + m0}')