PlainSQL-Agent / src /db_connector.py
LalitChaudhari3's picture
Update src/db_connector.py
d0fbfac verified
raw
history blame
5.03 kB
import pymysql
import sqlite3
import os
from dotenv import load_dotenv
from urllib.parse import urlparse, unquote
class Database:
def __init__(self):
load_dotenv()
self.db_uri = os.getenv("DB_URI")
# 🛡️ SAFETY FIX: If DB_URI is missing, default to a local SQLite demo file
# This prevents the "NoneType is not iterable" crash on Hugging Face
if not self.db_uri:
print("⚠️ WARNING: DB_URI not found. Defaulting to 'sqlite:///./demo.db'")
self.db_uri = "sqlite:///./demo.db"
self.parsed = urlparse(self.db_uri)
# Determine Database Type
if "sqlite" in self.parsed.scheme:
self.type = "sqlite"
# Extract path (remove 'sqlite:///')
self.db_path = self.parsed.path if self.parsed.path else "./demo.db"
# Fix absolute paths if needed
if self.db_path.startswith("/."): self.db_path = self.db_path[1:]
else:
self.type = "mysql"
self.host = self.parsed.hostname
self.port = self.parsed.port or 3306
self.user = self.parsed.username
self.password = unquote(self.parsed.password)
self.db_name = self.parsed.path[1:]
def get_connection(self):
if self.type == "sqlite":
# Connect to SQLite File
conn = sqlite3.connect(self.db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row # Allows accessing columns by name
return conn
else:
# Connect to MySQL Server
return pymysql.connect(
host=self.host,
user=self.user,
password=self.password,
database=self.db_name,
port=self.port,
cursorclass=pymysql.cursors.DictCursor
)
def run_query(self, query):
conn = self.get_connection()
try:
# MySQL Logic
if self.type == "mysql":
with conn.cursor() as cursor:
cursor.execute(query)
return cursor.fetchall()
# SQLite Logic
else:
cursor = conn.cursor()
cursor.execute(query)
# Convert SQLite rows to list of dicts to match MySQL format
items = [dict(row) for row in cursor.fetchall()]
return items
except Exception as e:
return [f"Error: {e}"]
finally:
conn.close()
def get_tables(self):
"""Returns a list of all table names (supports both SQLite & MySQL)."""
conn = self.get_connection()
try:
cursor = conn.cursor()
if self.type == "sqlite":
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
return [row[0] for row in cursor.fetchall()] # row[0] for standard cursor, row['name'] for Row factory
else:
cursor.execute("SHOW TABLES")
return [list(row.values())[0] for row in cursor.fetchall()]
except Exception as e:
print(f"Error fetching tables: {e}")
return []
finally:
conn.close()
def get_table_schema(self, table_name):
"""Returns column details for a specific table."""
conn = self.get_connection()
columns = []
try:
cursor = conn.cursor()
if self.type == "sqlite":
# SQLite Schema Query
cursor.execute(f"PRAGMA table_info({table_name})")
rows = cursor.fetchall()
# Row format: (cid, name, type, notnull, dflt_value, pk)
for row in rows:
# Handle both tuple and Row object access
col_name = row['name'] if isinstance(row, sqlite3.Row) else row[1]
col_type = row['type'] if isinstance(row, sqlite3.Row) else row[2]
columns.append(f"{col_name} ({col_type})")
else:
# MySQL Schema Query
cursor.execute(f"DESCRIBE {table_name}")
rows = cursor.fetchall()
for row in rows:
columns.append(f"{row['Field']} ({row['Type']})")
return columns
except Exception as e:
print(f"Error fetching schema for {table_name}: {e}")
return []
finally:
conn.close()
def get_schema(self):
"""Generates a full text schema of the database for the AI."""
tables = self.get_tables()
schema_text = ""
for table in tables:
columns = self.get_table_schema(table)
schema_text += f"Table: {table}\nColumns:\n"
for col in columns:
schema_text += f" - {col}\n"
schema_text += "\n"
return schema_text