| | import torch
|
| | from safetensors.torch import save_file
|
| |
|
| | weights = {}
|
| |
|
| |
|
| |
|
| |
|
| | def add_neuron(name, w_list, bias):
|
| | weights[f'{name}.weight'] = torch.tensor([w_list], dtype=torch.float32)
|
| | weights[f'{name}.bias'] = torch.tensor([bias], dtype=torch.float32)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | add_neuron('d3', [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], -1.0)
|
| | add_neuron('d2', [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], -1.0)
|
| | add_neuron('d1', [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], -1.0)
|
| | add_neuron('d0', [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], -1.0)
|
| |
|
| |
|
| |
|
| | add_neuron('s0_at1', [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], -1.0)
|
| | add_neuron('s0_at2', [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], -2.0)
|
| | add_neuron('s0_at3', [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], -3.0)
|
| | add_neuron('s0_at4', [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], -4.0)
|
| |
|
| |
|
| | add_neuron('s1_at1', [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], -1.0)
|
| | add_neuron('s1_at2', [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], -2.0)
|
| | add_neuron('s1_at3', [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], -3.0)
|
| | add_neuron('s1_at4', [1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], -4.0)
|
| |
|
| |
|
| | add_neuron('s2_at1', [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0], -1.0)
|
| | add_neuron('s2_at2', [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0], -2.0)
|
| | add_neuron('s2_at3', [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0], -3.0)
|
| | add_neuron('s2_at4', [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0], -4.0)
|
| |
|
| | save_file(weights, 'model.safetensors')
|
| |
|
| | def xor4(a, b, c, d):
|
| | return a ^ b ^ c ^ d
|
| |
|
| | def bch_decode(r6, r5, r4, r3, r2, r1, r0):
|
| |
|
| | s0 = xor4(r6, r4, r2, r0)
|
| | s1 = xor4(r6, r5, r2, r1)
|
| | s2 = xor4(r6, r5, r4, r2)
|
| |
|
| |
|
| | d3, d2, d1, d0 = r6, r5, r4, r3
|
| |
|
| | return d3, d2, d1, d0, s2, s1, s0
|
| |
|
| | print("Verifying BCH(7,4) decoder...")
|
| | errors = 0
|
| |
|
| |
|
| | def encode(d3, d2, d1, d0):
|
| | c2 = d3 ^ d2 ^ d0
|
| | c1 = d3 ^ d1 ^ d0
|
| | c0 = d2 ^ d1 ^ d0
|
| | return d3, d2, d1, d0, c2, c1, c0
|
| |
|
| | for d in range(16):
|
| | d3, d2, d1, d0 = (d>>3)&1, (d>>2)&1, (d>>1)&1, d&1
|
| | codeword = encode(d3, d2, d1, d0)
|
| | decoded = bch_decode(*codeword)
|
| |
|
| |
|
| | if decoded[:4] != (d3, d2, d1, d0):
|
| | errors += 1
|
| | print(f"Data error for d={d}")
|
| |
|
| | if errors == 0:
|
| | print("All 16 test cases passed!")
|
| | else:
|
| | print(f"FAILED: {errors} errors")
|
| |
|
| | mag = sum(t.abs().sum().item() for t in weights.values())
|
| | print(f"Magnitude: {mag:.0f}")
|
| | print(f"Parameters: {sum(t.numel() for t in weights.values())}")
|
| |
|