File size: 4,157 Bytes
29db051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def xor2_from_weights(a, b, w, or_w, or_b, nand_w, nand_b, and_w, and_b):
    """Compute XOR(a,b) using threshold gates"""
    inp = torch.tensor([float(a), float(b)])
    or_out = float((inp * or_w).sum() + or_b >= 0)
    nand_out = float((inp * nand_w).sum() + nand_b >= 0)
    l1 = torch.tensor([or_out, nand_out])
    return int((l1 * and_w).sum() + and_b >= 0)

def hamming74_encode(d1, d2, d3, d4, w):
    """Hamming(7,4) encoder: 4 data bits -> 7 coded bits"""
    inp = torch.tensor([float(d1), float(d2), float(d3), float(d4)])

    # p1 = d1 XOR d2 XOR d4
    or_out = float((inp * w['p1.xor12.layer1.or.weight']).sum() + w['p1.xor12.layer1.or.bias'] >= 0)
    nand_out = float((inp * w['p1.xor12.layer1.nand.weight']).sum() + w['p1.xor12.layer1.nand.bias'] >= 0)
    xor12 = int((torch.tensor([or_out, nand_out]) * w['p1.xor12.layer2.weight']).sum() + w['p1.xor12.layer2.bias'] >= 0)

    inp2 = torch.tensor([float(xor12), float(d4)])
    or_out = float((inp2 * w['p1.xor_final.layer1.or.weight']).sum() + w['p1.xor_final.layer1.or.bias'] >= 0)
    nand_out = float((inp2 * w['p1.xor_final.layer1.nand.weight']).sum() + w['p1.xor_final.layer1.nand.bias'] >= 0)
    p1 = int((torch.tensor([or_out, nand_out]) * w['p1.xor_final.layer2.weight']).sum() + w['p1.xor_final.layer2.bias'] >= 0)

    # p2 = d1 XOR d3 XOR d4
    or_out = float((inp * w['p2.xor13.layer1.or.weight']).sum() + w['p2.xor13.layer1.or.bias'] >= 0)
    nand_out = float((inp * w['p2.xor13.layer1.nand.weight']).sum() + w['p2.xor13.layer1.nand.bias'] >= 0)
    xor13 = int((torch.tensor([or_out, nand_out]) * w['p2.xor13.layer2.weight']).sum() + w['p2.xor13.layer2.bias'] >= 0)

    inp2 = torch.tensor([float(xor13), float(d4)])
    or_out = float((inp2 * w['p2.xor_final.layer1.or.weight']).sum() + w['p2.xor_final.layer1.or.bias'] >= 0)
    nand_out = float((inp2 * w['p2.xor_final.layer1.nand.weight']).sum() + w['p2.xor_final.layer1.nand.bias'] >= 0)
    p2 = int((torch.tensor([or_out, nand_out]) * w['p2.xor_final.layer2.weight']).sum() + w['p2.xor_final.layer2.bias'] >= 0)

    # p3 = d2 XOR d3 XOR d4
    or_out = float((inp * w['p3.xor23.layer1.or.weight']).sum() + w['p3.xor23.layer1.or.bias'] >= 0)
    nand_out = float((inp * w['p3.xor23.layer1.nand.weight']).sum() + w['p3.xor23.layer1.nand.bias'] >= 0)
    xor23 = int((torch.tensor([or_out, nand_out]) * w['p3.xor23.layer2.weight']).sum() + w['p3.xor23.layer2.bias'] >= 0)

    inp2 = torch.tensor([float(xor23), float(d4)])
    or_out = float((inp2 * w['p3.xor_final.layer1.or.weight']).sum() + w['p3.xor_final.layer1.or.bias'] >= 0)
    nand_out = float((inp2 * w['p3.xor_final.layer1.nand.weight']).sum() + w['p3.xor_final.layer1.nand.bias'] >= 0)
    p3 = int((torch.tensor([or_out, nand_out]) * w['p3.xor_final.layer2.weight']).sum() + w['p3.xor_final.layer2.bias'] >= 0)

    # Data pass-through
    c3 = int((inp * w['d1.weight']).sum() + w['d1.bias'] >= 0)
    c5 = int((inp * w['d2.weight']).sum() + w['d2.bias'] >= 0)
    c6 = int((inp * w['d3.weight']).sum() + w['d3.bias'] >= 0)
    c7 = int((inp * w['d4.weight']).sum() + w['d4.bias'] >= 0)

    # Output: c1=p1, c2=p2, c3=d1, c4=p3, c5=d2, c6=d3, c7=d4
    return [p1, p2, c3, p3, c5, c6, c7]

if __name__ == '__main__':
    w = load_model()
    print('Hamming(7,4) Encoder')
    print('Input (d1d2d3d4) -> Output (c1c2c3c4c5c6c7)')

    def ref_encode(d1, d2, d3, d4):
        p1 = d1 ^ d2 ^ d4
        p2 = d1 ^ d3 ^ d4
        p3 = d2 ^ d3 ^ d4
        return [p1, p2, d1, p3, d2, d3, d4]

    errors = 0
    for d in range(16):
        d1, d2, d3, d4 = (d>>0)&1, (d>>1)&1, (d>>2)&1, (d>>3)&1
        result = hamming74_encode(d1, d2, d3, d4, w)
        expected = ref_encode(d1, d2, d3, d4)
        status = 'OK' if result == expected else 'FAIL'
        if result != expected:
            errors += 1
        r_str = ''.join(map(str, result))
        print(f'{d1}{d2}{d3}{d4} -> {r_str} {status}')

    print(f'\n{16-errors}/16 correct')