phanerozoic's picture
Upload folder using huggingface_hub
45d0364 verified
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def full_adder(a, b, cin, w, prefix):
"""Single full adder: a + b + cin = (sum, cout)"""
inp = torch.tensor([float(a), float(b)])
or_out = (inp * w[f'{prefix}.ha1.sum.layer1.or.weight']).sum() + w[f'{prefix}.ha1.sum.layer1.or.bias'] >= 0
nand_out = (inp * w[f'{prefix}.ha1.sum.layer1.nand.weight']).sum() + w[f'{prefix}.ha1.sum.layer1.nand.bias'] >= 0
l1 = torch.tensor([float(or_out), float(nand_out)])
s1 = float((l1 * w[f'{prefix}.ha1.sum.layer2.weight']).sum() + w[f'{prefix}.ha1.sum.layer2.bias'] >= 0)
c1 = float((inp * w[f'{prefix}.ha1.carry.weight']).sum() + w[f'{prefix}.ha1.carry.bias'] >= 0)
inp2 = torch.tensor([s1, float(cin)])
or_out2 = (inp2 * w[f'{prefix}.ha2.sum.layer1.or.weight']).sum() + w[f'{prefix}.ha2.sum.layer1.or.bias'] >= 0
nand_out2 = (inp2 * w[f'{prefix}.ha2.sum.layer1.nand.weight']).sum() + w[f'{prefix}.ha2.sum.layer1.nand.bias'] >= 0
l2 = torch.tensor([float(or_out2), float(nand_out2)])
s = int((l2 * w[f'{prefix}.ha2.sum.layer2.weight']).sum() + w[f'{prefix}.ha2.sum.layer2.bias'] >= 0)
c2 = float((inp2 * w[f'{prefix}.ha2.carry.weight']).sum() + w[f'{prefix}.ha2.carry.bias'] >= 0)
cout = int((torch.tensor([c1, c2]) * w[f'{prefix}.carry_or.weight']).sum() + w[f'{prefix}.carry_or.bias'] >= 0)
return s, cout
def ripple_carry_8bit(a, b, cin, weights):
"""8-bit ripple carry adder.
a, b: lists of 8 bits each (LSB first)
cin: carry in
Returns: (sums, cout) where sums is 8-bit list (LSB first)
"""
carries = [cin]
sums = []
for i in range(8):
s, c = full_adder(a[i], b[i], carries[i], weights, f'fa{i}')
sums.append(s)
carries.append(c)
return sums, carries[8]
if __name__ == '__main__':
w = load_model()
print('8-bit Ripple Carry Adder')
tests = [(255, 1), (127, 128), (100, 55), (200, 200), (0, 0), (128, 128)]
for a_val, b_val in tests:
a = [(a_val >> i) & 1 for i in range(8)]
b = [(b_val >> i) & 1 for i in range(8)]
sums, cout = ripple_carry_8bit(a, b, 0, w)
result = sum(bit << i for i, bit in enumerate(sums)) + (cout << 8)
print(f'{a_val:3d} + {b_val:3d} = {result:3d}')