File size: 13,341 Bytes
4db5880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b144e5a
4db5880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import os
import json
import sqlite3
import pandas as pd
from typing import Dict, List, Any, Tuple
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
import time
import re

class DataRequest(BaseModel):
    """Structure for a data request"""
    request_id: str
    description: str
    tables: List[str]
    columns: List[str] = None
    filters: Dict[str, Any] = None
    time_period: Dict[str, str] = None
    groupby: List[str] = None
    purpose: str

class DataPipeline(BaseModel):
    """Structure for a data pipeline"""
    pipeline_id: str
    name: str
    sql: str
    description: str
    data_source: str
    schema: Dict[str, str]
    transformations: List[str] = None
    output_table: str
    purpose: str
    visualization_hints: List[str] = None

class DataSource(BaseModel):
    """Structure for a data source"""
    source_id: str
    name: str
    content: Any  # This will be the pandas DataFrame
    schema: Dict[str, str]

class DataAgent:
    """Agent responsible for data acquisition and transformation"""
    
    def __init__(self, db_path: str = "data/pharma_db.sqlite"):
        """Initialize the data agent with database connection"""
        # Set up database connection
        self.db_path = db_path
        self.db_connection = sqlite3.connect(db_path)
        
        # Set up Claude API client
        api_key = os.getenv("ANTHROPIC_API_KEY")
        if not api_key:
            raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
        
        self.llm = ChatAnthropic(
            model="claude-3-7-sonnet-20250219",
            anthropic_api_key=api_key,
            temperature=0.1
        )
        
        # Create SQL generation prompt
        self.sql_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an expert SQL developer specializing in pharmaceutical data analysis.
Your task is to translate natural language data requests into precise SQL queries suitable for a SQLite database.

For each request, generate a SQL query that:
1. Retrieves only the necessary data for the analysis
2. Uses appropriate JOINs to connect related tables
3. Applies filters correctly
4. Includes relevant aggregations and groupings
5. Is optimized for performance

Format your response as follows:
```sql
-- Your SQL query here
SELECT ...
FROM ...
WHERE ...
```

Explain your approach after the SQL block, describing:
- Why you selected specific tables and columns
- How the query addresses the analytical requirements
- Any assumptions you made

