jzou19950715 commited on
Commit
38571e1
·
verified ·
1 Parent(s): 5e23ce2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +296 -335
app.py CHANGED
@@ -1,325 +1,297 @@
1
- import os
2
- from typing import List, Optional, Tuple, Dict, Any
3
  import base64
4
  import io
 
 
 
5
 
6
  import gradio as gr
7
- import pandas as pd
8
  import numpy as np
 
 
 
9
  import plotly.graph_objects as go
 
10
  from litellm import completion
11
 
12
- class DataAnalyzer:
13
- """Handles data analysis and visualization"""
 
14
 
15
  def __init__(self):
16
- self.data: Optional[pd.DataFrame] = None
17
- self.width = 800
18
- self.height = 500
19
-
20
- def create_histogram(self, column: str, bins: int = 30, title: str = "") -> str:
21
- """Create histogram with Plotly"""
22
- if self.data is None:
23
- raise ValueError("No data loaded")
24
-
25
- fig = go.Figure()
26
-
27
- fig.add_trace(go.Histogram(
28
- x=self.data[column],
29
- nbinsx=bins,
30
- name=column
31
- ))
32
-
33
- fig.update_layout(
34
- title=title,
35
- xaxis_title=column,
36
- yaxis_title="Count",
37
- width=self.width,
38
- height=self.height,
39
- template="plotly_white"
40
- )
41
-
42
- return fig.to_html(include_plotlyjs=True, full_html=False)
43
-
44
- def create_scatter(self, x_col: str, y_col: str, color_col: Optional[str] = None,
45
- title: str = "") -> str:
46
- """Create scatter plot with Plotly"""
47
- if self.data is None:
48
- raise ValueError("No data loaded")
49
-
50
- fig = go.Figure()
51
-
52
- if color_col:
53
- for category in self.data[color_col].unique():
54
- mask = self.data[color_col] == category
55
- fig.add_trace(go.Scatter(
56
- x=self.data[mask][x_col],
57
- y=self.data[mask][y_col],
58
- mode='markers',
59
- name=str(category),
60
- text=self.data[mask][color_col]
61
- ))
62
- else:
63
- fig.add_trace(go.Scatter(
64
- x=self.data[x_col],
65
- y=self.data[y_col],
66
- mode='markers'
67
- ))
68
-
69
- fig.update_layout(
70
- title=title,
71
- xaxis_title=x_col,
72
- yaxis_title=y_col,
73
- width=self.width,
74
- height=self.height,
75
- template="plotly_white",
76
- hovermode='closest'
77
- )
78
-
79
- return fig.to_html(include_plotlyjs=True, full_html=False)
80
-
81
- def create_box(self, x_col: str, y_col: str, title: str = "") -> str:
82
- """Create box plot with Plotly"""
83
- if self.data is None:
84
- raise ValueError("No data loaded")
85
-
86
- fig = go.Figure()
87
-
88
- for category in self.data[x_col].unique():
89
- fig.add_trace(go.Box(
90
- y=self.data[self.data[x_col] == category][y_col],
91
- name=str(category),
92
- boxpoints='all',
93
- jitter=0.3,
94
- pointpos=-1.8
95
- ))
96
-
97
- fig.update_layout(
98
- title=title,
99
- yaxis_title=y_col,
100
- xaxis_title=x_col,
101
- width=self.width,
102
- height=self.height,
103
- template="plotly_white",
104
- showlegend=False
105
- )
106
 
107
- return fig.to_html(include_plotlyjs=True, full_html=False)
108
-
109
- def create_line(self, x_col: str, y_col: str, color_col: Optional[str] = None,
110
- title: str = "") -> str:
111
- """Create line plot with Plotly"""
112
- if self.data is None:
113
- raise ValueError("No data loaded")
114
 
115
- fig = go.Figure()
116
-
117
- if color_col:
118
- for category in self.data[color_col].unique():
119
- mask = self.data[color_col] == category
120
- fig.add_trace(go.Scatter(
121
- x=self.data[mask][x_col],
122
- y=self.data[mask][y_col],
123
- mode='lines+markers',
124
- name=str(category)
125
- ))
126
- else:
127
- fig.add_trace(go.Scatter(
128
- x=self.data[x_col],
129
- y=self.data[y_col],
130
- mode='lines+markers'
131
- ))
132
-
133
- fig.update_layout(
134
- title=title,
135
- xaxis_title=x_col,
136
- yaxis_title=y_col,
137
- width=self.width,
138
- height=self.height,
139
- template="plotly_white",
140
- hovermode='x unified'
141
- )
142
 
