jzou19950715 commited on
Commit
c8ef941
·
verified ·
1 Parent(s): ca77b8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -159
app.py CHANGED
@@ -11,265 +11,213 @@ from plotly.subplots import make_subplots
11
  from litellm import completion
12
 
13
  class DataAnalyzer:
14
- """Handles data analysis and visualization"""
15
-
16
  def __init__(self):
17
  self.data: Optional[pd.DataFrame] = None
18
 
19
- def create_visualization(self, plot_type: str, **kwargs) -> go.Figure:
20
- """Create different types of plotly visualizations"""
21
  if self.data is None:
22
  raise ValueError("No data loaded")
23
 
24
  if plot_type == "scatter":
25
  fig = px.scatter(
26
- self.data, x=kwargs.get('x'), y=kwargs.get('y'),
27
- color=kwargs.get('color'),
 
28
  title=kwargs.get('title', 'Scatter Plot'),
29
- labels=kwargs.get('labels', {}),
30
- trendline=kwargs.get('trendline'),
31
  )
32
-
33
  elif plot_type == "line":
34
  fig = px.line(
35
- self.data, x=kwargs.get('x'), y=kwargs.get('y'),
36
- color=kwargs.get('color'),
37
- title=kwargs.get('title', 'Line Plot'),
38
- labels=kwargs.get('labels', {})
39
  )
40
-
41
  elif plot_type == "bar":
42
  fig = px.bar(
43
- self.data, x=kwargs.get('x'), y=kwargs.get('y'),
44
- color=kwargs.get('color'),
45
- title=kwargs.get('title', 'Bar Plot'),
46
- labels=kwargs.get('labels', {})
47
  )
48
-
49
  elif plot_type == "histogram":
50
  fig = px.histogram(
51
- self.data, x=kwargs.get('x'),
52
- nbins=kwargs.get('bins', 30),
53
- title=kwargs.get('title', 'Histogram'),
54
- marginal=kwargs.get('marginal', 'box')
55
  )
56
-
57
  elif plot_type == "box":
58
  fig = px.box(
59
- self.data, x=kwargs.get('x'), y=kwargs.get('y'),
60
- color=kwargs.get('color'),
 
61
  title=kwargs.get('title', 'Box Plot')
62
  )
63
-
64
- elif plot_type == "violin":
65
- fig = px.violin(
66
- self.data, x=kwargs.get('x'), y=kwargs.get('y'),
67
- color=kwargs.get('color'),
68
- box=True,
69
- title=kwargs.get('title', 'Violin Plot')
70
- )
71
-
72
- elif plot_type == "correlation":
73
- corr = self.data.select_dtypes(include=[np.number]).corr()
74
- fig = px.imshow(
75
- corr,
76
- title=kwargs.get('title', 'Correlation Matrix'),
77
- color_continuous_scale="RdBu"
78
- )
79
-
80
  else:
81
  raise ValueError(f"Unknown plot type: {plot_type}")
82
-
83
- # Update layout for better interactivity
84
- fig.update_layout(
85
- hovermode='x unified',
86
- template='plotly_white',
87
- height=500,
88
- )
89
-
90
  return fig
91
 
92
  class ChatAnalyzer:
93
- """Handles chat-based analysis with visualization"""
94
-
95
  def __init__(self):
96
  self.analyzer = DataAnalyzer()
97
- self.chat_history: List[Tuple[str, str]] = []
98
 
99
- def process_file(self, file: gr.File) -> str:
100
- """Process uploaded file"""
101
  try:
102
  if file.name.endswith('.csv'):
103
  self.analyzer.data = pd.read_csv(file.name)
104
  elif file.name.endswith(('.xlsx', '.xls')):
105
  self.analyzer.data = pd.read_excel(file.name)
106
  else:
107
- return "Error: Please upload a CSV or Excel file."
 
 
 
 
108
 
109
- info = f"""
110
- Successfully loaded data with shape: {self.analyzer.data.shape}
111
- Columns: {', '.join(self.analyzer.data.columns)}
112
- """
113
- return info
114
 
115
  except Exception as e:
116
- return f"Error loading file: {str(e)}"
 
117
 
118
- def analyze(self, message: str, api_key: str) -> Tuple[str, List[go.Figure]]:
119
- """Analyze data based on user message"""
120
  if self.analyzer.data is None:
