# src/rag_graph.py from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver from langchain_core.prompts import ChatPromptTemplate from src.core.graph_state import GraphState from src.core.embeddings import load_embeddings from src.core.llm import load_llm from src.vector_store.vector_store import build_vector_store from src.config.config import K_OFFSET, MMR_LAMBDA from src.exceptions import VectorStoreNotInitializedError, LLMInvocationError class ProjectRAGGraph: def __init__(self): self.embeddings = load_embeddings() self.llm = load_llm() self.vector_store = None self.pdf_count = 0 self.memory = MemorySaver() self.workflow = self._build_graph() def process_documents(self, pdf_paths, original_names=None): self.pdf_count = len(pdf_paths) self.vector_store = build_vector_store( pdf_paths, self.embeddings, original_names ) # ---------- Graph Nodes ---------- def retrieve(self, state: GraphState): if not self.vector_store: raise VectorStoreNotInitializedError("Vector store not initialized") k_value = max(1, self.pdf_count + K_OFFSET) retriever = self.vector_store.as_retriever( search_type="mmr", search_kwargs={"k": k_value, "lambda_mult": MMR_LAMBDA} ) documents = retriever.invoke(state["question"]) return {"context": documents} def generate(self, state: GraphState): try: prompt = ChatPromptTemplate.from_template( """ You are an expert Project Analyst. Answer ONLY using the provided context. If the answer is not present, say "I don't know". Context: {context} Question: {question} """ ) formatted_context = "\n\n".join( doc.page_content for doc in state["context"] ) chain = prompt | self.llm response = chain.invoke({ "context": formatted_context, "question": state["question"] }) return {"answer": response.content} except Exception as e: raise LLMInvocationError(f"LLM failed: {e}") # ---------- Graph Build ---------- def _build_graph(self): workflow = StateGraph(GraphState) workflow.add_node("retrieve", self.retrieve) workflow.add_node("generate", self.generate) workflow.set_entry_point("retrieve") workflow.add_edge("retrieve", "generate") workflow.add_edge("generate", END) return workflow.compile(checkpointer=self.memory) def query(self, question: str, thread_id: str): config = {"configurable": {"thread_id": thread_id}} result = self.workflow.invoke({"question": question}, config=config) return result["answer"]