143
- return fig.to_html(include_plotlyjs=True, full_html=False)
144
-
145
- class ChatAnalyzer:
146
- """Handles chat-based analysis with visualization"""
147
-
148
- def __init__(self):
149
- self.analyzer = DataAnalyzer()
150
- self.history: List[Tuple[str, str]] = []
151
-
152
- def process_file(self, file: gr.File) -> List[Tuple[str, str]]:
153
- """Process uploaded file and initialize analyzer"""
154
  try:
155
- if file.name.endswith('.csv'):
156
- self.analyzer.data = pd.read_csv(file.name)
157
- elif file.name.endswith(('.xlsx', '.xls')):
158
- self.analyzer.data = pd.read_excel(file.name)
159
- else:
160
- return [("System", "Error: Please upload a CSV or Excel file.")]
161
-
162
- # Convert date columns to datetime
163
- date_cols = self.analyzer.data.select_dtypes(include=['object']).columns
164
- for col in date_cols:
165
- try:
166
- self.analyzer.data[col] = pd.to_datetime(self.analyzer.data[col])
167
- except:
168
- continue
169
 
170
- info = f"""Data loaded successfully!
171
- Shape: {self.analyzer.data.shape}
172
- Columns: {', '.join(self.analyzer.data.columns)}
 
 
 
 
 
 
173
 
174
- Numeric columns: {', '.join(self.analyzer.data.select_dtypes(include=[np.number]).columns)}
175
- Date columns: {', '.join(self.analyzer.data.select_dtypes(include=['datetime64']).columns)}
176
- Categorical columns: {', '.join(self.analyzer.data.select_dtypes(include=['object']).columns)}
177
- """
 
 
 
 
 
 
 
 
178
 
179
- self.history = [("System", info)]
180
- return self.history
181
 
182
  except Exception as e:
183
- self.history = [("System", f"Error loading file: {str(e)}")]
184
- return self.history
185
-
186
- def chat(self, message: str, api_key: str) -> Tuple[List[Tuple[str, str]], str]:
187
- """Process chat message and generate visualizations"""
188
- if self.analyzer.data is None:
189
- return [(message, "Please upload a data file first.")], ""
190
 
191
- if not api_key:
192
- return [(message, "Please provide an OpenAI API key.")], ""
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  try:
195
- os.environ["OPENAI_API_KEY"] = api_key
196
-
197
- # Get data context
198
- context = self._get_data_context()
199
-
200
- # Get AI response
201
- completion_response = completion(
202
- model="gpt-4o-mini",
203
- messages=[
204
- {"role": "system", "content": self._get_system_prompt()},
205
- {"role": "user", "content": f"{context}\n\nUser question: {message}"}
206
- ],
207
- temperature=0.7
208
  )
 
209
 
210
- analysis = completion_response.choices[0].message.content
 
211
 
212
- # Create visualizations
213
- plots_html = ""
214
- try:
215
- # Extract code blocks
216
- import re
217
- code_blocks = re.findall(r'```python\n(.*?)```', analysis, re.DOTALL)
218
-
219
- for code in code_blocks:
220
- # Create namespace for execution
221
- namespace = {
222
- 'analyzer': self.analyzer,
223
- 'df': self.analyzer.data,
224
- 'print': lambda x: x
225
- }
226
 
227
- # Execute the code
228
- try:
229
- result = eval(code, namespace)
230
- if isinstance(result, str) and ('<div' in result or '<script' in result):
231
- plots_html += f'<div class="plot-container">{result}</div>'
232
- except:
233
- exec(code, namespace)
234
-
235
- except Exception as e:
236
- analysis += f"\n\nError creating visualization: {str(e)}"
237
-
238
- # Update chat history
239
- self.history.append((message, analysis))
240
-
241
- return self.history, plots_html
242
 
243
  except Exception as e:
244
- self.history.append((message, f"Error: {str(e)}"))
245
- return self.history, ""
246
 
247
- def _get_data_context(self) -> str:
248
- """Get current data context for AI"""
249
- df = self.analyzer.data
250
- numeric_cols = df.select_dtypes(include=[np.number]).columns
251
- date_cols = df.select_dtypes(include=['datetime64']).columns
252
- categorical_cols = df.select_dtypes(include=['object']).columns
253
-
254
- # Get basic statistics
255
- stats = df[numeric_cols].describe().to_string() if len(numeric_cols) > 0 else "No numeric columns"
256
-
257
- return f"""
258
- Data Information:
259
- - Shape: {df.shape}
260
- - Numeric columns: {', '.join(numeric_cols)}
261
- - Date columns: {', '.join(date_cols)}
262
- - Categorical columns: {', '.join(categorical_cols)}
263
 
264
- Basic Statistics:
265
- {stats}
266
 
267
- Available visualization functions:
268
- - analyzer.create_histogram(column, bins, title)
269
- - analyzer.create_scatter(x_col, y_col, color_col, title)
270
- - analyzer.create_box(x_col, y_col, title)
271
- - analyzer.create_line(x_col, y_col, color_col, title)
272
- """
273
-
274
- def _get_system_prompt(self) -> str:
275
- """Get system prompt for AI"""
276
- return """You are a data analysis assistant specialized in creating interactive visualizations.
277
 
278
- Available visualization functions:
279
- 1. create_histogram(column, bins, title) - For distribution analysis
280
- 2. create_scatter(x_col, y_col, color_col, title) - For relationship analysis
281
- 3. create_box(x_col, y_col, title) - For categorical comparisons
282
- 4. create_line(x_col, y_col, color_col, title) - For trend analysis
 
283
 
284
- Example usage:
285
  ```python
286
- # Create histogram
287
- result = analyzer.create_histogram(
288
- column='Salary',
289
- bins=20,
290
- title='Salary Distribution'
291
- )
292
- print(result)
293
 
294
- # Create scatter plot with time series
295
- result = analyzer.create_scatter(
296
- x_col='Date',
297
- y_col='Salary',
298
- color_col='Title',
299
- title='Salary Trends by Title'
300
- )
301
- print(result)
302
 
303
- # Create box plot
304
- result = analyzer.create_box(
305
- x_col='Title',
306
- y_col='Salary',
307
- title='Salary Distribution by Title'
308
- )
309
- print(result)
310
  ```
 
 
 
 
 
 
 
 
311
 
312
- Always wrap code in Python code blocks and use print() to display the visualizations.
313
- Provide analysis and insights about what the visualizations show."""
314
 
315
- def create_interface():
316
- """Create Gradio interface"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
- analyzer = ChatAnalyzer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- # Custom CSS
321
  css = """
322
- .container { max-width: 1200px; margin: auto; }
323
  .plot-container {
324
  margin: 20px 0;
325
  padding: 15px;
@@ -328,81 +300,70 @@ def create_interface():
328
  background: white;
329
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
330
  }
331
- .chat-message {
332
- margin-bottom: 15px;
333
- padding: 10px;
334
- border-radius: 8px;
335
- background: #f8f9fa;
336
- }
337
  """
338
 
339
- with gr.Blocks(css=css) as demo:
340
  gr.Markdown("""
341
- # Interactive Data Analysis Chat
 
 
342
 
