File size: 18,430 Bytes
4418db4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import StringIO
import json
from typing import Dict, List, Optional, Union, Any
import tempfile
from .base import LLMTool
class DataAnalysisTool(LLMTool):
    name: str = "Data Analysis Tool"
    description: str = "A tool that can analyze data files (CSV, Excel, etc.) and provide insights. It can generate statistics, visualizations, and exploratory data analysis."
    arg: str = "Either a file path or a JSON object with parameters for analysis. If providing a path, supply the full path to the data file. If providing parameters, use the format: {'file_path': 'path/to/file', 'analysis_type': 'basic|correlation|visualization', 'columns': ['col1', 'col2'], 'target': 'target_column'}"
    # Path to the currently loaded dataframe
    _current_file: str = None
    _df: Optional[pd.DataFrame] = None
    def load_data(self, file_path: str) -> str:
        """Load data from the specified file path."""
        try:
            file_ext = os.path.splitext(file_path)[1].lower()
            if file_ext == '.csv':
                self._df = pd.read_csv(file_path)
            elif file_ext in ['.xlsx', '.xls']:
                self._df = pd.read_excel(file_path)
            elif file_ext == '.json':
                self._df = pd.read_json(file_path)
            elif file_ext == '.parquet':
                self._df = pd.read_parquet(file_path)
            elif file_ext == '.sql':
                # For SQL files, we expect a SQLite database
                import sqlite3
                conn = sqlite3.connect(file_path)
                self._df = pd.read_sql("SELECT * FROM main_table", conn)
                conn.close()
            else:
                return f"Unsupported file format: {file_ext}. Supported formats: .csv, .xlsx, .xls, .json, .parquet, .sql"
            self._current_file = file_path
            return f"Successfully loaded data from {file_path}. Shape: {self._df.shape}. Columns: {', '.join(self._df.columns.tolist())}"
        except Exception as e:
            return f"Error loading data: {str(e)}"
    def generate_basic_stats(self, columns: Optional[List[str]] = None) -> Dict:
        """Generate basic statistics for the dataframe or specified columns."""
        if self._df is None:
            return "No data loaded. Please load data first."
        try:
            if columns:
                # Filter to only include columns that exist in the dataframe
                valid_columns = [col for col in columns if col in self._df.columns]
                if not valid_columns:
                    return f"None of the specified columns {columns} exist in the dataframe."
                df_subset = self._df[valid_columns]
            else:
                df_subset = self._df
            numeric_stats = df_subset.describe().to_dict()
            null_counts = df_subset.isnull().sum().to_dict()
            categorical_columns = df_subset.select_dtypes(include=['object', 'category']).columns
            unique_counts = {col: df_subset[col].nunique() for col in categorical_columns}
            stats = {
                "shape": self._df.shape,
                "columns": self._df.columns.tolist(),
                "numeric_stats": numeric_stats,
                "null_counts": null_counts,
                "unique_counts": unique_counts
            }
            return stats
        except Exception as e:
            return f"Error generating basic statistics: {str(e)}"
    def generate_correlation_analysis(self, columns: Optional[List[str]] = None) -> Dict:
        """Generate correlation analysis for numeric columns."""
        if self._df is None:
            return "No data loaded. Please load data first."
        try:
            numeric_df = self._df.select_dtypes(include=[np.number])
            if columns:
                # Filter to only include numeric columns that were specified
                valid_columns = [col for col in columns if col in numeric_df.columns]
                if not valid_columns:
                    return f"None of the specified columns {columns} are numeric or exist in the dataframe."
                numeric_df = numeric_df[valid_columns]
            if numeric_df.empty:
                return "No numeric columns found in the dataset for correlation analysis."
            corr_matrix = numeric_df.corr().to_dict()
            corr_df = numeric_df.corr().abs()
            upper_tri = corr_df.where(np.triu(np.ones(corr_df.shape), k=1).astype(bool))
            high_corr = [(col1, col2, upper_tri.loc[col1, col2]) 
                         for col1 in upper_tri.index 
                         for col2 in upper_tri.columns 
                         if upper_tri.loc[col1, col2] > 0.7]
            high_corr.sort(key=lambda x: x[2], reverse=True)
            return {"correlation_matrix": corr_matrix, "high_correlations": high_corr}
        except Exception as e:
            return f"Error generating correlation analysis: {str(e)}"
    def generate_visualization(self, viz_type: str, columns: Optional[List[str]] = None, target: Optional[str] = None) -> str:
        """Generate visualization based on the specified type and columns."""
        if self._df is None:
            return "No data loaded. Please load data first."
        try:
            # Create a temporary directory for the visualization
            with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
                output_path = tmp.name
            plt.figure(figsize=(10, 6))
            # Handle different visualization types
            if viz_type == 'histogram':
                if not columns or len(columns) == 0:
                    # If no columns specified, use all numeric columns
                    numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist()
                    if not numeric_cols:
                        return "No numeric columns found for histogram."
                    # Limit to 4 columns for readability
                    columns = numeric_cols[:4]
                # Filter to valid columns
                valid_columns = [col for col in columns if col in self._df.columns]
                if not valid_columns:
                    return f"None of the specified columns {columns} exist in the dataframe."
                for col in valid_columns:
                    if pd.api.types.is_numeric_dtype(self._df[col]):
                        plt.hist(self._df[col].dropna(), alpha=0.5, label=col)
                plt.legend()
                plt.title(f"Histogram of {', '.join(valid_columns)}")
                plt.tight_layout()
            elif viz_type == 'scatter':
                if not columns or len(columns) < 2:
                    return "Scatter plot requires at least two columns."
                # Check if columns exist
                if columns[0] not in self._df.columns or columns[1] not in self._df.columns:
                    return f"One or more of the specified columns {columns[:2]} do not exist in the dataframe."
                # Create scatter plot
                x_col, y_col = columns[0], columns[1]
                plt.scatter(self._df[x_col], self._df[y_col], alpha=0.5)
                plt.xlabel(x_col)
                plt.ylabel(y_col)
                plt.title(f"Scatter Plot: {x_col} vs {y_col}")
                # Color by target if provided
                if target and target in self._df.columns:
                    if pd.api.types.is_numeric_dtype(self._df[target]):
                        scatter = plt.scatter(self._df[x_col], self._df[y_col], 
                                            c=self._df[target], alpha=0.5)
                        plt.colorbar(scatter, label=target)
                    else:
                        # For categorical targets, create multiple scatters
                        categories = self._df[target].unique()
                        for category in categories:
                            mask = self._df[target] == category
                            plt.scatter(self._df.loc[mask, x_col], self._df.loc[mask, y_col], alpha=0.5, label=str(category))
                        plt.legend()
                plt.tight_layout()
            elif viz_type == 'correlation':
                # Generate correlation heatmap
                numeric_df = self._df.select_dtypes(include=[np.number])
                if columns:
                    # Filter to valid numeric columns
                    valid_columns = [col for col in columns if col in numeric_df.columns]
                    if not valid_columns:
                        return f"None of the specified columns {columns} are numeric or exist in the dataframe."
                    numeric_df = numeric_df[valid_columns]
                if numeric_df.empty:
                    return "No numeric columns found for correlation heatmap."
                sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm', linewidths=0.5)
                plt.title("Correlation Heatmap")
                plt.tight_layout()
            elif viz_type == 'boxplot':
                if not columns or len(columns) == 0:
                    # If no columns specified, use all numeric columns
                    numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist()
                    if not numeric_cols:
                        return "No numeric columns found for boxplot."
                    # Limit to 5 columns for readability
                    columns = numeric_cols[:5]
                # Filter to valid columns
                valid_columns = [col for col in columns if col in self._df.columns]
                if not valid_columns:
                    return f"None of the specified columns {columns} exist in the dataframe."
                # Create boxplot
                self._df[valid_columns].boxplot()
                plt.title("Boxplot of Selected Columns")
                plt.xticks(rotation=45)
                plt.tight_layout()
            elif viz_type == 'pairplot':
                # Create a pair plot for multiple columns
                if not columns or len(columns) < 2:
                    # Use first 4 numeric columns if not specified
                    numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist()
                    if len(numeric_cols) < 2:
                        return "Not enough numeric columns for a pairplot."
                    columns = numeric_cols[:min(4, len(numeric_cols))]
                # Filter to valid columns
                valid_columns = [col for col in columns if col in self._df.columns]
                if len(valid_columns) < 2:
                    return f"Not enough valid columns in {columns} for a pairplot."
                # Use seaborn pairplot
                plt.close()  # Close previous figure
                # Create a temporary directory for the visualization
                with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
                    output_path = tmp.name
                if target and target in self._df.columns:
                    g = sns.pairplot(self._df[valid_columns + [target]], hue=target, height=2.5)
                else:
                    g = sns.pairplot(self._df[valid_columns], height=2.5)
                plt.suptitle("Pair Plot of Selected Features", y=1.02)
                plt.tight_layout()
            else:
                return f"Unsupported visualization type: {viz_type}. Supported types: histogram, scatter, correlation, boxplot, pairplot"
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            plt.close()
            return f"Visualization saved to: {output_path}"
        except Exception as e:
            return f"Error generating visualization: {str(e)}"
    def generate_data_insights(self) -> str:
        """Generate AI-powered insights about the data."""
        if self._df is None:
            return "No data loaded. Please load data first."
        try:
            # Get a sample and info about the data to send to the LLM
            df_sample = self._df.head(5).to_string()
            df_info = {
                "shape": self._df.shape,
                "columns": self._df.columns.tolist(),
                "dtypes": {col: str(self._df[col].dtype) for col in self._df.columns},
                "missing_values": self._df.isnull().sum().to_dict(),
                "numeric_stats": self._df.describe().to_dict() if not self._df.select_dtypes(include=[np.number]).empty else {},
            }
            
            prompt = f"""
            Analyze this dataset and provide key insights.
            
            Dataset Sample:
            {df_sample}
            
            Dataset Info:
            {json.dumps(df_info, indent=2)}
            
            Your task:
            1. Identify the dataset type and potential use cases
            2. Summarize the basic characteristics (rows, columns, data types)
            3. Highlight key statistics and distributions
            4. Point out missing data patterns if any
            5. Suggest potential relationships or correlations worth exploring
            6. Recommend next steps for deeper analysis
            7. Note any data quality issues or anomalies
            
            Provide a comprehensive but concise analysis with actionable insights.
            """
            
        #    response = self.client.chat.completions.create(
        #        model="gpt-4",
        #        messages=[
        #            {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."},
        #            {"role": "user", "content": prompt}
        #        ],
        #        max_tokens=3000)
        #    return response.choices[0].message.content
            openrouter_api_key = os.environ.get("OPENROUTER_API_KEY")
            model_name = os.environ.get("MODEL_NAME", "gpt-4")  # Default to gpt-4 if MODEL_NAME is not set
            try:
                if openrouter_api_key:
                    print(f"Using OpenRouter with model: {model_name} for data insights")
                    client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=openrouter_api_key)
                    response = client.chat.completions.create(
                        model=model_name,
                        messages=[
                            {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."},
                            {"role": "user", "content": prompt}
                        ],
                        max_tokens=3000)
                else: # Fall back to default OpenAI client
                    print("OpenRouter API key not found, using default OpenAI client with gpt-4")
                    response = self.client.chat.completions.create(
                        model="gpt-4",
                        messages=[
                            {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."},
                            {"role": "user", "content": prompt}
                        ],
                        max_tokens=3000)
                return response.choices[0].message.content
            except Exception as e:
                print(f"Error with OpenRouter: {e}")
                print("Falling back to default OpenAI client with gpt-4")
                try:
                    response = self.client.chat.completions.create(
                        model="gpt-4",
                        messages=[
                            {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."},
                            {"role": "user", "content": prompt}
                        ],
                        max_tokens=3000)
                    return response.choices[0].message.content
                except Exception as e2:
                    return f"Error generating data insights with fallback model: {str(e2)}"
        except Exception as e:
            return f"Error analyzing data for insights: {str(e)}"            
    def run(self, prompt: Union[str, Dict]) -> str:
        """Run the data analysis tool."""
        print(f"Calling Data Analysis Tool with prompt: {prompt}")
        try: # If prompt is a string, try to parse it as JSON or treat it as a file path
            if isinstance(prompt, str):
                try:
                    params = json.loads(prompt)
                except json.JSONDecodeError: # Treat as file path
                    return self.load_data(prompt)
            else:
                params = prompt
            # Handle different parameter options
            if 'file_path' in params:
                file_path = params['file_path']
                # Load the data first
                load_result = self.load_data(file_path)
                if "Successfully" not in load_result:
                    return load_result
            # If no analysis type is specified, generate insights
            if 'analysis_type' not in params:
                return self.generate_data_insights()
            analysis_type = params['analysis_type'].lower()
            columns = params.get('columns', None)
            target = params.get('target', None)
            if analysis_type == 'basic':
                stats = self.generate_basic_stats(columns)
                return json.dumps(stats, indent=2)
            elif analysis_type == 'correlation':
                corr_analysis = self.generate_correlation_analysis(columns)
                return json.dumps(corr_analysis, indent=2)
            elif analysis_type == 'visualization':
                viz_type = params.get('viz_type', 'histogram')
                return self.generate_visualization(viz_type, columns, target)
            elif analysis_type == 'insights':
                return self.generate_data_insights()
            else:
                return f"Unsupported analysis type: {analysis_type}. Supported types: basic, correlation, visualization, insights"
        except Exception as e:
            return f"Error executing data analysis: {str(e)}"