jzou19950715 commited on
Commit
b90c312
·
verified ·
1 Parent(s): eeebaa6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -248
app.py CHANGED
@@ -1,237 +1,29 @@
1
- import base64
2
- import io
3
- import os
4
- from dataclasses import dataclass
5
- from typing import Any, Callable, Dict, List, Optional, Union
6
- import json
7
-
8
- import gradio as gr
9
- import numpy as np
10
- import pandas as pd
11
- from bokeh.plotting import figure
12
- from bokeh.layouts import column, row, layout
13
- from bokeh.models import ColumnDataSource, HoverTool, BoxSelectTool, WheelZoomTool, ResetTool
14
- from bokeh.embed import components
15
- from bokeh.resources import CDN
16
- from litellm import completion
17
-
18
- class VisualizationEngine:
19
- """Engine for creating interactive Bokeh visualizations"""
20
-
21
- def __init__(self):
22
- self.width = 600
23
- self.height = 400
24
- self.tools = "pan,box_zoom,wheel_zoom,reset,save,hover"
25
-
26
- def create_scatter(self, df: pd.DataFrame, x_col: str, y_col: str,
27
- color_col: Optional[str] = None, title: str = "") -> str:
28
- """Create an interactive scatter plot"""
29
- source = ColumnDataSource(df)
30
-
31
- p = figure(width=self.width, height=self.height, title=title, tools=self.tools)
32
-
33
- if color_col and color_col in df.columns:
34
- colors = df[color_col].astype('category').cat.codes
35
- scatter = p.scatter(x_col, y_col, source=source, color={'field': color_col, 'transform': 'category10'})
36
- else:
37
- scatter = p.scatter(x_col, y_col, source=source)
38
-
39
- p.xaxis.axis_label = x_col
40
- p.yaxis.axis_label = y_col
41
-
42
- hover = p.select(dict(type=HoverTool))
43
- hover.tooltips = [(col, f"@{col}") for col in [x_col, y_col] + ([color_col] if color_col else [])]
44
-
45
- script, div = components(p)
46
- return f"{CDN.render()}\n{div}\n{script}"
47
-
48
- def create_line(self, df: pd.DataFrame, x_col: str, y_cols: List[str], title: str = "") -> str:
49
- """Create an interactive line plot"""
50
- source = ColumnDataSource(df)
51
-
52
- p = figure(width=self.width, height=self.height, title=title, tools=self.tools)
53
-
54
- for y_col in y_cols:
55
- p.line(x_col, y_col, line_width=2, source=source, legend_label=y_col)
56
-
57
- p.xaxis.axis_label = x_col
58
- p.yaxis.axis_label = "Values"
59
- p.legend.click_policy = "hide"
60
-
61
- hover = p.select(dict(type=HoverTool))
62
- hover.tooltips = [(col, f"@{col}") for col in [x_col] + y_cols]
63
-
64
- script, div = components(p)
65
- return f"{CDN.render()}\n{div}\n{script}"
66
-
67
- def create_bar(self, df: pd.DataFrame, x_col: str, y_col: str, title: str = "") -> str:
68
- """Create an interactive bar plot"""
69
- source = ColumnDataSource(df)
70
-
71
- p = figure(width=self.width, height=self.height, title=title,
72
- tools=self.tools, x_range=df[x_col].unique().tolist())
73
-
74
- p.vbar(x=x_col, top=y_col, width=0.9, source=source)
75
-
76
- p.xaxis.axis_label = x_col
77
- p.yaxis.axis_label = y_col
78
- p.xgrid.grid_line_color = None
79
-
80
- hover = p.select(dict(type=HoverTool))
81
- hover.tooltips = [(x_col, f"@{x_col}"), (y_col, f"@{y_col}")]
82
-
83
- script, div = components(p)
84
- return f"{CDN.render()}\n{div}\n{script}"
85
-
86
- class AnalysisSession:
87
- """Maintains state and history for the analysis session"""
88
-
89
- def __init__(self):
90
- self.data: Optional[pd.DataFrame] = None
91
- self.chat_history: List[Dict[str, str]] = []
92
- self.viz_engine = VisualizationEngine()
93
-
94
- def add_message(self, role: str, content: str):
95
- """Add a message to the chat history"""
96
- self.chat_history.append({"role": role, "content": content})
97
-
98
- def get_context(self) -> str:
99
- """Get the current analysis context"""
100
- if self.data is None:
101
- return "No data loaded yet."
102
-
103
- context = f"""
104
- Current DataFrame Info:
105
- - Shape: {self.data.shape}
106
- - Columns: {', '.join(self.data.columns)}
107
- - Numeric columns: {', '.join(self.data.select_dtypes(include=[np.number]).columns)}
108
- - Categorical columns: {', '.join(self.data.select_dtypes(include=['object', 'category']).columns)}
109
- """
110
- return context
111
-
112
- class AnalysisAgent:
113
- """Enhanced agent with interactive visualization and chat capabilities"""
114
-
115
- def __init__(
116
- self,
117
- model_id: str = "gpt-4o-mini",
118
- temperature: float = 0.7,
119
- ):
120
- self.model_id = model_id
121
- self.temperature = temperature
122
- self.session = AnalysisSession()
123
-
124
- def process_query(self, query: str) -> str:
125
- """Process a user query and generate response with visualizations"""
126
- context = self.session.get_context()
127
-
128
- messages = [
129
- {"role": "system", "content": self._get_system_prompt()},
130
- *self.session.chat_history[-5:], # Include last 5 messages for context
131
- {"role": "user", "content": f"{context}\n\nUser query: {query}"}
132
- ]
133
-
134
- try:
135
- response = completion(
136
- model=self.model_id,
137
- messages=messages,
138
- temperature=self.temperature,
139
- )
140
- analysis = response.choices[0].message.content
141
-
142
- # Extract and execute any code blocks
143
- visualizations = []
144
- code_blocks = self._extract_code(analysis)
145
-
146
- for code in code_blocks:
147
- try:
148
- # Execute code and capture visualization commands
149
- result = self._execute_visualization(code)
150
- if result:
151
- visualizations.append(result)
152
- except Exception as e:
153
- visualizations.append(f"Error creating visualization: {str(e)}")
154
-
155
- # Add messages to chat history
156
- self.session.add_message("user", query)
157
- self.session.add_message("assistant", analysis)
158
-
159
- # Combine analysis and visualizations
160
- return analysis + "\n\n" + "\n".join(visualizations)
161
-
162
- except Exception as e:
163
- return f"Error: {str(e)}"
164
-
165
- def _execute_visualization(self, code: str) -> Optional[str]:
166
- """Execute visualization code and return HTML output"""
167
- try:
168
- # Create a safe namespace with necessary libraries
169
- namespace = {
170
- 'df': self.session.data,
171
- 'np': np,
172
- 'pd': pd,
173
- 'viz': self.session.viz_engine
174
- }
175
-
176
- # Execute the code
177
- exec(code, namespace)
178
-
179
- # Look for visualization result
180
- for var in namespace.values():
181
- if isinstance(var, str) and ('<script' in var or '<div' in var):
182
- return var
183
-
184
- return None
185
-
186
- except Exception as e:
187
- return f"Error executing visualization: {str(e)}"
188
-
189
- def _get_system_prompt(self) -> str:
190
- """Get system prompt with visualization capabilities"""
191
- return """You are a data analysis assistant with interactive visualization capabilities.
192
-
193
- Available visualizations:
194
- 1. Scatter plots (viz.create_scatter)
195
- 2. Line plots (viz.create_line)
196
- 3. Bar plots (viz.create_bar)
197
-
198
- The following variables are available:
199
- - df: pandas DataFrame with the current data
200
- - viz: visualization engine with plotting methods
201
- - np: numpy library
202
- - pd: pandas library
203
-
204
- When analyzing data:
205
- 1. First understand and explain the data
206
- 2. Create relevant visualizations using the viz engine
207
- 3. Provide insights based on the visualizations
208
- 4. Ask follow-up questions when appropriate
209
- 5. Use markdown for formatting
210
-
211
- Example visualization code:
212
- ```python
213
- # Create scatter plot
214
- html = viz.create_scatter(df, 'column1', 'column2', title='Analysis')
215
- print(html)
216
-
217
- # Create line plot
218
- html = viz.create_line(df, 'date_column', ['value1', 'value2'], title='Trends')
219
- print(html)
220
- ```
221
- """
222
-
223
- @staticmethod
224
- def _extract_code(text: str) -> List[str]:
225
- """Extract Python code blocks from markdown"""
226
- import re
227
- pattern = r'```python\n(.*?)```'
228
- return re.findall(pattern, text, re.DOTALL)
229
-
230
  def create_interface():
