| import torch
|
| from safetensors.torch import save_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| weights = {}
|
|
|
|
|
| weights['l1.a1_gt_b1.weight'] = torch.tensor([[1.0, 0.0, -1.0, 0.0]], dtype=torch.float32)
|
| weights['l1.a1_gt_b1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| weights['l1.b1_gt_a1.weight'] = torch.tensor([[-1.0, 0.0, 1.0, 0.0]], dtype=torch.float32)
|
| weights['l1.b1_gt_a1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| weights['l1.a0_gt_b0.weight'] = torch.tensor([[0.0, 1.0, 0.0, -1.0]], dtype=torch.float32)
|
| weights['l1.a0_gt_b0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| weights['l1.b0_gt_a0.weight'] = torch.tensor([[0.0, -1.0, 0.0, 1.0]], dtype=torch.float32)
|
| weights['l1.b0_gt_a0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| weights['l1.both1_high.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0]], dtype=torch.float32)
|
| weights['l1.both1_high.bias'] = torch.tensor([-2.0], dtype=torch.float32)
|
|
|
| weights['l1.both1_low.weight'] = torch.tensor([[-1.0, 0.0, -1.0, 0.0]], dtype=torch.float32)
|
| weights['l1.both1_low.bias'] = torch.tensor([0.0], dtype=torch.float32)
|
|
|
| weights['l1.a1.weight'] = torch.tensor([[1.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['l1.a1.bias'] = torch.tensor([-0.5], dtype=torch.float32)
|
| weights['l1.a0.weight'] = torch.tensor([[0.0, 1.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['l1.a0.bias'] = torch.tensor([-0.5], dtype=torch.float32)
|
| weights['l1.b1.weight'] = torch.tensor([[0.0, 0.0, 1.0, 0.0]], dtype=torch.float32)
|
| weights['l1.b1.bias'] = torch.tensor([-0.5], dtype=torch.float32)
|
| weights['l1.b0.weight'] = torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32)
|
| weights['l1.b0.bias'] = torch.tensor([-0.5], dtype=torch.float32)
|
|
|
|
|
| weights['l2.a1_eq_b1.weight'] = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['l2.a1_eq_b1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| for v in ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0']:
|
| idx = ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'both1_high', 'both1_low', 'a1', 'a0', 'b1', 'b0'].index(v)
|
| w = [0.0] * 10
|
| w[idx] = 1.0
|
| weights[f'l2.{v}.weight'] = torch.tensor([w], dtype=torch.float32)
|
| weights[f'l2.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32)
|
|
|
|
|
| weights['l3.a_gt_b_part2.weight'] = torch.tensor([[1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['l3.a_gt_b_part2.bias'] = torch.tensor([-2.0], dtype=torch.float32)
|
|
|
| weights['l3.a0_neq_b0.weight'] = torch.tensor([[0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['l3.a0_neq_b0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| for v in ['a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1']:
|
| if v == 'a1_eq_b1':
|
| idx = 0
|
| else:
|
| idx = ['a1_eq_b1', 'a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0'].index(v)
|
| w = [0.0] * 9
|
| w[idx] = 1.0
|
| weights[f'l3.{v}.weight'] = torch.tensor([w], dtype=torch.float32)
|
| weights[f'l3.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32)
|
|
|
|
|
| weights['l4.a_gt_b.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['l4.a_gt_b.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| weights['l4.a_eq_b.weight'] = torch.tensor([[0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]], dtype=torch.float32)
|
| weights['l4.a_eq_b.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| for v in ['a1', 'a0', 'b1', 'b0']:
|
| idx = ['a_gt_b_part2', 'a0_neq_b0', 'a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1'].index(v)
|
| w = [0.0] * 8
|
| w[idx] = 1.0
|
| weights[f'l4.{v}.weight'] = torch.tensor([w], dtype=torch.float32)
|
| weights[f'l4.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32)
|
|
|
|
|
| weights['l5.a_le_b.weight'] = torch.tensor([[-1.0, 1.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['l5.a_le_b.bias'] = torch.tensor([0.0], dtype=torch.float32)
|
|
|
| for v in ['a1', 'a0', 'b1', 'b0']:
|
| idx = ['a_gt_b', 'a_eq_b', 'a1', 'a0', 'b1', 'b0'].index(v)
|
| w = [0.0] * 6
|
| w[idx] = 1.0
|
| weights[f'l5.{v}.weight'] = torch.tensor([w], dtype=torch.float32)
|
| weights[f'l5.{v}.bias'] = torch.tensor([-0.5], dtype=torch.float32)
|
|
|
|
|
|
|
| weights['l6.m1_a.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['l6.m1_a.bias'] = torch.tensor([-2.0], dtype=torch.float32)
|
|
|
| weights['l6.m1_b.weight'] = torch.tensor([[-1.0, 0.0, 0.0, 1.0, 0.0]], dtype=torch.float32)
|
| weights['l6.m1_b.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| weights['l6.m0_a.weight'] = torch.tensor([[1.0, 0.0, 1.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['l6.m0_a.bias'] = torch.tensor([-2.0], dtype=torch.float32)
|
|
|
| weights['l6.m0_b.weight'] = torch.tensor([[-1.0, 0.0, 0.0, 0.0, 1.0]], dtype=torch.float32)
|
| weights['l6.m0_b.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
|
|
| weights['l7.m1.weight'] = torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32)
|
| weights['l7.m1.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| weights['l7.m0.weight'] = torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float32)
|
| weights['l7.m0.bias'] = torch.tensor([-1.0], dtype=torch.float32)
|
|
|
| save_file(weights, 'model.safetensors')
|
|
|
|
|
| def min2(a1, a0, b1, b0):
|
| inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)])
|
|
|
| l1_keys = ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'both1_high', 'both1_low', 'a1', 'a0', 'b1', 'b0']
|
| l1 = {k: int((inp @ weights[f'l1.{k}.weight'].T + weights[f'l1.{k}.bias'] >= 0).item()) for k in l1_keys}
|
| l1_out = torch.tensor([float(l1[k]) for k in l1_keys])
|
|
|
| l2_keys = ['a1_eq_b1', 'a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0']
|
| l2 = {k: int((l1_out @ weights[f'l2.{k}.weight'].T + weights[f'l2.{k}.bias'] >= 0).item()) for k in l2_keys}
|
| l2_out = torch.tensor([float(l2[k]) for k in l2_keys])
|
|
|
| l3_keys = ['a_gt_b_part2', 'a0_neq_b0', 'a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1']
|
| l3 = {k: int((l2_out @ weights[f'l3.{k}.weight'].T + weights[f'l3.{k}.bias'] >= 0).item()) for k in l3_keys}
|
| l3_out = torch.tensor([float(l3[k]) for k in l3_keys])
|
|
|
| l4_keys = ['a_gt_b', 'a_eq_b', 'a1', 'a0', 'b1', 'b0']
|
| l4 = {k: int((l3_out @ weights[f'l4.{k}.weight'].T + weights[f'l4.{k}.bias'] >= 0).item()) for k in l4_keys}
|
| l4_out = torch.tensor([float(l4[k]) for k in l4_keys])
|
|
|
| l5_keys = ['a_le_b', 'a1', 'a0', 'b1', 'b0']
|
| l5 = {k: int((l4_out @ weights[f'l5.{k}.weight'].T + weights[f'l5.{k}.bias'] >= 0).item()) for k in l5_keys}
|
| l5_out = torch.tensor([float(l5[k]) for k in l5_keys])
|
|
|
| l6_keys = ['m1_a', 'm1_b', 'm0_a', 'm0_b']
|
| l6 = {k: int((l5_out @ weights[f'l6.{k}.weight'].T + weights[f'l6.{k}.bias'] >= 0).item()) for k in l6_keys}
|
| l6_out = torch.tensor([float(l6[k]) for k in l6_keys])
|
|
|
| m1 = int((l6_out @ weights['l7.m1.weight'].T + weights['l7.m1.bias'] >= 0).item())
|
| m0 = int((l6_out @ weights['l7.m0.weight'].T + weights['l7.m0.bias'] >= 0).item())
|
|
|
| return m1, m0
|
|
|
| print("Verifying min2...")
|
| errors = 0
|
| for a in range(4):
|
| for b in range(4):
|
| a1, a0 = (a >> 1) & 1, a & 1
|
| b1, b0 = (b >> 1) & 1, b & 1
|
| m1, m0 = min2(a1, a0, b1, b0)
|
| result = 2*m1 + m0
|
| expected = min(a, b)
|
| if result != expected:
|
| errors += 1
|
| print(f"ERROR: min({a}, {b}) = {result}, expected {expected}")
|
|
|
| if errors == 0:
|
| print("All 16 test cases passed!")
|
| else:
|
| print(f"FAILED: {errors} errors")
|
|
|
| mag = sum(t.abs().sum().item() for t in weights.values())
|
| print(f"Magnitude: {mag:.0f}")
|
|
|