|
|
import os |
|
|
import sys |
|
|
import gradio as gr |
|
|
from dotenv import load_dotenv |
|
|
import tempfile |
|
|
import pandas as pd |
|
|
import sqlite3 |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from backend.main import DocumentAssistant |
|
|
from backend.db import SimpleDB |
|
|
from backend.vector_db import ChromaVectorDB |
|
|
from backend.query_engine import QueryEngine |
|
|
from backend.document_parser import SimpleDocumentParser |
|
|
|
|
|
|
|
|
db = SimpleDB() |
|
|
vector_db = ChromaVectorDB(os.getenv("CHROMA_DB_PATH", "./data/chroma_db")) |
|
|
query_engine = QueryEngine() |
|
|
|
|
|
|
|
|
document_parser = SimpleDocumentParser() |
|
|
|
|
|
|
|
|
document_assistant = DocumentAssistant() |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db") |
|
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) |
|
|
|
|
|
|
|
|
query_prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", """ |
|
|
You are an SQL and data analysis expert. Generate an appropriate SQL query using SQLite syntax for the question provided, without any explanations or code comments. |
|
|
Follow SQLite-specific conventions, as shown in the examples below: |
|
|
|
|
|
Example 1: |
|
|
Question: "What is the average fare for trips over 10 miles?" |
|
|
SQL Query: SELECT AVG(fare_amount) FROM taxi_data WHERE trip_distance > 10; |
|
|
|
|
|
Example 2: |
|
|
Question: "How many trips were taken in each month?" |
|
|
SQL Query: SELECT strftime('%m', pickup_datetime) AS month, COUNT(*) AS trip_count FROM taxi_data GROUP BY month; |
|
|
|
|
|
Example 3: |
|
|
Question: "What is the total fare amount for each driver (medallion) per day?" |
|
|
SQL Query: SELECT DATE(pickup_datetime) AS date, medallion, SUM(fare_amount) AS total_fare FROM taxi_data GROUP BY date, medallion; |
|
|
|
|
|
SQLite-Specific Conventions: |
|
|
|
|
|
1. Date and Time Extraction: |
|
|
- Instead of `EXTRACT(YEAR FROM column)`, use `strftime('%Y', column)` to extract the year. |
|
|
- Example: `SELECT strftime('%Y', pickup_datetime) FROM taxi_data;` |
|
|
|
|
|
2. String Length: |
|
|
- Instead of `CHAR_LENGTH(column)`, use `LENGTH(column)`. |
|
|
- Example: `SELECT LENGTH(passenger_name) FROM taxi_data;` |
|
|
|
|
|
3. Regular Expressions: |
|
|
- SQLite does not support `REGEXP`. Use `LIKE` for simple patterns or avoid regular expressions. |
|
|
- Example: `SELECT * FROM taxi_data WHERE passenger_name LIKE 'A%';` |
|
|
|
|
|
4. Window Functions: |
|
|
- For row numbering, use `ROW_NUMBER()` if supported, or simulate with joins. |
|
|
- Example: `SELECT id, ROW_NUMBER() OVER (ORDER BY pickup_datetime) AS row_num FROM taxi_data;` |
|
|
|
|
|
5. Data Type Casting: |
|
|
- Use `CAST(column AS TYPE)`, but note that SQLite supports limited types. |
|
|
- Example: `SELECT CAST(fare_amount AS INTEGER) FROM taxi_data;` |
|
|
|
|
|
6. Full Outer Join Workaround: |
|
|
- SQLite doesn't support `FULL OUTER JOIN`. Combine `LEFT JOIN` and `UNION` for a similar effect. |
|
|
- Example: |
|
|
``` |
|
|
SELECT a.*, b.* |
|
|
FROM table_a a |
|
|
LEFT JOIN table_b b ON a.id = b.id |
|
|
UNION |
|
|
SELECT a.*, b.* |
|
|
FROM table_a a |
|
|
RIGHT JOIN table_b b ON a.id = b.id; |
|
|
``` |
|
|
|
|
|
Use these examples and guidelines to generate an SQL query compatible with SQLite syntax for the question provided. |
|
|
"""), |
|
|
("human", "{question}"), |
|
|
] |
|
|
) |
|
|
|
|
|
def process_text_query(query, history): |
|
|
"""Process a text query and update chat history""" |
|
|
if not query: |
|
|
return "", history |
|
|
|
|
|
|
|
|
if any(keyword in query.lower() for keyword in ['sql', 'query', 'table', 'select', 'from', 'where', 'group by']): |
|
|
try: |
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
|
tables = [row[0] for row in cursor.fetchall()] |
|
|
|
|
|
if tables: |
|
|
|
|
|
table_info = [] |
|
|
for table in tables: |
|
|
cursor.execute(f"PRAGMA table_info({table});") |
|
|
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()] |
|
|
table_info.append(f"Table '{table}' has columns: {', '.join(columns)}") |
|
|
|
|
|
|
|
|
context = f"The database contains the following tables:\n" + "\n".join(table_info) |
|
|
response = document_assistant.process_query(f"{context}\n\nUser query: {query}") |
|
|
|
|
|
|
|
|
history.append({"role": "user", "content": query}) |
|
|
history.append({"role": "assistant", "content": response}) |
|
|
else: |
|
|
|
|
|
history.append({"role": "user", "content": query}) |
|
|
history.append({"role": "assistant", "content": "No CSV data has been uploaded yet. Please upload a CSV file first."}) |
|
|
|
|
|
conn.close() |
|
|
except Exception as e: |
|
|
|
|
|
response = document_assistant.process_query(query) |
|
|
history.append({"role": "user", "content": query}) |
|
|
history.append({"role": "assistant", "content": response}) |
|
|
else: |
|
|
|
|
|
response = document_assistant.process_query(query) |
|
|
history.append({"role": "user", "content": query}) |
|
|
history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
return "", history |
|
|
|
|
|
def process_file_upload(files): |
|
|
"""Process uploaded files and index them""" |
|
|
if not files: |
|
|
return "No files uploaded" |
|
|
|
|
|
file_info = [] |
|
|
for file in files: |
|
|
file_path = file.name |
|
|
file_name = os.path.basename(file_path) |
|
|
file_ext = os.path.splitext(file_name)[1].lower() |
|
|
|
|
|
if file_ext == '.csv': |
|
|
|
|
|
try: |
|
|
|
|
|
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower() |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
load_csv_to_sqlite(file_path, conn, table_name) |
|
|
conn.close() |
|
|
|
|
|
file_info.append(f"CSV data loaded into table: {table_name}") |
|
|
|
|
|
|
|
|
result = document_assistant.upload_document(file_path) |
|
|
file_info.append(f"Also indexed for text search: {result['message']}") |
|
|
except Exception as e: |
|
|
file_info.append(f"Error loading CSV {file_name}: {str(e)}") |
|
|
else: |
|
|
|
|
|
result = document_assistant.upload_document(file_path) |
|
|
file_info.append(f"{result['message']} ({result['chunks']} chunks)") |
|
|
|
|
|
return "\n".join(file_info) |
|
|
|
|
|
def process_voice_input(audio_path): |
|
|
"""Process voice input and return transcribed text""" |
|
|
if audio_path is None: |
|
|
return "No audio recorded" |
|
|
|
|
|
|
|
|
return "Voice transcription is not available" |
|
|
|
|
|
def text_to_speech_output(text): |
|
|
"""Convert text to speech""" |
|
|
if not text or len(text) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
last_message = None |
|
|
for msg in reversed(text): |
|
|
if msg["role"] == "assistant": |
|
|
last_message = msg["content"] |
|
|
break |
|
|
|
|
|
if not last_message: |
|
|
return None |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
def load_csv_to_sqlite(file_path, conn, table_name): |
|
|
"""Load CSV data into SQLite database""" |
|
|
|
|
|
chunksize = 1000 |
|
|
for i, chunk in enumerate(pd.read_csv(file_path, chunksize=chunksize)): |
|
|
|
|
|
for col in chunk.columns: |
|
|
if 'date' in col.lower() or 'time' in col.lower(): |
|
|
try: |
|
|
chunk[col] = pd.to_datetime(chunk[col], errors='coerce') |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
if_exists = 'replace' if i == 0 else 'append' |
|
|
chunk.to_sql(table_name, conn, if_exists=if_exists, index=False) |
|
|
|
|
|
def list_documents(): |
|
|
"""List all indexed documents""" |
|
|
docs = document_assistant.get_all_documents() |
|
|
if not docs: |
|
|
return "No documents indexed yet" |
|
|
|
|
|
doc_list = [] |
|
|
for doc in docs: |
|
|
doc_list.append(f"{doc['filename']} (ID: {doc['id']})") |
|
|
|
|
|
|
|
|
try: |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
|
tables = cursor.fetchall() |
|
|
conn.close() |
|
|
|
|
|
if tables: |
|
|
doc_list.append("\nCSV data tables:") |
|
|
for table in tables: |
|
|
doc_list.append(f"- {table[0]}") |
|
|
except: |
|
|
pass |
|
|
|
|
|
return "\n".join(doc_list) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo: |
|
|
gr.Markdown("# π€ AI Document Analysis & Voice Assistant") |
|
|
gr.Markdown("Upload documents, ask questions, and get voice responses!") |
|
|
|
|
|
with gr.Tab("Chat"): |
|
|
chatbot = gr.Chatbot(height=400, type="messages") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=8): |
|
|
msg = gr.Textbox( |
|
|
placeholder="Ask a question about your documents...", |
|
|
show_label=False |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
voice_btn = gr.Button("π€") |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Submit") |
|
|
clear_btn = gr.Button("Clear") |
|
|
|
|
|
audio_output = gr.Audio(label="Voice Response", type="filepath") |
|
|
|
|
|
|
|
|
voice_input = gr.Audio( |
|
|
label="Voice Input", |
|
|
type="filepath", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
process_text_query, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
msg.submit( |
|
|
process_text_query, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
clear_btn.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
voice_btn.click( |
|
|
lambda: gr.update(visible=True), |
|
|
None, |
|
|
voice_input |
|
|
) |
|
|
|
|
|
voice_input.change( |
|
|
process_voice_input, |
|
|
inputs=[voice_input], |
|
|
outputs=[msg] |
|
|
) |
|
|
|
|
|
|
|
|
tts_btn = gr.Button("π Speak Response") |
|
|
tts_btn.click( |
|
|
text_to_speech_output, |
|
|
inputs=[chatbot], |
|
|
outputs=[audio_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("Document Upload"): |
|
|
file_upload = gr.File( |
|
|
label="Upload Documents", |
|
|
file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"], |
|
|
file_count="multiple" |
|
|
) |
|
|
upload_button = gr.Button("Process & Index Documents") |
|
|
upload_output = gr.Textbox(label="Upload Status") |
|
|
|
|
|
upload_button.click( |
|
|
process_file_upload, |
|
|
inputs=[file_upload], |
|
|
outputs=[upload_output] |
|
|
) |
|
|
|
|
|
list_docs_button = gr.Button("List Indexed Documents") |
|
|
docs_output = gr.Textbox(label="Indexed Documents") |
|
|
|
|
|
list_docs_button.click( |
|
|
list_documents, |
|
|
inputs=[], |
|
|
outputs=[docs_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("Settings"): |
|
|
gr.Markdown("## System Settings") |
|
|
api_key = gr.Textbox( |
|
|
label="Groq API Key", |
|
|
placeholder="Enter your Groq API key", |
|
|
type="password", |
|
|
value=os.getenv("GROQ_API_KEY", "") |
|
|
) |
|
|
save_btn = gr.Button("Save Settings") |
|
|
|
|
|
def save_settings(key): |
|
|
os.environ["GROQ_API_KEY"] = key |
|
|
return "Settings saved!" |
|
|
|
|
|
save_btn.click( |
|
|
save_settings, |
|
|
inputs=[api_key], |
|
|
outputs=[gr.Textbox(label="Status")] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |