import os import sqlite3 import csv from dotenv import load_dotenv # Load environment variables from .env file load_dotenv() # SQLite DB path from .env for saving results db_path = os.getenv("db_path", "./query_logs.db") # Function to initialize the SQLite database def initialize_local_db(): conn = sqlite3.connect(db_path) cursor = conn.cursor() # Check if the reformulated_query column exists cursor.execute("PRAGMA table_info(query_logs);") columns = [column[1] for column in cursor.fetchall()] # If the reformulated_query column doesn't exist, alter the table if 'reformulated_query' not in columns: print("Altering table to add reformulated_query column...") cursor.execute(''' CREATE TABLE IF NOT EXISTS query_logs_new ( id INTEGER PRIMARY KEY AUTOINCREMENT, natural_language_query TEXT, reformulated_query TEXT, generated_sql TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); ''') # Copy data from old table to new table cursor.execute(''' INSERT INTO query_logs_new (id, natural_language_query, generated_sql, created_at) SELECT id, natural_language_query, generated_sql, created_at FROM query_logs; ''') # Drop the old table cursor.execute("DROP TABLE query_logs;") # Rename the new table cursor.execute("ALTER TABLE query_logs_new RENAME TO query_logs;") conn.commit() cursor.close() conn.close() # Function to reset (drop all tables) and recreate the schema def reset_sqlite_db(): conn = sqlite3.connect(db_path) cursor = conn.cursor() # Drop the query_logs table cursor.execute("DROP TABLE IF EXISTS query_logs;") print("Dropped query_logs table") # Recreate the query_logs table cursor.execute('''CREATE TABLE query_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, natural_language_query TEXT, reformulated_query TEXT, generated_sql TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP );''') print("Recreated the query_logs table.") conn.commit() cursor.close() conn.close() # Function to save the natural language query, reformulated query, and generated SQL to the local SQLite DB def save_query_to_local_db(nl_query, reformulated_query, sql_query): try: conn = sqlite3.connect(db_path) cursor = conn.cursor() insert_query = ''' INSERT INTO query_logs (natural_language_query, reformulated_query, generated_sql) VALUES (?, ?, ?); ''' cursor.execute(insert_query, (nl_query, reformulated_query, sql_query)) conn.commit() cursor.close() conn.close() except Exception as e: print(f"Error saving query: {e}") # Function to retrieve saved queries from the local SQLite DB with pagination and search def get_saved_queries(page=1, per_page=10, search_term=None): try: offset = (page - 1) * per_page conn = sqlite3.connect(db_path) cursor = conn.cursor() query = "SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs" if search_term: query += f" WHERE natural_language_query LIKE '%{search_term}%' OR generated_sql LIKE '%{search_term}%'" query += f" ORDER BY created_at DESC LIMIT {per_page} OFFSET {offset};" cursor.execute(query) rows = cursor.fetchall() cursor.close() conn.close() return rows except Exception as e: return str(e) # Manually clear all data (for testing purposes) def clear_data(): conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("DELETE FROM query_logs") conn.commit() cursor.close() conn.close() # Function to retrieve the last 50 saved queries from the local SQLite DB def get_last_50_saved_queries(): try: conn = sqlite3.connect(db_path) cursor = conn.cursor() query = "SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC LIMIT 50;" cursor.execute(query) rows = cursor.fetchall() cursor.close() conn.close() return rows except Exception as e: return str(e) # Function to export all saved queries to a CSV file def export_saved_queries_to_csv(file_path="./saved_queries.csv"): try: conn = sqlite3.connect(db_path) cursor = conn.cursor() # Fetch all saved queries (remove LIMIT 50) cursor.execute("SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC;") rows = cursor.fetchall() # Write the results to a CSV file with open(file_path, 'w', newline='') as csvfile: csv_writer = csv.writer(csvfile) csv_writer.writerow(['Natural Language Query', 'Reformulated Query', 'Generated SQL', 'Timestamp']) csv_writer.writerows(rows) cursor.close() conn.close() return file_path except Exception as e: return str(e) # Uncomment the following line to reset the SQLite database when you run this script # reset_sqlite_db()