|
|
import os |
|
|
import sys |
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
import tempfile |
|
|
import pandas as pd |
|
|
import sqlite3 |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
import plotly.express as px |
|
|
import plotly.io as pio |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from backend.main import DocumentAssistant |
|
|
from backend.db import SimpleDB |
|
|
from backend.vector_db import ChromaVectorDB |
|
|
from backend.query_engine import QueryEngine |
|
|
from backend.document_parser import SimpleDocumentParser |
|
|
|
|
|
|
|
|
db = SimpleDB() |
|
|
vector_db = ChromaVectorDB(os.getenv("CHROMA_DB_PATH", "./data/chroma_db")) |
|
|
query_engine = QueryEngine() |
|
|
|
|
|
|
|
|
document_parser = SimpleDocumentParser() |
|
|
|
|
|
|
|
|
document_assistant = DocumentAssistant() |
|
|
|
|
|
|
|
|
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db") |
|
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) |
|
|
|
|
|
|
|
|
query_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", """You are an SQL expert. Generate an appropriate SQL query using SQLite syntax for the question provided. The query should be executable and return exactly what was asked for. |
|
|
|
|
|
For questions about maximum/highest values, use MAX(). |
|
|
For minimum/lowest values, use MIN(). |
|
|
For averages, use AVG(). |
|
|
For counts, use COUNT(). |
|
|
For sums, use SUM(). |
|
|
|
|
|
For visualization queries: |
|
|
1. For trends over time: |
|
|
- Group by appropriate time unit (day, month, year) |
|
|
- Include relevant aggregations (AVG, COUNT, SUM) |
|
|
2. For distributions: |
|
|
- Group by the value being distributed |
|
|
- Include COUNT or frequency |
|
|
3. For comparisons: |
|
|
- Include multiple measures |
|
|
- Order appropriately |
|
|
|
|
|
Examples: |
|
|
1. Question: "Plot tip amount trends by month" |
|
|
SQL: SELECT strftime('%Y-%m', pickup_datetime) as month, AVG(tip_amount) as avg_tip, COUNT(*) as count FROM data_tab GROUP BY month ORDER BY month; |
|
|
|
|
|
2. Question: "Show distribution of fare amounts" |
|
|
SQL: SELECT fare_amount, COUNT(*) as frequency FROM data_tab GROUP BY fare_amount ORDER BY fare_amount; |
|
|
|
|
|
3. Question: "What is the highest tip_amount in the dataset?" |
|
|
SQL: SELECT MAX(tip_amount) as highest_tip FROM data_tab; |
|
|
|
|
|
Generate only the SQL query, nothing else. Make sure to use the correct table name from the context provided."""), |
|
|
("human", "{question}") |
|
|
]) |
|
|
|
|
|
|
|
|
interpret_prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", "You are an experienced data analyst. Examine the following data and provide a clear analysis. Base your analysis solely on the provided data."), |
|
|
("human", "Question: {question}\n\nSQL Query: {sql_query}\n\nData:\n{data}") |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
def process_text_query(query, history): |
|
|
"""Process a text query and update chat history""" |
|
|
if not query: |
|
|
return "", history |
|
|
|
|
|
|
|
|
is_plot_query = any(word in query.lower() for word in [ |
|
|
'plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'trends' |
|
|
]) |
|
|
|
|
|
try: |
|
|
if current_context["file_type"] == "csv": |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
if is_plot_query: |
|
|
try: |
|
|
|
|
|
if 'trend' in query.lower(): |
|
|
|
|
|
sql_query = f""" |
|
|
SELECT strftime('%Y-%m', pickup_datetime) as month, |
|
|
AVG(tip_amount) as avg_tip, |
|
|
COUNT(*) as count, |
|
|
SUM(tip_amount) as total_tip |
|
|
FROM {current_context['table_name']} |
|
|
GROUP BY month |
|
|
ORDER BY month; |
|
|
""" |
|
|
else: |
|
|
|
|
|
sql_query = f""" |
|
|
SELECT tip_amount, COUNT(*) as frequency |
|
|
FROM {current_context['table_name']} |
|
|
GROUP BY tip_amount |
|
|
ORDER BY tip_amount; |
|
|
""" |
|
|
|
|
|
|
|
|
result_df = pd.read_sql_query(sql_query, conn) |
|
|
|
|
|
if 'trend' in query.lower(): |
|
|
fig = px.line(result_df, x='month', y=['avg_tip', 'total_tip'], |
|
|
title='Tip Trends Over Time') |
|
|
else: |
|
|
fig = px.bar(result_df, x='tip_amount', y='frequency', |
|
|
title='Distribution of Tip Amounts') |
|
|
|
|
|
|
|
|
plot_html = fig.to_html(full_html=False, include_plotlyjs='cdn') |
|
|
|
|
|
response = f"**Analysis:**\n\nHere's the visualization of the data:\n\n<div>{plot_html}</div>" |
|
|
|
|
|
except Exception as e: |
|
|
response = f"Error creating visualization: {str(e)}" |
|
|
else: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
conn.close() |
|
|
|
|
|
elif current_context["file_type"] == "pdf": |
|
|
|
|
|
response = document_assistant.process_query(query) |
|
|
else: |
|
|
response = "Please upload a file first." |
|
|
|
|
|
except Exception as e: |
|
|
response = f"Error processing query: {str(e)}" |
|
|
|
|
|
|
|
|
history.append({"role": "user", "content": query}) |
|
|
history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
return "", history |
|
|
|
|
|
def process_file_upload(files): |
|
|
"""Process uploaded files and index them""" |
|
|
if not files: |
|
|
return "No files uploaded" |
|
|
|
|
|
global current_context |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
file_info = [] |
|
|
for file in files: |
|
|
file_path = file.name |
|
|
file_name = os.path.basename(file_path) |
|
|
file_ext = os.path.splitext(file_name)[1].lower() |
|
|
|
|
|
if file_ext == '.csv': |
|
|
try: |
|
|
|
|
|
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower() |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
load_csv_to_sqlite(file_path, conn, table_name) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": "csv", |
|
|
"file_name": file_name, |
|
|
"table_name": table_name |
|
|
} |
|
|
|
|
|
|
|
|
cursor = conn.cursor() |
|
|
cursor.execute(f"PRAGMA table_info({table_name});") |
|
|
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()] |
|
|
|
|
|
|
|
|
cursor.execute(f"SELECT COUNT(*) FROM {table_name};") |
|
|
row_count = cursor.fetchone()[0] |
|
|
|
|
|
|
|
|
cursor.execute(f"SELECT * FROM {table_name} LIMIT 5;") |
|
|
sample_rows = cursor.fetchall() |
|
|
|
|
|
conn.close() |
|
|
|
|
|
file_info.append("β
CSV File Successfully Loaded") |
|
|
file_info.append(f"π Table Name: {table_name}") |
|
|
file_info.append(f"π Total Rows: {row_count:,}") |
|
|
file_info.append(f"\nπ Columns:") |
|
|
for col in columns: |
|
|
file_info.append(f" β’ {col}") |
|
|
|
|
|
if sample_rows: |
|
|
file_info.append("\nπ Sample Data (first 5 rows):") |
|
|
sample_df = pd.DataFrame(sample_rows, columns=[col.split(' ')[0] for col in columns]) |
|
|
file_info.append(f"```\n{sample_df.to_string()}\n```") |
|
|
|
|
|
except Exception as e: |
|
|
file_info.append(f"β Error loading CSV {file_name}: {str(e)}") |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
result = document_assistant.upload_document(file_path) |
|
|
|
|
|
|
|
|
current_context = { |
|
|
"file_type": "pdf", |
|
|
"file_name": file_name, |
|
|
"table_name": None |
|
|
} |
|
|
|
|
|
file_info.append("β
Document Successfully Processed") |
|
|
file_info.append(f"π File: {file_name}") |
|
|
file_info.append(f"π Chunks: {result['chunks']}") |
|
|
file_info.append(result['message']) |
|
|
except Exception as e: |
|
|
file_info.append(f"β Error processing document {file_name}: {str(e)}") |
|
|
|
|
|
return "\n".join(file_info) |
|
|
|
|
|
def process_voice_input(audio_path): |
|
|
"""Process voice input and return transcribed text""" |
|
|
if audio_path is None: |
|
|
return "No audio recorded" |
|
|
|
|
|
|
|
|
return "Voice transcription is not available" |
|
|
|
|
|
def text_to_speech_output(text): |
|
|
"""Convert text to speech""" |
|
|
if not text or len(text) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
last_message = None |
|
|
for msg in reversed(text): |
|
|
if msg["role"] == "assistant": |
|
|
last_message = msg["content"] |
|
|
break |
|
|
|
|
|
if not last_message: |
|
|
return None |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
def load_csv_to_sqlite(file_path, conn, table_name): |
|
|
"""Load CSV data into SQLite database""" |
|
|
|
|
|
chunksize = 1000 |
|
|
for i, chunk in enumerate(pd.read_csv(file_path, chunksize=chunksize)): |
|
|
|
|
|
for col in chunk.columns: |
|
|
if 'date' in col.lower() or 'time' in col.lower(): |
|
|
try: |
|
|
chunk[col] = pd.to_datetime(chunk[col], errors='coerce') |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
if_exists = 'replace' if i == 0 else 'append' |
|
|
chunk.to_sql(table_name, conn, if_exists=if_exists, index=False) |
|
|
|
|
|
def list_documents(): |
|
|
"""List all indexed documents""" |
|
|
info_list = [] |
|
|
|
|
|
|
|
|
try: |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
|
tables = cursor.fetchall() |
|
|
|
|
|
if tables: |
|
|
info_list.append("π CSV Data Tables:") |
|
|
for table in tables: |
|
|
|
|
|
cursor.execute(f"PRAGMA table_info({table[0]});") |
|
|
columns = [col[1] for col in cursor.fetchall()] |
|
|
|
|
|
|
|
|
cursor.execute(f"SELECT COUNT(*) FROM {table[0]};") |
|
|
row_count = cursor.fetchone()[0] |
|
|
|
|
|
|
|
|
sample_info = [] |
|
|
for col in ['vendor_id', 'rate_code', 'payment_type']: |
|
|
if col in columns: |
|
|
cursor.execute(f"SELECT DISTINCT {col} FROM {table[0]} LIMIT 5;") |
|
|
unique_vals = [str(row[0]) for row in cursor.fetchall()] |
|
|
if unique_vals: |
|
|
sample_info.append(f"{col}: {', '.join(unique_vals)}") |
|
|
|
|
|
info_list.append(f"\nπΉ Table: {table[0]}") |
|
|
info_list.append(f" - Rows: {row_count:,}") |
|
|
info_list.append(f" - Columns: {len(columns)}") |
|
|
if sample_info: |
|
|
info_list.append(" - Sample values:") |
|
|
for info in sample_info: |
|
|
info_list.append(f" β’ {info}") |
|
|
|
|
|
conn.close() |
|
|
except Exception as e: |
|
|
info_list.append(f"Error accessing CSV data: {str(e)}") |
|
|
|
|
|
|
|
|
docs = document_assistant.get_all_documents() |
|
|
if docs: |
|
|
info_list.append("\nπ Indexed Documents:") |
|
|
for doc in docs: |
|
|
info_list.append(f"- {doc['filename']} (ID: {doc['id']})") |
|
|
|
|
|
if not info_list: |
|
|
return "No data or documents loaded yet" |
|
|
|
|
|
return "\n".join(info_list) |
|
|
|
|
|
def clear_context(): |
|
|
"""Clear the current context and chat history""" |
|
|
global current_context |
|
|
current_context = { |
|
|
"file_type": None, |
|
|
"file_name": None, |
|
|
"table_name": None |
|
|
} |
|
|
return None |
|
|
|
|
|
|
|
|
with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo: |
|
|
gr.Markdown("# π€ AI Document Analysis & Voice Assistant") |
|
|
gr.Markdown("Upload documents, ask questions, and get voice responses!") |
|
|
|
|
|
with gr.Tab("Chat"): |
|
|
chatbot = gr.Chatbot(height=400, type="messages") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=8): |
|
|
msg = gr.Textbox( |
|
|
placeholder="Ask a question about your documents...", |
|
|
show_label=False |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
voice_btn = gr.Button("π€") |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Submit") |
|
|
clear_btn = gr.Button("Clear") |
|
|
clear_context_btn = gr.Button("Clear Context") |
|
|
|
|
|
audio_output = gr.Audio(label="Voice Response", type="filepath") |
|
|
|
|
|
|
|
|
voice_input = gr.Audio( |
|
|
label="Voice Input", |
|
|
type="filepath", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
process_text_query, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
msg.submit( |
|
|
process_text_query, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
clear_btn.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
voice_btn.click( |
|
|
lambda: gr.update(visible=True), |
|
|
None, |
|
|
voice_input |
|
|
) |
|
|
|
|
|
voice_input.change( |
|
|
process_voice_input, |
|
|
inputs=[voice_input], |
|
|
outputs=[msg] |
|
|
) |
|
|
|
|
|
|
|
|
tts_btn = gr.Button("π Speak Response") |
|
|
tts_btn.click( |
|
|
text_to_speech_output, |
|
|
inputs=[chatbot], |
|
|
outputs=[audio_output] |
|
|
) |
|
|
|
|
|
|
|
|
clear_context_btn.click( |
|
|
clear_context, |
|
|
inputs=[], |
|
|
outputs=[chatbot] |
|
|
) |
|
|
|
|
|
with gr.Tab("Document Upload"): |
|
|
file_upload = gr.File( |
|
|
label="Upload Documents", |
|
|
file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"], |
|
|
file_count="multiple" |
|
|
) |
|
|
upload_button = gr.Button("Process & Index Documents") |
|
|
upload_output = gr.Textbox(label="Upload Status") |
|
|
|
|
|
upload_button.click( |
|
|
process_file_upload, |
|
|
inputs=[file_upload], |
|
|
outputs=[upload_output] |
|
|
) |
|
|
|
|
|
list_docs_button = gr.Button("List Indexed Documents") |
|
|
docs_output = gr.Textbox(label="Indexed Documents") |
|
|
|
|
|
list_docs_button.click( |
|
|
list_documents, |
|
|
inputs=[], |
|
|
outputs=[docs_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("Settings"): |
|
|
gr.Markdown("## System Settings") |
|
|
api_key = gr.Textbox( |
|
|
label="Groq API Key", |
|
|
placeholder="Enter your Groq API key", |
|
|
type="password", |
|
|
value=os.getenv("GROQ_API_KEY", "") |
|
|
) |
|
|
save_btn = gr.Button("Save Settings") |
|
|
|
|
|
def save_settings(key): |
|
|
os.environ["GROQ_API_KEY"] = key |
|
|
return "Settings saved!" |
|
|
|
|
|
save_btn.click( |
|
|
save_settings, |
|
|
inputs=[api_key], |
|
|
outputs=[gr.Textbox(label="Status")] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |