Dinesh310 commited on
Commit
8b888df
·
verified ·
1 Parent(s): 160beae

Update src/graph/rag_graph.py

Browse files
Files changed (1) hide show
  1. src/graph/rag_graph.py +39 -47
src/graph/rag_graph.py CHANGED
@@ -11,70 +11,62 @@ class ProjectRAGGraph:
11
  self.memory = MemorySaver()
12
  self.workflow = self._build_graph()
13
 
14
- # -------- Nodes --------
15
  def retrieve(self, state: GraphState):
16
- try:
17
- retriever = self.vector_store.as_retriever(
18
- search_type=Config.SEARCH_TYPE,
19
- search_kwargs={
20
- "k": Config.TOP_K,
21
- "lambda_mult": Config.LAMBDA_MULT
22
- }
23
- )
24
- docs = retriever.invoke(state["question"])
25
- return {"context": docs}
 
26
 
27
- except Exception as e:
28
- raise RuntimeError(f"Retrieval failed: {e}")
29
 
30
- def generate(self, state: GraphState):
31
- try:
32
- prompt = ChatPromptTemplate.from_template(
33
- """
34
- You are a professional Project Analyst.
35
- Context:
36
- {context}
37
 
38
- Question:
39
- {question}
 
 
 
40
 
41
- Answer strictly using the context. Cite sources.
42
- """
43
- )
44
 
45
- formatted_context = "\n\n".join(
46
- doc.page_content for doc in state["context"]
47
- )
 
 
48
 
49
- chain = prompt | self.llm
50
- response = chain.invoke({
51
- "context": formatted_context,
52
- "question": state["question"]
53
- })
54
 
55
- return {"answer": response.content}
56
 
57
- except Exception as e:
58
- raise RuntimeError(f"Answer generation failed: {e}")
59
 
60
- # -------- Graph --------
61
  def _build_graph(self):
62
  graph = StateGraph(GraphState)
63
-
64
  graph.add_node("retrieve", self.retrieve)
65
  graph.add_node("generate", self.generate)
66
-
67
  graph.set_entry_point("retrieve")
68
  graph.add_edge("retrieve", "generate")
69
  graph.add_edge("generate", END)
70
-
71
  return graph.compile(checkpointer=self.memory)
72
 
73
  def query(self, question: str, thread_id: str):
74
- try:
75
- config = {"configurable": {"thread_id": thread_id}}
76
- result = self.workflow.invoke({"question": question}, config=config)
77
- return result["answer"]
78
-
79
- except Exception as e:
80
- raise RuntimeError(f"Graph execution failed: {e}")
 
11
  self.memory = MemorySaver()
12
  self.workflow = self._build_graph()
13
 
14
+ # ---------------- Retrieval ----------------
15
  def retrieve(self, state: GraphState):
16
+ retriever = self.vector_store.as_retriever(
17
+ search_type=Config.SEARCH_TYPE,
18
+ search_kwargs={"k": Config.TOP_K}
19
+ )
20
+ docs = retriever.invoke(state["question"])
21
+ return {"context": docs}
22
+
23
+ # ---------------- Generation ----------------
24
+ def generate(self, state: GraphState):
25
+ prompt = ChatPromptTemplate.from_template("""
26
+ You are a professional Project Analyst.
27
 
28
+ Context:
29
+ {context}
30
 
31
+ Question:
32
+ {question}
 
 
 
 
 
33
 
34
+ Rules:
35
+ - Answer ONLY from the context
36
+ - Add citations at the end
37
+ - Citation format: (Source: <file>, page <number>)
38
+ """)
39
 
40
+ context_text = []
41
+ citations = set()
 
42
 
43
+ for doc in state["context"]:
44
+ page = doc.metadata.get("page", "N/A")
45
+ source = doc.metadata.get("source", "Unknown")
46
+ context_text.append(doc.page_content)
47
+ citations.add(f"(Source: {source}, page {page + 1})")
48
 
49
+ chain = prompt | self.llm
50
+ response = chain.invoke({
51
+ "context": "\n\n".join(context_text),
52
+ "question": state["question"]
53
+ })
54
 
55
+ final_answer = response.content + "\n\n" + "\n".join(citations)
56
 
57
+ return {"answer": final_answer}
 
58
 
59
+ # ---------------- Graph ----------------
60
  def _build_graph(self):
61
  graph = StateGraph(GraphState)
 
62
  graph.add_node("retrieve", self.retrieve)
63
  graph.add_node("generate", self.generate)
 
64
  graph.set_entry_point("retrieve")
65
  graph.add_edge("retrieve", "generate")
66
  graph.add_edge("generate", END)
 
67
  return graph.compile(checkpointer=self.memory)
68
 
69
  def query(self, question: str, thread_id: str):
70
+ config = {"configurable": {"thread_id": thread_id}}
71
+ result = self.workflow.invoke({"question": question}, config=config)
72
+ return result["answer"]