jzou19950715 commited on
Commit
829203b
·
verified ·
1 Parent(s): cce2b52

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +190 -80
tools.py CHANGED
@@ -4,14 +4,49 @@ Integrates smolagents, GPT-4, and interactive Plotly visualizations.
4
  """
5
 
6
  import json
 
7
  import os
8
- from dataclasses import dataclass
9
- from pathlib import Path
10
- from typing import Any, Dict, List, Optional, Union, Tuple
11
  from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
13
  import gradio as gr
14
  import pandas as pd
 
15
  from smolagents import CodeAgent, LiteLLMModel
16
 
17
  from tools import (
@@ -22,13 +57,26 @@ from tools import (
22
  validate_dataframe,
23
  get_numeric_columns,
24
  get_temporal_columns,
25
- AnalysisError
26
  )
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # Constants
29
  SUPPORTED_FILE_TYPES = [".csv", ".xlsx", ".xls"]
30
  DEFAULT_MODEL = "gpt-4o-mini"
31
  HISTORY_FILE = "analysis_history.json"
 
32
 
33
  @dataclass
34
  class VisualizationConfig:
@@ -38,48 +86,54 @@ class VisualizationConfig:
38
  template: str = "plotly_white"
39
  show_grid: bool = True
40
  interactive: bool = True
41
-
42
- class DataPreprocessor:
43
- """Handles data preprocessing and validation."""
44
 
45
- @staticmethod
46
- def preprocess_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
47
- """Preprocess the dataframe and return metadata."""
48
- # First validate the dataframe
49
- is_valid, error_msg = validate_dataframe(df)
50
- if not is_valid:
51
- raise ValueError(error_msg)
52
-
53
- metadata = {
54
- "original_shape": df.shape,
55
- "missing_values": df.isnull().sum().to_dict(),
56
- "dtypes": df.dtypes.astype(str).to_dict(),
57
- "numeric_columns": get_numeric_columns(df),
58
- "categorical_columns": df.select_dtypes(include=['object']).columns.tolist(),
59
- "temporal_columns": get_temporal_columns(df)
60
- }
61
-
62
- # Handle missing values
63
- df = df.fillna(method='ffill').fillna(method='bfill')
64
-
65
- return df, metadata
66
 
67
  class AnalysisHistory:
68
  """Manages analysis history and persistence."""
69
 
70
  def __init__(self, history_file: str = HISTORY_FILE):
71
- self.history_file = history_file
72
- self.history = self._load_history()
73
 
74
  def _load_history(self) -> List[Dict]:
75
- if os.path.exists(self.history_file):
 
76
  try:
77
- with open(self.history_file, 'r') as f:
78
  return json.load(f)
79
- except:
 
 
 
 
80
  return []
81
  return []
82
-
 
 
 
 
 
 
 
 
83
  def add_entry(self, query: str, result: str) -> None:
84
  """Add new analysis entry to history."""
85
  entry = {
@@ -88,10 +142,8 @@ class AnalysisHistory:
88
  'result': result
89
  }
90
  self.history.append(entry)
91
-
92
- with open(self.history_file, 'w') as f:
93
- json.dump(self.history, f)
94
-
95
  def get_recent_analyses(self, limit: int = 5) -> List[Dict]:
96
  """Get recent analysis entries."""
97
  return sorted(
@@ -99,16 +151,56 @@ class AnalysisHistory:
99
  key=lambda x: x['timestamp'],
100
  reverse=True
101
  )[:limit]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  class DataAnalysisAssistant:
104
  """Enhanced data analysis assistant with visualization capabilities."""
105
 
106
  def __init__(self, api_key: str):
 
 
 
107
  self.model = LiteLLMModel(
108
  model_id=DEFAULT_MODEL,
109
  api_key=api_key
110
  )
111
  self.history = AnalysisHistory()
 
112
 
113
  self.agent = CodeAgent(
114
  model=self.model,
@@ -123,18 +215,16 @@ class DataAnalysisAssistant:
123
  'seaborn', 'scipy', 'statsmodels'
124
  ],
125
  )
126
-
 
127
  def analyze(self, df: pd.DataFrame, query: str) -> str:
128
  """Perform analysis with interactive visualizations."""
129
- try:
130
- df, metadata = DataPreprocessor.preprocess_dataframe(df)
131
- context = self._create_analysis_context(df, metadata, query)
132
- response = self.agent.run(context, additional_args={"df": df})
133
- self.history.add_entry(query, str(response))
134
- return self._format_results(response)
135
- except Exception as e:
136
- return f"Analysis failed: {str(e)}"
137
-
138
  def _create_analysis_context(self, df: pd.DataFrame, metadata: Dict, query: str) -> str:
139
  """Create detailed context for analysis."""
140
  tools_description = """
