TrueMath / infer.py
CompactAI's picture
Upload infer.py
f1b3a60 verified
import torch
import torch.nn.functional as F
import sys
ALPHABET = "0123456789+-*=()| "
N_DIM = len(ALPHABET)
CHAR_TO_INT = {c: i for i, c in enumerate(ALPHABET)}
INT_TO_CHAR = {i: c for i, c in enumerate(ALPHABET)}
SPACE_IDX = CHAR_TO_INT.get(' ', 0)
CONTEXT_SIZE = 64
MAX_GEN = 60
def _run(model, prompt):
generated = ''
for _ in range(MAX_GEN):
prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE]
indices = torch.tensor([CHAR_TO_INT.get(c, SPACE_IDX) for c in prefix], dtype=torch.long)
x = F.one_hot(indices, num_classes=N_DIM).float().view(1, -1)
with torch.no_grad():
logits = model(x)
next_token = torch.argmax(logits, dim=-1, keepdim=True)
char = INT_TO_CHAR.get(next_token.item(), '')
if not char or char == ' ' or char == '|':
break
generated += char
return prompt + generated
def _count_operators(expr):
count = 0
for c in expr:
if c in '+-*':
count += 1
return count
def _strip_outer_parens(expr):
while expr and expr[0] == '(' and expr[-1] == ')':
depth = 0
balanced = True
for i, c in enumerate(expr):
if c == '(':
depth += 1
elif c == ')':
depth -= 1
if depth == 0 and i < len(expr) - 1:
balanced = False
break
if balanced:
expr = expr[1:-1]
else:
break
return expr
def generate(model, prompt_text):
clean = prompt_text.replace(' ', '')
stripped = ''.join(c for c in clean if c in CHAR_TO_INT)
if '=' not in stripped:
return prompt_text + '?'
prompt = stripped[:stripped.index('=') + 1]
expr = _strip_outer_parens(stripped[:stripped.index('=')])
if _count_operators(expr) == 1:
wrapped = f"0+({expr})="
chain = _run(model, wrapped)
parts = chain.split('=')
answer = parts[-1].rstrip('| ')
if answer and answer != expr and not answer.startswith('-'):
return prompt_text[:prompt_text.index('=') + 1] + answer
return _run(model, prompt)
def main():
if len(sys.argv) < 2:
print("Usage: python infer.py <model.pt> [prompt]")
print(" python infer.py <model.pt> (interactive mode)")
sys.exit(1)
model_path = sys.argv[1]
model = torch.jit.load(model_path)
model.eval()
print(f"Loaded model from {model_path}", file=sys.stderr)
if len(sys.argv) >= 3:
prompt = sys.argv[2]
result = generate(model, prompt)
print(result)
else:
while True:
try:
prompt = input("Prompt > ")
if prompt.lower() == 'q':
break
print(generate(model, prompt))
except (EOFError, KeyboardInterrupt):
break
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()