343
- Upload your data and chat with AI to analyze it! Features:
344
- - Interactive visualizations
345
- - Natural language analysis
346
- - Statistical insights
347
- - Trend detection
 
 
348
  """)
349
 
350
  with gr.Row():
351
- with gr.Column(scale=1):
352
  file = gr.File(
353
- label="Upload Data (CSV or Excel)",
354
  file_types=[".csv", ".xlsx", ".xls"]
355
  )
356
- api_key = gr.Textbox(
357
- label="OpenAI API Key",
358
- type="password",
359
- placeholder="Enter your API key"
360
  )
361
-
362
- with gr.Column(scale=2):
363
- chatbot = gr.Chatbot(
364
- height=400,
365
- elem_classes="chat-message"
366
  )
367
- message = gr.Textbox(
368
- label="Ask about your data",
369
- placeholder="e.g., Show me trends in the data",
370
- lines=2
 
 
371
  )
372
- send = gr.Button("Send")
373
-
374
- # Plot output area
375
- plot_output = gr.HTML(
376
- label="Visualizations",
377
- elem_classes="plot-container"
378
- )
379
-
380
- # Event handlers
381
- file.change(
382
- analyzer.process_file,
383
- inputs=[file],
384
- outputs=[chatbot]
385
- )
386
 
387
- send.click(
388
- analyzer.chat,
389
- inputs=[message, api_key],
390
- outputs=[chatbot, plot_output]
391
  )
392
 
393
- # Example queries
394
  gr.Examples(
395
  examples=[
396
- ["Show me a histogram of salary distribution"],
397
- ["Create a scatter plot of salary trends over time"],
398
- ["Show me box plots of salaries by title"],
399
- ["Analyze the trends and patterns in the data"],
 
400
  ],
401
- inputs=message
402
  )
403
 
404
- return demo
405
 
406
  if __name__ == "__main__":
407
- demo = create_interface()
408
- demo.launch()
 
 
 
1
  import base64
2
  import io
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
 
7
  import gradio as gr
8
+ import matplotlib.pyplot as plt
9
  import numpy as np
10
+ import pandas as pd
11
+ import seaborn as sns
12
+ import plotly.express as px
13
  import plotly.graph_objects as go
14
+ from plotly.subplots import make_subplots
15
  from litellm import completion
16
 
17
+
18
+ class CodeEnvironment:
19
+ """Enhanced environment for executing code with both static and interactive visualization capabilities"""
20
 
21
  def __init__(self):
22
+ self.globals = {
23
+ 'pd': pd,
24
+ 'np': np,
25
+ 'plt': plt,
26
+ 'sns': sns,
27
+ 'px': px,
28
+ 'go': go,
29
+ 'make_subplots': make_subplots
30
+ }
31
+ self.locals = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def execute(self, code: str, df: pd.DataFrame = None) -> Dict[str, Any]:
34
+ """Execute code and capture both static and interactive outputs"""
35
+ if df is not None:
36
+ self.globals['df'] = df
 
 
 
37
 
38
+ # Capture output
39
+ output_buffer = io.StringIO()
40
+ result = {
41
+ 'output': '',
42
+ 'figures': [], # For base64 static images
43
+ 'interactive': [], # For Plotly HTML
44
+ 'error': None
45
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
47
  try:
48
+ # Execute code
49
+ exec(code, self.globals, self.locals)
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Capture matplotlib figures (static)
52
+ for i in plt.get_fignums():
53
+ fig = plt.figure(i)
54
+ buf = io.BytesIO()
55
+ fig.savefig(buf, format='png')
56
+ buf.seek(0)
57
+ img_str = base64.b64encode(buf.read()).decode()
58
+ result['figures'].append(f"data:image/png;base64,{img_str}")
59
+ plt.close(fig)
60
 
61
+ # Capture Plotly figures (interactive)
62
+ for var in list(self.locals.values()):
63
+ if isinstance(var, (go.Figure, px.Figure)):
64
+ html = var.to_html(
65
+ include_plotlyjs=True,
66
+ full_html=False,
67
+ config={
68
+ 'displayModeBar': True,
69
+ 'responsive': True
70
+ }
71
+ )
72
+ result['interactive'].append(html)
73
 
74
+ # Get printed output
75
+ result['output'] = output_buffer.getvalue()
76
 
77
  except Exception as e:
78
+ result['error'] = str(e)
 
 
 
 
 
 
79
 
80
+ finally:
81
+ output_buffer.close()
82
 
83
+ return result
84
+
85
+
86
+ @dataclass
87
+ class Tool:
88
+ """Tool for data analysis"""
89
+ name: str
90
+ description: str
91
+ func: Callable
92
+
93
+
94
+ class AnalysisAgent:
95
+ """Enhanced agent with interactive visualization capabilities"""
96
+
97
+ def __init__(
98
+ self,
99
+ model_id: str = "gpt-4o-mini",
100
+ temperature: float = 0.7,
101
+ ):
102
+ self.model_id = model_id
103
+ self.temperature = temperature
104
+ self.tools: List[Tool] = []
105
+ self.code_env = CodeEnvironment()
106
+
107
+ def add_tool(self, name: str, description: str, func: Callable) -> None:
108
+ """Add a tool to the agent"""
109
+ self.tools.append(Tool(name=name, description=description, func=func))
110
+
111
+ def run(self, prompt: str, df: pd.DataFrame = None) -> str:
112
+ """Run analysis with enhanced visualization support"""
113
+ messages = [
114
+ {"role": "system", "content": self._get_system_prompt()},
115
+ {"role": "user", "content": prompt}
116
+ ]
117
+
118
  try:
119
+ # Get response from model
120
+ response = completion(
121
+ model=self.model_id,
122
+ messages=messages,
123
+ temperature=self.temperature,
 
 
 
 
 
 
 
 
124
  )
125
+ analysis = response.choices[0].message.content
126
 
127
+ # Extract code blocks
128
+ code_blocks = self._extract_code(analysis)
129
 
130
+ # Execute code and capture results
131
+ results = []
132
+ for code in code_blocks:
133
+ result = self.code_env.execute(code, df)
134
+ if result['error']:
135
+ results.append(f"Error executing code: {result['error']}")
136
+ else:
137
+ # Add output text
138
+ if result['output']:
139
+ results.append(result['output'])
 
 
 
 
140
 
141
+ # Add interactive plots
142
+ for plot in result['interactive']:
143
+ results.append(f"<div class='plot-container'>{plot}</div>")
144
+
145
+ # Add static figures as fallback
146
+ for fig in result['figures']:
147
+ results.append(f"![Figure]({fig})")
148
+
149
+ # Combine analysis and results
150
+ return analysis + "\n\n" + "\n".join(results)
 
 
 
 
 
151
 
152
  except Exception as e:
153
+ return f"Error: {str(e)}"
 
154
 
155
+ def _get_system_prompt(self) -> str:
156
+ """Get enhanced system prompt with interactive visualization capabilities"""
157
+ tools_desc = "\n".join([
158
+ f"- {tool.name}: {tool.description}"
159
+ for tool in self.tools
160
+ ])
 
 
 
 
 
 
 
 
 
 
161
 
162
+ return f"""You are a data analysis assistant with interactive visualization capabilities.
 
