File size: 3,010 Bytes
239cc86 f1b3a60 239cc86 f1b3a60 239cc86 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import torch
import torch.nn.functional as F
import sys
ALPHABET = "0123456789+-*=()| "
N_DIM = len(ALPHABET)
CHAR_TO_INT = {c: i for i, c in enumerate(ALPHABET)}
INT_TO_CHAR = {i: c for i, c in enumerate(ALPHABET)}
SPACE_IDX = CHAR_TO_INT.get(' ', 0)
CONTEXT_SIZE = 64
MAX_GEN = 60
def _run(model, prompt):
generated = ''
for _ in range(MAX_GEN):
prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE]
indices = torch.tensor([CHAR_TO_INT.get(c, SPACE_IDX) for c in prefix], dtype=torch.long)
x = F.one_hot(indices, num_classes=N_DIM).float().view(1, -1)
with torch.no_grad():
logits = model(x)
next_token = torch.argmax(logits, dim=-1, keepdim=True)
char = INT_TO_CHAR.get(next_token.item(), '')
if not char or char == ' ' or char == '|':
break
generated += char
return prompt + generated
def _count_operators(expr):
count = 0
for c in expr:
if c in '+-*':
count += 1
return count
def _strip_outer_parens(expr):
while expr and expr[0] == '(' and expr[-1] == ')':
depth = 0
balanced = True
for i, c in enumerate(expr):
if c == '(':
depth += 1
elif c == ')':
depth -= 1
if depth == 0 and i < len(expr) - 1:
balanced = False
break
if balanced:
expr = expr[1:-1]
else:
break
return expr
def generate(model, prompt_text):
clean = prompt_text.replace(' ', '')
stripped = ''.join(c for c in clean if c in CHAR_TO_INT)
if '=' not in stripped:
return prompt_text + '?'
prompt = stripped[:stripped.index('=') + 1]
expr = _strip_outer_parens(stripped[:stripped.index('=')])
if _count_operators(expr) == 1:
wrapped = f"0+({expr})="
chain = _run(model, wrapped)
parts = chain.split('=')
answer = parts[-1].rstrip('| ')
if answer and answer != expr and not answer.startswith('-'):
return prompt_text[:prompt_text.index('=') + 1] + answer
return _run(model, prompt)
def main():
if len(sys.argv) < 2:
print("Usage: python infer.py <model.pt> [prompt]")
print(" python infer.py <model.pt> (interactive mode)")
sys.exit(1)
model_path = sys.argv[1]
model = torch.jit.load(model_path)
model.eval()
print(f"Loaded model from {model_path}", file=sys.stderr)
if len(sys.argv) >= 3:
prompt = sys.argv[2]
result = generate(model, prompt)
print(result)
else:
while True:
try:
prompt = input("Prompt > ")
if prompt.lower() == 'q':
break
print(generate(model, prompt))
except (EOFError, KeyboardInterrupt):
break
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()
|