File size: 4,313 Bytes
2e8d6bf
 
 
 
7428575
2e8d6bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7428575
2e8d6bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7428575
2e8d6bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import random
import re

from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage

from src.graph.state import StudyState
from src.tools.ingest import ingest_document
from src.tools.retriever import retrieve_chunks
from src.prompts.question_prompt import QUESTION_PROMPT
from src.prompts.answer_prompt import ANSWER_PROMPT
from src.prompts.evaluate_prompt import EVALUATE_PROMPT


# Module-level vectorstore reference, set during ingest
_vectorstore = None


def get_vectorstore():
    return _vectorstore


def ingest_node(state: StudyState) -> dict:
    global _vectorstore
    chunks, vectorstore = ingest_document(state["document_path"])
    _vectorstore = vectorstore
    print(f"Ingested {len(chunks)} chunks from {state['document_path']}")
    return {
        "chunks": chunks,
        "questions_asked": 0,
        "questions_correct": 0,
        "weak_chunks": [],
        "session_history": [],
        "mastery_reached": False,
    }


def generate_question_node(state: StudyState) -> dict:
    chunks = state["chunks"]
    weak = state.get("weak_chunks", [])

    # Prefer weak chunks if any, otherwise pick random
    if weak:
        passage = random.choice(weak)
    else:
        passage = random.choice(chunks)

    llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.7)
    prompt = QUESTION_PROMPT.format(passage=passage)
    response = llm.invoke([HumanMessage(content=prompt)])
    question = response.content.strip()

    print(f"\nQ{state['questions_asked'] + 1}: {question}")
    return {"current_question": question}


def answer_node(state: StudyState) -> dict:
    vectorstore = get_vectorstore()
    question = state["current_question"]

    retrieved = retrieve_chunks(vectorstore, question)
    context = "\n\n".join(
        f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved)
    )

    llm = ChatAnthropic(model="claude-sonnet-4-20250514", temperature=0.3)
    prompt = ANSWER_PROMPT.format(question=question, context=context)
    response = llm.invoke([HumanMessage(content=prompt)])
    answer = response.content.strip()

    print(f"Answer: {answer[:200]}...")
    return {"current_answer": answer}


def evaluate_node(state: StudyState) -> dict:
    vectorstore = get_vectorstore()
    question = state["current_question"]
    answer = state["current_answer"]

    # Retrieve the most relevant source chunk for grading
    source_chunks = retrieve_chunks(vectorstore, question, top_k=1)
    source = source_chunks[0] if source_chunks else ""

    llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
    prompt = EVALUATE_PROMPT.format(question=question, answer=answer, source=source)
    response = llm.invoke([HumanMessage(content=prompt)])
    result = response.content.strip()

    # Parse score
    score = 0.0
    reasoning = ""
    for line in result.split("\n"):
        if line.startswith("Score:"):
            match = re.search(r"[\d.]+", line)
            if match:
                score = float(match.group())
        elif line.startswith("Reasoning:"):
            reasoning = line.replace("Reasoning:", "").strip()

    questions_asked = state["questions_asked"] + 1
    questions_correct = state["questions_correct"] + (1 if score >= 0.75 else 0)

    # Track weak chunks
    weak_chunks = list(state.get("weak_chunks", []))
    if score < 0.75:
        weak_chunks.append(source)

    # Log to session history
    history = list(state.get("session_history", []))
    history.append({
        "question": question,
        "answer": answer,
        "score": score,
        "reasoning": reasoning,
    })

    print(f"Score: {score} | {reasoning}")
    return {
        "current_score": score,
        "questions_asked": questions_asked,
        "questions_correct": questions_correct,
        "weak_chunks": weak_chunks,
        "session_history": history,
    }


def reread_node(state: StudyState) -> dict:
    print("Re-reading weak chunk for reinforcement...")
    # The re-read simply keeps the weak chunk in state so the next
    # question generation will prioritize it. No additional action needed.
    return {}


def summarize_node(state: StudyState) -> dict:
    print("\nMastery reached. Generating session report...")
    return {"mastery_reached": True}