Demo_1 / src /graph /rag_graph.py
Dinesh310's picture
Update src/graph/rag_graph.py
4a5f44f verified
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.prompts import ChatPromptTemplate
from core.state import GraphState
from config.config import Config
class ProjectRAGGraph:
def __init__(self, llm, vector_store):
self.llm = llm
self.vector_store = vector_store
self.memory = MemorySaver()
self.workflow = self._build_graph()
# -------- Nodes --------
def retrieve(self, state: GraphState):
try:
retriever = self.vector_store.as_retriever(
search_type=Config.SEARCH_TYPE,
search_kwargs={
"k": Config.TOP_K,
"lambda_mult": Config.LAMBDA_MULT
}
)
docs = retriever.invoke(state["question"])
return {"context": docs}
except Exception as e:
raise RuntimeError(f"Retrieval failed: {e}")
def generate(self, state: GraphState):
try:
prompt = ChatPromptTemplate.from_template(
"""
You are a professional Project Analyst.
Context:
{context}
Question:
{question}
Answer strictly using the context. Cite sources.
"""
)
formatted_context = "\n\n".join(
doc.page_content for doc in state["context"]
)
chain = prompt | self.llm
response = chain.invoke({
"context": formatted_context,
"question": state["question"]
})
return {"answer": response.content}
except Exception as e:
raise RuntimeError(f"Answer generation failed: {e}")
# -------- Graph --------
def _build_graph(self):
graph = StateGraph(GraphState)
graph.add_node("retrieve", self.retrieve)
graph.add_node("generate", self.generate)
graph.set_entry_point("retrieve")
graph.add_edge("retrieve", "generate")
graph.add_edge("generate", END)
return graph.compile(checkpointer=self.memory)
def query(self, question: str, thread_id: str):
try:
config = {"configurable": {"thread_id": thread_id}}
result = self.workflow.invoke({"question": question}, config=config)
return result["answer"]
except Exception as e:
raise RuntimeError(f"Graph execution failed: {e}")