Spaces:
Runtime error
Runtime error
File size: 8,659 Bytes
909cddd 2cb3f69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import os
import json
import psycopg2
from psycopg2 import pool
import sqlite3
from dotenv import load_dotenv
import logging
import csv
# Load environment variables
load_dotenv()
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# PostgreSQL connection pool setup
db_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=5,
user=os.getenv("db_user"),
password=os.getenv("db_password"),
host=os.getenv("db_host"),
dbname=os.getenv("db_name"),
port=os.getenv("db_port", 5432)
)
# SQLite DB path
db_path = os.getenv("DB_PATH", "./query_logs.db")
# PostgreSQL: Function to fetch schema and save to schema.json
def fetch_and_save_schema():
try:
logging.info("Fetching schema from the database...")
conn = db_pool.getconn()
cursor = conn.cursor()
# Query to retrieve all table names and their comments
cursor.execute("""
SELECT table_name, obj_description(('public.' || table_name)::regclass) as table_comment
FROM information_schema.tables
WHERE table_schema = 'public';
""")
tables = cursor.fetchall()
# Build the schema information in JSON format
schema_info = {}
for table_name, table_comment in tables:
schema_info[table_name] = {
"comment": table_comment,
"columns": [],
"foreign_keys": []
}
# Fetch column details and comments for each table
cursor.execute(f"""
SELECT
c.column_name,
c.data_type,
col_description(('public.' || c.table_name)::regclass, ordinal_position) as column_comment
FROM information_schema.columns c
WHERE c.table_name = '{table_name}';
""")
columns = cursor.fetchall()
for column_name, data_type, column_comment in columns:
schema_info[table_name]["columns"].append({
"name": column_name,
"data_type": data_type,
"comment": column_comment
})
# Fetch foreign key relationships for each table
cursor.execute(f"""
SELECT
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = '{table_name}';
""")
foreign_keys = cursor.fetchall()
for column_name, foreign_table_name, foreign_column_name in foreign_keys:
schema_info[table_name]["foreign_keys"].append({
"column": column_name,
"references": {
"table": foreign_table_name,
"column": foreign_column_name
}
})
cursor.close()
db_pool.putconn(conn)
# Save the schema to a JSON file
with open("schema.json", "w") as schema_file:
json.dump(schema_info, schema_file, indent=2)
logging.info("Schema fetched and saved to schema.json.")
return schema_info
except Exception as e:
logging.error(f"Error fetching schema: {e}")
return {"error": str(e)}
# PostgreSQL: Function to execute SQL query
def execute_sql_query(sql_query):
try:
conn = db_pool.getconn()
cursor = conn.cursor()
cursor.execute(sql_query) # Execute the query
result = cursor.fetchall() # Fetch all results
# Get column names from the cursor description
column_names = [desc[0] for desc in cursor.description]
cursor.close()
db_pool.putconn(conn)
# Format the result as a list of lists for Gradio Dataframe
return [column_names] + result
except Exception as e:
logging.error(f"Error executing SQL query: {e}")
return str(e)
# SQLite: Initialize the local SQLite database
def initialize_local_db():
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Create table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS query_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
natural_language_query TEXT,
reformulated_query TEXT,
generated_sql TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
''')
conn.commit()
cursor.close()
conn.close()
# SQLite: Function to save the query to the local database
def save_query_to_local_db(nl_query, reformulated_query, sql_query):
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
insert_query = '''
INSERT INTO query_logs (natural_language_query, reformulated_query, generated_sql)
VALUES (?, ?, ?);
'''
cursor.execute(insert_query, (nl_query, reformulated_query, sql_query))
conn.commit()
cursor.close()
conn.close()
except Exception as e:
logging.error(f"Error saving query: {e}")
# SQLite: Function to get the last 50 saved queries
def get_last_50_saved_queries():
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
query = "SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC LIMIT 50;"
cursor.execute(query)
rows = cursor.fetchall()
cursor.close()
conn.close()
return rows
except Exception as e:
logging.error(f"Error retrieving saved queries: {e}")
return str(e)
# SQLite: Function to export saved queries to a CSV file
def export_saved_queries_to_csv(file_path="./saved_queries.csv"):
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Fetch all saved queries
cursor.execute("SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC;")
rows = cursor.fetchall()
# Write the results to a CSV file
with open(file_path, 'w', newline='') as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(['Natural Language Query', 'Reformulated Query', 'Generated SQL', 'Timestamp'])
csv_writer.writerows(rows)
cursor.close()
conn.close()
return file_path
except Exception as e:
logging.error(f"Error exporting queries to CSV: {e}")
return str(e)
def show_last_50_saved_queries():
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
query = "SELECT natural_language_query, reformulated_query, generated_sql, created_at FROM query_logs ORDER BY created_at DESC LIMIT 50;"
cursor.execute(query)
rows = cursor.fetchall()
cursor.close()
conn.close()
return rows
except Exception as e:
logging.error(f"Error retrieving saved queries: {e}")
return str(e)
# Function to reset (drop all tables) and recreate the schema
def reset_sqlite_db():
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Drop the query_logs table
cursor.execute("DROP TABLE IF EXISTS query_logs;")
print("Dropped query_logs table")
# Recreate the query_logs table
cursor.execute('''CREATE TABLE query_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
natural_language_query TEXT,
reformulated_query TEXT,
generated_sql TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);''')
print("Recreated the query_logs table.")
conn.commit()
cursor.close()
conn.close()
# Uncomment the following line to reset the SQLite database when you run this script
# reset_sqlite_db()
def fetch_schema_info():
try:
with open("schema.json", "r") as schema_file:
schema_info = json.load(schema_file)
logging.info("Schema loaded from schema.json")
return schema_info
except Exception as e:
logging.error(f"Error loading schema from schema.json: {e}")
return {} |