rairo commited on
Commit
3a9cf4b
·
verified ·
1 Parent(s): bb94e23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -80
app.py CHANGED
@@ -9,7 +9,7 @@ from langchain.chains import ConversationalRetrievalChain
9
  import os
10
  import pandas as pd
11
  from pandasai import SmartDataframe, SmartDatalake
12
- from pandasai.responses.response_parser import ResponseParser
13
  from pandasai.llm import GoogleGemini
14
  import plotly.graph_objects as go
15
  from PIL import Image
@@ -17,33 +17,67 @@ import io
17
  import base64
18
 
19
  class StreamLitResponse(ResponseParser):
20
- def __init__(self,context) -> None:
21
- super().__init__(context)
22
- def format_dataframe(self,result):
23
- st.dataframe(result['value'])
24
- return
25
- def format_plot(self,result):
26
- st.image(result['value'])
27
- return
28
- def format_other(self, result):
29
- st.write(result['value'])
30
- return
31
-
32
-
33
- load_dotenv() # Load environment variables at the beginning
34
- GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') #Use .get to handle if the variable is not present
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  if not GOOGLE_API_KEY:
37
  st.error("GOOGLE_API_KEY environment variable not set.")
38
  st.stop()
39
 
40
-
41
  def generateResponse(prompt, dfs):
 
42
  llm = GoogleGemini(api_key=GOOGLE_API_KEY)
43
- pandas_agent = SmartDataframe(dfs,config={"llm":llm, "response_parser":StreamLitResponse})
44
- answer = pandas_agent.chat(prompt)
45
- return answer
46
-
 
 
 
 
47
  # Processing pdfs
48
  def get_pdf_text(pdf_docs):
49
  text = ""
@@ -72,46 +106,6 @@ def get_vectorstore(text_chunks):
72
  vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
73
  return vectorstore
74
 
75
- #handle user input
76
- def handle_userinput(question, pdf_vectorstore, dfs):
77
- if pdf_vectorstore and st.session_state.conversation:
78
- response = st.session_state.conversation({"question": question})
79
- st.session_state.chat_history.append({"role": "user", "content": question})
80
- assistant_response = response.get('answer')
81
-
82
- if assistant_response: # Check if assistant_response is not None or empty
83
- st.session_state.chat_history.append({"role": "assistant", "content": assistant_response}) # Directly add string
84
-
85
- st.rerun()
86
-
87
- elif dfs: # PandasAI
88
- assistant_response = generateResponse(question, dfs) # Get the single response
89
-
90
- st.session_state.chat_history.append({"role": "user", "content": question})
91
-
92
- if assistant_response: # Check if assistant_response is not None or empty
93
- if isinstance(assistant_response, dict) and 'value' in assistant_response:
94
- content_type = assistant_response.get('type')
95
- content_value = assistant_response.get('value')
96
-
97
- if content_type == "dataframe":
98
- st.session_state.chat_history.append({"role": "assistant", "content": "DataFrame"})
99
- st.session_state.chat_history.append({"role": "assistant", "dataframe": content_value})
100
- elif content_type == "plot":
101
- st.session_state.chat_history.append({"role": "assistant", "content": "Plot"})
102
- st.session_state.chat_history.append({"role": "assistant", "plot": content_value})
103
- else: # Text or other
104
- st.session_state.chat_history.append({"role": "assistant", "content": assistant_response})
105
-
106
- else: # Text or other (including None if that's what it is)
107
- st.session_state.chat_history.append({"role": "assistant", "content": assistant_response})
108
-
109
- st.rerun()
110
- return # Exit early after PandasAI processing
111
-
112
- else:
113
- st.write("Please upload and process your documents/data first.")
114
-
115
  def get_conversation_chain(vectorstore):
116
  llm = ChatGoogleGenerativeAI(model='gemini-2.0-flash-exp')
117
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
@@ -122,9 +116,99 @@ def get_conversation_chain(vectorstore):
122
  )
123
  return conversation_chain
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def main():
126
  st.set_page_config(page_title="Chat with PDFs and Data", page_icon=":books:")
127
 
 
128
  if "conversation" not in st.session_state:
129
  st.session_state.conversation = None
130
  if "chat_history" not in st.session_state:
@@ -136,36 +220,24 @@ def main():
136
 
137
  st.title("Chat with PDFs and Data :books: :bar_chart:")
138
 
139
- # Chat display
140
  for message in st.session_state.chat_history:
141
  with st.chat_message(message["role"]):
