laudes's picture
Upload 8 files
2cb3f69 verified
import os
import json
import psycopg2
from psycopg2 import pool
import sqlite3
from dotenv import load_dotenv
import logging
import csv
# Load environment variables
load_dotenv()
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# PostgreSQL connection pool setup
db_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=5,
user=os.getenv("db_user"),
password=os.getenv("db_password"),
host=os.getenv("db_host"),
dbname=os.getenv("db_name"),
port=os.getenv("db_port", 5432)
)
# SQLite DB path
db_path = os.getenv("DB_PATH", "./query_logs.db")
# PostgreSQL: Function to fetch schema and save to schema.json
def fetch_and_save_schema():
try:
logging.info("Fetching schema from the database...")
conn = db_pool.getconn()
cursor = conn.cursor()
# Query to retrieve all table names and their comments
cursor.execute("""
SELECT table_name, obj_description(('public.' || table_name)::regclass) as table_comment
FROM information_schema.tables
WHERE table_schema = 'public';
""")
tables = cursor.fetchall()
# Build the schema information in JSON format
schema_info = {}
for table_name, table_comment in tables:
schema_info[table_name] = {
"comment": table_comment,
"columns": [],
"foreign_keys": []
}
# Fetch column details and comments for each table
cursor.execute(f"""
SELECT
c.column_name,
c.data_type,
col_description(('public.' || c.table_name)::regclass, ordinal_position) as column_comment
FROM information_schema.columns c
WHERE c.table_name = '{table_name}';
""")
columns = cursor.fetchall()
for column_name, data_type, column_comment in columns:
schema_info[table_name]["columns"].append({
"name": column_name,
"data_type": data_type,
"comment": column_comment
})
# Fetch foreign key relationships for each table
cursor.execute(f"""
SELECT
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = '{table_name}';
""")
foreign_keys = cursor.fetchall()
for column_name, foreign_table_name, foreign_column_name in foreign_keys:
schema_info[table_name]["foreign_keys"].append({
"column": column_name,
"references": {
"table": foreign_table_name,
"column": foreign_column_name
}
})
cursor.close()
db_pool.putconn(conn)
# Save the schema to a JSON file
with open("schema.json", "w") as schema_file:
json.dump(schema_info, schema_file, indent=2)
logging.info("Schema fetched and saved to schema.json.")
return schema_info
except Exception as e:
logging.error(f"Error fetching schema: {e}")
return {"error": str(e)}
# PostgreSQL: Function to execute SQL query
def execute_sql_query(sql_query):
try:
conn = db_pool.getconn()
cursor = conn.cursor()
cursor.execute(sql_query) # Execute the query
result = cursor.fetchall() # Fetch all results
# Get column names from the cursor description
column_names = [desc[0] for desc in cursor.description]
cursor.close()
db_pool.putconn(conn)
# Format the result as a list of lists for Gradio Dataframe
return [column_names] + result
except Exception as e:
logging.error(f"Error executing SQL query: {e}")
return str(e)
# SQLite: Initialize the local SQLite database
def initialize_local_db():
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Create table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS query_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
natural_language_query TEXT,
reformulated_query TEXT,
generated_sql TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
''')
conn.commit()
cursor.close()
conn.close()
# SQLite: Function to save the query to the local database
def save_query_to_local_db(nl_query, reformulated_query, sql_query):
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
insert_query = '''
INSERT INTO query_logs (natural_language_query, reformulated_query, generated_sql)
VALUES (?, ?, ?);
'''
cursor.execute(insert_query, (nl_query, reformulated_query, sql_query))
conn.commit()
cursor.close()
conn.close()
except Exception as e:
logging.error(f"Error saving query: {e}")
# SQLite: Function to get the last 50 saved queries
def get_last_50_saved_queries():
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
query = "SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC LIMIT 50;"
cursor.execute(query)
rows = cursor.fetchall()
cursor.close()
conn.close()
return rows
except Exception as e:
logging.error(f"Error retrieving saved queries: {e}")
return str(e)
# SQLite: Function to export saved queries to a CSV file
def export_saved_queries_to_csv(file_path="./saved_queries.csv"):
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Fetch all saved queries
cursor.execute("SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC;")
rows = cursor.fetchall()
# Write the results to a CSV file
with open(file_path, 'w', newline='') as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(['Natural Language Query', 'Reformulated Query', 'Generated SQL', 'Timestamp'])
csv_writer.writerows(rows)
cursor.close()
conn.close()
return file_path
except Exception as e:
logging.error(f"Error exporting queries to CSV: {e}")
return str(e)
def show_last_50_saved_queries():
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
query = "SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC LIMIT 50;"
cursor.execute(query)
rows = cursor.fetchall()
cursor.close()
conn.close()
return rows
except Exception as e:
logging.error(f"Error retrieving saved queries: {e}")
return str(e)
# Function to reset (drop all tables) and recreate the schema
def reset_sqlite_db():
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Drop the query_logs table
cursor.execute("DROP TABLE IF EXISTS query_logs;")
print("Dropped query_logs table")
# Recreate the query_logs table
cursor.execute('''CREATE TABLE query_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
natural_language_query TEXT,
reformulated_query TEXT,
generated_sql TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);''')
print("Recreated the query_logs table.")
conn.commit()
cursor.close()
conn.close()
# Uncomment the following line to reset the SQLite database when you run this script
# reset_sqlite_db()
def fetch_schema_info():
try:
with open("schema.json", "r") as schema_file:
schema_info = json.load(schema_file)
logging.info("Schema loaded from schema.json")
return schema_info
except Exception as e:
logging.error(f"Error loading schema from schema.json: {e}")
return {}