File size: 5,233 Bytes
8b64896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def xor2(a, b, w, prefix):
    """2-input XOR using threshold gates"""
    inp = torch.tensor([float(a), float(b)])
    or_out = float((inp * w[f'{prefix}.layer1.or.weight']).sum() + w[f'{prefix}.layer1.or.bias'] >= 0)
    nand_out = float((inp * w[f'{prefix}.layer1.nand.weight']).sum() + w[f'{prefix}.layer1.nand.bias'] >= 0)
    l1 = torch.tensor([or_out, nand_out])
    return int((l1 * w[f'{prefix}.layer2.weight']).sum() + w[f'{prefix}.layer2.bias'] >= 0)

def xor4(c, indices, w, prefix):
    """4-input XOR: XOR(a,b,c,d) = XOR(XOR(a,b), XOR(c,d))"""
    inp = torch.tensor([float(c[i]) for i in range(7)])

    # First pair XOR
    i0, i1 = indices[0], indices[1]
    or_out = float((inp * w[f'{prefix}.xor_{i0}{i1}.layer1.or.weight']).sum() + w[f'{prefix}.xor_{i0}{i1}.layer1.or.bias'] >= 0)
    nand_out = float((inp * w[f'{prefix}.xor_{i0}{i1}.layer1.nand.weight']).sum() + w[f'{prefix}.xor_{i0}{i1}.layer1.nand.bias'] >= 0)
    xor_ab = int((torch.tensor([or_out, nand_out]) * w[f'{prefix}.xor_{i0}{i1}.layer2.weight']).sum() + w[f'{prefix}.xor_{i0}{i1}.layer2.bias'] >= 0)

    # Second pair XOR
    i2, i3 = indices[2], indices[3]
    or_out = float((inp * w[f'{prefix}.xor_{i2}{i3}.layer1.or.weight']).sum() + w[f'{prefix}.xor_{i2}{i3}.layer1.or.bias'] >= 0)
    nand_out = float((inp * w[f'{prefix}.xor_{i2}{i3}.layer1.nand.weight']).sum() + w[f'{prefix}.xor_{i2}{i3}.layer1.nand.bias'] >= 0)
    xor_cd = int((torch.tensor([or_out, nand_out]) * w[f'{prefix}.xor_{i2}{i3}.layer2.weight']).sum() + w[f'{prefix}.xor_{i2}{i3}.layer2.bias'] >= 0)

    # Final XOR
    inp2 = torch.tensor([float(xor_ab), float(xor_cd)])
    or_out = float((inp2 * w[f'{prefix}.xor_final.layer1.or.weight']).sum() + w[f'{prefix}.xor_final.layer1.or.bias'] >= 0)
    nand_out = float((inp2 * w[f'{prefix}.xor_final.layer1.nand.weight']).sum() + w[f'{prefix}.xor_final.layer1.nand.bias'] >= 0)
    return int((torch.tensor([or_out, nand_out]) * w[f'{prefix}.xor_final.layer2.weight']).sum() + w[f'{prefix}.xor_final.layer2.bias'] >= 0)

def hamming74_decode(c, w):
    """Hamming(7,4) decoder with single-error correction.

    c: list of 7 bits [c1,c2,c3,c4,c5,c6,c7]

    Returns: list of 4 corrected data bits [d1,d2,d3,d4]

    """
    # Compute syndrome bits
    # s1 = c1 XOR c3 XOR c5 XOR c7 (indices 0,2,4,6)
    s1 = xor4(c, [0, 2, 4, 6], w, 's1')

    # s2 = c2 XOR c3 XOR c6 XOR c7 (indices 1,2,5,6)
    s2 = xor4(c, [1, 2, 5, 6], w, 's2')

    # s3 = c4 XOR c5 XOR c6 XOR c7 (indices 3,4,5,6)
    s3 = xor4(c, [3, 4, 5, 6], w, 's3')

    syndrome = torch.tensor([float(s1), float(s2), float(s3)])

    # Compute flip signals for each data position
    # flip3: syndrome = 011 (position 3 = d1)
    flip3 = int((syndrome * w['flip3.weight']).sum() + w['flip3.bias'] >= 0)
    # flip5: syndrome = 101 (position 5 = d2)
    flip5 = int((syndrome * w['flip5.weight']).sum() + w['flip5.bias'] >= 0)
    # flip6: syndrome = 110 (position 6 = d3)
    flip6 = int((syndrome * w['flip6.weight']).sum() + w['flip6.bias'] >= 0)
    # flip7: syndrome = 111 (position 7 = d4)
    flip7 = int((syndrome * w['flip7.weight']).sum() + w['flip7.bias'] >= 0)

    # Correct data bits: di = ci XOR flip_i
    d1 = xor2(c[2], flip3, w, 'd1.xor')  # c3
    d2 = xor2(c[4], flip5, w, 'd2.xor')  # c5
    d3 = xor2(c[5], flip6, w, 'd3.xor')  # c6
    d4 = xor2(c[6], flip7, w, 'd4.xor')  # c7

    return [d1, d2, d3, d4]

if __name__ == '__main__':
    w = load_model()
    print('Hamming(7,4) Decoder with Single-Error Correction')

    # Reference encoder
    def encode(d1, d2, d3, d4):
        p1 = d1 ^ d2 ^ d4
        p2 = d1 ^ d3 ^ d4
        p3 = d2 ^ d3 ^ d4
        return [p1, p2, d1, p3, d2, d3, d4]

    errors = 0

    # Test all 16 data words with no errors
    print('\nNo errors:')
    for d in range(16):
        d1, d2, d3, d4 = (d>>0)&1, (d>>1)&1, (d>>2)&1, (d>>3)&1
        codeword = encode(d1, d2, d3, d4)
        decoded = hamming74_decode(codeword, w)
        expected = [d1, d2, d3, d4]
        status = 'OK' if decoded == expected else 'FAIL'
        if decoded != expected:
            errors += 1
            print(f'  {d1}{d2}{d3}{d4} -> {decoded} (expected {expected}) {status}')

    # Test single-bit errors
    print('\nSingle-bit errors:')
    test_data = [0b1011, 0b0000, 0b1111, 0b0101]
    for d in test_data:
        d1, d2, d3, d4 = (d>>0)&1, (d>>1)&1, (d>>2)&1, (d>>3)&1
        codeword = encode(d1, d2, d3, d4)

        # Introduce error at each position
        for pos in range(7):
            corrupted = codeword.copy()
            corrupted[pos] ^= 1
            decoded = hamming74_decode(corrupted, w)
            expected = [d1, d2, d3, d4]
            status = 'OK' if decoded == expected else 'FAIL'
            if decoded != expected:
                errors += 1
                print(f'  data={d1}{d2}{d3}{d4} err@{pos+1}: {decoded} (expected {expected}) {status}')

    print(f'\nTotal errors: {errors}')
    if errors == 0:
        print('All tests passed!')