142
- if "dataframe" in message:
143
- st.dataframe(message["dataframe"])
144
- elif "plot" in message:
145
- if isinstance(message["plot"], Image.Image):
146
- st.image(message["plot"])
147
- elif isinstance(message["plot"], go.Figure):
148
- st.plotly_chart(message["plot"])
149
- elif isinstance(message["plot"], bytes):
150
- try:
151
- image = Image.open(io.BytesIO(message["plot"]))
152
- st.image(image)
153
- except Exception as e:
154
- st.error(f"Error displaying image: {e}")
155
- else:
156
- st.write("Unsupported plot format")
157
- else:
158
- st.write(message["content"])
159
 
 
160
  user_question = st.chat_input("Ask a question about your documents or data:")
161
 
162
  if user_question:
163
  handle_userinput(user_question, st.session_state.vectorstore, st.session_state.dfs)
164
 
 
165
  with st.sidebar:
166
  st.subheader("Your files")
167
  uploaded_files = st.file_uploader(
168
- "Upload PDFs, CSVs, or Excel files (up to 3)", accept_multiple_files=True, key="file_uploader"
 
 
169
  )
170
 
171
  if st.button("Process"):
@@ -175,6 +247,7 @@ def main():
175
  pdf_uploaded = False
176
  data_uploaded = False
177
 
 
178
  for uploaded_file in uploaded_files:
179
  file_extension = uploaded_file.name.split(".")[-1].lower()
180
 
@@ -204,6 +277,7 @@ def main():
204
  st.error(f"Error reading {uploaded_file.name}: {e}")
205
  st.stop()
206
 
 
207
  if pdf_docs:
208
  raw_text = get_pdf_text(pdf_docs)
209
  text_chunks = get_text_chunks(raw_text)
@@ -213,6 +287,7 @@ def main():
213
  st.session_state.vectorstore = None
214
  st.session_state.conversation = None
215
 
 
216
  if dfs:
217
  st.session_state.dfs = dfs
218
  else:
 
9
  import os
10
  import pandas as pd
11
  from pandasai import SmartDataframe, SmartDatalake
12
+ from pandasai.responses.response_parser import ResponseParser
13
  from pandasai.llm import GoogleGemini
14
  import plotly.graph_objects as go
15
  from PIL import Image
 
17
  import base64
18
 
19
  class StreamLitResponse(ResponseParser):
20
+ def __init__(self, context):
21
+ super().__init__(context)
22
+
23
+ def format_dataframe(self, result):
24
+ """Enhanced DataFrame rendering with type identifier"""
25
+ return {
26
+ 'type': 'dataframe',
27
+ 'value': result['value']
28
+ }
29
+
30
+ def format_plot(self, result):
31
+ """Enhanced plot rendering with type identifier"""
32
+ try:
33
+ image = result['value']
34
+
35
+ # Convert image to base64 for consistent storage
36
+ if isinstance(image, Image.Image):
37
+ buffered = io.BytesIO()
38
+ image.save(buffered, format="PNG")
39
+ base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
40
+ elif isinstance(image, bytes):
41
+ base64_image = base64.b64encode(image).decode('utf-8')
42
+ elif isinstance(image, str) and os.path.exists(image):
43
+ with open(image, "rb") as f:
44
+ base64_image = base64.b64encode(f.read()).decode('utf-8')
45
+ else:
46
+ return {'type': 'text', 'value': "Unsupported image format"}
47
+
48
+ return {
49
+ 'type': 'plot',
50
+ 'value': base64_image
51
+ }
52
+ except Exception as e:
53
+ return {'type': 'text', 'value': f"Error processing plot: {e}"}
54
+
55
+ def format_other(self, result):
56
+ """Handle other types of responses"""
57
+ return {
58
+ 'type': 'text',
59
+ 'value': str(result['value'])
60
+ }
61
+
62
+ # Load environment variables
63
+ load_dotenv()
64
+ GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
65
 
66
  if not GOOGLE_API_KEY:
67
  st.error("GOOGLE_API_KEY environment variable not set.")
68
  st.stop()
69
 
 
70
  def generateResponse(prompt, dfs):
71
+ """Generate response using PandasAI"""
72
  llm = GoogleGemini(api_key=GOOGLE_API_KEY)
73
+ pandas_agent = SmartDatalake(dfs, config={
74
+ "llm": llm,
75
+ "response_parser": StreamLitResponse
76
+ })
77
+ return pandas_agent.chat(prompt)
78
+
79
+ # Other utility functions remain the same as in the original code
80
+ # (get_pdf_text, get_text_chunks, get_vectorstore, get_conversation_chain)
81
  # Processing pdfs
82
  def get_pdf_text(pdf_docs):
83
  text = ""
 
106
  vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
107
  return vectorstore
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def get_conversation_chain(vectorstore):
110
  llm = ChatGoogleGenerativeAI(model='gemini-2.0-flash-exp')
111
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
 
116
  )
117
  return conversation_chain
118
 
