LalitChaudhari3 commited on
Commit
bfc9e9b
·
verified ·
1 Parent(s): 77ad74c

Update src/db_connector.py

Browse files
Files changed (1) hide show
  1. src/db_connector.py +7 -27
src/db_connector.py CHANGED
@@ -9,20 +9,15 @@ class Database:
9
  load_dotenv()
10
  self.db_uri = os.getenv("DB_URI")
11
 
12
- # 🛡️ SAFETY FIX: If DB_URI is missing, default to a local SQLite demo file
13
- # This prevents the "NoneType is not iterable" crash on Hugging Face
14
  if not self.db_uri:
15
  print("⚠️ WARNING: DB_URI not found. Defaulting to 'sqlite:///./demo.db'")
16
  self.db_uri = "sqlite:///./demo.db"
17
 
18
  self.parsed = urlparse(self.db_uri)
19
 
20
- # Determine Database Type
21
  if "sqlite" in self.parsed.scheme:
22
  self.type = "sqlite"
23
- # Extract path (remove 'sqlite:///')
24
  self.db_path = self.parsed.path if self.parsed.path else "./demo.db"
25
- # Fix absolute paths if needed
26
  if self.db_path.startswith("/."): self.db_path = self.db_path[1:]
27
  else:
28
  self.type = "mysql"
@@ -34,12 +29,10 @@ class Database:
34
 
35
  def get_connection(self):
36
  if self.type == "sqlite":
37
- # Connect to SQLite File
38
  conn = sqlite3.connect(self.db_path, check_same_thread=False)
39
- conn.row_factory = sqlite3.Row # Allows accessing columns by name
40
  return conn
41
  else:
42
- # Connect to MySQL Server
43
  return pymysql.connect(
44
  host=self.host,
45
  user=self.user,
@@ -52,36 +45,34 @@ class Database:
52
  def run_query(self, query):
53
  conn = self.get_connection()
54
  try:
55
- # MySQL Logic
56
  if self.type == "mysql":
57
  with conn.cursor() as cursor:
58
  cursor.execute(query)
59
  return cursor.fetchall()
60
-
61
- # SQLite Logic
62
  else:
63
  cursor = conn.cursor()
64
  cursor.execute(query)
65
- # Convert SQLite rows to list of dicts to match MySQL format
66
  items = [dict(row) for row in cursor.fetchall()]
67
  return items
68
-
69
  except Exception as e:
70
  return [f"Error: {e}"]
71
  finally:
72
  conn.close()
73
 
74
  def get_tables(self):
75
- """Returns a list of all table names (supports both SQLite & MySQL)."""
76
  conn = self.get_connection()
77
  try:
78
  cursor = conn.cursor()
79
  if self.type == "sqlite":
80
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
81
- return [row[0] for row in cursor.fetchall()] # row[0] for standard cursor, row['name'] for Row factory
82
  else:
83
  cursor.execute("SHOW TABLES")
84
- return [list(row.values())[0] for row in cursor.fetchall()]
 
 
 
 
85
  except Exception as e:
86
  print(f"Error fetching tables: {e}")
87
  return []
@@ -89,46 +80,35 @@ class Database:
89
  conn.close()
90
 
91
  def get_table_schema(self, table_name):
92
- """Returns column details for a specific table."""
93
  conn = self.get_connection()
94
  columns = []
95
  try:
96
  cursor = conn.cursor()
97
-
98
  if self.type == "sqlite":
99
- # SQLite Schema Query
100
  cursor.execute(f"PRAGMA table_info({table_name})")
101
  rows = cursor.fetchall()
102
- # Row format: (cid, name, type, notnull, dflt_value, pk)
103
  for row in rows:
104
- # Handle both tuple and Row object access
105
  col_name = row['name'] if isinstance(row, sqlite3.Row) else row[1]
106
  col_type = row['type'] if isinstance(row, sqlite3.Row) else row[2]
107
  columns.append(f"{col_name} ({col_type})")
108
  else:
109
- # MySQL Schema Query
110
  cursor.execute(f"DESCRIBE {table_name}")
111
  rows = cursor.fetchall()
112
  for row in rows:
113
  columns.append(f"{row['Field']} ({row['Type']})")
114
-
115
  return columns
116
  except Exception as e:
117
- print(f"Error fetching schema for {table_name}: {e}")
118
  return []
119
  finally:
120
  conn.close()
121
 
122
  def get_schema(self):
123
- """Generates a full text schema of the database for the AI."""
124
  tables = self.get_tables()
125
  schema_text = ""
126
-
127
  for table in tables:
128
  columns = self.get_table_schema(table)
129
  schema_text += f"Table: {table}\nColumns:\n"
130
  for col in columns:
131
  schema_text += f" - {col}\n"
132
  schema_text += "\n"
133
-
134
  return schema_text
 
9
  load_dotenv()
10
  self.db_uri = os.getenv("DB_URI")
11
 
 
 
12
  if not self.db_uri:
13
  print("⚠️ WARNING: DB_URI not found. Defaulting to 'sqlite:///./demo.db'")
14
  self.db_uri = "sqlite:///./demo.db"
15
 
16
  self.parsed = urlparse(self.db_uri)
17
 
 
18
  if "sqlite" in self.parsed.scheme:
19
  self.type = "sqlite"
 
20
  self.db_path = self.parsed.path if self.parsed.path else "./demo.db"
 
21
  if self.db_path.startswith("/."): self.db_path = self.db_path[1:]
22
  else:
23
  self.type = "mysql"
 
29
 
30
  def get_connection(self):
31
  if self.type == "sqlite":
 
32
  conn = sqlite3.connect(self.db_path, check_same_thread=False)
33
+ conn.row_factory = sqlite3.Row
34
  return conn
35
  else:
 
36
  return pymysql.connect(
37
  host=self.host,
38
  user=self.user,
 
45
  def run_query(self, query):
46
  conn = self.get_connection()
47
  try:
 
48
  if self.type == "mysql":
49
  with conn.cursor() as cursor:
50
  cursor.execute(query)
51
  return cursor.fetchall()
 
 
52
  else:
53
  cursor = conn.cursor()
54
  cursor.execute(query)
 
55
  items = [dict(row) for row in cursor.fetchall()]
56
  return items
 
57
  except Exception as e:
58
  return [f"Error: {e}"]
59
  finally:
60
  conn.close()
61
 
62
  def get_tables(self):
 
63
  conn = self.get_connection()
64
  try:
65
  cursor = conn.cursor()
66
  if self.type == "sqlite":
67
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
68
+ tables = [row[0] for row in cursor.fetchall()]
69
  else:
70
  cursor.execute("SHOW TABLES")
71
+ tables = [list(row.values())[0] for row in cursor.fetchall()]
72
+
73
+ # 🔍 DEBUG: PRINT TABLES TO LOGS
74
+ print(f"\n🔍 FOUND TABLES IN DATABASE: {tables}\n")
75
+ return tables
76
  except Exception as e:
77
  print(f"Error fetching tables: {e}")
78
  return []
 
80
  conn.close()
81
 
82
  def get_table_schema(self, table_name):
 
83
  conn = self.get_connection()
84
  columns = []
85
  try:
86
  cursor = conn.cursor()
 
87
  if self.type == "sqlite":
 
88
  cursor.execute(f"PRAGMA table_info({table_name})")
89
  rows = cursor.fetchall()
 
90
  for row in rows:
 
91
  col_name = row['name'] if isinstance(row, sqlite3.Row) else row[1]
92
  col_type = row['type'] if isinstance(row, sqlite3.Row) else row[2]
93
  columns.append(f"{col_name} ({col_type})")
94
  else:
 
95
  cursor.execute(f"DESCRIBE {table_name}")
96
  rows = cursor.fetchall()
97
  for row in rows:
98
  columns.append(f"{row['Field']} ({row['Type']})")
 
99
  return columns
100
  except Exception as e:
 
101
  return []
102
  finally:
103
  conn.close()
104
 
105
  def get_schema(self):
 
106
  tables = self.get_tables()
107
  schema_text = ""
 
108
  for table in tables:
109
  columns = self.get_table_schema(table)
110
  schema_text += f"Table: {table}\nColumns:\n"
111
  for col in columns:
112
  schema_text += f" - {col}\n"
113
  schema_text += "\n"
 
114
  return schema_text