|
|
import torch
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
def load_model(path='model.safetensors'):
|
|
|
return load_file(path)
|
|
|
|
|
|
def xor2(a, b, w, prefix):
|
|
|
"""2-input XOR using threshold gates"""
|
|
|
inp = torch.tensor([float(a), float(b)])
|
|
|
or_out = float((inp * w[f'{prefix}.layer1.or.weight']).sum() + w[f'{prefix}.layer1.or.bias'] >= 0)
|
|
|
nand_out = float((inp * w[f'{prefix}.layer1.nand.weight']).sum() + w[f'{prefix}.layer1.nand.bias'] >= 0)
|
|
|
l1 = torch.tensor([or_out, nand_out])
|
|
|
return int((l1 * w[f'{prefix}.layer2.weight']).sum() + w[f'{prefix}.layer2.bias'] >= 0)
|
|
|
|
|
|
def xor4(c, indices, w, prefix):
|
|
|
"""4-input XOR: XOR(a,b,c,d) = XOR(XOR(a,b), XOR(c,d))"""
|
|
|
inp = torch.tensor([float(c[i]) for i in range(7)])
|
|
|
|
|
|
|
|
|
i0, i1 = indices[0], indices[1]
|
|
|
or_out = float((inp * w[f'{prefix}.xor_{i0}{i1}.layer1.or.weight']).sum() + w[f'{prefix}.xor_{i0}{i1}.layer1.or.bias'] >= 0)
|
|
|
nand_out = float((inp * w[f'{prefix}.xor_{i0}{i1}.layer1.nand.weight']).sum() + w[f'{prefix}.xor_{i0}{i1}.layer1.nand.bias'] >= 0)
|
|
|
xor_ab = int((torch.tensor([or_out, nand_out]) * w[f'{prefix}.xor_{i0}{i1}.layer2.weight']).sum() + w[f'{prefix}.xor_{i0}{i1}.layer2.bias'] >= 0)
|
|
|
|
|
|
|
|
|
i2, i3 = indices[2], indices[3]
|
|
|
or_out = float((inp * w[f'{prefix}.xor_{i2}{i3}.layer1.or.weight']).sum() + w[f'{prefix}.xor_{i2}{i3}.layer1.or.bias'] >= 0)
|
|
|
nand_out = float((inp * w[f'{prefix}.xor_{i2}{i3}.layer1.nand.weight']).sum() + w[f'{prefix}.xor_{i2}{i3}.layer1.nand.bias'] >= 0)
|
|
|
xor_cd = int((torch.tensor([or_out, nand_out]) * w[f'{prefix}.xor_{i2}{i3}.layer2.weight']).sum() + w[f'{prefix}.xor_{i2}{i3}.layer2.bias'] >= 0)
|
|
|
|
|
|
|
|
|
inp2 = torch.tensor([float(xor_ab), float(xor_cd)])
|
|
|
or_out = float((inp2 * w[f'{prefix}.xor_final.layer1.or.weight']).sum() + w[f'{prefix}.xor_final.layer1.or.bias'] >= 0)
|
|
|
nand_out = float((inp2 * w[f'{prefix}.xor_final.layer1.nand.weight']).sum() + w[f'{prefix}.xor_final.layer1.nand.bias'] >= 0)
|
|
|
return int((torch.tensor([or_out, nand_out]) * w[f'{prefix}.xor_final.layer2.weight']).sum() + w[f'{prefix}.xor_final.layer2.bias'] >= 0)
|
|
|
|
|
|
def hamming74_decode(c, w):
|
|
|
"""Hamming(7,4) decoder with single-error correction.
|
|
|
c: list of 7 bits [c1,c2,c3,c4,c5,c6,c7]
|
|
|
Returns: list of 4 corrected data bits [d1,d2,d3,d4]
|
|
|
"""
|
|
|
|
|
|
|
|
|
s1 = xor4(c, [0, 2, 4, 6], w, 's1')
|
|
|
|
|
|
|
|
|
s2 = xor4(c, [1, 2, 5, 6], w, 's2')
|
|
|
|
|
|
|
|
|
s3 = xor4(c, [3, 4, 5, 6], w, 's3')
|
|
|
|
|
|
syndrome = torch.tensor([float(s1), float(s2), float(s3)])
|
|
|
|
|
|
|
|
|
|
|
|
flip3 = int((syndrome * w['flip3.weight']).sum() + w['flip3.bias'] >= 0)
|
|
|
|
|
|
flip5 = int((syndrome * w['flip5.weight']).sum() + w['flip5.bias'] >= 0)
|
|
|
|
|
|
flip6 = int((syndrome * w['flip6.weight']).sum() + w['flip6.bias'] >= 0)
|
|
|
|
|
|
flip7 = int((syndrome * w['flip7.weight']).sum() + w['flip7.bias'] >= 0)
|
|
|
|
|
|
|
|
|
d1 = xor2(c[2], flip3, w, 'd1.xor')
|
|
|
d2 = xor2(c[4], flip5, w, 'd2.xor')
|
|
|
d3 = xor2(c[5], flip6, w, 'd3.xor')
|
|
|
d4 = xor2(c[6], flip7, w, 'd4.xor')
|
|
|
|
|
|
return [d1, d2, d3, d4]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
w = load_model()
|
|
|
print('Hamming(7,4) Decoder with Single-Error Correction')
|
|
|
|
|
|
|
|
|
def encode(d1, d2, d3, d4):
|
|
|
p1 = d1 ^ d2 ^ d4
|
|
|
p2 = d1 ^ d3 ^ d4
|
|
|
p3 = d2 ^ d3 ^ d4
|
|
|
return [p1, p2, d1, p3, d2, d3, d4]
|
|
|
|
|
|
errors = 0
|
|
|
|
|
|
|
|
|
print('\nNo errors:')
|
|
|
for d in range(16):
|
|
|
d1, d2, d3, d4 = (d>>0)&1, (d>>1)&1, (d>>2)&1, (d>>3)&1
|
|
|
codeword = encode(d1, d2, d3, d4)
|
|
|
decoded = hamming74_decode(codeword, w)
|
|
|
expected = [d1, d2, d3, d4]
|
|
|
status = 'OK' if decoded == expected else 'FAIL'
|
|
|
if decoded != expected:
|
|
|
errors += 1
|
|
|
print(f' {d1}{d2}{d3}{d4} -> {decoded} (expected {expected}) {status}')
|
|
|
|
|
|
|
|
|
print('\nSingle-bit errors:')
|
|
|
test_data = [0b1011, 0b0000, 0b1111, 0b0101]
|
|
|
for d in test_data:
|
|
|
d1, d2, d3, d4 = (d>>0)&1, (d>>1)&1, (d>>2)&1, (d>>3)&1
|
|
|
codeword = encode(d1, d2, d3, d4)
|
|
|
|
|
|
|
|
|
for pos in range(7):
|
|
|
corrupted = codeword.copy()
|
|
|
corrupted[pos] ^= 1
|
|
|
decoded = hamming74_decode(corrupted, w)
|
|
|
expected = [d1, d2, d3, d4]
|
|
|
status = 'OK' if decoded == expected else 'FAIL'
|
|
|
if decoded != expected:
|
|
|
errors += 1
|
|
|
print(f' data={d1}{d2}{d3}{d4} err@{pos+1}: {decoded} (expected {expected}) {status}')
|
|
|
|
|
|
print(f'\nTotal errors: {errors}')
|
|
|
if errors == 0:
|
|
|
print('All tests passed!') |