File size: 3,026 Bytes
4302ded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# src/rag_graph.py
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.prompts import ChatPromptTemplate
from src.core.graph_state import GraphState
from src.core.embeddings import load_embeddings
from src.core.llm import load_llm
from src.vector_store.vector_store import build_vector_store
from src.config.config import K_OFFSET, MMR_LAMBDA
from src.exceptions import VectorStoreNotInitializedError, LLMInvocationError
class ProjectRAGGraph:
def __init__(self):
self.embeddings = load_embeddings()
self.llm = load_llm()
self.vector_store = None
self.pdf_count = 0
self.memory = MemorySaver()
self.workflow = self._build_graph()
def process_documents(self, pdf_paths, original_names=None):
self.pdf_count = len(pdf_paths)
self.vector_store = build_vector_store(
pdf_paths,
self.embeddings,
original_names
)
# ---------- Graph Nodes ----------
def retrieve(self, state: GraphState):
if not self.vector_store:
raise VectorStoreNotInitializedError("Vector store not initialized")
k_value = max(1, self.pdf_count + K_OFFSET)
retriever = self.vector_store.as_retriever(
search_type="mmr",
search_kwargs={"k": k_value, "lambda_mult": MMR_LAMBDA}
)
documents = retriever.invoke(state["question"])
return {"context": documents}
def generate(self, state: GraphState):
try:
prompt = ChatPromptTemplate.from_template(
"""
You are an expert Project Analyst.
Answer ONLY using the provided context.
If the answer is not present, say "I don't know".
Context:
{context}
Question:
{question}
"""
)
formatted_context = "\n\n".join(
doc.page_content for doc in state["context"]
)
chain = prompt | self.llm
response = chain.invoke({
"context": formatted_context,
"question": state["question"]
})
return {"answer": response.content}
except Exception as e:
raise LLMInvocationError(f"LLM failed: {e}")
# ---------- Graph Build ----------
def _build_graph(self):
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", self.retrieve)
workflow.add_node("generate", self.generate)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
return workflow.compile(checkpointer=self.memory)
def query(self, question: str, thread_id: str):
config = {"configurable": {"thread_id": thread_id}}
result = self.workflow.invoke({"question": question}, config=config)
return result["answer"]
|