File size: 4,752 Bytes
1611b23
 
 
 
d3afdfe
7ab1a2c
01db754
 
1611b23
 
 
 
 
 
 
 
 
 
 
 
 
d3afdfe
1611b23
 
 
 
 
 
 
 
 
 
d3afdfe
1611b23
 
efe21f0
93c751a
 
efe21f0
93c751a
 
efe21f0
93c751a
 
efe21f0
93c751a
1611b23
 
 
 
 
 
 
 
 
 
 
d3afdfe
1611b23
 
01db754
7ab1a2c
 
efe21f0
7ab1a2c
d3afdfe
01db754
7ab1a2c
 
1611b23
 
 
 
 
 
 
 
93c751a
d3afdfe
 
01db754
efe21f0
9759420
efe21f0
01db754
efe21f0
 
 
9759420
efe21f0
9759420
efe21f0
01db754
efe21f0
 
 
9759420
efe21f0
9759420
efe21f0
01db754
efe21f0
 
 
3791912
ff961a3
1611b23
 
3791912
1611b23
 
d3afdfe
1611b23
 
 
 
 
3791912
1611b23
 
d3afdfe
1611b23
 
 
 
 
3791912
1611b23
 
d3afdfe
1611b23
 
 
 
 
 
d3afdfe
 
 
 
1611b23
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
LangGraph workflow for pharmaceutical data management agents.
"""

from langgraph.graph import StateGraph, END, START
from langchain_core.tools import tool
from typing import Dict, Any
import operator

from agents.state import AgentState
from agents.understanding import understanding_agent
from agents.planning import planning_agent
from agents.sql_generator import sql_generator_agent
from agents.executor import executor_agent
from agents.tools import (
    tool_list_tables, 
    tool_describe_table, 
    tool_sample_table,
    tool_execute_query,
    tool_get_confidence
)
from agents.utils.logging import log_agent_activity

def create_agent_graph(anthropic_client, db):
    """
    Create the agent workflow graph.
    
    Args:
        anthropic_client: The Anthropic client for calling Claude API
        db: The database connection
    
    Returns:
        Compiled LangGraph workflow and state update function
    """
    # Wrap the agents with the anthropic client
    def understanding(state):
        return understanding_agent(anthropic_client, state)
    
    def planning(state):
        return planning_agent(anthropic_client, state)
    
    def sql_generation(state):
        return sql_generator_agent(anthropic_client, state)
    
    def execution(state):
        return executor_agent(anthropic_client, db, state)
    
    # Define LangGraph nodes
    nodes = {
        "understanding_agent": understanding,
        "planning_agent": planning,
        "sql_generator_agent": sql_generation,
        "executor_agent": execution
    }
    
    # Create a function to access the current state for the confidence tool
    state_dict = {}  # This will be updated in the Streamlit app
    state_provider = lambda: state_dict
    
    # Create tools node with database-related tools
    # Tools implementation remains the same
    
    # Create Tool Agent Node
    def tool_handler(state):
        """Handle tool calls from the agent workflow."""
        log_agent_activity("TOOL", state)
        return {}  # Return empty state update since we don't modify state
    
    nodes["tools"] = tool_handler
    
    # Create the state graph
    workflow = StateGraph(AgentState)
    
    # Add all nodes
    for name, node in nodes.items():
        workflow.add_node(name, node)
    
    # Set the entry point to understanding_agent 
    workflow.add_edge(START, "understanding_agent")
    
    # Define routing functions
    def route_understanding(state):
        """Route from understanding agent."""
        current_agent = state.get("current_agent", "understanding_agent")
        
        if current_agent == "planning_agent":
            return "planning_agent"
        return "understanding_agent"
    
    def route_planning(state):
        """Route from planning agent."""
        current_agent = state.get("current_agent", "planning_agent")
        
        if current_agent == "sql_generator_agent":
            return "sql_generator_agent"
        return "planning_agent"
    
    def route_sql_generator(state):
        """Route from SQL generator agent."""
        current_agent = state.get("current_agent", "sql_generator_agent")
        
        if current_agent == "executor_agent":
            return "executor_agent"
        return "sql_generator_agent"
    
    # Define conditional edges with our routing functions
    workflow.add_conditional_edges(
        "understanding_agent",
        route_understanding,
        {
            "understanding_agent": "understanding_agent",
            "planning_agent": "planning_agent"
        }
    )
    
    workflow.add_conditional_edges(
        "planning_agent",
        route_planning,
        {
            "planning_agent": "planning_agent",
            "sql_generator_agent": "sql_generator_agent"
        }
    )
    
    workflow.add_conditional_edges(
        "sql_generator_agent",
        route_sql_generator,
        {
            "sql_generator_agent": "sql_generator_agent",
            "executor_agent": "executor_agent"
        }
    )
    
    # Executor agent finishes the workflow
    workflow.add_edge("executor_agent", END)
    
    # Add edges for tools - they can be called from any agent and return to that agent
    for agent in ["understanding_agent", "planning_agent", "sql_generator_agent", "executor_agent"]:
        workflow.add_edge(agent, "tools")
        workflow.add_edge("tools", agent)
    
    # Compile the workflow
    app = workflow.compile()
    
    # Update the state dictionary reference (used by the confidence tool)
    def update_state_dict(new_state):
        state_dict.clear()
        state_dict.update(new_state)
    
    # Return both the compiled workflow and the update function
    return app, update_state_dict