jzou19950715 commited on
Commit
927667b
·
verified ·
1 Parent(s): 224e4cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -220
app.py CHANGED
@@ -12,7 +12,6 @@ import plotly.express as px
12
  import plotly.graph_objects as go
13
  from plotly.subplots import make_subplots
14
  from litellm import completion
15
-
16
  class DataAnalyzer:
17
  """Handles data analysis and visualization"""
18
 
@@ -20,133 +19,133 @@ class DataAnalyzer:
20
  self.data: Optional[pd.DataFrame] = None
21
  self.width = 800
22
  self.height = 500
23
- self.template = "plotly_white"
24
-
25
- def create_scatter(self, x_col: str, y_col: str, color_col: Optional[str] = None,
26
- title: str = "") -> go.Figure:
27
- """Create scatter plot"""
28
- fig = px.scatter(
29
- self.data,
30
- x=x_col,
31
- y=y_col,
32
- color=color_col,
 
 
 
 
 
33
  title=title,
34
- template=self.template,
 
 
35
  height=self.height,
36
- width=self.width
37
  )
38
- fig.update_layout(
39
- hovermode='closest',
40
- showlegend=True if color_col else False
41
- )
42
- return fig
43
 
44
- def create_line(self, x_col: str, y_cols: List[str], title: str = "") -> go.Figure:
45
- """Create line plot"""
 
 
 
 
46
  fig = go.Figure()
47
 
48
- for y_col in y_cols:
49
- fig.add_trace(
50
- go.Scatter(
51
- x=self.data[x_col],
52
- y=self.data[y_col],
53
- name=y_col,
54
- mode='lines+markers'
55
- )
56
- )
 
 
 
 
 
 
 
57
 
58
  fig.update_layout(
59
  title=title,
60
- template=self.template,
61
- height=self.height,
62
  width=self.width,
63
- hovermode='x unified',
64
- showlegend=True
65
- )
66
- return fig
67
-
68
- def create_bar(self, x_col: str, y_col: str, color_col: Optional[str] = None,
69
- title: str = "") -> go.Figure:
70
- """Create bar plot"""
71
- fig = px.bar(
72
- self.data,
73
- x=x_col,
74
- y=y_col,
75
- color=color_col,
76
- title=title,
77
- template=self.template,
78
  height=self.height,
79
- width=self.width
80
- )
81
- fig.update_layout(
82
- hovermode='closest',
83
- showlegend=True if color_col else False
84
  )
85
- return fig
 
86
 
87
- def create_histogram(self, column: str, bins: int = 30,
88
- title: str = "") -> go.Figure:
89
- """Create histogram"""
90
- fig = px.histogram(
91
- self.data,
92
- x=column,
93
- nbins=bins,
 
 
 
 
 
 
 
 
 
 
 
94
  title=title,
95
- template=self.template,
96
- height=self.height,
97
  width=self.width,
98
- marginal="box" # Add box plot on the margin
99
- )
100
- fig.update_layout(
101
- hovermode='closest',
102
  showlegend=False
103
  )
104
- return fig
 
105
 
106
- def create_box(self, x_col: str, y_col: str, color_col: Optional[str] = None,
107
- title: str = "") -> go.Figure:
108
- """Create box plot"""
109
- fig = px.box(
110
- self.data,
111
- x=x_col,
112
- y=y_col,
113
- color=color_col,
114
- title=title,
115
- template=self.template,
116
- height=self.height,
117
- width=self.width,
118
- points="all" # Show all points
119
- )
 
 
 
 
 
 
 
 
 
 
120
  fig.update_layout(
121
- hovermode='closest',
122
- showlegend=True if color_col else False
123
- )
124
- return fig
125
-
126
- def create_correlation_matrix(self, title: str = "") -> go.Figure:
127
- """Create correlation matrix"""
128
- # Get numeric columns
129
- numeric_cols = self.data.select_dtypes(include=[np.number]).columns
130
- corr_matrix = self.data[numeric_cols].corr()
131
-
132
- fig = px.imshow(
133
- corr_matrix,
134
  title=title,
135
- template=self.template,
136
- height=self.height,
137
  width=self.width,
138
- color_continuous_scale="RdBu",
139
- aspect="auto",
140
- labels=dict(color="Correlation")
141
  )
142
 
143
- # Update layout for better readability
144
- fig.update_traces(text=corr_matrix.round(2), texttemplate="%{text}")
145
- fig.update_layout(
146
- xaxis_title="",
147
- yaxis_title=""
148
- )
149
- return fig
150
 
