Spaces:
Sleeping
Sleeping
| # src/rag_graph.py | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from src.core.graph_state import GraphState | |
| from src.core.embeddings import load_embeddings | |
| from src.core.llm import load_llm | |
| from src.vector_store.vector_store import build_vector_store | |
| from src.config.config import K_OFFSET, MMR_LAMBDA | |
| from src.exceptions import VectorStoreNotInitializedError, LLMInvocationError | |
| class ProjectRAGGraph: | |
| def __init__(self): | |
| self.embeddings = load_embeddings() | |
| self.llm = load_llm() | |
| self.vector_store = None | |
| self.pdf_count = 0 | |
| self.memory = MemorySaver() | |
| self.workflow = self._build_graph() | |
| def process_documents(self, pdf_paths, original_names=None): | |
| self.pdf_count = len(pdf_paths) | |
| self.vector_store = build_vector_store( | |
| pdf_paths, | |
| self.embeddings, | |
| original_names | |
| ) | |
| # ---------- Graph Nodes ---------- | |
| def retrieve(self, state: GraphState): | |
| if not self.vector_store: | |
| raise VectorStoreNotInitializedError("Vector store not initialized") | |
| k_value = max(1, self.pdf_count + K_OFFSET) | |
| retriever = self.vector_store.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": k_value, "lambda_mult": MMR_LAMBDA} | |
| ) | |
| documents = retriever.invoke(state["question"]) | |
| return {"context": documents} | |
| def generate(self, state: GraphState): | |
| try: | |
| prompt = ChatPromptTemplate.from_template( | |
| """ | |
| You are an expert Project Analyst. | |
| Answer ONLY using the provided context. | |
| If the answer is not present, say "I don't know". | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| """ | |
| ) | |
| formatted_context = "\n\n".join( | |
| doc.page_content for doc in state["context"] | |
| ) | |
| chain = prompt | self.llm | |
| response = chain.invoke({ | |
| "context": formatted_context, | |
| "question": state["question"] | |
| }) | |
| return {"answer": response.content} | |
| except Exception as e: | |
| raise LLMInvocationError(f"LLM failed: {e}") | |
| # ---------- Graph Build ---------- | |
| def _build_graph(self): | |
| workflow = StateGraph(GraphState) | |
| workflow.add_node("retrieve", self.retrieve) | |
| workflow.add_node("generate", self.generate) | |
| workflow.set_entry_point("retrieve") | |
| workflow.add_edge("retrieve", "generate") | |
| workflow.add_edge("generate", END) | |
| return workflow.compile(checkpointer=self.memory) | |
| def query(self, question: str, thread_id: str): | |
| config = {"configurable": {"thread_id": thread_id}} | |
| result = self.workflow.invoke({"question": question}, config=config) | |
| return result["answer"] | |