import os import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from io import StringIO import json from typing import Dict, List, Optional, Union, Any import tempfile from .base import LLMTool class DataAnalysisTool(LLMTool): name: str = "Data Analysis Tool" description: str = "A tool that can analyze data files (CSV, Excel, etc.) and provide insights. It can generate statistics, visualizations, and exploratory data analysis." arg: str = "Either a file path or a JSON object with parameters for analysis. If providing a path, supply the full path to the data file. If providing parameters, use the format: {'file_path': 'path/to/file', 'analysis_type': 'basic|correlation|visualization', 'columns': ['col1', 'col2'], 'target': 'target_column'}" # Path to the currently loaded dataframe _current_file: str = None _df: Optional[pd.DataFrame] = None def load_data(self, file_path: str) -> str: """Load data from the specified file path.""" try: file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.csv': self._df = pd.read_csv(file_path) elif file_ext in ['.xlsx', '.xls']: self._df = pd.read_excel(file_path) elif file_ext == '.json': self._df = pd.read_json(file_path) elif file_ext == '.parquet': self._df = pd.read_parquet(file_path) elif file_ext == '.sql': # For SQL files, we expect a SQLite database import sqlite3 conn = sqlite3.connect(file_path) self._df = pd.read_sql("SELECT * FROM main_table", conn) conn.close() else: return f"Unsupported file format: {file_ext}. Supported formats: .csv, .xlsx, .xls, .json, .parquet, .sql" self._current_file = file_path return f"Successfully loaded data from {file_path}. Shape: {self._df.shape}. Columns: {', '.join(self._df.columns.tolist())}" except Exception as e: return f"Error loading data: {str(e)}" def generate_basic_stats(self, columns: Optional[List[str]] = None) -> Dict: """Generate basic statistics for the dataframe or specified columns.""" if self._df is None: return "No data loaded. Please load data first." try: if columns: # Filter to only include columns that exist in the dataframe valid_columns = [col for col in columns if col in self._df.columns] if not valid_columns: return f"None of the specified columns {columns} exist in the dataframe." df_subset = self._df[valid_columns] else: df_subset = self._df numeric_stats = df_subset.describe().to_dict() null_counts = df_subset.isnull().sum().to_dict() categorical_columns = df_subset.select_dtypes(include=['object', 'category']).columns unique_counts = {col: df_subset[col].nunique() for col in categorical_columns} stats = { "shape": self._df.shape, "columns": self._df.columns.tolist(), "numeric_stats": numeric_stats, "null_counts": null_counts, "unique_counts": unique_counts } return stats except Exception as e: return f"Error generating basic statistics: {str(e)}" def generate_correlation_analysis(self, columns: Optional[List[str]] = None) -> Dict: """Generate correlation analysis for numeric columns.""" if self._df is None: return "No data loaded. Please load data first." try: numeric_df = self._df.select_dtypes(include=[np.number]) if columns: # Filter to only include numeric columns that were specified valid_columns = [col for col in columns if col in numeric_df.columns] if not valid_columns: return f"None of the specified columns {columns} are numeric or exist in the dataframe." numeric_df = numeric_df[valid_columns] if numeric_df.empty: return "No numeric columns found in the dataset for correlation analysis." corr_matrix = numeric_df.corr().to_dict() corr_df = numeric_df.corr().abs() upper_tri = corr_df.where(np.triu(np.ones(corr_df.shape), k=1).astype(bool)) high_corr = [(col1, col2, upper_tri.loc[col1, col2]) for col1 in upper_tri.index for col2 in upper_tri.columns if upper_tri.loc[col1, col2] > 0.7] high_corr.sort(key=lambda x: x[2], reverse=True) return {"correlation_matrix": corr_matrix, "high_correlations": high_corr} except Exception as e: return f"Error generating correlation analysis: {str(e)}" def generate_visualization(self, viz_type: str, columns: Optional[List[str]] = None, target: Optional[str] = None) -> str: """Generate visualization based on the specified type and columns.""" if self._df is None: return "No data loaded. Please load data first." try: # Create a temporary directory for the visualization with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp: output_path = tmp.name plt.figure(figsize=(10, 6)) # Handle different visualization types if viz_type == 'histogram': if not columns or len(columns) == 0: # If no columns specified, use all numeric columns numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist() if not numeric_cols: return "No numeric columns found for histogram." # Limit to 4 columns for readability columns = numeric_cols[:4] # Filter to valid columns valid_columns = [col for col in columns if col in self._df.columns] if not valid_columns: return f"None of the specified columns {columns} exist in the dataframe." for col in valid_columns: if pd.api.types.is_numeric_dtype(self._df[col]): plt.hist(self._df[col].dropna(), alpha=0.5, label=col) plt.legend() plt.title(f"Histogram of {', '.join(valid_columns)}") plt.tight_layout() elif viz_type == 'scatter': if not columns or len(columns) < 2: return "Scatter plot requires at least two columns." # Check if columns exist if columns[0] not in self._df.columns or columns[1] not in self._df.columns: return f"One or more of the specified columns {columns[:2]} do not exist in the dataframe." # Create scatter plot x_col, y_col = columns[0], columns[1] plt.scatter(self._df[x_col], self._df[y_col], alpha=0.5) plt.xlabel(x_col) plt.ylabel(y_col) plt.title(f"Scatter Plot: {x_col} vs {y_col}") # Color by target if provided if target and target in self._df.columns: if pd.api.types.is_numeric_dtype(self._df[target]): scatter = plt.scatter(self._df[x_col], self._df[y_col], c=self._df[target], alpha=0.5) plt.colorbar(scatter, label=target) else: # For categorical targets, create multiple scatters categories = self._df[target].unique() for category in categories: mask = self._df[target] == category plt.scatter(self._df.loc[mask, x_col], self._df.loc[mask, y_col], alpha=0.5, label=str(category)) plt.legend() plt.tight_layout() elif viz_type == 'correlation': # Generate correlation heatmap numeric_df = self._df.select_dtypes(include=[np.number]) if columns: # Filter to valid numeric columns valid_columns = [col for col in columns if col in numeric_df.columns] if not valid_columns: return f"None of the specified columns {columns} are numeric or exist in the dataframe." numeric_df = numeric_df[valid_columns] if numeric_df.empty: return "No numeric columns found for correlation heatmap." sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm', linewidths=0.5) plt.title("Correlation Heatmap") plt.tight_layout() elif viz_type == 'boxplot': if not columns or len(columns) == 0: # If no columns specified, use all numeric columns numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist() if not numeric_cols: return "No numeric columns found for boxplot." # Limit to 5 columns for readability columns = numeric_cols[:5] # Filter to valid columns valid_columns = [col for col in columns if col in self._df.columns] if not valid_columns: return f"None of the specified columns {columns} exist in the dataframe." # Create boxplot self._df[valid_columns].boxplot() plt.title("Boxplot of Selected Columns") plt.xticks(rotation=45) plt.tight_layout() elif viz_type == 'pairplot': # Create a pair plot for multiple columns if not columns or len(columns) < 2: # Use first 4 numeric columns if not specified numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist() if len(numeric_cols) < 2: return "Not enough numeric columns for a pairplot." columns = numeric_cols[:min(4, len(numeric_cols))] # Filter to valid columns valid_columns = [col for col in columns if col in self._df.columns] if len(valid_columns) < 2: return f"Not enough valid columns in {columns} for a pairplot." # Use seaborn pairplot plt.close() # Close previous figure # Create a temporary directory for the visualization with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp: output_path = tmp.name if target and target in self._df.columns: g = sns.pairplot(self._df[valid_columns + [target]], hue=target, height=2.5) else: g = sns.pairplot(self._df[valid_columns], height=2.5) plt.suptitle("Pair Plot of Selected Features", y=1.02) plt.tight_layout() else: return f"Unsupported visualization type: {viz_type}. Supported types: histogram, scatter, correlation, boxplot, pairplot" plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() return f"Visualization saved to: {output_path}" except Exception as e: return f"Error generating visualization: {str(e)}" def generate_data_insights(self) -> str: """Generate AI-powered insights about the data.""" if self._df is None: return "No data loaded. Please load data first." try: # Get a sample and info about the data to send to the LLM df_sample = self._df.head(5).to_string() df_info = { "shape": self._df.shape, "columns": self._df.columns.tolist(), "dtypes": {col: str(self._df[col].dtype) for col in self._df.columns}, "missing_values": self._df.isnull().sum().to_dict(), "numeric_stats": self._df.describe().to_dict() if not self._df.select_dtypes(include=[np.number]).empty else {}, } prompt = f""" Analyze this dataset and provide key insights. Dataset Sample: {df_sample} Dataset Info: {json.dumps(df_info, indent=2)} Your task: 1. Identify the dataset type and potential use cases 2. Summarize the basic characteristics (rows, columns, data types) 3. Highlight key statistics and distributions 4. Point out missing data patterns if any 5. Suggest potential relationships or correlations worth exploring 6. Recommend next steps for deeper analysis 7. Note any data quality issues or anomalies Provide a comprehensive but concise analysis with actionable insights. """ # response = self.client.chat.completions.create( # model="gpt-4", # messages=[ # {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."}, # {"role": "user", "content": prompt} # ], # max_tokens=3000) # return response.choices[0].message.content openrouter_api_key = os.environ.get("OPENROUTER_API_KEY") model_name = os.environ.get("MODEL_NAME", "gpt-4") # Default to gpt-4 if MODEL_NAME is not set try: if openrouter_api_key: print(f"Using OpenRouter with model: {model_name} for data insights") client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=openrouter_api_key) response = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."}, {"role": "user", "content": prompt} ], max_tokens=3000) else: # Fall back to default OpenAI client print("OpenRouter API key not found, using default OpenAI client with gpt-4") response = self.client.chat.completions.create( model="gpt-4", messages=[ {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."}, {"role": "user", "content": prompt} ], max_tokens=3000) return response.choices[0].message.content except Exception as e: print(f"Error with OpenRouter: {e}") print("Falling back to default OpenAI client with gpt-4") try: response = self.client.chat.completions.create( model="gpt-4", messages=[ {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."}, {"role": "user", "content": prompt} ], max_tokens=3000) return response.choices[0].message.content except Exception as e2: return f"Error generating data insights with fallback model: {str(e2)}" except Exception as e: return f"Error analyzing data for insights: {str(e)}" def run(self, prompt: Union[str, Dict]) -> str: """Run the data analysis tool.""" print(f"Calling Data Analysis Tool with prompt: {prompt}") try: # If prompt is a string, try to parse it as JSON or treat it as a file path if isinstance(prompt, str): try: params = json.loads(prompt) except json.JSONDecodeError: # Treat as file path return self.load_data(prompt) else: params = prompt # Handle different parameter options if 'file_path' in params: file_path = params['file_path'] # Load the data first load_result = self.load_data(file_path) if "Successfully" not in load_result: return load_result # If no analysis type is specified, generate insights if 'analysis_type' not in params: return self.generate_data_insights() analysis_type = params['analysis_type'].lower() columns = params.get('columns', None) target = params.get('target', None) if analysis_type == 'basic': stats = self.generate_basic_stats(columns) return json.dumps(stats, indent=2) elif analysis_type == 'correlation': corr_analysis = self.generate_correlation_analysis(columns) return json.dumps(corr_analysis, indent=2) elif analysis_type == 'visualization': viz_type = params.get('viz_type', 'histogram') return self.generate_visualization(viz_type, columns, target) elif analysis_type == 'insights': return self.generate_data_insights() else: return f"Unsupported analysis type: {analysis_type}. Supported types: basic, correlation, visualization, insights" except Exception as e: return f"Error executing data analysis: {str(e)}"