import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def min2(a1, a0, b1, b0, weights): """Minimum of two 2-bit unsigned integers. Returns (m1, m0) where m = min(a, b). """ inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)]) l1_keys = ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'both1_high', 'both1_low', 'a1', 'a0', 'b1', 'b0'] l1 = {k: int((inp @ weights[f'l1.{k}.weight'].T + weights[f'l1.{k}.bias'] >= 0).item()) for k in l1_keys} l1_out = torch.tensor([float(l1[k]) for k in l1_keys]) l2_keys = ['a1_eq_b1', 'a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0'] l2 = {k: int((l1_out @ weights[f'l2.{k}.weight'].T + weights[f'l2.{k}.bias'] >= 0).item()) for k in l2_keys} l2_out = torch.tensor([float(l2[k]) for k in l2_keys]) l3_keys = ['a_gt_b_part2', 'a0_neq_b0', 'a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1'] l3 = {k: int((l2_out @ weights[f'l3.{k}.weight'].T + weights[f'l3.{k}.bias'] >= 0).item()) for k in l3_keys} l3_out = torch.tensor([float(l3[k]) for k in l3_keys]) l4_keys = ['a_gt_b', 'a_eq_b', 'a1', 'a0', 'b1', 'b0'] l4 = {k: int((l3_out @ weights[f'l4.{k}.weight'].T + weights[f'l4.{k}.bias'] >= 0).item()) for k in l4_keys} l4_out = torch.tensor([float(l4[k]) for k in l4_keys]) l5_keys = ['a_le_b', 'a1', 'a0', 'b1', 'b0'] l5 = {k: int((l4_out @ weights[f'l5.{k}.weight'].T + weights[f'l5.{k}.bias'] >= 0).item()) for k in l5_keys} l5_out = torch.tensor([float(l5[k]) for k in l5_keys]) l6_keys = ['m1_a', 'm1_b', 'm0_a', 'm0_b'] l6 = {k: int((l5_out @ weights[f'l6.{k}.weight'].T + weights[f'l6.{k}.bias'] >= 0).item()) for k in l6_keys} l6_out = torch.tensor([float(l6[k]) for k in l6_keys]) m1 = int((l6_out @ weights['l7.m1.weight'].T + weights['l7.m1.bias'] >= 0).item()) m0 = int((l6_out @ weights['l7.m0.weight'].T + weights['l7.m0.bias'] >= 0).item()) return m1, m0 if __name__ == '__main__': w = load_model() print('min2 truth table:') print(' a b | min') print('-------+----') for a in range(4): for b in range(4): a1, a0 = (a >> 1) & 1, a & 1 b1, b0 = (b >> 1) & 1, b & 1 m1, m0 = min2(a1, a0, b1, b0, w) print(f' {a} {b} | {2*m1 + m0}')