""" SQL Generator agent for pharmaceutical data management. This agent converts a data pipeline plan into executable SQL queries. """ import time from typing import Dict, Any, List from anthropic.types import MessageParam def sql_generator_agent(anthropic_client, state: Dict[str, Any]) -> Dict[str, Any]: """ Agent that generates SQL queries based on the pipeline plan. Args: anthropic_client: The Anthropic client for calling Claude API state: Current state of the agent workflow Returns: Updated state values only (not the entire state) """ # Get current messages and pipeline plan messages = state.get("messages", []) pipeline_plan = state.get("pipeline_plan", {}) # System message as a string system_message = """ You are an AI assistant specializing in generating SQL for pharmaceutical data pipelines. Your job is to transform the pipeline plan into executable SQL queries. For each step in the pipeline: 1. Create an appropriate SQL query 2. Include comments explaining what each query does 3. Ensure proper joins, aggregations, and transformations 4. Consider data quality needs Format each query clearly and tag your complete response with SQL_COMPLETE. """ # Prepare context for Claude context = f""" Pipeline plan: {pipeline_plan.get('description', 'No plan provided')} Available tables: - Raw data: RAW_SALES_TRANSACTIONS, RAW_HCP_DATA, RAW_PRODUCT_DATA - Staging data: STG_SALES, STG_HCP, STG_PRODUCT - Analytics-ready data: ARD_SALES_PERFORMANCE, ARD_HCP_ENGAGEMENT, ARD_MARKET_ANALYSIS - Data products: DP_SALES_DASHBOARD, DP_HCP_TARGETING """ # Format messages for the Anthropic API formatted_messages = [] for msg in messages: if isinstance(msg, dict) and "role" in msg and "content" in msg: formatted_messages.append(MessageParam( role=msg["role"], content=msg["content"] )) # Add final user message with context formatted_messages.append( MessageParam( role="user", content=f"Based on this pipeline plan, generate the SQL queries needed. {context}" ) ) try: # Call Claude API with system parameter separately response = anthropic_client.messages.create( model="claude-3-7-sonnet-20250219", system=system_message, messages=formatted_messages, max_tokens=3000 ) # Extract the response agent_response = response.content[0].text # Check if SQL generation is complete sql_complete = "SQL_COMPLETE" in agent_response # Clean response text clean_response = agent_response.replace("SQL_COMPLETE", "").strip() # Return only the STATE UPDATES, not the entire state result = {} # Add a new message to the list (will be combined with existing via operator.add) result["messages"] = [{"role": "assistant", "content": clean_response}] # Extract SQL queries from the response if complete if sql_complete: result["sql_queries"] = _extract_sql_queries(agent_response) result["current_agent"] = "executor_agent" else: result["current_agent"] = "sql_generator_agent" return result except Exception as e: # Handle any errors - return only state updates print(f"Error in sql_generator_agent: {str(e)}") return { "messages": [{"role": "assistant", "content": f"I encountered an error: {str(e)}"}], "current_agent": "sql_generator_agent" } def _extract_sql_queries(response: str) -> List[Dict[str, Any]]: """ Extract SQL queries from the agent's response. Args: response: The text response containing SQL queries Returns: List of dictionaries containing query information """ # Remove the SQL_COMPLETE tag if present clean_response = response.replace("SQL_COMPLETE", "") # Extract SQL code blocks # This is a simple extraction that looks for ```sql ... ``` blocks # In production, you would want a more robust parser sql_blocks = [] current_pos = 0 while True: start_marker = "```sql" end_marker = "```" start_pos = clean_response.find(start_marker, current_pos) if start_pos == -1: break # Find the end of this code block end_pos = clean_response.find(end_marker, start_pos + len(start_marker)) if end_pos == -1: break # Extract the SQL query sql_content = clean_response[start_pos + len(start_marker):end_pos].strip() # Add to our list sql_blocks.append(sql_content) # Move position forward current_pos = end_pos + len(end_marker) # Convert to query objects sql_queries = [] for i, sql in enumerate(sql_blocks): # Try to extract a purpose comment from the SQL purpose = "Data transformation" # Default lines = sql.split('\n') if lines and lines[0].strip().startswith('--'): purpose = lines[0].strip()[2:].strip() sql_queries.append({ "name": f"Query {i+1}", "sql": sql, "purpose": purpose, "created_at": time.time() }) return sql_queries