File size: 1,697 Bytes
f85d9ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def segment_decode(b3, b2, b1, b0, weights):
    """BCD to 7-segment decoder. Returns dict with segments a-g."""
    inp = torch.tensor([float(b3), float(b2), float(b1), float(b0)])

    # Layer 1: detect which digit
    digits = []
    for d in range(10):
        val = int((inp @ weights[f'd{d}.weight'].T + weights[f'd{d}.bias'] >= 0).item())
        digits.append(val)
    digit_vec = torch.tensor([float(d) for d in digits])

    # Layer 2: compute segments
    result = {}
    for seg in ['a', 'b', 'c', 'd', 'e', 'f', 'g']:
        val = int((digit_vec @ weights[f'{seg}.weight'].T + weights[f'{seg}.bias'] >= 0).item())
        result[seg] = val

    return result

def display_digit(segs):
    """ASCII art display of 7-segment pattern."""
    a = ' ' + ('_' * 3 if segs['a'] else ' ' * 3) + ' '
    b = ('|' if segs['f'] else ' ') + ' ' * 3 + ('|' if segs['b'] else ' ')
    g = ' ' + ('_' * 3 if segs['g'] else ' ' * 3) + ' '
    c = ('|' if segs['e'] else ' ') + ' ' * 3 + ('|' if segs['c'] else ' ')
    d = ' ' + ('_' * 3 if segs['d'] else ' ' * 3) + ' '
    return '\n'.join([a, b, g, c, d])

if __name__ == '__main__':
    w = load_model()
    print('7-Segment Display Decoder:')
    for digit in range(10):
        b3, b2, b1, b0 = (digit >> 3) & 1, (digit >> 2) & 1, (digit >> 1) & 1, digit & 1
        result = segment_decode(b3, b2, b1, b0, w)
        pattern = ''.join([str(result[s]) for s in 'abcdefg'])
        print(f'\nDigit {digit} ({b3}{b2}{b1}{b0}) -> {pattern}')
        print(display_digit(result))