import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def xor2(a, b, w, prefix): """2-input XOR using threshold gates""" inp = torch.tensor([float(a), float(b)]) or_out = float((inp * w[f'{prefix}.layer1.or.weight']).sum() + w[f'{prefix}.layer1.or.bias'] >= 0) nand_out = float((inp * w[f'{prefix}.layer1.nand.weight']).sum() + w[f'{prefix}.layer1.nand.bias'] >= 0) l1 = torch.tensor([or_out, nand_out]) return int((l1 * w[f'{prefix}.layer2.weight']).sum() + w[f'{prefix}.layer2.bias'] >= 0) def xor4(c, indices, w, prefix): """4-input XOR: XOR(a,b,c,d) = XOR(XOR(a,b), XOR(c,d))""" inp = torch.tensor([float(c[i]) for i in range(7)]) # First pair XOR i0, i1 = indices[0], indices[1] or_out = float((inp * w[f'{prefix}.xor_{i0}{i1}.layer1.or.weight']).sum() + w[f'{prefix}.xor_{i0}{i1}.layer1.or.bias'] >= 0) nand_out = float((inp * w[f'{prefix}.xor_{i0}{i1}.layer1.nand.weight']).sum() + w[f'{prefix}.xor_{i0}{i1}.layer1.nand.bias'] >= 0) xor_ab = int((torch.tensor([or_out, nand_out]) * w[f'{prefix}.xor_{i0}{i1}.layer2.weight']).sum() + w[f'{prefix}.xor_{i0}{i1}.layer2.bias'] >= 0) # Second pair XOR i2, i3 = indices[2], indices[3] or_out = float((inp * w[f'{prefix}.xor_{i2}{i3}.layer1.or.weight']).sum() + w[f'{prefix}.xor_{i2}{i3}.layer1.or.bias'] >= 0) nand_out = float((inp * w[f'{prefix}.xor_{i2}{i3}.layer1.nand.weight']).sum() + w[f'{prefix}.xor_{i2}{i3}.layer1.nand.bias'] >= 0) xor_cd = int((torch.tensor([or_out, nand_out]) * w[f'{prefix}.xor_{i2}{i3}.layer2.weight']).sum() + w[f'{prefix}.xor_{i2}{i3}.layer2.bias'] >= 0) # Final XOR inp2 = torch.tensor([float(xor_ab), float(xor_cd)]) or_out = float((inp2 * w[f'{prefix}.xor_final.layer1.or.weight']).sum() + w[f'{prefix}.xor_final.layer1.or.bias'] >= 0) nand_out = float((inp2 * w[f'{prefix}.xor_final.layer1.nand.weight']).sum() + w[f'{prefix}.xor_final.layer1.nand.bias'] >= 0) return int((torch.tensor([or_out, nand_out]) * w[f'{prefix}.xor_final.layer2.weight']).sum() + w[f'{prefix}.xor_final.layer2.bias'] >= 0) def hamming74_decode(c, w): """Hamming(7,4) decoder with single-error correction. c: list of 7 bits [c1,c2,c3,c4,c5,c6,c7] Returns: list of 4 corrected data bits [d1,d2,d3,d4] """ # Compute syndrome bits # s1 = c1 XOR c3 XOR c5 XOR c7 (indices 0,2,4,6) s1 = xor4(c, [0, 2, 4, 6], w, 's1') # s2 = c2 XOR c3 XOR c6 XOR c7 (indices 1,2,5,6) s2 = xor4(c, [1, 2, 5, 6], w, 's2') # s3 = c4 XOR c5 XOR c6 XOR c7 (indices 3,4,5,6) s3 = xor4(c, [3, 4, 5, 6], w, 's3') syndrome = torch.tensor([float(s1), float(s2), float(s3)]) # Compute flip signals for each data position # flip3: syndrome = 011 (position 3 = d1) flip3 = int((syndrome * w['flip3.weight']).sum() + w['flip3.bias'] >= 0) # flip5: syndrome = 101 (position 5 = d2) flip5 = int((syndrome * w['flip5.weight']).sum() + w['flip5.bias'] >= 0) # flip6: syndrome = 110 (position 6 = d3) flip6 = int((syndrome * w['flip6.weight']).sum() + w['flip6.bias'] >= 0) # flip7: syndrome = 111 (position 7 = d4) flip7 = int((syndrome * w['flip7.weight']).sum() + w['flip7.bias'] >= 0) # Correct data bits: di = ci XOR flip_i d1 = xor2(c[2], flip3, w, 'd1.xor') # c3 d2 = xor2(c[4], flip5, w, 'd2.xor') # c5 d3 = xor2(c[5], flip6, w, 'd3.xor') # c6 d4 = xor2(c[6], flip7, w, 'd4.xor') # c7 return [d1, d2, d3, d4] if __name__ == '__main__': w = load_model() print('Hamming(7,4) Decoder with Single-Error Correction') # Reference encoder def encode(d1, d2, d3, d4): p1 = d1 ^ d2 ^ d4 p2 = d1 ^ d3 ^ d4 p3 = d2 ^ d3 ^ d4 return [p1, p2, d1, p3, d2, d3, d4] errors = 0 # Test all 16 data words with no errors print('\nNo errors:') for d in range(16): d1, d2, d3, d4 = (d>>0)&1, (d>>1)&1, (d>>2)&1, (d>>3)&1 codeword = encode(d1, d2, d3, d4) decoded = hamming74_decode(codeword, w) expected = [d1, d2, d3, d4] status = 'OK' if decoded == expected else 'FAIL' if decoded != expected: errors += 1 print(f' {d1}{d2}{d3}{d4} -> {decoded} (expected {expected}) {status}') # Test single-bit errors print('\nSingle-bit errors:') test_data = [0b1011, 0b0000, 0b1111, 0b0101] for d in test_data: d1, d2, d3, d4 = (d>>0)&1, (d>>1)&1, (d>>2)&1, (d>>3)&1 codeword = encode(d1, d2, d3, d4) # Introduce error at each position for pos in range(7): corrupted = codeword.copy() corrupted[pos] ^= 1 decoded = hamming74_decode(corrupted, w) expected = [d1, d2, d3, d4] status = 'OK' if decoded == expected else 'FAIL' if decoded != expected: errors += 1 print(f' data={d1}{d2}{d3}{d4} err@{pos+1}: {decoded} (expected {expected}) {status}') print(f'\nTotal errors: {errors}') if errors == 0: print('All tests passed!')