|
|
import torch |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weights = {} |
|
|
|
|
|
|
|
|
|
|
|
for k in range(8): |
|
|
w = [0.0] * 8 |
|
|
|
|
|
w[7-k] = 1.0 |
|
|
|
|
|
|
|
|
for j in range(7-k): |
|
|
w[j] = -1.0 |
|
|
bias = -1.0 |
|
|
weights[f'layer1.h{k}.weight'] = torch.tensor([w], dtype=torch.float32) |
|
|
weights[f'layer1.h{k}.bias'] = torch.tensor([bias], dtype=torch.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weights['layer2.y2.weight'] = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) |
|
|
weights['layer2.y2.bias'] = torch.tensor([-1.0], dtype=torch.float32) |
|
|
|
|
|
|
|
|
weights['layer2.y1.weight'] = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0]], dtype=torch.float32) |
|
|
weights['layer2.y1.bias'] = torch.tensor([-1.0], dtype=torch.float32) |
|
|
|
|
|
|
|
|
weights['layer2.y0.weight'] = torch.tensor([[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]], dtype=torch.float32) |
|
|
weights['layer2.y0.bias'] = torch.tensor([-1.0], dtype=torch.float32) |
|
|
|
|
|
|
|
|
weights['layer2.v.weight'] = torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) |
|
|
weights['layer2.v.bias'] = torch.tensor([-1.0], dtype=torch.float32) |
|
|
|
|
|
save_file(weights, 'model.safetensors') |
|
|
|
|
|
def priority_encode(inputs): |
|
|
inp = torch.tensor([float(x) for x in inputs]) |
|
|
|
|
|
h = [] |
|
|
for k in range(8): |
|
|
hk = int((inp @ weights[f'layer1.h{k}.weight'].T + weights[f'layer1.h{k}.bias'] >= 0).item()) |
|
|
h.append(hk) |
|
|
h_tensor = torch.tensor([float(x) for x in h]) |
|
|
|
|
|
y2 = int((h_tensor @ weights['layer2.y2.weight'].T + weights['layer2.y2.bias'] >= 0).item()) |
|
|
y1 = int((h_tensor @ weights['layer2.y1.weight'].T + weights['layer2.y1.bias'] >= 0).item()) |
|
|
y0 = int((h_tensor @ weights['layer2.y0.weight'].T + weights['layer2.y0.bias'] >= 0).item()) |
|
|
v = int((h_tensor @ weights['layer2.v.weight'].T + weights['layer2.v.bias'] >= 0).item()) |
|
|
return y2, y1, y0, v |
|
|
|
|
|
print("Verifying priorityencoder8...") |
|
|
errors = 0 |
|
|
for val in range(256): |
|
|
inputs = [(val >> (7-j)) & 1 for j in range(8)] |
|
|
y2, y1, y0, v = priority_encode(inputs) |
|
|
|
|
|
|
|
|
highest = -1 |
|
|
for k in range(7, -1, -1): |
|
|
if inputs[7-k]: |
|
|
highest = k |
|
|
break |
|
|
|
|
|
if highest == -1: |
|
|
exp_v = 0 |
|
|
exp_y2, exp_y1, exp_y0 = 0, 0, 0 |
|
|
else: |
|
|
exp_v = 1 |
|
|
exp_y2 = (highest >> 2) & 1 |
|
|
exp_y1 = (highest >> 1) & 1 |
|
|
exp_y0 = highest & 1 |
|
|
|
|
|
if v != exp_v or (v == 1 and (y2 != exp_y2 or y1 != exp_y1 or y0 != exp_y0)): |
|
|
errors += 1 |
|
|
if errors <= 3: |
|
|
print(f"ERROR: val={val}, inputs={inputs}, got ({y2},{y1},{y0},{v}), expected ({exp_y2},{exp_y1},{exp_y0},{exp_v})") |
|
|
|
|
|
if errors == 0: |
|
|
print("All 256 test cases passed!") |
|
|
else: |
|
|
print(f"FAILED: {errors} errors") |
|
|
|
|
|
mag = sum(t.abs().sum().item() for t in weights.values()) |
|
|
print(f"Magnitude: {mag:.0f}") |
|
|
|