| import torch | |
| from safetensors.torch import load_file | |
| def load_model(path='model.safetensors'): | |
| return load_file(path) | |
| def min2(a1, a0, b1, b0, weights): | |
| """Minimum of two 2-bit unsigned integers. | |
| Returns (m1, m0) where m = min(a, b). | |
| """ | |
| inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)]) | |
| l1_keys = ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'both1_high', 'both1_low', 'a1', 'a0', 'b1', 'b0'] | |
| l1 = {k: int((inp @ weights[f'l1.{k}.weight'].T + weights[f'l1.{k}.bias'] >= 0).item()) for k in l1_keys} | |
| l1_out = torch.tensor([float(l1[k]) for k in l1_keys]) | |
| l2_keys = ['a1_eq_b1', 'a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0'] | |
| l2 = {k: int((l1_out @ weights[f'l2.{k}.weight'].T + weights[f'l2.{k}.bias'] >= 0).item()) for k in l2_keys} | |
| l2_out = torch.tensor([float(l2[k]) for k in l2_keys]) | |
| l3_keys = ['a_gt_b_part2', 'a0_neq_b0', 'a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1'] | |
| l3 = {k: int((l2_out @ weights[f'l3.{k}.weight'].T + weights[f'l3.{k}.bias'] >= 0).item()) for k in l3_keys} | |
| l3_out = torch.tensor([float(l3[k]) for k in l3_keys]) | |
| l4_keys = ['a_gt_b', 'a_eq_b', 'a1', 'a0', 'b1', 'b0'] | |
| l4 = {k: int((l3_out @ weights[f'l4.{k}.weight'].T + weights[f'l4.{k}.bias'] >= 0).item()) for k in l4_keys} | |
| l4_out = torch.tensor([float(l4[k]) for k in l4_keys]) | |
| l5_keys = ['a_le_b', 'a1', 'a0', 'b1', 'b0'] | |
| l5 = {k: int((l4_out @ weights[f'l5.{k}.weight'].T + weights[f'l5.{k}.bias'] >= 0).item()) for k in l5_keys} | |
| l5_out = torch.tensor([float(l5[k]) for k in l5_keys]) | |
| l6_keys = ['m1_a', 'm1_b', 'm0_a', 'm0_b'] | |
| l6 = {k: int((l5_out @ weights[f'l6.{k}.weight'].T + weights[f'l6.{k}.bias'] >= 0).item()) for k in l6_keys} | |
| l6_out = torch.tensor([float(l6[k]) for k in l6_keys]) | |
| m1 = int((l6_out @ weights['l7.m1.weight'].T + weights['l7.m1.bias'] >= 0).item()) | |
| m0 = int((l6_out @ weights['l7.m0.weight'].T + weights['l7.m0.bias'] >= 0).item()) | |
| return m1, m0 | |
| if __name__ == '__main__': | |
| w = load_model() | |
| print('min2 truth table:') | |
| print(' a b | min') | |
| print('-------+----') | |
| for a in range(4): | |
| for b in range(4): | |
| a1, a0 = (a >> 1) & 1, a & 1 | |
| b1, b0 = (b >> 1) & 1, b & 1 | |
| m1, m0 = min2(a1, a0, b1, b0, w) | |
| print(f' {a} {b} | {2*m1 + m0}') | |