LalitChaudhari3 commited on
Commit
d0fbfac
·
verified ·
1 Parent(s): fe7157e

Update src/db_connector.py

Browse files
Files changed (1) hide show
  1. src/db_connector.py +133 -83
src/db_connector.py CHANGED
@@ -1,84 +1,134 @@
1
- import pymysql
2
- import os
3
- from dotenv import load_dotenv
4
- from urllib.parse import urlparse, unquote
5
-
6
- # ✅ FIX 1: Class renamed to 'Database' (matches api_server.py)
7
- class Database:
8
- def __init__(self):
9
- load_dotenv()
10
- db_uri = os.getenv("DB_URI")
11
-
12
- if not db_uri:
13
- raise ValueError(" DB_URI is missing from .env file")
14
-
15
- parsed = urlparse(db_uri)
16
- self.host = parsed.hostname
17
- self.port = parsed.port or 3306
18
- self.user = parsed.username
19
- self.password = unquote(parsed.password)
20
- self.db_name = parsed.path[1:]
21
-
22
- def get_connection(self):
23
- return pymysql.connect(
24
- host=self.host,
25
- user=self.user,
26
- password=self.password,
27
- database=self.db_name,
28
- port=self.port,
29
- cursorclass=pymysql.cursors.DictCursor
30
- )
31
-
32
- # FIX 2: Method renamed to 'run_query' (matches api_server.py)
33
- def run_query(self, query):
34
- conn = self.get_connection()
35
- try:
36
- with conn.cursor() as cursor:
37
- cursor.execute(query)
38
- return cursor.fetchall()
39
- except Exception as e:
40
- return [f"Error: {e}"]
41
- finally:
42
- conn.close()
43
-
44
- def get_tables(self):
45
- """Returns a list of all table names."""
46
- conn = self.get_connection()
47
- try:
48
- with conn.cursor() as cursor:
49
- cursor.execute("SHOW TABLES")
50
- return [list(row.values())[0] for row in cursor.fetchall()]
51
- except Exception as e:
52
- return []
53
- finally:
54
- conn.close()
55
-
56
- def get_table_schema(self, table_name):
57
- """Returns column details for a specific table."""
58
- conn = self.get_connection()
59
- try:
60
- with conn.cursor() as cursor:
61
- cursor.execute(f"DESCRIBE {table_name}")
62
- columns = []
63
- for row in cursor.fetchall():
64
- columns.append(f"{row['Field']} ({row['Type']})")
65
- return columns
66
- except Exception:
67
- return []
68
- finally:
69
- conn.close()
70
-
71
- # ✅ FIX 3: Added 'get_schema()' (no args) for the RAG system
72
- def get_schema(self):
73
- """Generates a full text schema of the database for the AI."""
74
- tables = self.get_tables()
75
- schema_text = ""
76
-
77
- for table in tables:
78
- columns = self.get_table_schema(table)
79
- schema_text += f"Table: {table}\nColumns:\n"
80
- for col in columns:
81
- schema_text += f" - {col}\n"
82
- schema_text += "\n"
83
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return schema_text
 
