cryogenic22 commited on
Commit
7ab1a2c
·
verified ·
1 Parent(s): d5a7540

Update graph/workflow.py

Browse files
Files changed (1) hide show
  1. graph/workflow.py +75 -11
graph/workflow.py CHANGED
@@ -3,7 +3,9 @@ LangGraph workflow for pharmaceutical data management agents.
3
  """
4
 
5
  from langgraph.graph import StateGraph, END
6
- from langgraph.prebuilt import ToolNode, tools_to_langchain
 
 
7
 
8
  from agents.state import AgentState
9
  from agents.understanding import understanding_agent
@@ -48,16 +50,78 @@ def create_agent_graph(anthropic_client, db):
48
  state_dict = {} # This will be updated in the Streamlit app
49
 
50
  # Create tools node with database-related tools
51
- tools = [
52
- tool_list_tables(db),
53
- tool_describe_table(db),
54
- tool_sample_table(db),
55
- tool_execute_query(db),
56
- tool_get_confidence(lambda: state_dict)
57
- ]
58
-
59
- tools_node = ToolNode(tools_to_langchain(tools))
60
- nodes["tools"] = tools_node
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Create the state graph
63
  workflow = StateGraph(AgentState)
 
3
  """
4
 
5
  from langgraph.graph import StateGraph, END
6
+ from langchain.tools import StructuredTool
7
+ from langchain_core.tools import tool
8
+ from typing import Dict, Any, List
9
 
10
  from agents.state import AgentState
11
  from agents.understanding import understanding_agent
 
50
  state_dict = {} # This will be updated in the Streamlit app
51
 
52
  # Create tools node with database-related tools
53
+ # Convert our custom tools to LangChain StructuredTool format
54
+ lc_tools = []
55
+
56
+ # List Tables Tool
57
+ @tool
58
+ def list_tables():
59
+ """List available tables in the database, categorized by pipeline stage."""
60
+ return db.get_tables()
61
+ lc_tools.append(list_tables)
62
+
63
+ # Describe Table Tool
64
+ @tool
65
+ def describe_table(table_name: str):
66
+ """Get the schema of a specific table."""
67
+ query = f"DESCRIBE {table_name}"
68
+ return db.execute_query(query)
69
+ lc_tools.append(describe_table)
70
+
71
+ # Sample Table Tool
72
+ @tool
73
+ def sample_table(table_name: str, rows: int = 5):
74
+ """Get a sample of rows from a specific table."""
75
+ return db.get_table_sample(table_name, rows)
76
+ lc_tools.append(sample_table)
77
+
78
+ # Execute Query Tool
79
+ @tool
80
+ def execute_query(query: str):
81
+ """Execute a SQL query on the database."""
82
+ return db.execute_query(query)
83
+ lc_tools.append(execute_query)
84
+
85
+ # Get Confidence Tool
86
+ @tool
87
+ def get_confidence(area: str = "overall"):
88
+ """Calculate confidence score for the current plan or specific area."""
89
+ state = state_dict
90
+ user_intent = state.get("user_intent", {})
91
+ pipeline_plan = state.get("pipeline_plan", {})
92
+
93
+ # This would implement a real confidence scoring system
94
+ # For demo, we'll return simulated confidence scores
95
+
96
+ completeness = len(user_intent) / 5 # Simulate based on intent completeness
97
+ clarity = 0.7 if pipeline_plan and "description" in pipeline_plan else 0.3
98
+ feasibility = 0.85 # High by default for demo
99
+
100
+ if area == "intent":
101
+ return {"confidence": round(completeness * 100, 1), "area": "intent"}
102
+ elif area == "plan":
103
+ return {"confidence": round(clarity * 100, 1), "area": "plan"}
104
+ elif area == "feasibility":
105
+ return {"confidence": round(feasibility * 100, 1), "area": "feasibility"}
106
+ else:
107
+ overall = (completeness + clarity + feasibility) / 3
108
+ return {"confidence": round(overall * 100, 1), "area": "overall"}
109
+ lc_tools.append(get_confidence)
110
+
111
+ # Create Tool Agent Node
112
+ # In the updated LangGraph, we don't use ToolNode directly
113
+ # We'll create a tool handler function instead
114
+ def tool_handler(state: Dict[str, Any]) -> Dict[str, Any]:
115
+ """Handle tool calls from the agent workflow."""
116
+ # This would normally implement logic to determine which tool to call
117
+ # based on the agent's request, but for this demo we'll use a simpler approach
118
+ # In production, you would parse agent messages to identify tool calls
119
+
120
+ # Just return the state unmodified - tools are actually called via the LangChain integration
121
+ # in each agent's implementation
122
+ return state
123
+
124
+ nodes["tools"] = tool_handler
125
 
126
  # Create the state graph
127
  workflow = StateGraph(AgentState)