jzou19950715 commited on
Commit
962e9ea
·
verified ·
1 Parent(s): 48a1160

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -288
app.py CHANGED
@@ -3,8 +3,6 @@ Advanced Data Analysis Assistant with Interactive Visualizations
3
  Integrates smolagents, GPT-4, and interactive Plotly visualizations.
4
  """
5
 
6
- import base64
7
- import io
8
  import json
9
  import os
10
  from dataclasses import dataclass
@@ -12,14 +10,20 @@ from pathlib import Path
12
  from typing import Any, Dict, List, Optional, Union, Tuple
13
 
14
  import gradio as gr
15
- import numpy as np
16
  import pandas as pd
17
- import plotly.express as px
18
- import plotly.graph_objects as go
19
- from plotly.subplots import make_subplots
20
- import seaborn as sns
21
- from smolagents import CodeAgent, LiteLLMModel, tool
22
- from datetime import datetime, timedelta
 
 
 
 
 
 
 
23
 
24
  # Constants
25
  SUPPORTED_FILE_TYPES = [".csv", ".xlsx", ".xls"]
@@ -41,131 +45,25 @@ class DataPreprocessor:
41
  @staticmethod
42
  def preprocess_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
43
  """Preprocess the dataframe and return metadata."""
 
 
 
 
 
44
  metadata = {
45
  "original_shape": df.shape,
46
  "missing_values": df.isnull().sum().to_dict(),
47
  "dtypes": df.dtypes.astype(str).to_dict(),
48
- "numeric_columns": df.select_dtypes(include=[np.number]).columns.tolist(),
49
  "categorical_columns": df.select_dtypes(include=['object']).columns.tolist(),
50
- "temporal_columns": []
51
  }
52
 
53
- # Handle date/time columns
54
- for col in df.columns:
55
- try:
56
- pd.to_datetime(df[col])
57
- metadata["temporal_columns"].append(col)
58
- df[col] = pd.to_datetime(df[col])
59
- except:
60
- continue
61
-
62
  # Handle missing values
63
  df = df.fillna(method='ffill').fillna(method='bfill')
64
 
65
  return df, metadata
66
 
67
- class CodeExecutionEnvironment:
68
- """Safe environment for executing analysis code."""
69
-
70
- def __init__(self, visualization_config: Optional[VisualizationConfig] = None):
71
- self.viz_config = visualization_config or VisualizationConfig()
72
- self.globals = {
73
- 'pd': pd,
74
- 'np': np,
75
- 'px': px,
76
- 'go': go,
77
- 'make_subplots': make_subplots,
78
- 'sns': sns
79
- }
80
- self.locals = {}
81
-
82
- def execute(self, code: str, df: pd.DataFrame = None) -> Dict[str, Any]:
83
- """Execute code and capture outputs including visualizations."""
84
- if df is not None:
85
- self.globals['df'] = df
86
-
87
- output_buffer = io.StringIO()
88
- import sys
89
- sys.stdout = output_buffer
90
-
91
- result = {
92
- 'output': '',
93
- 'plotly_html': [],
94
- 'error': None,
95
- 'dataframe_updates': None
96
- }
97
-
98
- try:
99
- exec(code, self.globals, self.locals)
100
-
101
- # Capture Plotly figures
102
- for var_name, value in self.locals.items():
103
- if isinstance(value, (go.Figure, px.Figure)):
104
- # Apply visualization config
105
- value.update_layout(
106
- width=self.viz_config.width,
107
- height=self.viz_config.height,
108
- template=self.viz_config.template,
109
- showgrid=self.viz_config.show_grid
110
- )
111
- html = value.to_html(
112
- include_plotlyjs=True,
113
- full_html=False,
114
- config={'displayModeBar': True}
115
- )
116
- result['plotly_html'].append(html)
117
-
118
- # Capture DataFrame updates
119
- if 'df' in self.locals and id(self.locals['df']) != id(df):
120
- result['dataframe_updates'] = self.locals['df']
121
-
122
- result['output'] = output_buffer.getvalue()
123
-
124
- except Exception as e:
125
- result['error'] = f"Error executing code: {str(e)}"
126
-
127
- finally:
128
- sys.stdout = sys.__stdout__
129
- output_buffer.close()
130
-
131
- return result
132
-
133
- class AnalysisHistory:
134
- """Manages analysis history and persistence."""
135
-
136
- def __init__(self, history_file: str = HISTORY_FILE):
137
- self.history_file = history_file
138
- self.history = self._load_history()
139
-
140
- def _load_history(self) -> List[Dict]:
141
- if os.path.exists(self.history_file):
142
- try:
143
- with open(self.history_file, 'r') as f:
144
- return json.load(f)
145
- except:
146
- return []
147
- return []
148
-
149
- def add_entry(self, query: str, result: str) -> None:
150
- """Add new analysis entry to history."""
151
- entry = {
152
- 'timestamp': datetime.now().isoformat(),
153
- 'query': query,
154
- 'result': result
155
- }
156
- self.history.append(entry)
157
-
158
- with open(self.history_file, 'w') as f:
159
- json.dump(self.history, f)
160
-
161
- def get_recent_analyses(self, limit: int = 5) -> List[Dict]:
162
- """Get recent analysis entries."""
163
- return sorted(
164
- self.history,
165
- key=lambda x: x['timestamp'],
166
- reverse=True
167
- )[:limit]
168
-
169
  class DataAnalysisAssistant:
170
  """Enhanced data analysis assistant with visualization capabilities."""
171
 
@@ -174,12 +72,17 @@ class DataAnalysisAssistant:
174
  model_id=DEFAULT_MODEL,
175
  api_key=api_key
176
  )
177
- self.code_env = CodeExecutionEnvironment()
178
  self.history = AnalysisHistory()
179
 
180
- # Initialize agent with tools
181
  self.agent = CodeAgent(
182
  model=self.model,
 
 
 
 
 
 
183
  additional_authorized_imports=[
184
  'pandas', 'numpy', 'plotly.express', 'plotly.graph_objects',
185
  'seaborn', 'scipy', 'statsmodels'
@@ -188,29 +91,34 @@ class DataAnalysisAssistant:
188
 
189
  def analyze(self, df: pd.DataFrame, query: str) -> str:
190
  """Perform analysis with interactive visualizations."""
191
- # Preprocess data
192
- df, metadata = DataPreprocessor.preprocess_dataframe(df)
193
-
194
- # Create context for the agent
195
- context = self._create_analysis_context(df, metadata, query)
196
-
197
  try:
198
- # Get analysis plan
199
- response = self.agent.run(context, additional_args={"df": df})
 
 
 
200
 
201
- # Extract and execute code blocks
202
- results = self._execute_analysis(response, df)
203
 
204
  # Save to history
205
  self.history.add_entry(query, str(response))
206
 
207
- return self._format_results(response, results)
208
 
209
  except Exception as e:
210
  return f"Analysis failed: {str(e)}"
211
 
212
  def _create_analysis_context(self, df: pd.DataFrame, metadata: Dict, query: str) -> str:
213
  """Create detailed context for analysis."""
 
 
 
 
 
 
 
 
214
  return f"""
