Multimodal_Math_Mentor / agents /explainer_agent.py
Amit-kr26's picture
Initial commit: Multimodal Math Mentor
3c25c17
from __future__ import annotations
import json
import re
from datetime import datetime
from agents.state import MathMentorState
from llm.client import get_llm
from tools.plotter import plot_function
EXPLAINER_PROMPT = """\
You are a friendly math tutor explaining a solution to a JEE student.
Problem: {problem_text}
Solution: {solution}
Steps: {steps}
Topic: {topic}
Write a clear, step-by-step explanation in markdown. Use LaTeX for all math expressions:
- Inline math: \\(x^2 + 1\\)
- Display math: $$x = \\frac{{-b \\pm \\sqrt{{b^2 - 4ac}}}}{{2a}}$$
Guidelines:
1. Explain each step in simple, student-friendly language
2. Use LaTeX for ALL math expressions (never write raw math like x^2, always wrap in LaTeX)
3. Highlight key concepts and formulas used
4. Mention common mistakes to avoid
5. End with the final answer clearly stated
Do NOT wrap your response in JSON or code fences. Write plain markdown directly.
"""
PLOT_EXTRACT_PROMPT = """\
Given this math problem and solution, extract a numpy-compatible Python expression to plot.
Use 'x' as the variable and numpy functions prefixed with 'np.' (e.g. np.sin, np.exp, np.sqrt, np.log).
Problem: {problem_text}
Solution: {solution}
Rules:
- Use ** for powers: x**2 not x^2
- Use np.sin(x), np.cos(x), np.tan(x), np.exp(x), np.log(x), np.sqrt(x), np.abs(x)
- Use np.pi for pi, np.e for e
- For polynomials just write them: x**3 - 2*x**2 + 1
- Only output the expression, nothing else. No explanation, no code fences.
- If there's no clear function to plot, respond with exactly: NONE
Expression:"""
def _try_generate_plot(problem_text: str, solution: str) -> str:
"""Ask the LLM to extract a plottable expression and generate a diagram."""
try:
llm = get_llm(temperature=0.0)
response = llm.invoke(
PLOT_EXTRACT_PROMPT.format(problem_text=problem_text, solution=solution)
)
expr = response.content.strip()
# Clean up
if expr.startswith("```"):
expr = expr.split("\n", 1)[1] if "\n" in expr else expr[3:]
if expr.endswith("```"):
expr = expr[:-3].strip()
if not expr or expr.upper() == "NONE" or len(expr) > 200:
return ""
# Generate the plot
path = plot_function(
expression=expr,
title=f"y = {expr}",
)
return path
except Exception:
return ""
def explainer_node(state: MathMentorState) -> dict:
parsed = state.get("parsed_problem", {})
problem_text = parsed.get("problem_text", state.get("extracted_text", ""))
solution = state.get("solution", "")
steps = state.get("solution_steps", [])
topic = state.get("problem_topic", "math")
llm = get_llm(temperature=0.3)
response = llm.invoke(
EXPLAINER_PROMPT.format(
problem_text=problem_text,
solution=solution,
steps=json.dumps(steps),
topic=topic,
)
)
explanation = response.content if hasattr(response, "content") else str(response)
# Strip any accidental code fences
explanation = explanation.strip()
if explanation.startswith("```"):
explanation = explanation.split("\n", 1)[1] if "\n" in explanation else explanation[3:]
if explanation.endswith("```"):
explanation = explanation[:-3].strip()
# Generate diagram if router requested plotter
diagram_path = ""
tools_needed = state.get("tools_needed", [])
if "plotter" in tools_needed:
diagram_path = _try_generate_plot(problem_text, solution)
return {
"explanation": explanation,
"diagram_path": diagram_path,
"agent_trace": state.get("agent_trace", [])
+ [
{
"agent": "explainer",
"action": "explained",
"summary": "Generated explanation" + (" + diagram" if diagram_path else ""),
"timestamp": datetime.now().isoformat(),
}
],
}