from typing import TypedDict, List from langgraph.graph import StateGraph, END from langchain_core.documents import Document from core.models import get_llm from prompts import RAG_PROMPT, REFLECTION_PROMPT, REWRITE_PROMPT from langchain_core.output_parsers import StrOutputParser class GraphState(TypedDict): question: str current_query: str generation: str documents: List[Document] reflection_score: str iterations: int class RAGAgent: def __init__(self, retriever): self.retriever = retriever self.llm = get_llm() self.app = self.build_graph() def retriever_node(self, state: GraphState): query = state["current_query"] docs = self.retriever.invoke(query) return {"documents": docs} def generator_node(self, state: GraphState): question = state["question"] docs = state["documents"] context = "\n\n".join([f"[Document: {doc.metadata.get('filename', 'Unknown')} | Page: {doc.metadata.get('page', 0) + 2}] {doc.page_content}" for doc in docs]) chain = RAG_PROMPT | self.llm | StrOutputParser() response = chain.invoke({"context": context, "question": question}) return {"generation": response} def reflector_node(self, state: GraphState): question = state["question"] generation = state["generation"] docs = state["documents"] context = "\n\n".join([f"[Source: {doc.metadata.get('filename', 'Unknown')}] {doc.page_content}" for doc in docs]) chain = REFLECTION_PROMPT | self.llm | StrOutputParser() score = chain.invoke({ "context": context, "question": question, "generation": generation }) normalized_score = "yes" if "yes" in score.lower() else "no" return {"reflection_score": normalized_score} def rewriter_node(self, state: GraphState): question = state["question"] previous_query = state["current_query"] failed_gen = state["generation"] chain = REWRITE_PROMPT | self.llm | StrOutputParser() new_query = chain.invoke({ "question": question, "previous_query": previous_query, "generation": failed_gen }) return {"current_query": new_query, "iterations": state["iterations"] + 1} def decide_to_rewrite(self, state: GraphState): score = state["reflection_score"] iterations = state.get("iterations", 0) if score == "yes" or iterations >= 3: return "end" else: return "rewrite" def build_graph(self): workflow = StateGraph(GraphState) workflow.add_node("retriever", self.retriever_node) workflow.add_node("generator", self.generator_node) workflow.add_node("reflector", self.reflector_node) workflow.add_node("rewriter", self.rewriter_node) workflow.set_entry_point("retriever") workflow.add_edge("retriever", "generator") workflow.add_edge("generator", "reflector") workflow.add_conditional_edges( "reflector", self.decide_to_rewrite, { "rewrite": "rewriter", "end": END } ) workflow.add_edge("rewriter", "retriever") return workflow.compile() def run(self, question: str, callback=None): inputs = { "question": question, "current_query": question, "iterations": 0, "reflection_score": "no" } final_state = inputs for output in self.app.stream(inputs): for key, value in output.items(): final_state.update(value) if callback: callback(key, final_state) return final_state def get_graph_image(self, file_path: str = None): img_bytes = self.app.get_graph().draw_mermaid_png() if file_path: with open(file_path, "wb") as f: f.write(img_bytes) return img_bytes