SVashishta1 commited on
Commit
8de36f9
·
1 Parent(s): fbbf665
Files changed (1) hide show
  1. app.py +110 -23
app.py CHANGED
@@ -6,7 +6,12 @@ import tempfile
6
  import pandas as pd
7
  import sqlite3
8
  from langchain_core.prompts import ChatPromptTemplate
9
- #test
 
 
 
 
 
10
  # Add parent directory to path to import backend modules
11
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
 
@@ -27,8 +32,16 @@ document_parser = SimpleDocumentParser()
27
  # Initialize DocumentAssistant
28
  document_assistant = DocumentAssistant()
29
 
30
- # Load environment variables
31
- load_dotenv()
 
 
 
 
 
 
 
 
32
 
33
  # Database path for CSV data
34
  DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db")
@@ -94,15 +107,23 @@ query_prompt = ChatPromptTemplate.from_messages(
94
  ]
95
  )
96
 
 
 
 
 
 
 
 
 
97
  def process_text_query(query, history):
98
  """Process a text query and update chat history"""
99
  if not query:
100
  return "", history
101
 
102
  # Check if this looks like an SQL query for CSV data
103
- if any(keyword in query.lower() for keyword in ['sql', 'query', 'table', 'select', 'from', 'where', 'group by']):
104
  try:
105
- # Try to execute as SQL query against CSV data
106
  conn = sqlite3.connect(DB_PATH)
107
  cursor = conn.cursor()
108
 
@@ -111,36 +132,81 @@ def process_text_query(query, history):
111
  tables = [row[0] for row in cursor.fetchall()]
112
 
113
  if tables:
114
- # Generate a response that includes table info
115
  table_info = []
116
  for table in tables:
117
  cursor.execute(f"PRAGMA table_info({table});")
118
  columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
119
  table_info.append(f"Table '{table}' has columns: {', '.join(columns)}")
120
 
121
- # Use the assistant to generate a response that includes SQL info
122
- context = f"The database contains the following tables:\n" + "\n".join(table_info)
123
- response = document_assistant.process_query(f"{context}\n\nUser query: {query}")
124
 
125
- # Update history with message format
126
- history.append({"role": "user", "content": query})
127
- history.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  else:
129
- # No tables found
130
- history.append({"role": "user", "content": query})
131
- history.append({"role": "assistant", "content": "No CSV data has been uploaded yet. Please upload a CSV file first."})
132
 
133
  conn.close()
134
  except Exception as e:
135
  # Fall back to regular document query
136
  response = document_assistant.process_query(query)
137
- history.append({"role": "user", "content": query})
138
- history.append({"role": "assistant", "content": response})
139
  else:
140
  # Process regular document query
141
  response = document_assistant.process_query(query)
142
- history.append({"role": "user", "content": query})
143
- history.append({"role": "assistant", "content": response})
 
 
144
 
145
  return "", history
146
 
@@ -164,9 +230,21 @@ def process_file_upload(files):
164
  # Load CSV into SQLite
165
  conn = sqlite3.connect(DB_PATH)
166
  load_csv_to_sqlite(file_path, conn, table_name)
 
 
 
 
 
 
 
 
 
 
167
  conn.close()
168
 
169
  file_info.append(f"CSV data loaded into table: {table_name}")
 
 
170
 
171
  # Also index with document assistant for text search
172
  result = document_assistant.upload_document(file_path)
@@ -239,14 +317,23 @@ def list_documents():
239
  cursor = conn.cursor()
240
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
241
  tables = cursor.fetchall()
242
- conn.close()
243
 
244
  if tables:
245
  doc_list.append("\nCSV data tables:")
246
  for table in tables:
247
- doc_list.append(f"- {table[0]}")
248
- except:
249
- pass
 
 
 
 
 
 
 
 
 
 
250
 
251
  return "\n".join(doc_list)
252
 
 
6
  import pandas as pd
7
  import sqlite3
8
  from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_groq import ChatGroq
10
+ import plotly.express as px
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
  # Add parent directory to path to import backend modules
16
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
 
 
32
  # Initialize DocumentAssistant
33
  document_assistant = DocumentAssistant()
34
 
35
+ # Initialize the LLM using the llama3-8b-8192 model from Groq
36
+ llm = ChatGroq(
37
+ model="llama3-8b-8192",
38
+ temperature=0,
39
+ max_tokens=None,
40
+ timeout=None,
41
+ max_retries=2,
42
+ verbose=True,
43
+ api_key=os.getenv("GROQ_API_KEY")
44
+ )
45
 
46
  # Database path for CSV data
47
  DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db")
 
107
  ]
108
  )
109
 
110
+ # Define the prompt for interpreting the SQL query result
111
+ interpret_prompt = ChatPromptTemplate.from_messages(
112
+ [
113
+ ("system", "You are an experienced data analyst. Examine the following data and provide a clear analysis. Base your analysis solely on the provided data."),
114
+ ("human", "Question: {question}\n\nSQL Query: {sql_query}\n\nData:\n{data}")
115
+ ]
116
+ )
117
+
118
  def process_text_query(query, history):
119
  """Process a text query and update chat history"""
120
  if not query:
121
  return "", history
122
 