163
 
164
+ Available tools:
165
+ {tools_desc}
166
+
167
+ Capabilities:
168
+ - Data analysis (pandas, numpy)
169
+ - Interactive visualization (plotly)
170
+ - Static visualization (matplotlib, seaborn)
171
+ - Statistical analysis (scipy)
172
+ - Machine learning (sklearn)
 
173
 
174
+ When writing code:
175
+ - Prefer Plotly for interactive visualizations
176
+ - Use matplotlib/seaborn for static plots when appropriate
177
+ - Create clear visualizations with proper labels
178
+ - Include explanatory text
179
+ - Handle errors gracefully
180
 
181
+ Example Plotly usage:
182
  ```python
183
+ # Create interactive scatter plot
184
+ fig = px.scatter(df, x='column1', y='column2',
185
+ color='category',
186
+ title='Interactive Analysis')
187
+ fig.update_layout(height=600)
188
+ fig.show()
 
189
 
190
+ # Create interactive time series
191
+ fig = px.line(df, x='date', y='value',
192
+ color='category',
193
+ title='Time Series Analysis')
194
+ fig.update_layout(height=600)
195
+ fig.show()
196
+ ```
 
197
 
198
+ Example Matplotlib usage:
199
+ ```python
200
+ # Create static plot
201
+ plt.figure(figsize=(10, 6))
202
+ sns.boxplot(data=df, x='category', y='value')
203
+ plt.title('Distribution Analysis')
204
+ plt.show()
205
  ```
