jzou19950715 commited on
Commit
344d561
·
verified ·
1 Parent(s): e68b5c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +275 -134
app.py CHANGED
@@ -1,175 +1,279 @@
1
  """
2
- Enhanced Data Analysis Assistant using smolagents for more powerful analysis capabilities.
 
3
  """
4
 
5
  import base64
6
  import io
 
7
  import os
8
  from dataclasses import dataclass
9
- from typing import Any, Dict, List, Optional, Union
10
  from pathlib import Path
 
11
 
12
  import gradio as gr
13
- import pandas as pd
14
  import numpy as np
 
15
  import plotly.express as px
16
  import plotly.graph_objects as go
17
  from plotly.subplots import make_subplots
18
- import matplotlib.pyplot as plt
19
  import seaborn as sns
20
-
21
- from smolagents import CodeAgent, tool
22
 
23
  # Constants
24
  SUPPORTED_FILE_TYPES = [".csv", ".xlsx", ".xls"]
25
- DEFAULT_MODEL = "gpt-4o-mini"
 
26
 
27
- @tool
28
- def create_plotly_visualization(df: pd.DataFrame, plot_type: str, x: str, y: str,
29
- color: Optional[str] = None, title: Optional[str] = None) -> str:
30
- """Create an interactive Plotly visualization.
31
-
32
- Args:
33
- df: DataFrame to visualize
34
- plot_type: Type of plot (scatter, line, bar, box)
35
- x: Column for x-axis
36
- y: Column for y-axis
37
- color: Optional column for color encoding
38
- title: Optional plot title
39
-
40
- Returns:
41
- HTML string of the plot
42
- """
43
- if plot_type == "scatter":
44
- fig = px.scatter(df, x=x, y=y, color=color, title=title)
45
- elif plot_type == "line":
46
- fig = px.line(df, x=x, y=y, color=color, title=title)
47
- elif plot_type == "bar":
48
- fig = px.bar(df, x=x, y=y, color=color, title=title)
49
- elif plot_type == "box":
50
- fig = px.box(df, x=x, y=y, color=color, title=title)
51
- else:
52
- raise ValueError(f"Unsupported plot type: {plot_type}")
53
-
54
- return fig.to_html(include_plotlyjs=True, full_html=False)
55
 
56
- @tool
57
- def calculate_statistics(df: pd.DataFrame, columns: List[str]) -> Dict[str, Any]:
58
- """Calculate basic statistics for specified columns.
59
 
60
- Args:
61
- df: DataFrame to analyze
62
- columns: List of columns to analyze
 
 
 
 
 
 
 
 
63
 
64
- Returns:
65
- Dictionary of statistics
66
- """
67
- stats = {}
68
- for col in columns:
69
- if pd.api.types.is_numeric_dtype(df[col]):
70
- stats[col] = {
71
- "mean": df[col].mean(),
72
- "median": df[col].median(),
73
- "std": df[col].std(),
74
- "min": df[col].min(),
75
- "max": df[col].max(),
76
- "missing": df[col].isna().sum()
77
- }
78
- return stats
79
 
