File size: 5,034 Bytes
d0fbfac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8642c86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import pymysql
import sqlite3
import os
from dotenv import load_dotenv
from urllib.parse import urlparse, unquote

class Database:
    def __init__(self):
        load_dotenv()
        self.db_uri = os.getenv("DB_URI")
        
        # 🛡️ SAFETY FIX: If DB_URI is missing, default to a local SQLite demo file
        # This prevents the "NoneType is not iterable" crash on Hugging Face
        if not self.db_uri:
            print("⚠️ WARNING: DB_URI not found. Defaulting to 'sqlite:///./demo.db'")
            self.db_uri = "sqlite:///./demo.db"

        self.parsed = urlparse(self.db_uri)
        
        # Determine Database Type
        if "sqlite" in self.parsed.scheme:
            self.type = "sqlite"
            # Extract path (remove 'sqlite:///')
            self.db_path = self.parsed.path if self.parsed.path else "./demo.db"
            # Fix absolute paths if needed
            if self.db_path.startswith("/."): self.db_path = self.db_path[1:]
        else:
            self.type = "mysql"
            self.host = self.parsed.hostname
            self.port = self.parsed.port or 3306
            self.user = self.parsed.username
            self.password = unquote(self.parsed.password)
            self.db_name = self.parsed.path[1:]

    def get_connection(self):
        if self.type == "sqlite":
            # Connect to SQLite File
            conn = sqlite3.connect(self.db_path, check_same_thread=False)
            conn.row_factory = sqlite3.Row # Allows accessing columns by name
            return conn
        else:
            # Connect to MySQL Server
            return pymysql.connect(
                host=self.host,
                user=self.user,
                password=self.password,
                database=self.db_name,
                port=self.port,
                cursorclass=pymysql.cursors.DictCursor
            )

    def run_query(self, query):
        conn = self.get_connection()
        try:
            # MySQL Logic
            if self.type == "mysql":
                with conn.cursor() as cursor:
                    cursor.execute(query)
                    return cursor.fetchall()
            
            # SQLite Logic
            else:
                cursor = conn.cursor()
                cursor.execute(query)
                # Convert SQLite rows to list of dicts to match MySQL format
                items = [dict(row) for row in cursor.fetchall()]
                return items
                
        except Exception as e:
            return [f"Error: {e}"]
        finally:
            conn.close()

    def get_tables(self):
        """Returns a list of all table names (supports both SQLite & MySQL)."""
        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            if self.type == "sqlite":
                cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
                return [row[0] for row in cursor.fetchall()] # row[0] for standard cursor, row['name'] for Row factory
            else:
                cursor.execute("SHOW TABLES")
                return [list(row.values())[0] for row in cursor.fetchall()]
        except Exception as e:
            print(f"Error fetching tables: {e}")
            return []
        finally:
            conn.close()

    def get_table_schema(self, table_name):
        """Returns column details for a specific table."""
        conn = self.get_connection()
        columns = []
        try:
            cursor = conn.cursor()
            
            if self.type == "sqlite":
                # SQLite Schema Query
                cursor.execute(f"PRAGMA table_info({table_name})")
                rows = cursor.fetchall()
                # Row format: (cid, name, type, notnull, dflt_value, pk)
                for row in rows:
                    # Handle both tuple and Row object access
                    col_name = row['name'] if isinstance(row, sqlite3.Row) else row[1]
                    col_type = row['type'] if isinstance(row, sqlite3.Row) else row[2]
                    columns.append(f"{col_name} ({col_type})")
            else:
                # MySQL Schema Query
                cursor.execute(f"DESCRIBE {table_name}")
                rows = cursor.fetchall()
                for row in rows:
                    columns.append(f"{row['Field']} ({row['Type']})")
                    
            return columns
        except Exception as e:
            print(f"Error fetching schema for {table_name}: {e}")
            return []
        finally:
            conn.close()

    def get_schema(self):
        """Generates a full text schema of the database for the AI."""
        tables = self.get_tables()
        schema_text = ""
        
        for table in tables:
            columns = self.get_table_schema(table)
            schema_text += f"Table: {table}\nColumns:\n"
            for col in columns:
                schema_text += f"  - {col}\n"
            schema_text += "\n"
            
        return schema_text