Spaces:
Sleeping
Sleeping
| import pymysql | |
| import sqlite3 | |
| import os | |
| from dotenv import load_dotenv | |
| from urllib.parse import urlparse, unquote | |
| class Database: | |
| def __init__(self): | |
| load_dotenv() | |
| self.db_uri = os.getenv("DB_URI") | |
| # 🛡️ SAFETY FIX: If DB_URI is missing, default to a local SQLite demo file | |
| # This prevents the "NoneType is not iterable" crash on Hugging Face | |
| if not self.db_uri: | |
| print("⚠️ WARNING: DB_URI not found. Defaulting to 'sqlite:///./demo.db'") | |
| self.db_uri = "sqlite:///./demo.db" | |
| self.parsed = urlparse(self.db_uri) | |
| # Determine Database Type | |
| if "sqlite" in self.parsed.scheme: | |
| self.type = "sqlite" | |
| # Extract path (remove 'sqlite:///') | |
| self.db_path = self.parsed.path if self.parsed.path else "./demo.db" | |
| # Fix absolute paths if needed | |
| if self.db_path.startswith("/."): self.db_path = self.db_path[1:] | |
| else: | |
| self.type = "mysql" | |
| self.host = self.parsed.hostname | |
| self.port = self.parsed.port or 3306 | |
| self.user = self.parsed.username | |
| self.password = unquote(self.parsed.password) | |
| self.db_name = self.parsed.path[1:] | |
| def get_connection(self): | |
| if self.type == "sqlite": | |
| # Connect to SQLite File | |
| conn = sqlite3.connect(self.db_path, check_same_thread=False) | |
| conn.row_factory = sqlite3.Row # Allows accessing columns by name | |
| return conn | |
| else: | |
| # Connect to MySQL Server | |
| return pymysql.connect( | |
| host=self.host, | |
| user=self.user, | |
| password=self.password, | |
| database=self.db_name, | |
| port=self.port, | |
| cursorclass=pymysql.cursors.DictCursor | |
| ) | |
| def run_query(self, query): | |
| conn = self.get_connection() | |
| try: | |
| # MySQL Logic | |
| if self.type == "mysql": | |
| with conn.cursor() as cursor: | |
| cursor.execute(query) | |
| return cursor.fetchall() | |
| # SQLite Logic | |
| else: | |
| cursor = conn.cursor() | |
| cursor.execute(query) | |
| # Convert SQLite rows to list of dicts to match MySQL format | |
| items = [dict(row) for row in cursor.fetchall()] | |
| return items | |
| except Exception as e: | |
| return [f"Error: {e}"] | |
| finally: | |
| conn.close() | |
| def get_tables(self): | |
| """Returns a list of all table names (supports both SQLite & MySQL).""" | |
| conn = self.get_connection() | |
| try: | |
| cursor = conn.cursor() | |
| if self.type == "sqlite": | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| return [row[0] for row in cursor.fetchall()] # row[0] for standard cursor, row['name'] for Row factory | |
| else: | |
| cursor.execute("SHOW TABLES") | |
| return [list(row.values())[0] for row in cursor.fetchall()] | |
| except Exception as e: | |
| print(f"Error fetching tables: {e}") | |
| return [] | |
| finally: | |
| conn.close() | |
| def get_table_schema(self, table_name): | |
| """Returns column details for a specific table.""" | |
| conn = self.get_connection() | |
| columns = [] | |
| try: | |
| cursor = conn.cursor() | |
| if self.type == "sqlite": | |
| # SQLite Schema Query | |
| cursor.execute(f"PRAGMA table_info({table_name})") | |
| rows = cursor.fetchall() | |
| # Row format: (cid, name, type, notnull, dflt_value, pk) | |
| for row in rows: | |
| # Handle both tuple and Row object access | |
| col_name = row['name'] if isinstance(row, sqlite3.Row) else row[1] | |
| col_type = row['type'] if isinstance(row, sqlite3.Row) else row[2] | |
| columns.append(f"{col_name} ({col_type})") | |
| else: | |
| # MySQL Schema Query | |
| cursor.execute(f"DESCRIBE {table_name}") | |
| rows = cursor.fetchall() | |
| for row in rows: | |
| columns.append(f"{row['Field']} ({row['Type']})") | |
| return columns | |
| except Exception as e: | |
| print(f"Error fetching schema for {table_name}: {e}") | |
| return [] | |
| finally: | |
| conn.close() | |
| def get_schema(self): | |
| """Generates a full text schema of the database for the AI.""" | |
| tables = self.get_tables() | |
| schema_text = "" | |
| for table in tables: | |
| columns = self.get_table_schema(table) | |
| schema_text += f"Table: {table}\nColumns:\n" | |
| for col in columns: | |
| schema_text += f" - {col}\n" | |
| schema_text += "\n" | |
| return schema_text |