Gemini-Rag-Fastapi-Pro / langgraph_rag.py
lvvignesh2122's picture
Remove auth, fix quota issues with retries, and update agent graph
88cc76a
raw
history blame
2.59 kB
from typing import TypedDict, List, Optional
import google.generativeai as genai
from langgraph.graph import StateGraph, END
from rag_store import search_knowledge
from eval_logger import log_eval
MODEL_NAME = "gemini-2.5-flash"
# ===============================
# STATE
# ===============================
class RAGState(TypedDict):
query: str
retrieved_chunks: List[dict]
answer: Optional[str]
confidence: float
answer_known: bool
# ===============================
# RETRIEVAL NODE (TOOL)
# ===============================
def retrieve_node(state: RAGState) -> RAGState:
results = search_knowledge(state["query"])
return {
**state,
"retrieved_chunks": results
}
# ===============================
# ANSWER NODE
# ===============================
def answer_node(state: RAGState) -> RAGState:
if not state["retrieved_chunks"]:
return no_answer_node(state)
context = "\n\n".join(c["text"] for c in state["retrieved_chunks"])
prompt = f"""
Answer using ONLY the context below.
If the answer is not present, say "I don't know".
Context:
{context}
Question:
{state["query"]}
"""
model = genai.GenerativeModel(MODEL_NAME)
resp = model.generate_content(prompt)
answer_text = resp.text
confidence = min(1.0, len(state["retrieved_chunks"]) / 5)
answer_known = "i don't know" not in answer_text.lower()
log_eval(
query=state["query"],
retrieved_count=len(state["retrieved_chunks"]),
confidence=confidence,
answer_known=answer_known
)
return {
**state,
"answer": answer_text,
"confidence": confidence,
"answer_known": answer_known
}
# ===============================
# NO ANSWER NODE
# ===============================
def no_answer_node(state: RAGState) -> RAGState:
log_eval(
query=state["query"],
retrieved_count=0,
confidence=0.0,
answer_known=False
)
return {
**state,
"answer": "I don't know based on the provided documents.",
"confidence": 0.0,
"answer_known": False
}
# ===============================
# GRAPH BUILDER
# ===============================
def build_rag_graph():
graph = StateGraph(RAGState)
graph.add_node("retrieve", retrieve_node)
graph.add_node("answer", answer_node)
graph.add_node("no_answer", no_answer_node)
graph.set_entry_point("retrieve")
graph.add_edge("retrieve", "answer")
graph.add_edge("answer", END)
graph.add_edge("no_answer", END)
return graph.compile()