215
  Analyze the following data with interactive visualizations.
216
 
@@ -220,174 +128,39 @@ class DataAnalysisAssistant:
220
  - Categorical columns: {', '.join(metadata['categorical_columns'])}
221
  - Temporal columns: {', '.join(metadata['temporal_columns'])}
222
 
 
 
223
  User Query: {query}
224
 
225
  Guidelines:
226
- 1. Use Plotly for interactive visualizations
227
- 2. Store figures in variables named 'fig'
228
- 3. Include clear titles and labels
229
- 4. Add hover information
230
- 5. Use color effectively
231
- 6. Handle errors gracefully
232
 
233
  The DataFrame is available as 'df'.
234
  """
235
-
236
- def _execute_analysis(self, response: str, df: pd.DataFrame) -> List[Dict]:
237
- """Execute code blocks from analysis."""
238
- import re
239
- results = []
240
-
241
- # Extract code blocks
242
- code_blocks = re.findall(r'```python\n(.*?)```', str(response), re.DOTALL)
243
-
244
- for code in code_blocks:
245
- result = self.code_env.execute(code, df)
246
- results.append(result)
247
-
248
- return results
249
-
250
- def _format_results(self, response: str, results: List[Dict]) -> str:
251
  """Format analysis results with visualizations."""
252
- output_parts = []
253
-
254
- # Add analysis text
255
- analysis_text = str(response).replace("```python", "").replace("```", "")
256
- output_parts.append(f'<div class="analysis-text">{analysis_text}</div>')
257
-
258
- # Add execution results
259
- for result in results:
260
- if result['error']:
261
- output_parts.append(f'<div class="error">{result["error"]}</div>')
262
- else:
263
- if result['output']:
264
- output_parts.append(f'<pre>{result["output"]}</pre>')
265
- for html in result['plotly_html']:
266
- output_parts.append(
267
- f'<div class="plot-container">{html}</div>'
268
- )
269
-
270
- return "\n".join(output_parts)
271
 
272
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
273
  """Process uploaded file into DataFrame."""
274
- if not file:
275
- return None
276
-
277
- try:
278
- file_path = Path(file.name)
279
- if file_path.suffix == '.csv':
280
- return pd.read_csv(file_path)
281
- elif file_path.suffix in ('.xlsx', '.xls'):
282
- return pd.read_excel(file_path)
283
- else:
284
- raise ValueError(f"Unsupported file type: {file_path.suffix}")
285
- except Exception as e:
286
- raise RuntimeError(f"Error reading file: {str(e)}")
287
 
288
- def analyze_data(
289
- file: gr.File,
290
- query: str,
291
- api_key: str,
292
- ) -> str:
293
  """Main analysis function for Gradio interface."""
294
- if not api_key:
295
- return "Error: Please provide an API key"
296
-
297
- if not file:
298
- return "Error: Please upload a data file"
299
-
300
- try:
301
- # Process file
302
- df = process_file(file)
303
- if df is None:
304
- return "Error: Could not process file"
305
-
306
- # Create assistant and run analysis
307
- assistant = DataAnalysisAssistant(api_key)
308
- return assistant.analyze(df, query)
309
-
310
- except Exception as e:
311
- return f"Error: {str(e)}"
312
 
313
  def create_interface():
314
  """Create enhanced Gradio interface."""
315
- css = """
316
- .plot-container {
317
- margin: 20px 0;
318
- padding: 15px;
319
- border: 1px solid #e0e0e0;
320
- border-radius: 8px;
321
- background: white;
322
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
323
- }
324
- .analysis-text {
325
- margin: 20px 0;
326
- line-height: 1.6;
327
- }
328
- .error {
329
- color: red;
330
- padding: 10px;
331
- margin: 10px 0;
332
- border-left: 4px solid red;
333
- }
334
- """
335
-
336
- with gr.Blocks(css=css) as interface:
337
- gr.Markdown("""
338
- # Advanced Data Analysis Assistant
339
-
340
- Upload your data and get AI-powered analysis with interactive visualizations.
341
-
342
- **Features:**
343
- - Interactive Plotly visualizations
344
- - gpt-4o-mini powered analysis
345
- - Time series analysis
346
- - Statistical insights
347
- - Natural language queries
348
-
349
- **Required:** OpenAI API key
350
- """)
351
-
352
- with gr.Row():
353
- with gr.Column():
354
- file = gr.File(
355
- label="Upload Data File",
356
- file_types=SUPPORTED_FILE_TYPES
357
- )
358
- query = gr.Textbox(
359
- label="What would you like to analyze?",
360
- placeholder="e.g., Analyze trends and patterns in the data with interactive visualizations",
361
- lines=3
362
- )
363
- api_key = gr.Textbox(
364
- label="OpenAI API Key",
365
- placeholder="Your API key",
366
- type="password"
367
- )
368
- analyze_btn = gr.Button("Analyze")
369
-
370
- with gr.Column():
371
- output = gr.HTML(label="Analysis Results")
372
-
373
- analyze_btn.click(
374
- analyze_data,
375
- inputs=[file, query, api_key],
376
- outputs=output
377
- )
378
-
379
- # Add examples
380
- gr.Examples(
381
- examples=[
382
- [None, "Show trends over time with interactive visualizations"],
383
- [None, "Create a comprehensive analysis of relationships between variables"],
384
- [None, "Analyze distributions and statistical patterns"],
385
- [None, "Generate financial metrics and performance indicators"],
386
- ],
387
- inputs=[file, query]
388
- )
389
-
390
- return interface
391
 
392
  if __name__ == "__main__":
393
  interface = create_interface()
 
3
  Integrates smolagents, GPT-4, and interactive Plotly visualizations.
4
  """
