Dinesh310 commited on
Commit
2bd85c8
·
verified ·
1 Parent(s): dc6497c

Create graph/rag_graph.py

Browse files
Files changed (1) hide show
  1. src/graph/rag_graph.py +80 -0
src/graph/rag_graph.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import StateGraph, END
2
+ from langgraph.checkpoint.memory import MemorySaver
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from core.state import GraphState
5
+ from config.config import Config
6
+
7
+ class ProjectRAGGraph:
8
+ def __init__(self, llm, vector_store):
9
+ self.llm = llm
10
+ self.vector_store = vector_store
11
+ self.memory = MemorySaver()
12
+ self.workflow = self._build_graph()
13
+
14
+ # -------- Nodes --------
15
+ def retrieve(self, state: GraphState):
16
+ try:
17
+ retriever = self.vector_store.as_retriever(
18
+ search_type=Config.SEARCH_TYPE,
19
+ search_kwargs={
20
+ "k": Config.TOP_K,
21
+ "lambda_mult": Config.LAMBDA_MULT
22
+ }
23
+ )
24
+ docs = retriever.invoke(state["question"])
25
+ return {"context": docs}
26
+
27
+ except Exception as e:
28
+ raise RuntimeError(f"Retrieval failed: {e}")
29
+
30
+ def generate(self, state: GraphState):
31
+ try:
32
+ prompt = ChatPromptTemplate.from_template(
33
+ """
34
+ You are a professional Project Analyst.
35
+ Context:
36
+ {context}
37
+
38
+ Question:
39
+ {question}
40
+
41
+ Answer strictly using the context. Cite sources.
42
+ """
43
+ )
44
+
45
+ formatted_context = "\n\n".join(
46
+ doc.page_content for doc in state["context"]
47
+ )
48
+
49
+ chain = prompt | self.llm
50
+ response = chain.invoke({
51
+ "context": formatted_context,
52
+ "question": state["question"]
53
+ })
54
+
55
+ return {"answer": response.content}
56
+
57
+ except Exception as e:
58
+ raise RuntimeError(f"Answer generation failed: {e}")
59
+
60
+ # -------- Graph --------
61
+ def _build_graph(self):
62
+ graph = StateGraph(GraphState)
63
+
64
+ graph.add_node("retrieve", self.retrieve)
65
+ graph.add_node("generate", self.generate)
66
+
67
+ graph.set_entry_point("retrieve")
68
+ graph.add_edge("retrieve", "generate")
69
+ graph.add_edge("generate", END)
70
+
71
+ return graph.compile(checkpointer=self.memory)
72
+
73
+ def query(self, question: str, thread_id: str):
74
+ try:
75
+ config = {"configurable": {"thread_id": thread_id}}
76
+ result = self.workflow.invoke({"question": question}, config=config)
77
+ return result["answer"]
78
+
79
+ except Exception as e:
80
+ raise RuntimeError(f"Graph execution failed: {e}")