151
  class ChatAnalyzer:
152
  """Handles chat-based analysis with visualization"""
@@ -190,94 +189,65 @@ class ChatAnalyzer:
190
  return self.history
191
 
192
  def chat(self, message: str, api_key: str) -> Tuple[List[Tuple[str, str]], str]:
193
- """Process chat message and generate visualizations"""
194
- if self.analyzer.data is None:
195
- return [(message, "Please upload a data file first.")], ""
196
-
197
- if not api_key:
198
- return [(message, "Please provide an OpenAI API key.")], ""
199
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  try:
201
- os.environ["OPENAI_API_KEY"] = api_key
202
-
203
- # Get data context
204
- context = self._get_data_context()
205
-
206
- # Get AI response
207
- completion_response = completion(
208
- model="gpt-4o-mini",
209
- messages=[
210
- {"role": "system", "content": self._get_system_prompt()},
211
- {"role": "user", "content": f"{context}\n\nUser question: {message}"}
212
- ],
213
- temperature=0.7
214
- )
215
-
216
- analysis = completion_response.choices[0].message.content
217
 
218
- # Create visualizations
219
- plot_output = ""
220
- try:
221
- # Extract code blocks
222
- import re
223
- code_blocks = re.findall(r'```python\n(.*?)```', analysis, re.DOTALL)
 
224
 
225
- for code in code_blocks:
226
- # Create namespace for execution
227
- namespace = {
228
- 'df': self.analyzer.data,
229
- 'px': px,
230
- 'go': go,
231
- 'pd': pd,
232
- 'np': np,
233
- 'analyzer': self.analyzer
234
- }
235
-
236
- # Execute the code
237
  exec(code, namespace)
238
-
239
- # Look for figure object in namespace
240
- for var in namespace.values():
241
- if isinstance(var, (go.Figure, px.Figure)):
242
- try:
243
- # Try interactive HTML first
244
- html = var.to_html(
245
- include_plotlyjs=True,
246
- full_html=False,
247
- config={
248
- 'displayModeBar': True,
249
- 'responsive': True
250
- }
251
- )
252
- plot_output += f'''
253
- <div class="plot-container">
254
- <div style="overflow-x: auto;">{html}</div>
255
- </div>
256
- '''
257
- except Exception as e:
258
- # Fallback to static image
259
- buffer = io.BytesIO()
260
- var.write_image(buffer, format='png')
261
- buffer.seek(0)
262
- image = base64.b64encode(buffer.read()).decode()
263
- plot_output += f'''
264
- <div class="plot-container">
265
- <img src="data:image/png;base64,{image}"
266
- style="max-width: 100%; height: auto;">
267
- </div>
268
- '''
269
 
270
- except Exception as e:
271
- analysis += f"\n\nError creating visualization: {str(e)}"
272
-
273
- # Update chat history
274
- self.history.append((message, analysis))
275
-
276
- return self.history, plot_output
277
-
278
  except Exception as e:
279
- self.history.append((message, f"Error: {str(e)}"))
280
- return self.history, ""
 
 
 
 
 
 
 
 
281
 
282
  def _get_data_context(self) -> str:
283
  """Get current data context for AI"""
@@ -309,51 +279,45 @@ class ChatAnalyzer:
309
  """
310
 
311
  def _get_system_prompt(self) -> str:
312
- """Get system prompt for AI"""
313
- return """You are a data analysis assistant specialized in creating interactive visualizations.
314
 
315
  Available visualization functions:
316
- 1. analyzer.create_scatter(x_col, y_col, color_col, title)
317
- 2. analyzer.create_line(x_col, y_cols, title)
318
- 3. analyzer.create_bar(x_col, y_col, color_col, title)
319
- 4. analyzer.create_histogram(column, bins, title)
320
- 5. analyzer.create_box(x_col, y_col, color_col, title)
321
- 6. analyzer.create_correlation_matrix(title)
322
-
323
- When analyzing data:
324
- 1. First understand the data type and relationships
325
- 2. Choose appropriate visualizations
326
- 3. Provide insights and analysis
327
- 4. Suggest follow-up analyses
328
 
329
  Example usage:
330
  ```python
331
- # Create scatter plot
332
- fig = analyzer.create_scatter(
333
- x_col='Date',
334
- y_col='Value',
335
- color_col='Category',
336
- title='Value Trends by Category'
337
  )
338
- print(fig)
339
 
340
- # Create multiple visualizations
341
- fig1 = analyzer.create_histogram(
342
- column='Value',
343
- bins=30,
344
- title='Value Distribution'
 
345
  )
346
- print(fig1)
347
 
348
- fig2 = analyzer.create_box(
349
- x_col='Category',
350
- y_col='Value',
351
- title='Value Distribution by Category'
 
352
  )
353
- print(fig2)
354
  ```
