SVashishta1
commited on
Commit
·
5facdeb
1
Parent(s):
92d1d2a
Error Fix
Browse files
app.py
CHANGED
|
@@ -6,6 +6,8 @@ import tempfile
|
|
| 6 |
import pandas as pd
|
| 7 |
import sqlite3
|
| 8 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# Load environment variables
|
| 11 |
load_dotenv()
|
|
@@ -36,59 +38,36 @@ os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
|
| 36 |
|
| 37 |
# Define the prompt with examples
|
| 38 |
query_prompt = ChatPromptTemplate.from_messages([
|
| 39 |
-
("system", """
|
| 40 |
-
You are an SQL and data analysis expert. Generate an appropriate SQL query using SQLite syntax for the question provided, without any explanations or code comments.
|
| 41 |
-
Follow SQLite-specific conventions, as shown in the examples below:
|
| 42 |
-
|
| 43 |
-
Example 1:
|
| 44 |
-
Question: "What is the average fare for trips over 10 miles?"
|
| 45 |
-
SQL Query: SELECT AVG(fare_amount) FROM taxi_data WHERE trip_distance > 10;
|
| 46 |
-
|
| 47 |
-
Example 2:
|
| 48 |
-
Question: "How many trips were taken in each month?"
|
| 49 |
-
SQL Query: SELECT strftime('%m', pickup_datetime) AS month, COUNT(*) AS trip_count FROM taxi_data GROUP BY month;
|
| 50 |
-
|
| 51 |
-
Example 3:
|
| 52 |
-
Question: "What is the total fare amount for each driver (medallion) per day?"
|
| 53 |
-
SQL Query: SELECT DATE(pickup_datetime) AS date, medallion, SUM(fare_amount) AS total_fare FROM taxi_data GROUP BY date, medallion;
|
| 54 |
-
|
| 55 |
-
SQLite-Specific Conventions:
|
| 56 |
-
|
| 57 |
-
1. Date and Time Extraction:
|
| 58 |
-
- Instead of `EXTRACT(YEAR FROM column)`, use `strftime('%Y', column)` to extract the year.
|
| 59 |
-
- Example: `SELECT strftime('%Y', pickup_datetime) FROM taxi_data;`
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
- Example: `SELECT CAST(fare_amount AS INTEGER) FROM taxi_data;`
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
- Example:
|
| 80 |
-
```
|
| 81 |
-
SELECT a.*, b.*
|
| 82 |
-
FROM table_a a
|
| 83 |
-
LEFT JOIN table_b b ON a.id = b.id
|
| 84 |
-
UNION
|
| 85 |
-
SELECT a.*, b.*
|
| 86 |
-
FROM table_a a
|
| 87 |
-
RIGHT JOIN table_b b ON a.id = b.id;
|
| 88 |
-
```
|
| 89 |
|
| 90 |
-
|
| 91 |
-
"""),
|
| 92 |
("human", "{question}")
|
| 93 |
])
|
| 94 |
|
|
@@ -100,88 +79,88 @@ interpret_prompt = ChatPromptTemplate.from_messages(
|
|
| 100 |
]
|
| 101 |
)
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
def process_text_query(query, history):
|
| 104 |
"""Process a text query and update chat history"""
|
| 105 |
if not query:
|
| 106 |
return "", history
|
| 107 |
|
| 108 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
try:
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
tables = [row[0] for row in cursor.fetchall()]
|
| 114 |
-
|
| 115 |
-
if tables:
|
| 116 |
-
# Get table schema information
|
| 117 |
-
table_info = []
|
| 118 |
-
for table in tables:
|
| 119 |
-
cursor.execute(f"PRAGMA table_info({table});")
|
| 120 |
-
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
|
| 121 |
-
table_info.append(f"Table '{table}' has columns: {', '.join(columns)}")
|
| 122 |
|
| 123 |
-
|
| 124 |
-
if any(word in query.lower() for word in [
|
| 125 |
-
'what is', 'how many', 'highest', 'lowest', 'maximum', 'minimum',
|
| 126 |
-
'average', 'mean', 'sum', 'total', 'count', 'tip', 'fare', 'amount'
|
| 127 |
-
]):
|
| 128 |
try:
|
| 129 |
-
#
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
# Execute query
|
| 134 |
result_df = pd.read_sql_query(sql_query, conn)
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
else:
|
| 140 |
-
|
|
|
|
| 141 |
|
| 142 |
-
#
|
| 143 |
-
|
| 144 |
-
if not result_df.empty:
|
| 145 |
-
response += f"**Results:**\n```\n{data_str}\n```\n\n"
|
| 146 |
-
|
| 147 |
-
# Add interpretation
|
| 148 |
-
interpret_prompt = f"""
|
| 149 |
-
Question: {query}
|
| 150 |
-
SQL Query: {sql_query}
|
| 151 |
-
Results: {data_str}
|
| 152 |
-
|
| 153 |
-
Please provide a clear, concise answer to the question based on these results.
|
| 154 |
-
"""
|
| 155 |
-
interpretation = query_engine.generate_response(interpret_prompt)
|
| 156 |
-
response += f"**Answer:**\n{interpretation}"
|
| 157 |
-
else:
|
| 158 |
-
response += "No results found."
|
| 159 |
|
| 160 |
-
|
| 161 |
-
history.append({"role": "assistant", "content": response})
|
| 162 |
-
return "", history
|
| 163 |
|
| 164 |
except Exception as e:
|
| 165 |
-
|
| 166 |
-
# Fall back to document query if SQL fails
|
| 167 |
-
response = document_assistant.process_query(query)
|
| 168 |
else:
|
| 169 |
-
#
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
# No tables found, use document query
|
| 173 |
-
response = document_assistant.process_query(query)
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
except Exception as e:
|
| 178 |
-
|
| 179 |
-
# Fall back to document query if database access fails
|
| 180 |
-
response = document_assistant.process_query(query)
|
| 181 |
|
| 182 |
-
# Update history
|
| 183 |
history.append({"role": "user", "content": query})
|
| 184 |
history.append({"role": "assistant", "content": response})
|
|
|
|
| 185 |
return "", history
|
| 186 |
|
| 187 |
def process_file_upload(files):
|
|
@@ -189,6 +168,15 @@ def process_file_upload(files):
|
|
| 189 |
if not files:
|
| 190 |
return "No files uploaded"
|
| 191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
file_info = []
|
| 193 |
for file in files:
|
| 194 |
file_path = file.name
|
|
@@ -196,16 +184,22 @@ def process_file_upload(files):
|
|
| 196 |
file_ext = os.path.splitext(file_name)[1].lower()
|
| 197 |
|
| 198 |
if file_ext == '.csv':
|
| 199 |
-
# Special handling for CSV files - load into SQLite
|
| 200 |
try:
|
| 201 |
-
# Create table name from filename
|
| 202 |
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower()
|
| 203 |
|
| 204 |
# Load CSV into SQLite
|
| 205 |
conn = sqlite3.connect(DB_PATH)
|
| 206 |
load_csv_to_sqlite(file_path, conn, table_name)
|
| 207 |
|
| 208 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
cursor = conn.cursor()
|
| 210 |
cursor.execute(f"PRAGMA table_info({table_name});")
|
| 211 |
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
|
|
@@ -220,15 +214,24 @@ def process_file_upload(files):
|
|
| 220 |
file_info.append(f"Columns: {', '.join(columns)}")
|
| 221 |
file_info.append(f"Rows: {row_count}")
|
| 222 |
|
| 223 |
-
# Also index with document assistant for text search
|
| 224 |
-
result = document_assistant.upload_document(file_path)
|
| 225 |
-
file_info.append(f"Also indexed for text search: {result['message']}")
|
| 226 |
except Exception as e:
|
| 227 |
file_info.append(f"Error loading CSV {file_name}: {str(e)}")
|
|
|
|
| 228 |
else:
|
| 229 |
-
# Process
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
return "\n".join(file_info)
|
| 234 |
|
|
@@ -311,6 +314,16 @@ def list_documents():
|
|
| 311 |
|
| 312 |
return "\n".join(doc_list)
|
| 313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
# Create Gradio interface
|
| 315 |
with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
|
| 316 |
gr.Markdown("# 🤖 AI Document Analysis & Voice Assistant")
|
|
@@ -331,6 +344,7 @@ with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
|
|
| 331 |
with gr.Row():
|
| 332 |
submit_btn = gr.Button("Submit")
|
| 333 |
clear_btn = gr.Button("Clear")
|
|
|
|
| 334 |
|
| 335 |
audio_output = gr.Audio(label="Voice Response", type="filepath")
|
| 336 |
|
|
@@ -375,6 +389,13 @@ with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
|
|
| 375 |
inputs=[chatbot],
|
| 376 |
outputs=[audio_output]
|
| 377 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
with gr.Tab("Document Upload"):
|
| 380 |
file_upload = gr.File(
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
import sqlite3
|
| 8 |
from langchain_core.prompts import ChatPromptTemplate
|
| 9 |
+
import plotly.express as px
|
| 10 |
+
import plotly.io as pio
|
| 11 |
|
| 12 |
# Load environment variables
|
| 13 |
load_dotenv()
|
|
|
|
| 38 |
|
| 39 |
# Define the prompt with examples
|
| 40 |
query_prompt = ChatPromptTemplate.from_messages([
|
| 41 |
+
("system", """You are an SQL expert. Generate an appropriate SQL query using SQLite syntax for the question provided. The query should be executable and return exactly what was asked for.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
For questions about maximum/highest values, use MAX().
|
| 44 |
+
For minimum/lowest values, use MIN().
|
| 45 |
+
For averages, use AVG().
|
| 46 |
+
For counts, use COUNT().
|
| 47 |
+
For sums, use SUM().
|
| 48 |
|
| 49 |
+
For visualization queries:
|
| 50 |
+
1. For trends over time:
|
| 51 |
+
- Group by appropriate time unit (day, month, year)
|
| 52 |
+
- Include relevant aggregations (AVG, COUNT, SUM)
|
| 53 |
+
2. For distributions:
|
| 54 |
+
- Group by the value being distributed
|
| 55 |
+
- Include COUNT or frequency
|
| 56 |
+
3. For comparisons:
|
| 57 |
+
- Include multiple measures
|
| 58 |
+
- Order appropriately
|
| 59 |
|
| 60 |
+
Examples:
|
| 61 |
+
1. Question: "Plot tip amount trends by month"
|
| 62 |
+
SQL: SELECT strftime('%Y-%m', pickup_datetime) as month, AVG(tip_amount) as avg_tip, COUNT(*) as count FROM data_tab GROUP BY month ORDER BY month;
|
| 63 |
|
| 64 |
+
2. Question: "Show distribution of fare amounts"
|
| 65 |
+
SQL: SELECT fare_amount, COUNT(*) as frequency FROM data_tab GROUP BY fare_amount ORDER BY fare_amount;
|
|
|
|
| 66 |
|
| 67 |
+
3. Question: "What is the highest tip_amount in the dataset?"
|
| 68 |
+
SQL: SELECT MAX(tip_amount) as highest_tip FROM data_tab;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
Generate only the SQL query, nothing else. Make sure to use the correct table name from the context provided."""),
|
|
|
|
| 71 |
("human", "{question}")
|
| 72 |
])
|
| 73 |
|
|
|
|
| 79 |
]
|
| 80 |
)
|
| 81 |
|
| 82 |
+
# Add this as a global variable to track current context
|
| 83 |
+
current_context = {
|
| 84 |
+
"file_type": None, # 'csv' or 'pdf' or None
|
| 85 |
+
"file_name": None,
|
| 86 |
+
"table_name": None
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
def process_text_query(query, history):
|
| 90 |
"""Process a text query and update chat history"""
|
| 91 |
if not query:
|
| 92 |
return "", history
|
| 93 |
|
| 94 |
+
# Check if query is about visualization
|
| 95 |
+
is_plot_query = any(word in query.lower() for word in [
|
| 96 |
+
'plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'trends'
|
| 97 |
+
])
|
| 98 |
+
|
| 99 |
try:
|
| 100 |
+
if current_context["file_type"] == "csv":
|
| 101 |
+
conn = sqlite3.connect(DB_PATH)
|
| 102 |
+
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
if is_plot_query:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
try:
|
| 106 |
+
# For visualization queries, we need to get appropriate data
|
| 107 |
+
if 'trend' in query.lower():
|
| 108 |
+
# Example: For trend analysis, group by appropriate time unit
|
| 109 |
+
sql_query = f"""
|
| 110 |
+
SELECT strftime('%Y-%m', pickup_datetime) as month,
|
| 111 |
+
AVG(tip_amount) as avg_tip,
|
| 112 |
+
COUNT(*) as count,
|
| 113 |
+
SUM(tip_amount) as total_tip
|
| 114 |
+
FROM {current_context['table_name']}
|
| 115 |
+
GROUP BY month
|
| 116 |
+
ORDER BY month;
|
| 117 |
+
"""
|
| 118 |
+
else:
|
| 119 |
+
# Default to a general aggregation
|
| 120 |
+
sql_query = f"""
|
| 121 |
+
SELECT tip_amount, COUNT(*) as frequency
|
| 122 |
+
FROM {current_context['table_name']}
|
| 123 |
+
GROUP BY tip_amount
|
| 124 |
+
ORDER BY tip_amount;
|
| 125 |
+
"""
|
| 126 |
|
| 127 |
+
# Execute query and create visualization
|
| 128 |
result_df = pd.read_sql_query(sql_query, conn)
|
| 129 |
|
| 130 |
+
if 'trend' in query.lower():
|
| 131 |
+
fig = px.line(result_df, x='month', y=['avg_tip', 'total_tip'],
|
| 132 |
+
title='Tip Trends Over Time')
|
| 133 |
else:
|
| 134 |
+
fig = px.bar(result_df, x='tip_amount', y='frequency',
|
| 135 |
+
title='Distribution of Tip Amounts')
|
| 136 |
|
| 137 |
+
# Convert plot to HTML
|
| 138 |
+
plot_html = fig.to_html(full_html=False, include_plotlyjs='cdn')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
response = f"**Analysis:**\n\nHere's the visualization of the data:\n\n<div>{plot_html}</div>"
|
|
|
|
|
|
|
| 141 |
|
| 142 |
except Exception as e:
|
| 143 |
+
response = f"Error creating visualization: {str(e)}"
|
|
|
|
|
|
|
| 144 |
else:
|
| 145 |
+
# Handle regular SQL queries as before
|
| 146 |
+
# ... (keep your existing SQL query handling code here)
|
| 147 |
+
pass
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
conn.close()
|
| 150 |
+
|
| 151 |
+
elif current_context["file_type"] == "pdf":
|
| 152 |
+
# Process PDF queries using document_assistant
|
| 153 |
+
response = document_assistant.process_query(query)
|
| 154 |
+
else:
|
| 155 |
+
response = "Please upload a file first."
|
| 156 |
+
|
| 157 |
except Exception as e:
|
| 158 |
+
response = f"Error processing query: {str(e)}"
|
|
|
|
|
|
|
| 159 |
|
| 160 |
+
# Update history with message format
|
| 161 |
history.append({"role": "user", "content": query})
|
| 162 |
history.append({"role": "assistant", "content": response})
|
| 163 |
+
|
| 164 |
return "", history
|
| 165 |
|
| 166 |
def process_file_upload(files):
|
|
|
|
| 168 |
if not files:
|
| 169 |
return "No files uploaded"
|
| 170 |
|
| 171 |
+
global current_context
|
| 172 |
+
|
| 173 |
+
# Clear existing context
|
| 174 |
+
current_context = {
|
| 175 |
+
"file_type": None,
|
| 176 |
+
"file_name": None,
|
| 177 |
+
"table_name": None
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
file_info = []
|
| 181 |
for file in files:
|
| 182 |
file_path = file.name
|
|
|
|
| 184 |
file_ext = os.path.splitext(file_name)[1].lower()
|
| 185 |
|
| 186 |
if file_ext == '.csv':
|
|
|
|
| 187 |
try:
|
| 188 |
+
# Create table name from filename
|
| 189 |
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower()
|
| 190 |
|
| 191 |
# Load CSV into SQLite
|
| 192 |
conn = sqlite3.connect(DB_PATH)
|
| 193 |
load_csv_to_sqlite(file_path, conn, table_name)
|
| 194 |
|
| 195 |
+
# Update current context
|
| 196 |
+
current_context = {
|
| 197 |
+
"file_type": "csv",
|
| 198 |
+
"file_name": file_name,
|
| 199 |
+
"table_name": table_name
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
# Get column info
|
| 203 |
cursor = conn.cursor()
|
| 204 |
cursor.execute(f"PRAGMA table_info({table_name});")
|
| 205 |
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
|
|
|
|
| 214 |
file_info.append(f"Columns: {', '.join(columns)}")
|
| 215 |
file_info.append(f"Rows: {row_count}")
|
| 216 |
|
|
|
|
|
|
|
|
|
|
| 217 |
except Exception as e:
|
| 218 |
file_info.append(f"Error loading CSV {file_name}: {str(e)}")
|
| 219 |
+
|
| 220 |
else:
|
| 221 |
+
# Process PDF or other document types
|
| 222 |
+
try:
|
| 223 |
+
result = document_assistant.upload_document(file_path)
|
| 224 |
+
|
| 225 |
+
# Update current context
|
| 226 |
+
current_context = {
|
| 227 |
+
"file_type": "pdf",
|
| 228 |
+
"file_name": file_name,
|
| 229 |
+
"table_name": None
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
file_info.append(f"{result['message']} ({result['chunks']} chunks)")
|
| 233 |
+
except Exception as e:
|
| 234 |
+
file_info.append(f"Error processing document {file_name}: {str(e)}")
|
| 235 |
|
| 236 |
return "\n".join(file_info)
|
| 237 |
|
|
|
|
| 314 |
|
| 315 |
return "\n".join(doc_list)
|
| 316 |
|
| 317 |
+
def clear_context():
|
| 318 |
+
"""Clear the current context and chat history"""
|
| 319 |
+
global current_context
|
| 320 |
+
current_context = {
|
| 321 |
+
"file_type": None,
|
| 322 |
+
"file_name": None,
|
| 323 |
+
"table_name": None
|
| 324 |
+
}
|
| 325 |
+
return None
|
| 326 |
+
|
| 327 |
# Create Gradio interface
|
| 328 |
with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
|
| 329 |
gr.Markdown("# 🤖 AI Document Analysis & Voice Assistant")
|
|
|
|
| 344 |
with gr.Row():
|
| 345 |
submit_btn = gr.Button("Submit")
|
| 346 |
clear_btn = gr.Button("Clear")
|
| 347 |
+
clear_context_btn = gr.Button("Clear Context")
|
| 348 |
|
| 349 |
audio_output = gr.Audio(label="Voice Response", type="filepath")
|
| 350 |
|
|
|
|
| 389 |
inputs=[chatbot],
|
| 390 |
outputs=[audio_output]
|
| 391 |
)
|
| 392 |
+
|
| 393 |
+
# Add event handler for clear context button
|
| 394 |
+
clear_context_btn.click(
|
| 395 |
+
clear_context,
|
| 396 |
+
inputs=[],
|
| 397 |
+
outputs=[chatbot]
|
| 398 |
+
)
|
| 399 |
|
| 400 |
with gr.Tab("Document Upload"):
|
| 401 |
file_upload = gr.File(
|