cryogenic22 commited on
Commit
4db5880
·
verified ·
1 Parent(s): 9524b20

Create data_agent.py

Browse files
Files changed (1) hide show
  1. agents/data_agent.py +364 -0
agents/data_agent.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import sqlite3
4
+ import pandas as pd
5
+ from typing import Dict, List, Any, Tuple
6
+ from langchain_anthropic import ChatAnthropic
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from pydantic import BaseModel, Field
10
+ import time
11
+ import re
12
+
13
+ class DataRequest(BaseModel):
14
+ """Structure for a data request"""
15
+ request_id: str
16
+ description: str
17
+ tables: List[str]
18
+ columns: List[str] = None
19
+ filters: Dict[str, Any] = None
20
+ time_period: Dict[str, str] = None
21
+ groupby: List[str] = None
22
+ purpose: str
23
+
24
+ class DataPipeline(BaseModel):
25
+ """Structure for a data pipeline"""
26
+ pipeline_id: str
27
+ name: str
28
+ sql: str
29
+ description: str
30
+ data_source: str
31
+ schema: Dict[str, str]
32
+ transformations: List[str] = None
33
+ output_table: str
34
+ purpose: str
35
+ visualization_hints: List[str] = None
36
+
37
+ class DataSource(BaseModel):
38
+ """Structure for a data source"""
39
+ source_id: str
40
+ name: str
41
+ content: Any # This will be the pandas DataFrame
42
+ schema: Dict[str, str]
43
+
44
+ class DataAgent:
45
+ """Agent responsible for data acquisition and transformation"""
46
+
47
+ def __init__(self, db_path: str = "data/pharma_db.sqlite"):
48
+ """Initialize the data agent with database connection"""
49
+ # Set up database connection
50
+ self.db_path = db_path
51
+ self.db_connection = sqlite3.connect(db_path)
52
+
53
+ # Set up Claude API client
54
+ api_key = os.getenv("ANTHROPIC_API_KEY")
55
+ if not api_key:
56
+ raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
57
+
58
+ self.llm = ChatAnthropic(
59
+ model="claude-3-haiku-20240307",
60
+ anthropic_api_key=api_key,
61
+ temperature=0.1
62
+ )
63
+
64
+ # Create SQL generation prompt
65
+ self.sql_prompt = ChatPromptTemplate.from_messages([
66
+ ("system", """You are an expert SQL developer specializing in pharmaceutical data analysis.
67
+ Your task is to translate natural language data requests into precise SQL queries suitable for a SQLite database.
68
+
69
+ For each request, generate a SQL query that:
70
+ 1. Retrieves only the necessary data for the analysis
71
+ 2. Uses appropriate JOINs to connect related tables
72
+ 3. Applies filters correctly
73
+ 4. Includes relevant aggregations and groupings
74
+ 5. Is optimized for performance
75
+
76
+ Format your response as follows:
77
+ ```sql
78
+ -- Your SQL query here
79
+ SELECT ...
80
+ FROM ...
81
+ WHERE ...
82
+ ```
83
+
84
+ Explain your approach after the SQL block, describing:
85
+ - Why you selected specific tables and columns
86
+ - How the query addresses the analytical requirements
87
+ - Any assumptions you made
88
+
89
+ The database schema includes these tables and columns:
90
+ - sales: sale_id, sale_date, product_id, region_id, territory_id, prescriber_id, pharmacy_id, units_sold, revenue, cost, margin
91
+ - products: product_id, product_name, therapeutic_area, molecule, launch_date, status, list_price
92
+ - regions: region_id, region_name, country, division, population
93
+ - territories: territory_id, territory_name, region_id, sales_rep_id
94
+ - prescribers: prescriber_id, name, specialty, practice_type, territory_id, decile
95
+ - pharmacies: pharmacy_id, name, address, territory_id, pharmacy_type, monthly_rx_volume
96
+ - competitor_products: competitor_product_id, product_name, manufacturer, therapeutic_area, molecule, launch_date, list_price, competing_with_product_id
97
+ - marketing_campaigns: campaign_id, campaign_name, start_date, end_date, product_id, campaign_type, target_audience, channels, budget, spend
98
+ - market_events: event_id, event_date, event_type, description, affected_products, affected_regions, impact_score
99
+ - sales_targets: target_id, product_id, region_id, period, target_units, target_revenue
100
+ - distribution_centers: dc_id, dc_name, region_id, inventory_capacity
101
+ - inventory: inventory_id, product_id, dc_id, date, units_available, units_allocated, units_in_transit, days_of_supply
102
+ - external_factors: factor_id, date, region_id, factor_type, factor_value, description
103
+ """),
104
+ ("human", "{request}")
105
+ ])
106
+
107
+ # Set up the SQL generation chain
108
+ self.sql_chain = (
109
+ self.sql_prompt
110
+ | self.llm
111
+ | StrOutputParser()
112
+ )
113
+
114
+ # Create transformation prompt
115
+ self.transform_prompt = ChatPromptTemplate.from_messages([
116
+ ("system", """You are an expert data engineer specializing in pharmaceutical data transformation.
117
+ Your task is to generate Python code using pandas to transform the data based on the requirements.
118
+
119
+ For each transformation request:
120
+ 1. Generate clear, efficient pandas code
121
+ 2. Include appropriate data cleaning steps
122
+ 3. Apply necessary transformations (normalization, feature engineering, etc.)
123
+ 4. Add comments explaining key steps
124
+ 5. Handle potential edge cases and missing data
125
+
126
+ Format your response with a code block:
127
+ ```python
128
+ # Transformation code
129
+ import pandas as pd
130
+ import numpy as np
131
+
132
+ def transform_data(df):
133
+ # Your transformation code here
134
+
135
+ return transformed_df
136
+ ```
137
+
138
+ After the code block, explain your transformation approach and any assumptions.
139
+ """),
140
+ ("human", """
141
+ Here is the data description:
142
+ {data_description}
143
+
144
+ Transformation needed:
145
+ {transformation_request}
146
+
147
+ Schema of the input data:
148
+ {input_schema}
149
+
150
+ Please generate the pandas code to perform this transformation.
151
+ """)
152
+ ])
153
+
154
+ # Set up the transformation chain
155
+ self.transform_chain = (
156
+ self.transform_prompt
157
+ | self.llm
158
+ | StrOutputParser()
159
+ )
160
+
161
+ def execute_sql(self, sql: str) -> pd.DataFrame:
162
+ """Execute SQL query and return results as DataFrame"""
163
+ try:
164
+ start_time = time.time()
165
+ df = pd.read_sql_query(sql, self.db_connection)
166
+ end_time = time.time()
167
+ print(f"SQL execution time: {end_time - start_time:.2f} seconds")
168
+ print(f"Retrieved {len(df)} rows")
169
+ return df
170
+ except Exception as e:
171
+ print(f"SQL execution error: {e}")
172
+ print(f"Failed SQL: {sql}")
173
+ raise
174
+
175
+ def extract_sql_from_response(self, response: str) -> str:
176
+ """Extract SQL query from LLM response"""
177
+ # Extract SQL between ```sql and ``` markers
178
+ sql_match = re.search(r'```sql\s*(.*?)\s*```', response, re.DOTALL)
179
+ if sql_match:
180
+ return sql_match.group(1).strip()
181
+
182
+ # If not found with sql tag, try generic code block
183
+ sql_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL)
184
+ if sql_match:
185
+ return sql_match.group(1).strip()
186
+
187
+ # If no code blocks, look for SQL keywords
188
+ sql_pattern = r'(?i)(SELECT[\s\S]+?FROM[\s\S]+?(WHERE|GROUP BY|ORDER BY|LIMIT|$)[\s\S]*)'
189
+ sql_match = re.search(sql_pattern, response)
190
+ if sql_match:
191
+ return sql_match.group(0).strip()
192
+
193
+ # If all else fails, return empty string
194
+ return ""
195
+
196
+ def extract_python_from_response(self, response: str) -> str:
197
+ """Extract Python code from LLM response"""
198
+ # Extract Python between ```python and ``` markers
199
+ python_match = re.search(r'```python\s*(.*?)\s*```', response, re.DOTALL)
200
+ if python_match:
201
+ return python_match.group(1).strip()
202
+
203
+ # If not found with python tag, try generic code block
204
+ python_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL)
205
+ if python_match:
206
+ return python_match.group(1).strip()
207
+
208
+ # If all else fails, return empty string
209
+ return ""
210
+
211
+ def generate_sql(self, request: DataRequest) -> Tuple[str, str]:
212
+ """Generate SQL for data request"""
213
+ print(f"Data Agent: Generating SQL for request: {request.description}")
214
+
215
+ # Format the request for the prompt
216
+ request_text = f"""
217
+ Data Request: {request.description}
218
+
219
+ Tables needed: {', '.join(request.tables)}
220
+
221
+ {f"Columns needed: {', '.join(request.columns)}" if request.columns else ""}
222
+
223
+ {f"Filters: {json.dumps(request.filters)}" if request.filters else ""}
224
+
225
+ {f"Time period: {json.dumps(request.time_period)}" if request.time_period else ""}
226
+
227
+ {f"Group by: {', '.join(request.groupby)}" if request.groupby else ""}
228
+
229
+ Purpose: {request.purpose}
230
+
231
+ Please generate a SQL query for this request.
232
+ """
233
+
234
+ # Generate SQL
235
+ response = self.sql_chain.invoke({"request": request_text})
236
+
237
+ # Extract SQL query
238
+ sql_query = self.extract_sql_from_response(response)
239
+
240
+ return sql_query, response
241
+
242
+ def create_data_pipeline(self, request: DataRequest) -> Tuple[DataPipeline, pd.DataFrame]:
243
+ """Create data pipeline and execute it"""
244
+ # Generate SQL
245
+ sql_query, response = self.generate_sql(request)
246
+
247
+ # Execute SQL to get data
248
+ result_df = self.execute_sql(sql_query)
249
+
250
+ # Create schema description
251
+ schema = {col: str(result_df[col].dtype) for col in result_df.columns}
252
+
253
+ # Create pipeline object
254
+ pipeline = DataPipeline(
255
+ pipeline_id=f"pipeline_{request.request_id}",
256
+ name=f"Pipeline for {request.description}",
257
+ sql=sql_query,
258
+ description=request.description,
259
+ data_source=", ".join(request.tables),
260
+ schema=schema,
261
+ output_table=f"result_{request.request_id}",
262
+ purpose=request.purpose,
263
+ visualization_hints=["time_series"] if "date" in " ".join(result_df.columns).lower() else ["comparison"]
264
+ )
265
+
266
+ return pipeline, result_df
267
+
268
+ def transform_data(self, df: pd.DataFrame, transformation_request: str) -> Tuple[pd.DataFrame, str]:
269
+ """Transform data using pandas based on request"""
270
+ print(f"Data Agent: Transforming data based on request")
271
+
272
+ # Create schema description
273
+ schema = {col: str(df[col].dtype) for col in df.columns}
274
+
275
+ # Format the request for the prompt
276
+ request_text = {
277
+ "data_description": f"Data with {len(df)} rows and {len(df.columns)} columns.",
278
+ "transformation_request": transformation_request,
279
+ "input_schema": json.dumps(schema, indent=2)
280
+ }
281
+
282
+ # Generate transformation code
283
+ response = self.transform_chain.invoke(request_text)
284
+
285
+ # Extract Python code
286
+ python_code = self.extract_python_from_response(response)
287
+
288
+ # Execute transformation (with safety checks)
289
+ if not python_code:
290
+ print("Warning: No transformation code generated.")
291
+ return df, response
292
+
293
+ try:
294
+ # Create a local namespace with access to pandas and numpy
295
+ local_namespace = {
296
+ "pd": pd,
297
+ "np": __import__("numpy"),
298
+ "df": df.copy()
299
+ }
300
+
301
+ # Extract the function definition from the code
302
+ exec(python_code, local_namespace)
303
+
304
+ # Look for a transform_data function in the namespace
305
+ if "transform_data" in local_namespace:
306
+ transformed_df = local_namespace["transform_data"](df.copy())
307
+ return transformed_df, response
308
+ else:
309
+ print("Warning: No transform_data function found in generated code.")
310
+ return df, response
311
+ except Exception as e:
312
+ print(f"Transformation execution error: {e}")
313
+ return df, response
314
+
315
+ def get_data_for_analysis(self, data_requests: List[DataRequest]) -> Dict[str, DataSource]:
316
+ """Process multiple data requests and return results"""
317
+ data_sources = {}
318
+
319
+ for request in data_requests:
320
+ # Create data pipeline
321
+ pipeline, result_df = self.create_data_pipeline(request)
322
+
323
+ # Create data source object
324
+ data_source = DataSource(
325
+ source_id=request.request_id,
326
+ name=request.description,
327
+ content=result_df,
328
+ schema=pipeline.schema
329
+ )
330
+
331
+ # Store data source
332
+ data_sources[request.request_id] = data_source
333
+
334
+ return data_sources
335
+
336
+ def close(self):
337
+ """Close database connection"""
338
+ if hasattr(self, 'db_connection') and self.db_connection:
339
+ self.db_connection.close()
340
+
341
+ # For testing
342
+ if __name__ == "__main__":
343
+ # Set API key for testing
344
+ os.environ["ANTHROPIC_API_KEY"] = "your_api_key_here"
345
+
346
+ agent = DataAgent(db_path="data/pharma_db.sqlite")
347
+
348
+ # Example data request
349
+ request = DataRequest(
350
+ request_id="drx_sales_trend",
351
+ description="Monthly sales of DrugX by region over the past year",
352
+ tables=["sales", "regions", "products"],
353
+ filters={"product_id": "DRX"},
354
+ time_period={"start": "2023-01-01", "end": "2023-12-31"},
355
+ groupby=["region_id", "year_month"],
356
+ purpose="Analyze sales trend of DrugX by region"
357
+ )
358
+
359
+ pipeline, df = agent.create_data_pipeline(request)
360
+ print(f"Generated SQL:\n{pipeline.sql}")
361
+ print(f"Result shape: {df.shape}")
362
+ print(df.head())
363
+
364
+ agent.close()