231
- """Create Gradio interface with chat capabilities"""
232
 
233
  agent = AnalysisAgent()
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  def process_file(file: gr.File) -> str:
236
  """Process uploaded file and initialize session"""
237
  try:
@@ -242,25 +34,29 @@ def create_interface():
242
  else:
243
  return "Error: Unsupported file type"
244
 
245
- return f"Successfully loaded data: {agent.session.get_context()}"
246
  except Exception as e:
247
- return f"Error loading file: {str(e)}"
248
 
249
  def analyze(file: gr.File, query: str, api_key: str) -> str:
250
  """Process analysis query"""
251
  if not api_key:
252
- return "Error: Please provide an API key."
253
 
254
  if not file:
255
- return "Error: Please upload a file."
256
 
257
  try:
258
  os.environ["OPENAI_API_KEY"] = api_key
259
- return agent.process_query(query)
 
260
  except Exception as e:
261
- return f"Error: {str(e)}"
262
 
263
- with gr.Blocks(title="Interactive Data Analysis Assistant") as interface:
 
 
 
264
  gr.Markdown("""
265
  # Interactive Data Analysis Assistant
266
 
@@ -274,7 +70,7 @@ def create_interface():
274
  """)
275
 
276
  with gr.Row():
277
- with gr.Column():
278
  file = gr.File(
279
  label="Upload Data File",
280
  file_types=[".csv", ".xlsx", ".xls"]
@@ -290,8 +86,11 @@ def create_interface():
290
  )