5
 
 
 
6
  import json
7
  import os
8
  from dataclasses import dataclass
 
10
  from typing import Any, Dict, List, Optional, Union, Tuple
11
 
12
  import gradio as gr
 
13
  import pandas as pd
14
+ from smolagents import CodeAgent, LiteLLMModel
15
+
16
+ # Import our custom tools
17
+ from tools import (
18
+ create_time_series_plot,
19
+ create_correlation_heatmap,
20
+ create_statistical_summary,
21
+ detect_outliers,
22
+ validate_dataframe,
23
+ get_numeric_columns,
24
+ get_temporal_columns,
25
+ AnalysisError
26
+ )
27
 
28
  # Constants
29
  SUPPORTED_FILE_TYPES = [".csv", ".xlsx", ".xls"]
 
45
  @staticmethod
46
  def preprocess_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
47
  """Preprocess the dataframe and return metadata."""
48
+ # First validate the dataframe
49
+ is_valid, error_msg = validate_dataframe(df)
50
+ if not is_valid:
51
+ raise ValueError(error_msg)
52
+
53
  metadata = {
54
  "original_shape": df.shape,
55
  "missing_values": df.isnull().sum().to_dict(),
56
  "dtypes": df.dtypes.astype(str).to_dict(),
57
+ "numeric_columns": get_numeric_columns(df),
58
  "categorical_columns": df.select_dtypes(include=['object']).columns.tolist(),
59
+ "temporal_columns": get_temporal_columns(df)
60
  }
61
 
 
 
 
 
 
 
 
 
 
62
  # Handle missing values
63
  df = df.fillna(method='ffill').fillna(method='bfill')
64
 
65
  return df, metadata
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  class DataAnalysisAssistant:
68
  """Enhanced data analysis assistant with visualization capabilities."""
69
 
 
72
  model_id=DEFAULT_MODEL,
73
  api_key=api_key
74
  )
 
75
  self.history = AnalysisHistory()
76
 
77
+ # Initialize agent with tools and our custom analysis tools
78
  self.agent = CodeAgent(
79
  model=self.model,
80
+ tools=[
81
+ create_time_series_plot,
82
+ create_correlation_heatmap,
83
+ create_statistical_summary,
84
+ detect_outliers
85
+ ],
86
  additional_authorized_imports=[
87
  'pandas', 'numpy', 'plotly.express', 'plotly.graph_objects',
88
  'seaborn', 'scipy', 'statsmodels'
 
91
 
92
  def analyze(self, df: pd.DataFrame, query: str) -> str:
93
  """Perform analysis with interactive visualizations."""
 
 
 
 
 
 
94
  try:
95
+ # Preprocess data
96
+ df, metadata = DataPreprocessor.preprocess_dataframe(df)
97
+
98
+ # Create context for the agent
99
+ context = self._create_analysis_context(df, metadata, query)
100
 
101
+ # Get analysis plan and execute
102
+ response = self.agent.run(context, additional_args={"df": df})
103
 
104
  # Save to history
105
  self.history.add_entry(query, str(response))
106
 
107
+ return self._format_results(response)
108
 
109
  except Exception as e:
110
  return f"Analysis failed: {str(e)}"
111
 
112
  def _create_analysis_context(self, df: pd.DataFrame, metadata: Dict, query: str) -> str:
113
  """Create detailed context for analysis."""
114
+ tools_description = """
115
+ Available analysis tools:
116
+ - create_time_series_plot: Create interactive time series visualizations
117
+ - create_correlation_heatmap: Generate correlation analysis with heatmap
118
+ - create_statistical_summary: Compute statistical summaries with visualizations
119
+ - detect_outliers: Identify and visualize outliers
120
+ """
121
+
122
  return f"""
123
  Analyze the following data with interactive visualizations.
124
 
 
128
  - Categorical columns: {', '.join(metadata['categorical_columns'])}
129
  - Temporal columns: {', '.join(metadata['temporal_columns'])}
130
 
131
+ {tools_description}
132
+
133
  User Query: {query}
134
 
135
  Guidelines:
136
+ 1. Use the provided analysis tools for visualizations
137
+ 2. Include clear titles and labels
138
+ 3. Handle errors gracefully
139
+ 4. Chain multiple analyses when needed
140
+ 5. Provide insights along with visualizations
 
141
 
142
  The DataFrame is available as 'df'.
143
  """
144
+
145
+ def _format_results(self, response: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  """Format analysis results with visualizations."""
147
+ return f'<div class="analysis-text">{response}</div>'
148
+
149
+ class AnalysisHistory:
150
+ """Manages analysis history and persistence."""
151
+ [Previous AnalysisHistory implementation remains the same]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
154
  """Process uploaded file into DataFrame."""
155
+ [Previous process_file implementation remains the same]
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ def analyze_data(file: gr.File, query: str, api_key: str) -> str:
 
 
 
 
158
  """Main analysis function for Gradio interface."""
159
+ [Previous analyze_data implementation remains the same]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  def create_interface():
162
  """Create enhanced Gradio interface."""
163
+ [Previous create_interface implementation remains the same]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  if __name__ == "__main__":
166
  interface = create_interface()