SVashishta1 commited on
Commit
58eb965
·
1 Parent(s): 1822204

Feature: Add download visualization feature and chat without document upload

Browse files
Files changed (1) hide show
  1. app.py +72 -12
app.py CHANGED
@@ -203,7 +203,7 @@ def process_text_query(query, history):
203
  viz_type = vtype
204
  break
205
 
206
- # Check if we're in CSV context
207
  if current_context["file_type"] == "csv" and current_context["table_name"]:
208
  try:
209
  # Connect to the database
@@ -318,8 +318,8 @@ def process_text_query(query, history):
318
  history[-1][1] = error_msg
319
  return error_msg, history
320
 
321
- else:
322
- # Handle non-CSV queries (document queries)
323
  try:
324
  response = document_assistant.process_query(query)
325
  history[-1][1] = response
@@ -328,6 +328,26 @@ def process_text_query(query, history):
328
  error_msg = f"Error processing query: {str(e)}"
329
  history[-1][1] = error_msg
330
  return error_msg, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
  def process_file_upload(files):
333
  """Process uploaded files and index them"""
@@ -866,8 +886,10 @@ with gr.Blocks(title="LLM Powered Database Chatbot") as demo:
866
  with gr.Row():
867
  clear_viz_btn = gr.Button("🗑️ Clear Visualization")
868
  save_viz_btn = gr.Button("💾 Save Visualization")
 
869
 
870
  save_status = gr.Textbox(label="Save Status", visible=False)
 
871
 
872
  # Add information about capabilities
873
  gr.Markdown("""
@@ -886,18 +908,50 @@ with gr.Blocks(title="LLM Powered Database Chatbot") as demo:
886
  return "No visualization to save", gr.update(visible=True)
887
 
888
  try:
889
- # Create a unique filename
890
- timestamp = time.strftime("%Y%m%d_%H%M%S")
891
- filename = f"visualization_{timestamp}.html"
892
- filepath = os.path.join(DATA_DIR, filename)
893
 
894
- # Save the visualization
895
- with open(filepath, "w") as f:
896
- f.write(viz_html)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
 
898
- return f"Visualization saved as {filename}", gr.update(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
  except Exception as e:
900
- return f"Error saving visualization: {str(e)}", gr.update(visible=True)
 
901
 
902
  clear_viz_btn.click(
903
  clear_visualization,
@@ -909,6 +963,12 @@ with gr.Blocks(title="LLM Powered Database Chatbot") as demo:
909
  inputs=[current_visualization],
910
  outputs=[save_status, save_status]
911
  )
 
 
 
 
 
 
912
 
913
  # Update the process_text_query function to handle visualizations
914
  def process_text_query_with_visualization(query, history, current_viz):
 
203
  viz_type = vtype
204
  break
205
 
206
+ # Check if we're in CSV context or have documents loaded
207
  if current_context["file_type"] == "csv" and current_context["table_name"]:
208
  try:
209
  # Connect to the database
 
318
  history[-1][1] = error_msg
319
  return error_msg, history
320
 
321
+ elif document_assistant.get_all_documents():
322
+ # Handle document queries
323
  try:
324
  response = document_assistant.process_query(query)
325
  history[-1][1] = response
 
328
  error_msg = f"Error processing query: {str(e)}"
329
  history[-1][1] = error_msg
330
  return error_msg, history
331
+
332
+ else:
333
+ # Handle general queries with LLM when no documents are loaded
334
+ try:
335
+ # Create a general knowledge context prompt
336
+ general_prompt = ChatPromptTemplate.from_messages([
337
+ ("system", "You are a helpful assistant that provides clear, informative responses. Use your knowledge to answer the user's question concisely."),
338
+ ("human", "{question}")
339
+ ])
340
+
341
+ # Get response from LLM
342
+ response = llm.invoke(general_prompt.format(question=query)).content
343
+
344
+ # Add the response to history
345
+ history[-1][1] = response
346
+ return response, history
347
+ except Exception as e:
348
+ error_msg = f"Error processing query: {str(e)}"
349
+ history[-1][1] = error_msg
350
+ return error_msg, history
351
 
352
  def process_file_upload(files):
353
  """Process uploaded files and index them"""
 
886
  with gr.Row():
887
  clear_viz_btn = gr.Button("🗑️ Clear Visualization")
888
  save_viz_btn = gr.Button("💾 Save Visualization")
889
+ download_btn = gr.Button("📥 Download Visualization")
890
 
891
  save_status = gr.Textbox(label="Save Status", visible=False)
892
+ download_img = gr.Image(visible=False, type="pil", label="Download Image")
893
 
894
  # Add information about capabilities
895
  gr.Markdown("""
 
908
  return "No visualization to save", gr.update(visible=True)
909
 
910
  try:
911
+ # Extract the base64 image data from the HTML
912
+ img_data_match = re.search(r'src=\'data:image/png;base64,([^\']+)\'', viz_html)
 
 
913
 
914
+ if img_data_match:
915
+ # Get the base64 data
916
+ img_data = img_data_match.group(1)
917
+
918
+ # Create a downloadable file
919
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
920
+ filename = f"visualization_{timestamp}.png"
921
+
922
+ # Return success message with download link
923
+ return f"Visualization ready for download", gr.update(visible=True)
924
+ else:
925
+ return "Could not extract image data", gr.update(visible=True)
926
+ except Exception as e:
927
+ return f"Error preparing visualization for download: {str(e)}", gr.update(visible=True)
928
+
929
+ def download_visualization(viz_html):
930
+ if not viz_html:
931
+ return None
932
+
933
+ try:
934
+ # Extract the base64 image data from the HTML
935
+ img_data_match = re.search(r'src=\'data:image/png;base64,([^\']+)\'', viz_html)
936
 
937
+ if img_data_match:
938
+ # Get the base64 data
939
+ img_data = img_data_match.group(1)
940
+
941
+ # Convert base64 to image
942
+ import base64
943
+ from io import BytesIO
944
+ from PIL import Image
945
+
946
+ image_data = base64.b64decode(img_data)
947
+ image = Image.open(BytesIO(image_data))
948
+
949
+ return image, gr.update(visible=True)
950
+ else:
951
+ return None, gr.update(visible=False)
952
  except Exception as e:
953
+ print(f"Error downloading visualization: {str(e)}")
954
+ return None, gr.update(visible=False)
955
 
956
  clear_viz_btn.click(
957
  clear_visualization,
 
963
  inputs=[current_visualization],
964
  outputs=[save_status, save_status]
965
  )
966
+
967
+ download_btn.click(
968
+ download_visualization,
969
+ inputs=[current_visualization],
970
+ outputs=[download_img, download_img]
971
+ )
972
 
973
  # Update the process_text_query function to handle visualizations
974
  def process_text_query_with_visualization(query, history, current_viz):