docmind / graph /research_graph.py
mnoorchenar's picture
Update 2026-03-22 16:13:13
4088106
import time
from datetime import datetime
from typing import TypedDict, List, Any
from langgraph.graph import StateGraph, END
from agents.planner import run_planner
from agents.retriever import run_retriever
from agents.grader import run_grader
from agents.generator import run_generator
from agents.critic import run_critic
class GraphState(TypedDict):
question: str
query_id: str
plan: str
documents: List[Any]
graded_docs: List[Any]
generation: str
critique: str
verdict: str
timestamp: str
class ResearchGraph:
def __init__(self, vector_store, tracer):
self.vs = vector_store
self.tracer = tracer
self.graph = self._build()
def _planner_node(self, state: GraphState) -> dict:
t0 = time.time()
self.tracer.add(state["query_id"], "planner", "Planning research approach…", "running", 0)
plan = run_planner(state["question"])
ms = int((time.time() - t0) * 1000)
self.tracer.add(state["query_id"], "planner", plan[:200], "complete", ms)
return {"plan": plan}
def _retriever_node(self, state: GraphState) -> dict:
t0 = time.time()
self.tracer.add(state["query_id"], "retriever", "Searching knowledge base (FAISS + BM25)…", "running", 0)
docs = run_retriever(state["question"], self.vs, k=5)
ms = int((time.time() - t0) * 1000)
self.tracer.add(state["query_id"], "retriever", f"Retrieved {len(docs)} chunks.", "complete", ms)
return {"documents": docs}
def _grader_node(self, state: GraphState) -> dict:
t0 = time.time()
self.tracer.add(state["query_id"], "grader", f"Grading {len(state['documents'])} chunks for relevance…", "running", 0)
graded = run_grader(state["question"], state["documents"])
avg = sum(d["grade"] for d in graded) / len(graded) if graded else 0.0
ms = int((time.time() - t0) * 1000)
self.tracer.add(state["query_id"], "grader", f"Avg relevance: {avg:.2f}", "complete", ms)
return {"graded_docs": graded}
def _generator_node(self, state: GraphState) -> dict:
t0 = time.time()
self.tracer.add(state["query_id"], "generator", "Generating answer from relevant context…", "running", 0)
good_docs = [d for d in state["graded_docs"] if d.get("grade", 0) >= 0.35] or state["graded_docs"]
gen = run_generator(state["question"], good_docs[:4])
ms = int((time.time() - t0) * 1000)
self.tracer.add(state["query_id"], "generator", f"Answer generated ({len(gen)} chars).", "complete", ms)
return {"generation": gen}
def _critic_node(self, state: GraphState) -> dict:
t0 = time.time()
self.tracer.add(state["query_id"], "critic", "Evaluating answer quality…", "running", 0)
result = run_critic(state["question"], state["generation"], state["graded_docs"])
ms = int((time.time() - t0) * 1000)
label = "✅ High confidence." if result["verdict"] == "APPROVED" else "⚠️ Low confidence — verify with source."
self.tracer.add(state["query_id"], "critic", f"{label} {result['explanation'][:160]}", "complete", ms)
return {"critique": result["explanation"], "verdict": result["verdict"]}
def _build(self):
wf = StateGraph(GraphState)
wf.add_node("planner", self._planner_node)
wf.add_node("retriever", self._retriever_node)
wf.add_node("grader", self._grader_node)
wf.add_node("generator", self._generator_node)
wf.add_node("critic", self._critic_node)
wf.set_entry_point("planner")
wf.add_edge("planner", "retriever")
wf.add_edge("retriever", "grader")
wf.add_edge("grader", "generator")
wf.add_edge("generator", "critic")
wf.add_edge("critic", END)
return wf.compile()
def run(self, question: str, query_id: str) -> dict:
init = GraphState(
question=question, query_id=query_id, plan="",
documents=[], graded_docs=[], generation="",
critique="", verdict="", timestamp=datetime.utcnow().isoformat(),
)
return dict(self.graph.invoke(init))