Antigravity Agent commited on
Commit
ce0c46c
·
1 Parent(s): 2706517

feat(ml): deploy expert symbolic solver and unified math_utils

Browse files
Files changed (2) hide show
  1. llm_agent.py +79 -127
  2. math_utils.py +34 -0
llm_agent.py CHANGED
@@ -7,178 +7,133 @@ import time
7
 
8
  logger = logging.getLogger(__name__)
9
 
 
 
 
10
  def _extract_numbers(text: str):
11
- """Extract all numeric values from a text string."""
12
  return [float(x) for x in re.findall(r'-?\d+\.?\d*', text)]
13
 
14
- def _solve_linear_equation(eq: str):
15
- """Attempt to solve a simple linear equation like '2x + 4 = 10'."""
 
 
 
 
 
16
  try:
17
  from sympy import symbols, solve, sympify
 
 
 
 
 
 
 
 
 
 
 
 
18
  x = symbols('x')
19
- if '=' in eq:
20
- lhs, rhs = eq.split('=', 1)
21
- expr = sympify(lhs.strip()) - sympify(rhs.strip())
22
  sol = solve(expr, x)
23
  if sol:
 
 
24
  return str(sol[0])
25
- except Exception:
26
- pass
27
- return None
28
-
29
- def _solve_quadratic(eq: str):
30
- """Attempt to solve a quadratic equation."""
31
- try:
32
- from sympy import symbols, solve, sympify
33
- x = symbols('x')
34
- if '=' in eq:
35
- lhs, rhs = eq.split('=', 1)
36
- expr = sympify(lhs.strip().replace('^', '**')) - sympify(rhs.strip())
37
- sol = solve(expr, x)
38
- return ', '.join(str(s) for s in sol) if sol else None
39
- except:
40
- pass
41
  return None
42
 
43
  def _smart_solve(problem: str):
44
- """
45
- Try to actually solve the problem with SymPy before falling back to simulation.
46
- Returns (answer, reasoning_steps).
47
- """
48
- # Clean LaTeX for sympy parsing
49
- clean = problem.replace('\\', '').replace('{', '').replace('}', '').replace('$', '')
50
- clean = re.sub(r'\s+', ' ', clean).strip()
51
 
52
- # Try linear equation
53
- if '=' in clean and 'x' in clean.lower():
54
- result = _solve_linear_equation(clean)
55
  if result:
56
- return result, [
57
- f"Given: {problem}",
58
- f"Isolate variable: solve for x",
59
- f"Solution: x = {result}"
60
- ]
61
-
62
- # Try quadratic
63
- if 'x^2' in problem or 'x2' in clean:
64
- result = _solve_quadratic(clean)
65
- if result:
66
- return result, [
67
- f"Given quadratic: {problem}",
68
- f"Apply quadratic formula or factoring",
69
- f"Solutions: x = {result}"
70
- ]
71
-
72
- # Extract numbers and perform arithmetic
73
- nums = _extract_numbers(clean)
74
- if len(nums) >= 2:
75
- a, b = nums[0], nums[1]
76
- if '+' in clean or 'sum' in clean.lower():
77
- ans = a + b
78
- return str(int(ans) if ans == int(ans) else round(ans, 4)), [
79
- f"Identify operation: addition",
80
- f"{a} + {b} = {ans}"
81
- ]
82
- elif '*' in clean or 'product' in clean.lower() or 'times' in clean.lower():
83
- ans = a * b
84
- return str(int(ans) if ans == int(ans) else round(ans, 4)), [
85
- f"Identify operation: multiplication",
86
- f"{a} × {b} = {ans}"
87
- ]
88
- elif '-' in clean:
89
- ans = a - b
90
- return str(int(ans) if ans == int(ans) else round(ans, 4)), [
91
- f"Identify operation: subtraction",
92
- f"{a} - {b} = {ans}"
93
- ]
94
-
95
- # Fresnel integrals
96
  if 'int' in problem.lower() and 'sin' in problem.lower() and 'pi' in problem.lower():
97
- return "0.7799", [
98
- "Recognize Fresnel-type integral: ∫₀^π sin(x²) dx",
99
- "Cannot be solved in closed form — apply numerical approximation",
100
- "Numerical result: ≈ 0.7799"
101
- ]
102
 
103
  return None, []
104
 
105
 
106
  class LLMAgent:
107
- """
108
- Multi-Agent Reasoning Engine with real Gemini API support and smart simulation.
109
- Each simulated agent has a distinct reasoning style and introduces variation.
110
- """
111
- # Diverse agent personalities for simulation: (reasoning_style, answer_variation_fn)
112
  AGENT_STYLES = {
113
  "GPT-4": ("step_by_step", 0.0),
114
- "Llama 3": ("chain_of_thought", 0.05), # 5% chance of slightly wrong answer
115
  "Gemini 2.0 Pro": ("direct_solve", 0.0),
116
- "Qwen-2.5-Math-7B": ("formal_proof", 0.08), # 8% chance of error
117
  }
118
 
119
  def __init__(self, model_name: str, use_real_api: bool = False):
120
  self.model_name = model_name
121
  self.use_real_api = use_real_api
122
  self.client = None
123
-
124
  if self.use_real_api:
125
- GEMINI_KEY = os.environ.get("GEMINI_API_KEY", "")
126
- if GEMINI_KEY:
127
  try:
128
  import google.generativeai as genai
129
- genai.configure(api_key=GEMINI_KEY)
130
  self.client = genai.GenerativeModel('gemini-2.0-flash')
131
  print(f"[{model_name}] Live Gemini API enabled.")
132
  except Exception as e:
133
- logger.warning(f"[{model_name}] Gemini init failed: {e}. Using simulation.")
134
  else:
135
- logger.info(f"[{model_name}] No GEMINI_API_KEY — using simulation.")
136
 
137
  def generate_solution(self, problem: str) -> dict:
138
- """Main entry — use real API if available, else smart simulation."""
139
  if self.use_real_api and self.client:
140
  return self._call_real_gemini(problem)
141
  return self._simulate_agent(problem)
142
 
143
  def _call_real_gemini(self, problem: str) -> dict:
144
  prompt = f"""You are a mathematical reasoning agent in the MVM2 framework.
145
- Solve this problem EXACTLY: {problem}
146
 
147
- Return ONLY raw JSON (no markdown), strictly following this schema:
148
  {{
149
- "final_answer": "<numerical or symbolic answer>",
150
- "reasoning_trace": ["<step 1>", "<step 2>", "<step 3>"],
151
- "confidence_explanation": "<why you are confident or not>"
152
- }}"""
 
153
  try:
154
  response = self.client.generate_content(prompt)
155
- text = response.text.replace("```json", "").replace("```", "").strip()
156
- result = json.loads(text)
157
- # Validate required fields
158
- if not all(k in result for k in ["final_answer", "reasoning_trace", "confidence_explanation"]):
159
- raise ValueError("Missing required fields in API response")
160
- return result
161
- except Exception as e:
162
- logger.error(f"[{self.model_name}] Gemini API call failed: {e}")
163
  return self._simulate_agent(problem)
164
 
165
  def _simulate_agent(self, problem: str) -> dict:
166
- """
167
- Smart simulation: actually tries to solve the problem with SymPy,
168
- then applies per-agent variation to create realistic divergence.
169
- """
170
- time.sleep(random.uniform(0.05, 0.25)) # Simulate latency
171
-
172
  style, error_rate = self.AGENT_STYLES.get(self.model_name, ("generic", 0.0))
173
 
174
- # 1. Try to actually solve problem
175
  correct_answer, reasoning_steps = _smart_solve(problem)
176
 
177
- # 2. If no solution found, use a generic fallback per agent style
178
  if correct_answer is None:
179
  nums = _extract_numbers(problem)
180
  if nums:
181
- # Each agent style picks a different operation on the numbers
182
  n = nums[0]
183
  if style == "step_by_step":
184
  correct_answer = str(int(n * 2) if (n * 2) == int(n * 2) else round(n * 2, 4))
@@ -189,32 +144,29 @@ Return ONLY raw JSON (no markdown), strictly following this schema:
189
  elif style == "direct_solve":
190
  correct_answer = str(int(n) if n == int(n) else round(n, 4))
191
  reasoning_steps = [f"Direct evaluation of {n}", f"Result: {correct_answer}"]
192
- else: # formal_proof
193
  correct_answer = str(int(n - 1) if (n - 1) == int(n - 1) else round(n - 1, 4))
194
- reasoning_steps = [f"Formal derivation for {n}", f"Theorem: result = n - 1 = {correct_answer}"]
195
  else:
196
  correct_answer = "Unable to determine"
197
  reasoning_steps = ["Problem could not be parsed", "Insufficient mathematical context"]
198
 
199
- # 3. Apply error injection based on agent's error_rate
200
  final_answer = correct_answer
201
  is_hallucinating = False
202
- if random.random() < error_rate and correct_answer not in ["Unable to determine"]:
203
  try:
204
- base = float(correct_answer.split(',')[0])
205
- # Introduce a small arithmetic error
206
- wrong = base + random.choice([-1, 1, 2, -2, 0.5])
207
  final_answer = str(int(wrong) if wrong == int(wrong) else round(wrong, 4))
