File size: 4,259 Bytes
5aa2260
 
4088106
5aa2260
 
 
 
 
 
 
 
 
 
4088106
 
 
 
 
 
 
 
 
5aa2260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4088106
5aa2260
 
4088106
5aa2260
 
 
 
4088106
5aa2260
 
 
4088106
5aa2260
 
 
 
4088106
5aa2260
 
 
 
 
 
 
 
4088106
5aa2260
 
4088106
5aa2260
4088106
5aa2260
 
 
 
 
 
 
 
 
 
 
4088106
5aa2260
4088106
5aa2260
 
 
4088106
5aa2260
 
4088106
5aa2260
4088106
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import time
from datetime import datetime
from typing import TypedDict, List, Any
from langgraph.graph import StateGraph, END

from agents.planner   import run_planner
from agents.retriever import run_retriever
from agents.grader    import run_grader
from agents.generator import run_generator
from agents.critic    import run_critic


class GraphState(TypedDict):
    question:    str
    query_id:    str
    plan:        str
    documents:   List[Any]
    graded_docs: List[Any]
    generation:  str
    critique:    str
    verdict:     str
    timestamp:   str


class ResearchGraph:
    def __init__(self, vector_store, tracer):
        self.vs     = vector_store
        self.tracer = tracer
        self.graph  = self._build()

    def _planner_node(self, state: GraphState) -> dict:
        t0 = time.time()
        self.tracer.add(state["query_id"], "planner", "Planning research approach…", "running", 0)
        plan = run_planner(state["question"])
        ms   = int((time.time() - t0) * 1000)
        self.tracer.add(state["query_id"], "planner", plan[:200], "complete", ms)
        return {"plan": plan}

    def _retriever_node(self, state: GraphState) -> dict:
        t0 = time.time()
        self.tracer.add(state["query_id"], "retriever", "Searching knowledge base (FAISS + BM25)…", "running", 0)
        docs = run_retriever(state["question"], self.vs, k=5)
        ms   = int((time.time() - t0) * 1000)
        self.tracer.add(state["query_id"], "retriever", f"Retrieved {len(docs)} chunks.", "complete", ms)
        return {"documents": docs}

    def _grader_node(self, state: GraphState) -> dict:
        t0 = time.time()
        self.tracer.add(state["query_id"], "grader", f"Grading {len(state['documents'])} chunks for relevance…", "running", 0)
        graded = run_grader(state["question"], state["documents"])
        avg    = sum(d["grade"] for d in graded) / len(graded) if graded else 0.0
        ms     = int((time.time() - t0) * 1000)
        self.tracer.add(state["query_id"], "grader", f"Avg relevance: {avg:.2f}", "complete", ms)
        return {"graded_docs": graded}

    def _generator_node(self, state: GraphState) -> dict:
        t0 = time.time()
        self.tracer.add(state["query_id"], "generator", "Generating answer from relevant context…", "running", 0)
        good_docs = [d for d in state["graded_docs"] if d.get("grade", 0) >= 0.35] or state["graded_docs"]
        gen  = run_generator(state["question"], good_docs[:4])
        ms   = int((time.time() - t0) * 1000)
        self.tracer.add(state["query_id"], "generator", f"Answer generated ({len(gen)} chars).", "complete", ms)
        return {"generation": gen}

    def _critic_node(self, state: GraphState) -> dict:
        t0 = time.time()
        self.tracer.add(state["query_id"], "critic", "Evaluating answer quality…", "running", 0)
        result = run_critic(state["question"], state["generation"], state["graded_docs"])
        ms     = int((time.time() - t0) * 1000)
        label  = "✅ High confidence." if result["verdict"] == "APPROVED" else "⚠️ Low confidence — verify with source."
        self.tracer.add(state["query_id"], "critic", f"{label} {result['explanation'][:160]}", "complete", ms)
        return {"critique": result["explanation"], "verdict": result["verdict"]}

    def _build(self):
        wf = StateGraph(GraphState)
        wf.add_node("planner",   self._planner_node)
        wf.add_node("retriever", self._retriever_node)
        wf.add_node("grader",    self._grader_node)
        wf.add_node("generator", self._generator_node)
        wf.add_node("critic",    self._critic_node)
        wf.set_entry_point("planner")
        wf.add_edge("planner",   "retriever")
        wf.add_edge("retriever", "grader")
        wf.add_edge("grader",    "generator")
        wf.add_edge("generator", "critic")
        wf.add_edge("critic",    END)
        return wf.compile()

    def run(self, question: str, query_id: str) -> dict:
        init = GraphState(
            question=question, query_id=query_id, plan="",
            documents=[], graded_docs=[], generation="",
            critique="", verdict="", timestamp=datetime.utcnow().isoformat(),
        )
        return dict(self.graph.invoke(init))