threshold-min2 / model.py
phanerozoic's picture
Upload folder using huggingface_hub
6e6bb88 verified
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def min2(a1, a0, b1, b0, weights):
"""Minimum of two 2-bit unsigned integers.
Returns (m1, m0) where m = min(a, b).
"""
inp = torch.tensor([float(a1), float(a0), float(b1), float(b0)])
l1_keys = ['a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'both1_high', 'both1_low', 'a1', 'a0', 'b1', 'b0']
l1 = {k: int((inp @ weights[f'l1.{k}.weight'].T + weights[f'l1.{k}.bias'] >= 0).item()) for k in l1_keys}
l1_out = torch.tensor([float(l1[k]) for k in l1_keys])
l2_keys = ['a1_eq_b1', 'a1_gt_b1', 'b1_gt_a1', 'a0_gt_b0', 'b0_gt_a0', 'a1', 'a0', 'b1', 'b0']
l2 = {k: int((l1_out @ weights[f'l2.{k}.weight'].T + weights[f'l2.{k}.bias'] >= 0).item()) for k in l2_keys}
l2_out = torch.tensor([float(l2[k]) for k in l2_keys])
l3_keys = ['a_gt_b_part2', 'a0_neq_b0', 'a1_gt_b1', 'a1', 'a0', 'b1', 'b0', 'a1_eq_b1']
l3 = {k: int((l2_out @ weights[f'l3.{k}.weight'].T + weights[f'l3.{k}.bias'] >= 0).item()) for k in l3_keys}
l3_out = torch.tensor([float(l3[k]) for k in l3_keys])
l4_keys = ['a_gt_b', 'a_eq_b', 'a1', 'a0', 'b1', 'b0']
l4 = {k: int((l3_out @ weights[f'l4.{k}.weight'].T + weights[f'l4.{k}.bias'] >= 0).item()) for k in l4_keys}
l4_out = torch.tensor([float(l4[k]) for k in l4_keys])
l5_keys = ['a_le_b', 'a1', 'a0', 'b1', 'b0']
l5 = {k: int((l4_out @ weights[f'l5.{k}.weight'].T + weights[f'l5.{k}.bias'] >= 0).item()) for k in l5_keys}
l5_out = torch.tensor([float(l5[k]) for k in l5_keys])
l6_keys = ['m1_a', 'm1_b', 'm0_a', 'm0_b']
l6 = {k: int((l5_out @ weights[f'l6.{k}.weight'].T + weights[f'l6.{k}.bias'] >= 0).item()) for k in l6_keys}
l6_out = torch.tensor([float(l6[k]) for k in l6_keys])
m1 = int((l6_out @ weights['l7.m1.weight'].T + weights['l7.m1.bias'] >= 0).item())
m0 = int((l6_out @ weights['l7.m0.weight'].T + weights['l7.m0.bias'] >= 0).item())
return m1, m0
if __name__ == '__main__':
w = load_model()
print('min2 truth table:')
print(' a b | min')
print('-------+----')
for a in range(4):
for b in range(4):
a1, a0 = (a >> 1) & 1, a & 1
b1, b0 = (b >> 1) & 1, b & 1
m1, m0 = min2(a1, a0, b1, b0, w)
print(f' {a} {b} | {2*m1 + m0}')