jzou19950715 commited on
Commit
3950cb3
·
verified ·
1 Parent(s): 0a37365

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -156
app.py CHANGED
@@ -1,167 +1,276 @@
 
 
1
  import os
 
 
2
  import json
3
- from typing import Optional, Dict
4
 
5
  import gradio as gr
 
6
  import pandas as pd
 
 
 
 
 
7
  from litellm import completion
8
 
9
- from components.analysis import DataAnalyzer
10
- from components.statistical import StatisticalAnalyzer
11
- from components.visualization import D3Visualizer
12
-
13
- def parse_gpt_response(response: str) -> Dict:
14
- """Safely parse GPT response into analysis request"""
15
- try:
16
- # Try to fix common JSON issues
17
- cleaned_response = response.replace("```json\n", "").replace("```", "")
18
- cleaned_response = cleaned_response.strip()
19
- if not cleaned_response.startswith("{"):
20
- # Extract JSON if it's embedded in text
21
- start = cleaned_response.find("{")
22
- end = cleaned_response.rfind("}") + 1
23
- if start >= 0 and end > 0:
24
- cleaned_response = cleaned_response[start:end]
25
-
26
- # Parse JSON
27
- return json.loads(cleaned_response)
28
- except json.JSONDecodeError:
29
- # Fallback to default analysis
30
- return {
31
- "analysis_type": "distribution",
32
- "params": {"column": "all"},
33
- "explanation": "Performing basic distribution analysis as fallback."
34
- }
35
-
36
- def analyze_data(
37
- file: gr.File,
38
- query: str,
39
- api_key: str,
40
- temperature: float = 0.7,
41
- ) -> str:
42
- """Process user request and generate analysis"""
43
 
44
- if not api_key:
45
- return "Error: Please provide an API key."
46
-
47
- if not file:
48
- return "Error: Please upload a file."
 
 
 
 
49
 
50
- try:
51
- # Set up environment
52
- os.environ["OPENAI_API_KEY"] = api_key
53
 
54
- # Load data
55
- if file.name.endswith('.csv'):
56
- df = pd.read_csv(file.name)
57
- elif file.name.endswith(('.xlsx', '.xls')):
58
- df = pd.read_excel(file.name)
59
  else:
60
- return "Error: Unsupported file type."
61
 
62
- # Initialize analyzers
63
- analyzer = DataAnalyzer()
64
 
65
- # Build context
66
- file_info = f"""
67
- File: {file.name}
68
- Shape: {df.shape}
69
- Columns: {', '.join(df.columns)}
70
 
71
- Column Types:
72
- {chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])}
73
- """
 
 
 
74
 
75
- # Get analysis request from GPT-4
76
- messages = [
77
- {
78
- "role": "system",
79
- "content": """You are a data analysis assistant.
80
- Interpret the user's query and provide analysis details in JSON format.
81
-
82
- Return ONLY a JSON object with these fields:
83
- {
84
- "analysis_type": "distribution" or "forecast" or "correlation",
85
- "params": {"column": "column_name", ...},
86
- "explanation": "why this analysis is appropriate"
87
- }
88
-
89
- For timeseries data, prefer 'forecast' type.
90
- For multiple columns, prefer 'correlation' type.
91
- For single column analysis, prefer 'distribution' type.
92
- """
93
- },
94
- {
95
- "role": "user",
96
- "content": f"{file_info}\n\nUser request: {query}"
97
- }
98
- ]
99
 
100
- response = completion(
101
- model="gpt-4o-mini",
102
- messages=messages,
103
- temperature=temperature
104
- )
105
 
106
- # Parse response and perform analysis
107
- analysis_request = parse_gpt_response(response.choices[0].message.content)
 
108
 
109
- # Set default column if not specified
110
- if "params" not in analysis_request:
111
- analysis_request["params"] = {}
112
- if "column" not in analysis_request["params"]:
113
- analysis_request["params"]["column"] = df.select_dtypes(include=['number']).columns[0]
114
 
115
- result = analyzer.analyze_data(
116
- df,
117
- analysis_request["analysis_type"],
118
- analysis_request["params"]
119
- )
 
120
 
121
- # Combine results into HTML
122
- html_output = f"""
123
- <div class="analysis-container">
124
- <div class="explanation">
125
- <h2>Analysis Explanation</h2>
126
- <p>{analysis_request['explanation']}</p>
127
- </div>
128
-
129
- <div class="results">
130
- <h2>Statistical Results</h2>
131
- <pre>{str(result.get('statistics', ''))}</pre>
132
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- <div class="visualization">
135
- <h2>Interactive Visualization</h2>
136
- {result['visualization']}
137
- </div>
138
- </div>
 
139
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- return html_output
 
 
 
 
 
 
 
 
142
 
143
- except Exception as e:
144
- import traceback
145
- error_details = traceback.format_exc()
146
- return f"Error occurred: {str(e)}\n\nDetails:\n{error_details}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  def create_interface():
149
- """Create Gradio interface"""
 
 
150
 
151
- with gr.Blocks(title="Interactive Data Analysis") as interface:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  gr.Markdown("""
153
  # Interactive Data Analysis Assistant
