Spaces:
Sleeping
Sleeping
| import time | |
| from datetime import datetime | |
| from typing import TypedDict, List, Any | |
| from langgraph.graph import StateGraph, END | |
| from agents.planner import run_planner | |
| from agents.retriever import run_retriever | |
| from agents.grader import run_grader | |
| from agents.generator import run_generator | |
| from agents.critic import run_critic | |
| class GraphState(TypedDict): | |
| question: str | |
| query_id: str | |
| plan: str | |
| documents: List[Any] | |
| graded_docs: List[Any] | |
| generation: str | |
| critique: str | |
| verdict: str | |
| timestamp: str | |
| class ResearchGraph: | |
| def __init__(self, vector_store, tracer): | |
| self.vs = vector_store | |
| self.tracer = tracer | |
| self.graph = self._build() | |
| def _planner_node(self, state: GraphState) -> dict: | |
| t0 = time.time() | |
| self.tracer.add(state["query_id"], "planner", "Planning research approach…", "running", 0) | |
| plan = run_planner(state["question"]) | |
| ms = int((time.time() - t0) * 1000) | |
| self.tracer.add(state["query_id"], "planner", plan[:200], "complete", ms) | |
| return {"plan": plan} | |
| def _retriever_node(self, state: GraphState) -> dict: | |
| t0 = time.time() | |
| self.tracer.add(state["query_id"], "retriever", "Searching knowledge base (FAISS + BM25)…", "running", 0) | |
| docs = run_retriever(state["question"], self.vs, k=5) | |
| ms = int((time.time() - t0) * 1000) | |
| self.tracer.add(state["query_id"], "retriever", f"Retrieved {len(docs)} chunks.", "complete", ms) | |
| return {"documents": docs} | |
| def _grader_node(self, state: GraphState) -> dict: | |
| t0 = time.time() | |
| self.tracer.add(state["query_id"], "grader", f"Grading {len(state['documents'])} chunks for relevance…", "running", 0) | |
| graded = run_grader(state["question"], state["documents"]) | |
| avg = sum(d["grade"] for d in graded) / len(graded) if graded else 0.0 | |
| ms = int((time.time() - t0) * 1000) | |
| self.tracer.add(state["query_id"], "grader", f"Avg relevance: {avg:.2f}", "complete", ms) | |
| return {"graded_docs": graded} | |
| def _generator_node(self, state: GraphState) -> dict: | |
| t0 = time.time() | |
| self.tracer.add(state["query_id"], "generator", "Generating answer from relevant context…", "running", 0) | |
| good_docs = [d for d in state["graded_docs"] if d.get("grade", 0) >= 0.35] or state["graded_docs"] | |
| gen = run_generator(state["question"], good_docs[:4]) | |
| ms = int((time.time() - t0) * 1000) | |
| self.tracer.add(state["query_id"], "generator", f"Answer generated ({len(gen)} chars).", "complete", ms) | |
| return {"generation": gen} | |
| def _critic_node(self, state: GraphState) -> dict: | |
| t0 = time.time() | |
| self.tracer.add(state["query_id"], "critic", "Evaluating answer quality…", "running", 0) | |
| result = run_critic(state["question"], state["generation"], state["graded_docs"]) | |
| ms = int((time.time() - t0) * 1000) | |
| label = "✅ High confidence." if result["verdict"] == "APPROVED" else "⚠️ Low confidence — verify with source." | |
| self.tracer.add(state["query_id"], "critic", f"{label} {result['explanation'][:160]}", "complete", ms) | |
| return {"critique": result["explanation"], "verdict": result["verdict"]} | |
| def _build(self): | |
| wf = StateGraph(GraphState) | |
| wf.add_node("planner", self._planner_node) | |
| wf.add_node("retriever", self._retriever_node) | |
| wf.add_node("grader", self._grader_node) | |
| wf.add_node("generator", self._generator_node) | |
| wf.add_node("critic", self._critic_node) | |
| wf.set_entry_point("planner") | |
| wf.add_edge("planner", "retriever") | |
| wf.add_edge("retriever", "grader") | |
| wf.add_edge("grader", "generator") | |
| wf.add_edge("generator", "critic") | |
| wf.add_edge("critic", END) | |
| return wf.compile() | |
| def run(self, question: str, query_id: str) -> dict: | |
| init = GraphState( | |
| question=question, query_id=query_id, plan="", | |
| documents=[], graded_docs=[], generation="", | |
| critique="", verdict="", timestamp=datetime.utcnow().isoformat(), | |
| ) | |
| return dict(self.graph.invoke(init)) | |