Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import sqlite3 | |
| import pandas as pd | |
| from typing import Dict, List, Any, Tuple | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from pydantic import BaseModel, Field | |
| import time | |
| import re | |
| class DataRequest(BaseModel): | |
| """Structure for a data request""" | |
| request_id: str | |
| description: str | |
| tables: List[str] | |
| columns: List[str] = None | |
| filters: Dict[str, Any] = None | |
| time_period: Dict[str, str] = None | |
| groupby: List[str] = None | |
| purpose: str | |
| class DataPipeline(BaseModel): | |
| """Structure for a data pipeline""" | |
| pipeline_id: str | |
| name: str | |
| sql: str | |
| description: str | |
| data_source: str | |
| schema: Dict[str, str] | |
| transformations: List[str] = None | |
| output_table: str | |
| purpose: str | |
| visualization_hints: List[str] = None | |
| class DataSource(BaseModel): | |
| """Structure for a data source""" | |
| source_id: str | |
| name: str | |
| content: Any # This will be the pandas DataFrame | |
| schema: Dict[str, str] | |
| class DataAgent: | |
| """Agent responsible for data acquisition and transformation""" | |
| def __init__(self, db_path: str = "data/pharma_db.sqlite"): | |
| """Initialize the data agent with database connection""" | |
| # Set up database connection | |
| self.db_path = db_path | |
| self.db_connection = sqlite3.connect(db_path) | |
| # Set up Claude API client | |
| api_key = os.getenv("ANTHROPIC_API_KEY") | |
| if not api_key: | |
| raise ValueError("ANTHROPIC_API_KEY not found in environment variables") | |
| self.llm = ChatAnthropic( | |
| model="claude-3-7-sonnet-20250219", | |
| anthropic_api_key=api_key, | |
| temperature=0.1 | |
| ) | |
| # Create SQL generation prompt | |
| self.sql_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """You are an expert SQL developer specializing in pharmaceutical data analysis. | |
| Your task is to translate natural language data requests into precise SQL queries suitable for a SQLite database. | |
| For each request, generate a SQL query that: | |
| 1. Retrieves only the necessary data for the analysis | |
| 2. Uses appropriate JOINs to connect related tables | |
| 3. Applies filters correctly | |
| 4. Includes relevant aggregations and groupings | |
| 5. Is optimized for performance | |
| Format your response as follows: | |
| ```sql | |
| -- Your SQL query here | |
| SELECT ... | |
| FROM ... | |
| WHERE ... | |
| ``` | |
| Explain your approach after the SQL block, describing: | |
| - Why you selected specific tables and columns | |
| - How the query addresses the analytical requirements | |
| - Any assumptions you made | |
| The database schema includes these tables and columns: | |
| - sales: sale_id, sale_date, product_id, region_id, territory_id, prescriber_id, pharmacy_id, units_sold, revenue, cost, margin | |
| - products: product_id, product_name, therapeutic_area, molecule, launch_date, status, list_price | |
| - regions: region_id, region_name, country, division, population | |
| - territories: territory_id, territory_name, region_id, sales_rep_id | |
| - prescribers: prescriber_id, name, specialty, practice_type, territory_id, decile | |
| - pharmacies: pharmacy_id, name, address, territory_id, pharmacy_type, monthly_rx_volume | |
| - competitor_products: competitor_product_id, product_name, manufacturer, therapeutic_area, molecule, launch_date, list_price, competing_with_product_id | |
| - marketing_campaigns: campaign_id, campaign_name, start_date, end_date, product_id, campaign_type, target_audience, channels, budget, spend | |
| - market_events: event_id, event_date, event_type, description, affected_products, affected_regions, impact_score | |
| - sales_targets: target_id, product_id, region_id, period, target_units, target_revenue | |
| - distribution_centers: dc_id, dc_name, region_id, inventory_capacity | |
| - inventory: inventory_id, product_id, dc_id, date, units_available, units_allocated, units_in_transit, days_of_supply | |
| - external_factors: factor_id, date, region_id, factor_type, factor_value, description | |
| """), | |
| ("human", "{request}") | |
| ]) | |
| # Set up the SQL generation chain | |
| self.sql_chain = ( | |
| self.sql_prompt | |
| | self.llm | |
| | StrOutputParser() | |
| ) | |
| # Create transformation prompt | |
| self.transform_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """You are an expert data engineer specializing in pharmaceutical data transformation. | |
| Your task is to generate Python code using pandas to transform the data based on the requirements. | |
| For each transformation request: | |
| 1. Generate clear, efficient pandas code | |
| 2. Include appropriate data cleaning steps | |
| 3. Apply necessary transformations (normalization, feature engineering, etc.) | |
| 4. Add comments explaining key steps | |
| 5. Handle potential edge cases and missing data | |
| Format your response with a code block: | |
| ```python | |
| # Transformation code | |
| import pandas as pd | |
| import numpy as np | |
| def transform_data(df): | |
| # Your transformation code here | |
| return transformed_df | |
| ``` | |
| After the code block, explain your transformation approach and any assumptions. | |
| """), | |
| ("human", """ | |
| Here is the data description: | |
| {data_description} | |
| Transformation needed: | |
| {transformation_request} | |
| Schema of the input data: | |
| {input_schema} | |
| Please generate the pandas code to perform this transformation. | |
| """) | |
| ]) | |
| # Set up the transformation chain | |
| self.transform_chain = ( | |
| self.transform_prompt | |
| | self.llm | |
| | StrOutputParser() | |
| ) | |
| def execute_sql(self, sql: str) -> pd.DataFrame: | |
| """Execute SQL query and return results as DataFrame""" | |
| try: | |
| start_time = time.time() | |
| df = pd.read_sql_query(sql, self.db_connection) | |
| end_time = time.time() | |
| print(f"SQL execution time: {end_time - start_time:.2f} seconds") | |
| print(f"Retrieved {len(df)} rows") | |
| return df | |
| except Exception as e: | |
| print(f"SQL execution error: {e}") | |
| print(f"Failed SQL: {sql}") | |
| raise | |
| def extract_sql_from_response(self, response: str) -> str: | |
| """Extract SQL query from LLM response""" | |
| # Extract SQL between ```sql and ``` markers | |
| sql_match = re.search(r'```sql\s*(.*?)\s*```', response, re.DOTALL) | |
| if sql_match: | |
| return sql_match.group(1).strip() | |
| # If not found with sql tag, try generic code block | |
| sql_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL) | |
| if sql_match: | |
| return sql_match.group(1).strip() | |
| # If no code blocks, look for SQL keywords | |
| sql_pattern = r'(?i)(SELECT[\s\S]+?FROM[\s\S]+?(WHERE|GROUP BY|ORDER BY|LIMIT|$)[\s\S]*)' | |
| sql_match = re.search(sql_pattern, response) | |
| if sql_match: | |
| return sql_match.group(0).strip() | |
| # If all else fails, return empty string | |
| return "" | |
| def extract_python_from_response(self, response: str) -> str: | |
| """Extract Python code from LLM response""" | |
| # Extract Python between ```python and ``` markers | |
| python_match = re.search(r'```python\s*(.*?)\s*```', response, re.DOTALL) | |
| if python_match: | |
| return python_match.group(1).strip() | |
| # If not found with python tag, try generic code block | |
| python_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL) | |
| if python_match: | |
| return python_match.group(1).strip() | |
| # If all else fails, return empty string | |
| return "" | |
| def generate_sql(self, request: DataRequest) -> Tuple[str, str]: | |
| """Generate SQL for data request""" | |
| print(f"Data Agent: Generating SQL for request: {request.description}") | |
| # Format the request for the prompt | |
| request_text = f""" | |
| Data Request: {request.description} | |
| Tables needed: {', '.join(request.tables)} | |
| {f"Columns needed: {', '.join(request.columns)}" if request.columns else ""} | |
| {f"Filters: {json.dumps(request.filters)}" if request.filters else ""} | |
| {f"Time period: {json.dumps(request.time_period)}" if request.time_period else ""} | |
| {f"Group by: {', '.join(request.groupby)}" if request.groupby else ""} | |
| Purpose: {request.purpose} | |
| Please generate a SQL query for this request. | |
| """ | |
| # Generate SQL | |
| response = self.sql_chain.invoke({"request": request_text}) | |
| # Extract SQL query | |
| sql_query = self.extract_sql_from_response(response) | |
| return sql_query, response | |
| def create_data_pipeline(self, request: DataRequest) -> Tuple[DataPipeline, pd.DataFrame]: | |
| """Create data pipeline and execute it""" | |
| # Generate SQL | |
| sql_query, response = self.generate_sql(request) | |
| # Execute SQL to get data | |
| result_df = self.execute_sql(sql_query) | |
| # Create schema description | |
| schema = {col: str(result_df[col].dtype) for col in result_df.columns} | |
| # Create pipeline object | |
| pipeline = DataPipeline( | |
| pipeline_id=f"pipeline_{request.request_id}", | |
| name=f"Pipeline for {request.description}", | |
| sql=sql_query, | |
| description=request.description, | |
| data_source=", ".join(request.tables), | |
| schema=schema, | |
| output_table=f"result_{request.request_id}", | |
| purpose=request.purpose, | |
| visualization_hints=["time_series"] if "date" in " ".join(result_df.columns).lower() else ["comparison"] | |
| ) | |
| return pipeline, result_df | |
| def transform_data(self, df: pd.DataFrame, transformation_request: str) -> Tuple[pd.DataFrame, str]: | |
| """Transform data using pandas based on request""" | |
| print(f"Data Agent: Transforming data based on request") | |
| # Create schema description | |
| schema = {col: str(df[col].dtype) for col in df.columns} | |
| # Format the request for the prompt | |
| request_text = { | |
| "data_description": f"Data with {len(df)} rows and {len(df.columns)} columns.", | |
| "transformation_request": transformation_request, | |
| "input_schema": json.dumps(schema, indent=2) | |
| } | |
| # Generate transformation code | |
| response = self.transform_chain.invoke(request_text) | |
| # Extract Python code | |
| python_code = self.extract_python_from_response(response) | |
| # Execute transformation (with safety checks) | |
| if not python_code: | |
| print("Warning: No transformation code generated.") | |
| return df, response | |
| try: | |
| # Create a local namespace with access to pandas and numpy | |
| local_namespace = { | |
| "pd": pd, | |
| "np": __import__("numpy"), | |
| "df": df.copy() | |
| } | |
| # Extract the function definition from the code | |
| exec(python_code, local_namespace) | |
| # Look for a transform_data function in the namespace | |
| if "transform_data" in local_namespace: | |
| transformed_df = local_namespace["transform_data"](df.copy()) | |
| return transformed_df, response | |
| else: | |
| print("Warning: No transform_data function found in generated code.") | |
| return df, response | |
| except Exception as e: | |
| print(f"Transformation execution error: {e}") | |
| return df, response | |
| def get_data_for_analysis(self, data_requests: List[DataRequest]) -> Dict[str, DataSource]: | |
| """Process multiple data requests and return results""" | |
| data_sources = {} | |
| for request in data_requests: | |
| # Create data pipeline | |
| pipeline, result_df = self.create_data_pipeline(request) | |
| # Create data source object | |
| data_source = DataSource( | |
| source_id=request.request_id, | |
| name=request.description, | |
| content=result_df, | |
| schema=pipeline.schema | |
| ) | |
| # Store data source | |
| data_sources[request.request_id] = data_source | |
| return data_sources | |
| def close(self): | |
| """Close database connection""" | |
| if hasattr(self, 'db_connection') and self.db_connection: | |
| self.db_connection.close() | |
| # For testing | |
| if __name__ == "__main__": | |
| # Set API key for testing | |
| os.environ["ANTHROPIC_API_KEY"] = "your_api_key_here" | |
| agent = DataAgent(db_path="data/pharma_db.sqlite") | |
| # Example data request | |
| request = DataRequest( | |
| request_id="drx_sales_trend", | |
| description="Monthly sales of DrugX by region over the past year", | |
| tables=["sales", "regions", "products"], | |
| filters={"product_id": "DRX"}, | |
| time_period={"start": "2023-01-01", "end": "2023-12-31"}, | |
| groupby=["region_id", "year_month"], | |
| purpose="Analyze sales trend of DrugX by region" | |
| ) | |
| pipeline, df = agent.create_data_pipeline(request) | |
| print(f"Generated SQL:\n{pipeline.sql}") | |
| print(f"Result shape: {df.shape}") | |
| print(df.head()) | |
| agent.close() |