jzou19950715 commited on
Commit
356d9f3
·
verified ·
1 Parent(s): 54d2c65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -279
app.py CHANGED
@@ -1,283 +1,207 @@
 
 
 
 
1
  import base64
2
  import io
3
  import os
4
  from dataclasses import dataclass
5
- from typing import Any, Callable, Dict, List, Optional, Union
 
6
 
7
  import gradio as gr
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
  import pandas as pd
11
- import seaborn as sns
12
  import plotly.express as px
13
  import plotly.graph_objects as go
14
  from plotly.subplots import make_subplots
15
- from litellm import completion
 
16
 
17
- class CodeEnvironment:
18
- """Safe environment for executing code with data analysis capabilities"""
 
 
 
 
 
 
 
 
19
 
20
- def __init__(self):
21
- # Initialize libraries in globals
22
- self.globals = {
23
- 'pd': pd,
24
- 'np': np,
25
- 'plt': plt,
26
- 'sns': sns,
27
- 'px': px,
28
- 'go': go,
29
- 'make_subplots': make_subplots
30
- }
31
- self.locals = {}
32
-
33
- def execute(self, code: str, df: pd.DataFrame = None) -> Dict[str, Any]:
34
- """Execute code and capture both static and interactive outputs"""
35
- if df is not None:
36
- self.globals['df'] = df
37
-
38
- # Capture output
39
- output_buffer = io.StringIO()
40
- # Redirect stdout to capture print statements
41
- import sys
42
- sys.stdout = output_buffer
 
 
 
 
 
 
 
 
43
 