@@ -172,22 +262,28 @@ class DataAnalysisAssistant:
172
  """Format analysis results with visualizations."""
173
  return f'<div class="analysis-text">{response}</div>'
174
 
 
175
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
176
  """Process uploaded file into DataFrame."""
177
  if not file:
178
- return None
179
-
 
 
 
 
 
 
 
180
  try:
181
- file_path = Path(file.name)
182
  if file_path.suffix == '.csv':
183
  return pd.read_csv(file_path)
184
- elif file_path.suffix in ('.xlsx', '.xls'):
185
  return pd.read_excel(file_path)
186
- else:
187
- raise ValueError(f"Unsupported file type: {file_path.suffix}")
188
  except Exception as e:
189
- raise RuntimeError(f"Error reading file: {str(e)}")
190
 
 
191
  def analyze_data(
192
  file: gr.File,
193
  query: str,
@@ -195,23 +291,19 @@ def analyze_data(
195
  ) -> str:
196
  """Main analysis function for Gradio interface."""
197
  if not api_key:
198
- return "Error: Please provide an API key"
199
-
200
  if not file:
201
- return "Error: Please upload a data file"
202
-
203
- try:
204
- df = process_file(file)
205
- if df is None:
206
- return "Error: Could not process file"
207
-
208
- assistant = DataAnalysisAssistant(api_key)
209
- return assistant.analyze(df, query)
210
-
211
- except Exception as e:
212
- return f"Error: {str(e)}"
213
 
214
- def create_interface():
215
  """Create enhanced Gradio interface."""
216
  css = """
217
  .plot-container {
@@ -225,12 +317,15 @@ def create_interface():
225
  .analysis-text {
226
  margin: 20px 0;
227
  line-height: 1.6;
 
228
  }
229
  .error {
230
- color: red;
 
231
  padding: 10px;
232
  margin: 10px 0;
233
- border-left: 4px solid red;
 
234
  }
235
  """
236
 
@@ -242,7 +337,7 @@ def create_interface():
242
 
243
  **Features:**
244
  - Interactive Plotly visualizations
245
- - gpt-4o-mini powered analysis
246
  - Time series analysis
247
  - Statistical insights
248
  - Natural language queries
@@ -279,16 +374,31 @@ def create_interface():
279
 
280
  gr.Examples(
281
  examples=[
282
- [None, "Show trends over time with interactive visualizations"],
283
- [None, "Create a comprehensive analysis of relationships between variables"],
284
- [None, "Analyze distributions and statistical patterns"],
285
- [None, "Generate financial metrics and performance indicators"],
286
  ],
287
- inputs=[file, query]
288
  )
289
 
290
  return interface
291
 
292
  if __name__ == "__main__":
293
- interface = create_interface()
294
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
  import json
7
+ import logging
8
  import os
9
+ import sys
10
+ import subprocess
11
+ from dataclasses import dataclass, asdict
12
  from datetime import datetime
13
+ from pathlib import Path
14
+ from typing import Any, Dict, List, Optional, Tuple, Union
15
+ from functools import wraps
16
+
17
+ # Set up logging
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Auto-install required packages
25
+ def install_missing_packages():
26
+ required_packages = [
27
+ 'gradio',
28
+ 'pandas',
29
+ 'smolagents',
30
+ 'plotly',
31
+ 'numpy',
32
+ 'scikit-learn',
33
+ 'seaborn',
34
+ 'openpyxl' # For Excel support
35
+ ]
36
+
37
+ for package in required_packages:
38
+ try:
39
+ __import__(package)
40
+ except ImportError:
41
+ logger.info(f"Installing {package}...")
42
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
43
+
44
+ install_missing_packages()
45
 
46
+ # Now import the installed packages
47
  import gradio as gr
48
  import pandas as pd
49
+ import numpy as np
50
  from smolagents import CodeAgent, LiteLLMModel
51
 
52
  from tools import (
 
57
  validate_dataframe,
58
  get_numeric_columns,
59
  get_temporal_columns,
 
60
  )
61
 
62
+ # Custom Exceptions
63
+ class AnalysisError(Exception):
64
+ """Base exception for analysis errors."""
65
+ pass
66
+
67
+ class DataValidationError(AnalysisError):
68
+ """Exception for data validation errors."""
69
+ pass
70
+
71
+ class APIKeyError(AnalysisError):
72
+ """Exception for API key related errors."""
73
+ pass
74
+
75
  # Constants
76
  SUPPORTED_FILE_TYPES = [".csv", ".xlsx", ".xls"]
77
  DEFAULT_MODEL = "gpt-4o-mini"
78
  HISTORY_FILE = "analysis_history.json"
79
+ MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
80
 
81
  @dataclass
82
  class VisualizationConfig:
 
86
  template: str = "plotly_white"
87
  show_grid: bool = True
88
  interactive: bool = True
 
 
 
89
 
90
+ def to_dict(self) -> Dict[str, Any]:
91
+ """Convert config to dictionary."""
92
+ return asdict(self)
93
+
94
+ def error_handler(func):
95
+ """Decorator for handling errors gracefully."""
96
+ @wraps(func)
97
+ def wrapper(*args, **kwargs):
98
+ try:
99
+ return func(*args, **kwargs)
100
+ except AnalysisError as e:
101
+ logger.error(f"Analysis error: {str(e)}")
102
+ return f"Analysis error: {str(e)}"
103
+ except Exception as e:
104
+ logger.exception("Unexpected error occurred")
105
+ return f"An unexpected error occurred: {str(e)}"
106
+ return wrapper
 
 
 
 
107
 
108
  class AnalysisHistory:
109
  """Manages analysis history and persistence."""
110
 
111
  def __init__(self, history_file: str = HISTORY_FILE):
112
+ self.history_file = Path(history_file)
113
+ self.history: List[Dict] = self._load_history()
114
 
115
  def _load_history(self) -> List[Dict]:
116
+ """Load history from file with error handling."""
117
+ if self.history_file.exists():
118
  try:
119
+ with self.history_file.open('r') as f:
120
  return json.load(f)
121
+ except json.JSONDecodeError as e:
122
+ logger.error(f"Error loading history file: {e}")
123
+ return []
124
+ except Exception as e:
125
+ logger.exception("Unexpected error loading history")
126
  return []
127
  return []
128
+
129
+ def _save_history(self) -> None:
130
+ """Save history to file with error handling."""
131
+ try:
132
+ with self.history_file.open('w') as f:
133
+ json.dump(self.history, f, indent=2)
134
+ except Exception as e:
135
+ logger.error(f"Error saving history: {e}")
136
+
137
  def add_entry(self, query: str, result: str) -> None:
138
  """Add new analysis entry to history."""
139
  entry = {
 
142
  'result': result
143
  }
144
  self.history.append(entry)
145
+ self._save_history()
146
+
 
 
147
  def get_recent_analyses(self, limit: int = 5) -> List[Dict]:
148
  """Get recent analysis entries."""
149
  return sorted(
 
151
  key=lambda x: x['timestamp'],
152
  reverse=True
153
  )[:limit]
154
+
155
+ def clear_history(self) -> None:
156
+ """Clear analysis history."""
157
+ self.history = []
158
+ self._save_history()
159
+
160
+ class DataPreprocessor:
161
+ """Handles data preprocessing and validation."""
162
+
163
+ @staticmethod
164
+ def preprocess_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
165
+ """Preprocess the dataframe and return metadata."""
166
+ if df.empty:
167
+ raise DataValidationError("DataFrame is empty")
168
+
169
+ # First validate the dataframe
170
+ is_valid, error_msg = validate_dataframe(df)
171
+ if not is_valid:
172
+ raise DataValidationError(error_msg)
173
+
174
+ # Generate metadata
175
+ metadata = {
176
+ "original_shape": df.shape,
177
+ "missing_values": df.isnull().sum().to_dict(),
178
+ "dtypes": df.dtypes.astype(str).to_dict(),
179
+ "numeric_columns": get_numeric_columns(df),
180
+ "categorical_columns": df.select_dtypes(include=['object']).columns.tolist(),
181
+ "temporal_columns": get_temporal_columns(df),
182
+ "memory_usage": df.memory_usage(deep=True).sum() / (1024 * 1024) # MB
183
+ }
184
+
185
+ # Handle missing values
186
+ df = df.copy() # Avoid modifying original
187
+ df = df.fillna(method='ffill').fillna(method='bfill')
188
+
189
+ return df, metadata
190
 
191
  class DataAnalysisAssistant:
192
  """Enhanced data analysis assistant with visualization capabilities."""
193
 
194
  def __init__(self, api_key: str):
195
+ if not api_key:
196
+ raise APIKeyError("API key is required")
197
+
198
  self.model = LiteLLMModel(
199
  model_id=DEFAULT_MODEL,
200
  api_key=api_key
201
  )
202
  self.history = AnalysisHistory()
203
+ self.viz_config = VisualizationConfig()
204
 
205
  self.agent = CodeAgent(
206
  model=self.model,
 
215
  'seaborn', 'scipy', 'statsmodels'
216
  ],
217
  )