355
 
356
- Always wrap code in Python code blocks and print the figures to display them."""
 
357
 
358
  def create_interface():
359
  """Create Gradio interface"""
 
12
  import plotly.graph_objects as go
13
  from plotly.subplots import make_subplots
14
  from litellm import completion
 
15
  class DataAnalyzer:
16
  """Handles data analysis and visualization"""
17
 
 
19
  self.data: Optional[pd.DataFrame] = None
20
  self.width = 800
21
  self.height = 500
22
+
23
+ def create_histogram(self, column: str, bins: int = 30, title: str = "") -> str:
24
+ """Create histogram with Plotly"""
25
+ if self.data is None:
26
+ raise ValueError("No data loaded")
27
+
28
+ fig = go.Figure()
29
+
30
+ fig.add_trace(go.Histogram(
31
+ x=self.data[column],
32
+ nbinsx=bins,
33
+ name=column
34
+ ))
35
+
36
+ fig.update_layout(
37
  title=title,
38
+ xaxis_title=column,
39
+ yaxis_title="Count",
40
+ width=self.width,
41
  height=self.height,
42
+ template="plotly_white"
43
  )
44
+
45
+ # Convert to HTML string
46
+ return fig.to_html(include_plotlyjs=True, full_html=False)
 
 
47
 
48
+ def create_scatter(self, x_col: str, y_col: str, color_col: Optional[str] = None,
49
+ title: str = "") -> str:
50
+ """Create scatter plot with Plotly"""
51
+ if self.data is None:
52
+ raise ValueError("No data loaded")
53
+
54
  fig = go.Figure()
55
 
56
+ if color_col:
57
+ for category in self.data[color_col].unique():
58
+ mask = self.data[color_col] == category
59
+ fig.add_trace(go.Scatter(
60
+ x=self.data[mask][x_col],
61
+ y=self.data[mask][y_col],
62
+ mode='markers',
63
+ name=str(category),
64
+ text=self.data[mask][color_col]
65
+ ))
66
+ else:
67
+ fig.add_trace(go.Scatter(
68
+ x=self.data[x_col],
69
+ y=self.data[y_col],
70
+ mode='markers'
71
+ ))
72
 
73
  fig.update_layout(
74
  title=title,
75
+ xaxis_title=x_col,
76
+ yaxis_title=y_col,
77
  width=self.width,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  height=self.height,
79
+ template="plotly_white",
80
+ hovermode='closest'
 
 
 
81
  )
82
+
83
+ return fig.to_html(include_plotlyjs=True, full_html=False)
84
 
85
+ def create_box(self, x_col: str, y_col: str, title: str = "") -> str:
86
+ """Create box plot with Plotly"""
87
+ if self.data is None:
88
+ raise ValueError("No data loaded")
89
+
90
+ fig = go.Figure()
91
+
92
+ # Create box plot for each category
93
+ for category in self.data[x_col].unique():
94
+ fig.add_trace(go.Box(
95
+ y=self.data[self.data[x_col] == category][y_col],
96
+ name=str(category),
97
+ boxpoints='all', # show all points
98
+ jitter=0.3,
99
+ pointpos=-1.8
100
+ ))
101
+
102
+ fig.update_layout(
103
  title=title,
104
+ yaxis_title=y_col,
105
+ xaxis_title=x_col,
106
  width=self.width,
107
+ height=self.height,
108
+ template="plotly_white",
 
 
109
  showlegend=False
110
  )
111
+
112
+ return fig.to_html(include_plotlyjs=True, full_html=False)
113
 
114
+ def create_line(self, x_col: str, y_col: str, color_col: Optional[str] = None,
115
+ title: str = "") -> str:
116
+ """Create line plot with Plotly"""
117
+ if self.data is None:
118
+ raise ValueError("No data loaded")
119
+
120
+ fig = go.Figure()
121
+
122
+ if color_col:
123
+ for category in self.data[color_col].unique():
124
+ mask = self.data[color_col] == category
125
+ fig.add_trace(go.Scatter(
126
+ x=self.data[mask][x_col],
127
+ y=self.data[mask][y_col],
128
+ mode='lines+markers',
129
+ name=str(category)
130
+ ))
131
+ else:
132
+ fig.add_trace(go.Scatter(
133
+ x=self.data[x_col],
134
+ y=self.data[y_col],
135
+ mode='lines+markers'
136
+ ))
137
+
138
  fig.update_layout(
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  title=title,
140
+ xaxis_title=x_col,
141
+ yaxis_title=y_col,
142
  width=self.width,
143
+ height=self.height,
144
+ template="plotly_white",
145
+ hovermode='x unified'
146
  )
147
 
148
+ return fig.to_html(include_plotlyjs=True, full_html=False)
 
 
 
 
 
 
149
 
150
  class ChatAnalyzer:
151
  """Handles chat-based analysis with visualization"""
 
189
  return self.history
190
 
191
  def chat(self, message: str, api_key: str) -> Tuple[List[Tuple[str, str]], str]:
192
+ """Process chat message and generate visualizations"""
193
+ if self.analyzer.data is None:
194
+ return [(message, "Please upload a data file first.")], ""
195
+
196
+ if not api_key:
197
+ return [(message, "Please provide an OpenAI API key.")], ""
198
+
199
+ try:
200
+ os.environ["OPENAI_API_KEY"] = api_key
201
+
202
+ # Get data context
203
+ context = self._get_data_context()
204
+
205
+ # Get AI response
206
+ completion_response = completion(
207
+ model="gpt-4o-mini",
208
+ messages=[
209
+ {"role": "system", "content": self._get_system_prompt()},
210
+ {"role": "user", "content": f"{context}\n\nUser question: {message}"}
211
+ ],
212
+ temperature=0.7
213
+ )
214
+
215
+ analysis = completion_response.choices[0].message.content
216
+
217
+ # Create visualizations
218
+ plots_html = ""
219
  try:
220
+ # Extract code blocks
221
+ import re
222
+ code_blocks = re.findall(r'```python\n(.*?)```', analysis, re.DOTALL)
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ for code in code_blocks:
225
+ # Create namespace for execution
226
+ namespace = {
227
+ 'analyzer': self.analyzer,
228
+ 'df': self.analyzer.data,
229
+ 'print': lambda x: x
230
+ }
231
 
232
+ # Execute the code
233
+ try:
234
+ result = eval(code, namespace)
235
+ if isinstance(result, str) and ('<div' in result or '<script' in result):
236
+ plots_html += f'<div class="plot-container">{result}</div>'
237
+ except:
 
 
 
 
 
 
238
  exec(code, namespace)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
 
 
 
 
 
 
 
 
240
  except Exception as e:
241
+ analysis += f"\n\nError creating visualization: {str(e)}"
242
+
243
+ # Update chat history
244
+ self.history.append((message, analysis))
245
+
246
+ return self.history, plots_html
247
+
248
+ except Exception as e:
249
+ self.history.append((message, f"Error: {str(e)}"))
250
+ return self.history, ""
251
 
252
  def _get_data_context(self) -> str:
253
  """Get current data context for AI"""
 
279
  """
