File size: 2,013 Bytes
8676f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def gray2binary(g3, g2, g1, g0, weights):
    """Convert 4-bit Gray code to binary."""
    inp = torch.tensor([float(g3), float(g2), float(g1), float(g0)])

    # B3 = G3
    b3 = int((inp @ weights['b3.weight'].T + weights['b3.bias'] >= 0).item())

    # X1 = XOR(G3, G2)
    x1_or = int((inp @ weights['x1_or.weight'].T + weights['x1_or.bias'] >= 0).item())
    x1_nand = int((inp @ weights['x1_nand.weight'].T + weights['x1_nand.bias'] >= 0).item())
    x1_vec = torch.tensor([float(x1_or), float(x1_nand)])
    b2 = int((x1_vec @ weights['b2.weight'].T + weights['b2.bias'] >= 0).item())
    x1 = b2

    # X2 = XOR(X1, G1)
    x2_inp = torch.tensor([float(x1), float(g1)])
    x2_or = int((x2_inp @ weights['x2_or.weight'].T + weights['x2_or.bias'] >= 0).item())
    x2_nand = int((x2_inp @ weights['x2_nand.weight'].T + weights['x2_nand.bias'] >= 0).item())
    x2_vec = torch.tensor([float(x2_or), float(x2_nand)])
    b1 = int((x2_vec @ weights['b1.weight'].T + weights['b1.bias'] >= 0).item())
    x2 = b1

    # X3 = XOR(X2, G0)
    x3_inp = torch.tensor([float(x2), float(g0)])
    x3_or = int((x3_inp @ weights['x3_or.weight'].T + weights['x3_or.bias'] >= 0).item())
    x3_nand = int((x3_inp @ weights['x3_nand.weight'].T + weights['x3_nand.bias'] >= 0).item())
    x3_vec = torch.tensor([float(x3_or), float(x3_nand)])
    b0 = int((x3_vec @ weights['b0.weight'].T + weights['b0.bias'] >= 0).item())

    return b3, b2, b1, b0

if __name__ == '__main__':
    w = load_model()
    print('Gray to Binary conversion:')
    print('Binary -> Gray -> Binary')
    for i in range(16):
        gray = i ^ (i >> 1)
        g3, g2, g1, g0 = (gray >> 3) & 1, (gray >> 2) & 1, (gray >> 1) & 1, gray & 1
        b3, b2, b1, b0 = gray2binary(g3, g2, g1, g0, w)
        result = b3 * 8 + b2 * 4 + b1 * 2 + b0
        print(f'  {i:2d} -> {g3}{g2}{g1}{g0} -> {b3}{b2}{b1}{b0} = {result:2d}')