cryogenic22 commited on
Commit
371f00a
·
verified ·
1 Parent(s): e1d35ba

Update agents/sql_generator.py

Browse files
Files changed (1) hide show
  1. agents/sql_generator.py +21 -80
agents/sql_generator.py CHANGED
@@ -20,8 +20,8 @@ def sql_generator_agent(anthropic_client, state: Dict[str, Any]) -> Dict[str, An
20
  Updated state
21
  """
22
  # Get current messages and pipeline plan
23
- messages = state["messages"]
24
- pipeline_plan = state["pipeline_plan"]
25
 
26
  # Add agent-specific instructions
27
  system_message = """
@@ -39,7 +39,7 @@ def sql_generator_agent(anthropic_client, state: Dict[str, Any]) -> Dict[str, An
39
 
40
  # Prepare context for Claude
41
  context = f"""
42
- Pipeline plan: {pipeline_plan['description']}
43
 
44
  Available tables:
45
  - Raw data: RAW_SALES_TRANSACTIONS, RAW_HCP_DATA, RAW_PRODUCT_DATA
@@ -48,17 +48,28 @@ def sql_generator_agent(anthropic_client, state: Dict[str, Any]) -> Dict[str, An
48
  - Data products: DP_SALES_DASHBOARD, DP_HCP_TARGETING
49
  """
50
 
51
- # Prepare prompt for Claude
52
- prompt_messages = [
53
- *[MessageParam(role=m["role"], content=m["content"]) for m in messages],
54
- MessageParam(role="user", content=f"Based on this pipeline plan, generate the SQL queries needed. {context}")
55
- ]
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # Call Claude API
58
  response = anthropic_client.messages.create(
59
  model="claude-3-7-sonnet-20250219",
60
  system=system_message,
61
- messages=prompt_messages,
62
  max_tokens=3000
63
  )
64
 
@@ -76,74 +87,4 @@ def sql_generator_agent(anthropic_client, state: Dict[str, Any]) -> Dict[str, An
76
  sql_queries = _extract_sql_queries(agent_response)
77
 
78
  new_state["sql_queries"] = sql_queries
79
- new_state["current_agent"] = "executor_agent"
80
- else:
81
- # Need more information or work, stay with SQL generator agent
82
- new_state["current_agent"] = "sql_generator_agent"
83
-
84
- # Add agent's response to messages
85
- new_messages = add_messages(state, [
86
- {"role": "assistant", "content": agent_response.replace("SQL_COMPLETE", "").strip()}
87
- ])
88
- new_state["messages"] = new_messages
89
-
90
- return new_state
91
-
92
- def _extract_sql_queries(response: str) -> List[Dict[str, Any]]:
93
- """
94
- Extract SQL queries from the agent's response.
95
-
96
- Args:
97
- response: The text response containing SQL queries
98
-
99
- Returns:
100
- List of dictionaries containing query information
101
- """
102
- # Remove the SQL_COMPLETE tag if present
103
- clean_response = response.replace("SQL_COMPLETE", "")
104
-
105
- # Extract SQL code blocks
106
- # This is a simple extraction that looks for ```sql ... ``` blocks
107
- # In production, you would want a more robust parser
108
- sql_blocks = []
109
- current_pos = 0
110
-
111
- while True:
112
- start_marker = "```sql"
113
- end_marker = "```"
114
-
115
- start_pos = clean_response.find(start_marker, current_pos)
116
- if start_pos == -1:
117
- break
118
-
119
- # Find the end of this code block
120
- end_pos = clean_response.find(end_marker, start_pos + len(start_marker))
121
- if end_pos == -1:
122
- break
123
-
124
- # Extract the SQL query
125
- sql_content = clean_response[start_pos + len(start_marker):end_pos].strip()
126
-
127
- # Add to our list
128
- sql_blocks.append(sql_content)
129
-
130
- # Move position forward
131
- current_pos = end_pos + len(end_marker)
132
-
133
- # Convert to query objects
134
- sql_queries = []
135
- for i, sql in enumerate(sql_blocks):
136
- # Try to extract a purpose comment from the SQL
137
- purpose = "Data transformation" # Default
138
- lines = sql.split('\n')
139
- if lines and lines[0].strip().startswith('--'):
140
- purpose = lines[0].strip()[2:].strip()
141
-
142
- sql_queries.append({
143
- "name": f"Query {i+1}",
144
- "sql": sql,
145
- "purpose": purpose,
146
- "created_at": time.time()
147
- })
148
-
149
- return sql_queries
 
20
  Updated state
21
  """
22
  # Get current messages and pipeline plan
23
+ messages = state.get("messages", [])
24
+ pipeline_plan = state.get("pipeline_plan", {})
25
 
26
  # Add agent-specific instructions
27
  system_message = """
 
39
 
40
  # Prepare context for Claude
41
  context = f"""
42
+ Pipeline plan: {pipeline_plan.get('description', 'No plan provided')}
43
 
44
  Available tables:
45
  - Raw data: RAW_SALES_TRANSACTIONS, RAW_HCP_DATA, RAW_PRODUCT_DATA
 
48
  - Data products: DP_SALES_DASHBOARD, DP_HCP_TARGETING
49
  """
50
 
51
+ # Convert messages to the format expected by Anthropic API
52
+ anthropic_messages = []
53
+ for msg in messages:
54
+ if isinstance(msg, dict) and "role" in msg and "content" in msg:
55
+ anthropic_messages.append(MessageParam(
56
+ role=msg["role"],
57
+ content=msg["content"]
58
+ ))
59
+
60
+ # Add final user message with context
61
+ anthropic_messages.append(
62
+ MessageParam(
63
+ role="user",
64
+ content=f"Based on this pipeline plan, generate the SQL queries needed. {context}"
65
+ )
66
+ )
67
 
68
  # Call Claude API
69
  response = anthropic_client.messages.create(
70
  model="claude-3-7-sonnet-20250219",
71
  system=system_message,
72
+ messages=anthropic_messages,
73
  max_tokens=3000
74
  )
75
 
 
87
  sql_queries = _extract_sql_queries(agent_response)
88
 
89
  new_state["sql_queries"] = sql_queries
90
+ new_st