280
 
281
  def _get_system_prompt(self) -> str:
282
+ """Get system prompt for AI"""
283
+ return """You are a data analysis assistant specialized in creating interactive visualizations.
284
 
285
  Available visualization functions:
286
+ 1. create_histogram(column, bins, title) - For distribution analysis
287
+ 2. create_scatter(x_col, y_col, color_col, title) - For relationship analysis
288
+ 3. create_box(x_col, y_col, title) - For categorical comparisons
289
+ 4. create_line(x_col, y_col, color_col, title) - For trend analysis
 
 
 
 
 
 
 
 
290
 
291
  Example usage:
292
  ```python
293
+ # Create histogram
294
+ result = analyzer.create_histogram(
295
+ column='Salary',
296
+ bins=20,
297
+ title='Salary Distribution'
 
298
  )
299
+ print(result)
300
 
301
+ # Create scatter plot
302
+ result = analyzer.create_scatter(
303
+ x_col='Date',
304
+ y_col='Salary',
305
+ color_col='Title',
306
+ title='Salary Trends by Title'
307
  )
308
+ print(result)
309
 
310
+ # Create box plot
311
+ result = analyzer.create_box(
312
+ x_col='Title',
313
+ y_col='Salary',
314
+ title='Salary Distribution by Title'
315
  )
316
+ print(result)
317
  ```
318
 
319
+ Always wrap code in Python code blocks and use print() to display the visualizations.
320
+ Provide analysis and insights about what the visualizations show."""
321
 
322
  def create_interface():
323
  """Create Gradio interface"""