SVashishta1
commited on
Commit
·
58eb965
1
Parent(s):
1822204
Feature: Add download visualization feature and chat without document upload
Browse files
app.py
CHANGED
|
@@ -203,7 +203,7 @@ def process_text_query(query, history):
|
|
| 203 |
viz_type = vtype
|
| 204 |
break
|
| 205 |
|
| 206 |
-
# Check if we're in CSV context
|
| 207 |
if current_context["file_type"] == "csv" and current_context["table_name"]:
|
| 208 |
try:
|
| 209 |
# Connect to the database
|
|
@@ -318,8 +318,8 @@ def process_text_query(query, history):
|
|
| 318 |
history[-1][1] = error_msg
|
| 319 |
return error_msg, history
|
| 320 |
|
| 321 |
-
|
| 322 |
-
# Handle
|
| 323 |
try:
|
| 324 |
response = document_assistant.process_query(query)
|
| 325 |
history[-1][1] = response
|
|
@@ -328,6 +328,26 @@ def process_text_query(query, history):
|
|
| 328 |
error_msg = f"Error processing query: {str(e)}"
|
| 329 |
history[-1][1] = error_msg
|
| 330 |
return error_msg, history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
def process_file_upload(files):
|
| 333 |
"""Process uploaded files and index them"""
|
|
@@ -866,8 +886,10 @@ with gr.Blocks(title="LLM Powered Database Chatbot") as demo:
|
|
| 866 |
with gr.Row():
|
| 867 |
clear_viz_btn = gr.Button("🗑️ Clear Visualization")
|
| 868 |
save_viz_btn = gr.Button("💾 Save Visualization")
|
|
|
|
| 869 |
|
| 870 |
save_status = gr.Textbox(label="Save Status", visible=False)
|
|
|
|
| 871 |
|
| 872 |
# Add information about capabilities
|
| 873 |
gr.Markdown("""
|
|
@@ -886,18 +908,50 @@ with gr.Blocks(title="LLM Powered Database Chatbot") as demo:
|
|
| 886 |
return "No visualization to save", gr.update(visible=True)
|
| 887 |
|
| 888 |
try:
|
| 889 |
-
#
|
| 890 |
-
|
| 891 |
-
filename = f"visualization_{timestamp}.html"
|
| 892 |
-
filepath = os.path.join(DATA_DIR, filename)
|
| 893 |
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 897 |
|
| 898 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
except Exception as e:
|
| 900 |
-
|
|
|
|
| 901 |
|
| 902 |
clear_viz_btn.click(
|
| 903 |
clear_visualization,
|
|
@@ -909,6 +963,12 @@ with gr.Blocks(title="LLM Powered Database Chatbot") as demo:
|
|
| 909 |
inputs=[current_visualization],
|
| 910 |
outputs=[save_status, save_status]
|
| 911 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
|
| 913 |
# Update the process_text_query function to handle visualizations
|
| 914 |
def process_text_query_with_visualization(query, history, current_viz):
|
|
|
|
| 203 |
viz_type = vtype
|
| 204 |
break
|
| 205 |
|
| 206 |
+
# Check if we're in CSV context or have documents loaded
|
| 207 |
if current_context["file_type"] == "csv" and current_context["table_name"]:
|
| 208 |
try:
|
| 209 |
# Connect to the database
|
|
|
|
| 318 |
history[-1][1] = error_msg
|
| 319 |
return error_msg, history
|
| 320 |
|
| 321 |
+
elif document_assistant.get_all_documents():
|
| 322 |
+
# Handle document queries
|
| 323 |
try:
|
| 324 |
response = document_assistant.process_query(query)
|
| 325 |
history[-1][1] = response
|
|
|
|
| 328 |
error_msg = f"Error processing query: {str(e)}"
|
| 329 |
history[-1][1] = error_msg
|
| 330 |
return error_msg, history
|
| 331 |
+
|
| 332 |
+
else:
|
| 333 |
+
# Handle general queries with LLM when no documents are loaded
|
| 334 |
+
try:
|
| 335 |
+
# Create a general knowledge context prompt
|
| 336 |
+
general_prompt = ChatPromptTemplate.from_messages([
|
| 337 |
+
("system", "You are a helpful assistant that provides clear, informative responses. Use your knowledge to answer the user's question concisely."),
|
| 338 |
+
("human", "{question}")
|
| 339 |
+
])
|
| 340 |
+
|
| 341 |
+
# Get response from LLM
|
| 342 |
+
response = llm.invoke(general_prompt.format(question=query)).content
|
| 343 |
+
|
| 344 |
+
# Add the response to history
|
| 345 |
+
history[-1][1] = response
|
| 346 |
+
return response, history
|
| 347 |
+
except Exception as e:
|
| 348 |
+
error_msg = f"Error processing query: {str(e)}"
|
| 349 |
+
history[-1][1] = error_msg
|
| 350 |
+
return error_msg, history
|
| 351 |
|
| 352 |
def process_file_upload(files):
|
| 353 |
"""Process uploaded files and index them"""
|
|
|
|
| 886 |
with gr.Row():
|
| 887 |
clear_viz_btn = gr.Button("🗑️ Clear Visualization")
|
| 888 |
save_viz_btn = gr.Button("💾 Save Visualization")
|
| 889 |
+
download_btn = gr.Button("📥 Download Visualization")
|
| 890 |
|
| 891 |
save_status = gr.Textbox(label="Save Status", visible=False)
|
| 892 |
+
download_img = gr.Image(visible=False, type="pil", label="Download Image")
|
| 893 |
|
| 894 |
# Add information about capabilities
|
| 895 |
gr.Markdown("""
|
|
|
|
| 908 |
return "No visualization to save", gr.update(visible=True)
|
| 909 |
|
| 910 |
try:
|
| 911 |
+
# Extract the base64 image data from the HTML
|
| 912 |
+
img_data_match = re.search(r'src=\'data:image/png;base64,([^\']+)\'', viz_html)
|
|
|
|
|
|
|
| 913 |
|
| 914 |
+
if img_data_match:
|
| 915 |
+
# Get the base64 data
|
| 916 |
+
img_data = img_data_match.group(1)
|
| 917 |
+
|
| 918 |
+
# Create a downloadable file
|
| 919 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 920 |
+
filename = f"visualization_{timestamp}.png"
|
| 921 |
+
|
| 922 |
+
# Return success message with download link
|
| 923 |
+
return f"Visualization ready for download", gr.update(visible=True)
|
| 924 |
+
else:
|
| 925 |
+
return "Could not extract image data", gr.update(visible=True)
|
| 926 |
+
except Exception as e:
|
| 927 |
+
return f"Error preparing visualization for download: {str(e)}", gr.update(visible=True)
|
| 928 |
+
|
| 929 |
+
def download_visualization(viz_html):
|
| 930 |
+
if not viz_html:
|
| 931 |
+
return None
|
| 932 |
+
|
| 933 |
+
try:
|
| 934 |
+
# Extract the base64 image data from the HTML
|
| 935 |
+
img_data_match = re.search(r'src=\'data:image/png;base64,([^\']+)\'', viz_html)
|
| 936 |
|
| 937 |
+
if img_data_match:
|
| 938 |
+
# Get the base64 data
|
| 939 |
+
img_data = img_data_match.group(1)
|
| 940 |
+
|
| 941 |
+
# Convert base64 to image
|
| 942 |
+
import base64
|
| 943 |
+
from io import BytesIO
|
| 944 |
+
from PIL import Image
|
| 945 |
+
|
| 946 |
+
image_data = base64.b64decode(img_data)
|
| 947 |
+
image = Image.open(BytesIO(image_data))
|
| 948 |
+
|
| 949 |
+
return image, gr.update(visible=True)
|
| 950 |
+
else:
|
| 951 |
+
return None, gr.update(visible=False)
|
| 952 |
except Exception as e:
|
| 953 |
+
print(f"Error downloading visualization: {str(e)}")
|
| 954 |
+
return None, gr.update(visible=False)
|
| 955 |
|
| 956 |
clear_viz_btn.click(
|
| 957 |
clear_visualization,
|
|
|
|
| 963 |
inputs=[current_visualization],
|
| 964 |
outputs=[save_status, save_status]
|
| 965 |
)
|
| 966 |
+
|
| 967 |
+
download_btn.click(
|
| 968 |
+
download_visualization,
|
| 969 |
+
inputs=[current_visualization],
|
| 970 |
+
outputs=[download_img, download_img]
|
| 971 |
+
)
|
| 972 |
|
| 973 |
# Update the process_text_query function to handle visualizations
|
| 974 |
def process_text_query_with_visualization(query, history, current_viz):
|