Spaces:
Running
Running
| from typing import TypedDict, List | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.documents import Document | |
| from core.models import get_llm | |
| from prompts import RAG_PROMPT, REFLECTION_PROMPT, REWRITE_PROMPT | |
| from langchain_core.output_parsers import StrOutputParser | |
| class GraphState(TypedDict): | |
| question: str | |
| current_query: str | |
| generation: str | |
| documents: List[Document] | |
| reflection_score: str | |
| iterations: int | |
| class RAGAgent: | |
| def __init__(self, retriever): | |
| self.retriever = retriever | |
| self.llm = get_llm() | |
| self.app = self.build_graph() | |
| def retriever_node(self, state: GraphState): | |
| query = state["current_query"] | |
| docs = self.retriever.invoke(query) | |
| return {"documents": docs} | |
| def generator_node(self, state: GraphState): | |
| question = state["question"] | |
| docs = state["documents"] | |
| context = "\n\n".join([f"[Document: {doc.metadata.get('filename', 'Unknown')} | Page: {doc.metadata.get('page', 0) + 2}] {doc.page_content}" for doc in docs]) | |
| chain = RAG_PROMPT | self.llm | StrOutputParser() | |
| response = chain.invoke({"context": context, "question": question}) | |
| return {"generation": response} | |
| def reflector_node(self, state: GraphState): | |
| question = state["question"] | |
| generation = state["generation"] | |
| docs = state["documents"] | |
| context = "\n\n".join([f"[Source: {doc.metadata.get('filename', 'Unknown')}] {doc.page_content}" for doc in docs]) | |
| chain = REFLECTION_PROMPT | self.llm | StrOutputParser() | |
| score = chain.invoke({ | |
| "context": context, | |
| "question": question, | |
| "generation": generation | |
| }) | |
| normalized_score = "yes" if "yes" in score.lower() else "no" | |
| return {"reflection_score": normalized_score} | |
| def rewriter_node(self, state: GraphState): | |
| question = state["question"] | |
| previous_query = state["current_query"] | |
| failed_gen = state["generation"] | |
| chain = REWRITE_PROMPT | self.llm | StrOutputParser() | |
| new_query = chain.invoke({ | |
| "question": question, | |
| "previous_query": previous_query, | |
| "generation": failed_gen | |
| }) | |
| return {"current_query": new_query, "iterations": state["iterations"] + 1} | |
| def decide_to_rewrite(self, state: GraphState): | |
| score = state["reflection_score"] | |
| iterations = state.get("iterations", 0) | |
| if score == "yes" or iterations >= 3: | |
| return "end" | |
| else: | |
| return "rewrite" | |
| def build_graph(self): | |
| workflow = StateGraph(GraphState) | |
| workflow.add_node("retriever", self.retriever_node) | |
| workflow.add_node("generator", self.generator_node) | |
| workflow.add_node("reflector", self.reflector_node) | |
| workflow.add_node("rewriter", self.rewriter_node) | |
| workflow.set_entry_point("retriever") | |
| workflow.add_edge("retriever", "generator") | |
| workflow.add_edge("generator", "reflector") | |
| workflow.add_conditional_edges( | |
| "reflector", | |
| self.decide_to_rewrite, | |
| { | |
| "rewrite": "rewriter", | |
| "end": END | |
| } | |
| ) | |
| workflow.add_edge("rewriter", "retriever") | |
| return workflow.compile() | |
| def run(self, question: str, callback=None): | |
| inputs = { | |
| "question": question, | |
| "current_query": question, | |
| "iterations": 0, | |
| "reflection_score": "no" | |
| } | |
| final_state = inputs | |
| for output in self.app.stream(inputs): | |
| for key, value in output.items(): | |
| final_state.update(value) | |
| if callback: | |
| callback(key, final_state) | |
| return final_state | |
| def get_graph_image(self, file_path: str = None): | |
| img_bytes = self.app.get_graph().draw_mermaid_png() | |
| if file_path: | |
| with open(file_path, "wb") as f: | |
| f.write(img_bytes) | |
| return img_bytes | |