jzou19950715 commited on
Commit
9bb3afe
·
verified ·
1 Parent(s): 927667b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -180
app.py CHANGED
@@ -1,17 +1,14 @@
1
- import base64
2
- import io
3
  import os
4
- from dataclasses import dataclass
5
  from typing import List, Optional, Tuple, Dict, Any
6
- import json
 
7
 
8
  import gradio as gr
9
  import pandas as pd
10
  import numpy as np
11
- import plotly.express as px
12
  import plotly.graph_objects as go
13
- from plotly.subplots import make_subplots
14
  from litellm import completion
 
15
  class DataAnalyzer:
16
  """Handles data analysis and visualization"""
17
 
@@ -42,7 +39,6 @@ class DataAnalyzer:
42
  template="plotly_white"
43
  )
44
 
45
- # Convert to HTML string
46
  return fig.to_html(include_plotlyjs=True, full_html=False)
47
 
48
  def create_scatter(self, x_col: str, y_col: str, color_col: Optional[str] = None,
@@ -89,12 +85,11 @@ class DataAnalyzer:
89
 
90
  fig = go.Figure()
91
 
92
- # Create box plot for each category
93
  for category in self.data[x_col].unique():
94
  fig.add_trace(go.Box(
95
  y=self.data[self.data[x_col] == category][y_col],
96
  name=str(category),
97
- boxpoints='all', # show all points
98
  jitter=0.3,
99
  pointpos=-1.8
100
  ))
@@ -189,65 +184,65 @@ class ChatAnalyzer:
189
  return self.history
190
 
191
  def chat(self, message: str, api_key: str) -> Tuple[List[Tuple[str, str]], str]:
192
- """Process chat message and generate visualizations"""
193
- if self.analyzer.data is None:
194
- return [(message, "Please upload a data file first.")], ""
195
-
196
- if not api_key:
197
- return [(message, "Please provide an OpenAI API key.")], ""
198
-
199
- try:
200
- os.environ["OPENAI_API_KEY"] = api_key
201
-
202
- # Get data context
203
- context = self._get_data_context()
204
-
205
- # Get AI response
206
- completion_response = completion(
207
- model="gpt-4o-mini",
208
- messages=[
209
- {"role": "system", "content": self._get_system_prompt()},
210
- {"role": "user", "content": f"{context}\n\nUser question: {message}"}
211
- ],
212
- temperature=0.7
213
- )
214
-
215
- analysis = completion_response.choices[0].message.content
216
-
217
- # Create visualizations
218
- plots_html = ""
219
  try:
220
- # Extract code blocks
221
- import re
222
- code_blocks = re.findall(r'```python\n(.*?)```', analysis, re.DOTALL)
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- for code in code_blocks:
225
- # Create namespace for execution
226
- namespace = {
227
- 'analyzer': self.analyzer,
228
- 'df': self.analyzer.data,
229
- 'print': lambda x: x
230
- }
231
 
232
- # Execute the code
233
- try:
234
- result = eval(code, namespace)
235
- if isinstance(result, str) and ('<div' in result or '<script' in result):
236
- plots_html += f'<div class="plot-container">{result}</div>'
237
- except:
238
- exec(code, namespace)
 
 
 
 
 
 
 
 
239
 
 
 
 
 
 
 
 
 
240
  except Exception as e:
241
- analysis += f"\n\nError creating visualization: {str(e)}"
242
-
243
- # Update chat history
244
- self.history.append((message, analysis))
245
-
246
- return self.history, plots_html
247
-
248
- except Exception as e:
249
- self.history.append((message, f"Error: {str(e)}"))
250
- return self.history, ""
251
 
252
  def _get_data_context(self) -> str:
253
  """Get current data context for AI"""
@@ -270,17 +265,15 @@ class ChatAnalyzer:
270
  {stats}
271
 
272
  Available visualization functions:
273
- - analyzer.create_scatter(x_col, y_col, color_col, title)
274
- - analyzer.create_line(x_col, y_cols, title)
275
- - analyzer.create_bar(x_col, y_col, color_col, title)
276
  - analyzer.create_histogram(column, bins, title)
277
- - analyzer.create_box(x_col, y_col, color_col, title)
278
- - analyzer.create_correlation_matrix(title)
 
279
  """
280
 
281
  def _get_system_prompt(self) -> str:
282
- """Get system prompt for AI"""
283
- return """You are a data analysis assistant specialized in creating interactive visualizations.
284
 
285
  Available visualization functions:
286
  1. create_histogram(column, bins, title) - For distribution analysis
@@ -298,7 +291,7 @@ result = analyzer.create_histogram(
298
  )
299
  print(result)
300
 
301
- # Create scatter plot
302
  result = analyzer.create_scatter(
303
  x_col='Date',
304
  y_col='Salary',
@@ -316,7 +309,7 @@ result = analyzer.create_box(
316
  print(result)
317
  ```
318
 
319
- Always wrap code in Python code blocks and use print() to display the visualizations.
320
  Provide analysis and insights about what the visualizations show."""
321
 
322
  def create_interface():
@@ -324,7 +317,7 @@ def create_interface():
324
 
325
  analyzer = ChatAnalyzer()
326
 
327
- # Custom CSS for better visualization display
328
  css = """
329
  .container { max-width: 1200px; margin: auto; }
330
  .plot-container {
@@ -341,157 +334,75 @@ def create_interface():
341
  border-radius: 8px;
342
  background: #f8f9fa;
343
  }
344
- .title {
345
- text-align: center;
346
- margin-bottom: 20px;
347
- }
348
- .footer {
349
- text-align: center;
350
- margin-top: 20px;
351
- font-size: 0.9em;
352
- color: #666;
353
- }
354
  """
355
 
356
- with gr.Blocks(css=css, title="Interactive Data Analysis Chat") as demo:
357
  gr.Markdown("""
358
  # Interactive Data Analysis Chat
359
 
360
  Upload your data and chat with AI to analyze it! Features:
361
- - Interactive visualizations with zoom, pan, and hover capabilities
362
- - Natural language analysis and insights
363
- - Statistical analysis and summaries
364
- - Trend detection and pattern analysis
365
-
366
- Start by uploading a CSV or Excel file.
367
  """)
368
 
369
  with gr.Row():
370
  with gr.Column(scale=1):
371
  file = gr.File(
372
  label="Upload Data (CSV or Excel)",
373
- file_types=[".csv", ".xlsx", ".xls"],
374
- elem_classes="file-upload"
375
  )
376
  api_key = gr.Textbox(
377
  label="OpenAI API Key",
378
  type="password",
379
- placeholder="Enter your API key",
380
- elem_classes="api-input"
381
  )
382
 
383
  with gr.Column(scale=2):
384
  chatbot = gr.Chatbot(
385
  height=400,
386
- elem_classes="chat-message",
387
- show_label=False
 
 
 
 
388
  )
389
- with gr.Row():
390
- message = gr.Textbox(
391
- label="Ask about your data",
392
- placeholder="e.g., Show me trends in the data",
393
- lines=2,
394
- elem_classes="message-input",
395
- scale=4
396
- )
397
- send = gr.Button(
398
- "Send",
399
- scale=1,
400
- elem_classes="send-button"
401
- )
402
 
403
  # Plot output area
404
  plot_output = gr.HTML(
405
  label="Visualizations",
406
- elem_classes="plot-container",
407
- visible=True # Always show container even when empty
408
  )
409
 
410
  # Event handlers
411
  file.change(
412
- fn=analyzer.process_file,
413
  inputs=[file],
414
- outputs=[chatbot],
415
- api_name="upload"
416
- )
417
-
418
- # Handle both click and enter key
419
- msg_handler = send.click(
420
- fn=analyzer.chat,
421
- inputs=[message, api_key],
422
- outputs=[chatbot, plot_output],
423
- api_name="chat"
424
  )
425
 
426
- message.submit(
427
- fn=analyzer.chat,
428
  inputs=[message, api_key],
429
  outputs=[chatbot, plot_output]
430
  )
431
 
432
- # Clear message after sending
433
- msg_handler.then(
434
- fn=lambda: "",
435
- inputs=[],
436
- outputs=[message]
437
- )
438
-
439
  # Example queries
440
  gr.Examples(
441
  examples=[
442
- ["Show me a scatter plot of the main numerical variables and explain any patterns you see"],
443
- ["Create a correlation analysis with heatmap and highlight the strongest relationships"],
444
- ["Show the distribution of values using histograms and describe the shapes"],
445
- ["Create box plots to analyze categories and identify any outliers"],
446
- ["Show trends over time using line plots and explain the patterns"],
447
- ["Generate a comprehensive analysis with multiple visualizations"],
448
- ["Compare the distribution across different categories using appropriate plots"],
449
- ["Identify and visualize any seasonal patterns or cycles in the data"],
450
  ],
451
- inputs=[message],
452
- label="Example Analysis Queries"
453
- )
454
-
455
- # Tips section
456
- gr.Markdown("""
457
- ### Tips for better analysis:
458
- 1. **Data Preparation**: Upload clean CSV or Excel files with clear column names
459
- 2. **Specific Questions**: Ask clear, specific questions about your data
460
- 3. **Interactive Features**: Use zoom, pan, and hover on visualizations
461
- 4. **Follow-up Questions**: Ask for deeper analysis of interesting patterns
462
- 5. **Multiple Views**: Request different visualization types for better insights
463
-
464
- ### Available Visualization Types:
465
- - Scatter plots for relationships
466
- - Line plots for trends
467
- - Bar charts for comparisons
468
- - Histograms for distributions
469
- - Box plots for statistical summaries
470
- - Correlation matrices for relationship analysis
471
- """)
472
-
473
- # Footer
474
- gr.Markdown("""
475
- <div class="footer">
476
- Built with Gradio • Powered by OpenAI • Interactive Visualizations
477
- </div>
478
- """)
479
-
480
- # Theme customization
481
- demo.theme = gr.themes.Soft(
482
- primary_hue="blue",
483
- secondary_hue="gray",
484
- neutral_hue="gray",
485
- text_size=gr.themes.sizes.text_md
486
  )
487
 
488
  return demo
489
 
490
  if __name__ == "__main__":
491
  demo = create_interface()
492
- demo.launch(
493
- share=False, # Set to True to create a public link
494
- debug=True, # Set to False in production
495
- show_error=True, # Show detailed error messages
496
- server_port=7860 # Specify port number
497
- )
 
 
 
1
  import os
 
2
  from typing import List, Optional, Tuple, Dict, Any
3
+ import base64
4
+ import io
5
 
6
  import gradio as gr
7
  import pandas as pd
8
  import numpy as np
 
9
  import plotly.graph_objects as go
 
10
  from litellm import completion
11
+
12
  class DataAnalyzer:
13
  """Handles data analysis and visualization"""
14
 
 
39
  template="plotly_white"
40
  )
41
 
 
42
  return fig.to_html(include_plotlyjs=True, full_html=False)
43
 
44
  def create_scatter(self, x_col: str, y_col: str, color_col: Optional[str] = None,
 
85
 
86
  fig = go.Figure()
87
 
 
88
  for category in self.data[x_col].unique():
89
  fig.add_trace(go.Box(
90
  y=self.data[self.data[x_col] == category][y_col],
91
  name=str(category),
92
+ boxpoints='all',
93
  jitter=0.3,
94
  pointpos=-1.8
95
  ))
 
184
  return self.history
185
 
186
  def chat(self, message: str, api_key: str) -> Tuple[List[Tuple[str, str]], str]:
187
+ """Process chat message and generate visualizations"""
188
+ if self.analyzer.data is None:
189
+ return [(message, "Please upload a data file first.")], ""
190
+
191
+ if not api_key:
192
+ return [(message, "Please provide an OpenAI API key.")], ""
193
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  try:
195
+ os.environ["OPENAI_API_KEY"] = api_key
196
+
197
+ # Get data context
198
+ context = self._get_data_context()
199
+
200
+ # Get AI response
201
+ completion_response = completion(
202
+ model="gpt-4o-mini",
203
+ messages=[
204
+ {"role": "system", "content": self._get_system_prompt()},
205
+ {"role": "user", "content": f"{context}\n\nUser question: {message}"}
206
+ ],
207
+ temperature=0.7
208
+ )
209
+
210
+ analysis = completion_response.choices[0].message.content
211
 
212
+ # Create visualizations
213
+ plots_html = ""
214
+ try:
215
+ # Extract code blocks
216
+ import re
217
+ code_blocks = re.findall(r'```python\n(.*?)```', analysis, re.DOTALL)
 
218
 
219
+ for code in code_blocks:
220
+ # Create namespace for execution
221
+ namespace = {
222
+ 'analyzer': self.analyzer,
223
+ 'df': self.analyzer.data,
224
+ 'print': lambda x: x
225
+ }
226
+
227
+ # Execute the code
228
+ try:
229
+ result = eval(code, namespace)
230
+ if isinstance(result, str) and ('<div' in result or '<script' in result):
231
+ plots_html += f'<div class="plot-container">{result}</div>'
232
+ except:
233
+ exec(code, namespace)
234
 
235
+ except Exception as e:
236
+ analysis += f"\n\nError creating visualization: {str(e)}"
237
+
238
+ # Update chat history
239
+ self.history.append((message, analysis))
240
+
241
+ return self.history, plots_html
242
+
243
  except Exception as e:
244
+ self.history.append((message, f"Error: {str(e)}"))
245
+ return self.history, ""
 
 
 
 
 
 
 
 
246
 
247
  def _get_data_context(self) -> str:
248
  """Get current data context for AI"""
 
265
  {stats}
266
 
267
  Available visualization functions:
 
 
 
268
  - analyzer.create_histogram(column, bins, title)
269
+ - analyzer.create_scatter(x_col, y_col, color_col, title)
270
+ - analyzer.create_box(x_col, y_col, title)
271
+ - analyzer.create_line(x_col, y_col, color_col, title)
272
  """
273
 
274
  def _get_system_prompt(self) -> str:
275
+ """Get system prompt for AI"""
276
+ return """You are a data analysis assistant specialized in creating interactive visualizations.
277
 
278
  Available visualization functions:
279
  1. create_histogram(column, bins, title) - For distribution analysis
 
291
  )
292
  print(result)
293
 
294
+ # Create scatter plot with time series
295
  result = analyzer.create_scatter(
296
  x_col='Date',
297
  y_col='Salary',
 
309
  print(result)
310
  ```
311
 
312
+ Always wrap code in Python code blocks and use print() to display the visualizations.
313
  Provide analysis and insights about what the visualizations show."""
314
 
315
  def create_interface():
 
317
 
318
  analyzer = ChatAnalyzer()
319
 
320
+ # Custom CSS
321
  css = """
322
  .container { max-width: 1200px; margin: auto; }
323
  .plot-container {
 
334
  border-radius: 8px;
335
  background: #f8f9fa;
336
  }
 
 
 
 
 
 
 
 
 
 
337
  """
338
 
339
+ with gr.Blocks(css=css) as demo:
340
  gr.Markdown("""
341
  # Interactive Data Analysis Chat
342
 
343
  Upload your data and chat with AI to analyze it! Features:
344
+ - Interactive visualizations
345
+ - Natural language analysis
346
+ - Statistical insights
347
+ - Trend detection
 
 
348
  """)
349
 
350
  with gr.Row():
351
  with gr.Column(scale=1):
352
  file = gr.File(
353
  label="Upload Data (CSV or Excel)",
354
+ file_types=[".csv", ".xlsx", ".xls"]
 
355
  )
356
  api_key = gr.Textbox(
357
  label="OpenAI API Key",
358
  type="password",
359
+ placeholder="Enter your API key"
 
360
  )
361
 
362
  with gr.Column(scale=2):
363
  chatbot = gr.Chatbot(
364
  height=400,
365
+ elem_classes="chat-message"
366
+ )
367
+ message = gr.Textbox(
368
+ label="Ask about your data",
369
+ placeholder="e.g., Show me trends in the data",
370
+ lines=2
371
  )
372
+ send = gr.Button("Send")
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  # Plot output area
375
  plot_output = gr.HTML(
376
  label="Visualizations",
377
+ elem_classes="plot-container"
 
378
  )
379
 
380
  # Event handlers
381
  file.change(
382
+ analyzer.process_file,
383
  inputs=[file],
384
+ outputs=[chatbot]
 
 
 
 
 
 
 
 
 
385
  )
386
 
387
+ send.click(
388
+ analyzer.chat,
389
  inputs=[message, api_key],
390
  outputs=[chatbot, plot_output]
391
  )
392
 
 
 
 
 
 
 
 
393
  # Example queries
394
  gr.Examples(
395
  examples=[
396
+ ["Show me a histogram of salary distribution"],
397
+ ["Create a scatter plot of salary trends over time"],
398
+ ["Show me box plots of salaries by title"],
399
+ ["Analyze the trends and patterns in the data"],
 
 
 
 
400
  ],
401
+ inputs=message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  )
403
 
404
  return demo
405
 
406
  if __name__ == "__main__":
407
  demo = create_interface()
408
+ demo.launch()