SVashishta1
Error Fix
fbbf665
raw
history blame
13.4 kB
import os
import sys
import gradio as gr
from dotenv import load_dotenv
import tempfile
import pandas as pd
import sqlite3
from langchain_core.prompts import ChatPromptTemplate
#test
# Add parent directory to path to import backend modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from backend.main import DocumentAssistant
from backend.db import SimpleDB
from backend.vector_db import ChromaVectorDB
from backend.query_engine import QueryEngine
from backend.document_parser import SimpleDocumentParser
# Initialize components
db = SimpleDB()
vector_db = ChromaVectorDB(os.getenv("CHROMA_DB_PATH", "./data/chroma_db"))
query_engine = QueryEngine()
# Initialize the document parser
document_parser = SimpleDocumentParser()
# Initialize DocumentAssistant
document_assistant = DocumentAssistant()
# Load environment variables
load_dotenv()
# Database path for CSV data
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db")
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
# Define the prompt with examples
query_prompt = ChatPromptTemplate.from_messages(
[
("system", """
You are an SQL and data analysis expert. Generate an appropriate SQL query using SQLite syntax for the question provided, without any explanations or code comments.
Follow SQLite-specific conventions, as shown in the examples below:
Example 1:
Question: "What is the average fare for trips over 10 miles?"
SQL Query: SELECT AVG(fare_amount) FROM taxi_data WHERE trip_distance > 10;
Example 2:
Question: "How many trips were taken in each month?"
SQL Query: SELECT strftime('%m', pickup_datetime) AS month, COUNT(*) AS trip_count FROM taxi_data GROUP BY month;
Example 3:
Question: "What is the total fare amount for each driver (medallion) per day?"
SQL Query: SELECT DATE(pickup_datetime) AS date, medallion, SUM(fare_amount) AS total_fare FROM taxi_data GROUP BY date, medallion;
SQLite-Specific Conventions:
1. Date and Time Extraction:
- Instead of `EXTRACT(YEAR FROM column)`, use `strftime('%Y', column)` to extract the year.
- Example: `SELECT strftime('%Y', pickup_datetime) FROM taxi_data;`
2. String Length:
- Instead of `CHAR_LENGTH(column)`, use `LENGTH(column)`.
- Example: `SELECT LENGTH(passenger_name) FROM taxi_data;`
3. Regular Expressions:
- SQLite does not support `REGEXP`. Use `LIKE` for simple patterns or avoid regular expressions.
- Example: `SELECT * FROM taxi_data WHERE passenger_name LIKE 'A%';`
4. Window Functions:
- For row numbering, use `ROW_NUMBER()` if supported, or simulate with joins.
- Example: `SELECT id, ROW_NUMBER() OVER (ORDER BY pickup_datetime) AS row_num FROM taxi_data;`
5. Data Type Casting:
- Use `CAST(column AS TYPE)`, but note that SQLite supports limited types.
- Example: `SELECT CAST(fare_amount AS INTEGER) FROM taxi_data;`
6. Full Outer Join Workaround:
- SQLite doesn't support `FULL OUTER JOIN`. Combine `LEFT JOIN` and `UNION` for a similar effect.
- Example:
```
SELECT a.*, b.*
FROM table_a a
LEFT JOIN table_b b ON a.id = b.id
UNION
SELECT a.*, b.*
FROM table_a a
RIGHT JOIN table_b b ON a.id = b.id;
```
Use these examples and guidelines to generate an SQL query compatible with SQLite syntax for the question provided.
"""),
("human", "{question}"),
]
)
def process_text_query(query, history):
"""Process a text query and update chat history"""
if not query:
return "", history
# Check if this looks like an SQL query for CSV data
if any(keyword in query.lower() for keyword in ['sql', 'query', 'table', 'select', 'from', 'where', 'group by']):
try:
# Try to execute as SQL query against CSV data
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Get list of tables
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [row[0] for row in cursor.fetchall()]
if tables:
# Generate a response that includes table info
table_info = []
for table in tables:
cursor.execute(f"PRAGMA table_info({table});")
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
table_info.append(f"Table '{table}' has columns: {', '.join(columns)}")
# Use the assistant to generate a response that includes SQL info
context = f"The database contains the following tables:\n" + "\n".join(table_info)
response = document_assistant.process_query(f"{context}\n\nUser query: {query}")
# Update history with message format
history.append({"role": "user", "content": query})
history.append({"role": "assistant", "content": response})
else:
# No tables found
history.append({"role": "user", "content": query})
history.append({"role": "assistant", "content": "No CSV data has been uploaded yet. Please upload a CSV file first."})
conn.close()
except Exception as e:
# Fall back to regular document query
response = document_assistant.process_query(query)
history.append({"role": "user", "content": query})
history.append({"role": "assistant", "content": response})
else:
# Process regular document query
response = document_assistant.process_query(query)
history.append({"role": "user", "content": query})
history.append({"role": "assistant", "content": response})
return "", history
def process_file_upload(files):
"""Process uploaded files and index them"""
if not files:
return "No files uploaded"
file_info = []
for file in files:
file_path = file.name
file_name = os.path.basename(file_path)
file_ext = os.path.splitext(file_name)[1].lower()
if file_ext == '.csv':
# Special handling for CSV files - load into SQLite
try:
# Create table name from filename (remove extension, replace spaces with underscores)
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower()
# Load CSV into SQLite
conn = sqlite3.connect(DB_PATH)
load_csv_to_sqlite(file_path, conn, table_name)
conn.close()
file_info.append(f"CSV data loaded into table: {table_name}")
# Also index with document assistant for text search
result = document_assistant.upload_document(file_path)
file_info.append(f"Also indexed for text search: {result['message']}")
except Exception as e:
file_info.append(f"Error loading CSV {file_name}: {str(e)}")
else:
# Process and index the document
result = document_assistant.upload_document(file_path)
file_info.append(f"{result['message']} ({result['chunks']} chunks)")
return "\n".join(file_info)
def process_voice_input(audio_path):
"""Process voice input and return transcribed text"""
if audio_path is None:
return "No audio recorded"
# Since we don't have VoiceAssistant, return a placeholder message
return "Voice transcription is not available"
def text_to_speech_output(text):
"""Convert text to speech"""
if not text or len(text) == 0:
return None
# Extract the last assistant message
last_message = None
for msg in reversed(text):
if msg["role"] == "assistant":
last_message = msg["content"]
break
if not last_message:
return None
# Since we don't have VoiceAssistant, return None
return None
def load_csv_to_sqlite(file_path, conn, table_name):
"""Load CSV data into SQLite database"""
# Read the CSV in chunks
chunksize = 1000 # Adjust based on your memory constraints
for i, chunk in enumerate(pd.read_csv(file_path, chunksize=chunksize)):
# Perform any necessary data cleaning on the chunk
for col in chunk.columns:
if 'date' in col.lower() or 'time' in col.lower():
try:
chunk[col] = pd.to_datetime(chunk[col], errors='coerce')
except:
pass # If conversion fails, keep as is
# Load the chunk into the SQLite database
if_exists = 'replace' if i == 0 else 'append'
chunk.to_sql(table_name, conn, if_exists=if_exists, index=False)
def list_documents():
"""List all indexed documents"""
docs = document_assistant.get_all_documents()
if not docs:
return "No documents indexed yet"
doc_list = []
for doc in docs:
doc_list.append(f"{doc['filename']} (ID: {doc['id']})")
# Also list CSV tables
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
conn.close()
if tables:
doc_list.append("\nCSV data tables:")
for table in tables:
doc_list.append(f"- {table[0]}")
except:
pass
return "\n".join(doc_list)
# Create Gradio interface
with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
gr.Markdown("# πŸ€– AI Document Analysis & Voice Assistant")
gr.Markdown("Upload documents, ask questions, and get voice responses!")
with gr.Tab("Chat"):
chatbot = gr.Chatbot(height=400, type="messages")
with gr.Row():
with gr.Column(scale=8):
msg = gr.Textbox(
placeholder="Ask a question about your documents...",
show_label=False
)
with gr.Column(scale=1):
voice_btn = gr.Button("🎀")
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
audio_output = gr.Audio(label="Voice Response", type="filepath")
# Voice input
voice_input = gr.Audio(
label="Voice Input",
type="filepath",
visible=False
)
# Event handlers
submit_btn.click(
process_text_query,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
msg.submit(
process_text_query,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
voice_btn.click(
lambda: gr.update(visible=True),
None,
voice_input
)
voice_input.change(
process_voice_input,
inputs=[voice_input],
outputs=[msg]
)
# Add TTS functionality
tts_btn = gr.Button("πŸ”Š Speak Response")
tts_btn.click(
text_to_speech_output,
inputs=[chatbot],
outputs=[audio_output]
)
with gr.Tab("Document Upload"):
file_upload = gr.File(
label="Upload Documents",
file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"],
file_count="multiple"
)
upload_button = gr.Button("Process & Index Documents")
upload_output = gr.Textbox(label="Upload Status")
upload_button.click(
process_file_upload,
inputs=[file_upload],
outputs=[upload_output]
)
list_docs_button = gr.Button("List Indexed Documents")
docs_output = gr.Textbox(label="Indexed Documents")
list_docs_button.click(
list_documents,
inputs=[],
outputs=[docs_output]
)
with gr.Tab("Settings"):
gr.Markdown("## System Settings")
api_key = gr.Textbox(
label="Groq API Key",
placeholder="Enter your Groq API key",
type="password",
value=os.getenv("GROQ_API_KEY", "")
)
save_btn = gr.Button("Save Settings")
def save_settings(key):
os.environ["GROQ_API_KEY"] = key
return "Settings saved!"
save_btn.click(
save_settings,
inputs=[api_key],
outputs=[gr.Textbox(label="Status")]
)
# Launch the app
if __name__ == "__main__":
demo.launch()