|
|
import os |
|
|
import sys |
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
import tempfile |
|
|
import pandas as pd |
|
|
import sqlite3 |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_groq import ChatGroq |
|
|
import plotly.express as px |
|
|
import time |
|
|
import plotly.io as pio |
|
|
import traceback |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import speech_recognition as sr |
|
|
from gtts import gTTS |
|
|
import re |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from backend.main import DocumentAssistant |
|
|
|
|
|
|
|
|
document_assistant = DocumentAssistant() |
|
|
|
|
|
|
|
|
llm = ChatGroq( |
|
|
model="llama3-8b-8192", |
|
|
temperature=0, |
|
|
max_tokens=None, |
|
|
timeout=None, |
|
|
max_retries=2, |
|
|
verbose=True, |
|
|
api_key=os.getenv("GROQ_API_KEY") |
|
|
) |
|
|
|
|
|
|
|
|
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db") |
|
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
|
|
|
current_plot = None |
|
|
|
|
|
|
|
|
query_prompt = ChatPromptTemplate.from_template(""" |
|
|
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question. |
|
|
|
|
|
Important guidelines: |
|
|
1. Use SQLite syntax (not PostgreSQL or MySQL) |
|
|
2. For date functions, use strftime() instead of EXTRACT |
|
|
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column) |
|
|
3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead |
|
|
4. For percentiles, use window functions or approximate methods |
|
|
5. Keep queries efficient and focused on answering the specific question |
|
|
6. Always use 'data_tab' as the table name |
|
|
|
|
|
Question: {question} |
|
|
|
|
|
SQL Query: |
|
|
""") |
|
|
|
|
|
|
|
|
interpret_prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."), |
|
|
("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}") |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
visualization_prompt = ChatPromptTemplate.from_template(""" |
|
|
You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization. |
|
|
|
|
|
Important guidelines for SQLite syntax: |
|
|
1. Use strftime() for date functions: |
|
|
- Year: strftime('%Y', date_column) |
|
|
- Month: strftime('%m', date_column) |
|
|
- Day: strftime('%d', date_column) |
|
|
- Hour: strftime('%H', date_column) |
|
|
|
|
|
2. For histograms and binning: |
|
|
- Use: CAST((column / bin_size) AS INT) * bin_size |
|
|
- Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin |
|
|
|
|
|
3. For percentiles and statistics: |
|
|
- SQLite doesn't have built-in percentile functions |
|
|
- Use simple aggregations (MIN, MAX, AVG, COUNT) instead |
|
|
|
|
|
4. For time series: |
|
|
- Group by date parts using strftime() |
|
|
- Example: strftime('%Y-%m-%d', pickup_datetime) AS day |
|
|
|
|
|
5. Always use 'data_tab' as the table name |
|
|
|
|
|
Question: {question} |
|
|
Visualization type: {viz_type} |
|
|
|
|
|
SQL Query: |
|
|
""") |
|
|
|
|
|
def process_text_query(query, history): |
|
|
"""Process a text query and update chat history""" |
|
|
if not query: |
|
|
return "", history |
|
|
|
|
|
|
|
|
history.append({"role": "user", "content": query}) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if current_context["file_type"] == "csv" and current_context["table_name"]: |
|
|
try: |
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
|
|
|
|
|
|
cursor = conn.cursor() |
|
|
cursor.execute(f"PRAGMA table_info({current_context['table_name']});") |
|
|
columns = [info[1] for info in cursor.fetchall()] |
|
|
columns_str = ", ".join(columns) |
|
|
|
|
|
|
|
|
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}" |
|
|
|
|
|
|
|
|
ai_msg = query_prompt | llm |
|
|
sql_query = ai_msg.invoke({"question": question_with_context}).content.strip() |
|
|
|
|
|
print(f"Generated SQL Query: {sql_query}") |
|
|
|
|
|
|
|
|
is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend']) |
|
|
|
|
|
try: |
|
|
|
|
|
result_df = pd.read_sql_query(sql_query, conn) |
|
|
|
|
|
|
|
|
if not result_df.empty: |
|
|
data_summary = result_df.describe(include='all').to_string() |
|
|
|
|
|
|
|
|
if len(result_df) <= 10: |
|
|
data_summary += f"\n\nFull Results:\n{result_df.to_string()}" |
|
|
else: |
|
|
data_summary += f"\n\nFirst 5 rows:\n{result_df.head(5).to_string()}" |
|
|
else: |
|
|
data_summary = "No relevant data found." |
|
|
|
|
|
|
|
|
answer_chain = interpret_prompt | llm |
|
|
interpretation = answer_chain.invoke({ |
|
|
"question": query, |
|
|
"sql_query": sql_query, |
|
|
"data_summary": data_summary |
|
|
}).content.strip() |
|
|
|
|
|
|
|
|
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n" |
|
|
|
|
|
if not result_df.empty: |
|
|
if len(result_df) > 10: |
|
|
response += f"**Results (first 5 of {len(result_df)} rows):**\n```\n{result_df.head(5).to_string()}\n```\n\n" |
|
|
else: |
|
|
response += f"**Results:**\n```\n{result_df.to_string()}\n```\n\n" |
|
|
else: |
|
|
response += "**No results found.**\n\n" |
|
|
|
|
|
response += f"**Analysis:**\n{interpretation}" |
|
|
|
|
|
|
|
|
if is_visualization and not result_df.empty: |
|
|
try: |
|
|
print("Visualization requested, attempting to create plot...") |
|
|
|
|
|
|
|
|
|
|
|
is_pie_chart = any(word in query.lower() for word in ['pie chart', 'pie graph', 'distribution']) |
|
|
is_histogram = any(word in query.lower() for word in ['histogram', 'distribution of', 'frequency']) |
|
|
is_heatmap = any(word in query.lower() for word in ['heatmap', 'heat map', 'correlation']) |
|
|
is_scatter = any(word in query.lower() for word in ['scatter', 'relationship between', 'correlation']) |
|
|
|
|
|
if len(result_df.columns) >= 2: |
|
|
|
|
|
numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist() |
|
|
|
|
|
if len(numeric_cols) >= 1 and len(result_df) > 1: |
|
|
|
|
|
if is_pie_chart and len(result_df) <= 20: |
|
|
|
|
|
category_col = result_df.columns[0] |
|
|
value_col = numeric_cols[0] if len(numeric_cols) > 0 else result_df.columns[1] |
|
|
|
|
|
fig = px.pie(result_df, names=category_col, values=value_col, |
|
|
title="Distribution Analysis", |
|
|
hole=0.3) |
|
|
|
|
|
elif is_histogram and len(numeric_cols) > 0: |
|
|
|
|
|
fig = px.histogram(result_df, x=numeric_cols[0], |
|
|
title=f"Distribution of {numeric_cols[0]}", |
|
|
nbins=20) |
|
|
|
|
|
elif is_heatmap and len(numeric_cols) >= 2: |
|
|
|
|
|
|
|
|
if len(result_df.columns) == len(numeric_cols) and len(numeric_cols) > 2: |
|
|
|
|
|
fig = px.imshow(result_df, |
|
|
title="Correlation Heatmap", |
|
|
color_continuous_scale='RdBu_r', |
|
|
aspect="auto") |
|
|
else: |
|
|
|
|
|
corr_df = result_df[numeric_cols].corr() |
|
|
fig = px.imshow(corr_df, |
|
|
title="Correlation Heatmap", |
|
|
color_continuous_scale='RdBu_r', |
|
|
aspect="auto") |
|
|
|
|
|
elif is_scatter and len(numeric_cols) >= 2: |
|
|
|
|
|
fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1], |
|
|
title=f"Relationship between {numeric_cols[0]} and {numeric_cols[1]}", |
|
|
opacity=0.7) |
|
|
|
|
|
elif 'month' in result_df.columns or 'date' in result_df.columns or 'year' in result_df.columns or any('date' in col.lower() for col in result_df.columns): |
|
|
|
|
|
x_col = result_df.columns[0] |
|
|
y_cols = numeric_cols[:3] |
|
|
|
|
|
fig = px.line(result_df, x=x_col, y=y_cols, |
|
|
title="Time Series Analysis", |
|
|
markers=True) |
|
|
else: |
|
|
|
|
|
x_col = result_df.columns[0] |
|
|
y_cols = numeric_cols[0] |
|
|
|
|
|
fig = px.bar(result_df, x=x_col, y=y_cols, |
|
|
title="Data Visualization") |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
autosize=True, |
|
|
width=900, |
|
|
height=600, |
|
|
margin=dict(l=50, r=50, b=100, t=100, pad=4), |
|
|
template="plotly_white", |
|
|
font=dict(size=14) |
|
|
) |
|
|
|
|
|
|
|
|
img_bytes = fig.to_image(format="png", width=900, height=600, scale=2) |
|
|
encoded = base64.b64encode(img_bytes).decode("ascii") |
|
|
img_src = f"data:image/png;base64,{encoded}" |
|
|
|
|
|
|
|
|
response += f"\n\n<img src='{img_src}' width='100%' />" |
|
|
|
|
|
|
|
|
response += "\n\n**A visualization has been generated and is displayed above.**" |
|
|
else: |
|
|
print("Not enough numeric columns or data points for visualization") |
|
|
else: |
|
|
print("Not enough columns for visualization") |
|
|
except Exception as viz_error: |
|
|
print(f"Visualization error: {str(viz_error)}") |
|
|
traceback.print_exc() |
|
|
|
|
|
except Exception as e: |
|
|
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n**Error executing query:** {str(e)}" |
|
|
|
|
|
conn.close() |
|
|
|
|
|
except Exception as e: |
|
|
response = f"Error processing query: {str(e)}" |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
response = document_assistant.process_query(query) |
|
|
except Exception as e: |
|
|
response = f"Error processing document query: {str(e)}" |
|
|
|
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
response += f"\n\n(Query processed in {processing_time:.2f} seconds)" |
|
|
|
|
|
|
|
|
history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
return "", history |
|
|
|
|
|
def process_file_upload(files): |
|
|
"""Process uploaded files and index them""" |
|
|
if not files: |
|
|
return "No files uploaded" |
|
|
|
|
|
global current_context |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
file_info = [] |
|
|
for file in files: |
|
|
file_path = file.name |
|
|
file_name = os.path.basename(file_path) |
|
|
file_ext = os.path.splitext(file_name)[1].lower() |
|
|
|
|
|
if file_ext == '.csv': |
|
|
try: |
|
|
|
|
|
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower() |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
|
|
|
|
|
|
conn.execute("PRAGMA synchronous = OFF") |
|
|
conn.execute("PRAGMA journal_mode = MEMORY") |
|
|
|
|
|
|
|
|
df = pd.read_csv(file_path) |
|
|
df.to_sql('data_tab', conn, if_exists='replace', index=False) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": "csv", |
|
|
"file_name": file_name, |
|
|
"table_name": "data_tab" |
|
|
} |
|
|
|
|
|
|
|
|
cursor = conn.cursor() |
|
|
cursor.execute("PRAGMA table_info(data_tab);") |
|
|
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()] |
|
|
|
|
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM data_tab;") |
|
|
row_count = cursor.fetchone()[0] |
|
|
|
|
|
conn.close() |
|
|
|
|
|
file_info.append("β
CSV File Successfully Loaded") |
|
|
file_info.append(f"π Table Name: data_tab") |
|
|
file_info.append(f"π Source File: {file_name}") |
|
|
file_info.append(f"π Total Rows: {row_count:,}") |
|
|
file_info.append(f"π Columns: {', '.join(columns)}") |
|
|
|
|
|
except Exception as e: |
|
|
file_info.append(f"β Error loading CSV {file_name}: {str(e)}") |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
result = document_assistant.upload_document(file_path) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": "pdf", |
|
|
"file_name": file_name, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
file_info.append("β
Document Successfully Processed") |
|
|
file_info.append(f"π File: {file_name}") |
|
|
file_info.append(f"π Chunks: {result['chunks']}") |
|
|
file_info.append(result['message']) |
|
|
except Exception as e: |
|
|
file_info.append(f"β Error processing document {file_name}: {str(e)}") |
|
|
|
|
|
return "\n".join(file_info) |
|
|
|
|
|
def list_documents(): |
|
|
"""List all indexed documents""" |
|
|
info_list = [] |
|
|
|
|
|
|
|
|
try: |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
|
tables = cursor.fetchall() |
|
|
|
|
|
if tables: |
|
|
info_list.append("π CSV Data Tables:") |
|
|
for table in tables: |
|
|
|
|
|
cursor.execute(f"PRAGMA table_info({table[0]});") |
|
|
columns = [col[1] for col in cursor.fetchall()] |
|
|
|
|
|
|
|
|
cursor.execute(f"SELECT COUNT(*) FROM {table[0]};") |
|
|
row_count = cursor.fetchone()[0] |
|
|
|
|
|
info_list.append(f"- {table[0]} ({row_count:,} rows, {len(columns)} columns)") |
|
|
|
|
|
conn.close() |
|
|
except Exception as e: |
|
|
info_list.append(f"Error accessing CSV data: {str(e)}") |
|
|
|
|
|
|
|
|
docs = document_assistant.get_all_documents() |
|
|
if docs: |
|
|
info_list.append("\nπ Indexed Documents:") |
|
|
for doc in docs: |
|
|
info_list.append(f"- {doc['filename']} (ID: {doc['id']})") |
|
|
|
|
|
if not info_list: |
|
|
return "No data or documents loaded yet" |
|
|
|
|
|
return "\n".join(info_list) |
|
|
|
|
|
def clear_context(): |
|
|
"""Clear the current context and chat history""" |
|
|
global current_context |
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
return None |
|
|
|
|
|
def process_voice_input(audio_path): |
|
|
"""Process voice input and return transcribed text""" |
|
|
if audio_path is None: |
|
|
return "No audio recorded" |
|
|
|
|
|
try: |
|
|
|
|
|
r = sr.Recognizer() |
|
|
|
|
|
|
|
|
with sr.AudioFile(audio_path) as source: |
|
|
|
|
|
audio_data = r.record(source) |
|
|
|
|
|
|
|
|
text = r.recognize_google(audio_data) |
|
|
|
|
|
return text |
|
|
except sr.UnknownValueError: |
|
|
return "Could not understand audio" |
|
|
except sr.RequestError as e: |
|
|
return f"Error with speech recognition service: {e}" |
|
|
except Exception as e: |
|
|
return f"Error processing audio: {str(e)}" |
|
|
|
|
|
def text_to_speech_output(text): |
|
|
"""Convert text to speech""" |
|
|
if not text or len(text) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
last_message = None |
|
|
for msg in reversed(text): |
|
|
if msg["role"] == "assistant": |
|
|
last_message = msg["content"] |
|
|
break |
|
|
|
|
|
if not last_message: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
clean_text = re.sub(r'<.*?>', '', last_message) |
|
|
clean_text = re.sub(r'\*\*(.*?)\*\*', r'\1', clean_text) |
|
|
clean_text = re.sub(r'\n\n', ' ', clean_text) |
|
|
clean_text = re.sub(r'```.*?```', 'Code block removed for speech.', clean_text, flags=re.DOTALL) |
|
|
|
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") |
|
|
temp_file.close() |
|
|
|
|
|
|
|
|
tts = gTTS(text=clean_text, lang='en', slow=False) |
|
|
tts.save(temp_file.name) |
|
|
|
|
|
return temp_file.name |
|
|
except Exception as e: |
|
|
print(f"Error generating speech: {str(e)}") |
|
|
return None |
|
|
|
|
|
def create_test_visualization(): |
|
|
"""Create a test visualization to verify plotting works""" |
|
|
|
|
|
data = pd.DataFrame({ |
|
|
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'], |
|
|
'Value': [10, 15, 13, 17, 20, 25] |
|
|
}) |
|
|
|
|
|
|
|
|
fig = px.bar(data, x='Month', y='Value', title='Test Visualization') |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
autosize=True, |
|
|
width=800, |
|
|
height=500 |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
def create_test_html_visualization(): |
|
|
"""Create a test HTML visualization to verify plotting works""" |
|
|
|
|
|
data = pd.DataFrame({ |
|
|
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'], |
|
|
'Value': [10, 15, 13, 17, 20, 25] |
|
|
}) |
|
|
|
|
|
|
|
|
fig = px.bar(data, x='Month', y='Value', title='Test Visualization') |
|
|
|
|
|
|
|
|
html = pio.to_html(fig, full_html=False, include_plotlyjs='cdn') |
|
|
|
|
|
return html |
|
|
|
|
|
def flush_databases(): |
|
|
"""Flush ChromaDB and SQLite databases""" |
|
|
result = [] |
|
|
|
|
|
|
|
|
try: |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
|
tables = cursor.fetchall() |
|
|
|
|
|
|
|
|
for table in tables: |
|
|
cursor.execute(f"DROP TABLE IF EXISTS {table[0]};") |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
result.append("β
SQLite database cleared successfully") |
|
|
except Exception as e: |
|
|
result.append(f"β Error clearing SQLite database: {str(e)}") |
|
|
|
|
|
|
|
|
try: |
|
|
document_assistant.reset_database() |
|
|
result.append("β
ChromaDB cleared successfully") |
|
|
except Exception as e: |
|
|
result.append(f"β Error clearing ChromaDB: {str(e)}") |
|
|
|
|
|
|
|
|
global current_context |
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
return "\n".join(result) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo: |
|
|
gr.Markdown("# π€ AI Document Analysis & Voice Assistant") |
|
|
gr.Markdown("Upload documents, ask questions, and get voice responses!") |
|
|
|
|
|
with gr.Tab("Chat"): |
|
|
|
|
|
gr.HTML(""" |
|
|
<style> |
|
|
.chatbot-container img { |
|
|
max-width: 100%; |
|
|
height: auto; |
|
|
display: block; |
|
|
margin: 10px 0; |
|
|
} |
|
|
</style> |
|
|
""") |
|
|
|
|
|
chatbot = gr.Chatbot(height=500, type="messages", elem_classes="chatbot-container") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=8): |
|
|
msg = gr.Textbox( |
|
|
placeholder="Ask a question about your documents...", |
|
|
show_label=False |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
voice_btn = gr.Button("π€") |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Submit") |
|
|
clear_btn = gr.Button("Clear") |
|
|
clear_context_btn = gr.Button("Clear Context") |
|
|
|
|
|
audio_output = gr.Audio(label="Voice Response", type="filepath") |
|
|
|
|
|
|
|
|
voice_input = gr.Audio( |
|
|
label="Voice Input", |
|
|
type="filepath", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
process_text_query, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
msg.submit( |
|
|
process_text_query, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
clear_btn.click(lambda: None, None, [chatbot], queue=False) |
|
|
clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot]) |
|
|
|
|
|
voice_btn.click( |
|
|
lambda: gr.update(visible=True), |
|
|
None, |
|
|
voice_input |
|
|
) |
|
|
|
|
|
voice_input.change( |
|
|
process_voice_input, |
|
|
inputs=[voice_input], |
|
|
outputs=[msg] |
|
|
) |
|
|
|
|
|
|
|
|
tts_btn = gr.Button("π Speak Response") |
|
|
tts_btn.click( |
|
|
text_to_speech_output, |
|
|
inputs=[chatbot], |
|
|
outputs=[audio_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("Document Upload"): |
|
|
with gr.Row(): |
|
|
file_upload = gr.File( |
|
|
label="Upload Documents", |
|
|
file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"], |
|
|
file_count="multiple" |
|
|
) |
|
|
flush_db_btn_doc = gr.Button("ποΈ Flush All Databases", variant="stop") |
|
|
|
|
|
upload_button = gr.Button("Process & Index Documents") |
|
|
upload_output = gr.Textbox(label="Upload Status") |
|
|
|
|
|
upload_button.click( |
|
|
process_file_upload, |
|
|
inputs=[file_upload], |
|
|
outputs=[upload_output] |
|
|
) |
|
|
|
|
|
flush_db_btn_doc.click( |
|
|
flush_databases, |
|
|
inputs=[], |
|
|
outputs=[upload_output] |
|
|
) |
|
|
|
|
|
list_docs_button = gr.Button("List Indexed Documents") |
|
|
docs_output = gr.Textbox(label="Indexed Documents") |
|
|
|
|
|
list_docs_button.click( |
|
|
list_documents, |
|
|
inputs=[], |
|
|
outputs=[docs_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("Settings"): |
|
|
with gr.Row(): |
|
|
gr.Markdown("## Database Management") |
|
|
flush_db_btn = gr.Button("ποΈ Flush All Databases", variant="stop", scale=1) |
|
|
|
|
|
flush_result = gr.Textbox(label="Flush Result") |
|
|
|
|
|
flush_db_btn.click( |
|
|
flush_databases, |
|
|
inputs=[], |
|
|
outputs=[flush_result] |
|
|
) |
|
|
|
|
|
gr.Markdown("## System Settings") |
|
|
api_key = gr.Textbox( |
|
|
label="Groq API Key", |
|
|
placeholder="Enter your Groq API key", |
|
|
type="password", |
|
|
value=os.getenv("GROQ_API_KEY", "") |
|
|
) |
|
|
save_btn = gr.Button("Save Settings") |
|
|
|
|
|
def save_settings(key): |
|
|
os.environ["GROQ_API_KEY"] = key |
|
|
return "Settings saved!" |
|
|
|
|
|
save_btn.click( |
|
|
save_settings, |
|
|
inputs=[api_key], |
|
|
outputs=[gr.Textbox(label="Status")] |
|
|
) |
|
|
|
|
|
gr.Markdown("## Debugging") |
|
|
test_viz_btn = gr.Button("Test Visualization") |
|
|
test_viz_output = gr.HTML(label="Test Visualization") |
|
|
|
|
|
test_viz_btn.click( |
|
|
create_test_html_visualization, |
|
|
inputs=[], |
|
|
outputs=[test_viz_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |