| import torch |
| import torch.nn.functional as F |
| import sys |
|
|
| ALPHABET = "0123456789+-*=()| " |
| N_DIM = len(ALPHABET) |
| CHAR_TO_INT = {c: i for i, c in enumerate(ALPHABET)} |
| INT_TO_CHAR = {i: c for i, c in enumerate(ALPHABET)} |
| SPACE_IDX = CHAR_TO_INT.get(' ', 0) |
| CONTEXT_SIZE = 64 |
| MAX_GEN = 60 |
|
|
| def _run(model, prompt): |
| generated = '' |
| for _ in range(MAX_GEN): |
| prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE] |
| indices = torch.tensor([CHAR_TO_INT.get(c, SPACE_IDX) for c in prefix], dtype=torch.long) |
| x = F.one_hot(indices, num_classes=N_DIM).float().view(1, -1) |
| with torch.no_grad(): |
| logits = model(x) |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) |
| char = INT_TO_CHAR.get(next_token.item(), '') |
| if not char or char == ' ' or char == '|': |
| break |
| generated += char |
| return prompt + generated |
|
|
| def _count_operators(expr): |
| count = 0 |
| for c in expr: |
| if c in '+-*': |
| count += 1 |
| return count |
|
|
| def _strip_outer_parens(expr): |
| while expr and expr[0] == '(' and expr[-1] == ')': |
| depth = 0 |
| balanced = True |
| for i, c in enumerate(expr): |
| if c == '(': |
| depth += 1 |
| elif c == ')': |
| depth -= 1 |
| if depth == 0 and i < len(expr) - 1: |
| balanced = False |
| break |
| if balanced: |
| expr = expr[1:-1] |
| else: |
| break |
| return expr |
|
|
| def generate(model, prompt_text): |
| clean = prompt_text.replace(' ', '') |
| stripped = ''.join(c for c in clean if c in CHAR_TO_INT) |
| if '=' not in stripped: |
| return prompt_text + '?' |
| prompt = stripped[:stripped.index('=') + 1] |
| expr = _strip_outer_parens(stripped[:stripped.index('=')]) |
| if _count_operators(expr) == 1: |
| wrapped = f"0+({expr})=" |
| chain = _run(model, wrapped) |
| parts = chain.split('=') |
| answer = parts[-1].rstrip('| ') |
| if answer and answer != expr and not answer.startswith('-'): |
| return prompt_text[:prompt_text.index('=') + 1] + answer |
| return _run(model, prompt) |
|
|
| def main(): |
| if len(sys.argv) < 2: |
| print("Usage: python infer.py <model.pt> [prompt]") |
| print(" python infer.py <model.pt> (interactive mode)") |
| sys.exit(1) |
|
|
| model_path = sys.argv[1] |
| model = torch.jit.load(model_path) |
| model.eval() |
| print(f"Loaded model from {model_path}", file=sys.stderr) |
|
|
| if len(sys.argv) >= 3: |
| prompt = sys.argv[2] |
| result = generate(model, prompt) |
| print(result) |
| else: |
| while True: |
| try: |
| prompt = input("Prompt > ") |
| if prompt.lower() == 'q': |
| break |
| print(generate(model, prompt)) |
| except (EOFError, KeyboardInterrupt): |
| break |
| except Exception as e: |
| print(f"Error: {e}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|