Spaces:
Sleeping
Sleeping
File size: 3,991 Bytes
3c25c17 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | from __future__ import annotations
import json
import re
from datetime import datetime
from agents.state import MathMentorState
from llm.client import get_llm
from tools.plotter import plot_function
EXPLAINER_PROMPT = """\
You are a friendly math tutor explaining a solution to a JEE student.
Problem: {problem_text}
Solution: {solution}
Steps: {steps}
Topic: {topic}
Write a clear, step-by-step explanation in markdown. Use LaTeX for all math expressions:
- Inline math: \\(x^2 + 1\\)
- Display math: $$x = \\frac{{-b \\pm \\sqrt{{b^2 - 4ac}}}}{{2a}}$$
Guidelines:
1. Explain each step in simple, student-friendly language
2. Use LaTeX for ALL math expressions (never write raw math like x^2, always wrap in LaTeX)
3. Highlight key concepts and formulas used
4. Mention common mistakes to avoid
5. End with the final answer clearly stated
Do NOT wrap your response in JSON or code fences. Write plain markdown directly.
"""
PLOT_EXTRACT_PROMPT = """\
Given this math problem and solution, extract a numpy-compatible Python expression to plot.
Use 'x' as the variable and numpy functions prefixed with 'np.' (e.g. np.sin, np.exp, np.sqrt, np.log).
Problem: {problem_text}
Solution: {solution}
Rules:
- Use ** for powers: x**2 not x^2
- Use np.sin(x), np.cos(x), np.tan(x), np.exp(x), np.log(x), np.sqrt(x), np.abs(x)
- Use np.pi for pi, np.e for e
- For polynomials just write them: x**3 - 2*x**2 + 1
- Only output the expression, nothing else. No explanation, no code fences.
- If there's no clear function to plot, respond with exactly: NONE
Expression:"""
def _try_generate_plot(problem_text: str, solution: str) -> str:
"""Ask the LLM to extract a plottable expression and generate a diagram."""
try:
llm = get_llm(temperature=0.0)
response = llm.invoke(
PLOT_EXTRACT_PROMPT.format(problem_text=problem_text, solution=solution)
)
expr = response.content.strip()
# Clean up
if expr.startswith("```"):
expr = expr.split("\n", 1)[1] if "\n" in expr else expr[3:]
if expr.endswith("```"):
expr = expr[:-3].strip()
if not expr or expr.upper() == "NONE" or len(expr) > 200:
return ""
# Generate the plot
path = plot_function(
expression=expr,
title=f"y = {expr}",
)
return path
except Exception:
return ""
def explainer_node(state: MathMentorState) -> dict:
parsed = state.get("parsed_problem", {})
problem_text = parsed.get("problem_text", state.get("extracted_text", ""))
solution = state.get("solution", "")
steps = state.get("solution_steps", [])
topic = state.get("problem_topic", "math")
llm = get_llm(temperature=0.3)
response = llm.invoke(
EXPLAINER_PROMPT.format(
problem_text=problem_text,
solution=solution,
steps=json.dumps(steps),
topic=topic,
)
)
explanation = response.content if hasattr(response, "content") else str(response)
# Strip any accidental code fences
explanation = explanation.strip()
if explanation.startswith("```"):
explanation = explanation.split("\n", 1)[1] if "\n" in explanation else explanation[3:]
if explanation.endswith("```"):
explanation = explanation[:-3].strip()
# Generate diagram if router requested plotter
diagram_path = ""
tools_needed = state.get("tools_needed", [])
if "plotter" in tools_needed:
diagram_path = _try_generate_plot(problem_text, solution)
return {
"explanation": explanation,
"diagram_path": diagram_path,
"agent_trace": state.get("agent_trace", [])
+ [
{
"agent": "explainer",
"action": "explained",
"summary": "Generated explanation" + (" + diagram" if diagram_path else ""),
"timestamp": datetime.now().isoformat(),
}
],
}
|