jzou19950715 commited on
Commit
f3177ec
·
verified ·
1 Parent(s): 38571e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -61
app.py CHANGED
@@ -14,11 +14,11 @@ import plotly.graph_objects as go
14
  from plotly.subplots import make_subplots
15
  from litellm import completion
16
 
17
-
18
  class CodeEnvironment:
19
- """Enhanced environment for executing code with both static and interactive visualization capabilities"""
20
 
21
  def __init__(self):
 
22
  self.globals = {
23
  'pd': pd,
24
  'np': np,
@@ -37,10 +37,14 @@ class CodeEnvironment:
37
 
38
  # Capture output
39
  output_buffer = io.StringIO()
 
 
 
 
40
  result = {
41
- 'output': '',
42
- 'figures': [], # For base64 static images
43
- 'interactive': [], # For Plotly HTML
44
  'error': None
45
  }
46
 
@@ -48,7 +52,7 @@ class CodeEnvironment:
48
  # Execute code
49
  exec(code, self.globals, self.locals)
50
 
51
- # Capture matplotlib figures (static)
52
  for i in plt.get_fignums():
53
  fig = plt.figure(i)
54
  buf = io.BytesIO()
@@ -58,18 +62,16 @@ class CodeEnvironment:
58
  result['figures'].append(f"data:image/png;base64,{img_str}")
59
  plt.close(fig)
60
 
61
- # Capture Plotly figures (interactive)
62
- for var in list(self.locals.values()):
63
- if isinstance(var, (go.Figure, px.Figure)):
64
- html = var.to_html(
 
65
  include_plotlyjs=True,
66
  full_html=False,
67
- config={
68
- 'displayModeBar': True,
69
- 'responsive': True
70
- }
71
  )
72
- result['interactive'].append(html)
73
 
74
  # Get printed output
75
  result['output'] = output_buffer.getvalue()
@@ -78,11 +80,12 @@ class CodeEnvironment:
78
  result['error'] = str(e)
79
 
80
  finally:
 
 
81
  output_buffer.close()
82
 
83
  return result
84
 
85
-
86
  @dataclass
87
  class Tool:
88
  """Tool for data analysis"""
@@ -90,9 +93,8 @@ class Tool:
90
  description: str
91
  func: Callable
92
 
93
-
94
  class AnalysisAgent:
95
- """Enhanced agent with interactive visualization capabilities"""
96
 
97
  def __init__(
98
  self,
@@ -104,12 +106,8 @@ class AnalysisAgent:
104
  self.tools: List[Tool] = []
105
  self.code_env = CodeEnvironment()
106
 
107
- def add_tool(self, name: str, description: str, func: Callable) -> None:
108
- """Add a tool to the agent"""
109
- self.tools.append(Tool(name=name, description=description, func=func))
110
-
111
  def run(self, prompt: str, df: pd.DataFrame = None) -> str:
112
- """Run analysis with enhanced visualization support"""
113
  messages = [
114
  {"role": "system", "content": self._get_system_prompt()},
115
  {"role": "user", "content": prompt}
@@ -138,69 +136,58 @@ class AnalysisAgent:
138
  if result['output']:
139
  results.append(result['output'])
140
 
141
- # Add interactive plots
142
- for plot in result['interactive']:
143
- results.append(f"<div class='plot-container'>{plot}</div>")
144
 
145
- # Add static figures as fallback
146
  for fig in result['figures']:
147
- results.append(f"![Figure]({fig})")
148
-
149
  # Combine analysis and results
150
- return analysis + "\n\n" + "\n".join(results)
151
 
152
  except Exception as e:
153
  return f"Error: {str(e)}"
154
 
155
  def _get_system_prompt(self) -> str:
156
- """Get enhanced system prompt with interactive visualization capabilities"""
157
  tools_desc = "\n".join([
158
  f"- {tool.name}: {tool.description}"
159
  for tool in self.tools
160
  ])
161
 
162
- return f"""You are a data analysis assistant with interactive visualization capabilities.
163
-
164
- Available tools:
165
- {tools_desc}
166
-
167
- Capabilities:
168
- - Data analysis (pandas, numpy)
169
- - Interactive visualization (plotly)
170
- - Static visualization (matplotlib, seaborn)
171
- - Statistical analysis (scipy)
172
- - Machine learning (sklearn)
173
 
174
- When writing code:
175
- - Prefer Plotly for interactive visualizations
176
- - Use matplotlib/seaborn for static plots when appropriate
177
- - Create clear visualizations with proper labels
178
- - Include explanatory text
179
- - Handle errors gracefully
180
 
181
- Example Plotly usage:
182
  ```python
183
  # Create interactive scatter plot
184
- fig = px.scatter(df, x='column1', y='column2',
185
- color='category',
186
- title='Interactive Analysis')
187
- fig.update_layout(height=600)
 
 
188
  fig.show()
189
 
190
  # Create interactive time series
191
- fig = px.line(df, x='date', y='value',
192
- color='category',
193
- title='Time Series Analysis')
194
- fig.update_layout(height=600)
195
  fig.show()
196
  ```
197
 
198
- Example Matplotlib usage:
 
 
 
 
 
 
 
199
  ```python
200
- # Create static plot
201
  plt.figure(figsize=(10, 6))
202
- sns.boxplot(data=df, x='category', y='value')
203
- plt.title('Distribution Analysis')
204
  plt.show()
205
  ```
206
  """
 
14
  from plotly.subplots import make_subplots
15
  from litellm import completion
16
 
 
17
  class CodeEnvironment:
18
+ """Safe environment for executing code with data analysis capabilities"""
19
 
20
  def __init__(self):
21
+ # Initialize libraries in globals
22
  self.globals = {
23
  'pd': pd,
24
  'np': np,
 
37
 
38
  # Capture output
39
  output_buffer = io.StringIO()
40
+ # Redirect stdout to capture print statements
41
+ import sys
42
+ sys.stdout = output_buffer
43
+
44
  result = {
45
+ 'output': '',
46
+ 'figures': [], # For matplotlib figures
47
+ 'plotly_html': [], # For Plotly figures
48
  'error': None
49
  }
50
 
 
52
  # Execute code
53
  exec(code, self.globals, self.locals)
54
 
55
+ # Capture matplotlib figures
56
  for i in plt.get_fignums():
57
  fig = plt.figure(i)
58
  buf = io.BytesIO()
 
62
  result['figures'].append(f"data:image/png;base64,{img_str}")
63
  plt.close(fig)
64
 
65
+ # Capture Plotly figures
66
+ if 'fig' in self.locals:
67
+ if isinstance(self.locals['fig'], (go.Figure, px.Figure)):
68
+ # Convert Plotly figure to HTML
69
+ html = self.locals['fig'].to_html(
70
  include_plotlyjs=True,
71
  full_html=False,
72
+ config={'displayModeBar': True}
 
 
 
73
  )
74
+ result['plotly_html'].append(html)
75
 
76
  # Get printed output
77
  result['output'] = output_buffer.getvalue()
 
80
  result['error'] = str(e)
81
 
82
  finally:
83
+ # Reset stdout
84
+ sys.stdout = sys.__stdout__
85
  output_buffer.close()
86
 
87
  return result
88
 
 
89
  @dataclass
90
  class Tool:
91
  """Tool for data analysis"""
 
93
  description: str
94
  func: Callable
95
 
 
96
  class AnalysisAgent:
97
+ """Agent that can analyze data and execute code"""
98
 
99
  def __init__(
100
  self,
 
106
  self.tools: List[Tool] = []
107
  self.code_env = CodeEnvironment()
108
 
 
 
 
 
109
  def run(self, prompt: str, df: pd.DataFrame = None) -> str:
110
+ """Run analysis with code execution"""
111
  messages = [
112
  {"role": "system", "content": self._get_system_prompt()},
113
  {"role": "user", "content": prompt}
 
136
  if result['output']:
137
  results.append(result['output'])
138
 
139
+ # Add Plotly interactive visualizations
140
+ for html in result['plotly_html']:
141
+ results.append(f'<div class="plot-container">{html}</div>')
142
 
143
+ # Add static matplotlib figures as fallback
144
  for fig in result['figures']:
145
+ results.append(f'<img src="{fig}" style="max-width: 100%; height: auto;">')
146
+
147
  # Combine analysis and results
148
+ return f'<div class="analysis-text">{analysis}</div>' + "\n\n" + "\n".join(results)
149
 
150
  except Exception as e:
151
  return f"Error: {str(e)}"
152
 
153
  def _get_system_prompt(self) -> str:
154
+ """Get system prompt with tools and capabilities"""
155
  tools_desc = "\n".join([
156
  f"- {tool.name}: {tool.description}"
157
  for tool in self.tools
158
  ])
159
 
160
+ return """You are a data analysis assistant with interactive visualization capabilities.
 
 
 
 
 
 
 
 
 
 
161
 
162
+ When analyzing data, use Plotly for interactive visualizations. Here are examples:
 
 
 
 
 
163
 
 
164
  ```python
165
  # Create interactive scatter plot
166
+ import plotly.express as px
167
+ fig = px.scatter(df, x='Date', y='Salary', color='Title')
168
+ fig.show() # This will be captured and displayed
169
+
170
+ # Create interactive box plot
171
+ fig = px.box(df, x='Title', y='Salary')
172
  fig.show()
173
 
174
  # Create interactive time series
175
+ fig = px.line(df, x='Date', y='Salary', color='Title')
 
 
 
176
  fig.show()
177
  ```
178
 
179
+ Remember to:
180
+ 1. Always store Plotly figures in a variable named 'fig'
181
+ 2. Use fig.show() to display the plot
182
+ 3. Create clear labels and titles
183
+ 4. Include hover information
184
+ 5. Use colors effectively
185
+
186
+ For static visualizations, you can still use matplotlib:
187
  ```python
188
+ import matplotlib.pyplot as plt
189
  plt.figure(figsize=(10, 6))
190
+ plt.plot(df['Date'], df['Salary'])
 
191
  plt.show()
192
  ```
193
  """