119
+ def render_chat_message(message):
120
+ """Render different types of chat messages"""
121
+ if "dataframe" in message:
122
+ st.dataframe(message["dataframe"])
123
+ elif "plot" in message:
124
+ try:
125
+ # Handle base64 encoded images
126
+ plot_data = message["plot"]
127
+ if isinstance(plot_data, str):
128
+ st.image(f"data:image/png;base64,{plot_data}")
129
+ elif isinstance(plot_data, Image.Image):
130
+ st.image(plot_data)
131
+ elif isinstance(plot_data, go.Figure):
132
+ st.plotly_chart(plot_data)
133
+ elif isinstance(plot_data, bytes):
134
+ image = Image.open(io.BytesIO(plot_data))
135
+ st.image(image)
136
+ else:
137
+ st.write("Unsupported plot format")
138
+ except Exception as e:
139
+ st.error(f"Error rendering plot: {e}")
140
+
141
+ # Always render text content
142
+ if "content" in message:
143
+ st.markdown(message["content"])
144
+
145
+ def handle_userinput(question, pdf_vectorstore, dfs):
146
+ """Enhanced input handling with robust content processing"""
147
+ try:
148
+ if pdf_vectorstore and st.session_state.conversation:
149
+ # PDF/Vector search mode
150
+ response = st.session_state.conversation({"question": question})
151
+ st.session_state.chat_history.append({
152
+ "role": "user",
153
+ "content": question
154
+ })
155
+
156
+ assistant_response = response.get('answer', '')
157
+ st.session_state.chat_history.append({
158
+ "role": "assistant",
159
+ "content": assistant_response
160
+ })
161
+
162
+ elif dfs:
163
+ # PandasAI data analysis mode
164
+ st.session_state.chat_history.append({
165
+ "role": "user",
166
+ "content": question
167
+ })
168
+
169
+ # Generate response with PandasAI
170
+ result = generateResponse(question, dfs)
171
+
172
+ # Handle different response types
173
+ if isinstance(result, dict):
174
+ response_type = result.get('type', 'text')
175
+ response_value = result.get('value')
176
+
177
+ if response_type == 'dataframe':
178
+ st.session_state.chat_history.append({
179
+ "role": "assistant",
180
+ "content": "Here's the DataFrame analysis:",
181
+ "dataframe": response_value
182
+ })
183
+ elif response_type == 'plot':
184
+ st.session_state.chat_history.append({
185
+ "role": "assistant",
186
+ "content": "Here's the visualization:",
187
+ "plot": response_value
188
+ })
189
+ else:
190
+ st.session_state.chat_history.append({
191
+ "role": "assistant",
192
+ "content": str(response_value)
193
+ })
194
+ else:
195
+ st.session_state.chat_history.append({
196
+ "role": "assistant",
197
+ "content": str(result)
198
+ })
199
+
200
+ else:
201
+ st.write("Please upload and process your documents/data first.")
202
+
203
+ st.rerun()
204
+
205
+ except Exception as e:
206
+ st.error(f"Error processing input: {e}")
207
+
208
  def main():
209
  st.set_page_config(page_title="Chat with PDFs and Data", page_icon=":books:")
210
 
211
+ # Initialize session state variables
212
  if "conversation" not in st.session_state:
213
  st.session_state.conversation = None
214
  if "chat_history" not in st.session_state:
 
220
 
221
  st.title("Chat with PDFs and Data :books: :bar_chart:")
222
 
223
+ # Chat display with enhanced rendering
224
  for message in st.session_state.chat_history:
225
  with st.chat_message(message["role"]):
226
+ render_chat_message(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ # Chat input
229
  user_question = st.chat_input("Ask a question about your documents or data:")
230
 
231
  if user_question:
232
  handle_userinput(user_question, st.session_state.vectorstore, st.session_state.dfs)
233
 
234
+ # Sidebar for file upload
235
  with st.sidebar:
236
  st.subheader("Your files")
237
  uploaded_files = st.file_uploader(
238
+ "Upload PDFs, CSVs, or Excel files (up to 3)",
239
+ accept_multiple_files=True,
240
+ key="file_uploader"
241
  )
242
 
243
  if st.button("Process"):
 
247
  pdf_uploaded = False
248
  data_uploaded = False
249
 
250
+ # File processing logic remains the same as in the original code
251
  for uploaded_file in uploaded_files:
252
  file_extension = uploaded_file.name.split(".")[-1].lower()
253
 
 
277
  st.error(f"Error reading {uploaded_file.name}: {e}")
278
  st.stop()
279
 
280
+ # Set up vectorstore and conversation chain for PDFs
281
  if pdf_docs:
282
  raw_text = get_pdf_text(pdf_docs)
283
  text_chunks = get_text_chunks(raw_text)
 
287
  st.session_state.vectorstore = None
288
  st.session_state.conversation = None
289
 
290
+ # Set up DataFrames for PandasAI
291
  if dfs:
292
  st.session_state.dfs = dfs
293
  else: