File size: 2,528 Bytes
2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f 2bd85c8 4a5f44f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.prompts import ChatPromptTemplate
from core.state import GraphState
from config.config import Config
class ProjectRAGGraph:
def __init__(self, llm, vector_store):
self.llm = llm
self.vector_store = vector_store
self.memory = MemorySaver()
self.workflow = self._build_graph()
# -------- Nodes --------
def retrieve(self, state: GraphState):
try:
retriever = self.vector_store.as_retriever(
search_type=Config.SEARCH_TYPE,
search_kwargs={
"k": Config.TOP_K,
"lambda_mult": Config.LAMBDA_MULT
}
)
docs = retriever.invoke(state["question"])
return {"context": docs}
except Exception as e:
raise RuntimeError(f"Retrieval failed: {e}")
def generate(self, state: GraphState):
try:
prompt = ChatPromptTemplate.from_template(
"""
You are a professional Project Analyst.
Context:
{context}
Question:
{question}
Answer strictly using the context. Cite sources.
"""
)
formatted_context = "\n\n".join(
doc.page_content for doc in state["context"]
)
chain = prompt | self.llm
response = chain.invoke({
"context": formatted_context,
"question": state["question"]
})
return {"answer": response.content}
except Exception as e:
raise RuntimeError(f"Answer generation failed: {e}")
# -------- Graph --------
def _build_graph(self):
graph = StateGraph(GraphState)
graph.add_node("retrieve", self.retrieve)
graph.add_node("generate", self.generate)
graph.set_entry_point("retrieve")
graph.add_edge("retrieve", "generate")
graph.add_edge("generate", END)
return graph.compile(checkpointer=self.memory)
def query(self, question: str, thread_id: str):
try:
config = {"configurable": {"thread_id": thread_id}}
result = self.workflow.invoke({"question": question}, config=config)
return result["answer"]
except Exception as e:
raise RuntimeError(f"Graph execution failed: {e}")
|