Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import psycopg2 | |
| from psycopg2 import pool | |
| import sqlite3 | |
| from dotenv import load_dotenv | |
| import logging | |
| import csv | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # PostgreSQL connection pool setup | |
| db_pool = psycopg2.pool.SimpleConnectionPool( | |
| minconn=1, | |
| maxconn=5, | |
| user=os.getenv("db_user"), | |
| password=os.getenv("db_password"), | |
| host=os.getenv("db_host"), | |
| dbname=os.getenv("db_name"), | |
| port=os.getenv("db_port", 5432) | |
| ) | |
| # SQLite DB path | |
| db_path = os.getenv("DB_PATH", "./query_logs.db") | |
| # PostgreSQL: Function to fetch schema and save to schema.json | |
| def fetch_and_save_schema(): | |
| try: | |
| logging.info("Fetching schema from the database...") | |
| conn = db_pool.getconn() | |
| cursor = conn.cursor() | |
| # Query to retrieve all table names and their comments | |
| cursor.execute(""" | |
| SELECT table_name, obj_description(('public.' || table_name)::regclass) as table_comment | |
| FROM information_schema.tables | |
| WHERE table_schema = 'public'; | |
| """) | |
| tables = cursor.fetchall() | |
| # Build the schema information in JSON format | |
| schema_info = {} | |
| for table_name, table_comment in tables: | |
| schema_info[table_name] = { | |
| "comment": table_comment, | |
| "columns": [], | |
| "foreign_keys": [] | |
| } | |
| # Fetch column details and comments for each table | |
| cursor.execute(f""" | |
| SELECT | |
| c.column_name, | |
| c.data_type, | |
| col_description(('public.' || c.table_name)::regclass, ordinal_position) as column_comment | |
| FROM information_schema.columns c | |
| WHERE c.table_name = '{table_name}'; | |
| """) | |
| columns = cursor.fetchall() | |
| for column_name, data_type, column_comment in columns: | |
| schema_info[table_name]["columns"].append({ | |
| "name": column_name, | |
| "data_type": data_type, | |
| "comment": column_comment | |
| }) | |
| # Fetch foreign key relationships for each table | |
| cursor.execute(f""" | |
| SELECT | |
| kcu.column_name, | |
| ccu.table_name AS foreign_table_name, | |
| ccu.column_name AS foreign_column_name | |
| FROM information_schema.table_constraints AS tc | |
| JOIN information_schema.key_column_usage AS kcu | |
| ON tc.constraint_name = kcu.constraint_name | |
| AND tc.table_schema = kcu.table_schema | |
| JOIN information_schema.constraint_column_usage AS ccu | |
| ON ccu.constraint_name = tc.constraint_name | |
| WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = '{table_name}'; | |
| """) | |
| foreign_keys = cursor.fetchall() | |
| for column_name, foreign_table_name, foreign_column_name in foreign_keys: | |
| schema_info[table_name]["foreign_keys"].append({ | |
| "column": column_name, | |
| "references": { | |
| "table": foreign_table_name, | |
| "column": foreign_column_name | |
| } | |
| }) | |
| cursor.close() | |
| db_pool.putconn(conn) | |
| # Save the schema to a JSON file | |
| with open("schema.json", "w") as schema_file: | |
| json.dump(schema_info, schema_file, indent=2) | |
| logging.info("Schema fetched and saved to schema.json.") | |
| return schema_info | |
| except Exception as e: | |
| logging.error(f"Error fetching schema: {e}") | |
| return {"error": str(e)} | |
| # PostgreSQL: Function to execute SQL query | |
| def execute_sql_query(sql_query): | |
| try: | |
| conn = db_pool.getconn() | |
| cursor = conn.cursor() | |
| cursor.execute(sql_query) # Execute the query | |
| result = cursor.fetchall() # Fetch all results | |
| # Get column names from the cursor description | |
| column_names = [desc[0] for desc in cursor.description] | |
| cursor.close() | |
| db_pool.putconn(conn) | |
| # Format the result as a list of lists for Gradio Dataframe | |
| return [column_names] + result | |
| except Exception as e: | |
| logging.error(f"Error executing SQL query: {e}") | |
| return str(e) | |
| # SQLite: Initialize the local SQLite database | |
| def initialize_local_db(): | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| # Create table if it doesn't exist | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS query_logs ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| natural_language_query TEXT, | |
| reformulated_query TEXT, | |
| generated_sql TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ); | |
| ''') | |
| conn.commit() | |
| cursor.close() | |
| conn.close() | |
| # SQLite: Function to save the query to the local database | |
| def save_query_to_local_db(nl_query, reformulated_query, sql_query): | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| insert_query = ''' | |
| INSERT INTO query_logs (natural_language_query, reformulated_query, generated_sql) | |
| VALUES (?, ?, ?); | |
| ''' | |
| cursor.execute(insert_query, (nl_query, reformulated_query, sql_query)) | |
| conn.commit() | |
| cursor.close() | |
| conn.close() | |
| except Exception as e: | |
| logging.error(f"Error saving query: {e}") | |
| # SQLite: Function to get the last 50 saved queries | |
| def get_last_50_saved_queries(): | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| query = "SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC LIMIT 50;" | |
| cursor.execute(query) | |
| rows = cursor.fetchall() | |
| cursor.close() | |
| conn.close() | |
| return rows | |
| except Exception as e: | |
| logging.error(f"Error retrieving saved queries: {e}") | |
| return str(e) | |
| # SQLite: Function to export saved queries to a CSV file | |
| def export_saved_queries_to_csv(file_path="./saved_queries.csv"): | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| # Fetch all saved queries | |
| cursor.execute("SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC;") | |
| rows = cursor.fetchall() | |
| # Write the results to a CSV file | |
| with open(file_path, 'w', newline='') as csvfile: | |
| csv_writer = csv.writer(csvfile) | |
| csv_writer.writerow(['Natural Language Query', 'Reformulated Query', 'Generated SQL', 'Timestamp']) | |
| csv_writer.writerows(rows) | |
| cursor.close() | |
| conn.close() | |
| return file_path | |
| except Exception as e: | |
| logging.error(f"Error exporting queries to CSV: {e}") | |
| return str(e) | |
| def show_last_50_saved_queries(): | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| query = "SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC LIMIT 50;" | |
| cursor.execute(query) | |
| rows = cursor.fetchall() | |
| cursor.close() | |
| conn.close() | |
| return rows | |
| except Exception as e: | |
| logging.error(f"Error retrieving saved queries: {e}") | |
| return str(e) | |
| # Function to reset (drop all tables) and recreate the schema | |
| def reset_sqlite_db(): | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| # Drop the query_logs table | |
| cursor.execute("DROP TABLE IF EXISTS query_logs;") | |
| print("Dropped query_logs table") | |
| # Recreate the query_logs table | |
| cursor.execute('''CREATE TABLE query_logs ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| natural_language_query TEXT, | |
| reformulated_query TEXT, | |
| generated_sql TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| );''') | |
| print("Recreated the query_logs table.") | |
| conn.commit() | |
| cursor.close() | |
| conn.close() | |
| # Uncomment the following line to reset the SQLite database when you run this script | |
| # reset_sqlite_db() | |
| def fetch_schema_info(): | |
| try: | |
| with open("schema.json", "r") as schema_file: | |
| schema_info = json.load(schema_file) | |
| logging.info("Schema loaded from schema.json") | |
| return schema_info | |
| except Exception as e: | |
| logging.error(f"Error loading schema from schema.json: {e}") | |
| return {} |