jzou19950715 commited on
Commit
ada0c12
·
verified ·
1 Parent(s): ee4c228

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -517
app.py CHANGED
@@ -1,586 +1,281 @@
1
- import base64
2
- import io
3
  import os
4
- from dataclasses import dataclass
5
- from typing import Any, Callable, Dict, List, Optional, Union
6
  import json
7
 
8
  import gradio as gr
9
- import numpy as np
10
  import pandas as pd
11
- from bokeh.plotting import figure
12
- from bokeh.layouts import column, row, layout
13
- from bokeh.models import (
14
- ColumnDataSource,
15
- HoverTool,
16
- BoxSelectTool,
17
- WheelZoomTool,
18
- ResetTool,
19
- Legend,
20
- LegendItem
21
- )
22
- from bokeh.embed import file_html
23
- from bokeh.resources import CDN
24
  from litellm import completion
25
 
26
- class VisualizationEngine:
27
- """Engine for creating interactive Bokeh visualizations"""
28
 
29
  def __init__(self):
30
- self.width = 800
31
- self.height = 500
32
- self.tools = "pan,box_zoom,wheel_zoom,reset,save,hover"
33
- self.cdn = CDN
34
-
35
- def create_scatter(self, df: pd.DataFrame, x_col: str, y_col: str,
36
- color_col: Optional[str] = None, title: str = "") -> str:
37
- """Create an interactive scatter plot"""
38
- source = ColumnDataSource(df)
39
-
40
- p = figure(width=self.width, height=self.height,
41
- title=title, tools=self.tools)
42
 
43
- # Add scatter points
44
- if color_col and color_col in df.columns:
45
- scatter = p.scatter(
46
- x_col, y_col,
47
- source=source,
48
- color={'field': color_col, 'transform': 'category10'},
49
- size=8,
50
- alpha=0.6
 
 
 
 
51
  )
52
- else:
53
- scatter = p.scatter(
54
- x_col, y_col,
55
- source=source,
56
- color='navy',
57
- size=8,
58
- alpha=0.6
59
  )
