File size: 3,010 Bytes
239cc86
 
 
 
 
 
 
 
 
 
 
 
f1b3a60
239cc86
 
 
 
 
 
 
 
 
 
 
 
f1b3a60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239cc86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn.functional as F
import sys

ALPHABET = "0123456789+-*=()| "
N_DIM = len(ALPHABET)
CHAR_TO_INT = {c: i for i, c in enumerate(ALPHABET)}
INT_TO_CHAR = {i: c for i, c in enumerate(ALPHABET)}
SPACE_IDX = CHAR_TO_INT.get(' ', 0)
CONTEXT_SIZE = 64
MAX_GEN = 60

def _run(model, prompt):
    generated = ''
    for _ in range(MAX_GEN):
        prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE]
        indices = torch.tensor([CHAR_TO_INT.get(c, SPACE_IDX) for c in prefix], dtype=torch.long)
        x = F.one_hot(indices, num_classes=N_DIM).float().view(1, -1)
        with torch.no_grad():
            logits = model(x)
        next_token = torch.argmax(logits, dim=-1, keepdim=True)
        char = INT_TO_CHAR.get(next_token.item(), '')
        if not char or char == ' ' or char == '|':
            break
        generated += char
    return prompt + generated

def _count_operators(expr):
    count = 0
    for c in expr:
        if c in '+-*':
            count += 1
    return count

def _strip_outer_parens(expr):
    while expr and expr[0] == '(' and expr[-1] == ')':
        depth = 0
        balanced = True
        for i, c in enumerate(expr):
            if c == '(':
                depth += 1
            elif c == ')':
                depth -= 1
            if depth == 0 and i < len(expr) - 1:
                balanced = False
                break
        if balanced:
            expr = expr[1:-1]
        else:
            break
    return expr

def generate(model, prompt_text):
    clean = prompt_text.replace(' ', '')
    stripped = ''.join(c for c in clean if c in CHAR_TO_INT)
    if '=' not in stripped:
        return prompt_text + '?'
    prompt = stripped[:stripped.index('=') + 1]
    expr = _strip_outer_parens(stripped[:stripped.index('=')])
    if _count_operators(expr) == 1:
        wrapped = f"0+({expr})="
        chain = _run(model, wrapped)
        parts = chain.split('=')
        answer = parts[-1].rstrip('| ')
        if answer and answer != expr and not answer.startswith('-'):
            return prompt_text[:prompt_text.index('=') + 1] + answer
    return _run(model, prompt)

def main():
    if len(sys.argv) < 2:
        print("Usage: python infer.py <model.pt> [prompt]")
        print("       python infer.py <model.pt> (interactive mode)")
        sys.exit(1)

    model_path = sys.argv[1]
    model = torch.jit.load(model_path)
    model.eval()
    print(f"Loaded model from {model_path}", file=sys.stderr)

    if len(sys.argv) >= 3:
        prompt = sys.argv[2]
        result = generate(model, prompt)
        print(result)
    else:
        while True:
            try:
                prompt = input("Prompt > ")
                if prompt.lower() == 'q':
                    break
                print(generate(model, prompt))
            except (EOFError, KeyboardInterrupt):
                break
            except Exception as e:
                print(f"Error: {e}")

if __name__ == "__main__":
    main()