121
- return "Please upload a data file first.", []
122
 
123
  if not api_key:
124
- return "Please provide an OpenAI API key.", []
125
 
126
  try:
127
  os.environ["OPENAI_API_KEY"] = api_key
128
 
129
- # Prepare context for AI
130
  context = self._get_data_context()
131
 
132
  # Get AI response
133
- messages = [
134
- {"role": "system", "content": self._get_system_prompt()},
135
- {"role": "user", "content": f"{context}\n\nUser request: {message}"}
136
- ]
137
-
138
- response = completion(
139
  model="gpt-4o-mini",
140
- messages=messages,
 
 
 
141
  temperature=0.7
142
  )
143
 
144
- analysis = response.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- # Extract visualization commands and create plots
147
- figures = self._create_visualizations(analysis)
148
 
149
- return analysis, figures
150
 
151
  except Exception as e:
152
- return f"Error during analysis: {str(e)}", []
 
153
 
154
  def _get_data_context(self) -> str:
155
- """Get current data context"""
156
  df = self.analyzer.data
157
  numeric_cols = df.select_dtypes(include=[np.number]).columns
158
  categorical_cols = df.select_dtypes(include=['object', 'category']).columns
159
 
160
  return f"""
161
- Available Data Information:
162
  - Shape: {df.shape}
163
  - Numeric columns: {', '.join(numeric_cols)}
164
  - Categorical columns: {', '.join(categorical_cols)}
165
-
166
- Basic Statistics:
167
- {df.describe().to_string()}
168
  """
169
 
170
  def _get_system_prompt(self) -> str:
171
- """Get system prompt"""
172
- return """You are a data analysis assistant specialized in creating interactive visualizations using Plotly.
173
-
174
- Available plot types:
175
- 1. scatter - for relationships between variables
176
- 2. line - for trends over time
177
- 3. bar - for comparisons between categories
178
- 4. histogram - for distributions
179
- 5. box - for statistical summaries
180
- 6. violin - for distribution comparisons
181
- 7. correlation - for correlation matrix
182
 
183
- When creating visualizations:
184
- 1. Specify the plot type and required parameters
185
- 2. Provide insights about the visualization
186
- 3. Suggest follow-up analyses
187
- 4. Use markdown for formatting
188
-
189
- Example command format:
190
  ```python
191
  # Create scatter plot
192
- plot = viz.create_visualization("scatter", x="column1", y="column2", title="Analysis")
193
- print(plot)
 
 
 
 
194
  ```
195
- """
196
-
197
- def _create_visualizations(self, analysis: str) -> List[go.Figure]:
198
- """Extract and create visualizations from analysis"""
199
- figures = []
200
- viz = self.analyzer
201
-
202
- try:
203
- # Execute visualization commands in the analysis
204
- exec_globals = {
205
- 'viz': viz,
206
- 'print': lambda x: figures.append(x) if isinstance(x, go.Figure) else None
207
- }
208
-
209
- # Extract and execute code blocks
210
- import re
211
- code_blocks = re.findall(r'```python\n(.*?)```', analysis, re.DOTALL)
212
-
213
- for code in code_blocks:
214
- exec(code, exec_globals)
215
-
216
- except Exception as e:
217
- print(f"Error creating visualizations: {str(e)}")
218
-
219
- return figures
220
 
221
  def create_interface():
222
- """Create Gradio interface"""
223
-
224
  analyzer = ChatAnalyzer()
225
 
226
- def chat(message: str, api_key: str) -> Tuple[List[Tuple[str, str]], List[gr.Plot]]:
227
- """Handle chat interaction"""
228
- response, figures = analyzer.analyze(message, api_key)
229
-
230
- # Update chat history
231
- analyzer.chat_history.append((message, response))
232
-
233
- # Convert figures to Gradio plots
234
- plots = [gr.Plot(fig) for fig in figures]
235
-
236
- return analyzer.chat_history, plots
237
-
238
  with gr.Blocks() as demo:
239
  gr.Markdown("""
240
  # Interactive Data Analysis Chat
241
- Upload your data and chat with AI to create interactive visualizations!
242
  """)
243
 
244
  with gr.Row():
245
  with gr.Column(scale=1):
246
  file = gr.File(label="Upload Data (CSV or Excel)")
247
- api_key = gr.Textbox(label="OpenAI API Key", type="password")
248
-
 
 
 
 
249
  with gr.Column(scale=2):
