| from typing import TypedDict, List, Optional |
| import google.generativeai as genai |
| from langgraph.graph import StateGraph, END |
|
|
| from rag_store import search_knowledge |
| from eval_logger import log_eval |
|
|
| MODEL_NAME = "gemini-2.5-flash" |
|
|
|
|
| |
| |
| |
| class RAGState(TypedDict): |
| query: str |
| retrieved_chunks: List[dict] |
| answer: Optional[str] |
| confidence: float |
| answer_known: bool |
|
|
|
|
| |
| |
| |
| def retrieve_node(state: RAGState) -> RAGState: |
| results = search_knowledge(state["query"]) |
| return { |
| **state, |
| "retrieved_chunks": results |
| } |
|
|
|
|
| |
| |
| |
| def answer_node(state: RAGState) -> RAGState: |
| if not state["retrieved_chunks"]: |
| return no_answer_node(state) |
|
|
| context = "\n\n".join(c["text"] for c in state["retrieved_chunks"]) |
|
|
| prompt = f""" |
| Answer using ONLY the context below. |
| If the answer is not present, say "I don't know". |
| |
| Context: |
| {context} |
| |
| Question: |
| {state["query"]} |
| """ |
|
|
| model = genai.GenerativeModel(MODEL_NAME) |
| resp = model.generate_content(prompt) |
| answer_text = resp.text |
|
|
| confidence = min(1.0, len(state["retrieved_chunks"]) / 5) |
| answer_known = "i don't know" not in answer_text.lower() |
|
|
| log_eval( |
| query=state["query"], |
| retrieved_count=len(state["retrieved_chunks"]), |
| confidence=confidence, |
| answer_known=answer_known |
| ) |
|
|
| return { |
| **state, |
| "answer": answer_text, |
| "confidence": confidence, |
| "answer_known": answer_known |
| } |
|
|
|
|
| |
| |
| |
| def no_answer_node(state: RAGState) -> RAGState: |
| log_eval( |
| query=state["query"], |
| retrieved_count=0, |
| confidence=0.0, |
| answer_known=False |
| ) |
|
|
| return { |
| **state, |
| "answer": "I don't know based on the provided documents.", |
| "confidence": 0.0, |
| "answer_known": False |
| } |
|
|
|
|
| |
| |
| |
| def build_rag_graph(): |
| graph = StateGraph(RAGState) |
|
|
| graph.add_node("retrieve", retrieve_node) |
| graph.add_node("answer", answer_node) |
| graph.add_node("no_answer", no_answer_node) |
|
|
| graph.set_entry_point("retrieve") |
|
|
| graph.add_edge("retrieve", "answer") |
| graph.add_edge("answer", END) |
| graph.add_edge("no_answer", END) |
|
|
| return graph.compile() |
|
|