123
  # Check if this looks like an SQL query for CSV data
124
+ if any(keyword in query.lower() for keyword in ['sql', 'query', 'table', 'select', 'from', 'where', 'group by', 'data', 'csv', 'average', 'count', 'sum', 'max', 'min']):
125
  try:
126
+ # Connect to the SQLite database
127
  conn = sqlite3.connect(DB_PATH)
128
  cursor = conn.cursor()
129
 
 
132
  tables = [row[0] for row in cursor.fetchall()]
133
 
134
  if tables:
135
+ # Build context with table information
136
  table_info = []
137
  for table in tables:
138
  cursor.execute(f"PRAGMA table_info({table});")
139
  columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
140
  table_info.append(f"Table '{table}' has columns: {', '.join(columns)}")
141
 
142
+ # Create question with context
143
+ question_with_context = f"The database contains the following tables:\n{chr(10).join(table_info)}\n\n{query}"
 
144
 
145
+ # Generate SQL query
146
+ ai_msg = query_prompt | llm
147
+ sql_query = ai_msg.invoke({"question": question_with_context}).content.strip()
148
+
149
+ try:
150
+ # Execute the query
151
+ result_df = pd.read_sql_query(sql_query, conn)
152
+
153
+ # Generate a plot if appropriate
154
+ fig = None
155
+ plot_html = None
156
+ if not result_df.empty and len(result_df) > 0:
157
+ if len(result_df.columns) == 2:
158
+ # Try to create a bar chart for 2-column results
159
+ numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist()
160
+ if numeric_cols:
161
+ x_col = result_df.columns[0] if result_df.columns[0] not in numeric_cols else result_df.columns[1]
162
+ y_col = numeric_cols[0]
163
+ fig = px.bar(result_df, x=x_col, y=y_col, title="Query Results")
164
+ plot_html = fig.to_html(full_html=False)
165
+
166
+ # Format the data for the interpretation
167
+ if len(result_df) > 10:
168
+ data_str = f"{result_df.head(10).to_string()}\n... (showing 10 of {len(result_df)} rows)"
169
+ else:
170
+ data_str = result_df.to_string()
171
+
172
+ # Interpret the results
173
+ interpret_chain = interpret_prompt | llm
174
+ interpretation = interpret_chain.invoke({
175
+ "question": query,
176
+ "sql_query": sql_query,
177
+ "data": data_str
178
+ }).content.strip()
179
+
180
+ # Create the response
181
+ response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n"
182
+
183
+ if not result_df.empty:
184
+ response += f"**Results:**\n```\n{data_str}\n```\n\n"
185
+ else:
186
+ response += "**No results found.**\n\n"
187
+
188
+ response += f"**Analysis:**\n{interpretation}"
189
+
190
+ # Add plot if available
191
+ if plot_html:
192
+ response += f"\n\n<div>{plot_html}</div>"
193
+
194
+ except Exception as e:
195
+ response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n**Error executing query:** {str(e)}"
196
  else:
197
+ response = "No CSV data has been uploaded yet. Please upload a CSV file first."
 
 
198
 
199
  conn.close()
200
  except Exception as e:
201
  # Fall back to regular document query
202
  response = document_assistant.process_query(query)
 
 
203
  else:
204
  # Process regular document query
205
  response = document_assistant.process_query(query)
206
+
207
+ # Update history with message format
208
+ history.append({"role": "user", "content": query})
209
+ history.append({"role": "assistant", "content": response})
210
 
211
  return "", history
212
 
 
230
  # Load CSV into SQLite
231
  conn = sqlite3.connect(DB_PATH)
232
  load_csv_to_sqlite(file_path, conn, table_name)
233
+
234
+ # Get column info for the table
235
+ cursor = conn.cursor()
236
+ cursor.execute(f"PRAGMA table_info({table_name});")
237
+ columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
238
+
239
+ # Get row count
240
+ cursor.execute(f"SELECT COUNT(*) FROM {table_name};")
241
+ row_count = cursor.fetchone()[0]
242
+
243
  conn.close()
244
 
245
  file_info.append(f"CSV data loaded into table: {table_name}")
246
+ file_info.append(f"Columns: {', '.join(columns)}")
247
+ file_info.append(f"Rows: {row_count}")
248
 
249
  # Also index with document assistant for text search
250
  result = document_assistant.upload_document(file_path)
 
317
  cursor = conn.cursor()
318
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
319
  tables = cursor.fetchall()
 
320
 
321
  if tables:
322
  doc_list.append("\nCSV data tables:")
323
  for table in tables:
324
+ # Get column info
325
+ cursor.execute(f"PRAGMA table_info({table[0]});")
326
+ columns = [col[1] for col in cursor.fetchall()]
327
+
328
+ # Get row count
329
+ cursor.execute(f"SELECT COUNT(*) FROM {table[0]};")
330
+ row_count = cursor.fetchone()[0]
331
+
332
+ doc_list.append(f"- {table[0]} ({row_count} rows, {len(columns)} columns)")
333
+
334
+ conn.close()
335
+ except Exception as e:
336
+ doc_list.append(f"Error listing CSV tables: {str(e)}")
337
 
338
  return "\n".join(doc_list)
339