jzou19950715 commited on
Commit
ca192a7
·
verified ·
1 Parent(s): 47e9852

Delete tools.py

Browse files
Files changed (1) hide show
  1. tools.py +0 -404
tools.py DELETED
@@ -1,404 +0,0 @@
1
- """
2
- Advanced Data Analysis Assistant with Interactive Visualizations
3
- Integrates smolagents, GPT-4, and interactive Plotly visualizations.
4
- """
5
-
6
- import json
7
- import logging
8
- import os
9
- import sys
10
- import subprocess
11
- from dataclasses import dataclass, asdict
12
- from datetime import datetime
13
- from pathlib import Path
14
- from typing import Any, Dict, List, Optional, Tuple, Union
15
- from functools import wraps
16
-
17
- # Set up logging
18
- logging.basicConfig(
19
- level=logging.INFO,
20
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
21
- )
22
- logger = logging.getLogger(__name__)
23
-
24
- # Auto-install required packages
25
- def install_missing_packages():
26
- required_packages = [
27
- 'gradio',
28
- 'pandas',
29
- 'smolagents',
30
- 'plotly',
31
- 'numpy',
32
- 'scikit-learn',
33
- 'seaborn',
34
- 'openpyxl' # For Excel support
35
- ]
36
-
37
- for package in required_packages:
38
- try:
39
- __import__(package)
40
- except ImportError:
41
- logger.info(f"Installing {package}...")
42
- subprocess.check_call([sys.executable, "-m", "pip", "install", package])
43
-
44
- install_missing_packages()
45
-
46
- # Now import the installed packages
47
- import gradio as gr
48
- import pandas as pd
49
- import numpy as np
50
- from smolagents import CodeAgent, LiteLLMModel
51
-
52
- from tools import (
53
- create_time_series_plot,
54
- create_correlation_heatmap,
55
- create_statistical_summary,
56
- detect_outliers,
57
- validate_dataframe,
58
- get_numeric_columns,
59
- get_temporal_columns,
60
- )
61
-
62
- # Custom Exceptions
63
- class AnalysisError(Exception):
64
- """Base exception for analysis errors."""
65
- pass
66
-
67
- class DataValidationError(AnalysisError):
68
- """Exception for data validation errors."""
69
- pass
70
-
71
- class APIKeyError(AnalysisError):
72
- """Exception for API key related errors."""
73
- pass
74
-
75
- # Constants
76
- SUPPORTED_FILE_TYPES = [".csv", ".xlsx", ".xls"]
77
- DEFAULT_MODEL = "gpt-4o-mini"
78
- HISTORY_FILE = "analysis_history.json"
79
- MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
80
-
81
- @dataclass
82
- class VisualizationConfig:
83
- """Configuration for visualizations."""
84
- width: int = 800
85
- height: int = 500
86
- template: str = "plotly_white"
87
- show_grid: bool = True
88
- interactive: bool = True
89
-
90
- def to_dict(self) -> Dict[str, Any]:
91
- """Convert config to dictionary."""
92
- return asdict(self)
93
-
94
- def error_handler(func):
95
- """Decorator for handling errors gracefully."""
96
- @wraps(func)
97
- def wrapper(*args, **kwargs):
98
- try:
99
- return func(*args, **kwargs)
100
- except AnalysisError as e:
101
- logger.error(f"Analysis error: {str(e)}")
102
- return f"Analysis error: {str(e)}"
103
- except Exception as e:
104
- logger.exception("Unexpected error occurred")
105
- return f"An unexpected error occurred: {str(e)}"
106
- return wrapper
107
-
108
- class AnalysisHistory:
109
- """Manages analysis history and persistence."""
110
-
111
- def __init__(self, history_file: str = HISTORY_FILE):
112
- self.history_file = Path(history_file)
113
- self.history: List[Dict] = self._load_history()
114
-
115
- def _load_history(self) -> List[Dict]:
116
- """Load history from file with error handling."""
117
- if self.history_file.exists():
118
- try:
119
- with self.history_file.open('r') as f:
120
- return json.load(f)
121
- except json.JSONDecodeError as e:
122
- logger.error(f"Error loading history file: {e}")
123
- return []
124
- except Exception as e:
125
- logger.exception("Unexpected error loading history")
126
- return []
127
- return []
128
-
129
- def _save_history(self) -> None:
130
- """Save history to file with error handling."""
131
- try:
132
- with self.history_file.open('w') as f:
133
- json.dump(self.history, f, indent=2)
134
- except Exception as e:
135
- logger.error(f"Error saving history: {e}")
136
-
137
- def add_entry(self, query: str, result: str) -> None:
138
- """Add new analysis entry to history."""
139
- entry = {
140
- 'timestamp': datetime.now().isoformat(),
141
- 'query': query,
142
- 'result': result
143
- }
144
- self.history.append(entry)
145
- self._save_history()
146
-
147
- def get_recent_analyses(self, limit: int = 5) -> List[Dict]:
148
- """Get recent analysis entries."""
149
- return sorted(
150
- self.history,
151
- key=lambda x: x['timestamp'],
152
- reverse=True
153
- )[:limit]
154
-
155
- def clear_history(self) -> None:
156
- """Clear analysis history."""
157
- self.history = []
158
- self._save_history()
159
-
160
- class DataPreprocessor:
161
- """Handles data preprocessing and validation."""
162
-
163
- @staticmethod
164
- def preprocess_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
165
- """Preprocess the dataframe and return metadata."""
166
- if df.empty:
167
- raise DataValidationError("DataFrame is empty")
168
-
169
- # First validate the dataframe
170
- is_valid, error_msg = validate_dataframe(df)
171
- if not is_valid:
172
- raise DataValidationError(error_msg)
173
-
174
- # Generate metadata
175
- metadata = {
176
- "original_shape": df.shape,
177
- "missing_values": df.isnull().sum().to_dict(),
178
- "dtypes": df.dtypes.astype(str).to_dict(),
179
- "numeric_columns": get_numeric_columns(df),
180
- "categorical_columns": df.select_dtypes(include=['object']).columns.tolist(),
181
- "temporal_columns": get_temporal_columns(df),
182
- "memory_usage": df.memory_usage(deep=True).sum() / (1024 * 1024) # MB
183
- }
184
-
185
- # Handle missing values
186
- df = df.copy() # Avoid modifying original
187
- df = df.fillna(method='ffill').fillna(method='bfill')
188
-
189
- return df, metadata
190
-
191
- class DataAnalysisAssistant:
192
- """Enhanced data analysis assistant with visualization capabilities."""
193
-
194
- def __init__(self, api_key: str):
195
- if not api_key:
196
- raise APIKeyError("API key is required")
197
-
198
- self.model = LiteLLMModel(
199
- model_id=DEFAULT_MODEL,
200
- api_key=api_key
201
- )
202
- self.history = AnalysisHistory()
203
- self.viz_config = VisualizationConfig()
204
-
205
- self.agent = CodeAgent(
206
- model=self.model,
207
- tools=[
208
- create_time_series_plot,
209
- create_correlation_heatmap,
210
- create_statistical_summary,
211
- detect_outliers
212
- ],
213
- additional_authorized_imports=[
214
- 'pandas', 'numpy', 'plotly.express', 'plotly.graph_objects',
215
- 'seaborn', 'scipy', 'statsmodels'
216
- ],
217
- )
218
-
219
- @error_handler
220
- def analyze(self, df: pd.DataFrame, query: str) -> str:
221
- """Perform analysis with interactive visualizations."""
222
- df, metadata = DataPreprocessor.preprocess_dataframe(df)
223
- context = self._create_analysis_context(df, metadata, query)
224
- response = self.agent.run(context, additional_args={"df": df})
225
- self.history.add_entry(query, str(response))
226
- return self._format_results(response)
227
-
228
- def _create_analysis_context(self, df: pd.DataFrame, metadata: Dict, query: str) -> str:
229
- """Create detailed context for analysis."""
230
- tools_description = """
231
- Available analysis tools:
232
- - create_time_series_plot: Create interactive time series visualizations
233
- - create_correlation_heatmap: Generate correlation analysis with heatmap
234
- - create_statistical_summary: Compute statistical summaries with visualizations
235
- - detect_outliers: Identify and visualize outliers
236
- """
237
-
238
- return f"""
239
- Analyze the following data with interactive visualizations.
240
-
241
- DataFrame Information:
242
- - Shape: {metadata['original_shape']}
243
- - Numeric columns: {', '.join(metadata['numeric_columns'])}
244
- - Categorical columns: {', '.join(metadata['categorical_columns'])}
245
- - Temporal columns: {', '.join(metadata['temporal_columns'])}
246
-
247
- {tools_description}
248
-
249
- User Query: {query}
250
-
251
- Guidelines:
252
- 1. Use the provided analysis tools for visualizations
253
- 2. Include clear titles and labels
254
- 3. Handle errors gracefully
255
- 4. Chain multiple analyses when needed
256
- 5. Provide insights along with visualizations
257
-
258
- The DataFrame is available as 'df'.
259
- """
260
-
261
- def _format_results(self, response: str) -> str:
262
- """Format analysis results with visualizations."""
263
- return f'<div class="analysis-text">{response}</div>'
264
-
265
- @error_handler
266
- def process_file(file: gr.File) -> Optional[pd.DataFrame]:
267
- """Process uploaded file into DataFrame."""
268
- if not file:
269
- raise DataValidationError("No file provided")
270
-
271
- file_path = Path(file.name)
272
- if file_path.stat().st_size > MAX_FILE_SIZE:
273
- raise DataValidationError(f"File size exceeds maximum limit of {MAX_FILE_SIZE/1024/1024}MB")
274
-
275
- if file_path.suffix not in SUPPORTED_FILE_TYPES:
276
- raise DataValidationError(f"Unsupported file type: {file_path.suffix}")
277
-
278
- try:
279
- if file_path.suffix == '.csv':
280
- return pd.read_csv(file_path)
281
- else: # .xlsx or .xls
282
- return pd.read_excel(file_path)
283
- except Exception as e:
284
- raise DataValidationError(f"Error reading file: {str(e)}")
285
-
286
- @error_handler
287
- def analyze_data(
288
- file: gr.File,
289
- query: str,
290
- api_key: str,
291
- ) -> str:
292
- """Main analysis function for Gradio interface."""
293
- if not api_key:
294
- raise APIKeyError("Please provide an API key")
295
-
296
- if not file:
297
- raise DataValidationError("Please upload a data file")
298
-
299
- df = process_file(file)
300
- if df is None:
301
- raise DataValidationError("Could not process file")
302
-
303
- assistant = DataAnalysisAssistant(api_key)
304
- return assistant.analyze(df, query)
305
-
306
- def create_interface() -> gr.Blocks:
307
- """Create enhanced Gradio interface."""
308
- css = """
309
- .plot-container {
310
- margin: 20px 0;
311
- padding: 15px;
312
- border: 1px solid #e0e0e0;
313
- border-radius: 8px;
314
- background: white;
315
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
316
- }
317
- .analysis-text {
318
- margin: 20px 0;
319
- line-height: 1.6;
320
- font-size: 16px;
321
- }
322
- .error {
323
- color: #721c24;
324
- background-color: #f8d7da;
325
- padding: 10px;
326
- margin: 10px 0;
327
- border-left: 4px solid #f5c6cb;
328
- border-radius: 4px;
329
- }
330
- """
331
-
332
- with gr.Blocks(css=css) as interface:
333
- gr.Markdown("""
334
- # Advanced Data Analysis Assistant
335
-
336
- Upload your data and get AI-powered analysis with interactive visualizations.
337
-
338
- **Features:**
339
- - Interactive Plotly visualizations
340
- - GPT-4 powered analysis
341
- - Time series analysis
342
- - Statistical insights
343
- - Natural language queries
344
-
345
- **Required:** OpenAI API key
346
- """)
347
-
348
- with gr.Row():
349
- with gr.Column():
350
- file = gr.File(
351
- label="Upload Data File",
352
- file_types=SUPPORTED_FILE_TYPES
353
- )
354
- query = gr.Textbox(
355
- label="What would you like to analyze?",
356
- placeholder="e.g., Analyze trends and patterns in the data with interactive visualizations",
357
- lines=3
358
- )
359
- api_key = gr.Textbox(
360
- label="OpenAI API Key",
361
- placeholder="Your API key",
362
- type="password"
363
- )
364
- analyze_btn = gr.Button("Analyze")
365
-
366
- with gr.Column():
367
- output = gr.HTML(label="Analysis Results")
368
-
369
- analyze_btn.click(
370
- analyze_data,
371
- inputs=[file, query, api_key],
372
- outputs=output
373
- )
374
-
375
- gr.Examples(
376
- examples=[
377
- [None, "Show trends over time with interactive visualizations", None],
378
- [None, "Create a comprehensive analysis of relationships between variables", None],
379
- [None, "Analyze distributions and statistical patterns", None],
380
- [None, "Generate financial metrics and performance indicators", None],
381
- ],
382
- inputs=[file, query, api_key]
383
- )
384
-
385
- return interface
386
-
387
- if __name__ == "__main__":
388
- # Configure logging for production
389
- logging.basicConfig(
390
- filename='analysis_assistant.log',
391
- level=logging.INFO,
392
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
393
- )
394
-
395
- try:
396
- interface = create_interface()
397
- interface.launch(
398
- server_name="0.0.0.0",
399
- server_port=7860,
400
- share=True
401
- )
402
- except Exception as e:
403
- logger.exception("Failed to launch interface")
404
- raise