80
- @tool
81
- def correlation_analysis(df: pd.DataFrame, threshold: float = 0.5) -> str:
82
- """Generate correlation analysis with interactive heatmap.
83
 
84
- Args:
85
- df: DataFrame to analyze
86
- threshold: Correlation threshold to highlight
 
 
 
 
 
 
 
 
87
 
88
- Returns:
89
- HTML string of the correlation heatmap
90
- """
91
- numeric_df = df.select_dtypes(include=[np.number])
92
- corr = numeric_df.corr()
93
-
94
- fig = go.Figure(data=go.Heatmap(
95
- z=corr,
96
- x=corr.columns,
97
- y=corr.columns,
98
- colorscale='RdBu',
99
- ))
100
-
101
- fig.update_layout(
102
- title="Correlation Heatmap",
103
- height=600,
104
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- return fig.to_html(include_plotlyjs=True, full_html=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  class DataAnalysisAssistant:
109
- """Enhanced data analysis assistant using smolagents."""
110
 
111
- def __init__(self, api_key: str, model_id: str = DEFAULT_MODEL):
112
- """Initialize the assistant with API key and model."""
113
- os.environ["OPENAI_API_KEY"] = api_key
 
 
 
 
114
 
 
115
  self.agent = CodeAgent(
116
- tools=[
117
- create_plotly_visualization,
118
- calculate_statistics,
119
- correlation_analysis
120
- ],
121
- model=model_id,
122
  additional_authorized_imports=[
123
- "pandas",
124
- "numpy",
125
- "plotly.express",
126
- "plotly.graph_objects",
127
- "seaborn",
128
- ]
129
  )
130
 
131
  def analyze(self, df: pd.DataFrame, query: str) -> str:
132
- """Run analysis using the agent.
 
 
 
 
 
133
 
134
- Args:
135
- df: DataFrame to analyze
136
- query: User's analysis request
137
 
138
- Returns:
139
- HTML string containing analysis and visualizations
140
- """
141
- context = f"""
142
- Available DataFrame (as 'df'):
143
- - Shape: {df.shape}
144
- - Columns: {', '.join(df.columns)}
145
- - Data Types:
146
- {chr(10).join([f' • {col}: {dtype}' for col, dtype in df.dtypes.items()])}
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  User Query: {query}
149
 
150
- Please provide:
151
- 1. Data insights and findings
152
- 2. Interactive visualizations where appropriate
153
- 3. Statistical analysis
154
- 4. Clear explanations
 
 
155
 
156
- You can use these tools:
157
- - create_plotly_visualization: Creates interactive Plotly plots
158
- - calculate_statistics: Provides statistical summaries
159
- - correlation_analysis: Generates correlation heatmaps
160
  """
161
 
162
- try:
163
- result = self.agent.run(context, additional_args={"df": df})
164
- return str(result)
165
- except Exception as e:
166
- return f"Analysis failed: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
169
  """Process uploaded file into DataFrame."""
170
  if not file:
171
  return None
172
-
173
  try:
174
  file_path = Path(file.name)
175
  if file_path.suffix == '.csv':
@@ -181,19 +285,25 @@ def process_file(file: gr.File) -> Optional[pd.DataFrame]:
181
  except Exception as e:
182
  raise RuntimeError(f"Error reading file: {str(e)}")
183
 
184
- def analyze_data(file: gr.File, query: str, api_key: str) -> str:
 
 
 
 
185
  """Main analysis function for Gradio interface."""
186
  if not api_key:
187
  return "Error: Please provide an API key"
188
-
189
  if not file:
190
  return "Error: Please upload a data file"
191
-
192
  try:
 
193
  df = process_file(file)
194
  if df is None:
195
  return "Error: Could not process file"
196
-
 
197
  assistant = DataAnalysisAssistant(api_key)
198
  return assistant.analyze(df, query)
199
 
@@ -201,7 +311,7 @@ def analyze_data(file: gr.File, query: str, api_key: str) -> str:
201
  return f"Error: {str(e)}"
202
 
203
  def create_interface():
204
- """Create Gradio interface."""
205
  css = """
206
  .plot-container {
207
  margin: 20px 0;
@@ -209,14 +319,34 @@ def create_interface():
209
  border: 1px solid #e0e0e0;
210
  border-radius: 8px;
211
  background: white;
 
 
 
 
 
 
 
 
 
 
 
212
  }
213
  """
214
 
215
  with gr.Blocks(css=css) as interface:
216
  gr.Markdown("""
217
- # Enhanced Data Analysis Assistant
 
 
218
 
219
- Powered by smolagents for more intelligent analysis
 
 
 
 
 
 
 
220
  """)
221
 
222
  with gr.Row():
@@ -227,12 +357,12 @@ def create_interface():
227
  )
228
  query = gr.Textbox(
229
  label="What would you like to analyze?",
230
- placeholder="e.g., Show relationships between variables with interactive plots",
231
  lines=3
232
  )
233
  api_key = gr.Textbox(
234
- label="API Key",
235
- placeholder="Your OpenAI API key",
236
  type="password"
237
  )
238
  analyze_btn = gr.Button("Analyze")
@@ -246,6 +376,17 @@ def create_interface():
246
  outputs=output
247
  )
248
 
 
 
 
 
 
 
 
 
 
 
 
249
  return interface
250
 
251
  if __name__ == "__main__":
 
1
  """
2
+ Advanced Data Analysis Assistant with Interactive Visualizations
3
+ Integrates smolagents, GPT-4, and interactive Plotly visualizations.
4
  """
