import torch import torch.nn.functional as F import sys ALPHABET = "0123456789+-*=()| " N_DIM = len(ALPHABET) CHAR_TO_INT = {c: i for i, c in enumerate(ALPHABET)} INT_TO_CHAR = {i: c for i, c in enumerate(ALPHABET)} SPACE_IDX = CHAR_TO_INT.get(' ', 0) CONTEXT_SIZE = 64 MAX_GEN = 60 def generate(model, prompt_text): clean = prompt_text.replace(' ', '') stripped = ''.join(c for c in clean if c in CHAR_TO_INT) if '=' not in stripped: return prompt_text + '?' prompt = stripped[:stripped.index('=') + 1] generated = '' for _ in range(MAX_GEN): prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE] indices = torch.tensor([CHAR_TO_INT.get(c, SPACE_IDX) for c in prefix], dtype=torch.long) x = F.one_hot(indices, num_classes=N_DIM).float().view(1, -1) with torch.no_grad(): logits = model(x) next_token = torch.argmax(logits, dim=-1, keepdim=True) char = INT_TO_CHAR.get(next_token.item(), '') if not char or char == ' ' or char == '|': break generated += char return prompt_text[:prompt_text.index('=') + 1] + generated def main(): if len(sys.argv) < 2: print("Usage: python infer.py [prompt]") print(" python infer.py (interactive mode)") sys.exit(1) model_path = sys.argv[1] model = torch.jit.load(model_path) model.eval() print(f"Loaded model from {model_path}", file=sys.stderr) if len(sys.argv) >= 3: prompt = sys.argv[2] result = generate(model, prompt) print(result) else: while True: try: prompt = input("Prompt > ") if prompt.lower() == 'q': break print(generate(model, prompt)) except (EOFError, KeyboardInterrupt): break except Exception as e: print(f"Error: {e}") if __name__ == "__main__": main()