291
  analyze_btn = gr.Button("Analyze")
292
 
293
- with gr.Column():
294
- chat_output = gr.HTML(label="Analysis & Visualizations")
 
 
 
295
 
296
  # Set up event handlers
297
  file.change(process_file, inputs=[file], outputs=[chat_output])
@@ -305,15 +104,11 @@ def create_interface():
305
  gr.Examples(
306
  examples=[
307
  [None, "Show me the distribution of numerical variables"],
308
- [None, "Create a correlation analysis with interactive visualizations"],
309
- [None, "What are the main trends in the data?"],
310
- [None, "Can you identify any interesting patterns?"],
311
  ],
312
  inputs=[file, chat_input]
313
  )
314
 
315
- return interface
316
-
317
- if __name__ == "__main__":
318
- interface = create_interface()
319
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def create_interface():
2
+ """Create Gradio interface with proper HTML rendering"""
3
 
4
  agent = AnalysisAgent()
5
 
6
+ def format_html_output(content: str) -> str:
7
+ """Format the output to properly render HTML in Gradio"""
8
+ # Split content into text and HTML parts
9
+ parts = content.split('<!DOCTYPE html>')
10
+
11
+ if len(parts) == 1:
12
+ # No HTML content
13
+ return f'<div style="padding: 20px;">{content}</div>'
14
+
15
+ formatted_parts = []
16
+ for i, part in enumerate(parts):
17
+ if i == 0:
18
+ # Text content
19
+ if part.strip():
20
+ formatted_parts.append(f'<div style="padding: 20px;">{part}</div>')
21
+ else:
22
+ # HTML visualization
23
+ formatted_parts.append(f'<!DOCTYPE html>{part}')
24
+
25
+ return '\n'.join(formatted_parts)
26
+
27
  def process_file(file: gr.File) -> str:
28
  """Process uploaded file and initialize session"""
29
  try:
 
34
  else:
35
  return "Error: Unsupported file type"
36
 
37
+ return format_html_output(f"Successfully loaded data: {agent.session.get_context()}")
38
  except Exception as e:
39
+ return format_html_output(f"Error loading file: {str(e)}")
40
 
41
  def analyze(file: gr.File, query: str, api_key: str) -> str:
42
  """Process analysis query"""
43
  if not api_key:
44
+ return format_html_output("Error: Please provide an API key.")
45
 
46
  if not file:
47
+ return format_html_output("Error: Please upload a file.")
48
 
49
  try:
50
  os.environ["OPENAI_API_KEY"] = api_key
51
+ result = agent.process_query(query)
52
+ return format_html_output(result)
53
  except Exception as e:
54
+ return format_html_output(f"Error: {str(e)}")
55
 
56
+ with gr.Blocks(css="""
57
+ .plot-container { margin: 20px 0; }
58
+ .bokeh-plot { margin: 20px auto; }
59
+ """) as interface:
60
  gr.Markdown("""
61
  # Interactive Data Analysis Assistant
62
 
 
70
  """)
71
 
72
  with gr.Row():
73
+ with gr.Column(scale=1):
74
  file = gr.File(
75
  label="Upload Data File",
76
  file_types=[".csv", ".xlsx", ".xls"]
 
86
  )
87
  analyze_btn = gr.Button("Analyze")
88
 
89
+ with gr.Column(scale=2):
90
+ chat_output = gr.HTML(
91
+ label="Analysis & Visualizations",
92
+ elem_classes="plot-container"
93
+ )
94
 
95
  # Set up event handlers
96
  file.change(process_file, inputs=[file], outputs=[chat_output])
 
104
  gr.Examples(
105
  examples=[
106
  [None, "Show me the distribution of numerical variables"],
107
+ [None, "Create an interactive visualization of the relationships between variables"],
108
+ [None, "Analyze trends in the data over time"],
109
+ [None, "Compare different categories using interactive charts"],
110
  ],
111
  inputs=[file, chat_input]
112
  )
113
 
114
+ return interface