| | import torch
|
| | from safetensors.torch import save_file
|
| |
|
| | weights = {}
|
| |
|
| |
|
| |
|
| |
|
| | for i in range(16):
|
| |
|
| | w = []
|
| | for bit_pos in range(3, -1, -1):
|
| | bit_val = (i >> bit_pos) & 1
|
| | w.append(1.0 if bit_val else -1.0)
|
| |
|
| | bias = -bin(i).count('1')
|
| |
|
| | weights[f'y{i}.weight'] = torch.tensor([w], dtype=torch.float32)
|
| | weights[f'y{i}.bias'] = torch.tensor([float(bias)], dtype=torch.float32)
|
| |
|
| | save_file(weights, 'model.safetensors')
|
| |
|
| |
|
| | def decode(a3, a2, a1, a0):
|
| | inp = torch.tensor([float(a3), float(a2), float(a1), float(a0)])
|
| | outputs = []
|
| | for i in range(16):
|
| | y = int((inp * weights[f'y{i}.weight']).sum() + weights[f'y{i}.bias'] >= 0)
|
| | outputs.append(y)
|
| | return outputs
|
| |
|
| | print("Verifying 4to16decoder...")
|
| | errors = 0
|
| | for val in range(16):
|
| | a3, a2, a1, a0 = (val >> 3) & 1, (val >> 2) & 1, (val >> 1) & 1, val & 1
|
| | result = decode(a3, a2, a1, a0)
|
| | expected = [1 if i == val else 0 for i in range(16)]
|
| | if result != expected:
|
| | errors += 1
|
| | print(f"ERROR: {val} ({a3}{a2}{a1}{a0}) -> {result}")
|
| |
|
| | if errors == 0:
|
| | print("All 16 test cases passed!")
|
| |
|
| | mag = sum(t.abs().sum().item() for t in weights.values())
|
| | print(f"Magnitude: {mag:.0f}")
|
| |
|