Spaces:
No application file
No application file
| import os | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from io import StringIO | |
| import json | |
| from typing import Dict, List, Optional, Union, Any | |
| import tempfile | |
| from .base import LLMTool | |
| class DataAnalysisTool(LLMTool): | |
| name: str = "Data Analysis Tool" | |
| description: str = "A tool that can analyze data files (CSV, Excel, etc.) and provide insights. It can generate statistics, visualizations, and exploratory data analysis." | |
| arg: str = "Either a file path or a JSON object with parameters for analysis. If providing a path, supply the full path to the data file. If providing parameters, use the format: {'file_path': 'path/to/file', 'analysis_type': 'basic|correlation|visualization', 'columns': ['col1', 'col2'], 'target': 'target_column'}" | |
| # Path to the currently loaded dataframe | |
| _current_file: str = None | |
| _df: Optional[pd.DataFrame] = None | |
| def load_data(self, file_path: str) -> str: | |
| """Load data from the specified file path.""" | |
| try: | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| if file_ext == '.csv': | |
| self._df = pd.read_csv(file_path) | |
| elif file_ext in ['.xlsx', '.xls']: | |
| self._df = pd.read_excel(file_path) | |
| elif file_ext == '.json': | |
| self._df = pd.read_json(file_path) | |
| elif file_ext == '.parquet': | |
| self._df = pd.read_parquet(file_path) | |
| elif file_ext == '.sql': | |
| # For SQL files, we expect a SQLite database | |
| import sqlite3 | |
| conn = sqlite3.connect(file_path) | |
| self._df = pd.read_sql("SELECT * FROM main_table", conn) | |
| conn.close() | |
| else: | |
| return f"Unsupported file format: {file_ext}. Supported formats: .csv, .xlsx, .xls, .json, .parquet, .sql" | |
| self._current_file = file_path | |
| return f"Successfully loaded data from {file_path}. Shape: {self._df.shape}. Columns: {', '.join(self._df.columns.tolist())}" | |
| except Exception as e: | |
| return f"Error loading data: {str(e)}" | |
| def generate_basic_stats(self, columns: Optional[List[str]] = None) -> Dict: | |
| """Generate basic statistics for the dataframe or specified columns.""" | |
| if self._df is None: | |
| return "No data loaded. Please load data first." | |
| try: | |
| if columns: | |
| # Filter to only include columns that exist in the dataframe | |
| valid_columns = [col for col in columns if col in self._df.columns] | |
| if not valid_columns: | |
| return f"None of the specified columns {columns} exist in the dataframe." | |
| df_subset = self._df[valid_columns] | |
| else: | |
| df_subset = self._df | |
| numeric_stats = df_subset.describe().to_dict() | |
| null_counts = df_subset.isnull().sum().to_dict() | |
| categorical_columns = df_subset.select_dtypes(include=['object', 'category']).columns | |
| unique_counts = {col: df_subset[col].nunique() for col in categorical_columns} | |
| stats = { | |
| "shape": self._df.shape, | |
| "columns": self._df.columns.tolist(), | |
| "numeric_stats": numeric_stats, | |
| "null_counts": null_counts, | |
| "unique_counts": unique_counts | |
| } | |
| return stats | |
| except Exception as e: | |
| return f"Error generating basic statistics: {str(e)}" | |
| def generate_correlation_analysis(self, columns: Optional[List[str]] = None) -> Dict: | |
| """Generate correlation analysis for numeric columns.""" | |
| if self._df is None: | |
| return "No data loaded. Please load data first." | |
| try: | |
| numeric_df = self._df.select_dtypes(include=[np.number]) | |
| if columns: | |
| # Filter to only include numeric columns that were specified | |
| valid_columns = [col for col in columns if col in numeric_df.columns] | |
| if not valid_columns: | |
| return f"None of the specified columns {columns} are numeric or exist in the dataframe." | |
| numeric_df = numeric_df[valid_columns] | |
| if numeric_df.empty: | |
| return "No numeric columns found in the dataset for correlation analysis." | |
| corr_matrix = numeric_df.corr().to_dict() | |
| corr_df = numeric_df.corr().abs() | |
| upper_tri = corr_df.where(np.triu(np.ones(corr_df.shape), k=1).astype(bool)) | |
| high_corr = [(col1, col2, upper_tri.loc[col1, col2]) | |
| for col1 in upper_tri.index | |
| for col2 in upper_tri.columns | |
| if upper_tri.loc[col1, col2] > 0.7] | |
| high_corr.sort(key=lambda x: x[2], reverse=True) | |
| return {"correlation_matrix": corr_matrix, "high_correlations": high_corr} | |
| except Exception as e: | |
| return f"Error generating correlation analysis: {str(e)}" | |
| def generate_visualization(self, viz_type: str, columns: Optional[List[str]] = None, target: Optional[str] = None) -> str: | |
| """Generate visualization based on the specified type and columns.""" | |
| if self._df is None: | |
| return "No data loaded. Please load data first." | |
| try: | |
| # Create a temporary directory for the visualization | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp: | |
| output_path = tmp.name | |
| plt.figure(figsize=(10, 6)) | |
| # Handle different visualization types | |
| if viz_type == 'histogram': | |
| if not columns or len(columns) == 0: | |
| # If no columns specified, use all numeric columns | |
| numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist() | |
| if not numeric_cols: | |
| return "No numeric columns found for histogram." | |
| # Limit to 4 columns for readability | |
| columns = numeric_cols[:4] | |
| # Filter to valid columns | |
| valid_columns = [col for col in columns if col in self._df.columns] | |
| if not valid_columns: | |
| return f"None of the specified columns {columns} exist in the dataframe." | |
| for col in valid_columns: | |
| if pd.api.types.is_numeric_dtype(self._df[col]): | |
| plt.hist(self._df[col].dropna(), alpha=0.5, label=col) | |
| plt.legend() | |
| plt.title(f"Histogram of {', '.join(valid_columns)}") | |
| plt.tight_layout() | |
| elif viz_type == 'scatter': | |
| if not columns or len(columns) < 2: | |
| return "Scatter plot requires at least two columns." | |
| # Check if columns exist | |
| if columns[0] not in self._df.columns or columns[1] not in self._df.columns: | |
| return f"One or more of the specified columns {columns[:2]} do not exist in the dataframe." | |
| # Create scatter plot | |
| x_col, y_col = columns[0], columns[1] | |
| plt.scatter(self._df[x_col], self._df[y_col], alpha=0.5) | |
| plt.xlabel(x_col) | |
| plt.ylabel(y_col) | |
| plt.title(f"Scatter Plot: {x_col} vs {y_col}") | |
| # Color by target if provided | |
| if target and target in self._df.columns: | |
| if pd.api.types.is_numeric_dtype(self._df[target]): | |
| scatter = plt.scatter(self._df[x_col], self._df[y_col], | |
| c=self._df[target], alpha=0.5) | |
| plt.colorbar(scatter, label=target) | |
| else: | |
| # For categorical targets, create multiple scatters | |
| categories = self._df[target].unique() | |
| for category in categories: | |
| mask = self._df[target] == category | |
| plt.scatter(self._df.loc[mask, x_col], self._df.loc[mask, y_col], alpha=0.5, label=str(category)) | |
| plt.legend() | |
| plt.tight_layout() | |
| elif viz_type == 'correlation': | |
| # Generate correlation heatmap | |
| numeric_df = self._df.select_dtypes(include=[np.number]) | |
| if columns: | |
| # Filter to valid numeric columns | |
| valid_columns = [col for col in columns if col in numeric_df.columns] | |
| if not valid_columns: | |
| return f"None of the specified columns {columns} are numeric or exist in the dataframe." | |
| numeric_df = numeric_df[valid_columns] | |
| if numeric_df.empty: | |
| return "No numeric columns found for correlation heatmap." | |
| sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm', linewidths=0.5) | |
| plt.title("Correlation Heatmap") | |
| plt.tight_layout() | |
| elif viz_type == 'boxplot': | |
| if not columns or len(columns) == 0: | |
| # If no columns specified, use all numeric columns | |
| numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist() | |
| if not numeric_cols: | |
| return "No numeric columns found for boxplot." | |
| # Limit to 5 columns for readability | |
| columns = numeric_cols[:5] | |
| # Filter to valid columns | |
| valid_columns = [col for col in columns if col in self._df.columns] | |
| if not valid_columns: | |
| return f"None of the specified columns {columns} exist in the dataframe." | |
| # Create boxplot | |
| self._df[valid_columns].boxplot() | |
| plt.title("Boxplot of Selected Columns") | |
| plt.xticks(rotation=45) | |
| plt.tight_layout() | |
| elif viz_type == 'pairplot': | |
| # Create a pair plot for multiple columns | |
| if not columns or len(columns) < 2: | |
| # Use first 4 numeric columns if not specified | |
| numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist() | |
| if len(numeric_cols) < 2: | |
| return "Not enough numeric columns for a pairplot." | |
| columns = numeric_cols[:min(4, len(numeric_cols))] | |
| # Filter to valid columns | |
| valid_columns = [col for col in columns if col in self._df.columns] | |
| if len(valid_columns) < 2: | |
| return f"Not enough valid columns in {columns} for a pairplot." | |
| # Use seaborn pairplot | |
| plt.close() # Close previous figure | |
| # Create a temporary directory for the visualization | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp: | |
| output_path = tmp.name | |
| if target and target in self._df.columns: | |
| g = sns.pairplot(self._df[valid_columns + [target]], hue=target, height=2.5) | |
| else: | |
| g = sns.pairplot(self._df[valid_columns], height=2.5) | |
| plt.suptitle("Pair Plot of Selected Features", y=1.02) | |
| plt.tight_layout() | |
| else: | |
| return f"Unsupported visualization type: {viz_type}. Supported types: histogram, scatter, correlation, boxplot, pairplot" | |
| plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| return f"Visualization saved to: {output_path}" | |
| except Exception as e: | |
| return f"Error generating visualization: {str(e)}" | |
| def generate_data_insights(self) -> str: | |
| """Generate AI-powered insights about the data.""" | |
| if self._df is None: | |
| return "No data loaded. Please load data first." | |
| try: | |
| # Get a sample and info about the data to send to the LLM | |
| df_sample = self._df.head(5).to_string() | |
| df_info = { | |
| "shape": self._df.shape, | |
| "columns": self._df.columns.tolist(), | |
| "dtypes": {col: str(self._df[col].dtype) for col in self._df.columns}, | |
| "missing_values": self._df.isnull().sum().to_dict(), | |
| "numeric_stats": self._df.describe().to_dict() if not self._df.select_dtypes(include=[np.number]).empty else {}, | |
| } | |
| prompt = f""" | |
| Analyze this dataset and provide key insights. | |
| Dataset Sample: | |
| {df_sample} | |
| Dataset Info: | |
| {json.dumps(df_info, indent=2)} | |
| Your task: | |
| 1. Identify the dataset type and potential use cases | |
| 2. Summarize the basic characteristics (rows, columns, data types) | |
| 3. Highlight key statistics and distributions | |
| 4. Point out missing data patterns if any | |
| 5. Suggest potential relationships or correlations worth exploring | |
| 6. Recommend next steps for deeper analysis | |
| 7. Note any data quality issues or anomalies | |
| Provide a comprehensive but concise analysis with actionable insights. | |
| """ | |
| # response = self.client.chat.completions.create( | |
| # model="gpt-4", | |
| # messages=[ | |
| # {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."}, | |
| # {"role": "user", "content": prompt} | |
| # ], | |
| # max_tokens=3000) | |
| # return response.choices[0].message.content | |
| openrouter_api_key = os.environ.get("OPENROUTER_API_KEY") | |
| model_name = os.environ.get("MODEL_NAME", "gpt-4") # Default to gpt-4 if MODEL_NAME is not set | |
| try: | |
| if openrouter_api_key: | |
| print(f"Using OpenRouter with model: {model_name} for data insights") | |
| client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=openrouter_api_key) | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_tokens=3000) | |
| else: # Fall back to default OpenAI client | |
| print("OpenRouter API key not found, using default OpenAI client with gpt-4") | |
| response = self.client.chat.completions.create( | |
| model="gpt-4", | |
| messages=[ | |
| {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_tokens=3000) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| print(f"Error with OpenRouter: {e}") | |
| print("Falling back to default OpenAI client with gpt-4") | |
| try: | |
| response = self.client.chat.completions.create( | |
| model="gpt-4", | |
| messages=[ | |
| {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_tokens=3000) | |
| return response.choices[0].message.content | |
| except Exception as e2: | |
| return f"Error generating data insights with fallback model: {str(e2)}" | |
| except Exception as e: | |
| return f"Error analyzing data for insights: {str(e)}" | |
| def run(self, prompt: Union[str, Dict]) -> str: | |
| """Run the data analysis tool.""" | |
| print(f"Calling Data Analysis Tool with prompt: {prompt}") | |
| try: # If prompt is a string, try to parse it as JSON or treat it as a file path | |
| if isinstance(prompt, str): | |
| try: | |
| params = json.loads(prompt) | |
| except json.JSONDecodeError: # Treat as file path | |
| return self.load_data(prompt) | |
| else: | |
| params = prompt | |
| # Handle different parameter options | |
| if 'file_path' in params: | |
| file_path = params['file_path'] | |
| # Load the data first | |
| load_result = self.load_data(file_path) | |
| if "Successfully" not in load_result: | |
| return load_result | |
| # If no analysis type is specified, generate insights | |
| if 'analysis_type' not in params: | |
| return self.generate_data_insights() | |
| analysis_type = params['analysis_type'].lower() | |
| columns = params.get('columns', None) | |
| target = params.get('target', None) | |
| if analysis_type == 'basic': | |
| stats = self.generate_basic_stats(columns) | |
| return json.dumps(stats, indent=2) | |
| elif analysis_type == 'correlation': | |
| corr_analysis = self.generate_correlation_analysis(columns) | |
| return json.dumps(corr_analysis, indent=2) | |
| elif analysis_type == 'visualization': | |
| viz_type = params.get('viz_type', 'histogram') | |
| return self.generate_visualization(viz_type, columns, target) | |
| elif analysis_type == 'insights': | |
| return self.generate_data_insights() | |
| else: | |
| return f"Unsupported analysis type: {analysis_type}. Supported types: basic, correlation, visualization, insights" | |
| except Exception as e: | |
| return f"Error executing data analysis: {str(e)}" | |