|
|
import torch
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
def load_model(path='model.safetensors'):
|
|
|
return load_file(path)
|
|
|
|
|
|
def segment_decode(b3, b2, b1, b0, weights):
|
|
|
"""BCD to 7-segment decoder. Returns dict with segments a-g."""
|
|
|
inp = torch.tensor([float(b3), float(b2), float(b1), float(b0)])
|
|
|
|
|
|
|
|
|
digits = []
|
|
|
for d in range(10):
|
|
|
val = int((inp @ weights[f'd{d}.weight'].T + weights[f'd{d}.bias'] >= 0).item())
|
|
|
digits.append(val)
|
|
|
digit_vec = torch.tensor([float(d) for d in digits])
|
|
|
|
|
|
|
|
|
result = {}
|
|
|
for seg in ['a', 'b', 'c', 'd', 'e', 'f', 'g']:
|
|
|
val = int((digit_vec @ weights[f'{seg}.weight'].T + weights[f'{seg}.bias'] >= 0).item())
|
|
|
result[seg] = val
|
|
|
|
|
|
return result
|
|
|
|
|
|
def display_digit(segs):
|
|
|
"""ASCII art display of 7-segment pattern."""
|
|
|
a = ' ' + ('_' * 3 if segs['a'] else ' ' * 3) + ' '
|
|
|
b = ('|' if segs['f'] else ' ') + ' ' * 3 + ('|' if segs['b'] else ' ')
|
|
|
g = ' ' + ('_' * 3 if segs['g'] else ' ' * 3) + ' '
|
|
|
c = ('|' if segs['e'] else ' ') + ' ' * 3 + ('|' if segs['c'] else ' ')
|
|
|
d = ' ' + ('_' * 3 if segs['d'] else ' ' * 3) + ' '
|
|
|
return '\n'.join([a, b, g, c, d])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
w = load_model()
|
|
|
print('7-Segment Display Decoder:')
|
|
|
for digit in range(10):
|
|
|
b3, b2, b1, b0 = (digit >> 3) & 1, (digit >> 2) & 1, (digit >> 1) & 1, digit & 1
|
|
|
result = segment_decode(b3, b2, b1, b0, w)
|
|
|
pattern = ''.join([str(result[s]) for s in 'abcdefg'])
|
|
|
print(f'\nDigit {digit} ({b3}{b2}{b1}{b0}) -> {pattern}')
|
|
|
print(display_digit(result))
|
|
|
|