julkarnaeen commited on
Commit
27514ca
Β·
verified Β·
1 Parent(s): 1a07f7f

Update databot.py

Browse files
Files changed (1) hide show
  1. databot.py +309 -417
databot.py CHANGED
@@ -1,418 +1,310 @@
1
- import os
2
- import re
3
- import time
4
- import sys
5
- import json
6
- from dotenv import load_dotenv # pyre-ignore[21]
7
- from sqlalchemy import create_engine, text# pyre-ignore[21]
8
- from openai import OpenAI as OpenAIClient
9
-
10
- # Load credentials from .env
11
- load_dotenv()
12
-
13
- # Path to security config
14
- CONFIG_PATH = os.path.join(os.path.dirname(__file__), "data_access_config.json")
15
-
16
-
17
- class DataBot:
18
- def __init__(self):
19
- # 1. Database Connection Details
20
- self.db_user = os.getenv("DB_USER")
21
- self.db_pass = os.getenv("DB_PASSWORD")
22
- self.db_host = os.getenv("DB_HOST", "51.89.104.26")
23
- self.db_name = os.getenv("DB_NAME", "dev_poly")
24
- self.port = "3306"
25
-
26
- # 2. Initialize OpenAI client directly
27
- self.client = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY"))
28
- self.model = os.getenv("LLM_MODEL", "gpt-4o")
29
-
30
- # 3. Load data access security config
31
- self.access_config: dict = self._load_access_config()
32
-
33
- # 4. Establish MySQL Connection with timeouts
34
- self.engine = create_engine(
35
- f"mysql+pymysql://{self.db_user}:{self.db_pass}@{self.db_host}:{self.port}/{self.db_name}?charset=utf8",
36
- connect_args={
37
- "connect_timeout": 30,
38
- "read_timeout": 60,
39
- "write_timeout": 60,
40
- },
41
- pool_pre_ping=True,
42
- pool_recycle=300,
43
- )
44
-
45
- # 5. Cache the schema once at startup (filtered by access config)
46
- print("Loading database schema...")
47
- raw_schema = self._load_raw_schema_with_retry()
48
- self.schema_info = self._filter_schema(raw_schema)
49
- print(f"Loaded schema for {len(self.schema_info)} accessible tables "
50
- f"(filtered from {len(raw_schema)} total).")
51
-
52
- # ── Security Config ──────────────────────────────────────────────
53
-
54
- def _load_access_config(self) -> dict:
55
- """Load data access security configuration."""
56
- try:
57
- with open(CONFIG_PATH, "r", encoding="utf-8") as f:
58
- config = json.load(f)
59
- print("Loaded data access security config.")
60
- return config
61
- except FileNotFoundError:
62
- print("WARNING: data_access_config.json not found! All tables accessible.")
63
- return {}
64
- except json.JSONDecodeError as e:
65
- print(f"WARNING: Invalid config JSON: {e}. All tables accessible.")
66
- return {}
67
-
68
- def _is_table_allowed(self, table_name):
69
- """Check if a table is allowed based on the access config."""
70
- if not self.access_config:
71
- return True # No config = allow all (backward compatible)
72
-
73
- table_lower = table_name.lower()
74
-
75
- # First check restricted (takes priority)
76
- for prefix in self.access_config.get("restricted_table_prefixes", []):
77
- if table_lower.startswith(prefix.lower()):
78
- return False
79
-
80
- # Then check allowed
81
- for prefix in self.access_config.get("allowed_table_prefixes", []):
82
- if table_lower.startswith(prefix.lower()):
83
- return True
84
-
85
- # Default: deny (whitelist approach)
86
- return False
87
-
88
- def _is_column_allowed(self, column_name):
89
- """Check if a column is allowed based on restricted column patterns."""
90
- if not self.access_config:
91
- return True
92
-
93
- col_lower = column_name.lower()
94
- for restricted_col in self.access_config.get("restricted_columns", []):
95
- if restricted_col.lower() == col_lower:
96
- return False
97
-
98
- return True
99
-
100
- def _filter_schema(self, raw_schema: dict[str, list[str]]) -> dict[str, list[str]]:
101
- """Filter the raw schema to remove restricted tables and columns."""
102
- filtered: dict[str, list[str]] = {}
103
- blocked_tables: list[str] = []
104
-
105
- for table, columns in raw_schema.items():
106
- if not self._is_table_allowed(table):
107
- blocked_tables.append(table)
108
- continue
109
-
110
- # Filter out restricted columns
111
- safe_columns = []
112
- for col_entry in columns:
113
- # col_entry format: "column_name (type)"
114
- col_name = col_entry.split(" (")[0].strip()
115
- if self._is_column_allowed(col_name):
116
- safe_columns.append(col_entry)
117
-
118
- if safe_columns:
119
- filtered[table] = safe_columns
120
-
121
- if blocked_tables:
122
- print(f" β†’ Blocked {len(blocked_tables)} restricted tables.")
123
-
124
- return filtered
125
-
126
- # ── SQL Security Validation ──────────────────────────────────────
127
-
128
- def _validate_sql_security(self, sql):
129
- """
130
- Validate that a generated SQL query doesn't reference restricted tables or columns.
131
- Returns (is_safe, reason) tuple.
132
- """
133
- if not self.access_config:
134
- return True, ""
135
-
136
- sql_upper = sql.upper()
137
-
138
- # Check for write operations (extra safety)
139
- write_ops = ["INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE", "CREATE"]
140
- for op in write_ops:
141
- if re.search(rf'\b{op}\b', sql_upper):
142
- return False, f"Write operation '{op}' is not allowed."
143
-
144
- sql_lower = sql.lower()
145
-
146
- # Check for restricted tables
147
- for prefix in self.access_config.get("restricted_table_prefixes", []):
148
- # Look for table references with word boundaries
149
- pattern = rf'\b{re.escape(prefix.lower())}\w*\b'
150
- if re.search(pattern, sql_lower):
151
- return False, f"Query references restricted data ('{prefix}*' tables). Access denied."
152
-
153
- # Check for restricted columns in SELECT
154
- for restricted_col in self.access_config.get("restricted_columns", []):
155
- pattern = rf'\b{re.escape(restricted_col.lower())}\b'
156
- if re.search(pattern, sql_lower):
157
- return False, f"Query references restricted column '{restricted_col}'. Access denied."
158
-
159
- return True, ""
160
-
161
- # ── Core Schema Loading ──────────────────────────────────────────
162
-
163
- def _load_raw_schema_with_retry(self, max_retries=3, delay=5) -> dict[str, list[str]]:
164
- """Try to load the schema with retry logic for connection failures."""
165
- for attempt in range(1, max_retries + 1):
166
- try:
167
- return self._load_raw_schema()
168
- except Exception as e:
169
- if attempt < max_retries:
170
- print(f" βœ— Connection attempt {attempt}/{max_retries} failed. "
171
- f"Retrying in {delay}s...")
172
- time.sleep(delay)
173
- else:
174
- print(f"\n{'='*60}")
175
- print(f"ERROR: Cannot connect to MySQL server")
176
- print(f" Host: {self.db_host}:{self.port}")
177
- print(f" Database: {self.db_name}")
178
- print(f" Tried {max_retries} times.")
179
- print(f"")
180
- print(f" Possible causes:")
181
- print(f" 1. MySQL service is down on the server")
182
- print(f" 2. Firewall is blocking port {self.port}")
183
- print(f" 3. Wrong host/credentials in .env")
184
- print(f"{'='*60}")
185
- raise SystemExit(1) from e
186
- return {} # unreachable, but satisfies type checker
187
-
188
- def _load_raw_schema(self) -> dict[str, list[str]]:
189
- """Load ALL table names and columns from INFORMATION_SCHEMA (one fast query)."""
190
- query = text("""
191
- SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE
192
- FROM INFORMATION_SCHEMA.COLUMNS
193
- WHERE TABLE_SCHEMA = :db_name
194
- ORDER BY TABLE_NAME, ORDINAL_POSITION
195
- """)
196
- schema: dict[str, list[str]] = {}
197
- with self.engine.connect() as conn:
198
- result = conn.execute(query, {"db_name": self.db_name})
199
- for row in result:
200
- table = row[0]
201
- col_name = row[1]
202
- col_type = row[2]
203
- if table not in schema:
204
- schema[table] = []
205
- schema[table].append(f"{col_name} ({col_type})")
206
- return schema
207
-
208
- # ── LLM Pipeline ─────────────────────────────────────────────────
209
-
210
- def _pick_relevant_tables(self, question, max_tables=5):
211
- """Use GPT to quickly pick relevant tables from ALLOWED table names only."""
212
- table_names: list[str] = list(self.schema_info.keys())
213
- table_list = ", ".join(table_names)
214
-
215
- response = self.client.chat.completions.create(
216
- model=self.model,
217
- temperature=0,
218
- max_tokens=200,
219
- messages=[{
220
- "role": "system",
221
- "content": (
222
- "You are a bilingual database assistant. You understand questions in English and French "
223
- "and pick the most relevant tables.\n"
224
- "You can ONLY pick tables from the provided list. "
225
- "Return ONLY comma-separated table names from the list, nothing else."
226
- )
227
- }, {
228
- "role": "user",
229
- "content": (
230
- f"Here are the ONLY accessible MySQL table names from the dev_poly ERP database:\n"
231
- f"{table_list}\n\n"
232
- f"Question: \"{question}\"\n\n"
233
- f"Pick the {max_tables} most relevant tables to answer this question. "
234
- f"Return ONLY comma-separated table names, nothing else. "
235
- f"You may ONLY use tables from the list above."
236
- )
237
- }]
238
- )
239
- content = response.choices[0].message.content or ""
240
- suggested = [t.strip().strip("'\"` ") for t in content.split(",")]
241
- valid = [t for t in suggested if t in self.schema_info]
242
- fallback: list[str] = table_names[:5]
243
- return valid if valid else fallback
244
-
245
- def _build_schema_context(self, tables: list[str]):
246
- """Build a compact schema string for the selected tables."""
247
- parts: list[str] = []
248
- for table in tables:
249
- if table in self.schema_info:
250
- cols = ", ".join(self.schema_info[table])
251
- parts.append(f"Table '{table}': {cols}")
252
- return "\n".join(parts)
253
-
254
- def _generate_sql(self, question, schema_context):
255
- """Ask GPT to generate a SELECT query with security constraints."""
256
- response = self.client.chat.completions.create(
257
- model=self.model,
258
- temperature=0,
259
- max_tokens=500,
260
- messages=[{
261
- "role": "system",
262
- "content": (
263
- "You are DataBot, a bilingual SQL expert for the dev_poly ERP MySQL database. "
264
- "You understand questions in English and French.\n\n"
265
- "TASK: Generate ONLY ONE single SELECT statement based on the user's question.\n"
266
- "- No INSERT/UPDATE/DELETE. Only SELECT queries.\n"
267
- "- Do NOT generate multiple queries. Combine data using JOINs or subqueries into ONE query.\n"
268
- "- Do NOT use semicolons.\n"
269
- "- Return ONLY the raw SQL query, no explanation, no markdown.\n"
270
- "- Only use the tables and columns provided in the schema context below.\n"
271
- "- NEVER reference tables or columns not in the provided schema.\n\n"
272
- "SPECIAL CASES:\n"
273
- "- If the user's message is NOT a database question (e.g. greetings, chitchat, "
274
- "general knowledge questions unrelated to the database), return ONLY the text: NOT_A_QUERY\n"
275
- "- If the user explicitly asks for passwords, bank account numbers, or identity document "
276
- "numbers, return ONLY the text: SECURITY_BLOCK\n\n"
277
- "IMPORTANT: If the user asks a legitimate business question (like counts, totals, lists) "
278
- "about ANY topic and the schema provides relevant tables, generate the SQL. "
279
- "Do NOT block queries just because they mention employees, staff, or people β€” "
280
- "the schema you receive has already been filtered for security."
281
- )
282
- }, {
283
- "role": "user",
284
- "content": f"Schema:\n{schema_context}\n\nQuestion: {question}"
285
- }]
286
- )
287
- raw_sql = response.choices[0].message.content
288
- if not raw_sql:
289
- return "NOT_A_QUERY"
290
- sql = raw_sql.strip()
291
-
292
- # Check for special LLM responses
293
- if "SECURITY_BLOCK" in sql.upper():
294
- return "SECURITY_BLOCK"
295
- if "NOT_A_QUERY" in sql.upper():
296
- return "NOT_A_QUERY"
297
-
298
- # Clean up any markdown formatting
299
- sql = sql.replace("```sql", "").replace("```", "").strip()
300
- # Safety: if GPT returned multiple statements, keep only the first one
301
- if ";" in sql:
302
- sql = sql.split(";")[0].strip()
303
- return sql
304
-
305
- def _execute_sql(self, sql):
306
- """Execute the SQL and return results."""
307
- with self.engine.connect() as conn:
308
- result = conn.execute(text(sql))
309
- columns = list(result.keys())
310
- rows = [dict(zip(columns, row)) for row in result.fetchall()]
311
- return columns, rows
312
-
313
- def _summarize_results(self, question, sql, columns, rows):
314
- """Ask GPT to summarize the results in a well-structured, insightful way."""
315
- # Limit rows to avoid token overflow
316
- display_rows = rows[:50]
317
- total_count = len(rows)
318
- truncated = total_count > 50
319
-
320
- result_text = f"Columns: {columns}\nRows ({total_count} total"
321
- if truncated:
322
- result_text += f", showing first 50"
323
- result_text += "):\n"
324
- for row in display_rows:
325
- result_text += str(row) + "\n"
326
-
327
- response = self.client.chat.completions.create(
328
- model=self.model,
329
- temperature=0.3,
330
- max_tokens=2000,
331
- messages=[{
332
- "role": "system",
333
- "content": (
334
- "You are DataBot, an intelligent ERP database assistant for the dev_poly system. "
335
- "Your job is to answer the user's question based on the SQL query results provided.\n\n"
336
- "RESPONSE GUIDELINES:\n"
337
- "1. **Answer the question directly first** β€” start with a clear, direct answer to what "
338
- "the user asked. Don't start with 'Based on the query results...' or similar filler.\n"
339
- "2. **Be specific with numbers** β€” always include exact counts, totals, amounts, "
340
- "and percentages where relevant. Round monetary values to 2 decimal places.\n"
341
- "3. **Use structured formatting** when presenting multiple items:\n"
342
- " - Use numbered lists or bullet points for lists of items\n"
343
- " - Use simple text tables for comparisons (align columns with spaces)\n"
344
- " - Bold important values or key findings using **bold**\n"
345
- "4. **Add brief insights** β€” after presenting the data, add 1-2 sentences of "
346
- "business-relevant observations if applicable (e.g., trends, outliers, notable patterns).\n"
347
- "5. **Handle empty results gracefully** β€” if there are 0 rows, say so clearly and "
348
- "suggest possible reasons (e.g., 'No matching records found. This could mean the "
349
- "date range has no activity, or the filter criteria may be too narrow.').\n"
350
- "6. **Keep it conversational but professional** β€” write as a knowledgeable business "
351
- "analyst would speak, not as a robotic data dump.\n"
352
- "7. **If results are truncated** (showing partial data), mention that more records exist "
353
- "and the summary covers the displayed portion.\n"
354
- "8. **Match the user's language** β€” always reply in the same language the user "
355
- "used in their question. If they asked in French, respond entirely in French. "
356
- "If in English, respond in English. Only these two languages are supported.\n\n"
357
- "SECURITY RULES:\n"
358
- "- NEVER include personal data: phone numbers, email addresses, passwords, salaries, "
359
- "bank account numbers, identity document numbers, or home addresses.\n"
360
- "- If the results contain such data, omit it and note that it is restricted.\n"
361
- "- Focus on business-relevant information: counts, totals, trends, entity names, and statuses."
362
- )
363
- }, {
364
- "role": "user",
365
- "content": (
366
- f"Question: {question}\n"
367
- f"SQL executed: {sql}\n"
368
- f"Results:\n{result_text}"
369
- )
370
- }]
371
- )
372
- return (response.choices[0].message.content or "").strip()
373
-
374
- # ── Main Entry Point ─────────────────────────────────────────────
375
-
376
- def ask(self, question):
377
- """Processes a natural language question and returns an answer."""
378
- try:
379
- # Step 1: Pick relevant tables (fast LLM call, only from allowed tables)
380
- relevant_tables = self._pick_relevant_tables(question)
381
- print(f" β†’ Tables: {', '.join(relevant_tables)}")
382
-
383
- # Step 2: Build schema context from cache (instant, no DB query)
384
- schema_context = self._build_schema_context(relevant_tables)
385
-
386
- # Step 3: Generate SQL (fast LLM call with security prompt)
387
- sql = self._generate_sql(question, schema_context)
388
-
389
- # Step 3a: Check for non-database questions
390
- if sql == "NOT_A_QUERY":
391
- return ("Hello! I'm DataBot, your ERP database assistant. "
392
- "Ask me questions about your business data and I'll "
393
- "query the database to find the answer for you.")
394
-
395
- # Step 3b: Check if the LLM blocked the query
396
- if sql == "SECURITY_BLOCK":
397
- return ("I'm sorry, but I cannot provide that information. "
398
- "Your request involves sensitive or personal data "
399
- "(such as salaries, passwords, phone numbers, or identity details) "
400
- "which I am not authorized to access.")
401
-
402
- print(f" β†’ SQL: {sql}")
403
-
404
- # Step 3c: Validate SQL doesn't reference restricted tables/columns
405
- is_safe, reason = self._validate_sql_security(sql)
406
- if not is_safe:
407
- print(f" β†’ SECURITY BLOCK: {reason}")
408
- return ("I'm sorry, but I cannot execute that query. "
409
- f"Security check: {reason}")
410
-
411
- # Step 4: Execute SQL (one fast query)
412
- columns, rows = self._execute_sql(sql)
413
-
414
- # Step 5: Summarize results (fast LLM call)
415
- return self._summarize_results(question, sql, columns, rows)
416
-
417
- except Exception as e:
418
  return f"DataBot Error: {str(e)}"
 
