import os import json import sqlite3 import pandas as pd from typing import Dict, List, Any, Tuple from langchain_anthropic import ChatAnthropic from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from pydantic import BaseModel, Field import time import re class DataRequest(BaseModel): """Structure for a data request""" request_id: str description: str tables: List[str] columns: List[str] = None filters: Dict[str, Any] = None time_period: Dict[str, str] = None groupby: List[str] = None purpose: str class DataPipeline(BaseModel): """Structure for a data pipeline""" pipeline_id: str name: str sql: str description: str data_source: str schema: Dict[str, str] transformations: List[str] = None output_table: str purpose: str visualization_hints: List[str] = None class DataSource(BaseModel): """Structure for a data source""" source_id: str name: str content: Any # This will be the pandas DataFrame schema: Dict[str, str] class DataAgent: """Agent responsible for data acquisition and transformation""" def __init__(self, db_path: str = "data/pharma_db.sqlite"): """Initialize the data agent with database connection""" # Set up database connection self.db_path = db_path self.db_connection = sqlite3.connect(db_path) # Set up Claude API client api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: raise ValueError("ANTHROPIC_API_KEY not found in environment variables") self.llm = ChatAnthropic( model="claude-3-7-sonnet-20250219", anthropic_api_key=api_key, temperature=0.1 ) # Create SQL generation prompt self.sql_prompt = ChatPromptTemplate.from_messages([ ("system", """You are an expert SQL developer specializing in pharmaceutical data analysis. Your task is to translate natural language data requests into precise SQL queries suitable for a SQLite database. For each request, generate a SQL query that: 1. Retrieves only the necessary data for the analysis 2. Uses appropriate JOINs to connect related tables 3. Applies filters correctly 4. Includes relevant aggregations and groupings 5. Is optimized for performance Format your response as follows: ```sql -- Your SQL query here SELECT ... FROM ... WHERE ... ``` Explain your approach after the SQL block, describing: - Why you selected specific tables and columns - How the query addresses the analytical requirements - Any assumptions you made The database schema includes these tables and columns: - sales: sale_id, sale_date, product_id, region_id, territory_id, prescriber_id, pharmacy_id, units_sold, revenue, cost, margin - products: product_id, product_name, therapeutic_area, molecule, launch_date, status, list_price - regions: region_id, region_name, country, division, population - territories: territory_id, territory_name, region_id, sales_rep_id - prescribers: prescriber_id, name, specialty, practice_type, territory_id, decile - pharmacies: pharmacy_id, name, address, territory_id, pharmacy_type, monthly_rx_volume - competitor_products: competitor_product_id, product_name, manufacturer, therapeutic_area, molecule, launch_date, list_price, competing_with_product_id - marketing_campaigns: campaign_id, campaign_name, start_date, end_date, product_id, campaign_type, target_audience, channels, budget, spend - market_events: event_id, event_date, event_type, description, affected_products, affected_regions, impact_score - sales_targets: target_id, product_id, region_id, period, target_units, target_revenue - distribution_centers: dc_id, dc_name, region_id, inventory_capacity - inventory: inventory_id, product_id, dc_id, date, units_available, units_allocated, units_in_transit, days_of_supply - external_factors: factor_id, date, region_id, factor_type, factor_value, description """), ("human", "{request}") ]) # Set up the SQL generation chain self.sql_chain = ( self.sql_prompt | self.llm | StrOutputParser() ) # Create transformation prompt self.transform_prompt = ChatPromptTemplate.from_messages([ ("system", """You are an expert data engineer specializing in pharmaceutical data transformation. Your task is to generate Python code using pandas to transform the data based on the requirements. For each transformation request: 1. Generate clear, efficient pandas code 2. Include appropriate data cleaning steps 3. Apply necessary transformations (normalization, feature engineering, etc.) 4. Add comments explaining key steps 5. Handle potential edge cases and missing data Format your response with a code block: ```python # Transformation code import pandas as pd import numpy as np def transform_data(df): # Your transformation code here return transformed_df ``` After the code block, explain your transformation approach and any assumptions. """), ("human", """ Here is the data description: {data_description} Transformation needed: {transformation_request} Schema of the input data: {input_schema} Please generate the pandas code to perform this transformation. """) ]) # Set up the transformation chain self.transform_chain = ( self.transform_prompt | self.llm | StrOutputParser() ) def execute_sql(self, sql: str) -> pd.DataFrame: """Execute SQL query and return results as DataFrame""" try: start_time = time.time() df = pd.read_sql_query(sql, self.db_connection) end_time = time.time() print(f"SQL execution time: {end_time - start_time:.2f} seconds") print(f"Retrieved {len(df)} rows") return df except Exception as e: print(f"SQL execution error: {e}") print(f"Failed SQL: {sql}") raise def extract_sql_from_response(self, response: str) -> str: """Extract SQL query from LLM response""" # Extract SQL between ```sql and ``` markers sql_match = re.search(r'```sql\s*(.*?)\s*```', response, re.DOTALL) if sql_match: return sql_match.group(1).strip() # If not found with sql tag, try generic code block sql_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL) if sql_match: return sql_match.group(1).strip() # If no code blocks, look for SQL keywords sql_pattern = r'(?i)(SELECT[\s\S]+?FROM[\s\S]+?(WHERE|GROUP BY|ORDER BY|LIMIT|$)[\s\S]*)' sql_match = re.search(sql_pattern, response) if sql_match: return sql_match.group(0).strip() # If all else fails, return empty string return "" def extract_python_from_response(self, response: str) -> str: """Extract Python code from LLM response""" # Extract Python between ```python and ``` markers python_match = re.search(r'```python\s*(.*?)\s*```', response, re.DOTALL) if python_match: return python_match.group(1).strip() # If not found with python tag, try generic code block python_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL) if python_match: return python_match.group(1).strip() # If all else fails, return empty string return "" def generate_sql(self, request: DataRequest) -> Tuple[str, str]: """Generate SQL for data request""" print(f"Data Agent: Generating SQL for request: {request.description}") # Format the request for the prompt request_text = f""" Data Request: {request.description} Tables needed: {', '.join(request.tables)} {f"Columns needed: {', '.join(request.columns)}" if request.columns else ""} {f"Filters: {json.dumps(request.filters)}" if request.filters else ""} {f"Time period: {json.dumps(request.time_period)}" if request.time_period else ""} {f"Group by: {', '.join(request.groupby)}" if request.groupby else ""} Purpose: {request.purpose} Please generate a SQL query for this request. """ # Generate SQL response = self.sql_chain.invoke({"request": request_text}) # Extract SQL query sql_query = self.extract_sql_from_response(response) return sql_query, response def create_data_pipeline(self, request: DataRequest) -> Tuple[DataPipeline, pd.DataFrame]: """Create data pipeline and execute it""" # Generate SQL sql_query, response = self.generate_sql(request) # Execute SQL to get data result_df = self.execute_sql(sql_query) # Create schema description schema = {col: str(result_df[col].dtype) for col in result_df.columns} # Create pipeline object pipeline = DataPipeline( pipeline_id=f"pipeline_{request.request_id}", name=f"Pipeline for {request.description}", sql=sql_query, description=request.description, data_source=", ".join(request.tables), schema=schema, output_table=f"result_{request.request_id}", purpose=request.purpose, visualization_hints=["time_series"] if "date" in " ".join(result_df.columns).lower() else ["comparison"] ) return pipeline, result_df def transform_data(self, df: pd.DataFrame, transformation_request: str) -> Tuple[pd.DataFrame, str]: """Transform data using pandas based on request""" print(f"Data Agent: Transforming data based on request") # Create schema description schema = {col: str(df[col].dtype) for col in df.columns} # Format the request for the prompt request_text = { "data_description": f"Data with {len(df)} rows and {len(df.columns)} columns.", "transformation_request": transformation_request, "input_schema": json.dumps(schema, indent=2) } # Generate transformation code response = self.transform_chain.invoke(request_text) # Extract Python code python_code = self.extract_python_from_response(response) # Execute transformation (with safety checks) if not python_code: print("Warning: No transformation code generated.") return df, response try: # Create a local namespace with access to pandas and numpy local_namespace = { "pd": pd, "np": __import__("numpy"), "df": df.copy() } # Extract the function definition from the code exec(python_code, local_namespace) # Look for a transform_data function in the namespace if "transform_data" in local_namespace: transformed_df = local_namespace["transform_data"](df.copy()) return transformed_df, response else: print("Warning: No transform_data function found in generated code.") return df, response except Exception as e: print(f"Transformation execution error: {e}") return df, response def get_data_for_analysis(self, data_requests: List[DataRequest]) -> Dict[str, DataSource]: """Process multiple data requests and return results""" data_sources = {} for request in data_requests: # Create data pipeline pipeline, result_df = self.create_data_pipeline(request) # Create data source object data_source = DataSource( source_id=request.request_id, name=request.description, content=result_df, schema=pipeline.schema ) # Store data source data_sources[request.request_id] = data_source return data_sources def close(self): """Close database connection""" if hasattr(self, 'db_connection') and self.db_connection: self.db_connection.close() # For testing if __name__ == "__main__": # Set API key for testing os.environ["ANTHROPIC_API_KEY"] = "your_api_key_here" agent = DataAgent(db_path="data/pharma_db.sqlite") # Example data request request = DataRequest( request_id="drx_sales_trend", description="Monthly sales of DrugX by region over the past year", tables=["sales", "regions", "products"], filters={"product_id": "DRX"}, time_period={"start": "2023-01-01", "end": "2023-12-31"}, groupby=["region_id", "year_month"], purpose="Analyze sales trend of DrugX by region" ) pipeline, df = agent.create_data_pipeline(request) print(f"Generated SQL:\n{pipeline.sql}") print(f"Result shape: {df.shape}") print(df.head()) agent.close()