import torch from safetensors.torch import save_file weights = {} # BCH(7,4) Decoder / Hamming(7,4) Decoder # Computes syndrome and passes through data bits def add_neuron(name, w_list, bias): weights[f'{name}.weight'] = torch.tensor([w_list], dtype=torch.float32) weights[f'{name}.bias'] = torch.tensor([bias], dtype=torch.float32) # Input: R6, R5, R4, R3, R2, R1, R0 (7-bit received word) # Output: D3, D2, D1, D0 (4 data bits), S2, S1, S0 (syndrome) # Pass through data bits add_neuron('d3', [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -1.0) # R6 add_neuron('d2', [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], -1.0) # R5 add_neuron('d1', [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], -1.0) # R4 add_neuron('d0', [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], -1.0) # R3 # Syndrome bits (parity checks) # S0 checks positions 1,3,5,7 -> R6,R4,R2,R0 (indices 0,2,4,6) add_neuron('s0_at1', [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], -1.0) add_neuron('s0_at2', [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], -2.0) add_neuron('s0_at3', [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], -3.0) add_neuron('s0_at4', [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], -4.0) # S1 checks positions 2,3,6,7 -> R6,R5,R2,R1 (indices 0,1,4,5) add_neuron('s1_at1', [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], -1.0) add_neuron('s1_at2', [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], -2.0) add_neuron('s1_at3', [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], -3.0) add_neuron('s1_at4', [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], -4.0) # S2 checks positions 4,5,6,7 -> R6,R5,R4,R2 (indices 0,1,2,4) add_neuron('s2_at1', [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0], -1.0) add_neuron('s2_at2', [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0], -2.0) add_neuron('s2_at3', [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0], -3.0) add_neuron('s2_at4', [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0], -4.0) save_file(weights, 'model.safetensors') def xor4(a, b, c, d): return a ^ b ^ c ^ d def bch_decode(r6, r5, r4, r3, r2, r1, r0): # Syndrome computation s0 = xor4(r6, r4, r2, r0) s1 = xor4(r6, r5, r2, r1) s2 = xor4(r6, r5, r4, r2) # Data bits (without correction for simplicity) d3, d2, d1, d0 = r6, r5, r4, r3 return d3, d2, d1, d0, s2, s1, s0 print("Verifying BCH(7,4) decoder...") errors = 0 # Test with valid codewords def encode(d3, d2, d1, d0): c2 = d3 ^ d2 ^ d0 c1 = d3 ^ d1 ^ d0 c0 = d2 ^ d1 ^ d0 return d3, d2, d1, d0, c2, c1, c0 for d in range(16): d3, d2, d1, d0 = (d>>3)&1, (d>>2)&1, (d>>1)&1, d&1 codeword = encode(d3, d2, d1, d0) decoded = bch_decode(*codeword) # Check data extraction if decoded[:4] != (d3, d2, d1, d0): errors += 1 print(f"Data error for d={d}") if errors == 0: print("All 16 test cases passed!") else: print(f"FAILED: {errors} errors") mag = sum(t.abs().sum().item() for t in weights.values()) print(f"Magnitude: {mag:.0f}") print(f"Parameters: {sum(t.numel() for t in weights.values())}")