154
 
155
- Upload your data and get interactive visualizations with statistical analysis.
156
-
157
- **Features:**
158
- - Interactive D3 visualizations
159
- - Statistical analysis
160
- - Probability distributions
161
- - Time series forecasting
162
- - Correlation analysis
163
 
164
- **Note**: Requires your own OpenAI API key.
165
  """)
166
 
167
  with gr.Row():
@@ -170,43 +279,37 @@ def create_interface():
170
  label="Upload Data File",
171
  file_types=[".csv", ".xlsx", ".xls"]
172
  )
173
- query = gr.Textbox(
174
- label="What would you like to analyze?",
175
- placeholder="e.g., Show distribution of values with statistics",
176
- lines=3
177
- )
178
  api_key = gr.Textbox(
179
- label="API Key (Required)",
180
- placeholder="Your OpenAI API key",
181
  type="password"
182
  )
183
- temperature = gr.Slider(
184
- label="Temperature",
185
- minimum=0.0,
186
- maximum=1.0,
187
- value=0.7,
188
- step=0.1
189
  )
190
  analyze_btn = gr.Button("Analyze")
191
 
192
  with gr.Column():
193
- output = gr.HTML(label="Output")
194
 
 
 
195
  analyze_btn.click(
196
- analyze_data,
197
- inputs=[file, query, api_key, temperature],
198
- outputs=output
199
  )
200
 
 
201
  gr.Examples(
202
  examples=[
203
- [None, "Show me the distribution of values and calculate statistics"],
204
- [None, "Create a 10-period probability cone forecast"],
205
- [None, "Analyze correlations between variables"],
206
- [None, "Test if the data follows a normal distribution"],
207
- [None, "Show the data distribution with confidence intervals"],
208
  ],
209
- inputs=[file, query]
210
  )
211
 
212
  return interface
 
1
+ import base64
2
+ import io
3
  import os
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
  import json
 
7
 
8
  import gradio as gr
9
+ import numpy as np
10
  import pandas as pd
11
+ from bokeh.plotting import figure
12
+ from bokeh.layouts import column, row, layout
13
+ from bokeh.models import ColumnDataSource, HoverTool, BoxSelectTool, WheelZoomTool, ResetTool
14
+ from bokeh.embed import components
15
+ from bokeh.resources import CDN
16
  from litellm import completion
17
 
18
+ class VisualizationEngine:
19
+ """Engine for creating interactive Bokeh visualizations"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def __init__(self):
22
+ self.width = 600
23
+ self.height = 400
24
+ self.tools = "pan,box_zoom,wheel_zoom,reset,save,hover"
25
+
26
+ def create_scatter(self, df: pd.DataFrame, x_col: str, y_col: str,
27
+ color_col: Optional[str] = None, title: str = "") -> str:
28
+ """Create an interactive scatter plot"""
29
+ source = ColumnDataSource(df)
30
 
31
+ p = figure(width=self.width, height=self.height, title=title, tools=self.tools)
 
 
32
 
33
+ if color_col and color_col in df.columns:
34
+ colors = df[color_col].astype('category').cat.codes
35
+ scatter = p.scatter(x_col, y_col, source=source, color={'field': color_col, 'transform': 'category10'})
 
 
36
  else:
37
+ scatter = p.scatter(x_col, y_col, source=source)
38
 
39
+ p.xaxis.axis_label = x_col
40
+ p.yaxis.axis_label = y_col
41
 
42
+ hover = p.select(dict(type=HoverTool))
43
+ hover.tooltips = [(col, f"@{col}") for col in [x_col, y_col] + ([color_col] if color_col else [])]
 
 
 
44
 
45
+ script, div = components(p)
46
+ return f"{CDN.render()}\n{div}\n{script}"
47
+
48
+ def create_line(self, df: pd.DataFrame, x_col: str, y_cols: List[str], title: str = "") -> str:
49
+ """Create an interactive line plot"""
50
+ source = ColumnDataSource(df)
51
 
