SVashishta1 commited on
Commit
5facdeb
·
1 Parent(s): 92d1d2a
Files changed (1) hide show
  1. app.py +137 -116
app.py CHANGED
@@ -6,6 +6,8 @@ import tempfile
6
  import pandas as pd
7
  import sqlite3
8
  from langchain_core.prompts import ChatPromptTemplate
 
 
9
 
10
  # Load environment variables
11
  load_dotenv()
@@ -36,59 +38,36 @@ os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
36
 
37
  # Define the prompt with examples
38
  query_prompt = ChatPromptTemplate.from_messages([
39
- ("system", """
40
- You are an SQL and data analysis expert. Generate an appropriate SQL query using SQLite syntax for the question provided, without any explanations or code comments.
41
- Follow SQLite-specific conventions, as shown in the examples below:
42
-
43
- Example 1:
44
- Question: "What is the average fare for trips over 10 miles?"
45
- SQL Query: SELECT AVG(fare_amount) FROM taxi_data WHERE trip_distance > 10;
46
-
47
- Example 2:
48
- Question: "How many trips were taken in each month?"
49
- SQL Query: SELECT strftime('%m', pickup_datetime) AS month, COUNT(*) AS trip_count FROM taxi_data GROUP BY month;
50
-
51
- Example 3:
52
- Question: "What is the total fare amount for each driver (medallion) per day?"
53
- SQL Query: SELECT DATE(pickup_datetime) AS date, medallion, SUM(fare_amount) AS total_fare FROM taxi_data GROUP BY date, medallion;
54
-
55
- SQLite-Specific Conventions:
56
-
57
- 1. Date and Time Extraction:
58
- - Instead of `EXTRACT(YEAR FROM column)`, use `strftime('%Y', column)` to extract the year.
59
- - Example: `SELECT strftime('%Y', pickup_datetime) FROM taxi_data;`
60
 
61
- 2. String Length:
62
- - Instead of `CHAR_LENGTH(column)`, use `LENGTH(column)`.
63
- - Example: `SELECT LENGTH(passenger_name) FROM taxi_data;`
 
 
64
 
65
- 3. Regular Expressions:
66
- - SQLite does not support `REGEXP`. Use `LIKE` for simple patterns or avoid regular expressions.
67
- - Example: `SELECT * FROM taxi_data WHERE passenger_name LIKE 'A%';`
 
 
 
 
 
 
 
68
 
69
- 4. Window Functions:
70
- - For row numbering, use `ROW_NUMBER()` if supported, or simulate with joins.
71
- - Example: `SELECT id, ROW_NUMBER() OVER (ORDER BY pickup_datetime) AS row_num FROM taxi_data;`
72
 
73
- 5. Data Type Casting:
74
- - Use `CAST(column AS TYPE)`, but note that SQLite supports limited types.
75
- - Example: `SELECT CAST(fare_amount AS INTEGER) FROM taxi_data;`
76
 
77
- 6. Full Outer Join Workaround:
78
- - SQLite doesn’t support `FULL OUTER JOIN`. Combine `LEFT JOIN` and `UNION` for a similar effect.
79
- - Example:
80
- ```
81
- SELECT a.*, b.*
82
- FROM table_a a
83
- LEFT JOIN table_b b ON a.id = b.id
84
- UNION
85
- SELECT a.*, b.*
86
- FROM table_a a
87
- RIGHT JOIN table_b b ON a.id = b.id;
88
- ```
89
 
90
- Use these examples and guidelines to generate an SQL query compatible with SQLite syntax for the question provided.
91
- """),
92
  ("human", "{question}")
93
  ])
94
 
@@ -100,88 +79,88 @@ interpret_prompt = ChatPromptTemplate.from_messages(
100
  ]
101
  )
102
 
 
 
 
 
 
 
 
103
  def process_text_query(query, history):
104
  """Process a text query and update chat history"""
105
  if not query:
106
  return "", history
107
 
108
- # First, check if we have any CSV data loaded
 
 
 
 
109
  try:
110
- conn = sqlite3.connect(DB_PATH)
111
- cursor = conn.cursor()
112
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
113
- tables = [row[0] for row in cursor.fetchall()]
114
-
115
- if tables:
116
- # Get table schema information
117
- table_info = []
118
- for table in tables:
119
- cursor.execute(f"PRAGMA table_info({table});")
120
- columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
121
- table_info.append(f"Table '{table}' has columns: {', '.join(columns)}")
122
 
123
- # For questions about specific values, aggregations, or data analysis
124
- if any(word in query.lower() for word in [
125
- 'what is', 'how many', 'highest', 'lowest', 'maximum', 'minimum',
126
- 'average', 'mean', 'sum', 'total', 'count', 'tip', 'fare', 'amount'
127
- ]):
128
  try:
129
- # Generate SQL query
130
- context = f"The database contains the following tables:\n{chr(10).join(table_info)}\n\nQuestion: {query}"
131
- sql_query = query_engine.generate_response(query_prompt.format(question=context))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- # Execute query
134
  result_df = pd.read_sql_query(sql_query, conn)
135
 
136
- # Format results
137
- if len(result_df) > 10:
138
- data_str = f"{result_df.head(10).to_string()}\n... (showing 10 of {len(result_df)} rows)"
139
  else:
140
- data_str = result_df.to_string()
 
141
 
142
- # Generate response
143
- response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n"
144
- if not result_df.empty:
145
- response += f"**Results:**\n```\n{data_str}\n```\n\n"
146
-
147
- # Add interpretation
148
- interpret_prompt = f"""
149
- Question: {query}
150
- SQL Query: {sql_query}
151
- Results: {data_str}
152
-
153
- Please provide a clear, concise answer to the question based on these results.
154
- """
155
- interpretation = query_engine.generate_response(interpret_prompt)
156
- response += f"**Answer:**\n{interpretation}"
157
- else:
158
- response += "No results found."
159
 
160
- history.append({"role": "user", "content": query})
161
- history.append({"role": "assistant", "content": response})
162
- return "", history
163
 
164
  except Exception as e:
165
- print(f"SQL Error: {str(e)}")
166
- # Fall back to document query if SQL fails
167
- response = document_assistant.process_query(query)
168
  else:
169
- # For non-data analysis questions, use document query
170
- response = document_assistant.process_query(query)
171
- else:
172
- # No tables found, use document query
173
- response = document_assistant.process_query(query)
174
 
175
- conn.close()
176
-
 
 
 
 
 
 
177
  except Exception as e:
178
- print(f"Database Error: {str(e)}")
179
- # Fall back to document query if database access fails
180
- response = document_assistant.process_query(query)
181
 
182
- # Update history
183
  history.append({"role": "user", "content": query})
184
  history.append({"role": "assistant", "content": response})
 
185
  return "", history
186
 
187
  def process_file_upload(files):
@@ -189,6 +168,15 @@ def process_file_upload(files):
189
  if not files:
190
  return "No files uploaded"
191
 
 
 
 
 
 
 
 
 
 
192
  file_info = []
193
  for file in files:
194
  file_path = file.name
@@ -196,16 +184,22 @@ def process_file_upload(files):
196
  file_ext = os.path.splitext(file_name)[1].lower()
197
 
198
  if file_ext == '.csv':
199
- # Special handling for CSV files - load into SQLite
200
  try:
201
- # Create table name from filename (remove extension, replace spaces with underscores)
202
  table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower()
203
 
204
  # Load CSV into SQLite
205
  conn = sqlite3.connect(DB_PATH)
206
  load_csv_to_sqlite(file_path, conn, table_name)
207
 
208
- # Get column info for the table
 
 
 
 
 
 
 
209
  cursor = conn.cursor()
210
  cursor.execute(f"PRAGMA table_info({table_name});")
211
  columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
@@ -220,15 +214,24 @@ def process_file_upload(files):
220
  file_info.append(f"Columns: {', '.join(columns)}")
221
  file_info.append(f"Rows: {row_count}")
222
 
223
- # Also index with document assistant for text search
224
- result = document_assistant.upload_document(file_path)
225
- file_info.append(f"Also indexed for text search: {result['message']}")
226
  except Exception as e:
227
  file_info.append(f"Error loading CSV {file_name}: {str(e)}")
 
228
  else:
229
- # Process and index the document
230
- result = document_assistant.upload_document(file_path)
231
- file_info.append(f"{result['message']} ({result['chunks']} chunks)")
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  return "\n".join(file_info)
234
 
@@ -311,6 +314,16 @@ def list_documents():
311
 
312
  return "\n".join(doc_list)
313
 
 
 
 
 
 
 
 
 
 
 
314
  # Create Gradio interface
315
  with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
316
  gr.Markdown("# 🤖 AI Document Analysis & Voice Assistant")
