Multimodal_Math_Mentor / agents /solver_agent.py
Amit-kr26's picture
Initial commit: Multimodal Math Mentor
3c25c17
from __future__ import annotations
import json
from datetime import datetime
from agents.state import MathMentorState
from llm.client import get_llm
from tools.calculator import calculate
from tools.web_search import search as web_search
SOLVER_PROMPT = """\
You are an expert math solver for JEE-level problems. Solve the problem step by step.
Problem: {problem_text}
Topic: {topic}
Strategy: {strategy}
Retrieved knowledge base context:
{rag_context}
Similar previously solved problems:
{memory_context}
{web_context}
{retry_feedback}
Chat history (for follow-up context):
{chat_context}
Instructions:
1. Solve the problem showing your working concisely
2. Use LaTeX for all math: \\(inline\\) and $$display$$
3. If you need a computation verified, write: COMPUTE: sympy_expression
Example: COMPUTE: solve(x**2 - 4, x)
4. At the very end, write ANSWER on its own line followed by ONLY the final answer (short, just the result)
5. Then write CONFIDENCE on its own line followed by a number from 0.0 to 1.0
Keep the solution focused on the math working. Do NOT explain concepts or give tips — that is the explainer's job.
Write in plain markdown (NOT JSON).
"""
def _extract_computations(text: str) -> list[str]:
"""Extract COMPUTE: lines from solver output."""
computations = []
for line in text.split("\n"):
line = line.strip()
if line.startswith("COMPUTE:"):
expr = line[len("COMPUTE:"):].strip()
if expr:
computations.append(expr)
return computations
def _extract_answer(text: str) -> str:
"""Extract ANSWER: line from solver output."""
for line in text.split("\n"):
line = line.strip()
if line.startswith("ANSWER:"):
return line[len("ANSWER:"):].strip()
return ""
def _extract_confidence(text: str) -> float:
"""Extract CONFIDENCE: line from solver output."""
for line in text.split("\n"):
line = line.strip()
if line.startswith("CONFIDENCE:"):
try:
return float(line[len("CONFIDENCE:"):].strip())
except ValueError:
pass
return 0.7
def solver_node(state: MathMentorState) -> dict:
parsed = state.get("parsed_problem", {})
problem_text = parsed.get("problem_text", state.get("extracted_text", ""))
topic = state.get("problem_topic", "general")
strategy = state.get("solution_strategy", "")
tools_needed = state.get("tools_needed", [])
# Format RAG context
chunks = state.get("retrieved_chunks", [])
rag_context = "\n---\n".join(
f"[{c.get('source', 'unknown')}]: {c.get('text', '')}" for c in chunks[:5]
) or "No relevant context found."
# Format memory context
past = state.get("similar_past_problems", [])
memory_context = "\n---\n".join(
f"Q: {p.get('extracted_text', '')}\nA: {p.get('solution', '')}"
for p in past[:3]
) or "No similar past problems."
# Web search if router requested it
web_context = ""
if "web_search" in tools_needed:
try:
results = web_search(f"JEE math {topic} {problem_text[:100]}", max_results=3)
snippets = [f"- {r['title']}: {r['snippet']}" for r in results if r.get("snippet")]
if snippets:
web_context = "Web search results:\n" + "\n".join(snippets)
except Exception:
pass
# Retry feedback from verifier
retry_feedback = ""
verification = state.get("verification_result", {})
retries = state.get("solver_retries", 0)
if retries > 0 and verification:
issues = verification.get("issues", [])
suggestion = verification.get("suggestion", "")
feedback_parts = []
if issues:
feedback_parts.append("Previous attempt had issues: " + "; ".join(issues))
if suggestion:
feedback_parts.append("Suggestion: " + suggestion)
if feedback_parts:
retry_feedback = "IMPORTANT — Fix these issues from your previous attempt:\n" + "\n".join(feedback_parts)
# Chat context for follow-ups
chat_history = state.get("chat_history", [])
chat_context = ""
if chat_history and len(chat_history) > 1:
recent = chat_history[-6:]
chat_lines = []
for msg in recent:
content = msg.content if hasattr(msg, "content") else str(msg)
role = getattr(msg, "type", "unknown")
chat_lines.append(f"{role}: {content[:200]}")
chat_context = "\n".join(chat_lines)
chat_context = chat_context or "No previous conversation."
llm = get_llm(temperature=0.1)
response = llm.invoke(
SOLVER_PROMPT.format(
problem_text=problem_text,
topic=topic,
strategy=strategy,
rag_context=rag_context,
memory_context=memory_context,
web_context=web_context,
retry_feedback=retry_feedback,
chat_context=chat_context,
)
)
content = response.content if hasattr(response, "content") else str(response)
# Extract structured parts from markdown response
answer = _extract_answer(content) or content.split("\n")[-1]
confidence = _extract_confidence(content)
computations = _extract_computations(content)
# Run SymPy computations
computed_results = []
for expr in computations:
comp = calculate(expr)
computed_results.append(comp)
if comp.get("result"):
content += f"\n\n**Computed:** `{expr}` = {comp['result']}"
if comp.get("latex"):
content += f" \\(= {comp['latex']}\\)"
# Clean up the solution text (remove COMPUTE/ANSWER/CONFIDENCE lines for display)
solution_lines = []
for line in content.split("\n"):
stripped = line.strip()
if stripped.startswith(("COMPUTE:", "ANSWER:", "CONFIDENCE:")):
continue
solution_lines.append(line)
solution_display = "\n".join(solution_lines).strip()
retries = state.get("solver_retries", 0)
return {
"solution": solution_display,
"solution_steps": [solution_display],
"final_confidence": confidence,
"solver_retries": retries + 1,
"agent_trace": state.get("agent_trace", [])
+ [
{
"agent": "solver",
"action": "solved",
"summary": f"Confidence: {confidence}, computations: {len(computed_results)}",
"timestamp": datetime.now().isoformat(),
}
],
}