Spaces:
Sleeping
Sleeping
File size: 5,034 Bytes
d0fbfac 8642c86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import pymysql
import sqlite3
import os
from dotenv import load_dotenv
from urllib.parse import urlparse, unquote
class Database:
def __init__(self):
load_dotenv()
self.db_uri = os.getenv("DB_URI")
# 🛡️ SAFETY FIX: If DB_URI is missing, default to a local SQLite demo file
# This prevents the "NoneType is not iterable" crash on Hugging Face
if not self.db_uri:
print("⚠️ WARNING: DB_URI not found. Defaulting to 'sqlite:///./demo.db'")
self.db_uri = "sqlite:///./demo.db"
self.parsed = urlparse(self.db_uri)
# Determine Database Type
if "sqlite" in self.parsed.scheme:
self.type = "sqlite"
# Extract path (remove 'sqlite:///')
self.db_path = self.parsed.path if self.parsed.path else "./demo.db"
# Fix absolute paths if needed
if self.db_path.startswith("/."): self.db_path = self.db_path[1:]
else:
self.type = "mysql"
self.host = self.parsed.hostname
self.port = self.parsed.port or 3306
self.user = self.parsed.username
self.password = unquote(self.parsed.password)
self.db_name = self.parsed.path[1:]
def get_connection(self):
if self.type == "sqlite":
# Connect to SQLite File
conn = sqlite3.connect(self.db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row # Allows accessing columns by name
return conn
else:
# Connect to MySQL Server
return pymysql.connect(
host=self.host,
user=self.user,
password=self.password,
database=self.db_name,
port=self.port,
cursorclass=pymysql.cursors.DictCursor
)
def run_query(self, query):
conn = self.get_connection()
try:
# MySQL Logic
if self.type == "mysql":
with conn.cursor() as cursor:
cursor.execute(query)
return cursor.fetchall()
# SQLite Logic
else:
cursor = conn.cursor()
cursor.execute(query)
# Convert SQLite rows to list of dicts to match MySQL format
items = [dict(row) for row in cursor.fetchall()]
return items
except Exception as e:
return [f"Error: {e}"]
finally:
conn.close()
def get_tables(self):
"""Returns a list of all table names (supports both SQLite & MySQL)."""
conn = self.get_connection()
try:
cursor = conn.cursor()
if self.type == "sqlite":
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
return [row[0] for row in cursor.fetchall()] # row[0] for standard cursor, row['name'] for Row factory
else:
cursor.execute("SHOW TABLES")
return [list(row.values())[0] for row in cursor.fetchall()]
except Exception as e:
print(f"Error fetching tables: {e}")
return []
finally:
conn.close()
def get_table_schema(self, table_name):
"""Returns column details for a specific table."""
conn = self.get_connection()
columns = []
try:
cursor = conn.cursor()
if self.type == "sqlite":
# SQLite Schema Query
cursor.execute(f"PRAGMA table_info({table_name})")
rows = cursor.fetchall()
# Row format: (cid, name, type, notnull, dflt_value, pk)
for row in rows:
# Handle both tuple and Row object access
col_name = row['name'] if isinstance(row, sqlite3.Row) else row[1]
col_type = row['type'] if isinstance(row, sqlite3.Row) else row[2]
columns.append(f"{col_name} ({col_type})")
else:
# MySQL Schema Query
cursor.execute(f"DESCRIBE {table_name}")
rows = cursor.fetchall()
for row in rows:
columns.append(f"{row['Field']} ({row['Type']})")
return columns
except Exception as e:
print(f"Error fetching schema for {table_name}: {e}")
return []
finally:
conn.close()
def get_schema(self):
"""Generates a full text schema of the database for the AI."""
tables = self.get_tables()
schema_text = ""
for table in tables:
columns = self.get_table_schema(table)
schema_text += f"Table: {table}\nColumns:\n"
for col in columns:
schema_text += f" - {col}\n"
schema_text += "\n"
return schema_text |