File size: 2,588 Bytes
88cc76a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from typing import TypedDict, List, Optional
import google.generativeai as genai
from langgraph.graph import StateGraph, END
from rag_store import search_knowledge
from eval_logger import log_eval
MODEL_NAME = "gemini-2.5-flash"
# ===============================
# STATE
# ===============================
class RAGState(TypedDict):
query: str
retrieved_chunks: List[dict]
answer: Optional[str]
confidence: float
answer_known: bool
# ===============================
# RETRIEVAL NODE (TOOL)
# ===============================
def retrieve_node(state: RAGState) -> RAGState:
results = search_knowledge(state["query"])
return {
**state,
"retrieved_chunks": results
}
# ===============================
# ANSWER NODE
# ===============================
def answer_node(state: RAGState) -> RAGState:
if not state["retrieved_chunks"]:
return no_answer_node(state)
context = "\n\n".join(c["text"] for c in state["retrieved_chunks"])
prompt = f"""
Answer using ONLY the context below.
If the answer is not present, say "I don't know".
Context:
{context}
Question:
{state["query"]}
"""
model = genai.GenerativeModel(MODEL_NAME)
resp = model.generate_content(prompt)
answer_text = resp.text
confidence = min(1.0, len(state["retrieved_chunks"]) / 5)
answer_known = "i don't know" not in answer_text.lower()
log_eval(
query=state["query"],
retrieved_count=len(state["retrieved_chunks"]),
confidence=confidence,
answer_known=answer_known
)
return {
**state,
"answer": answer_text,
"confidence": confidence,
"answer_known": answer_known
}
# ===============================
# NO ANSWER NODE
# ===============================
def no_answer_node(state: RAGState) -> RAGState:
log_eval(
query=state["query"],
retrieved_count=0,
confidence=0.0,
answer_known=False
)
return {
**state,
"answer": "I don't know based on the provided documents.",
"confidence": 0.0,
"answer_known": False
}
# ===============================
# GRAPH BUILDER
# ===============================
def build_rag_graph():
graph = StateGraph(RAGState)
graph.add_node("retrieve", retrieve_node)
graph.add_node("answer", answer_node)
graph.add_node("no_answer", no_answer_node)
graph.set_entry_point("retrieve")
graph.add_edge("retrieve", "answer")
graph.add_edge("answer", END)
graph.add_edge("no_answer", END)
return graph.compile()
|