jzou19950715 commited on
Commit
cce2b52
·
verified ·
1 Parent(s): f3139c4

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +270 -315
tools.py CHANGED
@@ -1,339 +1,294 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
-
4
  """
5
- Analysis and visualization tools for data analysis assistant.
6
- Provides a collection of tools for data analysis, statistical computations,
7
- and interactive visualizations using Plotly.
8
  """
9
 
10
- import logging
11
- from typing import Any, Dict, List, Optional, Tuple, Union
12
- from datetime import datetime
13
  from pathlib import Path
 
 
14
 
15
- import numpy as np
16
  import pandas as pd
17
- import plotly.express as px
18
- import plotly.graph_objects as go
19
- from plotly.subplots import make_subplots
20
- import seaborn as sns
21
- from scipy import stats
22
- from smolagents import tool
23
-
24
- # Configure logging
25
- logging.basicConfig(
26
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
27
- level=logging.INFO
28
  )
29
- logger = logging.getLogger(__name__)
30
-
31
- class AnalysisError(Exception):
32
- """Custom exception for analysis errors."""
33
- pass
34
-
35
- @tool
36
- def create_time_series_plot(
37
- df: pd.DataFrame,
38
- time_column: str,
39
- value_column: str,
40
- title: Optional[str] = None
41
- ) -> Dict[str, Any]:
42
- """
43
- Create an interactive time series plot.
44
 
45
- Args:
46
- df: Input DataFrame
47
- time_column: Name of the time column
48
- value_column: Name of the value column to plot
49
- title: Optional title for the plot
50
-
51
- Returns:
52
- Dict containing the plotly figure and stats
53
- """
54
- try:
55
- # Validate inputs
56
- if time_column not in df.columns or value_column not in df.columns:
57
- raise AnalysisError(f"Columns {time_column} or {value_column} not found in DataFrame")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Create plot
60
- fig = px.line(
61
- df,
62
- x=time_column,
63
- y=value_column,
64
- title=title or f"{value_column} over Time",
65
- template="plotly_white"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
-
68
- # Add hover data
69
- fig.update_traces(
70
- hovertemplate=(
71
- f"{time_column}: %{{x}}<br>"
72
- f"{value_column}: %{{y:.2f}}<br>"
73
- "<extra></extra>"
74
- )
 
 
 
 
 
 
75
  )
76
-
77
- # Calculate basic stats
78
- stats_dict = {
79
- "mean": df[value_column].mean(),
80
- "std": df[value_column].std(),
81
- "min": df[value_column].min(),
82
- "max": df[value_column].max()
83
- }
84
-
85
- return {"figure": fig, "stats": stats_dict}
86
-
87
- except Exception as e:
88
- logger.error(f"Error in create_time_series_plot: {str(e)}")
89
- raise AnalysisError(f"Failed to create time series plot: {str(e)}")
90
-
91
- @tool
92
- def create_correlation_heatmap(df: pd.DataFrame, numeric_only: bool = True) -> Dict[str, Any]:
93
- """
94
- Create an interactive correlation heatmap.
95
-
96
- Args:
97
- df: Input DataFrame
98
- numeric_only: Whether to include only numeric columns
99
-
100
- Returns:
101
- Dict containing the plotly figure and correlation matrix
102
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  try:
104
- # Select numeric columns if requested
105
- if numeric_only:
106
- df = df.select_dtypes(include=[np.number])
107
-
108
- # Calculate correlation matrix
109
- corr_matrix = df.corr()
110
-
111
- # Create heatmap
112
- fig = go.Figure(data=go.Heatmap(
113
- z=corr_matrix,
114
- x=corr_matrix.columns,
115
- y=corr_matrix.columns,
116
- colorscale='RdBu',
117
- zmid=0,
118
- text=np.round(corr_matrix, 2),
119
- texttemplate='%{text:.2f}',
120
- textfont={"size": 10},
121
- hoverongaps=False
122
- ))
123
-
124
- # Update layout
125
- fig.update_layout(
126
- title="Correlation Heatmap",
127
- template="plotly_white",
128
- width=800,
129
- height=800
130
- )
131
-
132
- return {
133
- "figure": fig,
134
- "correlation_matrix": corr_matrix.to_dict()
135
- }
136
-
137
  except Exception as e:
138
- logger.error(f"Error in create_correlation_heatmap: {str(e)}")
139
- raise AnalysisError(f"Failed to create correlation heatmap: {str(e)}")
140
-
141
- @tool
142
- def create_statistical_summary(df: pd.DataFrame, column: str) -> Dict[str, Any]:
143
- """
144
- Create statistical summary with visualization for a column.
145
-
146
- Args:
147
- df: Input DataFrame
148
- column: Column name to analyze
149
-
150
- Returns:
151
- Dict containing summary statistics and visualization
152
- """
153
- try:
154
- if column not in df.columns:
155
- raise AnalysisError(f"Column {column} not found in DataFrame")
156
-
157
- # Calculate summary statistics
158
- summary_stats = df[column].describe().to_dict()
159
 
160
- # Additional statistics
161
- if pd.api.types.is_numeric_dtype(df[column]):
162
- summary_stats.update({
163
- "skewness": stats.skew(df[column].dropna()),
164
- "kurtosis": stats.kurtosis(df[column].dropna())
165
- })
166
-
167
- # Create distribution plot
168
- fig = make_subplots(rows=2, cols=1)
 
169
 
170
- # Add histogram
171
- fig.add_trace(
172
- go.Histogram(
173
- x=df[column],
174
- name="Distribution",
175
- nbinsx=30
176
- ),
177
- row=1, col=1
178
- )
179
-
180
- # Add box plot
181
- fig.add_trace(
182
- go.Box(
183
- y=df[column],
184
- name="Box Plot"
185
- ),
186
- row=2, col=1
187
- )
188
-
189
- # Update layout
190
- fig.update_layout(
191
- title=f"Statistical Analysis of {column}",
192
- showlegend=False,
193
- template="plotly_white",
194
- height=800
195
- )
196
-
197
- return {
198
- "figure": fig,
199
- "stats": summary_stats
200
- }
201
-
202
  except Exception as e:
203
- logger.error(f"Error in create_statistical_summary: {str(e)}")
204
- raise AnalysisError(f"Failed to create statistical summary: {str(e)}")
205
-
206
- @tool
207
- def detect_outliers(
208
- df: pd.DataFrame,
209
- column: str,
210
- method: str = "zscore",
211
- threshold: float = 3.0
212
- ) -> Dict[str, Any]:
213
- """
214
- Detect outliers in a column using various methods.
215
-
216
- Args:
217
- df: Input DataFrame
218
- column: Column to analyze
219
- method: Detection method ('zscore' or 'iqr')
220
- threshold: Threshold for outlier detection
221
-
222
- Returns:
223
- Dict containing outlier indices and visualization
 
 
224
  """
225
- try:
226
- if column not in df.columns:
227
- raise AnalysisError(f"Column {column} not found in DataFrame")
228
-
229
- values = df[column].dropna()
230
 
231
- if method == "zscore":
232
- z_scores = np.abs(stats.zscore(values))
233
- outlier_mask = z_scores > threshold
234
- elif method == "iqr":
235
- Q1 = values.quantile(0.25)
236
- Q3 = values.quantile(0.75)
237
- IQR = Q3 - Q1
238
- outlier_mask = (values < (Q1 - threshold * IQR)) | (values > (Q3 + threshold * IQR))
239
- else:
240
- raise AnalysisError(f"Unknown outlier detection method: {method}")
241
-
242
- # Create visualization
243
- fig = go.Figure()
244
 
245
- # Add main scatter plot
246
- fig.add_trace(
247
- go.Scatter(
248
- x=df.index[~outlier_mask],
249
- y=values[~outlier_mask],
250
- mode='markers',
251
- name='Normal Points',
252
- marker=dict(color='blue')
253
- )
254
- )
255
-
256
- # Add outliers
257
- fig.add_trace(
258
- go.Scatter(
259
- x=df.index[outlier_mask],
260
- y=values[outlier_mask],
261
- mode='markers',
262
- name='Outliers',
263
- marker=dict(color='red')
264
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  )
266
-
267
- fig.update_layout(
268
- title=f"Outlier Detection for {column}",
269
- template="plotly_white",
270
- showlegend=True
 
 
 
 
271
  )
272
-
273
- return {
274
- "figure": fig,
275
- "outlier_indices": df.index[outlier_mask].tolist(),
276
- "outlier_count": sum(outlier_mask)
277
- }
278
-
279
- except Exception as e:
280
- logger.error(f"Error in detect_outliers: {str(e)}")
281
- raise AnalysisError(f"Failed to detect outliers: {str(e)}")
282
-
283
- # Additional utility functions
284
- def validate_dataframe(df: pd.DataFrame) -> Tuple[bool, str]:
285
- """
286
- Validate DataFrame for analysis.
287
-
288
- Args:
289
- df: Input DataFrame
290
-
291
- Returns:
292
- Tuple of (is_valid, error_message)
293
- """
294
- if df is None:
295
- return False, "DataFrame is None"
296
-
297
- if df.empty:
298
- return False, "DataFrame is empty"
299
 
300
- if df.columns.duplicated().any():
301
- return False, "DataFrame contains duplicate column names"
302
-
303
- return True, ""
304
-
305
- def get_numeric_columns(df: pd.DataFrame) -> List[str]:
306
- """Get list of numeric columns from DataFrame."""
307
- return df.select_dtypes(include=[np.number]).columns.tolist()
308
-
309
- def get_temporal_columns(df: pd.DataFrame) -> List[str]:
310
- """Get list of temporal columns from DataFrame."""
311
- temporal_cols = []
312
- for col in df.columns:
313
- try:
314
- pd.to_datetime(df[col])
315
- temporal_cols.append(col)
316
- except:
317
- continue
318
- return temporal_cols
319
 
320
  if __name__ == "__main__":
321
- # Example usage and testing
322
- logging.info("Running tools.py tests...")
323
-
324
- # Create sample data
325
- dates = pd.date_range(start='2023-01-01', periods=100, freq='D')
326
- df = pd.DataFrame({
327
- 'date': dates,
328
- 'value': np.random.normal(100, 10, 100),
329
- 'category': np.random.choice(['A', 'B', 'C'], 100)
330
- })
331
-
332
- # Test time series plot
333
- try:
334
- result = create_time_series_plot(df, 'date', 'value')
335
- logging.info("Time series plot created successfully")
336
- except Exception as e:
337
- logging.error(f"Time series plot test failed: {str(e)}")
338
-
339
- # Add more tests as needed
 
 
 
 
1
  """
2
+ Advanced Data Analysis Assistant with Interactive Visualizations
3
+ Integrates smolagents, GPT-4, and interactive Plotly visualizations.
 
4
  """
5
 
6
+ import json
7
+ import os
8
+ from dataclasses import dataclass
9
  from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Union, Tuple
11
+ from datetime import datetime
12
 
13
+ import gradio as gr
14
  import pandas as pd
15
+ from smolagents import CodeAgent, LiteLLMModel
16
+
17
+ from tools import (
18
+ create_time_series_plot,
19
+ create_correlation_heatmap,
20
+ create_statistical_summary,
21
+ detect_outliers,
22
+ validate_dataframe,
23
+ get_numeric_columns,
24
+ get_temporal_columns,
25
+ AnalysisError
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Constants
29
+ SUPPORTED_FILE_TYPES = [".csv", ".xlsx", ".xls"]
30
+ DEFAULT_MODEL = "gpt-4o-mini"
31
+ HISTORY_FILE = "analysis_history.json"
32
+
33
+ @dataclass
34
+ class VisualizationConfig:
35
+ """Configuration for visualizations."""
36
+ width: int = 800
37
+ height: int = 500
38
+ template: str = "plotly_white"
39
+ show_grid: bool = True
40
+ interactive: bool = True
41
+
42
+ class DataPreprocessor:
43
+ """Handles data preprocessing and validation."""
44
+
45
+ @staticmethod
46
+ def preprocess_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
47
+ """Preprocess the dataframe and return metadata."""
48
+ # First validate the dataframe
49
+ is_valid, error_msg = validate_dataframe(df)
50
+ if not is_valid:
51
+ raise ValueError(error_msg)
52
+
53
+ metadata = {
54
+ "original_shape": df.shape,
55
+ "missing_values": df.isnull().sum().to_dict(),
56
+ "dtypes": df.dtypes.astype(str).to_dict(),
57
+ "numeric_columns": get_numeric_columns(df),
58
+ "categorical_columns": df.select_dtypes(include=['object']).columns.tolist(),
59
+ "temporal_columns": get_temporal_columns(df)
60
+ }
61
+
62
+ # Handle missing values
63
+ df = df.fillna(method='ffill').fillna(method='bfill')
64
+
65
+ return df, metadata
66
 
67
+ class AnalysisHistory:
68
+ """Manages analysis history and persistence."""
69
+
70
+ def __init__(self, history_file: str = HISTORY_FILE):
71
+ self.history_file = history_file
72
+ self.history = self._load_history()
73
+
74
+ def _load_history(self) -> List[Dict]:
75
+ if os.path.exists(self.history_file):
76
+ try:
77
+ with open(self.history_file, 'r') as f:
78
+ return json.load(f)
79
+ except:
80
+ return []
81
+ return []
82
+
83
+ def add_entry(self, query: str, result: str) -> None:
84
+ """Add new analysis entry to history."""
85
+ entry = {
86
+ 'timestamp': datetime.now().isoformat(),
87
+ 'query': query,
88
+ 'result': result
89
+ }
90
+ self.history.append(entry)
91
+
92
+ with open(self.history_file, 'w') as f:
93
+ json.dump(self.history, f)
94
+
95
+ def get_recent_analyses(self, limit: int = 5) -> List[Dict]:
96
+ """Get recent analysis entries."""
97
+ return sorted(
98
+ self.history,
99
+ key=lambda x: x['timestamp'],
100
+ reverse=True
101
+ )[:limit]
102
+
103
+ class DataAnalysisAssistant:
104
+ """Enhanced data analysis assistant with visualization capabilities."""
105
+
106
+ def __init__(self, api_key: str):
107
+ self.model = LiteLLMModel(
108
+ model_id=DEFAULT_MODEL,
109
+ api_key=api_key
110
  )
111
+ self.history = AnalysisHistory()
112
+
113
+ self.agent = CodeAgent(
114
+ model=self.model,
115
+ tools=[
116
+ create_time_series_plot,
117
+ create_correlation_heatmap,
118
+ create_statistical_summary,
119
+ detect_outliers
120
+ ],
121
+ additional_authorized_imports=[
122
+ 'pandas', 'numpy', 'plotly.express', 'plotly.graph_objects',
123
+ 'seaborn', 'scipy', 'statsmodels'
124
+ ],
125
  )
126
+
127
+ def analyze(self, df: pd.DataFrame, query: str) -> str:
128
+ """Perform analysis with interactive visualizations."""
129
+ try:
130
+ df, metadata = DataPreprocessor.preprocess_dataframe(df)
131
+ context = self._create_analysis_context(df, metadata, query)
132
+ response = self.agent.run(context, additional_args={"df": df})
133
+ self.history.add_entry(query, str(response))
134
+ return self._format_results(response)
135
+ except Exception as e:
136
+ return f"Analysis failed: {str(e)}"
137
+
138
+ def _create_analysis_context(self, df: pd.DataFrame, metadata: Dict, query: str) -> str:
139
+ """Create detailed context for analysis."""
140
+ tools_description = """
141
+ Available analysis tools:
142
+ - create_time_series_plot: Create interactive time series visualizations
143
+ - create_correlation_heatmap: Generate correlation analysis with heatmap
144
+ - create_statistical_summary: Compute statistical summaries with visualizations
145
+ - detect_outliers: Identify and visualize outliers
146
+ """
147
+
148
+ return f"""
149
+ Analyze the following data with interactive visualizations.
150
+
151
+ DataFrame Information:
152
+ - Shape: {metadata['original_shape']}
153
+ - Numeric columns: {', '.join(metadata['numeric_columns'])}
154
+ - Categorical columns: {', '.join(metadata['categorical_columns'])}
155
+ - Temporal columns: {', '.join(metadata['temporal_columns'])}
156
+
157
+ {tools_description}
158
+
159
+ User Query: {query}
160
+
161
+ Guidelines:
162
+ 1. Use the provided analysis tools for visualizations
163
+ 2. Include clear titles and labels
164
+ 3. Handle errors gracefully
165
+ 4. Chain multiple analyses when needed
166
+ 5. Provide insights along with visualizations
167
+
168
+ The DataFrame is available as 'df'.
169
+ """
170
+
171
+ def _format_results(self, response: str) -> str:
172
+ """Format analysis results with visualizations."""
173
+ return f'<div class="analysis-text">{response}</div>'
174
+
175
+ def process_file(file: gr.File) -> Optional[pd.DataFrame]:
176
+ """Process uploaded file into DataFrame."""
177
+ if not file:
178
+ return None
179
+
180
  try:
181
+ file_path = Path(file.name)
182
+ if file_path.suffix == '.csv':
183
+ return pd.read_csv(file_path)
184
+ elif file_path.suffix in ('.xlsx', '.xls'):
185
+ return pd.read_excel(file_path)
186
+ else:
187
+ raise ValueError(f"Unsupported file type: {file_path.suffix}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  except Exception as e:
189
+ raise RuntimeError(f"Error reading file: {str(e)}")
190
+
191
+ def analyze_data(
192
+ file: gr.File,
193
+ query: str,
194
+ api_key: str,
195
+ ) -> str:
196
+ """Main analysis function for Gradio interface."""
197
+ if not api_key:
198
+ return "Error: Please provide an API key"
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ if not file:
201
+ return "Error: Please upload a data file"
202
+
203
+ try:
204
+ df = process_file(file)
205
+ if df is None:
206
+ return "Error: Could not process file"
207
+
208
+ assistant = DataAnalysisAssistant(api_key)
209
+ return assistant.analyze(df, query)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  except Exception as e:
212
+ return f"Error: {str(e)}"
213
+
214
+ def create_interface():
215
+ """Create enhanced Gradio interface."""
216
+ css = """
217
+ .plot-container {
218
+ margin: 20px 0;
219
+ padding: 15px;
220
+ border: 1px solid #e0e0e0;
221
+ border-radius: 8px;
222
+ background: white;
223
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
224
+ }
225
+ .analysis-text {
226
+ margin: 20px 0;
227
+ line-height: 1.6;
228
+ }
229
+ .error {
230
+ color: red;
231
+ padding: 10px;
232
+ margin: 10px 0;
233
+ border-left: 4px solid red;
234
+ }
235
  """
236
+
237
+ with gr.Blocks(css=css) as interface:
238
+ gr.Markdown("""
239
+ # Advanced Data Analysis Assistant
 
240
 
241
+ Upload your data and get AI-powered analysis with interactive visualizations.
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ **Features:**
244
+ - Interactive Plotly visualizations
245
+ - gpt-4o-mini powered analysis
246
+ - Time series analysis
247
+ - Statistical insights
248
+ - Natural language queries
249
+
250
+ **Required:** OpenAI API key
251
+ """)
252
+
253
+ with gr.Row():
254
+ with gr.Column():
255
+ file = gr.File(
256
+ label="Upload Data File",
257
+ file_types=SUPPORTED_FILE_TYPES
258
+ )
259
+ query = gr.Textbox(
260
+ label="What would you like to analyze?",
261
+ placeholder="e.g., Analyze trends and patterns in the data with interactive visualizations",
262
+ lines=3
263
+ )
264
+ api_key = gr.Textbox(
265
+ label="OpenAI API Key",
266
+ placeholder="Your API key",
267
+ type="password"
268
+ )
269
+ analyze_btn = gr.Button("Analyze")
270
+
271
+ with gr.Column():
272
+ output = gr.HTML(label="Analysis Results")
273
+
274
+ analyze_btn.click(
275
+ analyze_data,
276
+ inputs=[file, query, api_key],
277
+ outputs=output
278
  )
279
+
280
+ gr.Examples(
281
+ examples=[
282
+ [None, "Show trends over time with interactive visualizations"],
283
+ [None, "Create a comprehensive analysis of relationships between variables"],
284
+ [None, "Analyze distributions and statistical patterns"],
285
+ [None, "Generate financial metrics and performance indicators"],
286
+ ],
287
+ inputs=[file, query]
288
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
+ return interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  if __name__ == "__main__":
293
+ interface = create_interface()
294
+ interface.launch()