File size: 2,528 Bytes
2bd85c8
 
 
4a5f44f
 
2bd85c8
 
 
 
 
 
 
 
4a5f44f
2bd85c8
4a5f44f
 
 
 
 
 
 
 
 
 
2bd85c8
4a5f44f
 
2bd85c8
4a5f44f
 
 
 
 
 
 
2bd85c8
4a5f44f
 
2bd85c8
4a5f44f
 
 
2bd85c8
4a5f44f
 
 
2bd85c8
4a5f44f
 
 
 
 
2bd85c8
4a5f44f
2bd85c8
4a5f44f
 
2bd85c8
4a5f44f
2bd85c8
 
4a5f44f
2bd85c8
 
4a5f44f
2bd85c8
 
 
4a5f44f
2bd85c8
 
 
4a5f44f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.prompts import ChatPromptTemplate
from core.state import GraphState
from config.config import Config

class ProjectRAGGraph:
    def __init__(self, llm, vector_store):
        self.llm = llm
        self.vector_store = vector_store
        self.memory = MemorySaver()
        self.workflow = self._build_graph()

    # -------- Nodes --------
    def retrieve(self, state: GraphState):
        try:
            retriever = self.vector_store.as_retriever(
                search_type=Config.SEARCH_TYPE,
                search_kwargs={
                    "k": Config.TOP_K,
                    "lambda_mult": Config.LAMBDA_MULT
                }
            )
            docs = retriever.invoke(state["question"])
            return {"context": docs}

        except Exception as e:
            raise RuntimeError(f"Retrieval failed: {e}")

    def generate(self, state: GraphState):
        try:
            prompt = ChatPromptTemplate.from_template(
                """
                You are a professional Project Analyst.
                Context:
                {context}

                Question:
                {question}

                Answer strictly using the context. Cite sources.
                """
            )

            formatted_context = "\n\n".join(
                doc.page_content for doc in state["context"]
            )

            chain = prompt | self.llm
            response = chain.invoke({
                "context": formatted_context,
                "question": state["question"]
            })

            return {"answer": response.content}

        except Exception as e:
            raise RuntimeError(f"Answer generation failed: {e}")

    # -------- Graph --------
    def _build_graph(self):
        graph = StateGraph(GraphState)

        graph.add_node("retrieve", self.retrieve)
        graph.add_node("generate", self.generate)

        graph.set_entry_point("retrieve")
        graph.add_edge("retrieve", "generate")
        graph.add_edge("generate", END)

        return graph.compile(checkpointer=self.memory)

    def query(self, question: str, thread_id: str):
        try:
            config = {"configurable": {"thread_id": thread_id}}
            result = self.workflow.invoke({"question": question}, config=config)
            return result["answer"]

        except Exception as e:
            raise RuntimeError(f"Graph execution failed: {e}")