5
 
6
  import base64
7
  import io
8
+ import json
9
  import os
10
  from dataclasses import dataclass
 
11
  from pathlib import Path
12
+ from typing import Any, Dict, List, Optional, Union, Tuple
13
 
14
  import gradio as gr
 
15
  import numpy as np
16
+ import pandas as pd
17
  import plotly.express as px
18
  import plotly.graph_objects as go
19
  from plotly.subplots import make_subplots
 
20
  import seaborn as sns
21
+ from smolagents import CodeAgent, LiteLLMModel, tool
22
+ from datetime import datetime, timedelta
23
 
24
  # Constants
25
  SUPPORTED_FILE_TYPES = [".csv", ".xlsx", ".xls"]
26
+ DEFAULT_MODEL = "gpt-4"
27
+ HISTORY_FILE = "analysis_history.json"
28
 
29
+ @dataclass
30
+ class VisualizationConfig:
31
+ """Configuration for visualizations."""
32
+ width: int = 800
33
+ height: int = 500
34
+ template: str = "plotly_white"
35
+ show_grid: bool = True
36
+ interactive: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ class DataPreprocessor:
39
+ """Handles data preprocessing and validation."""
 
40
 
41
+ @staticmethod
42
+ def preprocess_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
43
+ """Preprocess the dataframe and return metadata."""
44
+ metadata = {
45
+ "original_shape": df.shape,
46
+ "missing_values": df.isnull().sum().to_dict(),
47
+ "dtypes": df.dtypes.astype(str).to_dict(),
48
+ "numeric_columns": df.select_dtypes(include=[np.number]).columns.tolist(),
49
+ "categorical_columns": df.select_dtypes(include=['object']).columns.tolist(),
50
+ "temporal_columns": []
51
+ }
52
 
53
+ # Handle date/time columns
54
+ for col in df.columns:
55
+ try:
56
+ pd.to_datetime(df[col])
57
+ metadata["temporal_columns"].append(col)
58
+ df[col] = pd.to_datetime(df[col])
59
+ except:
60
+ continue
61
+
62
+ # Handle missing values
63
+ df = df.fillna(method='ffill').fillna(method='bfill')
64
+
65
+ return df, metadata
 
 
66
 
67
+ class CodeExecutionEnvironment:
68
+ """Safe environment for executing analysis code."""
 
69
 
70
+ def __init__(self, visualization_config: Optional[VisualizationConfig] = None):
71
+ self.viz_config = visualization_config or VisualizationConfig()
72
+ self.globals = {
73
+ 'pd': pd,
74
+ 'np': np,
75
+ 'px': px,
76
+ 'go': go,
77
+ 'make_subplots': make_subplots,
78
+ 'sns': sns
79
+ }
80
+ self.locals = {}
81
 
82
+ def execute(self, code: str, df: pd.DataFrame = None) -> Dict[str, Any]:
83
+ """Execute code and capture outputs including visualizations."""
84
+ if df is not None:
85
+ self.globals['df'] = df
86
+
87
+ output_buffer = io.StringIO()
88
+ import sys
89
+ sys.stdout = output_buffer
90
+
91
+ result = {
92
+ 'output': '',
93
+ 'plotly_html': [],
94
+ 'error': None,
95
+ 'dataframe_updates': None
96
+ }
97
+
98
+ try:
99
+ exec(code, self.globals, self.locals)
100
+
101
+ # Capture Plotly figures
102
+ for var_name, value in self.locals.items():
103
+ if isinstance(value, (go.Figure, px.Figure)):
104
+ # Apply visualization config
105
+ value.update_layout(
106
+ width=self.viz_config.width,
107
+ height=self.viz_config.height,
108
+ template=self.viz_config.template,
109
+ showgrid=self.viz_config.show_grid
110
+ )
111
+ html = value.to_html(
112
+ include_plotlyjs=True,
113
+ full_html=False,
114
+ config={'displayModeBar': True}
115
+ )
116
+ result['plotly_html'].append(html)
117
+
118
+ # Capture DataFrame updates
119
+ if 'df' in self.locals and id(self.locals['df']) != id(df):
120
+ result['dataframe_updates'] = self.locals['df']
121
+
122
+ result['output'] = output_buffer.getvalue()
123
+
124
+ except Exception as e:
125
+ result['error'] = f"Error executing code: {str(e)}"
126
+
127
+ finally:
128
+ sys.stdout = sys.__stdout__
129
+ output_buffer.close()
130
+
131
+ return result
132
+
133
+ class AnalysisHistory:
134
+ """Manages analysis history and persistence."""
135
 