60
-
61
- # Style the plot
62
- p.title.text_font_size = '16pt'
63
- p.xaxis.axis_label = x_col
64
- p.yaxis.axis_label = y_col
65
- p.axis.axis_label_text_font_size = '12pt'
66
-
67
- # Add hover tooltip
68
- hover = p.select(dict(type=HoverTool))
69
- hover.tooltips = [(col, f"@{col}") for col in [x_col, y_col] + ([color_col] if color_col else [])]
70
- hover.mode = 'mouse'
71
-
72
- return file_html(p, self.cdn)
73
-
74
- def create_line(self, df: pd.DataFrame, x_col: str, y_cols: List[str],
75
- title: str = "") -> str:
76
- """Create an interactive line plot"""
77
- source = ColumnDataSource(df)
78
-
79
- p = figure(width=self.width, height=self.height,
80
- title=title, tools=self.tools)
81
-
82
- # Add lines for each y column
83
- colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
84
- '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
85
- for i, y_col in enumerate(y_cols):
86
- line = p.line(
87
- x_col, y_col,
88
- line_width=2,
89
- source=source,
90
- legend_label=y_col,
91
- color=colors[i % len(colors)]
92
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- # Style the plot
95
- p.title.text_font_size = '16pt'
96
- p.xaxis.axis_label = x_col
97
- p.yaxis.axis_label = "Values"
98
- p.axis.axis_label_text_font_size = '12pt'
99
- p.legend.click_policy = "hide"
100
- p.legend.location = "top_right"
101
-
102
- # Add hover tooltip
103
- hover = p.select(dict(type=HoverTool))
104
- hover.tooltips = [(col, f"@{col}") for col in [x_col] + y_cols]
105
- hover.mode = 'mouse'
106
-
107
- return file_html(p, self.cdn)
108
-
109
- def create_bar(self, df: pd.DataFrame, x_col: str, y_col: str,
110
- title: str = "", color: str = "#1f77b4") -> str:
111
- """Create an interactive bar plot"""
112
- source = ColumnDataSource(df)
113
-
114
- p = figure(width=self.width, height=self.height,
115
- title=title, tools=self.tools,
116
- x_range=df[x_col].unique().tolist())
117
-
118
- # Add bars
119
- p.vbar(
120
- x=x_col,
121
- top=y_col,
122
- width=0.9,
123
- source=source,
124
- color=color,
125
- alpha=0.8
126
- )
127
-
128
- # Style the plot
129
- p.title.text_font_size = '16pt'
130
- p.xaxis.axis_label = x_col
131
- p.yaxis.axis_label = y_col
132
- p.axis.axis_label_text_font_size = '12pt'
133
- p.xgrid.grid_line_color = None
134
- p.xaxis.major_label_orientation = 0.7
135
-
136
- # Add hover tooltip
137
- hover = p.select(dict(type=HoverTool))
138
- hover.tooltips = [(x_col, f"@{x_col}"), (y_col, f"@{y_col}")]
139
- hover.mode = 'mouse'
140
-
141
- return file_html(p, self.cdn)
142
-
143
- def create_histogram(self, df: pd.DataFrame, column: str, bins: int = 30,
144
- title: str = "") -> str:
145
- """Create an interactive histogram"""
146
- hist, edges = np.histogram(df[column].dropna(), bins=bins)
147
- hist_df = pd.DataFrame({
148
- 'count': hist,
149
- 'left': edges[:-1],
150
- 'right': edges[1:]
151
- })
152
- source = ColumnDataSource(hist_df)
153
-
154
- p = figure(width=self.width, height=self.height,
155
- title=title, tools=self.tools)
156
-
157
- # Add histogram bars
158
- p.quad(
159
- top='count',
160
- bottom=0,
161
- left='left',
162
- right='right',
163
- source=source,
164
- fill_color="#1f77b4",
165
- line_color="white",
166
- alpha=0.8
167
  )
168
 
169
- # Style the plot
170
- p.title.text_font_size = '16pt'
171
- p.xaxis.axis_label = column
172
- p.yaxis.axis_label = 'Count'
173
- p.axis.axis_label_text_font_size = '12pt'
174
-
175
- # Add hover tooltip
176
- hover = p.select(dict(type=HoverTool))
177
- hover.tooltips = [
178
- ('Range', '@left{0.00} to @right{0.00}'),
179
- ('Count', '@count')
180
- ]
181
- hover.mode = 'mouse'
182
-
183
- return file_html(p, self.cdn)
184
-
185
- class DataAnalyzer:
186
- """Helper class for common data analysis tasks"""
187
-
188
- @staticmethod
189
- def get_summary_stats(df: pd.DataFrame) -> pd.DataFrame:
190
- """Get summary statistics for numerical columns"""
191
- return df.describe()
192
-
193
- @staticmethod
194
- def get_missing_values(df: pd.DataFrame) -> pd.DataFrame:
195
- """Get missing values information"""
196
- missing = pd.DataFrame({
197
- 'column': df.columns,
198
- 'missing_count': df.isnull().sum(),
199
- 'missing_percentage': (df.isnull().sum() / len(df) * 100).round(2)
200
- })
201
- return missing[missing['missing_count'] > 0]
202
-
203
- @staticmethod
204
- def get_correlation_matrix(df: pd.DataFrame) -> pd.DataFrame:
205
- """Get correlation matrix for numerical columns"""
206
- numeric_cols = df.select_dtypes(include=[np.number]).columns
207
- return df[numeric_cols].corr()
208
 
209
- class AnalysisSession:
210
- """Maintains state and history for the analysis session"""
211
 
212
  def __init__(self):
213
- self.data: Optional[pd.DataFrame] = None
214
- self.chat_history: List[Dict[str, str]] = []
215
- self.viz_engine = VisualizationEngine()
216
  self.analyzer = DataAnalyzer()
 
217
 
218
- def add_message(self, role: str, content: str):
219
- """Add a message to the chat history"""
220
- self.chat_history.append({"role": role, "content": content})
221
-
222
- def get_context(self) -> str:
223
- """Get the current analysis context"""
224
- if self.data is None:
225
- return "No data loaded yet."
 
226
 
227
- numeric_cols = self.data.select_dtypes(include=[np.number]).columns
228
- categorical_cols = self.data.select_dtypes(include=['object', 'category']).columns
229
-
230
- missing_info = self.analyzer.get_missing_values(self.data)
231
- missing_summary = "\n".join([
232
- f"- {row['column']}: {row['missing_count']} ({row['missing_percentage']}%)"
233
- for _, row in missing_info.iterrows()
234
- ]) if not missing_info.empty else "No missing values found."
235
-
236
- context = f"""
237
- Current DataFrame Info:
238
- - Shape: {self.data.shape}
239
- - Numeric columns: {', '.join(numeric_cols)}
240
- - Categorical columns: {', '.join(categorical_cols)}
241
-
242
- Missing Values:
243
- {missing_summary}
244
- """
245
- return context
246
-
247
- class AnalysisAgent:
248
- """Enhanced agent with interactive visualization and chat capabilities"""
249
 
250
- def __init__(
251
- self,
252
- model_id: str = "gpt-4",
253
- temperature: float = 0.7,
254
- ):
255
- self.model_id = model_id
256
- self.temperature = temperature
257
- self.session = AnalysisSession()
258
-
259
- def process_query(self, query: str) -> str:
260
- """Process a user query and generate response with visualizations"""
261
- context = self.session.get_context()
262
-
263
- messages = [
264
- {"role": "system", "content": self._get_system_prompt()},
265
- *self.session.chat_history[-5:], # Include last 5 messages for context
266
- {"role": "user", "content": f"{context}\n\nUser query: {query}"}
267
- ]
268
-
269
- try:
270
- response = completion(
271
- model=self.model_id,
272
- messages=messages,
273
- temperature=self.temperature,
274
- )
275
- analysis = response.choices[0].message.content
276
 
277
- # Extract and execute any code blocks
278
- visualizations = []
279
- code_blocks = self._extract_code(analysis)
280
 
281
- for code in code_blocks:
282
- try:
283
- # Execute code and capture visualization commands
284
- result = self._execute_visualization(code)
285
- if result:
286
- visualizations.append(result)
287
- except Exception as e:
288
- visualizations.append(f"Error creating visualization: {str(e)}")
289
 
290
- # Add messages to chat history
291
- self.session.add_message("user", query)
292
- self.session.add_message("assistant", analysis)
293
 
294
- # Format the response with visualizations
295
- formatted_response = self._format_response(analysis, visualizations)
296
- return formatted_response
 
 
297
 
298
- except Exception as e:
299
- return f"Error: {str(e)}"
300
-
301
- def _execute_visualization(self, code: str) -> Optional[str]:
302
- """Execute visualization code and return HTML output"""
303
- try:
304
- # Create a safe namespace with necessary libraries
305
- namespace = {
306
- 'df': self.session.data,
307
- 'np': np,
308
- 'pd': pd,
309
- 'viz': self.session.viz_engine,
310
- 'analyzer': self.session.analyzer
311
- }
312
 
313
- # Execute the code
314
- exec(code, namespace)
315
 
316
- # Look for visualization result
317
- for var in namespace.values():
318
- if isinstance(var, str) and ('<script' in var or '<div' in var):
319
- return var
320
-
321
- return None
322
 
323
  except Exception as e:
324
- return f"Error executing visualization: {str(e)}"
325
 
326
- def _format_response(self, analysis: str, visualizations: List[str]) -> str:
327
- """Format the response with text and visualizations"""
328
- # Split analysis into parts (before and after code blocks)
329
- parts = self._extract_code(analysis, keep_markdown=True)
 
330
 
331
- formatted_parts = []
332
- for i, part in enumerate(parts):
333
- if i % 2 == 0: # Text content
334
- formatted_parts.append(f'<div class="analysis-text">{part}</div>')
335
- else: # Code block location
336
- if i//2 < len(visualizations):
337
- viz = visualizations[i//2]
338
- formatted_parts.append(f'<div class="visualization">{viz}</div>')
339
 
340
- return '\n'.join(formatted_parts)
 
 
341
 
342
  def _get_system_prompt(self) -> str:
343
- """Get system prompt with visualization capabilities"""
344
- return """You are a data analysis assistant with interactive visualization capabilities.
345
-
346
- Available visualizations:
347
- 1. Scatter plots (viz.create_scatter)
348
- - x_col: x-axis column name
349
- - y_col: y-axis column name
350
- - color_col: optional column for color coding
351
- - title: plot title
352
-
353
- 2. Line plots (viz.create_line)
354
- - x_col: x-axis column name
355
- - y_cols: list of column names for multiple lines
356
- - title: plot title
357
-
358
- 3. Bar plots (viz.create_bar)
359
- - x_col: category column name
360
- - y_col: value column name
361
- - title: plot title
362
- - color: optional bar color
363
-
364
- 4. Histograms (viz.create_histogram)
365
- - column: column to analyze
366
- - bins: number of bins
367
- - title: plot title
368
 
369
- Analysis tools:
370
- - analyzer.get_summary_stats(df): Get summary statistics
371
- - analyzer.get_correlation_matrix(df): Get correlation matrix
372
- - analyzer.get_missing_values(df): Get missing values information
 
 
 
 
373
 
374
- When analyzing data:
375
- 1. First understand and explain the data
376
- 2. Create relevant visualizations using the viz engine
377
- 3. Provide insights based on the visualizations
378
- 4. Ask follow-up questions when appropriate
379
- 5. Use markdown for formatting
380
 
381
- Example visualization code:
382
  ```python
383
  # Create scatter plot
384
- html = viz.create_scatter(df, 'column1', 'column2', title='Analysis')
385
- print(html)
386
-
387
- # Create line plot with multiple series
388
- html = viz.create_line(df, 'date_column', ['value1', 'value2'], title='Trends')
389
- print(html)
390
-
391
- # Create histogram
392
- html = viz.create_histogram(df, 'numeric_column', bins=30, title='Distribution')
393
- print(html)
394
  ```
395
  """
396
 
397
- @staticmethod
398
- def _extract_code(text: str, keep_markdown: bool = False) -> List[str]:
399
- """Extract Python code blocks from markdown"""
400
- import re
401
- pattern = r'```python\n(.*?)```'
402
- if keep_markdown:
403
- return re.split(pattern, text, flags=re.DOTALL)
404
- return re.findall(pattern, text, re.DOTALL)
405
-
406
- def create_interface():
407
- """Create Gradio interface with proper HTML rendering"""
408
-
409
- agent = AnalysisAgent()
410
-
411
- def format_html_output(content: str) -> str:
412
- """Format the output to properly render HTML in Gradio"""
413
- # Add custom CSS for better visualization display
414
- css = """
415
- <style>
416
- .analysis-text {
417
- padding: 20px;
418
- margin: 10px 0;
419
- background: #f8f9fa;
420
- border-radius: 8px;
421
- font-size: 16px;
422
- }
423
- .visualization {
424
- margin: 20px 0;
425
- padding: 10px;
426
- border: 1px solid #dee2e6;
427
- border-radius: 8px;
428
- background: white;
429
- }
430
- .bokeh-plot {
431
- margin: 0 auto;
432
- }
433
- pre {
434
- background: #f1f3f5;
435
- padding: 15px;
436
- border-radius: 5px;
437
- overflow-x: auto;
438
- }
439
- code {
440
- font-family: 'Courier New', Courier, monospace;
441
- }
442
- </style>
443
- """
444
- return f"{css}\n{content}"
445
-
446
- def process_file(file: gr.File) -> str:
447
- """Process uploaded file and initialize session"""
448
- try:
449
- if file.name.endswith('.csv'):
450
- agent.session.data = pd.read_csv(file.name)
451
- elif file.name.endswith(('.xlsx', '.xls')):
452
- agent.session.data = pd.read_excel(file.name)
453
- else:
454
- return format_html_output(
455
- '<div class="analysis-text">Error: Unsupported file type. Please upload a CSV or Excel file.</div>'
456
- )
457
-
458
- # Show initial data summary
459
- summary = agent.session.get_context()
460
- return format_html_output(
461
- f'<div class="analysis-text">Successfully loaded data!\n\n{summary}</div>'
462
- )
463
- except Exception as e:
464
- return format_html_output(
465
- f'<div class="analysis-text">Error loading file: {str(e)}</div>'
466
- )
467
-
468
- def analyze(file: gr.File, query: str, api_key: str, chat_history: str) -> tuple:
469
- """Process analysis query and update chat history"""
470
- if not api_key:
471
- return (
472
- format_html_output('<div class="analysis-text">Error: Please provide an API key.</div>'),
473
- chat_history
474
- )
475
-
476
- if not file:
477
- return (
478
- format_html_output('<div class="analysis-text">Error: Please upload a file.</div>'),
479
- chat_history
480
- )
481
-
482
  try:
483
- os.environ["OPENAI_API_KEY"] = api_key
484
- result = agent.process_query(query)
485
-
486
- # Update chat history
487
- new_history = chat_history or ""
488
- new_history += f"\nYou: {query}\nAssistant: {result}\n"
489
 
490
- return format_html_output(result), new_history
 
 
491
 
 
 
 
492
  except Exception as e:
493
- return (
494
- format_html_output(f'<div class="analysis-text">Error: {str(e)}</div>'),
495
- chat_history
496
- )
 
 
497
 
498
- # Create the Gradio interface
499
- with gr.Blocks(css="""
500
- .container { max-width: 1200px; margin: auto; }
501
- .analysis-header { margin-bottom: 20px; }
502
- .file-upload { margin-bottom: 15px; }
503
- """) as interface:
504
- gr.Markdown("""
505
- # Interactive Data Analysis Assistant
506
 
507
- Upload your data file and chat with the AI to analyze it. Features:
508
- - Interactive visualizations with zoom, pan, and hover capabilities
509
- - Natural language analysis and insights
510
- - Statistical analysis and summaries
511
- - Trend detection and pattern analysis
512
 
513
- **Note**: Requires OpenAI API key
 
 
 
 
 
 
 
 
514
  """)
515
 
516
  with gr.Row():
517
  with gr.Column(scale=1):
518
- file = gr.File(
519
- label="Upload Data File",
520
- file_types=[".csv", ".xlsx", ".xls"],
521
- elem_classes="file-upload"
522
- )
523
-
524
- api_key = gr.Textbox(
525
- label="OpenAI API Key",
526
- type="password",
527
- placeholder="Enter your API key here"
528
- )
529
-
530
- chat_input = gr.Textbox(
531
- label="Ask about your data",
532
- placeholder="e.g., Show me the relationship between variables",
533
- lines=3
534
- )
535
-
536
- chat_history = gr.State("")
537
-
538
- analyze_btn = gr.Button("Analyze", variant="primary")
539
 
540
  with gr.Column(scale=2):
541
- output = gr.HTML(
542
- label="Analysis & Visualizations",
543
- elem_classes="analysis-output"
544
- )
545
 
546
- # Set up event handlers
547
- file.change(
548
- process_file,
549
- inputs=[file],
550
- outputs=[output]
551
- )
552
 
553
- analyze_btn.click(
554
- analyze,
555
- inputs=[file, chat_input, api_key, chat_history],
556
- outputs=[output, chat_history]
 
 
557
  )
558
 
559
- # Example queries
560
  gr.Examples(
561
  examples=[
562
- [None, "Show me the distribution of all numerical variables using histograms"],
563
- [None, "Create an interactive scatter plot matrix of the main variables"],
564
- [None, "Analyze trends over time and show them with an interactive line plot"],
565
- [None, "Compare categories using bar plots and provide statistical insights"],
566
- [None, "Identify and visualize correlations between numerical variables"],
567
- [None, "Create a dashboard showing key metrics and their distributions"],
568
  ],
569
- inputs=[file, chat_input]
570
  )
571
-
572
- # Add footer with information
573
- gr.Markdown("""
574
- ### Tips for better analysis:
575
- 1. Upload clean data in CSV or Excel format
576
- 2. Be specific in your questions
577
- 3. Use follow-up questions to dive deeper
578
- 4. Interact with the visualizations using mouse hover, zoom, and pan
579
- 5. Look for patterns and trends in the interactive plots
580
- """)
581
 
582
- return interface
583
 
584
  if __name__ == "__main__":
585
- interface = create_interface()
586
- interface.launch()
 
 
 
1
  import os
2
+ from typing import List, Optional, Tuple, Dict, Any
 
3
  import json
4
 
5
  import gradio as gr
 
6
  import pandas as pd
7
+ import numpy as np
8
+ import plotly.express as px
9
+ import plotly.graph_objects as go
10
+ from plotly.subplots import make_subplots
 
 
 
 
 
 
 
 
 
11
  from litellm import completion
12
 
13
+ class DataAnalyzer:
14
+ """Handles data analysis and visualization"""
15
 
16
  def __init__(self):
17
+ self.data: Optional[pd.DataFrame] = None
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def create_visualization(self, plot_type: str, **kwargs) -> go.Figure:
20
+ """Create different types of plotly visualizations"""
21
+ if self.data is None:
22
+ raise ValueError("No data loaded")
23
+
24
+ if plot_type == "scatter":
25
+ fig = px.scatter(
26
+ self.data, x=kwargs.get('x'), y=kwargs.get('y'),
27
+ color=kwargs.get('color'),
28
+ title=kwargs.get('title', 'Scatter Plot'),
29
+ labels=kwargs.get('labels', {}),
30
+ trendline=kwargs.get('trendline'),
31
  )
32
+
33
+ elif plot_type == "line":
34
+ fig = px.line(
35
+ self.data, x=kwargs.get('x'), y=kwargs.get('y'),
36
+ color=kwargs.get('color'),
37
+ title=kwargs.get('title', 'Line Plot'),
38
+ labels=kwargs.get('labels', {})
39
  )
40
+
41
+ elif plot_type == "bar":
42
+ fig = px.bar(
43
+ self.data, x=kwargs.get('x'), y=kwargs.get('y'),
44
+ color=kwargs.get('color'),
45
+ title=kwargs.get('title', 'Bar Plot'),
46
+ labels=kwargs.get('labels', {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
48
+
49
+ elif plot_type == "histogram":
50
+ fig = px.histogram(
51
+ self.data, x=kwargs.get('x'),
52
+ nbins=kwargs.get('bins', 30),
53
+ title=kwargs.get('title', 'Histogram'),
54
+ marginal=kwargs.get('marginal', 'box')
55
+ )
56
+
57
+ elif plot_type == "box":
58
+ fig = px.box(
59
+ self.data, x=kwargs.get('x'), y=kwargs.get('y'),
60
+ color=kwargs.get('color'),
61
+ title=kwargs.get('title', 'Box Plot')
62
+ )
63
+
64
+ elif plot_type == "violin":
65
+ fig = px.violin(
66
+ self.data, x=kwargs.get('x'), y=kwargs.get('y'),
67
+ color=kwargs.get('color'),
68
+ box=True,
69
+ title=kwargs.get('title', 'Violin Plot')
70
+ )
71
+
72
+ elif plot_type == "correlation":
73
+ corr = self.data.select_dtypes(include=[np.number]).corr()
74
+ fig = px.imshow(
75
+ corr,
76
+ title=kwargs.get('title', 'Correlation Matrix'),
77
+ color_continuous_scale="RdBu"
78
+ )
79
+
80
+ else:
81
+ raise ValueError(f"Unknown plot type: {plot_type}")
82
 
83
+ # Update layout for better interactivity
84
+ fig.update_layout(
85
+ hovermode='x unified',
86
+ template='plotly_white',
87
+ height=500,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  )
89
 
90
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ class ChatAnalyzer:
93
+ """Handles chat-based analysis with visualization"""
94
 
95
  def __init__(self):
 
 
 
96
  self.analyzer = DataAnalyzer()
97
+ self.chat_history: List[Tuple[str, str]] = []
98
 
99
+ def process_file(self, file: gr.File) -> str:
100
+ """Process uploaded file"""
101
+ try:
102
+ if file.name.endswith('.csv'):
103
+ self.analyzer.data = pd.read_csv(file.name)
104
+ elif file.name.endswith(('.xlsx', '.xls')):
105
+ self.analyzer.data = pd.read_excel(file.name)
106
+ else:
107
+ return "Error: Please upload a CSV or Excel file."
108
 
109
+ info = f"""
110
+ Successfully loaded data with shape: {self.analyzer.data.shape}
111
+ Columns: {', '.join(self.analyzer.data.columns)}
112
+ """
113
+ return info
114
+
115
+ except Exception as e:
116
+ return f"Error loading file: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ def analyze(self, message: str, api_key: str) -> Tuple[str, List[go.Figure]]:
119
+ """Analyze data based on user message"""
120
+ if self.analyzer.data is None:
121
+ return "Please upload a data file first.", []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ if not api_key:
124
+ return "Please provide an OpenAI API key.", []
 
125
 
126
+ try:
127
+ os.environ["OPENAI_API_KEY"] = api_key
 
 
 
 
 
 
128
 
129
+ # Prepare context for AI
130
+ context = self._get_data_context()
 
131
 
132
+ # Get AI response
133
+ messages = [
134
+ {"role": "system", "content": self._get_system_prompt()},
135
+ {"role": "user", "content": f"{context}\n\nUser request: {message}"}
136
+ ]
137
 
138
+ response = completion(
139
+ model="gpt-4o-mini",
140
+ messages=messages,
141
+ temperature=0.7
142
+ )
 
 
 
 
 
 
 
 
 
143
 
144
+ analysis = response.choices[0].message.content
 
145
 
146
+ # Extract visualization commands and create plots
147
+ figures = self._create_visualizations(analysis)
148
+
149
+ return analysis, figures
 
 
150
 
151
  except Exception as e:
152
+ return f"Error during analysis: {str(e)}", []
153
 
154
+ def _get_data_context(self) -> str:
155
+ """Get current data context"""
156
+ df = self.analyzer.data
157
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
158
+ categorical_cols = df.select_dtypes(include=['object', 'category']).columns
159
 
160
+ return f"""
161
+ Available Data Information:
162
+ - Shape: {df.shape}
163
+ - Numeric columns: {', '.join(numeric_cols)}
164
+ - Categorical columns: {', '.join(categorical_cols)}
 
 
 
165
 
166
+ Basic Statistics:
167
+ {df.describe().to_string()}
168
+ """
169
 
170
  def _get_system_prompt(self) -> str:
171
+ """Get system prompt"""
172
+ return """You are a data analysis assistant specialized in creating interactive visualizations using Plotly.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ Available plot types:
175
+ 1. scatter - for relationships between variables
176
+ 2. line - for trends over time
177
+ 3. bar - for comparisons between categories
178
+ 4. histogram - for distributions
179
+ 5. box - for statistical summaries
180
+ 6. violin - for distribution comparisons
181
+ 7. correlation - for correlation matrix
182
 
183
+ When creating visualizations:
184
+ 1. Specify the plot type and required parameters
185
+ 2. Provide insights about the visualization
186
+ 3. Suggest follow-up analyses
187
+ 4. Use markdown for formatting
 
188
 
189
+ Example command format:
190
  ```python
191
  # Create scatter plot
192
+ plot = viz.create_visualization("scatter", x="column1", y="column2", title="Analysis")
193
+ print(plot)
 
 
 
 
 
 
 
 
194
  ```
195
  """
196
 
197
+ def _create_visualizations(self, analysis: str) -> List[go.Figure]:
198
+ """Extract and create visualizations from analysis"""
199
+ figures = []
200
+ viz = self.analyzer
201
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  try:
203
+ # Execute visualization commands in the analysis
204
+ exec_globals = {
205
+ 'viz': viz,
206
+ 'print': lambda x: figures.append(x) if isinstance(x, go.Figure) else None
207
+ }
 
208
 
209
+ # Extract and execute code blocks
210
+ import re
211
+ code_blocks = re.findall(r'```python\n(.*?)```', analysis, re.DOTALL)
212
 
213
+ for code in code_blocks:
214
+ exec(code, exec_globals)
215
+
216
  except Exception as e:
217
+ print(f"Error creating visualizations: {str(e)}")
218
+
219
+ return figures
220
+
221
+ def create_interface():
222
+ """Create Gradio interface"""
223
 
224
+ analyzer = ChatAnalyzer()
225
+
226
+ def chat(message: str, api_key: str) -> Tuple[List[Tuple[str, str]], List[gr.Plot]]:
227
+ """Handle chat interaction"""
228
+ response, figures = analyzer.analyze(message, api_key)
 
 
 
229
 
230
+ # Update chat history
231
+ analyzer.chat_history.append((message, response))
 
 
 
232
 
233
+ # Convert figures to Gradio plots
234
+ plots = [gr.Plot(fig) for fig in figures]
235
+
236
+ return analyzer.chat_history, plots
237
+
238
+ with gr.Blocks() as demo:
239
+ gr.Markdown("""
240
+ # Interactive Data Analysis Chat
241
+ Upload your data and chat with AI to create interactive visualizations!
242
  """)
243
 
244
  with gr.Row():
245
  with gr.Column(scale=1):
246
+ file = gr.File(label="Upload Data (CSV or Excel)")
247
+ api_key = gr.Textbox(label="OpenAI API Key", type="password")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  with gr.Column(scale=2):
250
+ chatbot = gr.Chatbot(height=400)
 
 
 
251
 
252
+ with gr.Row():
253
+ message = gr.Textbox(label="Ask about your data", lines=2)
254
+ send = gr.Button("Send")
255
+
256
+ # Plot output area
257
+ plot_output = gr.Plot(visible=False)
258
 
259
+ # Set up event handlers
260
+ file.change(analyzer.process_file, inputs=[file], outputs=[chatbot])
261
+ send.click(
262
+ chat,
263
+ inputs=[message, api_key],
264
+ outputs=[chatbot, plot_output]
265
  )
266
 
 
267
  gr.Examples(
268
  examples=[
269
+ ["Show me a scatter plot of the main numerical variables"],
270
+ ["Create a correlation matrix of all numerical columns"],
271
+ ["Analyze the distribution of each variable"],
272
+ ["Show trends over time if there's temporal data"],
 
 
273
  ],
274
+ inputs=message
275
  )
 
 
 
 
 
 
 
 
 
 
276
 
277
+ return demo
278
 
279
  if __name__ == "__main__":
280
+ demo = create_interface()
281
+ demo.launch()