1
+ import pymysql
2
+ import sqlite3
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from urllib.parse import urlparse, unquote
6
+
7
+ class Database:
8
+ def __init__(self):
9
+ load_dotenv()
10
+ self.db_uri = os.getenv("DB_URI")
11
+
12
+ # 🛡️ SAFETY FIX: If DB_URI is missing, default to a local SQLite demo file
13
+ # This prevents the "NoneType is not iterable" crash on Hugging Face
14
+ if not self.db_uri:
15
+ print("⚠️ WARNING: DB_URI not found. Defaulting to 'sqlite:///./demo.db'")
16
+ self.db_uri = "sqlite:///./demo.db"
17
+
18
+ self.parsed = urlparse(self.db_uri)
19
+
20
+ # Determine Database Type
21
+ if "sqlite" in self.parsed.scheme:
22
+ self.type = "sqlite"
23
+ # Extract path (remove 'sqlite:///')
24
+ self.db_path = self.parsed.path if self.parsed.path else "./demo.db"
25
+ # Fix absolute paths if needed
26
+ if self.db_path.startswith("/."): self.db_path = self.db_path[1:]
27
+ else:
28
+ self.type = "mysql"
29
+ self.host = self.parsed.hostname
30
+ self.port = self.parsed.port or 3306
31
+ self.user = self.parsed.username
32
+ self.password = unquote(self.parsed.password)
33
+ self.db_name = self.parsed.path[1:]
34
+
35
+ def get_connection(self):
36
+ if self.type == "sqlite":
37
+ # Connect to SQLite File
38
+ conn = sqlite3.connect(self.db_path, check_same_thread=False)
39
+ conn.row_factory = sqlite3.Row # Allows accessing columns by name
40
+ return conn
41
+ else:
42
+ # Connect to MySQL Server
43
+ return pymysql.connect(
44
+ host=self.host,
45
+ user=self.user,
46
+ password=self.password,
47
+ database=self.db_name,
48
+ port=self.port,
49
+ cursorclass=pymysql.cursors.DictCursor
50
+ )
51
+
52
+ def run_query(self, query):
53
+ conn = self.get_connection()
54
+ try:
55
+ # MySQL Logic
56
+ if self.type == "mysql":
57
+ with conn.cursor() as cursor:
58
+ cursor.execute(query)
59
+ return cursor.fetchall()
60
+
61
+ # SQLite Logic
62
+ else:
63
+ cursor = conn.cursor()
64
+ cursor.execute(query)
65
+ # Convert SQLite rows to list of dicts to match MySQL format
66
+ items = [dict(row) for row in cursor.fetchall()]
67
+ return items
68
+
69
+ except Exception as e:
70
+ return [f"Error: {e}"]
71
+ finally:
72
+ conn.close()
73
+
74
+ def get_tables(self):
75
+ """Returns a list of all table names (supports both SQLite & MySQL)."""
76
+ conn = self.get_connection()
77
+ try:
78
+ cursor = conn.cursor()
79
+ if self.type == "sqlite":
80
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
81
+ return [row[0] for row in cursor.fetchall()] # row[0] for standard cursor, row['name'] for Row factory
82
+ else:
83
+ cursor.execute("SHOW TABLES")
84
+ return [list(row.values())[0] for row in cursor.fetchall()]
85
+ except Exception as e:
86
+ print(f"Error fetching tables: {e}")
87
+ return []
88
+ finally:
89
+ conn.close()
90
+
91
+ def get_table_schema(self, table_name):
92
+ """Returns column details for a specific table."""
93
+ conn = self.get_connection()
94
+ columns = []
95
+ try:
96
+ cursor = conn.cursor()
97
+
98
+ if self.type == "sqlite":
99
+ # SQLite Schema Query
100
+ cursor.execute(f"PRAGMA table_info({table_name})")
101
+ rows = cursor.fetchall()
102
+ # Row format: (cid, name, type, notnull, dflt_value, pk)
103
+ for row in rows:
104
+ # Handle both tuple and Row object access
105
+ col_name = row['name'] if isinstance(row, sqlite3.Row) else row[1]
106
+ col_type = row['type'] if isinstance(row, sqlite3.Row) else row[2]
107
+ columns.append(f"{col_name} ({col_type})")
108
+ else:
109
+ # MySQL Schema Query
110
+ cursor.execute(f"DESCRIBE {table_name}")
111
+ rows = cursor.fetchall()
112
+ for row in rows:
113
+ columns.append(f"{row['Field']} ({row['Type']})")
114
+
115
+ return columns
116
+ except Exception as e:
117
+ print(f"Error fetching schema for {table_name}: {e}")
118
+ return []
119
+ finally:
120
+ conn.close()
121
+
122
+ def get_schema(self):
123
+ """Generates a full text schema of the database for the AI."""
124
+ tables = self.get_tables()
125
+ schema_text = ""
126
+
127
+ for table in tables:
128
+ columns = self.get_table_schema(table)
129
+ schema_text += f"Table: {table}\nColumns:\n"
130
+ for col in columns:
131
+ schema_text += f" - {col}\n"
132
+ schema_text += "\n"
133
+
134
  return schema_text