218
+
219
+ @error_handler
220
  def analyze(self, df: pd.DataFrame, query: str) -> str:
221
  """Perform analysis with interactive visualizations."""
222
+ df, metadata = DataPreprocessor.preprocess_dataframe(df)
223
+ context = self._create_analysis_context(df, metadata, query)
224
+ response = self.agent.run(context, additional_args={"df": df})
225
+ self.history.add_entry(query, str(response))
226
+ return self._format_results(response)
227
+
 
 
 
228
  def _create_analysis_context(self, df: pd.DataFrame, metadata: Dict, query: str) -> str:
229
  """Create detailed context for analysis."""
230
  tools_description = """
 
262
  """Format analysis results with visualizations."""
263
  return f'<div class="analysis-text">{response}</div>'
264
 
265
+ @error_handler
266
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
267
  """Process uploaded file into DataFrame."""
268
  if not file:
269
+ raise DataValidationError("No file provided")
270
+
271
+ file_path = Path(file.name)
272
+ if file_path.stat().st_size > MAX_FILE_SIZE:
273
+ raise DataValidationError(f"File size exceeds maximum limit of {MAX_FILE_SIZE/1024/1024}MB")
274
+
275
+ if file_path.suffix not in SUPPORTED_FILE_TYPES:
276
+ raise DataValidationError(f"Unsupported file type: {file_path.suffix}")
277
+
278
  try:
 
279
  if file_path.suffix == '.csv':
280
  return pd.read_csv(file_path)
281
+ else: # .xlsx or .xls
282
  return pd.read_excel(file_path)
 
 
283
  except Exception as e:
284
+ raise DataValidationError(f"Error reading file: {str(e)}")
285
 
286
+ @error_handler
287
  def analyze_data(
288
  file: gr.File,
289
  query: str,
 
291
  ) -> str:
292
  """Main analysis function for Gradio interface."""
293
  if not api_key:
294
+ raise APIKeyError("Please provide an API key")
295
+
296
  if not file:
297
+ raise DataValidationError("Please upload a data file")
298
+
299
+ df = process_file(file)
300
+ if df is None:
301
+ raise DataValidationError("Could not process file")
302
+
303
+ assistant = DataAnalysisAssistant(api_key)
304
+ return assistant.analyze(df, query)
 
 
 
 
305
 
306
+ def create_interface() -> gr.Blocks:
307
  """Create enhanced Gradio interface."""
308
  css = """
