File size: 5,619 Bytes
e4905e2
 
 
 
 
 
29b12bd
e4905e2
 
87fe343
e4905e2
 
 
 
 
 
 
 
9ea17da
e4905e2
29b12bd
371f00a
 
e4905e2
9ea17da
e4905e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371f00a
e4905e2
 
 
 
 
 
 
 
6568623
 
120e648
6568623
 
120e648
 
 
87fe343
120e648
6568623
120e648
 
 
 
 
 
6568623
9ea17da
6568623
 
 
 
 
 
e4905e2
6568623
 
 
 
 
 
9ea17da
 
6568623
9ea17da
 
 
 
 
6568623
 
 
9ea17da
 
6568623
9ea17da
6568623
9ea17da
6568623
 
9ea17da
6568623
9ea17da
 
 
 
5bb5631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
SQL Generator agent for pharmaceutical data management.
This agent converts a data pipeline plan into executable SQL queries.
"""

import time
from typing import Dict, Any, List
from anthropic.types import MessageParam

def sql_generator_agent(anthropic_client, state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Agent that generates SQL queries based on the pipeline plan.
    
    Args:
        anthropic_client: The Anthropic client for calling Claude API
        state: Current state of the agent workflow
    
    Returns:
        Updated state values only (not the entire state)
    """
    # Get current messages and pipeline plan
    messages = state.get("messages", [])
    pipeline_plan = state.get("pipeline_plan", {})
    
    # System message as a string
    system_message = """
    You are an AI assistant specializing in generating SQL for pharmaceutical data pipelines.
    Your job is to transform the pipeline plan into executable SQL queries.
    
    For each step in the pipeline:
    1. Create an appropriate SQL query
    2. Include comments explaining what each query does
    3. Ensure proper joins, aggregations, and transformations
    4. Consider data quality needs
    
    Format each query clearly and tag your complete response with SQL_COMPLETE.
    """
    
    # Prepare context for Claude
    context = f"""
    Pipeline plan: {pipeline_plan.get('description', 'No plan provided')}
    
    Available tables:
    - Raw data: RAW_SALES_TRANSACTIONS, RAW_HCP_DATA, RAW_PRODUCT_DATA
    - Staging data: STG_SALES, STG_HCP, STG_PRODUCT
    - Analytics-ready data: ARD_SALES_PERFORMANCE, ARD_HCP_ENGAGEMENT, ARD_MARKET_ANALYSIS
    - Data products: DP_SALES_DASHBOARD, DP_HCP_TARGETING
    """
    
    # Format messages for the Anthropic API
    formatted_messages = []
    for msg in messages:
        if isinstance(msg, dict) and "role" in msg and "content" in msg:
            formatted_messages.append(MessageParam(
                role=msg["role"],
                content=msg["content"]
            ))
    
    # Add final user message with context
    formatted_messages.append(
        MessageParam(
            role="user", 
            content=f"Based on this pipeline plan, generate the SQL queries needed. {context}"
        )
    )
    
    try:
        # Call Claude API with system parameter separately
        response = anthropic_client.messages.create(
            model="claude-3-7-sonnet-20250219",
            system=system_message,
            messages=formatted_messages,
            max_tokens=3000
        )
        
        # Extract the response
        agent_response = response.content[0].text
        
        # Check if SQL generation is complete
        sql_complete = "SQL_COMPLETE" in agent_response
        
        # Clean response text
        clean_response = agent_response.replace("SQL_COMPLETE", "").strip()
        
        # Return only the STATE UPDATES, not the entire state
        result = {}
        
        # Add a new message to the list (will be combined with existing via operator.add)
        result["messages"] = [{"role": "assistant", "content": clean_response}]
        
        # Extract SQL queries from the response if complete
        if sql_complete:
            result["sql_queries"] = _extract_sql_queries(agent_response)
            result["current_agent"] = "executor_agent"
        else:
            result["current_agent"] = "sql_generator_agent"
        
        return result
    
    except Exception as e:
        # Handle any errors - return only state updates
        print(f"Error in sql_generator_agent: {str(e)}")
        return {
            "messages": [{"role": "assistant", "content": f"I encountered an error: {str(e)}"}],
            "current_agent": "sql_generator_agent"
        }

def _extract_sql_queries(response: str) -> List[Dict[str, Any]]:
    """
    Extract SQL queries from the agent's response.
    
    Args:
        response: The text response containing SQL queries
    
    Returns:
        List of dictionaries containing query information
    """
    # Remove the SQL_COMPLETE tag if present
    clean_response = response.replace("SQL_COMPLETE", "")
    
    # Extract SQL code blocks
    # This is a simple extraction that looks for ```sql ... ``` blocks
    # In production, you would want a more robust parser
    sql_blocks = []
    current_pos = 0
    
    while True:
        start_marker = "```sql"
        end_marker = "```"
        
        start_pos = clean_response.find(start_marker, current_pos)
        if start_pos == -1:
            break
            
        # Find the end of this code block
        end_pos = clean_response.find(end_marker, start_pos + len(start_marker))
        if end_pos == -1:
            break
            
        # Extract the SQL query
        sql_content = clean_response[start_pos + len(start_marker):end_pos].strip()
        
        # Add to our list
        sql_blocks.append(sql_content)
        
        # Move position forward
        current_pos = end_pos + len(end_marker)
    
    # Convert to query objects
    sql_queries = []
    for i, sql in enumerate(sql_blocks):
        # Try to extract a purpose comment from the SQL
        purpose = "Data transformation"  # Default
        lines = sql.split('\n')
        if lines and lines[0].strip().startswith('--'):
            purpose = lines[0].strip()[2:].strip()
        
        sql_queries.append({
            "name": f"Query {i+1}",
            "sql": sql,
            "purpose": purpose,
            "created_at": time.time()
        })
    
    return sql_queries