Spaces:
Runtime error
Runtime error
File size: 5,619 Bytes
e4905e2 29b12bd e4905e2 87fe343 e4905e2 9ea17da e4905e2 29b12bd 371f00a e4905e2 9ea17da e4905e2 371f00a e4905e2 6568623 120e648 6568623 120e648 87fe343 120e648 6568623 120e648 6568623 9ea17da 6568623 e4905e2 6568623 9ea17da 6568623 9ea17da 6568623 9ea17da 6568623 9ea17da 6568623 9ea17da 6568623 9ea17da 6568623 9ea17da 5bb5631 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
"""
SQL Generator agent for pharmaceutical data management.
This agent converts a data pipeline plan into executable SQL queries.
"""
import time
from typing import Dict, Any, List
from anthropic.types import MessageParam
def sql_generator_agent(anthropic_client, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Agent that generates SQL queries based on the pipeline plan.
Args:
anthropic_client: The Anthropic client for calling Claude API
state: Current state of the agent workflow
Returns:
Updated state values only (not the entire state)
"""
# Get current messages and pipeline plan
messages = state.get("messages", [])
pipeline_plan = state.get("pipeline_plan", {})
# System message as a string
system_message = """
You are an AI assistant specializing in generating SQL for pharmaceutical data pipelines.
Your job is to transform the pipeline plan into executable SQL queries.
For each step in the pipeline:
1. Create an appropriate SQL query
2. Include comments explaining what each query does
3. Ensure proper joins, aggregations, and transformations
4. Consider data quality needs
Format each query clearly and tag your complete response with SQL_COMPLETE.
"""
# Prepare context for Claude
context = f"""
Pipeline plan: {pipeline_plan.get('description', 'No plan provided')}
Available tables:
- Raw data: RAW_SALES_TRANSACTIONS, RAW_HCP_DATA, RAW_PRODUCT_DATA
- Staging data: STG_SALES, STG_HCP, STG_PRODUCT
- Analytics-ready data: ARD_SALES_PERFORMANCE, ARD_HCP_ENGAGEMENT, ARD_MARKET_ANALYSIS
- Data products: DP_SALES_DASHBOARD, DP_HCP_TARGETING
"""
# Format messages for the Anthropic API
formatted_messages = []
for msg in messages:
if isinstance(msg, dict) and "role" in msg and "content" in msg:
formatted_messages.append(MessageParam(
role=msg["role"],
content=msg["content"]
))
# Add final user message with context
formatted_messages.append(
MessageParam(
role="user",
content=f"Based on this pipeline plan, generate the SQL queries needed. {context}"
)
)
try:
# Call Claude API with system parameter separately
response = anthropic_client.messages.create(
model="claude-3-7-sonnet-20250219",
system=system_message,
messages=formatted_messages,
max_tokens=3000
)
# Extract the response
agent_response = response.content[0].text
# Check if SQL generation is complete
sql_complete = "SQL_COMPLETE" in agent_response
# Clean response text
clean_response = agent_response.replace("SQL_COMPLETE", "").strip()
# Return only the STATE UPDATES, not the entire state
result = {}
# Add a new message to the list (will be combined with existing via operator.add)
result["messages"] = [{"role": "assistant", "content": clean_response}]
# Extract SQL queries from the response if complete
if sql_complete:
result["sql_queries"] = _extract_sql_queries(agent_response)
result["current_agent"] = "executor_agent"
else:
result["current_agent"] = "sql_generator_agent"
return result
except Exception as e:
# Handle any errors - return only state updates
print(f"Error in sql_generator_agent: {str(e)}")
return {
"messages": [{"role": "assistant", "content": f"I encountered an error: {str(e)}"}],
"current_agent": "sql_generator_agent"
}
def _extract_sql_queries(response: str) -> List[Dict[str, Any]]:
"""
Extract SQL queries from the agent's response.
Args:
response: The text response containing SQL queries
Returns:
List of dictionaries containing query information
"""
# Remove the SQL_COMPLETE tag if present
clean_response = response.replace("SQL_COMPLETE", "")
# Extract SQL code blocks
# This is a simple extraction that looks for ```sql ... ``` blocks
# In production, you would want a more robust parser
sql_blocks = []
current_pos = 0
while True:
start_marker = "```sql"
end_marker = "```"
start_pos = clean_response.find(start_marker, current_pos)
if start_pos == -1:
break
# Find the end of this code block
end_pos = clean_response.find(end_marker, start_pos + len(start_marker))
if end_pos == -1:
break
# Extract the SQL query
sql_content = clean_response[start_pos + len(start_marker):end_pos].strip()
# Add to our list
sql_blocks.append(sql_content)
# Move position forward
current_pos = end_pos + len(end_marker)
# Convert to query objects
sql_queries = []
for i, sql in enumerate(sql_blocks):
# Try to extract a purpose comment from the SQL
purpose = "Data transformation" # Default
lines = sql.split('\n')
if lines and lines[0].strip().startswith('--'):
purpose = lines[0].strip()[2:].strip()
sql_queries.append({
"name": f"Query {i+1}",
"sql": sql,
"purpose": purpose,
"created_at": time.time()
})
return sql_queries |