CompactAI commited on
Commit
f1b3a60
·
verified ·
1 Parent(s): e15a52f

Upload infer.py

Browse files
Files changed (1) hide show
  1. infer.py +43 -7
infer.py CHANGED
@@ -10,12 +10,7 @@ SPACE_IDX = CHAR_TO_INT.get(' ', 0)
10
  CONTEXT_SIZE = 64
11
  MAX_GEN = 60
12
 
13
- def generate(model, prompt_text):
14
- clean = prompt_text.replace(' ', '')
15
- stripped = ''.join(c for c in clean if c in CHAR_TO_INT)
16
- if '=' not in stripped:
17
- return prompt_text + '?'
18
- prompt = stripped[:stripped.index('=') + 1]
19
  generated = ''
20
  for _ in range(MAX_GEN):
21
  prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE]
@@ -28,7 +23,48 @@ def generate(model, prompt_text):
28
  if not char or char == ' ' or char == '|':
29
  break
30
  generated += char
31
- return prompt_text[:prompt_text.index('=') + 1] + generated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def main():
34
  if len(sys.argv) < 2:
 
10
  CONTEXT_SIZE = 64
11
  MAX_GEN = 60
12
 
13
+ def _run(model, prompt):
 
 
 
 
 
14
  generated = ''
15
  for _ in range(MAX_GEN):
16
  prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE]
 
23
  if not char or char == ' ' or char == '|':
24
  break
25
  generated += char
26
+ return prompt + generated
27
+
28
+ def _count_operators(expr):
29
+ count = 0
30
+ for c in expr:
31
+ if c in '+-*':
32
+ count += 1
33
+ return count
34
+
35
+ def _strip_outer_parens(expr):
36
+ while expr and expr[0] == '(' and expr[-1] == ')':
37
+ depth = 0
38
+ balanced = True
39
+ for i, c in enumerate(expr):
40
+ if c == '(':
41
+ depth += 1
42
+ elif c == ')':
43
+ depth -= 1
44
+ if depth == 0 and i < len(expr) - 1:
45
+ balanced = False
46
+ break
47
+ if balanced:
48
+ expr = expr[1:-1]
49
+ else:
50
+ break
51
+ return expr
52
+
53
+ def generate(model, prompt_text):
54
+ clean = prompt_text.replace(' ', '')
55
+ stripped = ''.join(c for c in clean if c in CHAR_TO_INT)
56
+ if '=' not in stripped:
57
+ return prompt_text + '?'
58
+ prompt = stripped[:stripped.index('=') + 1]
59
+ expr = _strip_outer_parens(stripped[:stripped.index('=')])
60
+ if _count_operators(expr) == 1:
61
+ wrapped = f"0+({expr})="
62
+ chain = _run(model, wrapped)
63
+ parts = chain.split('=')
64
+ answer = parts[-1].rstrip('| ')
65
+ if answer and answer != expr and not answer.startswith('-'):
66
+ return prompt_text[:prompt_text.index('=') + 1] + answer
67
+ return _run(model, prompt)
68
 
69
  def main():
70
  if len(sys.argv) < 2: