Spaces:
Sleeping
Sleeping
File size: 4,179 Bytes
ea8f8db ba7bcd3 ea8f8db 6882231 ea8f8db 6882231 ea8f8db 6882231 ea8f8db ba7bcd3 ea8f8db ba7bcd3 ea8f8db 6882231 ea8f8db ba7bcd3 ea8f8db ba7bcd3 ea8f8db 7bafc8f 6882231 ea8f8db 6882231 ea8f8db 6882231 ea8f8db 6882231 ea8f8db 6882231 ea8f8db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | from typing import TypedDict, List
from langgraph.graph import StateGraph, END
from langchain_core.documents import Document
from core.models import get_llm
from prompts import RAG_PROMPT, REFLECTION_PROMPT, REWRITE_PROMPT
from langchain_core.output_parsers import StrOutputParser
class GraphState(TypedDict):
question: str
current_query: str
generation: str
documents: List[Document]
reflection_score: str
iterations: int
class RAGAgent:
def __init__(self, retriever):
self.retriever = retriever
self.llm = get_llm()
self.app = self.build_graph()
def retriever_node(self, state: GraphState):
query = state["current_query"]
docs = self.retriever.invoke(query)
return {"documents": docs}
def generator_node(self, state: GraphState):
question = state["question"]
docs = state["documents"]
context = "\n\n".join([f"[Document: {doc.metadata.get('filename', 'Unknown')} | Page: {doc.metadata.get('page', 0) + 2}] {doc.page_content}" for doc in docs])
chain = RAG_PROMPT | self.llm | StrOutputParser()
response = chain.invoke({"context": context, "question": question})
return {"generation": response}
def reflector_node(self, state: GraphState):
question = state["question"]
generation = state["generation"]
docs = state["documents"]
context = "\n\n".join([f"[Source: {doc.metadata.get('filename', 'Unknown')}] {doc.page_content}" for doc in docs])
chain = REFLECTION_PROMPT | self.llm | StrOutputParser()
score = chain.invoke({
"context": context,
"question": question,
"generation": generation
})
normalized_score = "yes" if "yes" in score.lower() else "no"
return {"reflection_score": normalized_score}
def rewriter_node(self, state: GraphState):
question = state["question"]
previous_query = state["current_query"]
failed_gen = state["generation"]
chain = REWRITE_PROMPT | self.llm | StrOutputParser()
new_query = chain.invoke({
"question": question,
"previous_query": previous_query,
"generation": failed_gen
})
return {"current_query": new_query, "iterations": state["iterations"] + 1}
def decide_to_rewrite(self, state: GraphState):
score = state["reflection_score"]
iterations = state.get("iterations", 0)
if score == "yes" or iterations >= 3:
return "end"
else:
return "rewrite"
def build_graph(self):
workflow = StateGraph(GraphState)
workflow.add_node("retriever", self.retriever_node)
workflow.add_node("generator", self.generator_node)
workflow.add_node("reflector", self.reflector_node)
workflow.add_node("rewriter", self.rewriter_node)
workflow.set_entry_point("retriever")
workflow.add_edge("retriever", "generator")
workflow.add_edge("generator", "reflector")
workflow.add_conditional_edges(
"reflector",
self.decide_to_rewrite,
{
"rewrite": "rewriter",
"end": END
}
)
workflow.add_edge("rewriter", "retriever")
return workflow.compile()
def run(self, question: str, callback=None):
inputs = {
"question": question,
"current_query": question,
"iterations": 0,
"reflection_score": "no"
}
final_state = inputs
for output in self.app.stream(inputs):
for key, value in output.items():
final_state.update(value)
if callback:
callback(key, final_state)
return final_state
def get_graph_image(self, file_path: str = None):
img_bytes = self.app.get_graph().draw_mermaid_png()
if file_path:
with open(file_path, "wb") as f:
f.write(img_bytes)
return img_bytes
|