jzou19950715 commited on
Commit
707e5d0
·
verified ·
1 Parent(s): b90c312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +519 -47
app.py CHANGED
@@ -1,29 +1,448 @@
1
- def create_interface():
2
- """Create Gradio interface with proper HTML rendering"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- agent = AnalysisAgent()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def format_html_output(content: str) -> str:
7
- """Format the output to properly render HTML in Gradio"""
8
- # Split content into text and HTML parts
9
- parts = content.split('<!DOCTYPE html>')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- if len(parts) == 1:
12
- # No HTML content
13
- return f'<div style="padding: 20px;">{content}</div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  formatted_parts = []
16
  for i, part in enumerate(parts):
17
- if i == 0:
18
- # Text content
19
- if part.strip():
20
- formatted_parts.append(f'<div style="padding: 20px;">{part}</div>')
21
- else:
22
- # HTML visualization
23
- formatted_parts.append(f'<!DOCTYPE html>{part}')
24
 
25
  return '\n'.join(formatted_parts)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def process_file(file: gr.File) -> str:
28
  """Process uploaded file and initialize session"""
29
  try:
@@ -32,39 +451,64 @@ def create_interface():
32
  elif file.name.endswith(('.xlsx', '.xls')):
33
  agent.session.data = pd.read_excel(file.name)
34
  else:
35
- return "Error: Unsupported file type"
36
-
37
- return format_html_output(f"Successfully loaded data: {agent.session.get_context()}")
 
 
 
 
 
 
38
  except Exception as e:
39
- return format_html_output(f"Error loading file: {str(e)}")
 
 
40
 
41
- def analyze(file: gr.File, query: str, api_key: str) -> str:
42
- """Process analysis query"""
43
  if not api_key:
44
- return format_html_output("Error: Please provide an API key.")
 
 
 
45
 
46
  if not file:
47
- return format_html_output("Error: Please upload a file.")
 
 
 
48
 
49
  try:
50
  os.environ["OPENAI_API_KEY"] = api_key
51
  result = agent.process_query(query)
52
- return format_html_output(result)
 
 
 
 
 
 
53
  except Exception as e:
54
- return format_html_output(f"Error: {str(e)}")
 
 
 
55
 
 
56
  with gr.Blocks(css="""
57
- .plot-container { margin: 20px 0; }
58
- .bokeh-plot { margin: 20px auto; }
 
59
  """) as interface:
60
  gr.Markdown("""
61
  # Interactive Data Analysis Assistant
62
 
63
  Upload your data file and chat with the AI to analyze it. Features:
64
- - Interactive visualizations
65
- - Natural language analysis
66
- - Follow-up questions
67
- - Statistical insights
68
 
69
  **Note**: Requires OpenAI API key
