|
|
from langgraph.graph import StateGraph, END |
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from core.state import GraphState |
|
|
from config.config import Config |
|
|
|
|
|
class ProjectRAGGraph: |
|
|
def __init__(self, llm, vector_store): |
|
|
self.llm = llm |
|
|
self.vector_store = vector_store |
|
|
self.memory = MemorySaver() |
|
|
self.workflow = self._build_graph() |
|
|
|
|
|
|
|
|
def retrieve(self, state: GraphState): |
|
|
try: |
|
|
retriever = self.vector_store.as_retriever( |
|
|
search_type=Config.SEARCH_TYPE, |
|
|
search_kwargs={ |
|
|
"k": Config.TOP_K, |
|
|
"lambda_mult": Config.LAMBDA_MULT |
|
|
} |
|
|
) |
|
|
docs = retriever.invoke(state["question"]) |
|
|
return {"context": docs} |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Retrieval failed: {e}") |
|
|
|
|
|
def generate(self, state: GraphState): |
|
|
try: |
|
|
prompt = ChatPromptTemplate.from_template( |
|
|
""" |
|
|
You are a professional Project Analyst. |
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: |
|
|
{question} |
|
|
|
|
|
Answer strictly using the context. Cite sources. |
|
|
""" |
|
|
) |
|
|
|
|
|
formatted_context = "\n\n".join( |
|
|
doc.page_content for doc in state["context"] |
|
|
) |
|
|
|
|
|
chain = prompt | self.llm |
|
|
response = chain.invoke({ |
|
|
"context": formatted_context, |
|
|
"question": state["question"] |
|
|
}) |
|
|
|
|
|
return {"answer": response.content} |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Answer generation failed: {e}") |
|
|
|
|
|
|
|
|
def _build_graph(self): |
|
|
graph = StateGraph(GraphState) |
|
|
|
|
|
graph.add_node("retrieve", self.retrieve) |
|
|
graph.add_node("generate", self.generate) |
|
|
|
|
|
graph.set_entry_point("retrieve") |
|
|
graph.add_edge("retrieve", "generate") |
|
|
graph.add_edge("generate", END) |
|
|
|
|
|
return graph.compile(checkpointer=self.memory) |
|
|
|
|
|
def query(self, question: str, thread_id: str): |
|
|
try: |
|
|
config = {"configurable": {"thread_id": thread_id}} |
|
|
result = self.workflow.invoke({"question": question}, config=config) |
|
|
return result["answer"] |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Graph execution failed: {e}") |
|
|
|