| import torch |
| from safetensors.torch import save_file |
|
|
| weights = { |
| 'layer1.weight': torch.tensor([ |
| [1.0, 1.0, 1.0, 1.0, 1.0], |
| [-1.0, -1.0, -1.0, -1.0, -1.0] |
| ], dtype=torch.float32), |
| 'layer1.bias': torch.tensor([-4.0, 4.0], dtype=torch.float32), |
| 'layer2.weight': torch.tensor([[1.0, 1.0]], dtype=torch.float32), |
| 'layer2.bias': torch.tensor([-2.0], dtype=torch.float32) |
| } |
| save_file(weights, 'model.safetensors') |
|
|
| def exactly4of5(a, b, c, d, e): |
| inp = torch.tensor([float(a), float(b), float(c), float(d), float(e)]) |
| l1 = (inp @ weights['layer1.weight'].T + weights['layer1.bias'] >= 0).float() |
| out = (l1 @ weights['layer2.weight'].T + weights['layer2.bias'] >= 0).float() |
| return int(out.item()) |
|
|
| print("Verifying exactly4outof5...") |
| errors = 0 |
| for i in range(32): |
| bits = [(i >> j) & 1 for j in range(5)] |
| result = exactly4of5(*bits) |
| expected = 1 if sum(bits) == 4 else 0 |
| if result != expected: |
| errors += 1 |
| print(f"ERROR: {bits} -> {result}, expected {expected}") |
| if errors == 0: |
| print("All 32 test cases passed!") |
| print(f"Magnitude: {sum(t.abs().sum().item() for t in weights.values()):.0f}") |
|
|