1
+ import os
2
+ import re
3
+ import time
4
+ import json
5
+ from dotenv import load_dotenv # pyre-ignore[21]
6
+ from sqlalchemy import create_engine, text # pyre-ignore[21]
7
+ from openai import OpenAI as OpenAIClient
8
+
9
+ load_dotenv()
10
+
11
+ # Config file paths
12
+ BASE_DIR = os.path.dirname(__file__)
13
+
14
+
15
+ def _load_json(path, name):
16
+ try:
17
+ with open(path, "r", encoding="utf-8") as f:
18
+ return json.load(f)
19
+ except (FileNotFoundError, json.JSONDecodeError) as e:
20
+ print(f" βœ— {name}: {e}")
21
+ return {}
22
+
23
+
24
+ class DataBot:
25
+ def __init__(self):
26
+ print("Loading configurations...")
27
+ self.db_cfg = _load_json(os.path.join(BASE_DIR, "db_config.json"), "db_config")
28
+ self.ai_cfg = _load_json(os.path.join(BASE_DIR, "ai_config.json"), "ai_config")
29
+ self.prompts = _load_json(os.path.join(BASE_DIR, "prompts_config.json"), "prompts_config")
30
+ self.access_cfg = _load_json(os.path.join(BASE_DIR, "data_access_config.json"), "data_access_config")
31
+
32
+ # Query limits
33
+ ql = self.db_cfg.get("query_limits", {})
34
+ self.MAX_ROWS = ql.get("max_rows", 100)
35
+ self.MAX_QUERY_TIME = ql.get("max_query_time_seconds", 30)
36
+ self.MAX_JOIN_TABLES = ql.get("max_join_tables", 3)
37
+
38
+ # Pre-cache restricted columns as a lowercase set (used on every column check)
39
+ self._restricted_cols = {c.lower() for c in self.access_cfg.get("restricted_columns", [])}
40
+
41
+ # AI model
42
+ self.model = self.ai_cfg.get("model", os.getenv("LLM_MODEL", "gpt-4o"))
43
+ self.client = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY"))
44
+
45
+ # Database engine
46
+ conn_cfg = self.db_cfg.get("connection", {})
47
+ timeouts = self.db_cfg.get("timeouts", {})
48
+ pool = self.db_cfg.get("pool", {})
49
+ self.db_user = os.getenv("DB_USER")
50
+ self.db_pass = os.getenv("DB_PASSWORD")
51
+ self.db_host = os.getenv("DB_HOST", conn_cfg.get("host", "51.89.104.26"))
52
+ self.db_name = os.getenv("DB_NAME", conn_cfg.get("database", "dev_poly"))
53
+ self.port = conn_cfg.get("port", "3306")
54
+
55
+ self.engine = create_engine(
56
+ f"mysql+pymysql://{self.db_user}:{self.db_pass}@{self.db_host}:{self.port}/{self.db_name}?charset={conn_cfg.get('charset', 'utf8')}",
57
+ connect_args={
58
+ "connect_timeout": timeouts.get("connect_timeout", 30),
59
+ "read_timeout": timeouts.get("read_timeout", 60),
60
+ "write_timeout": timeouts.get("write_timeout", 60),
61
+ },
62
+ pool_pre_ping=pool.get("pool_pre_ping", True),
63
+ pool_recycle=pool.get("pool_recycle", 300),
64
+ )
65
+
66
+ # Load and filter schema
67
+ print("Loading database schema...")
68
+ schema_cfg = self.db_cfg.get("schema_loading", {})
69
+ raw = self._load_schema(schema_cfg.get("max_retries", 3), schema_cfg.get("retry_delay_seconds", 5))
70
+ self.schema_info = self._filter_schema(raw)
71
+ print(f"Loaded {len(self.schema_info)} accessible tables (from {len(raw)} total).")
72
+
73
+ # ── Schema ────────────────────────────────────────────────────────
74
+
75
+ def _load_schema(self, retries=3, delay=5):
76
+ for attempt in range(1, retries + 1):
77
+ try:
78
+ schema = {}
79
+ with self.engine.connect() as conn:
80
+ rows = conn.execute(text(
81
+ "SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE "
82
+ "FROM INFORMATION_SCHEMA.COLUMNS "
83
+ "WHERE TABLE_SCHEMA = :db ORDER BY TABLE_NAME, ORDINAL_POSITION"
84
+ ), {"db": self.db_name})
85
+ for r in rows:
86
+ schema.setdefault(r[0], []).append(f"{r[1]} ({r[2]})")
87
+ return schema
88
+ except Exception as e:
89
+ if attempt < retries:
90
+ print(f" βœ— Attempt {attempt}/{retries} failed, retrying in {delay}s...")
91
+ time.sleep(delay)
92
+ else:
93
+ print(f"ERROR: Cannot connect to {self.db_host}:{self.port}/{self.db_name}")
94
+ raise SystemExit(1) from e
95
+ return {}
96
+
97
+ def _filter_schema(self, raw):
98
+ if not self.access_cfg:
99
+ return raw
100
+ filtered = {}
101
+ blocked = 0
102
+ for table, cols in raw.items():
103
+ if not self._table_allowed(table):
104
+ blocked += 1
105
+ continue
106
+ safe = [c for c in cols if self._column_allowed(c.split(" (")[0].strip())]
107
+ if safe:
108
+ filtered[table] = safe
109
+ if blocked:
110
+ print(f" β†’ Blocked {blocked} restricted tables.")
111
+ return filtered
112
+
113
+ def _table_allowed(self, name):
114
+ if not self.access_cfg:
115
+ return True
116
+ t = name.lower()
117
+ for p in self.access_cfg.get("restricted_table_prefixes", []):
118
+ if t.startswith(p.lower()):
119
+ return False
120
+ for p in self.access_cfg.get("allowed_table_prefixes", []):
121
+ if t.startswith(p.lower()):
122
+ return True
123
+ return False
124
+
125
+ def _column_allowed(self, name):
126
+ if not self.access_cfg:
127
+ return True
128
+ return name.lower() not in self._restricted_cols
129
+
130
+ # ── Security & Limits ─────────────────────────────────────────────
131
+
132
+ def _validate_security(self, sql):
133
+ if not self.access_cfg:
134
+ return True, ""
135
+ sql_up = sql.upper()
136
+ for op in ("INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE", "CREATE"):
137
+ if re.search(rf'\b{op}\b', sql_up):
138
+ return False, f"Write operation '{op}' is not allowed."
139
+ sql_lo = sql.lower()
140
+ for prefix in self.access_cfg.get("restricted_table_prefixes", []):
141
+ if re.search(rf'\b{re.escape(prefix.lower())}\w*\b', sql_lo):
142
+ return False, f"Restricted data ('{prefix}*' tables). Access denied."
143
+ for col in self.access_cfg.get("restricted_columns", []):
144
+ if re.search(rf'\b{re.escape(col.lower())}\b', sql_lo):
145
+ return False, f"Restricted column '{col}'. Access denied."
146
+ return True, ""
147
+
148
+ def _validate_complexity(self, sql):
149
+ sql_up = sql.upper()
150
+ if "CROSS JOIN" in sql_up:
151
+ return False, "CROSS JOIN is not allowed."
152
+ if len(re.findall(r'\bJOIN\b', sql_up)) > self.MAX_JOIN_TABLES:
153
+ return False, f"Too many JOINs (max {self.MAX_JOIN_TABLES}). Simplify your question."
154
+ if re.search(r'SELECT\s+\*', sql_up) and not re.search(r'SELECT\s+COUNT\s*\(\s*\*\s*\)', sql_up):
155
+ return False, "SELECT * is not allowed. Specific columns must be selected."
156
+ has_where = bool(re.search(r'\bWHERE\b', sql_up))
157
+ has_agg = bool(re.search(r'SELECT\s+(COUNT|SUM|AVG|MIN|MAX)\s*\(', sql_up))
158
+ has_group = bool(re.search(r'\bGROUP\s+BY\b', sql_up))
159
+ if not has_where and not has_agg and not has_group:
160
+ return False, "No WHERE clause or aggregation. Add filters to your question."
161
+ return True, ""
162
+
163
+ def _enforce_limit(self, sql):
164
+ sql_up = sql.upper().strip()
165
+ # Skip pure aggregates without GROUP BY
166
+ if re.search(r'^SELECT\s+(COUNT|SUM|AVG|MIN|MAX)\s*\(', sql_up) and not re.search(r'\bGROUP\s+BY\b', sql_up):
167
+ return sql
168
+ m = re.search(r'\bLIMIT\s+(\d+)', sql_up)
169
+ if m:
170
+ if int(m.group(1)) > self.MAX_ROWS:
171
+ sql = re.sub(r'\bLIMIT\s+\d+', f'LIMIT {self.MAX_ROWS}', sql, flags=re.IGNORECASE)
172
+ return sql
173
+ return f"{sql.rstrip()} LIMIT {self.MAX_ROWS}"
174
+
175
+ # ── Prompt Helper ─────────────────────────────────────────────────
176
+
177
+ def _prompt(self, key, **kw):
178
+ t = self.prompts.get(key, "")
179
+ if not t:
180
+ print(f" βœ— WARNING: prompt '{key}' not found in prompts_config.json")
181
+ return ""
182
+ try:
183
+ return t.format(**kw)
184
+ except KeyError as e:
185
+ print(f" βœ— WARNING: missing placeholder {e} in prompt '{key}'")
186
+ return t
187
+
188
+ # ── LLM Pipeline ─────────────────────────────────────────────────
189
+
190
+ def _pick_tables(self, question):
191
+ cfg = self.ai_cfg.get("table_picker", {})
192
+ max_t = cfg.get("max_tables", 5)
193
+ names = list(self.schema_info.keys())
194
+ resp = self.client.chat.completions.create(
195
+ model=self.model,
196
+ temperature=cfg.get("temperature", 0),
197
+ max_tokens=cfg.get("max_tokens", 200),
198
+ messages=[
199
+ {"role": "system", "content": self._prompt("table_picker_system")},
200
+ {"role": "user", "content": self._prompt("table_picker_user",
201
+ db_name=self.db_name, table_list=", ".join(names),
202
+ question=question, max_tables=max_t)},
203
+ ]
204
+ )
205
+ picked = [t.strip().strip("'\"` ") for t in (resp.choices[0].message.content or "").split(",")]
206
+ valid = [t for t in picked if t in self.schema_info]
207
+ return valid or names[:max_t]
208
+
209
+ def _generate_sql(self, question, schema_ctx):
210
+ cfg = self.ai_cfg.get("sql_generator", {})
211
+ resp = self.client.chat.completions.create(
212
+ model=self.model,
213
+ temperature=cfg.get("temperature", 0),
214
+ max_tokens=cfg.get("max_tokens", 500),
215
+ messages=[
216
+ {"role": "system", "content": self._prompt("sql_generator_system",
217
+ db_name=self.db_name, max_rows=self.MAX_ROWS, max_join_tables=self.MAX_JOIN_TABLES)},
218
+ {"role": "user", "content": self._prompt("sql_generator_user",
219
+ schema_context=schema_ctx, question=question)},
220
+ ]
221
+ )
222
+ sql = (resp.choices[0].message.content or "").strip()
223
+ if "SECURITY_BLOCK" in sql.upper():
224
+ return "SECURITY_BLOCK"
225
+ if "NOT_A_QUERY" in sql.upper():
226
+ return "NOT_A_QUERY"
227
+ sql = sql.replace("```sql", "").replace("```", "").strip()
228
+ if ";" in sql:
229
+ sql = sql.split(";")[0].strip()
230
+ return sql
231
+
232
+ def _execute(self, sql):
233
+ with self.engine.connect() as conn:
234
+ # Try setting query timeout (MariaDB vs MySQL have different syntax)
235
+ try:
236
+ conn.execute(text(f"SET SESSION max_statement_time = {self.MAX_QUERY_TIME}"))
237
+ except Exception:
238
+ try:
239
+ conn.execute(text(f"SET SESSION MAX_EXECUTION_TIME = {self.MAX_QUERY_TIME * 1000}"))
240
+ except Exception:
241
+ pass # Neither supported β€” LIMIT and row cap still protect us
242
+ result = conn.execute(text(sql))
243
+ cols = list(result.keys())
244
+ batch = result.fetchmany(self.MAX_ROWS + 1)
245
+ rows = [dict(zip(cols, r)) for r in batch[:self.MAX_ROWS]]
246
+ if len(batch) > self.MAX_ROWS:
247
+ print(f" β†’ Capped at {self.MAX_ROWS} rows")
248
+ return cols, rows
249
+
250
+ def _summarize(self, question, sql, cols, rows):
251
+ cfg = self.ai_cfg.get("summarizer", {})
252
+ max_disp = cfg.get("max_display_rows", 50)
253
+ shown = rows[:max_disp]
254
+ result_text = f"Columns: {cols}\nRows ({len(rows)} total"
255
+ if len(rows) > max_disp:
256
+ result_text += f", showing first {max_disp}"
257
+ result_text += "):\n" + "\n".join(str(r) for r in shown)
258
+
259
+ resp = self.client.chat.completions.create(
260
+ model=self.model,
261
+ temperature=cfg.get("temperature", 0.3),
262
+ max_tokens=cfg.get("max_tokens", 2000),
263
+ messages=[
264
+ {"role": "system", "content": self._prompt("summarizer_system", db_name=self.db_name)},
265
+ {"role": "user", "content": self._prompt("summarizer_user",
266
+ question=question, sql=sql, result_text=result_text)},
267
+ ]
268
+ )
269
+ return (resp.choices[0].message.content or "").strip()
270
+
271
+ # ── Main Entry ────────────────────────────────────────────────────
272
+
273
+ def ask(self, question):
274
+ try:
275
+ tables = self._pick_tables(question)
276
+ print(f" β†’ Tables: {', '.join(tables)}")
277
+
278
+ schema_ctx = "\n".join(
279
+ f"Table '{t}': {', '.join(self.schema_info[t])}"
280
+ for t in tables if t in self.schema_info
281
+ )
282
+
283
+ sql = self._generate_sql(question, schema_ctx)
284
+
285
+ responses = self.prompts.get("responses", {})
286
+ if sql == "NOT_A_QUERY":
287
+ return responses.get("not_a_query", "I'm DataBot. Ask me about your business data.")
288
+ if sql == "SECURITY_BLOCK":
289
+ return responses.get("security_block", "Access denied: sensitive data requested.")
290
+
291
+ print(f" β†’ SQL: {sql}")
292
+
293
+ ok, reason = self._validate_security(sql)
294
+ if not ok:
295
+ print(f" β†’ BLOCKED: {reason}")
296
+ return responses.get("security_check_fail", "Query blocked: {reason}").format(reason=reason)
297
+
298
+ ok, reason = self._validate_complexity(sql)
299
+ if not ok:
300
+ print(f" β†’ BLOCKED: {reason}")
301
+ return responses.get("complexity_fail", "Query too complex: {reason}").format(reason=reason)
302
+
303
+ sql = self._enforce_limit(sql)
304
+ print(f" β†’ Final: {sql}")
305
+
306
+ cols, rows = self._execute(sql)
307
+ return self._summarize(question, sql, cols, rows)
308
+
309
+ except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  return f"DataBot Error: {str(e)}"