44
- result = {
45
- 'output': '',
46
- 'figures': [], # For matplotlib figures
47
- 'plotly_html': [], # For Plotly figures
48
- 'error': None
49
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- try:
52
- # Execute code
53
- exec(code, self.globals, self.locals)
54
-
55
- # Capture matplotlib figures
56
- for i in plt.get_fignums():
57
- fig = plt.figure(i)
58
- buf = io.BytesIO()
59
- fig.savefig(buf, format='png')
60
- buf.seek(0)
61
- img_str = base64.b64encode(buf.read()).decode()
62
- result['figures'].append(f"data:image/png;base64,{img_str}")
63
- plt.close(fig)
64
-
65
- # Capture Plotly figures
66
- if 'fig' in self.locals:
67
- if isinstance(self.locals['fig'], (go.Figure, px.Figure)):
68
- # Convert Plotly figure to HTML
69
- html = self.locals['fig'].to_html(
70
- include_plotlyjs=True,
71
- full_html=False,
72
- config={'displayModeBar': True}
73
- )
74
- result['plotly_html'].append(html)
75
-
76
- # Get printed output
77
- result['output'] = output_buffer.getvalue()
78
-
79
- except Exception as e:
80
- result['error'] = str(e)
81
-
82
- finally:
83
- # Reset stdout
84
- sys.stdout = sys.__stdout__
85
- output_buffer.close()
86
-
87
- return result
88
-
89
- @dataclass
90
- class Tool:
91
- """Tool for data analysis"""
92
- name: str
93
- description: str
94
- func: Callable
95
 
96
- class AnalysisAgent:
97
- """Agent that can analyze data and execute code"""
98
 
99
- def __init__(
100
- self,
101
- model_id: str = "gpt-4o-mini",
102
- temperature: float = 0.7,
103
- ):
104
- self.model_id = model_id
105
- self.temperature = temperature
106
- self.tools: List[Tool] = []
107
- self.code_env = CodeEnvironment()
 
 
 
 
 
 
 
 
 
 
108
 
109
- def run(self, prompt: str, df: pd.DataFrame = None) -> str:
110
- """Run analysis with code execution"""
111
- messages = [
112
- {"role": "system", "content": self._get_system_prompt()},
113
- {"role": "user", "content": prompt}
114
- ]
115
 
116
- try:
117
- # Get response from model
118
- response = completion(
119
- model=self.model_id,
120
- messages=messages,
121
- temperature=self.temperature,
122
- )
123
- analysis = response.choices[0].message.content
124
-
125
- # Extract code blocks
126
- code_blocks = self._extract_code(analysis)
127
 
128
- # Execute code and capture results
129
- results = []
130
- for code in code_blocks:
131
- result = self.code_env.execute(code, df)
132
- if result['error']:
133
- results.append(f"Error executing code: {result['error']}")
134
- else:
135
- # Add output text
136
- if result['output']:
137
- results.append(result['output'])
138
-
139
- # Add Plotly interactive visualizations
140
- for html in result['plotly_html']:
141
- results.append(f'<div class="plot-container">{html}</div>')
142
-
143
- # Add static matplotlib figures as fallback
144
- for fig in result['figures']:
145
- results.append(f'<img src="{fig}" style="max-width: 100%; height: auto;">')
146
-
147
- # Combine analysis and results
148
- return f'<div class="analysis-text">{analysis}</div>' + "\n\n" + "\n".join(results)
149
-
150
- except Exception as e:
151
- return f"Error: {str(e)}"
152
-
153
- def _get_system_prompt(self) -> str:
154
- """Get system prompt with tools and capabilities"""
155
- tools_desc = "\n".join([
156
- f"- {tool.name}: {tool.description}"
157
- for tool in self.tools
158
- ])
159
 
160
- return """You are a data analysis assistant with interactive visualization capabilities.
161
-
162
- When analyzing data, use Plotly for interactive visualizations. Here are examples:
163
-
164
- ```python
165
- # Create interactive scatter plot
166
- import plotly.express as px
167
- fig = px.scatter(df, x='Date', y='Salary', color='Title')
168
- fig.show() # This will be captured and displayed
169
-
170
- # Create interactive box plot
171
- fig = px.box(df, x='Title', y='Salary')
172
- fig.show()
173
-
174
- # Create interactive time series
175
- fig = px.line(df, x='Date', y='Salary', color='Title')
176
- fig.show()
177
- ```
178
-
179
- Remember to:
180
- 1. Always store Plotly figures in a variable named 'fig'
181
- 2. Use fig.show() to display the plot
182
- 3. Create clear labels and titles
183
- 4. Include hover information
184
- 5. Use colors effectively
185
-
186
- For static visualizations, you can still use matplotlib:
187
- ```python
188
- import matplotlib.pyplot as plt
189
- plt.figure(figsize=(10, 6))
190
- plt.plot(df['Date'], df['Salary'])
191
- plt.show()
192
- ```
193
- """
194
-
195
- @staticmethod
196
- def _extract_code(text: str) -> List[str]:
197
- """Extract Python code blocks from markdown"""
198
- import re
199
- pattern = r'```python\n(.*?)```'
200
- return re.findall(pattern, text, re.DOTALL)
201
-
202
 
203
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
204
- """Process uploaded file into DataFrame"""
205
  if not file:
206
  return None
207
-
208
  try:
209
- if file.name.endswith('.csv'):
210
- return pd.read_csv(file.name)
211
- elif file.name.endswith(('.xlsx', '.xls')):
212
- return pd.read_excel(file.name)
 
 
 
213
  except Exception as e:
214
- print(f"Error reading file: {str(e)}")
215
- return None
216
 
217
-
218
- def analyze_data(
219
- file: gr.File,
220
- query: str,
221
- api_key: str,
222
- temperature: float = 0.7,
223
- ) -> str:
224
- """Process user request and generate enhanced analysis"""
225
-
226
  if not api_key:
227
- return "Error: Please provide an API key."
228
-
229
  if not file:
230
- return "Error: Please upload a file."
231
-
232
  try:
233
- # Set up environment
234
- os.environ["OPENAI_API_KEY"] = api_key
235
-
236
- # Create agent
237
- agent = AnalysisAgent(
238
- model_id="gpt-4o-mini",
239
- temperature=temperature
240
- )
241
-
242
- # Process file
243
  df = process_file(file)
244
  if df is None:
245
- return "Error: Could not process file."
246
-
247
- # Build context
248
- file_info = f"""
249
- File: {file.name}
250
- Shape: {df.shape}
251
- Columns: {', '.join(df.columns)}
252
-
253
- Column Types:
254
- {chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])}
255
- """
256
-
257
- # Run analysis
258
- prompt = f"""
259
- {file_info}
260
-
261
- The data is loaded in a pandas DataFrame called 'df'.
262
-
263
- User request: {query}
264
 
265
- Please analyze the data and provide:
266
- 1. Key insights and findings
267
- 2. Interactive visualizations where appropriate
268
- 3. Statistical summaries when relevant
269
- 4. Clear explanations of patterns and trends
270
- """
271
-
272
- return agent.run(prompt, df=df)
273
 
274
  except Exception as e:
275
- return f"Error occurred: {str(e)}"
276
-
277
 
278
  def create_interface():
279
- """Create enhanced Gradio interface"""
280
-
281
  css = """
282
  .plot-container {
283
  margin: 20px 0;
@@ -285,70 +209,43 @@ def create_interface():
285
  border: 1px solid #e0e0e0;
286
  border-radius: 8px;
287
  background: white;
288
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
289
  }
290
  """
291
 
292
- with gr.Blocks(title="AI Data Analysis Assistant", css=css) as interface:
293
  gr.Markdown("""
294
- # AI Data Analysis Assistant
295
-
296
- Upload your data file and get AI-powered analysis with interactive visualizations.
297
 
298
- **Features:**
299
- - Interactive data visualization
300
- - Statistical analysis
301
- - Machine learning capabilities
302
- - Natural language interaction
303
-
304
- **Note**: Requires your own OpenAI API key.
305
  """)
306
 
307
  with gr.Row():
308
  with gr.Column():
309
  file = gr.File(
310
  label="Upload Data File",
311
- file_types=[".csv", ".xlsx", ".xls"]
312
  )
313
  query = gr.Textbox(
314
  label="What would you like to analyze?",
315
- placeholder="e.g., Create interactive visualizations showing relationships between variables",
316
  lines=3
317
  )
318
  api_key = gr.Textbox(
319
- label="API Key (Required)",
320
- placeholder="Your API key",
321
  type="password"
322
  )
323
- temperature = gr.Slider(
324
- label="Temperature",
325
- minimum=0.0,
326
- maximum=1.0,
327
- value=0.7,
328
- step=0.1
329
- )
330
  analyze_btn = gr.Button("Analyze")
331
 
332
  with gr.Column():
333
- output = gr.HTML(label="Output") # Changed to HTML for interactive plots
334
 
335
  analyze_btn.click(
336
  analyze_data,
337
- inputs=[file, query, api_key, temperature],
338
  outputs=output
339
  )
340
 
341
- gr.Examples(
342
- examples=[
343
- [None, "Create interactive visualizations showing relationships between variables"],
344
- [None, "Show the distribution of values with interactive plots"],
345
- [None, "Create an interactive correlation analysis"],
346
- [None, "Show trends over time with interactive charts"],
347
- [None, "Generate a comprehensive analysis with multiple visualizations"],
348
- ],
349
- inputs=[file, query]
350
- )
351
-
352
  return interface
353
 
354
  if __name__ == "__main__":
 
1
+ """
2
+ Enhanced Data Analysis Assistant using smolagents for more powerful analysis capabilities.
3
+ """
4
+
5
  import base64
6
  import io
7
  import os
8
  from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from pathlib import Path
11
 
12
  import gradio as gr
 
 
13
  import pandas as pd
14
+ import numpy as np
15
  import plotly.express as px
16
  import plotly.graph_objects as go
17
  from plotly.subplots import make_subplots
18
+ import matplotlib.pyplot as plt
19
+ import seaborn as sns
20
 
21
+ from smolagents import CodeAgent, tool
22
+
23
+ # Constants
24
+ SUPPORTED_FILE_TYPES = [".csv", ".xlsx", ".xls"]
25
+ DEFAULT_MODEL = "gpt-4o-mini"
26
+
27
+ @tool
28
+ def create_plotly_visualization(df: pd.DataFrame, plot_type: str, x: str, y: str,
29
+ color: Optional[str] = None, title: Optional[str] = None) -> str:
30
+ """Create an interactive Plotly visualization.
31
 
32
+ Args:
33
+ df: DataFrame to visualize
34
+ plot_type: Type of plot (scatter, line, bar, box)
35
+ x: Column for x-axis
36
+ y: Column for y-axis
37
+ color: Optional column for color encoding
38
+ title: Optional plot title
39
+
40
+ Returns:
41
+ HTML string of the plot
42
+ """
43
+ if plot_type == "scatter":
44
+ fig = px.scatter(df, x=x, y=y, color=color, title=title)
45
+ elif plot_type == "line":
46
+ fig = px.line(df, x=x, y=y, color=color, title=title)
47
+ elif plot_type == "bar":
48
+ fig = px.bar(df, x=x, y=y, color=color, title=title)
49
+ elif plot_type == "box":
50
+ fig = px.box(df, x=x, y=y, color=color, title=title)
51
+ else:
52
+ raise ValueError(f"Unsupported plot type: {plot_type}")
53
+
54
+ return fig.to_html(include_plotlyjs=True, full_html=False)
55
+
56
+ @tool
57
+ def calculate_statistics(df: pd.DataFrame, columns: List[str]) -> Dict[str, Any]:
58
+ """Calculate basic statistics for specified columns.
59
+
60
+ Args:
61
+ df: DataFrame to analyze
62
+ columns: List of columns to analyze
63
 
64
+ Returns:
65
+ Dictionary of statistics
66
+ """
67
+ stats = {}
68
+ for col in columns:
69
+ if pd.api.types.is_numeric_dtype(df[col]):
70
+ stats[col] = {
71
+ "mean": df[col].mean(),
72
+ "median": df[col].median(),
73
+ "std": df[col].std(),
74
+ "min": df[col].min(),
75
+ "max": df[col].max(),
76
+ "missing": df[col].isna().sum()
77
+ }
78
+ return stats
79
+
80
+ @tool
81
+ def correlation_analysis(df: pd.DataFrame, threshold: float = 0.5) -> str:
82
+ """Generate correlation analysis with interactive heatmap.
83
+
84
+ Args:
85
+ df: DataFrame to analyze
86
+ threshold: Correlation threshold to highlight
87
 
88
+ Returns:
89
+ HTML string of the correlation heatmap
90
+ """
91
+ numeric_df = df.select_dtypes(include=[np.number])
92
+ corr = numeric_df.corr()
93
+
94
+ fig = go.Figure(data=go.Heatmap(
95
+ z=corr,
96
+ x=corr.columns,
97
+ y=corr.columns,
98
+ colorscale='RdBu',
99
+ ))
100
+
101
+ fig.update_layout(
102
+ title="Correlation Heatmap",
103
+ height=600,
104
+ )
105
+
106
+ return fig.to_html(include_plotlyjs=True, full_html=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ class DataAnalysisAssistant:
109
+ """Enhanced data analysis assistant using smolagents."""
110
 
111
+ def __init__(self, api_key: str, model_id: str = DEFAULT_MODEL):
112
+ """Initialize the assistant with API key and model."""
113
+ os.environ["OPENAI_API_KEY"] = api_key
114
+
115
+ self.agent = CodeAgent(
116
+ tools=[
117
+ create_plotly_visualization,
118
+ calculate_statistics,
119
+ correlation_analysis
120
+ ],
121
+ model=model_id,
122
+ additional_authorized_imports=[
123
+ "pandas",
124
+ "numpy",
125
+ "plotly.express",
126
+ "plotly.graph_objects",
127
+ "seaborn",
128
+ ]
129
+ )
130
 
131
+ def analyze(self, df: pd.DataFrame, query: str) -> str:
132
+ """Run analysis using the agent.
 
 
 
 
133
 
134
+ Args:
135
+ df: DataFrame to analyze
136
+ query: User's analysis request
 
 
 
 
 
 
 
 
137
 
138
+ Returns:
139
+ HTML string containing analysis and visualizations
140
+ """
141
+ context = f"""
142
+ Available DataFrame (as 'df'):
143
+ - Shape: {df.shape}
144
+ - Columns: {', '.join(df.columns)}
145
+ - Data Types:
146
+ {chr(10).join([f' • {col}: {dtype}' for col, dtype in df.dtypes.items()])}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ User Query: {query}
149
+
150
+ Please provide:
151
+ 1. Data insights and findings
152
+ 2. Interactive visualizations where appropriate
153
+ 3. Statistical analysis
154
+ 4. Clear explanations
155
+
156
+ You can use these tools:
157
+ - create_plotly_visualization: Creates interactive Plotly plots
158
+ - calculate_statistics: Provides statistical summaries
159
+ - correlation_analysis: Generates correlation heatmaps
160
+ """
161
+
162
+ try:
163
+ result = self.agent.run(context, additional_args={"df": df})
164
+ return str(result)
165
+ except Exception as e:
166
+ return f"Analysis failed: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
169
+ """Process uploaded file into DataFrame."""
170
  if not file:
171
  return None
172
+
173
  try:
174
+ file_path = Path(file.name)
175
+ if file_path.suffix == '.csv':
176
+ return pd.read_csv(file_path)
177
+ elif file_path.suffix in ('.xlsx', '.xls'):
178
+ return pd.read_excel(file_path)
179
+ else:
180
+ raise ValueError(f"Unsupported file type: {file_path.suffix}")
181
  except Exception as e:
182
+ raise RuntimeError(f"Error reading file: {str(e)}")
 
183
 
184
+ def analyze_data(file: gr.File, query: str, api_key: str) -> str:
185
+ """Main analysis function for Gradio interface."""
 
 
 
 
 
 
 
186
  if not api_key:
187
+ return "Error: Please provide an API key"
188
+
189
  if not file:
190
+ return "Error: Please upload a data file"
191
+
192
  try:
 
 
 
 
 
 
 
 
 
 
193
  df = process_file(file)
194
  if df is None:
195
+ return "Error: Could not process file"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ assistant = DataAnalysisAssistant(api_key)
198
+ return assistant.analyze(df, query)
 
 
 
 
 
 
199
 
200
  except Exception as e:
201
+ return f"Error: {str(e)}"
 
202
 
203
  def create_interface():
204
+ """Create Gradio interface."""
 
205
  css = """
206
  .plot-container {
207
  margin: 20px 0;
 
209
  border: 1px solid #e0e0e0;
210
  border-radius: 8px;
211
  background: white;
 
212
  }
213
  """
214
 
215
+ with gr.Blocks(css=css) as interface:
216
  gr.Markdown("""
217
+ # Enhanced Data Analysis Assistant
 
 
218
 
219
+ Powered by smolagents for more intelligent analysis
 
 
 
 
 
 
220
  """)
221
 
222
  with gr.Row():
223
  with gr.Column():
224
  file = gr.File(
225
  label="Upload Data File",
226
+ file_types=SUPPORTED_FILE_TYPES
227
  )
228
  query = gr.Textbox(
229
  label="What would you like to analyze?",
230
+ placeholder="e.g., Show relationships between variables with interactive plots",
231
  lines=3
232
  )
233
  api_key = gr.Textbox(
234
+ label="API Key",
235
+ placeholder="Your OpenAI API key",
236
  type="password"
237
  )
 
 
 
 
 
 
 
238
  analyze_btn = gr.Button("Analyze")
239
 
240
  with gr.Column():
241
+ output = gr.HTML(label="Analysis Results")
242
 
243
  analyze_btn.click(
244
  analyze_data,
245
+ inputs=[file, query, api_key],
246
  outputs=output
247
  )
248
 
 
 
 
 
 
 
 
 
 
 
 
249
  return interface
250
 
251
  if __name__ == "__main__":