250
  chatbot = gr.Chatbot(height=400)
251
-
252
- with gr.Row():
253
- message = gr.Textbox(label="Ask about your data", lines=2)
254
- send = gr.Button("Send")
 
255
 
256
  # Plot output area
257
- plot_output = gr.Plot(visible=False)
 
 
 
 
 
 
 
258
 
259
- # Set up event handlers
260
- file.change(analyzer.process_file, inputs=[file], outputs=[chatbot])
261
  send.click(
262
- chat,
263
  inputs=[message, api_key],
264
  outputs=[chatbot, plot_output]
265
  )
266
 
 
267
  gr.Examples(
268
  examples=[
269
- ["Show me a scatter plot of the main numerical variables"],
270
- ["Create a correlation matrix of all numerical columns"],
271
- ["Analyze the distribution of each variable"],
272
- ["Show trends over time if there's temporal data"],
273
  ],
274
  inputs=message
275
  )
 
11
  from litellm import completion
12
 
13
  class DataAnalyzer:
 
 
14
  def __init__(self):
15
  self.data: Optional[pd.DataFrame] = None
16
 
17
+ def create_plot(self, plot_type: str, **kwargs) -> go.Figure:
 
18
  if self.data is None:
19
  raise ValueError("No data loaded")
20
 
21
  if plot_type == "scatter":
22
  fig = px.scatter(
23
+ self.data,
24
+ x=kwargs.get('x'),
25
+ y=kwargs.get('y'),
26
  title=kwargs.get('title', 'Scatter Plot'),
27
+ color=kwargs.get('color')
 
28
  )
 
29
  elif plot_type == "line":
30
  fig = px.line(
31
+ self.data,
32
+ x=kwargs.get('x'),
33
+ y=kwargs.get('y'),
34
+ title=kwargs.get('title', 'Line Plot')
35
  )
 
36
  elif plot_type == "bar":
37
  fig = px.bar(
38
+ self.data,
39
+ x=kwargs.get('x'),
40
+ y=kwargs.get('y'),
41
+ title=kwargs.get('title', 'Bar Plot')
42
  )
 
43
  elif plot_type == "histogram":
44
  fig = px.histogram(
45
+ self.data,
46
+ x=kwargs.get('x'),
47
+ title=kwargs.get('title', 'Distribution')
 
48
  )
 
49
  elif plot_type == "box":
