data_pipeline_agent / agents /sql_generator.py
cryogenic22's picture
Update agents/sql_generator.py
9ea17da verified
"""
SQL Generator agent for pharmaceutical data management.
This agent converts a data pipeline plan into executable SQL queries.
"""
import time
from typing import Dict, Any, List
from anthropic.types import MessageParam
def sql_generator_agent(anthropic_client, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Agent that generates SQL queries based on the pipeline plan.
Args:
anthropic_client: The Anthropic client for calling Claude API
state: Current state of the agent workflow
Returns:
Updated state values only (not the entire state)
"""
# Get current messages and pipeline plan
messages = state.get("messages", [])
pipeline_plan = state.get("pipeline_plan", {})
# System message as a string
system_message = """
You are an AI assistant specializing in generating SQL for pharmaceutical data pipelines.
Your job is to transform the pipeline plan into executable SQL queries.
For each step in the pipeline:
1. Create an appropriate SQL query
2. Include comments explaining what each query does
3. Ensure proper joins, aggregations, and transformations
4. Consider data quality needs
Format each query clearly and tag your complete response with SQL_COMPLETE.
"""
# Prepare context for Claude
context = f"""
Pipeline plan: {pipeline_plan.get('description', 'No plan provided')}
Available tables:
- Raw data: RAW_SALES_TRANSACTIONS, RAW_HCP_DATA, RAW_PRODUCT_DATA
- Staging data: STG_SALES, STG_HCP, STG_PRODUCT
- Analytics-ready data: ARD_SALES_PERFORMANCE, ARD_HCP_ENGAGEMENT, ARD_MARKET_ANALYSIS
- Data products: DP_SALES_DASHBOARD, DP_HCP_TARGETING
"""
# Format messages for the Anthropic API
formatted_messages = []
for msg in messages:
if isinstance(msg, dict) and "role" in msg and "content" in msg:
formatted_messages.append(MessageParam(
role=msg["role"],
content=msg["content"]
))
# Add final user message with context
formatted_messages.append(
MessageParam(
role="user",
content=f"Based on this pipeline plan, generate the SQL queries needed. {context}"
)
)
try:
# Call Claude API with system parameter separately
response = anthropic_client.messages.create(
model="claude-3-7-sonnet-20250219",
system=system_message,
messages=formatted_messages,
max_tokens=3000
)
# Extract the response
agent_response = response.content[0].text
# Check if SQL generation is complete
sql_complete = "SQL_COMPLETE" in agent_response
# Clean response text
clean_response = agent_response.replace("SQL_COMPLETE", "").strip()
# Return only the STATE UPDATES, not the entire state
result = {}
# Add a new message to the list (will be combined with existing via operator.add)
result["messages"] = [{"role": "assistant", "content": clean_response}]
# Extract SQL queries from the response if complete
if sql_complete:
result["sql_queries"] = _extract_sql_queries(agent_response)
result["current_agent"] = "executor_agent"
else:
result["current_agent"] = "sql_generator_agent"
return result
except Exception as e:
# Handle any errors - return only state updates
print(f"Error in sql_generator_agent: {str(e)}")
return {
"messages": [{"role": "assistant", "content": f"I encountered an error: {str(e)}"}],
"current_agent": "sql_generator_agent"
}
def _extract_sql_queries(response: str) -> List[Dict[str, Any]]:
"""
Extract SQL queries from the agent's response.
Args:
response: The text response containing SQL queries
Returns:
List of dictionaries containing query information
"""
# Remove the SQL_COMPLETE tag if present
clean_response = response.replace("SQL_COMPLETE", "")
# Extract SQL code blocks
# This is a simple extraction that looks for ```sql ... ``` blocks
# In production, you would want a more robust parser
sql_blocks = []
current_pos = 0
while True:
start_marker = "```sql"
end_marker = "```"
start_pos = clean_response.find(start_marker, current_pos)
if start_pos == -1:
break
# Find the end of this code block
end_pos = clean_response.find(end_marker, start_pos + len(start_marker))
if end_pos == -1:
break
# Extract the SQL query
sql_content = clean_response[start_pos + len(start_marker):end_pos].strip()
# Add to our list
sql_blocks.append(sql_content)
# Move position forward
current_pos = end_pos + len(end_marker)
# Convert to query objects
sql_queries = []
for i, sql in enumerate(sql_blocks):
# Try to extract a purpose comment from the SQL
purpose = "Data transformation" # Default
lines = sql.split('\n')
if lines and lines[0].strip().startswith('--'):
purpose = lines[0].strip()[2:].strip()
sql_queries.append({
"name": f"Query {i+1}",
"sql": sql,
"purpose": purpose,
"created_at": time.time()
})
return sql_queries