sales_analytics / agents /data_agent.py
cryogenic22's picture
Update agents/data_agent.py
b144e5a verified
import os
import json
import sqlite3
import pandas as pd
from typing import Dict, List, Any, Tuple
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
import time
import re
class DataRequest(BaseModel):
"""Structure for a data request"""
request_id: str
description: str
tables: List[str]
columns: List[str] = None
filters: Dict[str, Any] = None
time_period: Dict[str, str] = None
groupby: List[str] = None
purpose: str
class DataPipeline(BaseModel):
"""Structure for a data pipeline"""
pipeline_id: str
name: str
sql: str
description: str
data_source: str
schema: Dict[str, str]
transformations: List[str] = None
output_table: str
purpose: str
visualization_hints: List[str] = None
class DataSource(BaseModel):
"""Structure for a data source"""
source_id: str
name: str
content: Any # This will be the pandas DataFrame
schema: Dict[str, str]
class DataAgent:
"""Agent responsible for data acquisition and transformation"""
def __init__(self, db_path: str = "data/pharma_db.sqlite"):
"""Initialize the data agent with database connection"""
# Set up database connection
self.db_path = db_path
self.db_connection = sqlite3.connect(db_path)
# Set up Claude API client
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
self.llm = ChatAnthropic(
model="claude-3-7-sonnet-20250219",
anthropic_api_key=api_key,
temperature=0.1
)
# Create SQL generation prompt
self.sql_prompt = ChatPromptTemplate.from_messages([
("system", """You are an expert SQL developer specializing in pharmaceutical data analysis.
Your task is to translate natural language data requests into precise SQL queries suitable for a SQLite database.
For each request, generate a SQL query that:
1. Retrieves only the necessary data for the analysis
2. Uses appropriate JOINs to connect related tables
3. Applies filters correctly
4. Includes relevant aggregations and groupings
5. Is optimized for performance
Format your response as follows:
```sql
-- Your SQL query here
SELECT ...
FROM ...
WHERE ...
```
Explain your approach after the SQL block, describing:
- Why you selected specific tables and columns
- How the query addresses the analytical requirements
- Any assumptions you made
The database schema includes these tables and columns:
- sales: sale_id, sale_date, product_id, region_id, territory_id, prescriber_id, pharmacy_id, units_sold, revenue, cost, margin
- products: product_id, product_name, therapeutic_area, molecule, launch_date, status, list_price
- regions: region_id, region_name, country, division, population
- territories: territory_id, territory_name, region_id, sales_rep_id
- prescribers: prescriber_id, name, specialty, practice_type, territory_id, decile
- pharmacies: pharmacy_id, name, address, territory_id, pharmacy_type, monthly_rx_volume
- competitor_products: competitor_product_id, product_name, manufacturer, therapeutic_area, molecule, launch_date, list_price, competing_with_product_id
- marketing_campaigns: campaign_id, campaign_name, start_date, end_date, product_id, campaign_type, target_audience, channels, budget, spend
- market_events: event_id, event_date, event_type, description, affected_products, affected_regions, impact_score
- sales_targets: target_id, product_id, region_id, period, target_units, target_revenue
- distribution_centers: dc_id, dc_name, region_id, inventory_capacity
- inventory: inventory_id, product_id, dc_id, date, units_available, units_allocated, units_in_transit, days_of_supply
- external_factors: factor_id, date, region_id, factor_type, factor_value, description
"""),
("human", "{request}")
])
# Set up the SQL generation chain
self.sql_chain = (
self.sql_prompt
| self.llm
| StrOutputParser()
)
# Create transformation prompt
self.transform_prompt = ChatPromptTemplate.from_messages([
("system", """You are an expert data engineer specializing in pharmaceutical data transformation.
Your task is to generate Python code using pandas to transform the data based on the requirements.
For each transformation request:
1. Generate clear, efficient pandas code
2. Include appropriate data cleaning steps
3. Apply necessary transformations (normalization, feature engineering, etc.)
4. Add comments explaining key steps
5. Handle potential edge cases and missing data
Format your response with a code block:
```python
# Transformation code
import pandas as pd
import numpy as np
def transform_data(df):
# Your transformation code here
return transformed_df
```
After the code block, explain your transformation approach and any assumptions.
"""),
("human", """
Here is the data description:
{data_description}
Transformation needed:
{transformation_request}
Schema of the input data:
{input_schema}
Please generate the pandas code to perform this transformation.
""")
])
# Set up the transformation chain
self.transform_chain = (
self.transform_prompt
| self.llm
| StrOutputParser()
)
def execute_sql(self, sql: str) -> pd.DataFrame:
"""Execute SQL query and return results as DataFrame"""
try:
start_time = time.time()
df = pd.read_sql_query(sql, self.db_connection)
end_time = time.time()
print(f"SQL execution time: {end_time - start_time:.2f} seconds")
print(f"Retrieved {len(df)} rows")
return df
except Exception as e:
print(f"SQL execution error: {e}")
print(f"Failed SQL: {sql}")
raise
def extract_sql_from_response(self, response: str) -> str:
"""Extract SQL query from LLM response"""
# Extract SQL between ```sql and ``` markers
sql_match = re.search(r'```sql\s*(.*?)\s*```', response, re.DOTALL)
if sql_match:
return sql_match.group(1).strip()
# If not found with sql tag, try generic code block
sql_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL)
if sql_match:
return sql_match.group(1).strip()
# If no code blocks, look for SQL keywords
sql_pattern = r'(?i)(SELECT[\s\S]+?FROM[\s\S]+?(WHERE|GROUP BY|ORDER BY|LIMIT|$)[\s\S]*)'
sql_match = re.search(sql_pattern, response)
if sql_match:
return sql_match.group(0).strip()
# If all else fails, return empty string
return ""
def extract_python_from_response(self, response: str) -> str:
"""Extract Python code from LLM response"""
# Extract Python between ```python and ``` markers
python_match = re.search(r'```python\s*(.*?)\s*```', response, re.DOTALL)
if python_match:
return python_match.group(1).strip()
# If not found with python tag, try generic code block
python_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL)
if python_match:
return python_match.group(1).strip()
# If all else fails, return empty string
return ""
def generate_sql(self, request: DataRequest) -> Tuple[str, str]:
"""Generate SQL for data request"""
print(f"Data Agent: Generating SQL for request: {request.description}")
# Format the request for the prompt
request_text = f"""
Data Request: {request.description}
Tables needed: {', '.join(request.tables)}
{f"Columns needed: {', '.join(request.columns)}" if request.columns else ""}
{f"Filters: {json.dumps(request.filters)}" if request.filters else ""}
{f"Time period: {json.dumps(request.time_period)}" if request.time_period else ""}
{f"Group by: {', '.join(request.groupby)}" if request.groupby else ""}
Purpose: {request.purpose}
Please generate a SQL query for this request.
"""
# Generate SQL
response = self.sql_chain.invoke({"request": request_text})
# Extract SQL query
sql_query = self.extract_sql_from_response(response)
return sql_query, response
def create_data_pipeline(self, request: DataRequest) -> Tuple[DataPipeline, pd.DataFrame]:
"""Create data pipeline and execute it"""
# Generate SQL
sql_query, response = self.generate_sql(request)
# Execute SQL to get data
result_df = self.execute_sql(sql_query)
# Create schema description
schema = {col: str(result_df[col].dtype) for col in result_df.columns}
# Create pipeline object
pipeline = DataPipeline(
pipeline_id=f"pipeline_{request.request_id}",
name=f"Pipeline for {request.description}",
sql=sql_query,
description=request.description,
data_source=", ".join(request.tables),
schema=schema,
output_table=f"result_{request.request_id}",
purpose=request.purpose,
visualization_hints=["time_series"] if "date" in " ".join(result_df.columns).lower() else ["comparison"]
)
return pipeline, result_df
def transform_data(self, df: pd.DataFrame, transformation_request: str) -> Tuple[pd.DataFrame, str]:
"""Transform data using pandas based on request"""
print(f"Data Agent: Transforming data based on request")
# Create schema description
schema = {col: str(df[col].dtype) for col in df.columns}
# Format the request for the prompt
request_text = {
"data_description": f"Data with {len(df)} rows and {len(df.columns)} columns.",
"transformation_request": transformation_request,
"input_schema": json.dumps(schema, indent=2)
}
# Generate transformation code
response = self.transform_chain.invoke(request_text)
# Extract Python code
python_code = self.extract_python_from_response(response)
# Execute transformation (with safety checks)
if not python_code:
print("Warning: No transformation code generated.")
return df, response
try:
# Create a local namespace with access to pandas and numpy
local_namespace = {
"pd": pd,
"np": __import__("numpy"),
"df": df.copy()
}
# Extract the function definition from the code
exec(python_code, local_namespace)
# Look for a transform_data function in the namespace
if "transform_data" in local_namespace:
transformed_df = local_namespace["transform_data"](df.copy())
return transformed_df, response
else:
print("Warning: No transform_data function found in generated code.")
return df, response
except Exception as e:
print(f"Transformation execution error: {e}")
return df, response
def get_data_for_analysis(self, data_requests: List[DataRequest]) -> Dict[str, DataSource]:
"""Process multiple data requests and return results"""
data_sources = {}
for request in data_requests:
# Create data pipeline
pipeline, result_df = self.create_data_pipeline(request)
# Create data source object
data_source = DataSource(
source_id=request.request_id,
name=request.description,
content=result_df,
schema=pipeline.schema
)
# Store data source
data_sources[request.request_id] = data_source
return data_sources
def close(self):
"""Close database connection"""
if hasattr(self, 'db_connection') and self.db_connection:
self.db_connection.close()
# For testing
if __name__ == "__main__":
# Set API key for testing
os.environ["ANTHROPIC_API_KEY"] = "your_api_key_here"
agent = DataAgent(db_path="data/pharma_db.sqlite")
# Example data request
request = DataRequest(
request_id="drx_sales_trend",
description="Monthly sales of DrugX by region over the past year",
tables=["sales", "regions", "products"],
filters={"product_id": "DRX"},
time_period={"start": "2023-01-01", "end": "2023-12-31"},
groupby=["region_id", "year_month"],
purpose="Analyze sales trend of DrugX by region"
)
pipeline, df = agent.create_data_pipeline(request)
print(f"Generated SQL:\n{pipeline.sql}")
print(f"Result shape: {df.shape}")
print(df.head())
agent.close()