File size: 3,901 Bytes
5f77b99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def subtractor4(a3, a2, a1, a0, b3, b2, b1, b0, weights):
    """4-bit subtractor: returns (A - B) mod 16 and borrow out"""
    inp = torch.tensor([float(a3), float(a2), float(a1), float(a0),
                        float(b3), float(b2), float(b1), float(b0)])

    # Bit 0
    d0_or = int((inp @ weights['d0_or.weight'].T + weights['d0_or.bias'] >= 0).item())
    d0_nand = int((inp @ weights['d0_nand.weight'].T + weights['d0_nand.bias'] >= 0).item())
    d0 = int((torch.tensor([float(d0_or), float(d0_nand)]) @ weights['d0.weight'].T + weights['d0.bias'] >= 0).item())
    bout0 = int((inp @ weights['bout0.weight'].T + weights['bout0.bias'] >= 0).item())

    # Bit 1
    xor1_or = int((inp @ weights['xor1_or.weight'].T + weights['xor1_or.bias'] >= 0).item())
    xor1_nand = int((inp @ weights['xor1_nand.weight'].T + weights['xor1_nand.bias'] >= 0).item())
    xor1 = int((torch.tensor([float(xor1_or), float(xor1_nand)]) @ weights['xor1.weight'].T + weights['xor1.bias'] >= 0).item())
    d1_in = torch.tensor([float(xor1), float(bout0)])
    d1_or = int((d1_in @ weights['d1_or.weight'].T + weights['d1_or.bias'] >= 0).item())
    d1_nand = int((d1_in @ weights['d1_nand.weight'].T + weights['d1_nand.bias'] >= 0).item())
    d1 = int((torch.tensor([float(d1_or), float(d1_nand)]) @ weights['d1.weight'].T + weights['d1.bias'] >= 0).item())
    not_a1 = 1 - a1
    bout1 = int((torch.tensor([float(not_a1), float(b1), float(bout0)]) @ weights['bout1.weight'].T + weights['bout1.bias'] >= 0).item())

    # Bit 2
    xor2_or = int((inp @ weights['xor2_or.weight'].T + weights['xor2_or.bias'] >= 0).item())
    xor2_nand = int((inp @ weights['xor2_nand.weight'].T + weights['xor2_nand.bias'] >= 0).item())
    xor2 = int((torch.tensor([float(xor2_or), float(xor2_nand)]) @ weights['xor2.weight'].T + weights['xor2.bias'] >= 0).item())
    d2_in = torch.tensor([float(xor2), float(bout1)])
    d2_or = int((d2_in @ weights['d2_or.weight'].T + weights['d2_or.bias'] >= 0).item())
    d2_nand = int((d2_in @ weights['d2_nand.weight'].T + weights['d2_nand.bias'] >= 0).item())
    d2 = int((torch.tensor([float(d2_or), float(d2_nand)]) @ weights['d2.weight'].T + weights['d2.bias'] >= 0).item())
    not_a2 = 1 - a2
    bout2 = int((torch.tensor([float(not_a2), float(b2), float(bout1)]) @ weights['bout2.weight'].T + weights['bout2.bias'] >= 0).item())

    # Bit 3
    xor3_or = int((inp @ weights['xor3_or.weight'].T + weights['xor3_or.bias'] >= 0).item())
    xor3_nand = int((inp @ weights['xor3_nand.weight'].T + weights['xor3_nand.bias'] >= 0).item())
    xor3 = int((torch.tensor([float(xor3_or), float(xor3_nand)]) @ weights['xor3.weight'].T + weights['xor3.bias'] >= 0).item())
    d3_in = torch.tensor([float(xor3), float(bout2)])
    d3_or = int((d3_in @ weights['d3_or.weight'].T + weights['d3_or.bias'] >= 0).item())
    d3_nand = int((d3_in @ weights['d3_nand.weight'].T + weights['d3_nand.bias'] >= 0).item())
    d3 = int((torch.tensor([float(d3_or), float(d3_nand)]) @ weights['d3.weight'].T + weights['d3.bias'] >= 0).item())
    not_a3 = 1 - a3
    bout3 = int((torch.tensor([float(not_a3), float(b3), float(bout2)]) @ weights['bout3.weight'].T + weights['bout3.bias'] >= 0).item())

    return [d3, d2, d1, d0, bout3]

if __name__ == '__main__':
    w = load_model()
    print('Subtractor4bit examples:')
    examples = [(7, 3), (5, 5), (3, 7), (15, 1), (0, 1)]
    for a, b in examples:
        a3, a2, a1, a0 = (a >> 3) & 1, (a >> 2) & 1, (a >> 1) & 1, a & 1
        b3, b2, b1, b0 = (b >> 3) & 1, (b >> 2) & 1, (b >> 1) & 1, b & 1
        result = subtractor4(a3, a2, a1, a0, b3, b2, b1, b0, w)
        diff = result[0]*8 + result[1]*4 + result[2]*2 + result[3]
        bout = result[4]
        print(f'  {a:2d} - {b:2d} = {diff:2d} (bout={bout})')