import pymysql import sqlite3 import os from dotenv import load_dotenv from urllib.parse import urlparse, unquote class Database: def __init__(self): load_dotenv() self.db_uri = os.getenv("DB_URI") # 🛡️ SAFETY FIX: If DB_URI is missing, default to a local SQLite demo file # This prevents the "NoneType is not iterable" crash on Hugging Face if not self.db_uri: print("⚠️ WARNING: DB_URI not found. Defaulting to 'sqlite:///./demo.db'") self.db_uri = "sqlite:///./demo.db" self.parsed = urlparse(self.db_uri) # Determine Database Type if "sqlite" in self.parsed.scheme: self.type = "sqlite" # Extract path (remove 'sqlite:///') self.db_path = self.parsed.path if self.parsed.path else "./demo.db" # Fix absolute paths if needed if self.db_path.startswith("/."): self.db_path = self.db_path[1:] else: self.type = "mysql" self.host = self.parsed.hostname self.port = self.parsed.port or 3306 self.user = self.parsed.username self.password = unquote(self.parsed.password) self.db_name = self.parsed.path[1:] def get_connection(self): if self.type == "sqlite": # Connect to SQLite File conn = sqlite3.connect(self.db_path, check_same_thread=False) conn.row_factory = sqlite3.Row # Allows accessing columns by name return conn else: # Connect to MySQL Server return pymysql.connect( host=self.host, user=self.user, password=self.password, database=self.db_name, port=self.port, cursorclass=pymysql.cursors.DictCursor ) def run_query(self, query): conn = self.get_connection() try: # MySQL Logic if self.type == "mysql": with conn.cursor() as cursor: cursor.execute(query) return cursor.fetchall() # SQLite Logic else: cursor = conn.cursor() cursor.execute(query) # Convert SQLite rows to list of dicts to match MySQL format items = [dict(row) for row in cursor.fetchall()] return items except Exception as e: return [f"Error: {e}"] finally: conn.close() def get_tables(self): """Returns a list of all table names (supports both SQLite & MySQL).""" conn = self.get_connection() try: cursor = conn.cursor() if self.type == "sqlite": cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") return [row[0] for row in cursor.fetchall()] # row[0] for standard cursor, row['name'] for Row factory else: cursor.execute("SHOW TABLES") return [list(row.values())[0] for row in cursor.fetchall()] except Exception as e: print(f"Error fetching tables: {e}") return [] finally: conn.close() def get_table_schema(self, table_name): """Returns column details for a specific table.""" conn = self.get_connection() columns = [] try: cursor = conn.cursor() if self.type == "sqlite": # SQLite Schema Query cursor.execute(f"PRAGMA table_info({table_name})") rows = cursor.fetchall() # Row format: (cid, name, type, notnull, dflt_value, pk) for row in rows: # Handle both tuple and Row object access col_name = row['name'] if isinstance(row, sqlite3.Row) else row[1] col_type = row['type'] if isinstance(row, sqlite3.Row) else row[2] columns.append(f"{col_name} ({col_type})") else: # MySQL Schema Query cursor.execute(f"DESCRIBE {table_name}") rows = cursor.fetchall() for row in rows: columns.append(f"{row['Field']} ({row['Type']})") return columns except Exception as e: print(f"Error fetching schema for {table_name}: {e}") return [] finally: conn.close() def get_schema(self): """Generates a full text schema of the database for the AI.""" tables = self.get_tables() schema_text = "" for table in tables: columns = self.get_table_schema(table) schema_text += f"Table: {table}\nColumns:\n" for col in columns: schema_text += f" - {col}\n" schema_text += "\n" return schema_text