@@ -331,6 +344,7 @@ with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
331
  with gr.Row():
332
  submit_btn = gr.Button("Submit")
333
  clear_btn = gr.Button("Clear")
 
334
 
335
  audio_output = gr.Audio(label="Voice Response", type="filepath")
336
 
@@ -375,6 +389,13 @@ with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
375
  inputs=[chatbot],
376
  outputs=[audio_output]
377
  )
 
 
 
 
 
 
 
378
 
379
  with gr.Tab("Document Upload"):
380
  file_upload = gr.File(
 
6
  import pandas as pd
7
  import sqlite3
8
  from langchain_core.prompts import ChatPromptTemplate
9
+ import plotly.express as px
10
+ import plotly.io as pio
11
 
12
  # Load environment variables
13
  load_dotenv()
 
38
 
39
  # Define the prompt with examples
40
  query_prompt = ChatPromptTemplate.from_messages([
41
+ ("system", """You are an SQL expert. Generate an appropriate SQL query using SQLite syntax for the question provided. The query should be executable and return exactly what was asked for.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ For questions about maximum/highest values, use MAX().
44
+ For minimum/lowest values, use MIN().
45
+ For averages, use AVG().
46
+ For counts, use COUNT().
47
+ For sums, use SUM().
48
 
49
+ For visualization queries:
50
+ 1. For trends over time:
51
+ - Group by appropriate time unit (day, month, year)
52
+ - Include relevant aggregations (AVG, COUNT, SUM)
53
+ 2. For distributions:
54
+ - Group by the value being distributed
55
+ - Include COUNT or frequency
56
+ 3. For comparisons:
57
+ - Include multiple measures
58
+ - Order appropriately
59
 
60
+ Examples:
61
+ 1. Question: "Plot tip amount trends by month"
62
+ SQL: SELECT strftime('%Y-%m', pickup_datetime) as month, AVG(tip_amount) as avg_tip, COUNT(*) as count FROM data_tab GROUP BY month ORDER BY month;
63
 
64
+ 2. Question: "Show distribution of fare amounts"
65
+ SQL: SELECT fare_amount, COUNT(*) as frequency FROM data_tab GROUP BY fare_amount ORDER BY fare_amount;
 
66
 
67
+ 3. Question: "What is the highest tip_amount in the dataset?"
68
+ SQL: SELECT MAX(tip_amount) as highest_tip FROM data_tab;
 
 
 
 
 
 
 
 
 
 
69
 
70
+ Generate only the SQL query, nothing else. Make sure to use the correct table name from the context provided."""),
 
71
  ("human", "{question}")
72
  ])
73
 
 
79
  ]
80
  )
81
 
82
+ # Add this as a global variable to track current context
83
+ current_context = {
84
+ "file_type": None, # 'csv' or 'pdf' or None
85
+ "file_name": None,
86
+ "table_name": None
87
+ }
88
+
89
  def process_text_query(query, history):
90
  """Process a text query and update chat history"""
91
  if not query:
92
  return "", history
93
 
94
+ # Check if query is about visualization
95
+ is_plot_query = any(word in query.lower() for word in [
96
+ 'plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'trends'
97
+ ])
98
+
99
  try:
100
+ if current_context["file_type"] == "csv":
101
+ conn = sqlite3.connect(DB_PATH)
102
+ cursor = conn.cursor()
 
 
 
 
 
 
 
 
 
103
 
104
+ if is_plot_query:
 
 
 
 
105
  try:
106
+ # For visualization queries, we need to get appropriate data
107
+ if 'trend' in query.lower():
108
+ # Example: For trend analysis, group by appropriate time unit
109
+ sql_query = f"""
110
+ SELECT strftime('%Y-%m', pickup_datetime) as month,
111
+ AVG(tip_amount) as avg_tip,
112
+ COUNT(*) as count,
113
+ SUM(tip_amount) as total_tip
114
+ FROM {current_context['table_name']}
115
+ GROUP BY month
116
+ ORDER BY month;
117
+ """
118
+ else:
119
+ # Default to a general aggregation
120
+ sql_query = f"""
121
+ SELECT tip_amount, COUNT(*) as frequency
122
+ FROM {current_context['table_name']}
123
+ GROUP BY tip_amount
124
+ ORDER BY tip_amount;
125
+ """
126
 
127
+ # Execute query and create visualization
128
  result_df = pd.read_sql_query(sql_query, conn)
129
 