70
  """)
@@ -73,42 +517,70 @@ def create_interface():
73
  with gr.Column(scale=1):
74
  file = gr.File(
75
  label="Upload Data File",
76
- file_types=[".csv", ".xlsx", ".xls"]
 
77
  )
 
78
  api_key = gr.Textbox(
79
- label="API Key",
80
- type="password"
 
81
  )
 
82
  chat_input = gr.Textbox(
83
  label="Ask about your data",
84
  placeholder="e.g., Show me the relationship between variables",
85
  lines=3
86
  )
87
- analyze_btn = gr.Button("Analyze")
 
 
 
88
 
89
  with gr.Column(scale=2):
90
- chat_output = gr.HTML(
91
  label="Analysis & Visualizations",
92
- elem_classes="plot-container"
93
  )
94
 
95
  # Set up event handlers
96
- file.change(process_file, inputs=[file], outputs=[chat_output])
 
 
 
 
 
97
  analyze_btn.click(
98
  analyze,
99
- inputs=[file, chat_input, api_key],
100
- outputs=[chat_output]
101
  )
102
 
103
  # Example queries
104
  gr.Examples(
105
  examples=[
106
- [None, "Show me the distribution of numerical variables"],
107
- [None, "Create an interactive visualization of the relationships between variables"],
108
- [None, "Analyze trends in the data over time"],
109
- [None, "Compare different categories using interactive charts"],
 
 
110
  ],
111
  inputs=[file, chat_input]
112
  )
 
 
 
 
 
 
 
 
 
 
113
 
114
- return interface
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+ import json
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import pandas as pd
11
+ from bokeh.plotting import figure
12
+ from bokeh.layouts import column, row, layout
13
+ from bokeh.models import (
14
+ ColumnDataSource,
15
+ HoverTool,
16
+ BoxSelectTool,
17
+ WheelZoomTool,
18
+ ResetTool,
19
+ Legend,
20
+ LegendItem
21
+ )
22
+ from bokeh.embed import file_html
23
+ from bokeh.resources import CDN
24
+ from litellm import completion
25
+
26
+ class VisualizationEngine:
27
+ """Engine for creating interactive Bokeh visualizations"""
28
 
29
+ def __init__(self):
30
+ self.width = 800
31
+ self.height = 500
32
+ self.tools = "pan,box_zoom,wheel_zoom,reset,save,hover"
33
+ self.cdn = CDN
34
+
35
+ def create_scatter(self, df: pd.DataFrame, x_col: str, y_col: str,
36
+ color_col: Optional[str] = None, title: str = "") -> str:
37
+ """Create an interactive scatter plot"""
38
+ source = ColumnDataSource(df)
39
+
40
+ p = figure(width=self.width, height=self.height,
41
+ title=title, tools=self.tools)
42
+
43
+ # Add scatter points
44
+ if color_col and color_col in df.columns:
45
+ scatter = p.scatter(
46
+ x_col, y_col,
47
+ source=source,
48
+ color={'field': color_col, 'transform': 'category10'},
49
+ size=8,
50
+ alpha=0.6
51
+ )
52
+ else:
53
+ scatter = p.scatter(
54
+ x_col, y_col,
55
+ source=source,
56
+ color='navy',
57
+ size=8,
58
+ alpha=0.6
59
+ )
60
+
61
+ # Style the plot
62
+ p.title.text_font_size = '16pt'
63
+ p.xaxis.axis_label = x_col
64
+ p.yaxis.axis_label = y_col
65
+ p.axis.axis_label_text_font_size = '12pt'
66
+
67
+ # Add hover tooltip
68
+ hover = p.select(dict(type=HoverTool))
69
+ hover.tooltips = [(col, f"@{col}") for col in [x_col, y_col] + ([color_col] if color_col else [])]
70
+ hover.mode = 'mouse'
71
+
72
+ return file_html(p, self.cdn)
73
 
74
+ def create_line(self, df: pd.DataFrame, x_col: str, y_cols: List[str],
75
+ title: str = "") -> str:
76
+ """Create an interactive line plot"""
77
+ source = ColumnDataSource(df)
78
+
79
+ p = figure(width=self.width, height=self.height,
80
+ title=title, tools=self.tools)
81
+
82
+ # Add lines for each y column
83
+ colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
84
+ '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
85
+ for i, y_col in enumerate(y_cols):
86
+ line = p.line(
87
+ x_col, y_col,
88
+ line_width=2,
89
+ source=source,
90
+ legend_label=y_col,
91
+ color=colors[i % len(colors)]
92
+ )
93
+
94
+ # Style the plot
95
+ p.title.text_font_size = '16pt'
96
+ p.xaxis.axis_label = x_col
97
+ p.yaxis.axis_label = "Values"
98
+ p.axis.axis_label_text_font_size = '12pt'
99
+ p.legend.click_policy = "hide"
100
+ p.legend.location = "top_right"
101
+
102
+ # Add hover tooltip
103
+ hover = p.select(dict(type=HoverTool))
104
+ hover.tooltips = [(col, f"@{col}") for col in [x_col] + y_cols]
105
+ hover.mode = 'mouse'
106
+
107
+ return file_html(p, self.cdn)
108
+
109
+ def create_bar(self, df: pd.DataFrame, x_col: str, y_col: str,
110
+ title: str = "", color: str = "#1f77b4") -> str:
111
+ """Create an interactive bar plot"""
112
+ source = ColumnDataSource(df)
113
+
114
+ p = figure(width=self.width, height=self.height,
115
+ title=title, tools=self.tools,
116
+ x_range=df[x_col].unique().tolist())
117
+
118
+ # Add bars
119
+ p.vbar(
120
+ x=x_col,
121
+ top=y_col,
122
+ width=0.9,
123
+ source=source,
124
+ color=color,
125
+ alpha=0.8
126
+ )
127
+
128
+ # Style the plot
129
+ p.title.text_font_size = '16pt'
130
+ p.xaxis.axis_label = x_col
131
+ p.yaxis.axis_label = y_col
132
+ p.axis.axis_label_text_font_size = '12pt'
133
+ p.xgrid.grid_line_color = None
134
+ p.xaxis.major_label_orientation = 0.7
135
+
136
+ # Add hover tooltip
137
+ hover = p.select(dict(type=HoverTool))
138
+ hover.tooltips = [(x_col, f"@{x_col}"), (y_col, f"@{y_col}")]
139
+ hover.mode = 'mouse'
140
+
141
+ return file_html(p, self.cdn)
142
+
143
+ def create_histogram(self, df: pd.DataFrame, column: str, bins: int = 30,
144
+ title: str = "") -> str:
145
+ """Create an interactive histogram"""
146
+ hist, edges = np.histogram(df[column].dropna(), bins=bins)
147
+ hist_df = pd.DataFrame({
148
+ 'count': hist,
149
+ 'left': edges[:-1],
150
+ 'right': edges[1:]
151
+ })
152
+ source = ColumnDataSource(hist_df)
153
+
154
+ p = figure(width=self.width, height=self.height,
155
+ title=title, tools=self.tools)
156
+
157
+ # Add histogram bars
158
+ p.quad(
159
+ top='count',
160
+ bottom=0,
161
+ left='left',
162
+ right='right',
163
+ source=source,
164
+ fill_color="#1f77b4",
165
+ line_color="white",
166
+ alpha=0.8
167
+ )
168
+
169
+ # Style the plot
170
+ p.title.text_font_size = '16pt'
171
+ p.xaxis.axis_label = column
172
+ p.yaxis.axis_label = 'Count'
173
+ p.axis.axis_label_text_font_size = '12pt'
174
+
175
+ # Add hover tooltip
176
+ hover = p.select(dict(type=HoverTool))
177
+ hover.tooltips = [
178
+ ('Range', '@left{0.00} to @right{0.00}'),
179
+ ('Count', '@count')
180
+ ]
181
+ hover.mode = 'mouse'
182
+
183
+ return file_html(p, self.cdn)
184
+
185
+ class DataAnalyzer:
186
+ """Helper class for common data analysis tasks"""
187
+
188
+ @staticmethod
189
+ def get_summary_stats(df: pd.DataFrame) -> pd.DataFrame:
190
+ """Get summary statistics for numerical columns"""
191
+ return df.describe()
192
+
193
+ @staticmethod
194
+ def get_missing_values(df: pd.DataFrame) -> pd.DataFrame:
195
+ """Get missing values information"""
196
+ missing = pd.DataFrame({
197
+ 'column': df.columns,
198
+ 'missing_count': df.isnull().sum(),
199
+ 'missing_percentage': (df.isnull().sum() / len(df) * 100).round(2)
200
+ })
201
+ return missing[missing['missing_count'] > 0]
202
+
203
+ @staticmethod
204
+ def get_correlation_matrix(df: pd.DataFrame) -> pd.DataFrame:
205
+ """Get correlation matrix for numerical columns"""
206
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
207
+ return df[numeric_cols].corr()
208
+
209
+ class AnalysisSession:
210
+ """Maintains state and history for the analysis session"""
211
+
212
+ def __init__(self):
213
+ self.data: Optional[pd.DataFrame] = None
214
+ self.chat_history: List[Dict[str, str]] = []
215
+ self.viz_engine = VisualizationEngine()
216
+ self.analyzer = DataAnalyzer()
217
+
218
+ def add_message(self, role: str, content: str):
219
+ """Add a message to the chat history"""
220
+ self.chat_history.append({"role": role, "content": content})
221
+
222
+ def get_context(self) -> str:
223
+ """Get the current analysis context"""
224
+ if self.data is None:
225
+ return "No data loaded yet."
226
+
227
+ numeric_cols = self.data.select_dtypes(include=[np.number]).columns
228
+ categorical_cols = self.data.select_dtypes(include=['object', 'category']).columns
229
+
230
+ missing_info = self.analyzer.get_missing_values(self.data)
231
+ missing_summary = "\n".join([
232
+ f"- {row['column']}: {row['missing_count']} ({row['missing_percentage']}%)"
233
+ for _, row in missing_info.iterrows()
234
+ ]) if not missing_info.empty else "No missing values found."
235
+
236
+ context = f"""
237
+ Current DataFrame Info:
238
+ - Shape: {self.data.shape}
239
+ - Numeric columns: {', '.join(numeric_cols)}
240
+ - Categorical columns: {', '.join(categorical_cols)}
241
+
242
+ Missing Values:
243
+ {missing_summary}
244
+ """
245
+ return context
246
+
247
+ class AnalysisAgent:
248
+ """Enhanced agent with interactive visualization and chat capabilities"""
249
+
250
+ def __init__(
251
+ self,
252
+ model_id: str = "gpt-4",
253
+ temperature: float = 0.7,
254
+ ):
255
+ self.model_id = model_id
256
+ self.temperature = temperature
257
+ self.session = AnalysisSession()
258
+
259
+ def process_query(self, query: str) -> str:
260
+ """Process a user query and generate response with visualizations"""
261
+ context = self.session.get_context()
262
 
263
+ messages = [
264
+ {"role": "system", "content": self._get_system_prompt()},
265
+ *self.session.chat_history[-5:], # Include last 5 messages for context
266
+ {"role": "user", "content": f"{context}\n\nUser query: {query}"}
267
+ ]
268
+
269
+ try:
270
+ response = completion(
271
+ model=self.model_id,
272
+ messages=messages,
273
+ temperature=self.temperature,
274
+ )
275
+ analysis = response.choices[0].message.content
276
+
277
+ # Extract and execute any code blocks
278
+ visualizations = []
279
+ code_blocks = self._extract_code(analysis)
280
+
281
+ for code in code_blocks:
282
+ try:
283
+ # Execute code and capture visualization commands
284
+ result = self._execute_visualization(code)
285
+ if result:
286
+ visualizations.append(result)
287
+ except Exception as e:
288
+ visualizations.append(f"Error creating visualization: {str(e)}")
289
+
290
+ # Add messages to chat history
291
+ self.session.add_message("user", query)
292
+ self.session.add_message("assistant", analysis)
293
+
294
+ # Format the response with visualizations
295
+ formatted_response = self._format_response(analysis, visualizations)
296
+ return formatted_response
297
+
298
+ except Exception as e:
299
+ return f"Error: {str(e)}"
300
+
301
+ def _execute_visualization(self, code: str) -> Optional[str]:
302
+ """Execute visualization code and return HTML output"""
303
+ try:
304
+ # Create a safe namespace with necessary libraries
305
+ namespace = {
306
+ 'df': self.session.data,
307
+ 'np': np,
308
+ 'pd': pd,
309
+ 'viz': self.session.viz_engine,
310
+ 'analyzer': self.session.analyzer
311
+ }
312
+
313
+ # Execute the code
314
+ exec(code, namespace)
315
+
316
+ # Look for visualization result
317
+ for var in namespace.values():
318
+ if isinstance(var, str) and ('<script' in var or '<div' in var):
319
+ return var
320
+
321
+ return None
322
+
323
+ except Exception as e:
324
+ return f"Error executing visualization: {str(e)}"
325
+
326
+ def _format_response(self, analysis: str, visualizations: List[str]) -> str:
327
+ """Format the response with text and visualizations"""
328
+ # Split analysis into parts (before and after code blocks)
329
+ parts = self._extract_code(analysis, keep_markdown=True)
330
 
331
  formatted_parts = []
332
  for i, part in enumerate(parts):
333
+ if i % 2 == 0: # Text content
334
+ formatted_parts.append(f'<div class="analysis-text">{part}</div>')
335
+ else: # Code block location
336
+ if i//2 < len(visualizations):
337
+ viz = visualizations[i//2]
338
+ formatted_parts.append(f'<div class="visualization">{viz}</div>')
 
339
 
340
  return '\n'.join(formatted_parts)
341
 
342
+ def _get_system_prompt(self) -> str:
343
+ """Get system prompt with visualization capabilities"""
344
+ return """You are a data analysis assistant with interactive visualization capabilities.
345
+
346
+ Available visualizations:
347
+ 1. Scatter plots (viz.create_scatter)
348
+ - x_col: x-axis column name
349
+ - y_col: y-axis column name
350
+ - color_col: optional column for color coding
351
+ - title: plot title
352
+
353
+ 2. Line plots (viz.create_line)
354
+ - x_col: x-axis column name
355
+ - y_cols: list of column names for multiple lines
356
+ - title: plot title
357
+
358
+ 3. Bar plots (viz.create_bar)
359
+ - x_col: category column name
360
+ - y_col: value column name
361
+ - title: plot title
362
+ - color: optional bar color
363
+
364
+ 4. Histograms (viz.create_histogram)
365
+ - column: column to analyze
366
+ - bins: number of bins
367
+ - title: plot title
368
+
369
+ Analysis tools:
370
+ - analyzer.get_summary_stats(df): Get summary statistics
371
+ - analyzer.get_correlation_matrix(df): Get correlation matrix
372
+ - analyzer.get_missing_values(df): Get missing values information
373
+
374
+ When analyzing data:
375
+ 1. First understand and explain the data
376
+ 2. Create relevant visualizations using the viz engine
377
+ 3. Provide insights based on the visualizations
378
+ 4. Ask follow-up questions when appropriate
379
+ 5. Use markdown for formatting
380
+
381
+ Example visualization code:
382
+ ```python
383
+ # Create scatter plot
384
+ html = viz.create_scatter(df, 'column1', 'column2', title='Analysis')
385
+ print(html)
386
+
387
+ # Create line plot with multiple series
388
+ html = viz.create_line(df, 'date_column', ['value1', 'value2'], title='Trends')
389
+ print(html)
390
+
391
+ # Create histogram
392
+ html = viz.create_histogram(df, 'numeric_column', bins=30, title='Distribution')
393
+ print(html)
394
+ ```
395
+ """
396
+
397
+ @staticmethod
398
+ def _extract_code(text: str, keep_markdown: bool = False) -> List[str]:
399
+ """Extract Python code blocks from markdown"""
400
+ import re
401
+ pattern = r'```python\n(.*?)```'
402
+ if keep_markdown:
403
+ return re.split(pattern, text, flags=re.DOTALL)
404
+ return re.findall(pattern, text, re.DOTALL)
405
+
406
+ def create_interface():
407
+ """Create Gradio interface with proper HTML rendering"""
408
+
409
+ agent = AnalysisAgent()
410
+
411
+ def format_html_output(content: str) -> str:
412
+ """Format the output to properly render HTML in Gradio"""
413
+ # Add custom CSS for better visualization display
414
+ css = """
415
+ <style>
416
+ .analysis-text {
417
+ padding: 20px;
418
+ margin: 10px 0;
419
+ background: #f8f9fa;
420
+ border-radius: 8px;
421
+ font-size: 16px;
422
+ }
423
+ .visualization {
424
+ margin: 20px 0;
425
+ padding: 10px;
426
+ border: 1px solid #dee2e6;
427
+ border-radius: 8px;
428
+ background: white;
429
+ }
430
+ .bokeh-plot {
431
+ margin: 0 auto;
432
+ }
433
+ pre {
434
+ background: #f1f3f5;
435
+ padding: 15px;
436
+ border-radius: 5px;
437
+ overflow-x: auto;
438
+ }
439
+ code {
440
+ font-family: 'Courier New', Courier, monospace;
441
+ }
442
+ </style>
443
+ """
444
+ return f"{css}\n{content}"
445
+
446
  def process_file(file: gr.File) -> str:
447
  """Process uploaded file and initialize session"""
448
  try:
 
451
  elif file.name.endswith(('.xlsx', '.xls')):
452
  agent.session.data = pd.read_excel(file.name)
453
  else:
454
+ return format_html_output(
455
+ '<div class="analysis-text">Error: Unsupported file type. Please upload a CSV or Excel file.</div>'
456
+ )
457
+
458
+ # Show initial data summary
459
+ summary = agent.session.get_context()
460
+ return format_html_output(
461
+ f'<div class="analysis-text">Successfully loaded data!\n\n{summary}</div>'
462
+ )
463
  except Exception as e:
464
+ return format_html_output(
465
+ f'<div class="analysis-text">Error loading file: {str(e)}</div>'
466
+ )
467
 
468
+ def analyze(file: gr.File, query: str, api_key: str, chat_history: str) -> tuple:
469
+ """Process analysis query and update chat history"""
470
  if not api_key:
471
+ return (
472
+ format_html_output('<div class="analysis-text">Error: Please provide an API key.</div>'),
473
+ chat_history
474
+ )
475
 
476
  if not file:
477
+ return (
478
+ format_html_output('<div class="analysis-text">Error: Please upload a file.</div>'),
479
+ chat_history
480
+ )
481
 
482
  try:
483
  os.environ["OPENAI_API_KEY"] = api_key
484
  result = agent.process_query(query)
485
+
486
+ # Update chat history
487
+ new_history = chat_history or ""
488
+ new_history += f"\nYou: {query}\nAssistant: {result}\n"
489
+
490
+ return format_html_output(result), new_history
491
+
492
  except Exception as e:
493
+ return (
494
+ format_html_output(f'<div class="analysis-text">Error: {str(e)}</div>'),
495
+ chat_history
496
+ )
497
 
498
+ # Create the Gradio interface
499
  with gr.Blocks(css="""
