|
|
import torch
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
def load_model(path='model.safetensors'):
|
|
|
return load_file(path)
|
|
|
|
|
|
def half_sub(a, b, w, prefix):
|
|
|
inp = torch.tensor([float(a), float(b)])
|
|
|
l1 = (inp @ w[f'{prefix}.xor.layer1.weight'].T + w[f'{prefix}.xor.layer1.bias'] >= 0).float()
|
|
|
d = float((l1 @ w[f'{prefix}.xor.layer2.weight'].T + w[f'{prefix}.xor.layer2.bias'] >= 0).item())
|
|
|
borrow = float((inp @ w[f'{prefix}.borrow.weight'].T + w[f'{prefix}.borrow.bias'] >= 0).item())
|
|
|
return d, borrow
|
|
|
|
|
|
def full_sub(a, b, bin_in, w, prefix):
|
|
|
d1, b1 = half_sub(a, b, w, f'{prefix}.hs1')
|
|
|
d, b2 = half_sub(d1, bin_in, w, f'{prefix}.hs2')
|
|
|
bout = int((torch.tensor([b1, b2]) @ w[f'{prefix}.bout.weight'].T + w[f'{prefix}.bout.bias'] >= 0).item())
|
|
|
return int(d), bout
|
|
|
|
|
|
def greater_than(a, b, weights):
|
|
|
"""8-bit greater-than comparator.
|
|
|
a, b: lists of 8 bits each (LSB first)
|
|
|
Returns: 1 if a > b, 0 otherwise
|
|
|
|
|
|
Computes b - a; borrow out means b < a, i.e., a > b.
|
|
|
"""
|
|
|
borrows = [0]
|
|
|
for i in range(8):
|
|
|
|
|
|
d, bout = full_sub(b[i], a[i], borrows[i], weights, f'fs{i}')
|
|
|
borrows.append(bout)
|
|
|
return borrows[8]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
w = load_model()
|
|
|
print('8-bit GreaterThan Comparator')
|
|
|
print('a > b tests:')
|
|
|
tests = [(0, 0), (1, 0), (0, 1), (128, 127), (0, 255), (255, 0), (100, 100), (100, 99)]
|
|
|
for a_val, b_val in tests:
|
|
|
a = [(a_val >> i) & 1 for i in range(8)]
|
|
|
b = [(b_val >> i) & 1 for i in range(8)]
|
|
|
result = greater_than(a, b, w)
|
|
|
print(f'{a_val:3d} > {b_val:3d} = {result}')
|
|
|
|