agent-project / core /graph.py
ego
update prompt
6882231
from typing import TypedDict, List
from langgraph.graph import StateGraph, END
from langchain_core.documents import Document
from core.models import get_llm
from prompts import RAG_PROMPT, REFLECTION_PROMPT, REWRITE_PROMPT
from langchain_core.output_parsers import StrOutputParser
class GraphState(TypedDict):
question: str
current_query: str
generation: str
documents: List[Document]
reflection_score: str
iterations: int
class RAGAgent:
def __init__(self, retriever):
self.retriever = retriever
self.llm = get_llm()
self.app = self.build_graph()
def retriever_node(self, state: GraphState):
query = state["current_query"]
docs = self.retriever.invoke(query)
return {"documents": docs}
def generator_node(self, state: GraphState):
question = state["question"]
docs = state["documents"]
context = "\n\n".join([f"[Document: {doc.metadata.get('filename', 'Unknown')} | Page: {doc.metadata.get('page', 0) + 2}] {doc.page_content}" for doc in docs])
chain = RAG_PROMPT | self.llm | StrOutputParser()
response = chain.invoke({"context": context, "question": question})
return {"generation": response}
def reflector_node(self, state: GraphState):
question = state["question"]
generation = state["generation"]
docs = state["documents"]
context = "\n\n".join([f"[Source: {doc.metadata.get('filename', 'Unknown')}] {doc.page_content}" for doc in docs])
chain = REFLECTION_PROMPT | self.llm | StrOutputParser()
score = chain.invoke({
"context": context,
"question": question,
"generation": generation
})
normalized_score = "yes" if "yes" in score.lower() else "no"
return {"reflection_score": normalized_score}
def rewriter_node(self, state: GraphState):
question = state["question"]
previous_query = state["current_query"]
failed_gen = state["generation"]
chain = REWRITE_PROMPT | self.llm | StrOutputParser()
new_query = chain.invoke({
"question": question,
"previous_query": previous_query,
"generation": failed_gen
})
return {"current_query": new_query, "iterations": state["iterations"] + 1}
def decide_to_rewrite(self, state: GraphState):
score = state["reflection_score"]
iterations = state.get("iterations", 0)
if score == "yes" or iterations >= 3:
return "end"
else:
return "rewrite"
def build_graph(self):
workflow = StateGraph(GraphState)
workflow.add_node("retriever", self.retriever_node)
workflow.add_node("generator", self.generator_node)
workflow.add_node("reflector", self.reflector_node)
workflow.add_node("rewriter", self.rewriter_node)
workflow.set_entry_point("retriever")
workflow.add_edge("retriever", "generator")
workflow.add_edge("generator", "reflector")
workflow.add_conditional_edges(
"reflector",
self.decide_to_rewrite,
{
"rewrite": "rewriter",
"end": END
}
)
workflow.add_edge("rewriter", "retriever")
return workflow.compile()
def run(self, question: str, callback=None):
inputs = {
"question": question,
"current_query": question,
"iterations": 0,
"reflection_score": "no"
}
final_state = inputs
for output in self.app.stream(inputs):
for key, value in output.items():
final_state.update(value)
if callback:
callback(key, final_state)
return final_state
def get_graph_image(self, file_path: str = None):
img_bytes = self.app.get_graph().draw_mermaid_png()
if file_path:
with open(file_path, "wb") as f:
f.write(img_bytes)
return img_bytes