File size: 3,022 Bytes
a820b5b 586a3f1 a820b5b e0c585c a820b5b 7baf8ba a820b5b f160233 586a3f1 7baf8ba 586a3f1 a820b5b f160233 a820b5b f160233 a820b5b f160233 a820b5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
"""DeepBoner research workflow definition using LangGraph."""
from __future__ import annotations
from functools import partial
from typing import Any
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from src.agents.graph.nodes import (
judge_node,
resolve_node,
search_node,
supervisor_node,
synthesize_node,
)
from src.agents.graph.state import ResearchState
from src.services.embedding_protocol import EmbeddingServiceProtocol
def create_research_graph(
llm: BaseChatModel | None = None,
checkpointer: BaseCheckpointSaver[Any] | None = None,
embedding_service: EmbeddingServiceProtocol | None = None,
) -> CompiledStateGraph[Any]:
"""Build the research state graph.
Args:
llm: The language model for the supervisor node.
checkpointer: Optional persistence layer.
embedding_service: Service for evidence storage and retrieval.
"""
graph = StateGraph(ResearchState)
# --- Nodes ---
# Bind the LLM to the supervisor node using partial
bound_supervisor = partial(supervisor_node, llm=llm) if llm else supervisor_node
# Bind embedding service to worker nodes
# We use partial to inject the service dependency while keeping the node signature clean
bound_search = (
partial(search_node, embedding_service=embedding_service)
if embedding_service
else search_node
)
bound_judge = (
partial(judge_node, embedding_service=embedding_service)
if embedding_service
else judge_node
)
bound_resolve = (
partial(resolve_node, embedding_service=embedding_service)
if embedding_service
else resolve_node
)
bound_synthesize = (
partial(synthesize_node, embedding_service=embedding_service)
if embedding_service
else synthesize_node
)
graph.add_node("supervisor", bound_supervisor)
graph.add_node("search", bound_search)
graph.add_node("judge", bound_judge)
graph.add_node("resolve", bound_resolve)
graph.add_node("synthesize", bound_synthesize)
# --- Edges ---
# All worker nodes report back to supervisor
graph.add_edge("search", "supervisor")
graph.add_edge("judge", "supervisor")
graph.add_edge("resolve", "supervisor")
# Synthesis is the end
graph.add_edge("synthesize", END)
# --- Conditional Routing ---
# Supervisor decides where to go next based on state["next_step"]
graph.add_conditional_edges(
"supervisor",
lambda state: state["next_step"],
{
"search": "search",
"judge": "judge",
"resolve": "resolve",
"synthesize": "synthesize",
"finish": END,
},
)
# Entry Point
graph.set_entry_point("supervisor")
return graph.compile(checkpointer=checkpointer)
|