The database schema includes these tables and columns:
- sales: sale_id, sale_date, product_id, region_id, territory_id, prescriber_id, pharmacy_id, units_sold, revenue, cost, margin
- products: product_id, product_name, therapeutic_area, molecule, launch_date, status, list_price
- regions: region_id, region_name, country, division, population
- territories: territory_id, territory_name, region_id, sales_rep_id
- prescribers: prescriber_id, name, specialty, practice_type, territory_id, decile
- pharmacies: pharmacy_id, name, address, territory_id, pharmacy_type, monthly_rx_volume
- competitor_products: competitor_product_id, product_name, manufacturer, therapeutic_area, molecule, launch_date, list_price, competing_with_product_id
- marketing_campaigns: campaign_id, campaign_name, start_date, end_date, product_id, campaign_type, target_audience, channels, budget, spend
- market_events: event_id, event_date, event_type, description, affected_products, affected_regions, impact_score
- sales_targets: target_id, product_id, region_id, period, target_units, target_revenue
- distribution_centers: dc_id, dc_name, region_id, inventory_capacity
- inventory: inventory_id, product_id, dc_id, date, units_available, units_allocated, units_in_transit, days_of_supply
- external_factors: factor_id, date, region_id, factor_type, factor_value, description
"""),
            ("human", "{request}")
        ])
        
        # Set up the SQL generation chain
        self.sql_chain = (
            self.sql_prompt
            | self.llm
            | StrOutputParser()
        )
        
        # Create transformation prompt
        self.transform_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an expert data engineer specializing in pharmaceutical data transformation.
Your task is to generate Python code using pandas to transform the data based on the requirements.

For each transformation request:
1. Generate clear, efficient pandas code
2. Include appropriate data cleaning steps
3. Apply necessary transformations (normalization, feature engineering, etc.)
4. Add comments explaining key steps
5. Handle potential edge cases and missing data

Format your response with a code block:
```python
# Transformation code
import pandas as pd
import numpy as np

def transform_data(df):
    # Your transformation code here
    
    return transformed_df
```

After the code block, explain your transformation approach and any assumptions.
"""),
            ("human", """
Here is the data description:
{data_description}

Transformation needed:
{transformation_request}

Schema of the input data:
{input_schema}

Please generate the pandas code to perform this transformation.
""")
        ])
        
        # Set up the transformation chain
        self.transform_chain = (
            self.transform_prompt
            | self.llm
            | StrOutputParser()
        )
    
    def execute_sql(self, sql: str) -> pd.DataFrame:
        """Execute SQL query and return results as DataFrame"""
        try:
            start_time = time.time()
            df = pd.read_sql_query(sql, self.db_connection)
            end_time = time.time()
            print(f"SQL execution time: {end_time - start_time:.2f} seconds")
            print(f"Retrieved {len(df)} rows")
            return df
        except Exception as e:
            print(f"SQL execution error: {e}")
            print(f"Failed SQL: {sql}")
            raise
    
    def extract_sql_from_response(self, response: str) -> str:
        """Extract SQL query from LLM response"""
        # Extract SQL between ```sql and ``` markers
        sql_match = re.search(r'```sql\s*(.*?)\s*```', response, re.DOTALL)
        if sql_match:
            return sql_match.group(1).strip()
        
        # If not found with sql tag, try generic code block
        sql_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL)
        if sql_match:
            return sql_match.group(1).strip()
        
        # If no code blocks, look for SQL keywords
        sql_pattern = r'(?i)(SELECT[\s\S]+?FROM[\s\S]+?(WHERE|GROUP BY|ORDER BY|LIMIT|$)[\s\S]*)'
        sql_match = re.search(sql_pattern, response)
        if sql_match:
            return sql_match.group(0).strip()
        
        # If all else fails, return empty string
        return ""
    
    def extract_python_from_response(self, response: str) -> str:
        """Extract Python code from LLM response"""
        # Extract Python between ```python and ``` markers
        python_match = re.search(r'```python\s*(.*?)\s*```', response, re.DOTALL)
        if python_match:
            return python_match.group(1).strip()
        
        # If not found with python tag, try generic code block
        python_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL)
        if python_match:
            return python_match.group(1).strip()
        
        # If all else fails, return empty string
        return ""
    
    def generate_sql(self, request: DataRequest) -> Tuple[str, str]:
        """Generate SQL for data request"""
        print(f"Data Agent: Generating SQL for request: {request.description}")
        
        # Format the request for the prompt
        request_text = f"""
Data Request: {request.description}

Tables needed: {', '.join(request.tables)}

{f"Columns needed: {', '.join(request.columns)}" if request.columns else ""}

{f"Filters: {json.dumps(request.filters)}" if request.filters else ""}

{f"Time period: {json.dumps(request.time_period)}" if request.time_period else ""}

{f"Group by: {', '.join(request.groupby)}" if request.groupby else ""}

Purpose: {request.purpose}

Please generate a SQL query for this request.
"""
        
        # Generate SQL
        response = self.sql_chain.invoke({"request": request_text})
        
        # Extract SQL query
        sql_query = self.extract_sql_from_response(response)
        
        return sql_query, response
    
    def create_data_pipeline(self, request: DataRequest) -> Tuple[DataPipeline, pd.DataFrame]:
        """Create data pipeline and execute it"""
        # Generate SQL
        sql_query, response = self.generate_sql(request)
        
        # Execute SQL to get data
        result_df = self.execute_sql(sql_query)
        
        # Create schema description
        schema = {col: str(result_df[col].dtype) for col in result_df.columns}
        
        # Create pipeline object
        pipeline = DataPipeline(
            pipeline_id=f"pipeline_{request.request_id}",
            name=f"Pipeline for {request.description}",
            sql=sql_query,
            description=request.description,
            data_source=", ".join(request.tables),
            schema=schema,
            output_table=f"result_{request.request_id}",
            purpose=request.purpose,
            visualization_hints=["time_series"] if "date" in " ".join(result_df.columns).lower() else ["comparison"]
        )
        
        return pipeline, result_df
    
    def transform_data(self, df: pd.DataFrame, transformation_request: str) -> Tuple[pd.DataFrame, str]:
        """Transform data using pandas based on request"""
        print(f"Data Agent: Transforming data based on request")
        
        # Create schema description
        schema = {col: str(df[col].dtype) for col in df.columns}
        
        # Format the request for the prompt
        request_text = {
            "data_description": f"Data with {len(df)} rows and {len(df.columns)} columns.",
            "transformation_request": transformation_request,
            "input_schema": json.dumps(schema, indent=2)
        }
        
        # Generate transformation code
        response = self.transform_chain.invoke(request_text)
        
        # Extract Python code
        python_code = self.extract_python_from_response(response)
        
        # Execute transformation (with safety checks)
        if not python_code:
            print("Warning: No transformation code generated.")
            return df, response
        
        try:
            # Create a local namespace with access to pandas and numpy
            local_namespace = {
                "pd": pd,
                "np": __import__("numpy"),
                "df": df.copy()
            }
            
            # Extract the function definition from the code
            exec(python_code, local_namespace)
            
            # Look for a transform_data function in the namespace
            if "transform_data" in local_namespace:
                transformed_df = local_namespace["transform_data"](df.copy())
                return transformed_df, response
            else:
                print("Warning: No transform_data function found in generated code.")
                return df, response
        except Exception as e:
            print(f"Transformation execution error: {e}")
            return df, response
    
    def get_data_for_analysis(self, data_requests: List[DataRequest]) -> Dict[str, DataSource]:
        """Process multiple data requests and return results"""
        data_sources = {}
        
        for request in data_requests:
            # Create data pipeline
            pipeline, result_df = self.create_data_pipeline(request)
            
            # Create data source object
            data_source = DataSource(
                source_id=request.request_id,
                name=request.description,
                content=result_df,
                schema=pipeline.schema
            )
            
            # Store data source
            data_sources[request.request_id] = data_source
        
        return data_sources
    
    def close(self):
        """Close database connection"""
        if hasattr(self, 'db_connection') and self.db_connection:
            self.db_connection.close()

# For testing
if __name__ == "__main__":
    # Set API key for testing
    os.environ["ANTHROPIC_API_KEY"] = "your_api_key_here"
    
    agent = DataAgent(db_path="data/pharma_db.sqlite")
    
    # Example data request
    request = DataRequest(
        request_id="drx_sales_trend",
        description="Monthly sales of DrugX by region over the past year",
        tables=["sales", "regions", "products"],
        filters={"product_id": "DRX"},
        time_period={"start": "2023-01-01", "end": "2023-12-31"},
        groupby=["region_id", "year_month"],
        purpose="Analyze sales trend of DrugX by region"
    )
    
    pipeline, df = agent.create_data_pipeline(request)
    print(f"Generated SQL:\n{pipeline.sql}")
    print(f"Result shape: {df.shape}")
    print(df.head())
    
    agent.close()