206
+ """
207
+
208
+ @staticmethod
209
+ def _extract_code(text: str) -> List[str]:
210
+ """Extract Python code blocks from markdown"""
211
+ import re
212
+ pattern = r'```python\n(.*?)```'
213
+ return re.findall(pattern, text, re.DOTALL)
214
 
 
 
215
 
216
+ def process_file(file: gr.File) -> Optional[pd.DataFrame]:
217
+ """Process uploaded file into DataFrame"""
218
+ if not file:
219
+ return None
220
+
221
+ try:
222
+ if file.name.endswith('.csv'):
223
+ return pd.read_csv(file.name)
224
+ elif file.name.endswith(('.xlsx', '.xls')):
225
+ return pd.read_excel(file.name)
226
+ except Exception as e:
227
+ print(f"Error reading file: {str(e)}")
228
+ return None
229
+
230
+
231
+ def analyze_data(
232
+ file: gr.File,
233
+ query: str,
234
+ api_key: str,
235
+ temperature: float = 0.7,
236
+ ) -> str:
237
+ """Process user request and generate enhanced analysis"""
238
 
239
+ if not api_key:
240
+ return "Error: Please provide an API key."
241
+
242
+ if not file:
243
+ return "Error: Please upload a file."
244
+
245
+ try:
246
+ # Set up environment
247
+ os.environ["OPENAI_API_KEY"] = api_key
248
+
249
+ # Create agent
250
+ agent = AnalysisAgent(
251
+ model_id="gpt-4o-mini",
252
+ temperature=temperature
253
+ )
254
+
255
+ # Process file
256
+ df = process_file(file)
257
+ if df is None:
258
+ return "Error: Could not process file."
259
+
260
+ # Build context
261
+ file_info = f"""
262
+ File: {file.name}
263
+ Shape: {df.shape}
264
+ Columns: {', '.join(df.columns)}
265
+
266
+ Column Types:
267
+ {chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])}
268
+ """
269
+
270
+ # Run analysis
271
+ prompt = f"""
272
+ {file_info}
273
+
274
+ The data is loaded in a pandas DataFrame called 'df'.
275
+
276
+ User request: {query}
277
+
278
+ Please analyze the data and provide:
279
+ 1. Key insights and findings
280
+ 2. Interactive visualizations where appropriate
281
+ 3. Statistical summaries when relevant
282
+ 4. Clear explanations of patterns and trends
283
+ """
284
+
285
+ return agent.run(prompt, df=df)
286
+
287
+ except Exception as e:
288
+ return f"Error occurred: {str(e)}"
289
+
290
+
291
+ def create_interface():
292
+ """Create enhanced Gradio interface"""
293
 
 
294
  css = """
 
295
  .plot-container {
296
  margin: 20px 0;
297
  padding: 15px;
 
300
  background: white;
301
  box-shadow: 0 2px 4px rgba(0,0,0,0.1);
302
  }
 
 
 
 
 
 
303
  """
304
 
305
+ with gr.Blocks(title="AI Data Analysis Assistant", css=css) as interface:
306
  gr.Markdown("""
307
+ # AI Data Analysis Assistant
308
+
309
+ Upload your data file and get AI-powered analysis with interactive visualizations.
310
 
311
+ **Features:**
312
+ - Interactive data visualization
313
+ - Statistical analysis
314
+ - Machine learning capabilities
315
+ - Natural language interaction
316
+
317
+ **Note**: Requires your own OpenAI API key.
318
  """)
319
 
320
  with gr.Row():
321
+ with gr.Column():
322
  file = gr.File(
323
+ label="Upload Data File",
324
  file_types=[".csv", ".xlsx", ".xls"]
325
  )
326
+ query = gr.Textbox(
327
+ label="What would you like to analyze?",
328
+ placeholder="e.g., Create interactive visualizations showing relationships between variables",
329
+ lines=3
330
  )
331
+ api_key = gr.Textbox(
332
+ label="API Key (Required)",
333
+ placeholder="Your API key",
334
+ type="password"
 
335
  )
336
+ temperature = gr.Slider(
337
+ label="Temperature",
338
+ minimum=0.0,
339
+ maximum=1.0,
340
+ value=0.7,
341
+ step=0.1
342
  )
343
+ analyze_btn = gr.Button("Analyze")
344
+
345
+ with gr.Column():
346
+ output = gr.HTML(label="Output") # Changed to HTML for interactive plots
 
 
 
 
 
 
 
 
 
 
347
 
348
+ analyze_btn.click(
349
+ analyze_data,
350
+ inputs=[file, query, api_key, temperature],
351
+ outputs=output
352
  )
353
 
 
354
  gr.Examples(
355
  examples=[
356
+ [None, "Create interactive visualizations showing relationships between variables"],
357
+ [None, "Show the distribution of values with interactive plots"],
358
+ [None, "Create an interactive correlation analysis"],
359
+ [None, "Show trends over time with interactive charts"],
360
+ [None, "Generate a comprehensive analysis with multiple visualizations"],
361
  ],
362
+ inputs=[file, query]
363
  )
364
 
365
+ return interface
366
 
367
  if __name__ == "__main__":
368
+ interface = create_interface()
369
+ interface.launch()