Spaces:
Runtime error
Runtime error
File size: 4,319 Bytes
927c050 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """Graph builder module. Assembles the complete multi-agent LangGraph workflow."""
import logging
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode, create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from langgraph.store.memory import InMemoryStore
from src.state import State
from src.tools import music_tools, invoice_tools
from src.agents.prompts import INVOICE_SUBAGENT_PROMPT, SUPERVISOR_PROMPT
from src.agents.nodes import (
create_music_assistant_node,
should_continue,
should_interrupt,
create_verify_info_node,
human_input,
load_memory,
create_memory_node,
)
logger = logging.getLogger(__name__)
def build_graph(
model_name: str = "gpt-4o-mini",
temperature: float = 0,
openai_api_key: str = None,
openai_api_base: str = None,
):
llm_kwargs = {
"model": model_name,
"temperature": temperature,
}
if openai_api_key:
llm_kwargs["api_key"] = openai_api_key
if openai_api_base:
llm_kwargs["base_url"] = openai_api_base
llm = ChatOpenAI(**llm_kwargs)
logger.info(f"LLM initialized: {model_name}, temperature={temperature}")
# NOTE: Both stores are in-memory only — all data is lost on restart.
# For production, replace with SqliteSaver / persistent store.
in_memory_store = InMemoryStore()
checkpointer = MemorySaver()
# Music Catalog Sub-Agent (hand-built ReAct)
music_assistant_fn = create_music_assistant_node(llm, music_tools)
music_tool_node = ToolNode(music_tools)
music_workflow = StateGraph(State)
music_workflow.add_node("music_assistant", music_assistant_fn)
music_workflow.add_node("music_tool_node", music_tool_node)
music_workflow.add_edge(START, "music_assistant")
music_workflow.add_conditional_edges(
"music_assistant",
should_continue,
{"continue": "music_tool_node", "end": END},
)
music_workflow.add_edge("music_tool_node", "music_assistant")
music_catalog_subagent = music_workflow.compile(
name="music_catalog_subagent",
checkpointer=checkpointer,
store=in_memory_store,
)
logger.info("Music catalog sub-agent compiled.")
# Invoice Information Sub-Agent (pre-built ReAct)
invoice_information_subagent = create_react_agent(
llm,
tools=invoice_tools,
name="invoice_information_subagent",
prompt=INVOICE_SUBAGENT_PROMPT,
state_schema=State,
checkpointer=checkpointer,
store=in_memory_store,
)
logger.info("Invoice information sub-agent compiled.")
# Supervisor
from langgraph_supervisor import create_supervisor
supervisor_workflow = create_supervisor(
agents=[invoice_information_subagent, music_catalog_subagent],
output_mode="last_message",
model=llm,
prompt=SUPERVISOR_PROMPT,
state_schema=State,
)
supervisor_prebuilt = supervisor_workflow.compile(
name="supervisor",
checkpointer=checkpointer,
store=in_memory_store,
)
logger.info("Supervisor compiled.")
# Final Multi-Agent Graph
verify_info_fn = create_verify_info_node(llm)
create_memory_fn = create_memory_node(llm)
multi_agent = StateGraph(State)
multi_agent.add_node("verify_info", verify_info_fn)
multi_agent.add_node("human_input", human_input)
multi_agent.add_node("load_memory", load_memory)
multi_agent.add_node("supervisor", supervisor_prebuilt)
multi_agent.add_node("create_memory", create_memory_fn)
multi_agent.add_edge(START, "verify_info")
multi_agent.add_conditional_edges(
"verify_info",
should_interrupt,
{"continue": "load_memory", "interrupt": "human_input"},
)
multi_agent.add_edge("human_input", "verify_info")
multi_agent.add_edge("load_memory", "supervisor")
multi_agent.add_edge("supervisor", "create_memory")
multi_agent.add_edge("create_memory", END)
compiled_graph = multi_agent.compile(
name="multi_agent_final",
checkpointer=checkpointer,
store=in_memory_store,
)
logger.info("Final multi-agent graph compiled successfully.")
return compiled_graph, checkpointer, in_memory_store
|