File size: 4,319 Bytes
927c050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Graph builder module. Assembles the complete multi-agent LangGraph workflow."""

import logging
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode, create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from langgraph.store.memory import InMemoryStore

from src.state import State
from src.tools import music_tools, invoice_tools
from src.agents.prompts import INVOICE_SUBAGENT_PROMPT, SUPERVISOR_PROMPT
from src.agents.nodes import (
    create_music_assistant_node,
    should_continue,
    should_interrupt,
    create_verify_info_node,
    human_input,
    load_memory,
    create_memory_node,
)

logger = logging.getLogger(__name__)


def build_graph(
    model_name: str = "gpt-4o-mini",
    temperature: float = 0,
    openai_api_key: str = None,
    openai_api_base: str = None,
):
    llm_kwargs = {
        "model": model_name,
        "temperature": temperature,
    }
    if openai_api_key:
        llm_kwargs["api_key"] = openai_api_key
    if openai_api_base:
        llm_kwargs["base_url"] = openai_api_base

    llm = ChatOpenAI(**llm_kwargs)
    logger.info(f"LLM initialized: {model_name}, temperature={temperature}")

    # NOTE: Both stores are in-memory only — all data is lost on restart.
    # For production, replace with SqliteSaver / persistent store.
    in_memory_store = InMemoryStore()
    checkpointer = MemorySaver()

    # Music Catalog Sub-Agent (hand-built ReAct)
    music_assistant_fn = create_music_assistant_node(llm, music_tools)
    music_tool_node = ToolNode(music_tools)

    music_workflow = StateGraph(State)
    music_workflow.add_node("music_assistant", music_assistant_fn)
    music_workflow.add_node("music_tool_node", music_tool_node)
    music_workflow.add_edge(START, "music_assistant")
    music_workflow.add_conditional_edges(
        "music_assistant",
        should_continue,
        {"continue": "music_tool_node", "end": END},
    )
    music_workflow.add_edge("music_tool_node", "music_assistant")

    music_catalog_subagent = music_workflow.compile(
        name="music_catalog_subagent",
        checkpointer=checkpointer,
        store=in_memory_store,
    )
    logger.info("Music catalog sub-agent compiled.")

    # Invoice Information Sub-Agent (pre-built ReAct)
    invoice_information_subagent = create_react_agent(
        llm,
        tools=invoice_tools,
        name="invoice_information_subagent",
        prompt=INVOICE_SUBAGENT_PROMPT,
        state_schema=State,
        checkpointer=checkpointer,
        store=in_memory_store,
    )
    logger.info("Invoice information sub-agent compiled.")

    # Supervisor
    from langgraph_supervisor import create_supervisor

    supervisor_workflow = create_supervisor(
        agents=[invoice_information_subagent, music_catalog_subagent],
        output_mode="last_message",
        model=llm,
        prompt=SUPERVISOR_PROMPT,
        state_schema=State,
    )
    supervisor_prebuilt = supervisor_workflow.compile(
        name="supervisor",
        checkpointer=checkpointer,
        store=in_memory_store,
    )
    logger.info("Supervisor compiled.")

    # Final Multi-Agent Graph
    verify_info_fn = create_verify_info_node(llm)
    create_memory_fn = create_memory_node(llm)

    multi_agent = StateGraph(State)
    multi_agent.add_node("verify_info", verify_info_fn)
    multi_agent.add_node("human_input", human_input)
    multi_agent.add_node("load_memory", load_memory)
    multi_agent.add_node("supervisor", supervisor_prebuilt)
    multi_agent.add_node("create_memory", create_memory_fn)

    multi_agent.add_edge(START, "verify_info")
    multi_agent.add_conditional_edges(
        "verify_info",
        should_interrupt,
        {"continue": "load_memory", "interrupt": "human_input"},
    )
    multi_agent.add_edge("human_input", "verify_info")
    multi_agent.add_edge("load_memory", "supervisor")
    multi_agent.add_edge("supervisor", "create_memory")
    multi_agent.add_edge("create_memory", END)

    compiled_graph = multi_agent.compile(
        name="multi_agent_final",
        checkpointer=checkpointer,
        store=in_memory_store,
    )
    logger.info("Final multi-agent graph compiled successfully.")

    return compiled_graph, checkpointer, in_memory_store