victor7246 commited on
Commit
c5a694c
·
verified ·
1 Parent(s): c62acd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -68
app.py CHANGED
@@ -92,6 +92,9 @@ if __name__ == '__main__':
92
  if "messages" not in st.session_state:
93
  st.session_state.messages = []
94
 
 
 
 
95
  for message in st.session_state.messages:
96
  with st.chat_message(message["role"]):
97
  st.markdown(message["content"])
@@ -112,6 +115,7 @@ if __name__ == '__main__':
112
  st.session_state.messages.append({"role": "user", "content": question})
113
 
114
  if 'yes' in q_relevant.lower():
 
115
  with st.status("Retrieving results..."):
116
  #top_table_names = table_search(question, topk=1)['table'].tolist()
117
  questions = extract_question_list(llm, question)
@@ -189,86 +193,89 @@ if __name__ == '__main__':
189
  else:
190
  with st.chat_message("assistant"):
191
  st.markdown("Looks like this question is not related to the database, but a generic. Do you want me to answer it from the table? Otherwise I will use my own knowledge.")
192
- if st.button("Yes"):
193
- question = st.session_state.messages[-1]['question']
194
- with st.status("Retrieving results..."):
195
- #top_table_names = table_search(question, topk=1)['table'].tolist()
196
- questions = extract_question_list(llm, question)
197
-
198
- if type(questions) == list:
199
- responses = []
200
- for q in questions:
201
- top_table_names = extract_table_name(llm, q) #[extract_table_name(llm, q)]
202
- print (top_table_names)
203
- #history = st.session_state['questions']
204
- history = st.session_state['history']
205
- try:
206
- db_chain._call(inputs={'query': q, 'history': history, \
207
- 'table_names_to_use': top_table_names})
208
- except:
209
- pass
210
-
211
- if db_chain.intermediate_steps.get("result",'') != '':
212
- response = db_chain.intermediate_steps.get("result",'')
213
- elif db_chain.intermediate_steps.get("sql_data",'') != '':
214
- out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data'])
215
- response = tabulate(out, headers='keys', tablefmt='psql')
216
- else:
217
- response = ""
218
-
219
- if "SQLQuery" in response or "Answer:" in response:
220
- response = ""
221
-
222
- responses.append(response)
223
- st.session_state['history'].append(db_chain.intermediate_steps.get("result",''))
224
-
225
- response = "\n\n".join(responses)
226
- if response == "":
227
- response = "Sorry I may not have answer to this question."
228
- else:
229
- top_table_names = extract_table_name(llm, question) #[extract_table_name(llm, question)]
230
  print (top_table_names)
231
  #history = st.session_state['questions']
232
  history = st.session_state['history']
233
- #try:
234
- db_chain._call(inputs={'query': question, 'history': history, \
235
  'table_names_to_use': top_table_names})
236
- #except:
237
- # pass
238
-
239
  if db_chain.intermediate_steps.get("result",'') != '':
240
- #st.markdown("Answer to your question is - " + db_chain.intermediate_steps.get("result",''))
241
- #print (db_chain.intermediate_steps.get("result",''))
242
- #st.markdown("The SQL query is - " + db_chain.intermediate_steps['sql_cmd'])
243
- response = "Answer to your question is - " + db_chain.intermediate_steps.get("result",'')
244
  elif db_chain.intermediate_steps.get("sql_data",'') != '':
245
  out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data'])
246
- #st.markdown("Here is your result in a table format")
247
- #st.table(out)
248
- #st.markdown("The SQL query is - " + db_chain.intermediate_steps['sql_cmd'])
249
  response = tabulate(out, headers='keys', tablefmt='psql')
250
- elif db_chain.intermediate_steps.get("sql_cmd_unchecked",'') == '':
251
- #print (db_chain)
252
- #st.markdown("Sorry I cannot answer that. Please try again later.")
253
- response = "Sorry I may not have answer to this question."
254
  else:
255
- #st.markdown("Sorry I cannot answer that. Please try again later.")
256
- response = "Sorry I may not have answer to this question."
257
-
258
  if "SQLQuery" in response or "Answer:" in response:
259
- response = "Sorry I may not have answer to this question."
260
 
 
261
  st.session_state['history'].append(db_chain.intermediate_steps.get("result",''))
