|
|
"""DeepBoner research workflow definition using LangGraph.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from functools import partial |
|
|
from typing import Any |
|
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
from langgraph.checkpoint.base import BaseCheckpointSaver |
|
|
from langgraph.graph import END, StateGraph |
|
|
from langgraph.graph.state import CompiledStateGraph |
|
|
|
|
|
from src.agents.graph.nodes import ( |
|
|
judge_node, |
|
|
resolve_node, |
|
|
search_node, |
|
|
supervisor_node, |
|
|
synthesize_node, |
|
|
) |
|
|
from src.agents.graph.state import ResearchState |
|
|
from src.services.embedding_protocol import EmbeddingServiceProtocol |
|
|
|
|
|
|
|
|
def create_research_graph( |
|
|
llm: BaseChatModel | None = None, |
|
|
checkpointer: BaseCheckpointSaver[Any] | None = None, |
|
|
embedding_service: EmbeddingServiceProtocol | None = None, |
|
|
) -> CompiledStateGraph[Any]: |
|
|
"""Build the research state graph. |
|
|
|
|
|
Args: |
|
|
llm: The language model for the supervisor node. |
|
|
checkpointer: Optional persistence layer. |
|
|
embedding_service: Service for evidence storage and retrieval. |
|
|
""" |
|
|
graph = StateGraph(ResearchState) |
|
|
|
|
|
|
|
|
|
|
|
bound_supervisor = partial(supervisor_node, llm=llm) if llm else supervisor_node |
|
|
|
|
|
|
|
|
|
|
|
bound_search = ( |
|
|
partial(search_node, embedding_service=embedding_service) |
|
|
if embedding_service |
|
|
else search_node |
|
|
) |
|
|
bound_judge = ( |
|
|
partial(judge_node, embedding_service=embedding_service) |
|
|
if embedding_service |
|
|
else judge_node |
|
|
) |
|
|
bound_resolve = ( |
|
|
partial(resolve_node, embedding_service=embedding_service) |
|
|
if embedding_service |
|
|
else resolve_node |
|
|
) |
|
|
bound_synthesize = ( |
|
|
partial(synthesize_node, embedding_service=embedding_service) |
|
|
if embedding_service |
|
|
else synthesize_node |
|
|
) |
|
|
|
|
|
graph.add_node("supervisor", bound_supervisor) |
|
|
graph.add_node("search", bound_search) |
|
|
graph.add_node("judge", bound_judge) |
|
|
graph.add_node("resolve", bound_resolve) |
|
|
graph.add_node("synthesize", bound_synthesize) |
|
|
|
|
|
|
|
|
|
|
|
graph.add_edge("search", "supervisor") |
|
|
graph.add_edge("judge", "supervisor") |
|
|
graph.add_edge("resolve", "supervisor") |
|
|
|
|
|
|
|
|
graph.add_edge("synthesize", END) |
|
|
|
|
|
|
|
|
|
|
|
graph.add_conditional_edges( |
|
|
"supervisor", |
|
|
lambda state: state["next_step"], |
|
|
{ |
|
|
"search": "search", |
|
|
"judge": "judge", |
|
|
"resolve": "resolve", |
|
|
"synthesize": "synthesize", |
|
|
"finish": END, |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
graph.set_entry_point("supervisor") |
|
|
|
|
|
return graph.compile(checkpointer=checkpointer) |
|
|
|