import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def segment_decode(b3, b2, b1, b0, weights): """BCD to 7-segment decoder. Returns dict with segments a-g.""" inp = torch.tensor([float(b3), float(b2), float(b1), float(b0)]) # Layer 1: detect which digit digits = [] for d in range(10): val = int((inp @ weights[f'd{d}.weight'].T + weights[f'd{d}.bias'] >= 0).item()) digits.append(val) digit_vec = torch.tensor([float(d) for d in digits]) # Layer 2: compute segments result = {} for seg in ['a', 'b', 'c', 'd', 'e', 'f', 'g']: val = int((digit_vec @ weights[f'{seg}.weight'].T + weights[f'{seg}.bias'] >= 0).item()) result[seg] = val return result def display_digit(segs): """ASCII art display of 7-segment pattern.""" a = ' ' + ('_' * 3 if segs['a'] else ' ' * 3) + ' ' b = ('|' if segs['f'] else ' ') + ' ' * 3 + ('|' if segs['b'] else ' ') g = ' ' + ('_' * 3 if segs['g'] else ' ' * 3) + ' ' c = ('|' if segs['e'] else ' ') + ' ' * 3 + ('|' if segs['c'] else ' ') d = ' ' + ('_' * 3 if segs['d'] else ' ' * 3) + ' ' return '\n'.join([a, b, g, c, d]) if __name__ == '__main__': w = load_model() print('7-Segment Display Decoder:') for digit in range(10): b3, b2, b1, b0 = (digit >> 3) & 1, (digit >> 2) & 1, (digit >> 1) & 1, digit & 1 result = segment_decode(b3, b2, b1, b0, w) pattern = ''.join([str(result[s]) for s in 'abcdefg']) print(f'\nDigit {digit} ({b3}{b2}{b1}{b0}) -> {pattern}') print(display_digit(result))