262
-
263
- with st.chat_message("assistant"):
264
- st.markdown(response)
265
- # Add assistant response to chat history
266
- st.session_state.messages.append({"role": "assistant", "content": response})
267
- elif st.button("No"):
268
- question = st.session_state.messages[-1]['question']
269
- response = llm.invoke(question).content
270
- with st.chat_message("assistant"):
271
- st.markdown(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  if st.button("Reset Chat History"):
274
  #st.session_state['questions'] = []
 
92
  if "messages" not in st.session_state:
93
  st.session_state.messages = []
94
 
95
+ if "last_message_failed" not in st.session_state:
96
+ st.session_state.last_message_failed = False
97
+
98
  for message in st.session_state.messages:
99
  with st.chat_message(message["role"]):
100
  st.markdown(message["content"])
 
115
  st.session_state.messages.append({"role": "user", "content": question})
116
 
117
  if 'yes' in q_relevant.lower():
118
+ st.session_state.last_message_failed = False
119
  with st.status("Retrieving results..."):
120
  #top_table_names = table_search(question, topk=1)['table'].tolist()
121
  questions = extract_question_list(llm, question)
 
193
  else:
194
  with st.chat_message("assistant"):
195
  st.markdown("Looks like this question is not related to the database, but a generic. Do you want me to answer it from the table? Otherwise I will use my own knowledge.")
196
+ st.session_state.last_message_failed = True
197
+
198
+ if st.session_state.last_message_failed == True:
199
+ if st.button("Yes"):
200
+ question = st.session_state.messages[-1]['question']
201
+ with st.status("Retrieving results..."):
202
+ #top_table_names = table_search(question, topk=1)['table'].tolist()
203
+ questions = extract_question_list(llm, question)
204
+
205
+ if type(questions) == list:
206
+ responses = []
207
+ for q in questions:
208
+ top_table_names = extract_table_name(llm, q) #[extract_table_name(llm, q)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  print (top_table_names)
210
  #history = st.session_state['questions']
211
  history = st.session_state['history']
212
+ try:
213
+ db_chain._call(inputs={'query': q, 'history': history, \
214
  'table_names_to_use': top_table_names})
215
+ except:
216
+ pass
217
+
218
  if db_chain.intermediate_steps.get("result",'') != '':
219
+ response = db_chain.intermediate_steps.get("result",'')
 
 
 
220
  elif db_chain.intermediate_steps.get("sql_data",'') != '':
221
  out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data'])
 
 
 
222
  response = tabulate(out, headers='keys', tablefmt='psql')
 
 
 
 
223
  else:
224
+ response = ""
225
+
 
226
  if "SQLQuery" in response or "Answer:" in response:
227
+ response = ""
228
 
229
+ responses.append(response)
230
  st.session_state['history'].append(db_chain.intermediate_steps.get("result",''))
231
+
232
+ response = "\n\n".join(responses)
233
+ if response == "":
234
+ response = "Sorry I may not have answer to this question."
235
+ else:
236
+ top_table_names = extract_table_name(llm, question) #[extract_table_name(llm, question)]
237
+ print (top_table_names)
238
+ #history = st.session_state['questions']
239
+ history = st.session_state['history']
240
+ #try:
241
+ db_chain._call(inputs={'query': question, 'history': history, \
242
+ 'table_names_to_use': top_table_names})
243
+ #except:
244
+ # pass
245
+
246
+ if db_chain.intermediate_steps.get("result",'') != '':
247
+ #st.markdown("Answer to your question is - " + db_chain.intermediate_steps.get("result",''))
248
+ #print (db_chain.intermediate_steps.get("result",''))
249
+ #st.markdown("The SQL query is - " + db_chain.intermediate_steps['sql_cmd'])
250
+ response = "Answer to your question is - " + db_chain.intermediate_steps.get("result",'')
251
+ elif db_chain.intermediate_steps.get("sql_data",'') != '':
252
+ out = pd.DataFrame.from_dict(db_chain.intermediate_steps['sql_data'])
253
+ #st.markdown("Here is your result in a table format")
254
+ #st.table(out)
255
+ #st.markdown("The SQL query is - " + db_chain.intermediate_steps['sql_cmd'])
256
+ response = tabulate(out, headers='keys', tablefmt='psql')
257
+ elif db_chain.intermediate_steps.get("sql_cmd_unchecked",'') == '':
258
+ #print (db_chain)
259
+ #st.markdown("Sorry I cannot answer that. Please try again later.")
260
+ response = "Sorry I may not have answer to this question."
261
+ else:
262
+ #st.markdown("Sorry I cannot answer that. Please try again later.")
263
+ response = "Sorry I may not have answer to this question."
264
+
265
+ if "SQLQuery" in response or "Answer:" in response:
266
+ response = "Sorry I may not have answer to this question."
267
+
268
+ st.session_state['history'].append(db_chain.intermediate_steps.get("result",''))
269
+
270
+ with st.chat_message("assistant"):
271
+ st.markdown(response)
272
+ # Add assistant response to chat history
273
+ st.session_state.messages.append({"role": "assistant", "content": response})
274
+ elif st.button("No"):
275
+ question = st.session_state.messages[-1]['question']
276
+ response = llm.invoke(question).content
277
+ with st.chat_message("assistant"):
278
+ st.markdown(response)
279
 
280
  if st.button("Reset Chat History"):
281
  #st.session_state['questions'] = []