Vanshcc commited on
Commit
a8441ef
·
verified ·
1 Parent(s): a00009c

Upload 15 files

Browse files
app.py CHANGED
@@ -47,28 +47,21 @@ GROQ_MODELS = [
47
  # Database types
48
  DB_TYPES = {
49
  "MySQL": "mysql",
50
- "PostgreSQL": "postgresql",
51
- "SQLite": "sqlite"
52
  }
53
 
54
 
55
  def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
56
  """Create a custom database configuration from user input."""
57
- db_config = DatabaseConfig.__new__(DatabaseConfig)
58
-
59
- # Set database type
60
- db_config.db_type = DatabaseType(db_type)
61
-
62
- # Set connection parameters
63
- db_config.host = kwargs.get("host", "")
64
- db_config.port = kwargs.get("port", 3306 if db_type == "mysql" else 5432)
65
- db_config.database = kwargs.get("database", "")
66
- db_config.username = kwargs.get("username", "")
67
- db_config.password = kwargs.get("password", "")
68
- db_config.ssl_ca = kwargs.get("ssl_ca", None)
69
- db_config.sqlite_path = kwargs.get("sqlite_path", "./chatbot.db")
70
-
71
- return db_config
72
 
73
 
74
  def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None,
@@ -143,10 +136,7 @@ def render_database_config():
143
  # Show current environment config
144
  current_db_type = config.database.db_type.value.upper()
145
  st.info(f"📌 Using {current_db_type} from environment")
146
- if config.database.is_sqlite:
147
- st.caption(f"Path: {config.database.sqlite_path}")
148
- else:
149
- st.caption(f"Host: {config.database.host}")
150
  return None
151
 
152
  else:
@@ -162,21 +152,7 @@ def render_database_config():
162
  )
163
  db_type = DB_TYPES[db_type_label]
164
 
165
- if db_type == "sqlite":
166
- # SQLite only needs file path
167
- sqlite_path = st.text_input(
168
- "Database File Path",
169
- value="./chatbot.db",
170
- key="sqlite_path_input",
171
- help="Path to SQLite database file (will be created if doesn't exist)"
172
- )
173
-
174
- return {
175
- "db_type": db_type,
176
- "sqlite_path": sqlite_path
177
- }
178
-
179
- else:
180
  # MySQL or PostgreSQL
181
  col1, col2 = st.columns([3, 1])
182
  with col1:
@@ -411,11 +387,7 @@ def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool:
411
  # Validate custom params
412
  db_type = custom_db_params.get("db_type", "mysql")
413
 
414
- if db_type == "sqlite":
415
- if not custom_db_params.get("sqlite_path"):
416
- st.error("Please provide SQLite database path.")
417
- return False
418
- else:
419
  if not all([custom_db_params.get("host"),
420
  custom_db_params.get("database"),
421
  custom_db_params.get("username")]):
@@ -482,6 +454,11 @@ def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool:
482
 
483
  st.session_state.llm = llm
484
  st.session_state.initialized = True
 
 
 
 
 
485
 
486
  # Create memory with appropriate connection
487
  db_conn = st.session_state.custom_db_connection or get_db()
@@ -553,7 +530,7 @@ def render_schema_explorer():
553
  def render_chat_interface():
554
  """Render the main chat interface."""
555
  st.title("🤖 OnceDataBot")
556
- st.caption("Schema-agnostic chatbot • MySQL | PostgreSQL | SQLite • Powered by Groq (FREE!)")
557
 
558
  # Schema explorer
559
  render_schema_explorer()
 
47
  # Database types
48
  DB_TYPES = {
49
  "MySQL": "mysql",
50
+ "PostgreSQL": "postgresql"
 
51
  }
52
 
53
 
54
  def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig:
55
  """Create a custom database configuration from user input."""
56
+ return DatabaseConfig(
57
+ db_type=DatabaseType(db_type),
58
+ host=kwargs.get("host", ""),
59
+ port=kwargs.get("port", 3306 if db_type == "mysql" else 5432),
60
+ database=kwargs.get("database", ""),
61
+ username=kwargs.get("username", ""),
62
+ password=kwargs.get("password", ""),
63
+ ssl_ca=kwargs.get("ssl_ca", None)
64
+ )
 
 
 
 
 
 
65
 
66
 
67
  def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None,
 
136
  # Show current environment config
137
  current_db_type = config.database.db_type.value.upper()
138
  st.info(f"📌 Using {current_db_type} from environment")
139
+ st.caption(f"Host: {config.database.host}")
 
 
 
140
  return None
141
 
142
  else:
 
152
  )
153
  db_type = DB_TYPES[db_type_label]
154
 
155
+ if True: # MySQL or PostgreSQL (SQLite removed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # MySQL or PostgreSQL
157
  col1, col2 = st.columns([3, 1])
158
  with col1:
 
387
  # Validate custom params
388
  db_type = custom_db_params.get("db_type", "mysql")
389
 
390
+ if True:
 
 
 
 
391
  if not all([custom_db_params.get("host"),
392
  custom_db_params.get("database"),
393
  custom_db_params.get("username")]):
 
454
 
455
  st.session_state.llm = llm
456
  st.session_state.initialized = True
457
+ st.session_state.indexed = False # Reset index status on new connection
458
+
459
+ # Clear RAG index to ensure no data from previous DB connection persists
460
+ if hasattr(chatbot, 'rag_engine') and hasattr(chatbot.rag_engine, 'clear_index'):
461
+ chatbot.rag_engine.clear_index()
462
 
463
  # Create memory with appropriate connection
464
  db_conn = st.session_state.custom_db_connection or get_db()
 
530
  def render_chat_interface():
531
  """Render the main chat interface."""
532
  st.title("🤖 OnceDataBot")
533
+ st.caption("Schema-agnostic chatbot • MySQL | PostgreSQL • Powered by Groq (FREE!)")
534
 
535
  # Schema explorer
536
  render_schema_explorer()
chatbot.py CHANGED
@@ -106,7 +106,8 @@ YOUR RESPONSE:"""
106
  if not self._schema_initialized:
107
  raise RuntimeError("Chatbot not initialized. Call initialize() first.")
108
 
109
- schema = get_schema()
 
110
  total_docs = 0
111
 
112
  for table_name, table_info in schema.tables.items():
@@ -117,17 +118,39 @@ YOUR RESPONSE:"""
117
  pk = table_info.primary_keys[0] if table_info.primary_keys else None
118
  cols_to_select = text_cols + ([pk] if pk else [])
119
 
120
- query = f"SELECT {', '.join(cols_to_select)} FROM {table_name} LIMIT 1000"
 
 
 
 
 
 
121
 
122
  try:
 
 
123
  rows = self.db.execute_query(query)
124
  docs = self.rag_engine.index_table(table_name, rows, text_cols, pk)
125
  total_docs += docs
126
 
127
  if progress_callback:
128
  progress_callback(table_name, docs)
129
-
130
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  logger.warning(f"Failed to index {table_name}: {e}")
132
 
133
  self.rag_engine.save()
@@ -146,7 +169,8 @@ YOUR RESPONSE:"""
146
  error="Configure LLM client first")
147
 
148
  try:
149
- schema = get_schema()
 
150
  schema_context = schema.to_context_string()
151
 
152
  # Check for memory commands
@@ -384,7 +408,7 @@ YOUR RESPONSE:"""
384
  """Get a summary of the database schema."""
385
  if not self._schema_initialized:
386
  return "Schema not loaded."
387
- return get_schema().to_context_string()
388
 
389
 
390
  def create_chatbot(llm_client: Optional[LLMClient] = None) -> DatabaseChatbot:
 
106
  if not self._schema_initialized:
107
  raise RuntimeError("Chatbot not initialized. Call initialize() first.")
108
 
109
+ # Use the instance's introspector which might be patched for custom DB
110
+ schema = self.introspector.introspect()
111
  total_docs = 0
112
 
113
  for table_name, table_info in schema.tables.items():
 
118
  pk = table_info.primary_keys[0] if table_info.primary_keys else None
119
  cols_to_select = text_cols + ([pk] if pk else [])
120
 
121
+ # Quote table name based on DB specific rules to handle case sensitivity and special chars
122
+ if self.db.db_type.value == "mysql":
123
+ quoted_table = f"`{table_name}`"
124
+ else:
125
+ quoted_table = f'"{table_name}"'
126
+
127
+ query = f"SELECT {', '.join(cols_to_select)} FROM {quoted_table} LIMIT 1000"
128
 
129
  try:
130
+ # Try the primary query
131
+ query = f"SELECT {', '.join(cols_to_select)} FROM {quoted_table} LIMIT 1000"
132
  rows = self.db.execute_query(query)
133
  docs = self.rag_engine.index_table(table_name, rows, text_cols, pk)
134
  total_docs += docs
135
 
136
  if progress_callback:
137
  progress_callback(table_name, docs)
138
+
139
  except Exception as e:
140
+ # Fallback mechanism for PostgreSQL if table not found (often due to schema issues)
141
+ if self.db.db_type.value == "postgresql" and "UndefinedTable" in str(e):
142
+ try:
143
+ logger.warning(f"Initial query failed for {table_name}, trying 'public' schema prefix...")
144
+ fallback_query = f"SELECT {', '.join(cols_to_select)} FROM public.\"{table_name}\" LIMIT 1000"
145
+ rows = self.db.execute_query(fallback_query)
146
+ docs = self.rag_engine.index_table(table_name, rows, text_cols, pk)
147
+ total_docs += docs
148
+ if progress_callback:
149
+ progress_callback(table_name, docs)
150
+ continue # Success with fallback
151
+ except Exception as e2:
152
+ logger.error(f"Fallback query also failed for {table_name}: {e2}")
153
+
154
  logger.warning(f"Failed to index {table_name}: {e}")
155
 
156
  self.rag_engine.save()
 
169
  error="Configure LLM client first")
170
 
171
  try:
172
+ # Use instance introspector
173
+ schema = self.introspector.introspect()
174
  schema_context = schema.to_context_string()
175
 
176
  # Check for memory commands
 
408
  """Get a summary of the database schema."""
409
  if not self._schema_initialized:
410
  return "Schema not loaded."
411
+ return self.introspector.introspect().to_context_string()
412
 
413
 
414
  def create_chatbot(llm_client: Optional[LLMClient] = None) -> DatabaseChatbot:
config.py CHANGED
@@ -24,7 +24,6 @@ class DatabaseType(Enum):
24
  """Supported database types."""
25
  MYSQL = "mysql"
26
  POSTGRESQL = "postgresql"
27
- SQLITE = "sqlite"
28
 
29
 
30
  class LLMProvider(Enum):
@@ -43,11 +42,11 @@ class EmbeddingProvider(Enum):
43
  @dataclass
44
  class DatabaseConfig:
45
  """
46
- Database configuration supporting MySQL, PostgreSQL, and SQLite.
47
 
48
  All sensitive values are loaded from environment variables.
49
  """
50
- # Database type (mysql, postgresql, sqlite)
51
  db_type: DatabaseType = field(
52
  default_factory=lambda: DatabaseType(os.getenv("DB_TYPE", "mysql").lower())
53
  )
@@ -62,17 +61,10 @@ class DatabaseConfig:
62
  # SSL configuration
63
  ssl_ca: Optional[str] = field(default_factory=lambda: os.getenv("DB_SSL_CA", os.getenv("MYSQL_SSL_CA", None)))
64
 
65
- # SQLite-specific: path to database file
66
- sqlite_path: str = field(default_factory=lambda: os.getenv("SQLITE_PATH", "./chatbot.db"))
67
-
68
  @property
69
  def connection_string(self) -> str:
70
  """Generate SQLAlchemy connection string based on database type."""
71
- if self.db_type == DatabaseType.SQLITE:
72
- # SQLite uses file path
73
- return f"sqlite:///{self.sqlite_path}"
74
-
75
- elif self.db_type == DatabaseType.POSTGRESQL:
76
  # PostgreSQL connection string
77
  base_url = f"postgresql+psycopg2://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
78
  if self.ssl_ca:
@@ -88,12 +80,8 @@ class DatabaseConfig:
88
 
89
  def is_configured(self) -> bool:
90
  """Check if all required database settings are configured."""
91
- if self.db_type == DatabaseType.SQLITE:
92
- # SQLite only needs a valid path
93
- return bool(self.sqlite_path)
94
- else:
95
- # MySQL/PostgreSQL need host, database, username, password
96
- return all([self.host, self.database, self.username, self.password])
97
 
98
  @property
99
  def is_mysql(self) -> bool:
@@ -104,11 +92,6 @@ class DatabaseConfig:
104
  def is_postgresql(self) -> bool:
105
  """Check if using PostgreSQL."""
106
  return self.db_type == DatabaseType.POSTGRESQL
107
-
108
- @property
109
- def is_sqlite(self) -> bool:
110
- """Check if using SQLite."""
111
- return self.db_type == DatabaseType.SQLITE
112
 
113
 
114
  @dataclass
@@ -203,9 +186,7 @@ class RAGConfig:
203
  # MySQL types
204
  "TEXT", "MEDIUMTEXT", "LONGTEXT", "TINYTEXT", "VARCHAR", "CHAR",
205
  # PostgreSQL types
206
- "CHARACTER VARYING", "CHARACTER",
207
- # SQLite types (SQLite is flexible but these are common)
208
- "CLOB", "NVARCHAR", "NCHAR"
209
  ])
210
 
211
  # Minimum character length to consider a column for RAG
@@ -257,10 +238,7 @@ class AppConfig:
257
 
258
  if not self.database.is_configured():
259
  db_type = self.database.db_type.value.upper()
260
- if self.database.is_sqlite:
261
- errors.append("SQLite configuration incomplete. Check SQLITE_PATH environment variable.")
262
- else:
263
- errors.append(f"{db_type} configuration incomplete. Check DB_* environment variables.")
264
 
265
  if not self.llm.is_configured():
266
  errors.append(
 
24
  """Supported database types."""
25
  MYSQL = "mysql"
26
  POSTGRESQL = "postgresql"
 
27
 
28
 
29
  class LLMProvider(Enum):
 
42
  @dataclass
43
  class DatabaseConfig:
44
  """
45
+ Database configuration supporting MySQL and PostgreSQL.
46
 
47
  All sensitive values are loaded from environment variables.
48
  """
49
+ # Database type (mysql, postgresql)
50
  db_type: DatabaseType = field(
51
  default_factory=lambda: DatabaseType(os.getenv("DB_TYPE", "mysql").lower())
52
  )
 
61
  # SSL configuration
62
  ssl_ca: Optional[str] = field(default_factory=lambda: os.getenv("DB_SSL_CA", os.getenv("MYSQL_SSL_CA", None)))
63
 
 
 
 
64
  @property
65
  def connection_string(self) -> str:
66
  """Generate SQLAlchemy connection string based on database type."""
67
+ if self.db_type == DatabaseType.POSTGRESQL:
 
 
 
 
68
  # PostgreSQL connection string
69
  base_url = f"postgresql+psycopg2://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
70
  if self.ssl_ca:
 
80
 
81
  def is_configured(self) -> bool:
82
  """Check if all required database settings are configured."""
83
+ # MySQL/PostgreSQL need host, database, username, password
84
+ return all([self.host, self.database, self.username, self.password])
 
 
 
 
85
 
86
  @property
87
  def is_mysql(self) -> bool:
 
92
  def is_postgresql(self) -> bool:
93
  """Check if using PostgreSQL."""
94
  return self.db_type == DatabaseType.POSTGRESQL
 
 
 
 
 
95
 
96
 
97
  @dataclass
 
186
  # MySQL types
187
  "TEXT", "MEDIUMTEXT", "LONGTEXT", "TINYTEXT", "VARCHAR", "CHAR",
188
  # PostgreSQL types
189
+ "CHARACTER VARYING", "CHARACTER"
 
 
190
  ])
191
 
192
  # Minimum character length to consider a column for RAG
 
238
 
239
  if not self.database.is_configured():
240
  db_type = self.database.db_type.value.upper()
241
+ errors.append(f"{db_type} configuration incomplete. Check DB_* environment variables.")
 
 
 
242
 
243
  if not self.llm.is_configured():
244
  errors.append(
database/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/database/__pycache__/__init__.cpython-311.pyc and b/database/__pycache__/__init__.cpython-311.pyc differ
 
database/__pycache__/connection.cpython-311.pyc CHANGED
Binary files a/database/__pycache__/connection.cpython-311.pyc and b/database/__pycache__/connection.cpython-311.pyc differ
 
database/__pycache__/schema_introspector.cpython-311.pyc CHANGED
Binary files a/database/__pycache__/schema_introspector.cpython-311.pyc and b/database/__pycache__/schema_introspector.cpython-311.pyc differ
 
database/connection.py CHANGED
@@ -52,26 +52,7 @@ class DatabaseConnection:
52
  """
53
  connect_args = {}
54
 
55
- if self.config.db_type == DatabaseType.SQLITE:
56
- # SQLite-specific settings
57
- # Use StaticPool for SQLite to handle multi-threading
58
- connect_args["check_same_thread"] = False
59
-
60
- engine = create_engine(
61
- self.config.connection_string,
62
- poolclass=StaticPool, # SQLite works best with StaticPool
63
- connect_args=connect_args,
64
- echo=False
65
- )
66
-
67
- # Enable foreign keys for SQLite
68
- @event.listens_for(engine, "connect")
69
- def set_sqlite_pragma(dbapi_connection, connection_record):
70
- cursor = dbapi_connection.cursor()
71
- cursor.execute("PRAGMA foreign_keys=ON")
72
- cursor.close()
73
-
74
- elif self.config.db_type == DatabaseType.POSTGRESQL:
75
  # PostgreSQL-specific settings
76
  if self.config.ssl_ca:
77
  connect_args["sslmode"] = "verify-full"
 
52
  """
53
  connect_args = {}
54
 
55
+ if self.config.db_type == DatabaseType.POSTGRESQL:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # PostgreSQL-specific settings
57
  if self.config.ssl_ca:
58
  connect_args["sslmode"] = "verify-full"
database/schema_introspector.py CHANGED
@@ -42,9 +42,7 @@ class ColumnInfo:
42
  # MySQL
43
  'text', 'mediumtext', 'longtext', 'tinytext', 'varchar', 'char', 'json',
44
  # PostgreSQL
45
- 'character varying', 'character', 'text', 'json', 'jsonb',
46
- # SQLite (column affinity - TEXT)
47
- 'clob', 'nvarchar', 'nchar', 'ntext'
48
  ]
49
  data_type_lower = self.data_type.lower().split('(')[0].strip()
50
  return data_type_lower in text_types
@@ -57,9 +55,7 @@ class ColumnInfo:
57
  'int', 'integer', 'bigint', 'smallint', 'tinyint',
58
  'decimal', 'numeric', 'float', 'double', 'real',
59
  # PostgreSQL specific
60
- 'double precision', 'serial', 'bigserial', 'smallserial',
61
- # SQLite (NUMERIC affinity)
62
- 'bool', 'boolean'
63
  ]
64
  data_type_lower = self.data_type.lower().split('(')[0].strip()
65
  return data_type_lower in numeric_types
@@ -185,10 +181,10 @@ class SchemaIntrospector:
185
  '_chatbot_user_summaries',
186
  'schema_migrations',
187
  'flyway_schema_history',
188
- # SQLite internal tables
189
- 'sqlite_sequence',
190
- 'sqlite_stat1',
191
- 'sqlite_stat4'
192
  }
193
 
194
  def __init__(self, engine: Optional[Engine] = None):
@@ -245,10 +241,7 @@ class SchemaIntrospector:
245
  db_type = self.db.db_type
246
 
247
  try:
248
- if db_type.value == "sqlite":
249
- # For SQLite, return the database file name
250
- return self.db.config.sqlite_path.split('/')[-1]
251
- elif db_type.value == "postgresql":
252
  result = self.db.execute_query("SELECT current_database() as db_name")
253
  return result[0]['db_name'] if result else "unknown"
254
  else: # MySQL
@@ -266,18 +259,7 @@ class SchemaIntrospector:
266
  db_type = self.db.db_type
267
 
268
  try:
269
- if db_type.value == "sqlite":
270
- query = """
271
- SELECT name as table_name
272
- FROM sqlite_master
273
- WHERE type='table'
274
- AND name NOT LIKE 'sqlite_%'
275
- ORDER BY name
276
- """
277
- result = self.db.execute_query(query)
278
- return [row['table_name'] for row in result]
279
-
280
- elif db_type.value == "postgresql":
281
  query = """
282
  SELECT table_name
283
  FROM information_schema.tables
@@ -351,24 +333,7 @@ class SchemaIntrospector:
351
  db_type = self.db.db_type
352
 
353
  try:
354
- if db_type.value == "sqlite":
355
- query = f"PRAGMA table_info('{table_name}')"
356
- result = self.db.execute_query(query)
357
-
358
- columns = []
359
- for row in result:
360
- columns.append(ColumnInfo(
361
- name=row['name'],
362
- data_type=row['type'] or 'TEXT', # SQLite columns can have no type
363
- is_nullable=row['notnull'] == 0,
364
- is_primary_key=row['pk'] == 1,
365
- max_length=None,
366
- default_value=row['dflt_value'],
367
- comment=None # SQLite doesn't support column comments
368
- ))
369
- return columns
370
-
371
- elif db_type.value == "postgresql":
372
  query = """
373
  SELECT
374
  column_name,
@@ -438,17 +403,12 @@ class SchemaIntrospector:
438
  db_type = self.db.db_type
439
 
440
  try:
441
- if db_type.value == "sqlite":
442
- query = f"PRAGMA table_info('{table_name}')"
443
- result = self.db.execute_query(query)
444
- return [row['name'] for row in result if row['pk'] > 0]
445
-
446
- elif db_type.value == "postgresql":
447
  query = """
448
  SELECT a.attname as column_name
449
  FROM pg_index i
450
  JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
451
- WHERE i.indrelid = :table_name::regclass
452
  AND i.indisprimary
453
  """
454
  result = self.db.execute_query(query, {"table_name": table_name})
@@ -475,15 +435,7 @@ class SchemaIntrospector:
475
  db_type = self.db.db_type
476
 
477
  try:
478
- if db_type.value == "sqlite":
479
- query = f"PRAGMA foreign_key_list('{table_name}')"
480
- result = self.db.execute_query(query)
481
- return {
482
- row['from']: f"{row['table']}.{row['to']}"
483
- for row in result
484
- }
485
-
486
- elif db_type.value == "postgresql":
487
  query = """
488
  SELECT
489
  kcu.column_name,
@@ -534,13 +486,7 @@ class SchemaIntrospector:
534
  db_type = self.db.db_type
535
 
536
  try:
537
- if db_type.value == "sqlite":
538
- # SQLite doesn't have stats table, use max rowid for estimation
539
- query = f"SELECT MAX(rowid) as row_count FROM \"{table_name}\""
540
- result = self.db.execute_query(query)
541
- return result[0]['row_count'] if result and result[0]['row_count'] else 0
542
-
543
- elif db_type.value == "postgresql":
544
  # Use pg_stat_user_tables for fast estimation
545
  query = """
546
  SELECT n_live_tup as row_count
@@ -569,13 +515,9 @@ class SchemaIntrospector:
569
  db_type = self.db.db_type
570
 
571
  try:
572
- if db_type.value == "sqlite":
573
- # SQLite doesn't support table comments
574
- return None
575
-
576
- elif db_type.value == "postgresql":
577
  query = """
578
- SELECT obj_description(:table_name::regclass, 'pg_class') as table_comment
579
  """
580
  result = self.db.execute_query(query, {"table_name": table_name})
581
  comment = result[0]['table_comment'] if result else None
 
42
  # MySQL
43
  'text', 'mediumtext', 'longtext', 'tinytext', 'varchar', 'char', 'json',
44
  # PostgreSQL
45
+ 'character varying', 'character', 'text', 'json', 'jsonb'
 
 
46
  ]
47
  data_type_lower = self.data_type.lower().split('(')[0].strip()
48
  return data_type_lower in text_types
 
55
  'int', 'integer', 'bigint', 'smallint', 'tinyint',
56
  'decimal', 'numeric', 'float', 'double', 'real',
57
  # PostgreSQL specific
58
+ 'double precision', 'serial', 'bigserial', 'smallserial'
 
 
59
  ]
60
  data_type_lower = self.data_type.lower().split('(')[0].strip()
61
  return data_type_lower in numeric_types
 
181
  '_chatbot_user_summaries',
182
  'schema_migrations',
183
  'flyway_schema_history',
184
+ # Vector store internal tables
185
+ 'chunks',
186
+ 'embeddings',
187
+ 'vectors'
188
  }
189
 
190
  def __init__(self, engine: Optional[Engine] = None):
 
241
  db_type = self.db.db_type
242
 
243
  try:
244
+ if db_type.value == "postgresql":
 
 
 
245
  result = self.db.execute_query("SELECT current_database() as db_name")
246
  return result[0]['db_name'] if result else "unknown"
247
  else: # MySQL
 
259
  db_type = self.db.db_type
260
 
261
  try:
262
+ if db_type.value == "postgresql":
 
 
 
 
 
 
 
 
 
 
 
263
  query = """
264
  SELECT table_name
265
  FROM information_schema.tables
 
333
  db_type = self.db.db_type
334
 
335
  try:
336
+ if db_type.value == "postgresql":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  query = """
338
  SELECT
339
  column_name,
 
403
  db_type = self.db.db_type
404
 
405
  try:
406
+ if db_type.value == "postgresql":
 
 
 
 
 
407
  query = """
408
  SELECT a.attname as column_name
409
  FROM pg_index i
410
  JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
411
+ WHERE i.indrelid = CAST(:table_name AS regclass)
412
  AND i.indisprimary
413
  """
414
  result = self.db.execute_query(query, {"table_name": table_name})
 
435
  db_type = self.db.db_type
436
 
437
  try:
438
+ if db_type.value == "postgresql":
 
 
 
 
 
 
 
 
439
  query = """
440
  SELECT
441
  kcu.column_name,
 
486
  db_type = self.db.db_type
487
 
488
  try:
489
+ if db_type.value == "postgresql":
 
 
 
 
 
 
490
  # Use pg_stat_user_tables for fast estimation
491
  query = """
492
  SELECT n_live_tup as row_count
 
515
  db_type = self.db.db_type
516
 
517
  try:
518
+ if db_type.value == "postgresql":
 
 
 
 
519
  query = """
520
+ SELECT obj_description(CAST(:table_name AS regclass), 'pg_class') as table_comment
521
  """
522
  result = self.db.execute_query(query, {"table_name": table_name})
523
  comment = result[0]['table_comment'] if result else None
sql/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/sql/__pycache__/__init__.cpython-311.pyc and b/sql/__pycache__/__init__.cpython-311.pyc differ
 
sql/__pycache__/generator.cpython-311.pyc CHANGED
Binary files a/sql/__pycache__/generator.cpython-311.pyc and b/sql/__pycache__/generator.cpython-311.pyc differ
 
sql/__pycache__/validator.cpython-311.pyc CHANGED
Binary files a/sql/__pycache__/validator.cpython-311.pyc and b/sql/__pycache__/validator.cpython-311.pyc differ
 
sql/generator.py CHANGED
@@ -16,8 +16,7 @@ def get_sql_dialect(db_type: str) -> str:
16
  """Get the SQL dialect name for the given database type."""
17
  dialects = {
18
  "mysql": "MySQL",
19
- "postgresql": "PostgreSQL",
20
- "sqlite": "SQLite"
21
  }
22
  return dialects.get(db_type, "SQL")
23
 
 
16
  """Get the SQL dialect name for the given database type."""
17
  dialects = {
18
  "mysql": "MySQL",
19
+ "postgresql": "PostgreSQL"
 
20
  }
21
  return dialects.get(db_type, "SQL")
22
 
sql/validator.py CHANGED
@@ -47,7 +47,7 @@ class SQLValidator:
47
  def set_allowed_tables(self, tables: List[str]):
48
  """Set the whitelist of allowed tables."""
49
  self.allowed_tables = set(tables)
50
-
51
  def validate(self, sql: str) -> Tuple[bool, str, Optional[str]]:
52
  """
53
  Validate SQL query for safety.
@@ -94,7 +94,11 @@ class SQLValidator:
94
  # Extract and validate tables
95
  tables = self._extract_tables(statement)
96
  if self.allowed_tables:
97
- invalid_tables = tables - self.allowed_tables
 
 
 
 
98
  if invalid_tables:
99
  return False, f"Access denied to tables: {invalid_tables}", None
100
 
@@ -109,13 +113,14 @@ class SQLValidator:
109
  sql = str(statement)
110
 
111
  # Use regex to find tables after FROM and JOIN
112
- # Pattern: FROM table_name or JOIN table_name
 
113
  from_pattern = re.compile(
114
- r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)',
115
  re.IGNORECASE
116
  )
117
  join_pattern = re.compile(
118
- r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)',
119
  re.IGNORECASE
120
  )
121
 
@@ -128,7 +133,7 @@ class SQLValidator:
128
  tables.add(match.group(1))
129
 
130
  return tables
131
-
132
  def _ensure_limit(self, sql: str) -> str:
133
  """Ensure the query has a LIMIT clause."""
134
  sql_upper = sql.upper()
 
47
  def set_allowed_tables(self, tables: List[str]):
48
  """Set the whitelist of allowed tables."""
49
  self.allowed_tables = set(tables)
50
+
51
  def validate(self, sql: str) -> Tuple[bool, str, Optional[str]]:
52
  """
53
  Validate SQL query for safety.
 
94
  # Extract and validate tables
95
  tables = self._extract_tables(statement)
96
  if self.allowed_tables:
97
+ # Normalize for comparison (remove quotes, lowercase)
98
+ allowed_norm = {t.lower().replace('"', '').replace('`', '') for t in self.allowed_tables}
99
+ tables_norm = {t.lower().replace('"', '').replace('`', '') for t in tables}
100
+
101
+ invalid_tables = tables_norm - allowed_norm
102
  if invalid_tables:
103
  return False, f"Access denied to tables: {invalid_tables}", None
104
 
 
113
  sql = str(statement)
114
 
115
  # Use regex to find tables after FROM and JOIN
116
+ # Pattern: FROM table_name or JOIN table_name, supporting quotes
117
+ # Matches: FROM table, FROM "table", FROM `table`
118
  from_pattern = re.compile(
119
+ r'\bFROM\s+(?:["`]?)([a-zA-Z0-9_]+)(?:["`]?)',
120
  re.IGNORECASE
121
  )
122
  join_pattern = re.compile(
123
+ r'\bJOIN\s+(?:["`]?)([a-zA-Z0-9_]+)(?:["`]?)',
124
  re.IGNORECASE
125
  )
126
 
 
133
  tables.add(match.group(1))
134
 
135
  return tables
136
+
137
  def _ensure_limit(self, sql: str) -> str:
138
  """Ensure the query has a LIMIT clause."""
139
  sql_upper = sql.upper()