File size: 4,179 Bytes
ea8f8db
 
 
 
ba7bcd3
ea8f8db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6882231
ea8f8db
 
 
 
6882231
ea8f8db
 
 
 
 
 
 
 
 
6882231
ea8f8db
 
ba7bcd3
 
 
ea8f8db
 
ba7bcd3
 
 
 
 
ea8f8db
 
 
 
6882231
ea8f8db
ba7bcd3
 
ea8f8db
 
ba7bcd3
 
 
 
 
ea8f8db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bafc8f
6882231
 
 
 
ea8f8db
6882231
 
 
ea8f8db
 
6882231
ea8f8db
 
6882231
ea8f8db
 
 
6882231
ea8f8db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import TypedDict, List
from langgraph.graph import StateGraph, END
from langchain_core.documents import Document
from core.models import get_llm
from prompts import RAG_PROMPT, REFLECTION_PROMPT, REWRITE_PROMPT
from langchain_core.output_parsers import StrOutputParser

class GraphState(TypedDict):
    question: str
    current_query: str
    generation: str
    documents: List[Document]
    reflection_score: str
    iterations: int

class RAGAgent:
    def __init__(self, retriever):
        self.retriever = retriever
        self.llm = get_llm()
        self.app = self.build_graph()

    def retriever_node(self, state: GraphState):
        query = state["current_query"]
        docs = self.retriever.invoke(query)
        return {"documents": docs}

    def generator_node(self, state: GraphState):
        question = state["question"]
        docs = state["documents"]
        
        context = "\n\n".join([f"[Document: {doc.metadata.get('filename', 'Unknown')} | Page: {doc.metadata.get('page', 0) + 2}] {doc.page_content}" for doc in docs])
        
        chain = RAG_PROMPT | self.llm | StrOutputParser()
        response = chain.invoke({"context": context, "question": question})
        return {"generation": response}

    def reflector_node(self, state: GraphState):
        question = state["question"]
        generation = state["generation"]
        docs = state["documents"]
        
        context = "\n\n".join([f"[Source: {doc.metadata.get('filename', 'Unknown')}] {doc.page_content}" for doc in docs])
        
        chain = REFLECTION_PROMPT | self.llm | StrOutputParser()
        score = chain.invoke({
            "context": context,
            "question": question, 
            "generation": generation
        })
        
        normalized_score = "yes" if "yes" in score.lower() else "no"
        return {"reflection_score": normalized_score}

    def rewriter_node(self, state: GraphState):
        question = state["question"]
        previous_query = state["current_query"]
        failed_gen = state["generation"]
        
        chain = REWRITE_PROMPT | self.llm | StrOutputParser()
        new_query = chain.invoke({
            "question": question, 
            "previous_query": previous_query, 
            "generation": failed_gen
        })
        
        return {"current_query": new_query, "iterations": state["iterations"] + 1}

    def decide_to_rewrite(self, state: GraphState):
        score = state["reflection_score"]
        iterations = state.get("iterations", 0)
        
        if score == "yes" or iterations >= 3:
            return "end"
        else:
            return "rewrite"

    def build_graph(self):
        workflow = StateGraph(GraphState)


        workflow.add_node("retriever", self.retriever_node)
        workflow.add_node("generator", self.generator_node)
        workflow.add_node("reflector", self.reflector_node)
        workflow.add_node("rewriter", self.rewriter_node)

        workflow.set_entry_point("retriever")
        workflow.add_edge("retriever", "generator")
        workflow.add_edge("generator", "reflector")
        
        workflow.add_conditional_edges(
            "reflector",
            self.decide_to_rewrite,
            {
                "rewrite": "rewriter",
                "end": END
            }
        )
        workflow.add_edge("rewriter", "retriever")

        return workflow.compile()

    def run(self, question: str, callback=None):
        inputs = {
            "question": question,
            "current_query": question,
            "iterations": 0,
            "reflection_score": "no"
        }
        
        final_state = inputs
        for output in self.app.stream(inputs):
            for key, value in output.items():
                final_state.update(value)
                if callback:
                    callback(key, final_state)
                
        return final_state

    def get_graph_image(self, file_path: str = None):
        img_bytes = self.app.get_graph().draw_mermaid_png()
        if file_path:
            with open(file_path, "wb") as f:
                f.write(img_bytes)
        return img_bytes