Dinesh310 commited on
Commit
4302ded
·
verified ·
1 Parent(s): f953109

Create rag_graph.py

Browse files
Files changed (1) hide show
  1. src/rag_graph.py +92 -0
src/rag_graph.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/rag_graph.py
2
+ from langgraph.graph import StateGraph, END
3
+ from langgraph.checkpoint.memory import MemorySaver
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+
6
+ from src.core.graph_state import GraphState
7
+ from src.core.embeddings import load_embeddings
8
+ from src.core.llm import load_llm
9
+ from src.vector_store.vector_store import build_vector_store
10
+ from src.config.config import K_OFFSET, MMR_LAMBDA
11
+ from src.exceptions import VectorStoreNotInitializedError, LLMInvocationError
12
+
13
+
14
+ class ProjectRAGGraph:
15
+ def __init__(self):
16
+ self.embeddings = load_embeddings()
17
+ self.llm = load_llm()
18
+ self.vector_store = None
19
+ self.pdf_count = 0
20
+ self.memory = MemorySaver()
21
+ self.workflow = self._build_graph()
22
+
23
+ def process_documents(self, pdf_paths, original_names=None):
24
+ self.pdf_count = len(pdf_paths)
25
+ self.vector_store = build_vector_store(
26
+ pdf_paths,
27
+ self.embeddings,
28
+ original_names
29
+ )
30
+
31
+ # ---------- Graph Nodes ----------
32
+
33
+ def retrieve(self, state: GraphState):
34
+ if not self.vector_store:
35
+ raise VectorStoreNotInitializedError("Vector store not initialized")
36
+
37
+ k_value = max(1, self.pdf_count + K_OFFSET)
38
+
39
+ retriever = self.vector_store.as_retriever(
40
+ search_type="mmr",
41
+ search_kwargs={"k": k_value, "lambda_mult": MMR_LAMBDA}
42
+ )
43
+
44
+ documents = retriever.invoke(state["question"])
45
+ return {"context": documents}
46
+
47
+ def generate(self, state: GraphState):
48
+ try:
49
+ prompt = ChatPromptTemplate.from_template(
50
+ """
51
+ You are an expert Project Analyst.
52
+ Answer ONLY using the provided context.
53
+ If the answer is not present, say "I don't know".
54
+
55
+ Context:
56
+ {context}
57
+
58
+ Question:
59
+ {question}
60
+ """
61
+ )
62
+
63
+ formatted_context = "\n\n".join(
64
+ doc.page_content for doc in state["context"]
65
+ )
66
+
67
+ chain = prompt | self.llm
68
+ response = chain.invoke({
69
+ "context": formatted_context,
70
+ "question": state["question"]
71
+ })
72
+
73
+ return {"answer": response.content}
74
+
75
+ except Exception as e:
76
+ raise LLMInvocationError(f"LLM failed: {e}")
77
+
78
+ # ---------- Graph Build ----------
79
+
80
+ def _build_graph(self):
81
+ workflow = StateGraph(GraphState)
82
+ workflow.add_node("retrieve", self.retrieve)
83
+ workflow.add_node("generate", self.generate)
84
+ workflow.set_entry_point("retrieve")
85
+ workflow.add_edge("retrieve", "generate")
86
+ workflow.add_edge("generate", END)
87
+ return workflow.compile(checkpointer=self.memory)
88
+
89
+ def query(self, question: str, thread_id: str):
90
+ config = {"configurable": {"thread_id": thread_id}}
91
+ result = self.workflow.invoke({"question": question}, config=config)
92
+ return result["answer"]