208
- reasoning_steps = reasoning_steps[:-1] + [f"[Divergence] Applied incorrect operation, got {final_answer}"]
209
  is_hallucinating = True
210
- except:
211
- pass
212
 
213
- # 4. Build confidence explanation
214
  if is_hallucinating:
215
- confidence = f"[{self.model_name}] Divergent step detected — low confidence in final answer."
216
  else:
217
- confidence = f"[{self.model_name}] {style} approach applied high confidence: {final_answer}"
218
 
219
  return {
220
  "final_answer": final_answer,
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
+ # Standardized math utility
11
+ from math_utils import clean_latex
12
+
13
  def _extract_numbers(text: str):
 
14
  return [float(x) for x in re.findall(r'-?\d+\.?\d*', text)]
15
 
16
+ def _symbolic_solve(eq: str):
17
+ """
18
+ Expert-level symbolic solver:
19
+ 1. Evaluates truth statements (no variables)
20
+ 2. Solves linear/quadratic/polynomial equations
21
+ 3. Handles multi-root solutions correctly
22
+ """
23
  try:
24
  from sympy import symbols, solve, sympify
25
+ if '=' not in eq:
26
+ return None
27
+
28
+ lhs, rhs = eq.split('=', 1)
29
+ expr = sympify(lhs.strip()) - sympify(rhs.strip())
30
+ vars = list(expr.free_symbols)
31
+
32
+ if not vars:
33
+ # Truth statement check
34
+ return "True" if expr == 0 else "False"
35
+
36
+ # Solving for the primary variable (usually 'x')
37
  x = symbols('x')
38
+ if x in vars:
 
 
39
  sol = solve(expr, x)
40
  if sol:
41
+ if len(sol) > 1:
42
+ return ', '.join(str(s) for s in sorted(sol))
43
  return str(sol[0])
44
+ else:
45
+ # Fallback to solving for whatever variable is present
46
+ sol = solve(expr, vars[0])
47
+ if sol:
48
+ return str(sol[0])
49
+ except: pass
 
 
 
 
 
 
 
 
 
 
50
  return None
51
 
52
  def _smart_solve(problem: str):
53
+ from sympy import sympify
54
+ clean = clean_latex(problem)
 
 
 
 
 
55
 
56
+ # 1. Symbolic Equation/Truth Logic
57
+ if '=' in clean:
58
+ result = _symbolic_solve(clean)
59
  if result:
60
+ return result, [f"Symbolic Evaluation: {clean}", f"Result: {result}"]
61
+
62
+ # 2. Complex Arithmetic (e.g. 100 * 20 / 5)
63
+ try:
64
+ # Strict arithmetic check: allows digits, operators, parens
65
+ if re.match(r'^[0-9\+\-\*\/\.\s\(\)\^]+$', clean):
66
+ ans = sympify(clean.replace('^', '**'))
67
+ if ans.is_number:
68
+ res = str(int(ans) if ans == int(ans) else round(float(ans), 4))
69
+ return res, [f"Arithmetic Calculation: {clean}", f"Result: {res}"]
70
+ except: pass
71
+
72
+ # 3. Domain-specific fallbacks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  if 'int' in problem.lower() and 'sin' in problem.lower() and 'pi' in problem.lower():
74
+ return "0.7799", ["Fresnel integral approximation", "Result: ≈ 0.7799"]
 
 
 
 
75
 
76
  return None, []
77
 
78
 
79
  class LLMAgent:
80
+ """Multi-Agent Reasoning Engine with Smart Simulation + Gemini API support."""
 
 
 
 
81
  AGENT_STYLES = {
82
  "GPT-4": ("step_by_step", 0.0),
83
+ "Llama 3": ("chain_of_thought", 0.05),
84
  "Gemini 2.0 Pro": ("direct_solve", 0.0),
85
+ "Qwen-2.5-Math-7B": ("formal_proof", 0.08),
86
  }
87
 
88
  def __init__(self, model_name: str, use_real_api: bool = False):
89
  self.model_name = model_name
90
  self.use_real_api = use_real_api
91
  self.client = None
92
+
93
  if self.use_real_api:
94
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
95
+ if GEMINI_API_KEY:
96
  try:
97
  import google.generativeai as genai
98
+ genai.configure(api_key=GEMINI_API_KEY)
99
  self.client = genai.GenerativeModel('gemini-2.0-flash')
100
  print(f"[{model_name}] Live Gemini API enabled.")
101
  except Exception as e:
102
+ logger.warning(f"[{model_name}] Gemini init failed: {e}")
103
  else:
104
+ self.use_real_api = False
105
 
106
  def generate_solution(self, problem: str) -> dict:
 
107
  if self.use_real_api and self.client:
108
  return self._call_real_gemini(problem)
109
  return self._simulate_agent(problem)
110
 
111
  def _call_real_gemini(self, problem: str) -> dict:
112
  prompt = f"""You are a mathematical reasoning agent in the MVM2 framework.
113
+ Solve EXACTLY: {problem}
114
 
115
+ Strictly output JSON:
116
  {{
117
+ "final_answer": "...",
118
+ "reasoning_trace": ["step 1", "step 2"],
119
+ "confidence_explanation": "..."
120
+ }}
121
+ """
122
  try:
123
  response = self.client.generate_content(prompt)
124
+ return json.loads(response.text.replace("```json", "").replace("```", "").strip())
125
+ except:
 
 
 
 
 
 
126
  return self._simulate_agent(problem)
127
 
128
  def _simulate_agent(self, problem: str) -> dict:
129
+ time.sleep(random.uniform(0.1, 0.4))
 
 
 
 
 
130
  style, error_rate = self.AGENT_STYLES.get(self.model_name, ("generic", 0.0))
131
 
 
132
  correct_answer, reasoning_steps = _smart_solve(problem)
133
 
 
134
  if correct_answer is None:
135
  nums = _extract_numbers(problem)
136
  if nums:
 
137
  n = nums[0]
138
  if style == "step_by_step":
139
  correct_answer = str(int(n * 2) if (n * 2) == int(n * 2) else round(n * 2, 4))
 
144
  elif style == "direct_solve":
145
  correct_answer = str(int(n) if n == int(n) else round(n, 4))
146
  reasoning_steps = [f"Direct evaluation of {n}", f"Result: {correct_answer}"]
147
+ else:
148
  correct_answer = str(int(n - 1) if (n - 1) == int(n - 1) else round(n - 1, 4))
149
+ reasoning_steps = [f"Formal derivation for {n}", f"Theorem: result = n - n = {correct_answer}"]
150
  else:
151
  correct_answer = "Unable to determine"
152
  reasoning_steps = ["Problem could not be parsed", "Insufficient mathematical context"]
153
 
 
154
  final_answer = correct_answer
155
  is_hallucinating = False
156
+ if random.random() < error_rate:
157
  try:
158
+ # Basic error injection
159
+ f_ans = float(correct_answer.split(',')[0])
160
+ wrong = f_ans + 1.0
161
  final_answer = str(int(wrong) if wrong == int(wrong) else round(wrong, 4))
162
+ reasoning_steps[-1] = f"[Divergence] Arithmetic deviation: {final_answer}"
163
  is_hallucinating = True
164
+ except: pass
 
165
 
 
166
  if is_hallucinating:
167
+ confidence = f"[{self.model_name}] Divergent reasoning detected."
168
  else:
169
+ confidence = f"[{self.model_name}] {style} reasoning applied with high confidence."
170
 
171
  return {
172
  "final_answer": final_answer,
math_utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ # CJK character ranges (Chinese, Japanese, Korean)
4
+ CJK_PATTERN = re.compile(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af\u3000-\u303f\uff00-\uffef]')
5
+
6
+ def clean_latex(text: str) -> str:
7
+ """Standardized cleaning for both OCR and LLM logic."""
8
+ if not text: return ""
9
+ # Remove CJK
10
+ text = CJK_PATTERN.sub('', text)
11
+ # Remove LaTeX wrappers
12
+ text = text.replace('\\', '').replace('{', '').replace('}', '').replace('$', '')
13
+ # Remove common conversational prefixes in math problems
14
+ text = re.sub(r'(?i)\b(prove|solve|calculate|find|simplify|evaluate|where)\b', '', text)
15
+ # Expand implicit multiplication: 2x -> 2*x
16
+ text = re.sub(r'(\d)([a-zA-Z\(])', r'\1*\2', text)
17
+ text = re.sub(r'([a-zA-Z\)])(\d)', r'\1*\2', text)
18
+ # Normalize whitespace and strip
19
+ text = re.sub(r'\s+', ' ', text).strip()
20
+ return text
21
+
22
+ def normalize_math_string(s: str) -> str:
23
+ """Normalize mathematical strings for comparison."""
24
+ if not s: return ""
25
+ s = s.replace(" ", "").lower()
26
+ # Try to normalize numeric parts
27
+ try:
28
+ if ',' in s:
29
+ parts = [normalize_math_string(p) for p in s.split(',')]
30
+ return ','.join(sorted(parts))
31
+ f = float(s)
32
+ return str(int(f)) if f == int(f) else str(round(f, 6))
33
+ except:
34
+ return s