File size: 3,991 Bytes
3c25c17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import annotations

import json
import re
from datetime import datetime

from agents.state import MathMentorState
from llm.client import get_llm
from tools.plotter import plot_function

EXPLAINER_PROMPT = """\
You are a friendly math tutor explaining a solution to a JEE student.

Problem: {problem_text}
Solution: {solution}
Steps: {steps}
Topic: {topic}

Write a clear, step-by-step explanation in markdown. Use LaTeX for all math expressions:
- Inline math: \\(x^2 + 1\\)
- Display math: $$x = \\frac{{-b \\pm \\sqrt{{b^2 - 4ac}}}}{{2a}}$$

Guidelines:
1. Explain each step in simple, student-friendly language
2. Use LaTeX for ALL math expressions (never write raw math like x^2, always wrap in LaTeX)
3. Highlight key concepts and formulas used
4. Mention common mistakes to avoid
5. End with the final answer clearly stated

Do NOT wrap your response in JSON or code fences. Write plain markdown directly.
"""

PLOT_EXTRACT_PROMPT = """\
Given this math problem and solution, extract a numpy-compatible Python expression to plot.
Use 'x' as the variable and numpy functions prefixed with 'np.' (e.g. np.sin, np.exp, np.sqrt, np.log).

Problem: {problem_text}
Solution: {solution}

Rules:
- Use ** for powers: x**2 not x^2
- Use np.sin(x), np.cos(x), np.tan(x), np.exp(x), np.log(x), np.sqrt(x), np.abs(x)
- Use np.pi for pi, np.e for e
- For polynomials just write them: x**3 - 2*x**2 + 1
- Only output the expression, nothing else. No explanation, no code fences.
- If there's no clear function to plot, respond with exactly: NONE

Expression:"""


def _try_generate_plot(problem_text: str, solution: str) -> str:
    """Ask the LLM to extract a plottable expression and generate a diagram."""
    try:
        llm = get_llm(temperature=0.0)
        response = llm.invoke(
            PLOT_EXTRACT_PROMPT.format(problem_text=problem_text, solution=solution)
        )
        expr = response.content.strip()

        # Clean up
        if expr.startswith("```"):
            expr = expr.split("\n", 1)[1] if "\n" in expr else expr[3:]
        if expr.endswith("```"):
            expr = expr[:-3].strip()

        if not expr or expr.upper() == "NONE" or len(expr) > 200:
            return ""

        # Generate the plot
        path = plot_function(
            expression=expr,
            title=f"y = {expr}",
        )
        return path
    except Exception:
        return ""


def explainer_node(state: MathMentorState) -> dict:
    parsed = state.get("parsed_problem", {})
    problem_text = parsed.get("problem_text", state.get("extracted_text", ""))
    solution = state.get("solution", "")
    steps = state.get("solution_steps", [])
    topic = state.get("problem_topic", "math")

    llm = get_llm(temperature=0.3)
    response = llm.invoke(
        EXPLAINER_PROMPT.format(
            problem_text=problem_text,
            solution=solution,
            steps=json.dumps(steps),
            topic=topic,
        )
    )

    explanation = response.content if hasattr(response, "content") else str(response)

    # Strip any accidental code fences
    explanation = explanation.strip()
    if explanation.startswith("```"):
        explanation = explanation.split("\n", 1)[1] if "\n" in explanation else explanation[3:]
    if explanation.endswith("```"):
        explanation = explanation[:-3].strip()

    # Generate diagram if router requested plotter
    diagram_path = ""
    tools_needed = state.get("tools_needed", [])
    if "plotter" in tools_needed:
        diagram_path = _try_generate_plot(problem_text, solution)

    return {
        "explanation": explanation,
        "diagram_path": diagram_path,
        "agent_trace": state.get("agent_trace", [])
        + [
            {
                "agent": "explainer",
                "action": "explained",
                "summary": "Generated explanation" + (" + diagram" if diagram_path else ""),
                "timestamp": datetime.now().isoformat(),
            }
        ],
    }