trueact
TrueMath / infer.py
CompactAI's picture
Upload 6 files
239cc86 verified
raw
history blame
2.01 kB
import torch
import torch.nn.functional as F
import sys
ALPHABET = "0123456789+-*=()| "
N_DIM = len(ALPHABET)
CHAR_TO_INT = {c: i for i, c in enumerate(ALPHABET)}
INT_TO_CHAR = {i: c for i, c in enumerate(ALPHABET)}
SPACE_IDX = CHAR_TO_INT.get(' ', 0)
CONTEXT_SIZE = 64
MAX_GEN = 60
def generate(model, prompt_text):
clean = prompt_text.replace(' ', '')
stripped = ''.join(c for c in clean if c in CHAR_TO_INT)
if '=' not in stripped:
return prompt_text + '?'
prompt = stripped[:stripped.index('=') + 1]
generated = ''
for _ in range(MAX_GEN):
prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE]
indices = torch.tensor([CHAR_TO_INT.get(c, SPACE_IDX) for c in prefix], dtype=torch.long)
x = F.one_hot(indices, num_classes=N_DIM).float().view(1, -1)
with torch.no_grad():
logits = model(x)
next_token = torch.argmax(logits, dim=-1, keepdim=True)
char = INT_TO_CHAR.get(next_token.item(), '')
if not char or char == ' ' or char == '|':
break
generated += char
return prompt_text[:prompt_text.index('=') + 1] + generated
def main():
if len(sys.argv) < 2:
print("Usage: python infer.py <model.pt> [prompt]")
print(" python infer.py <model.pt> (interactive mode)")
sys.exit(1)
model_path = sys.argv[1]
model = torch.jit.load(model_path)
model.eval()
print(f"Loaded model from {model_path}", file=sys.stderr)
if len(sys.argv) >= 3:
prompt = sys.argv[2]
result = generate(model, prompt)
print(result)
else:
while True:
try:
prompt = input("Prompt > ")
if prompt.lower() == 'q':
break
print(generate(model, prompt))
except (EOFError, KeyboardInterrupt):
break
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()