|
|
import torch |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weights = { |
|
|
|
|
|
'y1.weight': torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32), |
|
|
'y1.bias': torch.tensor([-1.0], dtype=torch.float32), |
|
|
|
|
|
'y0.weight': torch.tensor([[2.0, -1.0, 1.0, 0.0]], dtype=torch.float32), |
|
|
'y0.bias': torch.tensor([-1.0], dtype=torch.float32), |
|
|
|
|
|
'v.weight': torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32), |
|
|
'v.bias': torch.tensor([-1.0], dtype=torch.float32), |
|
|
} |
|
|
save_file(weights, 'model.safetensors') |
|
|
|
|
|
def priority_encode(i3, i2, i1, i0): |
|
|
inp = torch.tensor([float(i3), float(i2), float(i1), float(i0)]) |
|
|
y1 = int((inp @ weights['y1.weight'].T + weights['y1.bias'] >= 0).item()) |
|
|
y0 = int((inp @ weights['y0.weight'].T + weights['y0.bias'] >= 0).item()) |
|
|
v = int((inp @ weights['v.weight'].T + weights['v.bias'] >= 0).item()) |
|
|
return y1, y0, v |
|
|
|
|
|
print("Verifying priorityencoder4...") |
|
|
errors = 0 |
|
|
for val in range(16): |
|
|
i3, i2, i1, i0 = (val >> 3) & 1, (val >> 2) & 1, (val >> 1) & 1, val & 1 |
|
|
y1, y0, v = priority_encode(i3, i2, i1, i0) |
|
|
|
|
|
|
|
|
if i3: |
|
|
exp_idx, exp_v = 3, 1 |
|
|
elif i2: |
|
|
exp_idx, exp_v = 2, 1 |
|
|
elif i1: |
|
|
exp_idx, exp_v = 1, 1 |
|
|
elif i0: |
|
|
exp_idx, exp_v = 0, 1 |
|
|
else: |
|
|
exp_idx, exp_v = 0, 0 |
|
|
|
|
|
exp_y1, exp_y0 = (exp_idx >> 1) & 1, exp_idx & 1 |
|
|
if exp_v == 0: |
|
|
exp_y1, exp_y0 = 0, 0 |
|
|
|
|
|
if v != exp_v or (v == 1 and (y1 != exp_y1 or y0 != exp_y0)): |
|
|
errors += 1 |
|
|
print(f"ERROR: {i3}{i2}{i1}{i0} -> y1={y1},y0={y0},v={v}, expected {exp_y1},{exp_y0},{exp_v}") |
|
|
|
|
|
if errors == 0: |
|
|
print("All 16 test cases passed!") |
|
|
|
|
|
mag = sum(t.abs().sum().item() for t in weights.values()) |
|
|
print(f"Magnitude: {mag:.0f}") |
|
|
|