import torch import torch.nn.functional as F import sys ALPHABET = "0123456789+-*=()| " N_DIM = len(ALPHABET) CHAR_TO_INT = {c: i for i, c in enumerate(ALPHABET)} INT_TO_CHAR = {i: c for i, c in enumerate(ALPHABET)} SPACE_IDX = CHAR_TO_INT.get(' ', 0) CONTEXT_SIZE = 64 MAX_GEN = 60 def _run(model, prompt): generated = '' for _ in range(MAX_GEN): prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE] indices = torch.tensor([CHAR_TO_INT.get(c, SPACE_IDX) for c in prefix], dtype=torch.long) x = F.one_hot(indices, num_classes=N_DIM).float().view(1, -1) with torch.no_grad(): logits = model(x) next_token = torch.argmax(logits, dim=-1, keepdim=True) char = INT_TO_CHAR.get(next_token.item(), '') if not char or char == ' ' or char == '|': break generated += char return prompt + generated def _count_operators(expr): count = 0 for c in expr: if c in '+-*': count += 1 return count def _strip_outer_parens(expr): while expr and expr[0] == '(' and expr[-1] == ')': depth = 0 balanced = True for i, c in enumerate(expr): if c == '(': depth += 1 elif c == ')': depth -= 1 if depth == 0 and i < len(expr) - 1: balanced = False break if balanced: expr = expr[1:-1] else: break return expr def generate(model, prompt_text): clean = prompt_text.replace(' ', '') stripped = ''.join(c for c in clean if c in CHAR_TO_INT) if '=' not in stripped: return prompt_text + '?' prompt = stripped[:stripped.index('=') + 1] expr = _strip_outer_parens(stripped[:stripped.index('=')]) if _count_operators(expr) == 1: wrapped = f"0+({expr})=" chain = _run(model, wrapped) parts = chain.split('=') answer = parts[-1].rstrip('| ') if answer and answer != expr and not answer.startswith('-'): return prompt_text[:prompt_text.index('=') + 1] + answer return _run(model, prompt) def main(): if len(sys.argv) < 2: print("Usage: python infer.py [prompt]") print(" python infer.py (interactive mode)") sys.exit(1) model_path = sys.argv[1] model = torch.jit.load(model_path) model.eval() print(f"Loaded model from {model_path}", file=sys.stderr) if len(sys.argv) >= 3: prompt = sys.argv[2] result = generate(model, prompt) print(result) else: while True: try: prompt = input("Prompt > ") if prompt.lower() == 'q': break print(generate(model, prompt)) except (EOFError, KeyboardInterrupt): break except Exception as e: print(f"Error: {e}") if __name__ == "__main__": main()