136
+ def __init__(self, history_file: str = HISTORY_FILE):
137
+ self.history_file = history_file
138
+ self.history = self._load_history()
139
+
140
+ def _load_history(self) -> List[Dict]:
141
+ if os.path.exists(self.history_file):
142
+ try:
143
+ with open(self.history_file, 'r') as f:
144
+ return json.load(f)
145
+ except:
146
+ return []
147
+ return []
148
+
149
+ def add_entry(self, query: str, result: str) -> None:
150
+ """Add new analysis entry to history."""
151
+ entry = {
152
+ 'timestamp': datetime.now().isoformat(),
153
+ 'query': query,
154
+ 'result': result
155
+ }
156
+ self.history.append(entry)
157
+
158
+ with open(self.history_file, 'w') as f:
159
+ json.dump(self.history, f)
160
+
161
+ def get_recent_analyses(self, limit: int = 5) -> List[Dict]:
162
+ """Get recent analysis entries."""
163
+ return sorted(
164
+ self.history,
165
+ key=lambda x: x['timestamp'],
166
+ reverse=True
167
+ )[:limit]
168
 
169
  class DataAnalysisAssistant:
170
+ """Enhanced data analysis assistant with visualization capabilities."""
171
 
172
+ def __init__(self, api_key: str):
173
+ self.model = LiteLLMModel(
174
+ model_id=DEFAULT_MODEL,
175
+ api_key=api_key
176
+ )
177
+ self.code_env = CodeExecutionEnvironment()
178
+ self.history = AnalysisHistory()
179
 
180
+ # Initialize agent with tools
181
  self.agent = CodeAgent(
182
+ model=self.model,
 
 
 
 
 
183
  additional_authorized_imports=[
184
+ 'pandas', 'numpy', 'plotly.express', 'plotly.graph_objects',
185
+ 'seaborn', 'scipy', 'statsmodels'
186
+ ],
 
 
 
187
  )
188
 
189
  def analyze(self, df: pd.DataFrame, query: str) -> str:
190
+ """Perform analysis with interactive visualizations."""
191
+ # Preprocess data
192
+ df, metadata = DataPreprocessor.preprocess_dataframe(df)
193
+
194
+ # Create context for the agent
195
+ context = self._create_analysis_context(df, metadata, query)
196
 
197
+ try:
198
+ # Get analysis plan
199
+ response = self.agent.run(context, additional_args={"df": df})
200
 
201
+ # Extract and execute code blocks
202
+ results = self._execute_analysis(response, df)
203
+
204
+ # Save to history
205
+ self.history.add_entry(query, str(response))
206
+
207
+ return self._format_results(response, results)
208
+
209
+ except Exception as e:
210
+ return f"Analysis failed: {str(e)}"
211
+
212
+ def _create_analysis_context(self, df: pd.DataFrame, metadata: Dict, query: str) -> str:
213
+ """Create detailed context for analysis."""
214
+ return f"""
215
+ Analyze the following data with interactive visualizations.
216
+
217
+ DataFrame Information:
218
+ - Shape: {metadata['original_shape']}
219
+ - Numeric columns: {', '.join(metadata['numeric_columns'])}
220
+ - Categorical columns: {', '.join(metadata['categorical_columns'])}
221
+ - Temporal columns: {', '.join(metadata['temporal_columns'])}
222
 
223
  User Query: {query}
224
 
225
+ Guidelines:
226
+ 1. Use Plotly for interactive visualizations
227
+ 2. Store figures in variables named 'fig'
228
+ 3. Include clear titles and labels
229
+ 4. Add hover information
230
+ 5. Use color effectively
231
+ 6. Handle errors gracefully
232
 
233
+ The DataFrame is available as 'df'.
 
 
 
234
  """
235
 
236
+ def _execute_analysis(self, response: str, df: pd.DataFrame) -> List[Dict]:
237
+ """Execute code blocks from analysis."""
238
+ import re
239
+ results = []
240
+
241
+ # Extract code blocks
242
+ code_blocks = re.findall(r'```python\n(.*?)```', str(response), re.DOTALL)
243
+
244
+ for code in code_blocks:
245
+ result = self.code_env.execute(code, df)
246
+ results.append(result)
247
+
248
+ return results
249
+
250
+ def _format_results(self, response: str, results: List[Dict]) -> str:
251
+ """Format analysis results with visualizations."""
252
+ output_parts = []
253
+
254
+ # Add analysis text
255
+ analysis_text = str(response).replace("```python", "").replace("```", "")
256
+ output_parts.append(f'<div class="analysis-text">{analysis_text}</div>')
257
+
258
+ # Add execution results
259
+ for result in results:
260
+ if result['error']:
261
+ output_parts.append(f'<div class="error">{result["error"]}</div>')
262
+ else:
263
+ if result['output']:
264
+ output_parts.append(f'<pre>{result["output"]}</pre>')
265
+ for html in result['plotly_html']:
266
+ output_parts.append(
267
+ f'<div class="plot-container">{html}</div>'
268
+ )
269
+
270
+ return "\n".join(output_parts)
271
 
272
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
273
  """Process uploaded file into DataFrame."""
274
  if not file:
275
  return None
276
+
277
  try:
278
  file_path = Path(file.name)
279
  if file_path.suffix == '.csv':
 
285
  except Exception as e:
286
  raise RuntimeError(f"Error reading file: {str(e)}")
287
 
288
+ def analyze_data(
289
+ file: gr.File,
290
+ query: str,
291
+ api_key: str,
292
+ ) -> str:
293
  """Main analysis function for Gradio interface."""
294
  if not api_key:
295
  return "Error: Please provide an API key"
296
+
297
  if not file:
298
  return "Error: Please upload a data file"
299
+
300
  try:
301
+ # Process file
302
  df = process_file(file)
303
  if df is None:
304
  return "Error: Could not process file"
305
+
306
+ # Create assistant and run analysis
307
  assistant = DataAnalysisAssistant(api_key)
308
  return assistant.analyze(df, query)
309
 
 
311
  return f"Error: {str(e)}"
312
 
313
  def create_interface():
314
+ """Create enhanced Gradio interface."""
315
  css = """
316
  .plot-container {
317
  margin: 20px 0;
 
319
  border: 1px solid #e0e0e0;
320
  border-radius: 8px;
321
  background: white;
322
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
323
+ }
324
+ .analysis-text {
325
+ margin: 20px 0;
326
+ line-height: 1.6;
327
+ }
328
+ .error {
329
+ color: red;
330
+ padding: 10px;
331
+ margin: 10px 0;
332
+ border-left: 4px solid red;
333
  }
334
  """
335
 
336
  with gr.Blocks(css=css) as interface:
337
  gr.Markdown("""
338
+ # Advanced Data Analysis Assistant
339
+
340
+ Upload your data and get AI-powered analysis with interactive visualizations.
341
 
342
+ **Features:**
343
+ - Interactive Plotly visualizations
344
+ - GPT-4 powered analysis
345
+ - Time series analysis
346
+ - Statistical insights
347
+ - Natural language queries
348
+
349
+ **Required:** OpenAI API key
350
  """)
351
 
352
  with gr.Row():
 
357
  )
358
  query = gr.Textbox(
359
  label="What would you like to analyze?",
360
+ placeholder="e.g., Analyze trends and patterns in the data with interactive visualizations",
361
  lines=3
362
  )
363
  api_key = gr.Textbox(
364
+ label="OpenAI API Key",
365
+ placeholder="Your API key",
366
  type="password"
367
  )
368
  analyze_btn = gr.Button("Analyze")
 
376
  outputs=output
377
  )
378
 
379
+ # Add examples
380
+ gr.Examples(
381
+ examples=[
382
+ [None, "Show trends over time with interactive visualizations"],
383
+ [None, "Create a comprehensive analysis of relationships between variables"],
384
+ [None, "Analyze distributions and statistical patterns"],
385
+ [None, "Generate financial metrics and performance indicators"],
386
+ ],
387
+ inputs=[file, query]
388
+ )
389
+
390
  return interface
391
 
392
  if __name__ == "__main__":