Spaces:
Runtime error
Runtime error
Update graph/workflow.py
Browse files- graph/workflow.py +75 -11
graph/workflow.py
CHANGED
|
@@ -3,7 +3,9 @@ LangGraph workflow for pharmaceutical data management agents.
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
from langgraph.graph import StateGraph, END
|
| 6 |
-
from
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from agents.state import AgentState
|
| 9 |
from agents.understanding import understanding_agent
|
|
@@ -48,16 +50,78 @@ def create_agent_graph(anthropic_client, db):
|
|
| 48 |
state_dict = {} # This will be updated in the Streamlit app
|
| 49 |
|
| 50 |
# Create tools node with database-related tools
|
| 51 |
-
tools
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Create the state graph
|
| 63 |
workflow = StateGraph(AgentState)
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
from langgraph.graph import StateGraph, END
|
| 6 |
+
from langchain.tools import StructuredTool
|
| 7 |
+
from langchain_core.tools import tool
|
| 8 |
+
from typing import Dict, Any, List
|
| 9 |
|
| 10 |
from agents.state import AgentState
|
| 11 |
from agents.understanding import understanding_agent
|
|
|
|
| 50 |
state_dict = {} # This will be updated in the Streamlit app
|
| 51 |
|
| 52 |
# Create tools node with database-related tools
|
| 53 |
+
# Convert our custom tools to LangChain StructuredTool format
|
| 54 |
+
lc_tools = []
|
| 55 |
+
|
| 56 |
+
# List Tables Tool
|
| 57 |
+
@tool
|
| 58 |
+
def list_tables():
|
| 59 |
+
"""List available tables in the database, categorized by pipeline stage."""
|
| 60 |
+
return db.get_tables()
|
| 61 |
+
lc_tools.append(list_tables)
|
| 62 |
+
|
| 63 |
+
# Describe Table Tool
|
| 64 |
+
@tool
|
| 65 |
+
def describe_table(table_name: str):
|
| 66 |
+
"""Get the schema of a specific table."""
|
| 67 |
+
query = f"DESCRIBE {table_name}"
|
| 68 |
+
return db.execute_query(query)
|
| 69 |
+
lc_tools.append(describe_table)
|
| 70 |
+
|
| 71 |
+
# Sample Table Tool
|
| 72 |
+
@tool
|
| 73 |
+
def sample_table(table_name: str, rows: int = 5):
|
| 74 |
+
"""Get a sample of rows from a specific table."""
|
| 75 |
+
return db.get_table_sample(table_name, rows)
|
| 76 |
+
lc_tools.append(sample_table)
|
| 77 |
+
|
| 78 |
+
# Execute Query Tool
|
| 79 |
+
@tool
|
| 80 |
+
def execute_query(query: str):
|
| 81 |
+
"""Execute a SQL query on the database."""
|
| 82 |
+
return db.execute_query(query)
|
| 83 |
+
lc_tools.append(execute_query)
|
| 84 |
+
|
| 85 |
+
# Get Confidence Tool
|
| 86 |
+
@tool
|
| 87 |
+
def get_confidence(area: str = "overall"):
|
| 88 |
+
"""Calculate confidence score for the current plan or specific area."""
|
| 89 |
+
state = state_dict
|
| 90 |
+
user_intent = state.get("user_intent", {})
|
| 91 |
+
pipeline_plan = state.get("pipeline_plan", {})
|
| 92 |
+
|
| 93 |
+
# This would implement a real confidence scoring system
|
| 94 |
+
# For demo, we'll return simulated confidence scores
|
| 95 |
+
|
| 96 |
+
completeness = len(user_intent) / 5 # Simulate based on intent completeness
|
| 97 |
+
clarity = 0.7 if pipeline_plan and "description" in pipeline_plan else 0.3
|
| 98 |
+
feasibility = 0.85 # High by default for demo
|
| 99 |
+
|
| 100 |
+
if area == "intent":
|
| 101 |
+
return {"confidence": round(completeness * 100, 1), "area": "intent"}
|
| 102 |
+
elif area == "plan":
|
| 103 |
+
return {"confidence": round(clarity * 100, 1), "area": "plan"}
|
| 104 |
+
elif area == "feasibility":
|
| 105 |
+
return {"confidence": round(feasibility * 100, 1), "area": "feasibility"}
|
| 106 |
+
else:
|
| 107 |
+
overall = (completeness + clarity + feasibility) / 3
|
| 108 |
+
return {"confidence": round(overall * 100, 1), "area": "overall"}
|
| 109 |
+
lc_tools.append(get_confidence)
|
| 110 |
+
|
| 111 |
+
# Create Tool Agent Node
|
| 112 |
+
# In the updated LangGraph, we don't use ToolNode directly
|
| 113 |
+
# We'll create a tool handler function instead
|
| 114 |
+
def tool_handler(state: Dict[str, Any]) -> Dict[str, Any]:
|
| 115 |
+
"""Handle tool calls from the agent workflow."""
|
| 116 |
+
# This would normally implement logic to determine which tool to call
|
| 117 |
+
# based on the agent's request, but for this demo we'll use a simpler approach
|
| 118 |
+
# In production, you would parse agent messages to identify tool calls
|
| 119 |
+
|
| 120 |
+
# Just return the state unmodified - tools are actually called via the LangChain integration
|
| 121 |
+
# in each agent's implementation
|
| 122 |
+
return state
|
| 123 |
+
|
| 124 |
+
nodes["tools"] = tool_handler
|
| 125 |
|
| 126 |
# Create the state graph
|
| 127 |
workflow = StateGraph(AgentState)
|