File size: 2,588 Bytes
88cc76a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from typing import TypedDict, List, Optional
import google.generativeai as genai
from langgraph.graph import StateGraph, END

from rag_store import search_knowledge
from eval_logger import log_eval

MODEL_NAME = "gemini-2.5-flash"


# ===============================
# STATE
# ===============================
class RAGState(TypedDict):
    query: str
    retrieved_chunks: List[dict]
    answer: Optional[str]
    confidence: float
    answer_known: bool


# ===============================
# RETRIEVAL NODE (TOOL)
# ===============================
def retrieve_node(state: RAGState) -> RAGState:
    results = search_knowledge(state["query"])
    return {
        **state,
        "retrieved_chunks": results
    }


# ===============================
# ANSWER NODE
# ===============================
def answer_node(state: RAGState) -> RAGState:
    if not state["retrieved_chunks"]:
        return no_answer_node(state)

    context = "\n\n".join(c["text"] for c in state["retrieved_chunks"])

    prompt = f"""
Answer using ONLY the context below.
If the answer is not present, say "I don't know".

Context:
{context}

Question:
{state["query"]}
"""

    model = genai.GenerativeModel(MODEL_NAME)
    resp = model.generate_content(prompt)
    answer_text = resp.text

    confidence = min(1.0, len(state["retrieved_chunks"]) / 5)
    answer_known = "i don't know" not in answer_text.lower()

    log_eval(
        query=state["query"],
        retrieved_count=len(state["retrieved_chunks"]),
        confidence=confidence,
        answer_known=answer_known
    )

    return {
        **state,
        "answer": answer_text,
        "confidence": confidence,
        "answer_known": answer_known
    }


# ===============================
# NO ANSWER NODE
# ===============================
def no_answer_node(state: RAGState) -> RAGState:
    log_eval(
        query=state["query"],
        retrieved_count=0,
        confidence=0.0,
        answer_known=False
    )

    return {
        **state,
        "answer": "I don't know based on the provided documents.",
        "confidence": 0.0,
        "answer_known": False
    }


# ===============================
# GRAPH BUILDER
# ===============================
def build_rag_graph():
    graph = StateGraph(RAGState)

    graph.add_node("retrieve", retrieve_node)
    graph.add_node("answer", answer_node)
    graph.add_node("no_answer", no_answer_node)

    graph.set_entry_point("retrieve")

    graph.add_edge("retrieve", "answer")
    graph.add_edge("answer", END)
    graph.add_edge("no_answer", END)

    return graph.compile()