Dinesh310 commited on
Commit
4a5f44f
·
verified ·
1 Parent(s): d6afd8b

Update src/graph/rag_graph.py

Browse files
Files changed (1) hide show
  1. src/graph/rag_graph.py +49 -41
src/graph/rag_graph.py CHANGED
@@ -1,8 +1,8 @@
1
  from langgraph.graph import StateGraph, END
2
  from langgraph.checkpoint.memory import MemorySaver
3
  from langchain_core.prompts import ChatPromptTemplate
4
- from src.core.state import GraphState
5
- from src.config.config import Config
6
 
7
  class ProjectRAGGraph:
8
  def __init__(self, llm, vector_store):
@@ -11,62 +11,70 @@ class ProjectRAGGraph:
11
  self.memory = MemorySaver()
12
  self.workflow = self._build_graph()
13
 
14
- # ---------------- Retrieval ----------------
15
  def retrieve(self, state: GraphState):
16
- retriever = self.vector_store.as_retriever(
17
- search_type=Config.SEARCH_TYPE,
18
- search_kwargs={"k": Config.TOP_K}
19
- )
20
- docs = retriever.invoke(state["question"])
21
- return {"context": docs}
22
-
23
- # ---------------- Generation ----------------
24
- def generate(self, state: GraphState):
25
- prompt = ChatPromptTemplate.from_template("""
26
- You are a professional Project Analyst.
27
 
28
- Context:
29
- {context}
30
 
31
- Question:
32
- {question}
 
 
 
 
 
33
 
34
- Rules:
35
- - Answer ONLY from the context
36
- - Add citations at the end
37
- - Citation format: (Source: <file>, page <number>)
38
- """)
39
 
40
- context_text = []
41
- citations = set()
 
42
 
43
- for doc in state["context"]:
44
- page = doc.metadata.get("page", "N/A")
45
- source = doc.metadata.get("source", "Unknown")
46
- context_text.append(doc.page_content)
47
- citations.add(f"(Source: {source}, page {page + 1})")
48
 
49
- chain = prompt | self.llm
50
- response = chain.invoke({
51
- "context": "\n\n".join(context_text),
52
- "question": state["question"]
53
- })
54
 
55
- final_answer = response.content + "\n\n" + "\n".join(citations)
56
 
57
- return {"answer": final_answer}
 
58
 
59
- # ---------------- Graph ----------------
60
  def _build_graph(self):
61
  graph = StateGraph(GraphState)
 
62
  graph.add_node("retrieve", self.retrieve)
63
  graph.add_node("generate", self.generate)
 
64
  graph.set_entry_point("retrieve")
65
  graph.add_edge("retrieve", "generate")
66
  graph.add_edge("generate", END)
 
67
  return graph.compile(checkpointer=self.memory)
68
 
69
  def query(self, question: str, thread_id: str):
70
- config = {"configurable": {"thread_id": thread_id}}
71
- result = self.workflow.invoke({"question": question}, config=config)
72
- return result["answer"]
 
 
 
 
 
1
  from langgraph.graph import StateGraph, END
2
  from langgraph.checkpoint.memory import MemorySaver
3
  from langchain_core.prompts import ChatPromptTemplate
4
+ from core.state import GraphState
5
+ from config.config import Config
6
 
7
  class ProjectRAGGraph:
8
  def __init__(self, llm, vector_store):
 
11
  self.memory = MemorySaver()
12
  self.workflow = self._build_graph()
13
 
14
+ # -------- Nodes --------
15
  def retrieve(self, state: GraphState):
16
+ try:
17
+ retriever = self.vector_store.as_retriever(
18
+ search_type=Config.SEARCH_TYPE,
19
+ search_kwargs={
20
+ "k": Config.TOP_K,
21
+ "lambda_mult": Config.LAMBDA_MULT
22
+ }
23
+ )
24
+ docs = retriever.invoke(state["question"])
25
+ return {"context": docs}
 
26
 
27
+ except Exception as e:
28
+ raise RuntimeError(f"Retrieval failed: {e}")
29
 
30
+ def generate(self, state: GraphState):
31
+ try:
32
+ prompt = ChatPromptTemplate.from_template(
33
+ """
34
+ You are a professional Project Analyst.
35
+ Context:
36
+ {context}
37
 
38
+ Question:
39
+ {question}
 
 
 
40
 
41
+ Answer strictly using the context. Cite sources.
42
+ """
43
+ )
44
 
45
+ formatted_context = "\n\n".join(
46
+ doc.page_content for doc in state["context"]
47
+ )
 
 
48
 
49
+ chain = prompt | self.llm
50
+ response = chain.invoke({
51
+ "context": formatted_context,
52
+ "question": state["question"]
53
+ })
54
 
55
+ return {"answer": response.content}
56
 
57
+ except Exception as e:
58
+ raise RuntimeError(f"Answer generation failed: {e}")
59
 
60
+ # -------- Graph --------
61
  def _build_graph(self):
62
  graph = StateGraph(GraphState)
63
+
64
  graph.add_node("retrieve", self.retrieve)
65
  graph.add_node("generate", self.generate)
66
+
67
  graph.set_entry_point("retrieve")
68
  graph.add_edge("retrieve", "generate")
69
  graph.add_edge("generate", END)
70
+
71
  return graph.compile(checkpointer=self.memory)
72
 
73
  def query(self, question: str, thread_id: str):
74
+ try:
75
+ config = {"configurable": {"thread_id": thread_id}}
76
+ result = self.workflow.invoke({"question": question}, config=config)
77
+ return result["answer"]
78
+
79
+ except Exception as e:
80
+ raise RuntimeError(f"Graph execution failed: {e}")