Spaces:
Sleeping
Sleeping
File size: 4,259 Bytes
5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 5aa2260 4088106 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import time
from datetime import datetime
from typing import TypedDict, List, Any
from langgraph.graph import StateGraph, END
from agents.planner import run_planner
from agents.retriever import run_retriever
from agents.grader import run_grader
from agents.generator import run_generator
from agents.critic import run_critic
class GraphState(TypedDict):
question: str
query_id: str
plan: str
documents: List[Any]
graded_docs: List[Any]
generation: str
critique: str
verdict: str
timestamp: str
class ResearchGraph:
def __init__(self, vector_store, tracer):
self.vs = vector_store
self.tracer = tracer
self.graph = self._build()
def _planner_node(self, state: GraphState) -> dict:
t0 = time.time()
self.tracer.add(state["query_id"], "planner", "Planning research approach…", "running", 0)
plan = run_planner(state["question"])
ms = int((time.time() - t0) * 1000)
self.tracer.add(state["query_id"], "planner", plan[:200], "complete", ms)
return {"plan": plan}
def _retriever_node(self, state: GraphState) -> dict:
t0 = time.time()
self.tracer.add(state["query_id"], "retriever", "Searching knowledge base (FAISS + BM25)…", "running", 0)
docs = run_retriever(state["question"], self.vs, k=5)
ms = int((time.time() - t0) * 1000)
self.tracer.add(state["query_id"], "retriever", f"Retrieved {len(docs)} chunks.", "complete", ms)
return {"documents": docs}
def _grader_node(self, state: GraphState) -> dict:
t0 = time.time()
self.tracer.add(state["query_id"], "grader", f"Grading {len(state['documents'])} chunks for relevance…", "running", 0)
graded = run_grader(state["question"], state["documents"])
avg = sum(d["grade"] for d in graded) / len(graded) if graded else 0.0
ms = int((time.time() - t0) * 1000)
self.tracer.add(state["query_id"], "grader", f"Avg relevance: {avg:.2f}", "complete", ms)
return {"graded_docs": graded}
def _generator_node(self, state: GraphState) -> dict:
t0 = time.time()
self.tracer.add(state["query_id"], "generator", "Generating answer from relevant context…", "running", 0)
good_docs = [d for d in state["graded_docs"] if d.get("grade", 0) >= 0.35] or state["graded_docs"]
gen = run_generator(state["question"], good_docs[:4])
ms = int((time.time() - t0) * 1000)
self.tracer.add(state["query_id"], "generator", f"Answer generated ({len(gen)} chars).", "complete", ms)
return {"generation": gen}
def _critic_node(self, state: GraphState) -> dict:
t0 = time.time()
self.tracer.add(state["query_id"], "critic", "Evaluating answer quality…", "running", 0)
result = run_critic(state["question"], state["generation"], state["graded_docs"])
ms = int((time.time() - t0) * 1000)
label = "✅ High confidence." if result["verdict"] == "APPROVED" else "⚠️ Low confidence — verify with source."
self.tracer.add(state["query_id"], "critic", f"{label} {result['explanation'][:160]}", "complete", ms)
return {"critique": result["explanation"], "verdict": result["verdict"]}
def _build(self):
wf = StateGraph(GraphState)
wf.add_node("planner", self._planner_node)
wf.add_node("retriever", self._retriever_node)
wf.add_node("grader", self._grader_node)
wf.add_node("generator", self._generator_node)
wf.add_node("critic", self._critic_node)
wf.set_entry_point("planner")
wf.add_edge("planner", "retriever")
wf.add_edge("retriever", "grader")
wf.add_edge("grader", "generator")
wf.add_edge("generator", "critic")
wf.add_edge("critic", END)
return wf.compile()
def run(self, question: str, query_id: str) -> dict:
init = GraphState(
question=question, query_id=query_id, plan="",
documents=[], graded_docs=[], generation="",
critique="", verdict="", timestamp=datetime.utcnow().isoformat(),
)
return dict(self.graph.invoke(init))
|