from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver from langchain_core.prompts import ChatPromptTemplate from core.state import GraphState from config.config import Config class ProjectRAGGraph: def __init__(self, llm, vector_store): self.llm = llm self.vector_store = vector_store self.memory = MemorySaver() self.workflow = self._build_graph() # -------- Nodes -------- def retrieve(self, state: GraphState): try: retriever = self.vector_store.as_retriever( search_type=Config.SEARCH_TYPE, search_kwargs={ "k": Config.TOP_K, "lambda_mult": Config.LAMBDA_MULT } ) docs = retriever.invoke(state["question"]) return {"context": docs} except Exception as e: raise RuntimeError(f"Retrieval failed: {e}") def generate(self, state: GraphState): try: prompt = ChatPromptTemplate.from_template( """ You are a professional Project Analyst. Context: {context} Question: {question} Answer strictly using the context. Cite sources. """ ) formatted_context = "\n\n".join( doc.page_content for doc in state["context"] ) chain = prompt | self.llm response = chain.invoke({ "context": formatted_context, "question": state["question"] }) return {"answer": response.content} except Exception as e: raise RuntimeError(f"Answer generation failed: {e}") # -------- Graph -------- def _build_graph(self): graph = StateGraph(GraphState) graph.add_node("retrieve", self.retrieve) graph.add_node("generate", self.generate) graph.set_entry_point("retrieve") graph.add_edge("retrieve", "generate") graph.add_edge("generate", END) return graph.compile(checkpointer=self.memory) def query(self, question: str, thread_id: str): try: config = {"configurable": {"thread_id": thread_id}} result = self.workflow.invoke({"question": question}, config=config) return result["answer"] except Exception as e: raise RuntimeError(f"Graph execution failed: {e}")