threshold-7segment / model.py
phanerozoic's picture
Upload folder using huggingface_hub
f85d9ac verified
import torch
from safetensors.torch import load_file
def load_model(path='model.safetensors'):
return load_file(path)
def segment_decode(b3, b2, b1, b0, weights):
"""BCD to 7-segment decoder. Returns dict with segments a-g."""
inp = torch.tensor([float(b3), float(b2), float(b1), float(b0)])
# Layer 1: detect which digit
digits = []
for d in range(10):
val = int((inp @ weights[f'd{d}.weight'].T + weights[f'd{d}.bias'] >= 0).item())
digits.append(val)
digit_vec = torch.tensor([float(d) for d in digits])
# Layer 2: compute segments
result = {}
for seg in ['a', 'b', 'c', 'd', 'e', 'f', 'g']:
val = int((digit_vec @ weights[f'{seg}.weight'].T + weights[f'{seg}.bias'] >= 0).item())
result[seg] = val
return result
def display_digit(segs):
"""ASCII art display of 7-segment pattern."""
a = ' ' + ('_' * 3 if segs['a'] else ' ' * 3) + ' '
b = ('|' if segs['f'] else ' ') + ' ' * 3 + ('|' if segs['b'] else ' ')
g = ' ' + ('_' * 3 if segs['g'] else ' ' * 3) + ' '
c = ('|' if segs['e'] else ' ') + ' ' * 3 + ('|' if segs['c'] else ' ')
d = ' ' + ('_' * 3 if segs['d'] else ' ' * 3) + ' '
return '\n'.join([a, b, g, c, d])
if __name__ == '__main__':
w = load_model()
print('7-Segment Display Decoder:')
for digit in range(10):
b3, b2, b1, b0 = (digit >> 3) & 1, (digit >> 2) & 1, (digit >> 1) & 1, digit & 1
result = segment_decode(b3, b2, b1, b0, w)
pattern = ''.join([str(result[s]) for s in 'abcdefg'])
print(f'\nDigit {digit} ({b3}{b2}{b1}{b0}) -> {pattern}')
print(display_digit(result))