File size: 3,022 Bytes
a820b5b
 
586a3f1
 
a820b5b
 
 
 
e0c585c
a820b5b
 
 
 
 
 
 
 
 
 
 
7baf8ba
a820b5b
 
 
f160233
586a3f1
7baf8ba
586a3f1
a820b5b
 
 
 
 
f160233
a820b5b
 
 
 
 
 
 
f160233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a820b5b
f160233
 
 
 
a820b5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""DeepBoner research workflow definition using LangGraph."""

from __future__ import annotations

from functools import partial
from typing import Any

from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph

from src.agents.graph.nodes import (
    judge_node,
    resolve_node,
    search_node,
    supervisor_node,
    synthesize_node,
)
from src.agents.graph.state import ResearchState
from src.services.embedding_protocol import EmbeddingServiceProtocol


def create_research_graph(
    llm: BaseChatModel | None = None,
    checkpointer: BaseCheckpointSaver[Any] | None = None,
    embedding_service: EmbeddingServiceProtocol | None = None,
) -> CompiledStateGraph[Any]:
    """Build the research state graph.

    Args:
        llm: The language model for the supervisor node.
        checkpointer: Optional persistence layer.
        embedding_service: Service for evidence storage and retrieval.
    """
    graph = StateGraph(ResearchState)

    # --- Nodes ---
    # Bind the LLM to the supervisor node using partial
    bound_supervisor = partial(supervisor_node, llm=llm) if llm else supervisor_node

    # Bind embedding service to worker nodes
    # We use partial to inject the service dependency while keeping the node signature clean
    bound_search = (
        partial(search_node, embedding_service=embedding_service)
        if embedding_service
        else search_node
    )
    bound_judge = (
        partial(judge_node, embedding_service=embedding_service)
        if embedding_service
        else judge_node
    )
    bound_resolve = (
        partial(resolve_node, embedding_service=embedding_service)
        if embedding_service
        else resolve_node
    )
    bound_synthesize = (
        partial(synthesize_node, embedding_service=embedding_service)
        if embedding_service
        else synthesize_node
    )

    graph.add_node("supervisor", bound_supervisor)
    graph.add_node("search", bound_search)
    graph.add_node("judge", bound_judge)
    graph.add_node("resolve", bound_resolve)
    graph.add_node("synthesize", bound_synthesize)

    # --- Edges ---
    # All worker nodes report back to supervisor
    graph.add_edge("search", "supervisor")
    graph.add_edge("judge", "supervisor")
    graph.add_edge("resolve", "supervisor")

    # Synthesis is the end
    graph.add_edge("synthesize", END)

    # --- Conditional Routing ---
    # Supervisor decides where to go next based on state["next_step"]
    graph.add_conditional_edges(
        "supervisor",
        lambda state: state["next_step"],
        {
            "search": "search",
            "judge": "judge",
            "resolve": "resolve",
            "synthesize": "synthesize",
            "finish": END,
        },
    )

    # Entry Point
    graph.set_entry_point("supervisor")

    return graph.compile(checkpointer=checkpointer)