50
  fig = px.box(
51
+ self.data,
52
+ x=kwargs.get('x'),
53
+ y=kwargs.get('y'),
54
  title=kwargs.get('title', 'Box Plot')
55
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  else:
57
  raise ValueError(f"Unknown plot type: {plot_type}")
58
+
 
 
 
 
 
 
 
59
  return fig
60
 
61
  class ChatAnalyzer:
 
 
62
  def __init__(self):
63
  self.analyzer = DataAnalyzer()
64
+ self.history: List[Tuple[str, str]] = []
65
 
66
+ def process_file(self, file: gr.File) -> List[Tuple[str, str]]:
 
67
  try:
68
  if file.name.endswith('.csv'):
69
  self.analyzer.data = pd.read_csv(file.name)
70
  elif file.name.endswith(('.xlsx', '.xls')):
71
  self.analyzer.data = pd.read_excel(file.name)
72
  else:
73
+ return [("System", "Error: Please upload a CSV or Excel file.")]
74
+
75
+ info = f"""Data loaded successfully!
76
+ Shape: {self.analyzer.data.shape}
77
+ Columns: {', '.join(self.analyzer.data.columns)}"""
78
 
79
+ self.history = [("System", info)]
80
+ return self.history
 
 
 
81
 
82
  except Exception as e:
83
+ self.history = [("System", f"Error loading file: {str(e)}")]
84
+ return self.history
85
 
86
+ def chat(self, message: str, api_key: str) -> Tuple[List[Tuple[str, str]], List[gr.Plot]]:
 
87
  if self.analyzer.data is None:
88
+ return [(message, "Please upload a data file first.")], []
89
 
90
  if not api_key:
91
+ return [(message, "Please provide an OpenAI API key.")], []
92
 
93
  try:
94
  os.environ["OPENAI_API_KEY"] = api_key
95
 
96
+ # Get data context
97
  context = self._get_data_context()
98
 
99
  # Get AI response
100
+ completion_response = completion(
 
 
 
 
 
101
  model="gpt-4o-mini",
102
+ messages=[
103
+ {"role": "system", "content": self._get_system_prompt()},
104
+ {"role": "user", "content": f"{context}\n\nUser question: {message}"}
105
+ ],
106
  temperature=0.7
107
  )
108
 
109
+ analysis = completion_response.choices[0].message.content
110
+
111
+ # Create visualizations
112
+ figures = []
113
+ try:
114
+ # Execute any visualization commands in the analysis
115
+ exec_globals = {
116
+ 'analyzer': self.analyzer,
117
+ 'df': self.analyzer.data,
118
+ 'px': px,
119
+ 'go': go,
120
+ 'print': lambda x: figures.append(x) if isinstance(x, go.Figure) else None
121
+ }
122
+
123
+ # Extract code blocks
124
+ import re
125
+ code_blocks = re.findall(r'```python\n(.*?)```', analysis, re.DOTALL)
126
+
127
+ for code in code_blocks:
128
+ exec(code, exec_globals)
129
+
130
+ except Exception as e:
131
+ analysis += f"\n\nError creating visualization: {str(e)}"
132
+
133
+ # Update chat history
134
+ self.history.append((message, analysis))
135
 
136
+ # Convert figures to Gradio plots
137
+ plots = [gr.Plot(fig) for fig in figures]
138
 
139
+ return self.history, plots
140
 
141
  except Exception as e:
142
+ self.history.append((message, f"Error: {str(e)}"))
143
+ return self.history, []
144
 
145
  def _get_data_context(self) -> str:
 
146
  df = self.analyzer.data
147
  numeric_cols = df.select_dtypes(include=[np.number]).columns
148
  categorical_cols = df.select_dtypes(include=['object', 'category']).columns
149
 
150
  return f"""
151
+ Data Information:
152
  - Shape: {df.shape}
153
  - Numeric columns: {', '.join(numeric_cols)}
154
  - Categorical columns: {', '.join(categorical_cols)}
 
 
 
155
  """
156
 
157
  def _get_system_prompt(self) -> str:
158
+ return """You are a data analysis assistant. To create visualizations, use Python code blocks with Plotly.
 
 
 
 
 
 
 
 
 
 
159
 
160
+ Example commands:
 
 
 
 
 
 
161
  ```python
162
  # Create scatter plot
163
+ fig = px.scatter(df, x='column1', y='column2', title='Analysis')
164
+ print(fig)
165
+
166
+ # Create histogram
167
+ fig = px.histogram(df, x='column', title='Distribution')
168
+ print(fig)
169
  ```
170
+
171
+ Provide analysis and insights along with visualizations."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  def create_interface():
 
 
174
  analyzer = ChatAnalyzer()
175
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  with gr.Blocks() as demo:
177
  gr.Markdown("""
178
  # Interactive Data Analysis Chat
179
+ Upload your data and chat with AI to analyze it!
180
  """)
181
 
182
  with gr.Row():
183
  with gr.Column(scale=1):
184
  file = gr.File(label="Upload Data (CSV or Excel)")
185
+ api_key = gr.Textbox(
186
+ label="OpenAI API Key",
187
+ type="password",
188
+ placeholder="Enter your API key"
189
+ )
190
+
191
  with gr.Column(scale=2):
192
  chatbot = gr.Chatbot(height=400)
193
+ message = gr.Textbox(
194
+ label="Ask about your data",
195
+ placeholder="e.g., Show me a scatter plot of X vs Y"
196
+ )
197
+ send = gr.Button("Send")
198
 
199
  # Plot output area
200
+ plot_output = gr.Plot(label="Visualization")
201
+
202
+ # Event handlers
203
+ file.change(
204
+ analyzer.process_file,
205
+ inputs=[file],
206
+ outputs=[chatbot]
207
+ )
208
 
 
 
209
  send.click(
210
+ analyzer.chat,
211
  inputs=[message, api_key],
212
  outputs=[chatbot, plot_output]
213
  )
214
 
215
+ # Example queries
216
  gr.Examples(
217
  examples=[
218
+ ["Show me a scatter plot of the numerical variables"],
219
+ ["Create a histogram of the distribution"],
220
+ ["Analyze trends in the data"],
 
221
  ],
222
  inputs=message
223
  )