52
+ p = figure(width=self.width, height=self.height, title=title, tools=self.tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ for y_col in y_cols:
55
+ p.line(x_col, y_col, line_width=2, source=source, legend_label=y_col)
 
 
 
56
 
57
+ p.xaxis.axis_label = x_col
58
+ p.yaxis.axis_label = "Values"
59
+ p.legend.click_policy = "hide"
60
 
61
+ hover = p.select(dict(type=HoverTool))
62
+ hover.tooltips = [(col, f"@{col}") for col in [x_col] + y_cols]
 
 
 
63
 
64
+ script, div = components(p)
65
+ return f"{CDN.render()}\n{div}\n{script}"
66
+
67
+ def create_bar(self, df: pd.DataFrame, x_col: str, y_col: str, title: str = "") -> str:
68
+ """Create an interactive bar plot"""
69
+ source = ColumnDataSource(df)
70
 
71
+ p = figure(width=self.width, height=self.height, title=title,
72
+ tools=self.tools, x_range=df[x_col].unique().tolist())
73
+
74
+ p.vbar(x=x_col, top=y_col, width=0.9, source=source)
75
+
76
+ p.xaxis.axis_label = x_col
77
+ p.yaxis.axis_label = y_col
78
+ p.xgrid.grid_line_color = None
79
+
80
+ hover = p.select(dict(type=HoverTool))
81
+ hover.tooltips = [(x_col, f"@{x_col}"), (y_col, f"@{y_col}")]
82
+
83
+ script, div = components(p)
84
+ return f"{CDN.render()}\n{div}\n{script}"
85
+
86
+ class AnalysisSession:
87
+ """Maintains state and history for the analysis session"""
88
+
89
+ def __init__(self):
90
+ self.data: Optional[pd.DataFrame] = None
91
+ self.chat_history: List[Dict[str, str]] = []
92
+ self.viz_engine = VisualizationEngine()
93
+
94
+ def add_message(self, role: str, content: str):
95
+ """Add a message to the chat history"""
96
+ self.chat_history.append({"role": role, "content": content})
97
+
98
+ def get_context(self) -> str:
99
+ """Get the current analysis context"""
100
+ if self.data is None:
101
+ return "No data loaded yet."
102
 
103
+ context = f"""
104
+ Current DataFrame Info:
105
+ - Shape: {self.data.shape}
106
+ - Columns: {', '.join(self.data.columns)}
107
+ - Numeric columns: {', '.join(self.data.select_dtypes(include=[np.number]).columns)}
108
+ - Categorical columns: {', '.join(self.data.select_dtypes(include=['object', 'category']).columns)}
109
  """
110
+ return context
111
+
112
+ class AnalysisAgent:
113
+ """Enhanced agent with interactive visualization and chat capabilities"""
114
+
115
+ def __init__(
116
+ self,
117
+ model_id: str = "gpt-4o-mini",
118
+ temperature: float = 0.7,
119
+ ):
120
+ self.model_id = model_id
121
+ self.temperature = temperature
122
+ self.session = AnalysisSession()
123
 
124
+ def process_query(self, query: str) -> str:
125
+ """Process a user query and generate response with visualizations"""
126
+ context = self.session.get_context()
127
+
128
+ messages = [
129
+ {"role": "system", "content": self._get_system_prompt()},
130
+ *self.session.chat_history[-5:], # Include last 5 messages for context
131
+ {"role": "user", "content": f"{context}\n\nUser query: {query}"}
132
+ ]
133
 
134
+ try:
135
+ response = completion(
136
+ model=self.model_id,
137
+ messages=messages,
138
+ temperature=self.temperature,
139
+ )
140
+ analysis = response.choices[0].message.content
141
+
142
+ # Extract and execute any code blocks
143
+ visualizations = []
144
+ code_blocks = self._extract_code(analysis)
145
+
146
+ for code in code_blocks:
147
+ try:
148
+ # Execute code and capture visualization commands
149
+ result = self._execute_visualization(code)
150
+ if result:
151
+ visualizations.append(result)
152
+ except Exception as e:
153
+ visualizations.append(f"Error creating visualization: {str(e)}")
154
+
155
+ # Add messages to chat history
156
+ self.session.add_message("user", query)
157
+ self.session.add_message("assistant", analysis)
158
+
159
+ # Combine analysis and visualizations
160
+ return analysis + "\n\n" + "\n".join(visualizations)
161
+
162
+ except Exception as e:
163
+ return f"Error: {str(e)}"
164
+
165
+ def _execute_visualization(self, code: str) -> Optional[str]:
166
+ """Execute visualization code and return HTML output"""
167
+ try:
168
+ # Create a safe namespace with necessary libraries
169
+ namespace = {
170
+ 'df': self.session.data,
171
+ 'np': np,
172
+ 'pd': pd,
173
+ 'viz': self.session.viz_engine
174
+ }
175
+
176
+ # Execute the code
177
+ exec(code, namespace)
178
+
179
+ # Look for visualization result
180
+ for var in namespace.values():
181
+ if isinstance(var, str) and ('<script' in var or '<div' in var):
182
+ return var
183
+
184
+ return None
185
+
186
+ except Exception as e:
187
+ return f"Error executing visualization: {str(e)}"
188
+
189
+ def _get_system_prompt(self) -> str:
190
+ """Get system prompt with visualization capabilities"""
191
+ return """You are a data analysis assistant with interactive visualization capabilities.
192
+
193
+ Available visualizations:
194
+ 1. Scatter plots (viz.create_scatter)
195
+ 2. Line plots (viz.create_line)
196
+ 3. Bar plots (viz.create_bar)
197
+
198
+ The following variables are available:
199
+ - df: pandas DataFrame with the current data
200
+ - viz: visualization engine with plotting methods
201
+ - np: numpy library
202
+ - pd: pandas library
203
+
204
+ When analyzing data:
205
+ 1. First understand and explain the data
206
+ 2. Create relevant visualizations using the viz engine
207
+ 3. Provide insights based on the visualizations
208
+ 4. Ask follow-up questions when appropriate
209
+ 5. Use markdown for formatting
210
+
211
+ Example visualization code:
212
+ ```python
213
+ # Create scatter plot
214
+ html = viz.create_scatter(df, 'column1', 'column2', title='Analysis')
215
+ print(html)
216
+
217
+ # Create line plot
218
+ html = viz.create_line(df, 'date_column', ['value1', 'value2'], title='Trends')
219
+ print(html)
220
+ ```
221
+ """
222
+
223
+ @staticmethod
224
+ def _extract_code(text: str) -> List[str]:
225
+ """Extract Python code blocks from markdown"""
226
+ import re
227
+ pattern = r'```python\n(.*?)```'
228
+ return re.findall(pattern, text, re.DOTALL)
229
 
230
  def create_interface():
231
+ """Create Gradio interface with chat capabilities"""
232
+
233
+ agent = AnalysisAgent()
234
 
235
+ def process_file(file: gr.File) -> str:
236
+ """Process uploaded file and initialize session"""
237
+ try:
238
+ if file.name.endswith('.csv'):
239
+ agent.session.data = pd.read_csv(file.name)
240
+ elif file.name.endswith(('.xlsx', '.xls')):
241
+ agent.session.data = pd.read_excel(file.name)
242
+ else:
243
+ return "Error: Unsupported file type"
244
+
245
+ return f"Successfully loaded data: {agent.session.get_context()}"
246
+ except Exception as e:
247
+ return f"Error loading file: {str(e)}"
248
+
249
+ def analyze(file: gr.File, query: str, api_key: str) -> str:
250
+ """Process analysis query"""
251
+ if not api_key:
252
+ return "Error: Please provide an API key."
253
+
254
+ if not file:
255
+ return "Error: Please upload a file."
256
+
257
+ try:
258
+ os.environ["OPENAI_API_KEY"] = api_key
259
+ return agent.process_query(query)
260
+ except Exception as e:
261
+ return f"Error: {str(e)}"
262
+
263
+ with gr.Blocks(title="Interactive Data Analysis Assistant") as interface:
264
  gr.Markdown("""
265
  # Interactive Data Analysis Assistant
266
 
267
+ Upload your data file and chat with the AI to analyze it. Features:
268
+ - Interactive visualizations
269
+ - Natural language analysis
270
+ - Follow-up questions
271
+ - Statistical insights
 
 
 
272
 
273
+ **Note**: Requires OpenAI API key
274
  """)
275
 
276
  with gr.Row():
 
279
  label="Upload Data File",
280
  file_types=[".csv", ".xlsx", ".xls"]
281
  )
 
 
 
 
 
282
  api_key = gr.Textbox(
283
+ label="API Key",
 
284
  type="password"
285
  )
286
+ chat_input = gr.Textbox(
287
+ label="Ask about your data",
288
+ placeholder="e.g., Show me the relationship between variables",
289
+ lines=3
 
 
290
  )
291
  analyze_btn = gr.Button("Analyze")
292
 
293
  with gr.Column():
294
+ chat_output = gr.HTML(label="Analysis & Visualizations")
295
 
296
+ # Set up event handlers
297
+ file.change(process_file, inputs=[file], outputs=[chat_output])
298
  analyze_btn.click(
299
+ analyze,
300
+ inputs=[file, chat_input, api_key],
301
+ outputs=[chat_output]
302
  )
303
 
304
+ # Example queries
305
  gr.Examples(
306
  examples=[
307
+ [None, "Show me the distribution of numerical variables"],
308
+ [None, "Create a correlation analysis with interactive visualizations"],
309
+ [None, "What are the main trends in the data?"],
310
+ [None, "Can you identify any interesting patterns?"],
 
311
  ],
312
+ inputs=[file, chat_input]
313
  )
314
 
315
  return interface