309
  .plot-container {
 
317
  .analysis-text {
318
  margin: 20px 0;
319
  line-height: 1.6;
320
+ font-size: 16px;
321
  }
322
  .error {
323
+ color: #721c24;
324
+ background-color: #f8d7da;
325
  padding: 10px;
326
  margin: 10px 0;
327
+ border-left: 4px solid #f5c6cb;
328
+ border-radius: 4px;
329
  }
330
  """
331
 
 
337
 
338
  **Features:**
339
  - Interactive Plotly visualizations
340
+ - GPT-4 powered analysis
341
  - Time series analysis
342
  - Statistical insights
343
  - Natural language queries
 
374
 
375
  gr.Examples(
376
  examples=[
377
+ [None, "Show trends over time with interactive visualizations", None],
378
+ [None, "Create a comprehensive analysis of relationships between variables", None],
379
+ [None, "Analyze distributions and statistical patterns", None],
380
+ [None, "Generate financial metrics and performance indicators", None],
381
  ],
382
+ inputs=[file, query, api_key]
383
  )
384
 
385
  return interface
386
 
387
  if __name__ == "__main__":
388
+ # Configure logging for production
389
+ logging.basicConfig(
390
+ filename='analysis_assistant.log',
391
+ level=logging.INFO,
392
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
393
+ )
394
+
395
+ try:
396
+ interface = create_interface()
397
+ interface.launch(
398
+ server_name="0.0.0.0",
399
+ server_port=7860,
400
+ share=True
401
+ )
402
+ except Exception as e:
403
+ logger.exception("Failed to launch interface")
404
+ raise