130
+ if 'trend' in query.lower():
131
+ fig = px.line(result_df, x='month', y=['avg_tip', 'total_tip'],
132
+ title='Tip Trends Over Time')
133
  else:
134
+ fig = px.bar(result_df, x='tip_amount', y='frequency',
135
+ title='Distribution of Tip Amounts')
136
 
137
+ # Convert plot to HTML
138
+ plot_html = fig.to_html(full_html=False, include_plotlyjs='cdn')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ response = f"**Analysis:**\n\nHere's the visualization of the data:\n\n<div>{plot_html}</div>"
 
 
141
 
142
  except Exception as e:
143
+ response = f"Error creating visualization: {str(e)}"
 
 
144
  else:
145
+ # Handle regular SQL queries as before
146
+ # ... (keep your existing SQL query handling code here)
147
+ pass
 
 
148
 
149
+ conn.close()
150
+
151
+ elif current_context["file_type"] == "pdf":
152
+ # Process PDF queries using document_assistant
153
+ response = document_assistant.process_query(query)
154
+ else:
155
+ response = "Please upload a file first."
156
+
157
  except Exception as e:
158
+ response = f"Error processing query: {str(e)}"
 
 
159
 
160
+ # Update history with message format
161
  history.append({"role": "user", "content": query})
162
  history.append({"role": "assistant", "content": response})
163
+
164
  return "", history
165
 
166
  def process_file_upload(files):
 
168
  if not files:
169
  return "No files uploaded"
170
 
171
+ global current_context
172
+
173
+ # Clear existing context
174
+ current_context = {
175
+ "file_type": None,
176
+ "file_name": None,
177
+ "table_name": None
178
+ }
179
+
180
  file_info = []
181
  for file in files:
182
  file_path = file.name
 
184
  file_ext = os.path.splitext(file_name)[1].lower()
185
 
186
  if file_ext == '.csv':
 
187
  try:
188
+ # Create table name from filename
189
  table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower()
190
 
191
  # Load CSV into SQLite
192
  conn = sqlite3.connect(DB_PATH)
193
  load_csv_to_sqlite(file_path, conn, table_name)
194
 
195
+ # Update current context
196
+ current_context = {
197
+ "file_type": "csv",
198
+ "file_name": file_name,
199
+ "table_name": table_name
200
+ }
201
+
202
+ # Get column info
203
  cursor = conn.cursor()
204
  cursor.execute(f"PRAGMA table_info({table_name});")
205
  columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
 
214
  file_info.append(f"Columns: {', '.join(columns)}")
215
  file_info.append(f"Rows: {row_count}")
216
 
 
 
 
217
  except Exception as e:
218
  file_info.append(f"Error loading CSV {file_name}: {str(e)}")
219
+
220
  else:
221
+ # Process PDF or other document types
222
+ try:
223
+ result = document_assistant.upload_document(file_path)
224
+
225
+ # Update current context
226
+ current_context = {
227
+ "file_type": "pdf",
228
+ "file_name": file_name,
229
+ "table_name": None
230
+ }
231
+
232
+ file_info.append(f"{result['message']} ({result['chunks']} chunks)")
233
+ except Exception as e:
234
+ file_info.append(f"Error processing document {file_name}: {str(e)}")
235
 
236
  return "\n".join(file_info)
237
 
 
314
 
315
  return "\n".join(doc_list)
316
 
317
+ def clear_context():
318
+ """Clear the current context and chat history"""
319
+ global current_context
320
+ current_context = {
321
+ "file_type": None,
322
+ "file_name": None,
323
+ "table_name": None
324
+ }
325
+ return None
326
+
327
  # Create Gradio interface
328
  with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
329
  gr.Markdown("# 🤖 AI Document Analysis & Voice Assistant")
 
344
  with gr.Row():
345
  submit_btn = gr.Button("Submit")
346
  clear_btn = gr.Button("Clear")
347
+ clear_context_btn = gr.Button("Clear Context")
348
 
349
  audio_output = gr.Audio(label="Voice Response", type="filepath")
350
 
 
389
  inputs=[chatbot],
390
  outputs=[audio_output]
391
  )
392
+
393
+ # Add event handler for clear context button
394
+ clear_context_btn.click(
395
+ clear_context,
396
+ inputs=[],
397
+ outputs=[chatbot]
398
+ )
399
 
400
  with gr.Tab("Document Upload"):
401
  file_upload = gr.File(