500
+ .container { max-width: 1200px; margin: auto; }
501
+ .analysis-header { margin-bottom: 20px; }
502
+ .file-upload { margin-bottom: 15px; }
503
  """) as interface:
504
  gr.Markdown("""
505
  # Interactive Data Analysis Assistant
506
 
507
  Upload your data file and chat with the AI to analyze it. Features:
508
+ - Interactive visualizations with zoom, pan, and hover capabilities
509
+ - Natural language analysis and insights
510
+ - Statistical analysis and summaries
511
+ - Trend detection and pattern analysis
512
 
513
  **Note**: Requires OpenAI API key
514
  """)
 
517
  with gr.Column(scale=1):
518
  file = gr.File(
519
  label="Upload Data File",
520
+ file_types=[".csv", ".xlsx", ".xls"],
521
+ elem_classes="file-upload"
522
  )
523
+
524
  api_key = gr.Textbox(
525
+ label="OpenAI API Key",
526
+ type="password",
527
+ placeholder="Enter your API key here"
528
  )
529
+
530
  chat_input = gr.Textbox(
531
  label="Ask about your data",
532
  placeholder="e.g., Show me the relationship between variables",
533
  lines=3
534
  )
535
+
536
+ chat_history = gr.State("")
537
+
538
+ analyze_btn = gr.Button("Analyze", variant="primary")
539
 
540
  with gr.Column(scale=2):
541
+ output = gr.HTML(
542
  label="Analysis & Visualizations",
543
+ elem_classes="analysis-output"
544
  )
545
 
546
  # Set up event handlers
547
+ file.change(
548
+ process_file,
549
+ inputs=[file],
550
+ outputs=[output]
551
+ )
552
+
553
  analyze_btn.click(
554
  analyze,
555
+ inputs=[file, chat_input, api_key, chat_history],
556
+ outputs=[output, chat_history]
557
  )
558
 
559
  # Example queries
560
  gr.Examples(
561
  examples=[
562
+ [None, "Show me the distribution of all numerical variables using histograms"],
563
+ [None, "Create an interactive scatter plot matrix of the main variables"],
564
+ [None, "Analyze trends over time and show them with an interactive line plot"],
565
+ [None, "Compare categories using bar plots and provide statistical insights"],
566
+ [None, "Identify and visualize correlations between numerical variables"],
567
+ [None, "Create a dashboard showing key metrics and their distributions"],
568
  ],
569
  inputs=[file, chat_input]
570
  )
571
+
572
+ # Add footer with information
573
+ gr.Markdown("""
574
+ ### Tips for better analysis:
575
+ 1. Upload clean data in CSV or Excel format
576
+ 2. Be specific in your questions
577
+ 3. Use follow-up questions to dive deeper
578
+ 4. Interact with the visualizations using mouse hover, zoom, and pan
579
+ 5. Look for patterns and trends in the interactive plots
580
+ """)
581
 
582
+ return interface
583
+
584
+ if __name__ == "__main__":
585
+ interface = create_interface()
586
+ interface.launch()