import os import json import psycopg2 from psycopg2 import pool import sqlite3 from dotenv import load_dotenv import logging import csv # Load environment variables load_dotenv() # Initialize logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # PostgreSQL connection pool setup db_pool = psycopg2.pool.SimpleConnectionPool( minconn=1, maxconn=5, user=os.getenv("db_user"), password=os.getenv("db_password"), host=os.getenv("db_host"), dbname=os.getenv("db_name"), port=os.getenv("db_port", 5432) ) # SQLite DB path db_path = os.getenv("DB_PATH", "./query_logs.db") # PostgreSQL: Function to fetch schema and save to schema.json def fetch_and_save_schema(): try: logging.info("Fetching schema from the database...") conn = db_pool.getconn() cursor = conn.cursor() # Query to retrieve all table names and their comments cursor.execute(""" SELECT table_name, obj_description(('public.' || table_name)::regclass) as table_comment FROM information_schema.tables WHERE table_schema = 'public'; """) tables = cursor.fetchall() # Build the schema information in JSON format schema_info = {} for table_name, table_comment in tables: schema_info[table_name] = { "comment": table_comment, "columns": [], "foreign_keys": [] } # Fetch column details and comments for each table cursor.execute(f""" SELECT c.column_name, c.data_type, col_description(('public.' || c.table_name)::regclass, ordinal_position) as column_comment FROM information_schema.columns c WHERE c.table_name = '{table_name}'; """) columns = cursor.fetchall() for column_name, data_type, column_comment in columns: schema_info[table_name]["columns"].append({ "name": column_name, "data_type": data_type, "comment": column_comment }) # Fetch foreign key relationships for each table cursor.execute(f""" SELECT kcu.column_name, ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = '{table_name}'; """) foreign_keys = cursor.fetchall() for column_name, foreign_table_name, foreign_column_name in foreign_keys: schema_info[table_name]["foreign_keys"].append({ "column": column_name, "references": { "table": foreign_table_name, "column": foreign_column_name } }) cursor.close() db_pool.putconn(conn) # Save the schema to a JSON file with open("schema.json", "w") as schema_file: json.dump(schema_info, schema_file, indent=2) logging.info("Schema fetched and saved to schema.json.") return schema_info except Exception as e: logging.error(f"Error fetching schema: {e}") return {"error": str(e)} # PostgreSQL: Function to execute SQL query def execute_sql_query(sql_query): try: conn = db_pool.getconn() cursor = conn.cursor() cursor.execute(sql_query) # Execute the query result = cursor.fetchall() # Fetch all results # Get column names from the cursor description column_names = [desc[0] for desc in cursor.description] cursor.close() db_pool.putconn(conn) # Format the result as a list of lists for Gradio Dataframe return [column_names] + result except Exception as e: logging.error(f"Error executing SQL query: {e}") return str(e) # SQLite: Initialize the local SQLite database def initialize_local_db(): conn = sqlite3.connect(db_path) cursor = conn.cursor() # Create table if it doesn't exist cursor.execute(''' CREATE TABLE IF NOT EXISTS query_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, natural_language_query TEXT, reformulated_query TEXT, generated_sql TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); ''') conn.commit() cursor.close() conn.close() # SQLite: Function to save the query to the local database def save_query_to_local_db(nl_query, reformulated_query, sql_query): try: conn = sqlite3.connect(db_path) cursor = conn.cursor() insert_query = ''' INSERT INTO query_logs (natural_language_query, reformulated_query, generated_sql) VALUES (?, ?, ?); ''' cursor.execute(insert_query, (nl_query, reformulated_query, sql_query)) conn.commit() cursor.close() conn.close() except Exception as e: logging.error(f"Error saving query: {e}") # SQLite: Function to get the last 50 saved queries def get_last_50_saved_queries(): try: conn = sqlite3.connect(db_path) cursor = conn.cursor() query = "SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC LIMIT 50;" cursor.execute(query) rows = cursor.fetchall() cursor.close() conn.close() return rows except Exception as e: logging.error(f"Error retrieving saved queries: {e}") return str(e) # SQLite: Function to export saved queries to a CSV file def export_saved_queries_to_csv(file_path="./saved_queries.csv"): try: conn = sqlite3.connect(db_path) cursor = conn.cursor() # Fetch all saved queries cursor.execute("SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC;") rows = cursor.fetchall() # Write the results to a CSV file with open(file_path, 'w', newline='') as csvfile: csv_writer = csv.writer(csvfile) csv_writer.writerow(['Natural Language Query', 'Reformulated Query', 'Generated SQL', 'Timestamp']) csv_writer.writerows(rows) cursor.close() conn.close() return file_path except Exception as e: logging.error(f"Error exporting queries to CSV: {e}") return str(e) def show_last_50_saved_queries(): try: conn = sqlite3.connect(db_path) cursor = conn.cursor() query = "SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC LIMIT 50;" cursor.execute(query) rows = cursor.fetchall() cursor.close() conn.close() return rows except Exception as e: logging.error(f"Error retrieving saved queries: {e}") return str(e) # Function to reset (drop all tables) and recreate the schema def reset_sqlite_db(): conn = sqlite3.connect(db_path) cursor = conn.cursor() # Drop the query_logs table cursor.execute("DROP TABLE IF EXISTS query_logs;") print("Dropped query_logs table") # Recreate the query_logs table cursor.execute('''CREATE TABLE query_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, natural_language_query TEXT, reformulated_query TEXT, generated_sql TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP );''') print("Recreated the query_logs table.") conn.commit() cursor.close() conn.close() # Uncomment the following line to reset the SQLite database when you run this script # reset_sqlite_db() def fetch_schema_info(): try: with open("schema.json", "r") as schema_file: schema_info = json.load(schema_file) logging.info("Schema loaded from schema.json") return schema_info except Exception as e: logging.error(f"Error loading schema from schema.json: {e}") return {}