rairo commited on
Commit
0f73c76
·
verified ·
1 Parent(s): 168ac5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -17
app.py CHANGED
@@ -11,19 +11,78 @@ import pandas as pd
11
  from pandasai import SmartDataframe, SmartDatalake
12
  from pandasai.responses.response_parser import ResponseParser
13
  from pandasai.llm import GoogleGemini
 
 
 
 
14
 
15
  class StreamLitResponse(ResponseParser):
16
- def __init__(self,context) -> None:
17
- super().__init__(context)
18
- def format_dataframe(self,result):
19
- st.dataframe(result['value'])
20
- return
21
- def format_plot(self,result):
22
- st.image(result['value'], use_container_width=True)
23
- return
24
- def format_other(self, result):
25
- st.write(result['value'])
26
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  load_dotenv() # Load environment variables at the beginning
@@ -113,7 +172,7 @@ def main():
113
  with st.sidebar:
114
  st.subheader("Your files")
115
  uploaded_files = st.file_uploader(
116
- "Upload PDFs, CSVs, or Excel files (up to 3)", accept_multiple_files=True
117
  )
118
 
119
  if st.button("Process"):
@@ -127,20 +186,22 @@ def main():
127
  file_extension = uploaded_file.name.split(".")[-1].lower()
128
 
129
  if file_extension == "pdf":
130
- if data_uploaded: # Check if data was already uploaded
131
  if st.session_state.dfs:
132
  st.session_state.dfs = None
133
  data_uploaded = False
134
- st.warning("Switching to PDF mode. Data files removed.") # Inform user
 
135
  pdf_docs.append(uploaded_file)
136
  pdf_uploaded = True
137
  elif file_extension in ["csv", "xlsx", "xls"]:
138
- if pdf_uploaded: # Check if PDF was already uploaded
139
  if st.session_state.vectorstore:
140
  st.session_state.vectorstore = None
141
  st.session_state.conversation = None
142
  pdf_uploaded = False
143
- st.warning("Switching to Data mode. PDF files removed.") # Inform user
 
144
  try:
145
  if file_extension == 'csv':
146
  df = pd.read_csv(uploaded_file)
@@ -166,6 +227,15 @@ def main():
166
  else:
167
  st.session_state.dfs = None
168
 
 
 
 
 
 
 
 
 
 
 
169
  if __name__ == "__main__":
170
  main()
171
-
 
11
  from pandasai import SmartDataframe, SmartDatalake
12
  from pandasai.responses.response_parser import ResponseParser
13
  from pandasai.llm import GoogleGemini
14
+ import plotly.graph_objects as go
15
+ from PIL import Image
16
+ import io
17
+ import base64
18
 
19
  class StreamLitResponse(ResponseParser):
20
+ def __init__(self, context):
21
+ super().__init__(context)
22
+
23
+ def format_dataframe(self, result):
24
+ st.dataframe(result['value'])
25
+ return
26
+
27
+ def format_plot(self, result):
28
+ try:
29
+ image = result['value']
30
+
31
+ if isinstance(image, Image.Image): # PIL Image
32
+ img_bytes = io.BytesIO()
33
+ image.save(img_bytes, format="PNG")
34
+ img_bytes = img_bytes.getvalue()
35
+ encoded = base64.b64encode(img_bytes).decode("ascii")
36
+
37
+ st.image(image) # Display the image
38
+ fig = go.Figure(data=[go.Image(source=f'data:image/png;base64,{encoded}')])
39
+ fig.update_layout(
40
+ margin=dict(l=0, r=0, b=0, t=0),
41
+ xaxis=dict(visible=False),
42
+ yaxis=dict(visible=False),
43
+ )
44
+ st.plotly_chart(fig)
45
+
46
+ elif isinstance(image, bytes): # Bytes
47
+ encoded = base64.b64encode(image).decode("ascii")
48
+ fig = go.Figure(data=[go.Image(source=f'data:image/png;base64,{encoded}')])
49
+ fig.update_layout(
50
+ margin=dict(l=0, r=0, b=0, t=0),
51
+ xaxis=dict(visible=False),
52
+ yaxis=dict(visible=False),
53
+ )
54
+ st.plotly_chart(fig)
55
+
56
+ elif isinstance(image, str) and os.path.exists(image): # File Path
57
+ with open(image, "rb") as f:
58
+ encoded = base64.b64encode(f.read()).decode("ascii")
59
+ fig = go.Figure(data=[go.Image(source=f'data:image/png;base64,{encoded}')])
60
+ fig.update_layout(
61
+ margin=dict(l=0, r=0, b=0, t=0),
62
+ xaxis=dict(visible=False),
63
+ yaxis=dict(visible=False),
64
+ )
65
+ st.plotly_chart(fig)
66
+
67
+ elif isinstance(image, str): # Base64 encoded string
68
+ fig = go.Figure(data=[go.Image(source=f'data:image/png;base64,{image}')])
69
+ fig.update_layout(
70
+ margin=dict(l=0, r=0, b=0, t=0),
71
+ xaxis=dict(visible=False),
72
+ yaxis=dict(visible=False),
73
+ )
74
+ st.plotly_chart(fig)
75
+
76
+ else:
77
+ st.write("Unsupported Image format")
78
+
79
+ except Exception as e:
80
+ st.image(image)
81
+ st.write(f"Error displaying image: {e}")
82
+
83
+ def format_other(self, result):
84
+ st.write(result['value'])
85
+ return
86
 
87
 
88
  load_dotenv() # Load environment variables at the beginning
 
172
  with st.sidebar:
173
  st.subheader("Your files")
174
  uploaded_files = st.file_uploader(
175
+ "Upload PDFs, CSVs, or Excel files (up to 3)", accept_multiple_files=True, key="file_uploader" # Add a key
176
  )
177
 
178
  if st.button("Process"):
 
186
  file_extension = uploaded_file.name.split(".")[-1].lower()
187
 
188
  if file_extension == "pdf":
189
+ if data_uploaded:
190
  if st.session_state.dfs:
191
  st.session_state.dfs = None
192
  data_uploaded = False
193
+ st.warning("Switching to PDF mode. Data files removed.")
194
+ st.experimental_rerun() # Rerun to clear file uploader
195
  pdf_docs.append(uploaded_file)
196
  pdf_uploaded = True
197
  elif file_extension in ["csv", "xlsx", "xls"]:
198
+ if pdf_uploaded:
199
  if st.session_state.vectorstore:
200
  st.session_state.vectorstore = None
201
  st.session_state.conversation = None
202
  pdf_uploaded = False
203
+ st.warning("Switching to Data mode. PDF files removed.")
204
+ st.rerun() # Rerun to clear file uploader
205
  try:
206
  if file_extension == 'csv':
207
  df = pd.read_csv(uploaded_file)
 
227
  else:
228
  st.session_state.dfs = None
229
 
230
+ # Example of creating a bar chart with Plotly Express (replace with your actual data)
231
+ if st.session_state.dfs:
232
+ for df_ in st.session_state.dfs: #Iterate over dataframes
233
+ try:
234
+ fig = px.bar(df_, x=df_.columns[0], y=df_.columns[1], title="Example Bar Chart") #Assumes first two columns are x and y
235
+ fig.update_layout(xaxis_tickangle=-45) # Rotate x-axis labels for better readability
236
+ st.plotly_chart(fig)
237
+ except Exception as e:
238
+ st.write(f"Error plotting chart for dataframe: {e}") #Handle exceptions when plotting
239
+
240
  if __name__ == "__main__":
241
  main()