import os import re import time import json from dotenv import load_dotenv # pyre-ignore[21] from sqlalchemy import create_engine, text # pyre-ignore[21] from openai import OpenAI as OpenAIClient load_dotenv() # Config file paths BASE_DIR = os.path.dirname(__file__) def _load_json(path, name): try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError) as e: print(f" ✗ {name}: {e}") return {} class DataBot: def __init__(self): print("Loading configurations...") self.db_cfg = _load_json(os.path.join(BASE_DIR, "db_config.json"), "db_config") self.ai_cfg = _load_json(os.path.join(BASE_DIR, "ai_config.json"), "ai_config") self.prompts = _load_json(os.path.join(BASE_DIR, "prompts_config.json"), "prompts_config") self.access_cfg = _load_json(os.path.join(BASE_DIR, "data_access_config.json"), "data_access_config") # Query limits ql = self.db_cfg.get("query_limits", {}) self.MAX_ROWS = ql.get("max_rows", 100) self.MAX_QUERY_TIME = ql.get("max_query_time_seconds", 30) self.MAX_JOIN_TABLES = ql.get("max_join_tables", 3) # Pre-cache restricted columns as a lowercase set (used on every column check) self._restricted_cols = {c.lower() for c in self.access_cfg.get("restricted_columns", [])} # AI model self.model = self.ai_cfg.get("model", os.getenv("LLM_MODEL", "gpt-4o")) self.client = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY")) # Database engine conn_cfg = self.db_cfg.get("connection", {}) timeouts = self.db_cfg.get("timeouts", {}) pool = self.db_cfg.get("pool", {}) self.db_user = os.getenv("DB_USER") self.db_pass = os.getenv("DB_PASSWORD") self.db_host = os.getenv("DB_HOST", conn_cfg.get("host", "51.89.104.26")) self.db_name = os.getenv("DB_NAME", conn_cfg.get("database", "dev_poly")) self.port = conn_cfg.get("port", "3306") self.engine = create_engine( f"mysql+pymysql://{self.db_user}:{self.db_pass}@{self.db_host}:{self.port}/{self.db_name}?charset={conn_cfg.get('charset', 'utf8')}", connect_args={ "connect_timeout": timeouts.get("connect_timeout", 30), "read_timeout": timeouts.get("read_timeout", 60), "write_timeout": timeouts.get("write_timeout", 60), }, pool_pre_ping=pool.get("pool_pre_ping", True), pool_recycle=pool.get("pool_recycle", 300), ) # Load and filter schema print("Loading database schema...") schema_cfg = self.db_cfg.get("schema_loading", {}) raw = self._load_schema(schema_cfg.get("max_retries", 3), schema_cfg.get("retry_delay_seconds", 5)) self.schema_info = self._filter_schema(raw) print(f"Loaded {len(self.schema_info)} accessible tables (from {len(raw)} total).") # ── Schema ──────────────────────────────────────────────────────── def _load_schema(self, retries=3, delay=5): for attempt in range(1, retries + 1): try: schema = {} with self.engine.connect() as conn: rows = conn.execute(text( "SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE " "FROM INFORMATION_SCHEMA.COLUMNS " "WHERE TABLE_SCHEMA = :db ORDER BY TABLE_NAME, ORDINAL_POSITION" ), {"db": self.db_name}) for r in rows: schema.setdefault(r[0], []).append(f"{r[1]} ({r[2]})") return schema except Exception as e: if attempt < retries: print(f" ✗ Attempt {attempt}/{retries} failed, retrying in {delay}s...") time.sleep(delay) else: print(f"ERROR: Cannot connect to {self.db_host}:{self.port}/{self.db_name}") raise SystemExit(1) from e return {} def _filter_schema(self, raw): if not self.access_cfg: return raw filtered = {} blocked = 0 for table, cols in raw.items(): if not self._table_allowed(table): blocked += 1 continue safe = [c for c in cols if self._column_allowed(c.split(" (")[0].strip())] if safe: filtered[table] = safe if blocked: print(f" → Blocked {blocked} restricted tables.") return filtered def _table_allowed(self, name): if not self.access_cfg: return True t = name.lower() for p in self.access_cfg.get("restricted_table_prefixes", []): if t.startswith(p.lower()): return False for p in self.access_cfg.get("allowed_table_prefixes", []): if t.startswith(p.lower()): return True return False def _column_allowed(self, name): if not self.access_cfg: return True return name.lower() not in self._restricted_cols # ── Security & Limits ───────────────────────────────────────────── def _validate_security(self, sql): if not self.access_cfg: return True, "" sql_up = sql.upper() for op in ("INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE", "CREATE"): if re.search(rf'\b{op}\b', sql_up): return False, f"Write operation '{op}' is not allowed." sql_lo = sql.lower() for prefix in self.access_cfg.get("restricted_table_prefixes", []): if re.search(rf'\b{re.escape(prefix.lower())}\w*\b', sql_lo): return False, f"Restricted data ('{prefix}*' tables). Access denied." for col in self.access_cfg.get("restricted_columns", []): if re.search(rf'\b{re.escape(col.lower())}\b', sql_lo): return False, f"Restricted column '{col}'. Access denied." return True, "" def _validate_complexity(self, sql): sql_up = sql.upper() if "CROSS JOIN" in sql_up: return False, "CROSS JOIN is not allowed." if len(re.findall(r'\bJOIN\b', sql_up)) > self.MAX_JOIN_TABLES: return False, f"Too many JOINs (max {self.MAX_JOIN_TABLES}). Simplify your question." if re.search(r'SELECT\s+\*', sql_up) and not re.search(r'SELECT\s+COUNT\s*\(\s*\*\s*\)', sql_up): return False, "SELECT * is not allowed. Specific columns must be selected." has_where = bool(re.search(r'\bWHERE\b', sql_up)) has_agg = bool(re.search(r'SELECT\s+(COUNT|SUM|AVG|MIN|MAX)\s*\(', sql_up)) has_group = bool(re.search(r'\bGROUP\s+BY\b', sql_up)) if not has_where and not has_agg and not has_group: return False, "No WHERE clause or aggregation. Add filters to your question." return True, "" def _enforce_limit(self, sql): sql_up = sql.upper().strip() # Skip pure aggregates without GROUP BY if re.search(r'^SELECT\s+(COUNT|SUM|AVG|MIN|MAX)\s*\(', sql_up) and not re.search(r'\bGROUP\s+BY\b', sql_up): return sql m = re.search(r'\bLIMIT\s+(\d+)', sql_up) if m: if int(m.group(1)) > self.MAX_ROWS: sql = re.sub(r'\bLIMIT\s+\d+', f'LIMIT {self.MAX_ROWS}', sql, flags=re.IGNORECASE) return sql return f"{sql.rstrip()} LIMIT {self.MAX_ROWS}" # ── Prompt Helper ───────────────────────────────────────────────── def _prompt(self, key, **kw): t = self.prompts.get(key, "") if not t: print(f" ✗ WARNING: prompt '{key}' not found in prompts_config.json") return "" try: return t.format(**kw) except KeyError as e: print(f" ✗ WARNING: missing placeholder {e} in prompt '{key}'") return t # ── LLM Pipeline ───────────────────────────────────────────────── def _pick_tables(self, question): cfg = self.ai_cfg.get("table_picker", {}) max_t = cfg.get("max_tables", 5) names = list(self.schema_info.keys()) resp = self.client.chat.completions.create( model=self.model, temperature=cfg.get("temperature", 0), max_tokens=cfg.get("max_tokens", 200), messages=[ {"role": "system", "content": self._prompt("table_picker_system")}, {"role": "user", "content": self._prompt("table_picker_user", db_name=self.db_name, table_list=", ".join(names), question=question, max_tables=max_t)}, ] ) picked = [t.strip().strip("'\"` ") for t in (resp.choices[0].message.content or "").split(",")] valid = [t for t in picked if t in self.schema_info] return valid or names[:max_t] def _generate_sql(self, question, schema_ctx): cfg = self.ai_cfg.get("sql_generator", {}) resp = self.client.chat.completions.create( model=self.model, temperature=cfg.get("temperature", 0), max_tokens=cfg.get("max_tokens", 500), messages=[ {"role": "system", "content": self._prompt("sql_generator_system", db_name=self.db_name, max_rows=self.MAX_ROWS, max_join_tables=self.MAX_JOIN_TABLES)}, {"role": "user", "content": self._prompt("sql_generator_user", schema_context=schema_ctx, question=question)}, ] ) sql = (resp.choices[0].message.content or "").strip() if "SECURITY_BLOCK" in sql.upper(): return "SECURITY_BLOCK" if "NOT_A_QUERY" in sql.upper(): return "NOT_A_QUERY" sql = sql.replace("```sql", "").replace("```", "").strip() if ";" in sql: sql = sql.split(";")[0].strip() return sql def _execute(self, sql): with self.engine.connect() as conn: # Try setting query timeout (MariaDB vs MySQL have different syntax) try: conn.execute(text(f"SET SESSION max_statement_time = {self.MAX_QUERY_TIME}")) except Exception: try: conn.execute(text(f"SET SESSION MAX_EXECUTION_TIME = {self.MAX_QUERY_TIME * 1000}")) except Exception: pass # Neither supported — LIMIT and row cap still protect us result = conn.execute(text(sql)) cols = list(result.keys()) batch = result.fetchmany(self.MAX_ROWS + 1) rows = [dict(zip(cols, r)) for r in batch[:self.MAX_ROWS]] if len(batch) > self.MAX_ROWS: print(f" → Capped at {self.MAX_ROWS} rows") return cols, rows def _summarize(self, question, sql, cols, rows): cfg = self.ai_cfg.get("summarizer", {}) max_disp = cfg.get("max_display_rows", 50) shown = rows[:max_disp] result_text = f"Columns: {cols}\nRows ({len(rows)} total" if len(rows) > max_disp: result_text += f", showing first {max_disp}" result_text += "):\n" + "\n".join(str(r) for r in shown) resp = self.client.chat.completions.create( model=self.model, temperature=cfg.get("temperature", 0.3), max_tokens=cfg.get("max_tokens", 2000), messages=[ {"role": "system", "content": self._prompt("summarizer_system", db_name=self.db_name)}, {"role": "user", "content": self._prompt("summarizer_user", question=question, sql=sql, result_text=result_text)}, ] ) return (resp.choices[0].message.content or "").strip() # ── Main Entry ──────────────────────────────────────────────────── def ask(self, question): try: tables = self._pick_tables(question) print(f" → Tables: {', '.join(tables)}") schema_ctx = "\n".join( f"Table '{t}': {', '.join(self.schema_info[t])}" for t in tables if t in self.schema_info ) sql = self._generate_sql(question, schema_ctx) responses = self.prompts.get("responses", {}) if sql == "NOT_A_QUERY": return responses.get("not_a_query", "I'm DataBot. Ask me about your business data.") if sql == "SECURITY_BLOCK": return responses.get("security_block", "Access denied: sensitive data requested.") print(f" → SQL: {sql}") ok, reason = self._validate_security(sql) if not ok: print(f" → BLOCKED: {reason}") return responses.get("security_check_fail", "Query blocked: {reason}").format(reason=reason) ok, reason = self._validate_complexity(sql) if not ok: print(f" → BLOCKED: {reason}") return responses.get("complexity_fail", "Query too complex: {reason}").format(reason=reason) sql = self._enforce_limit(sql) print(f" → Final: {sql}") cols, rows = self._execute(sql) return self._summarize(question, sql, cols, rows) except Exception as e: return f"DataBot Error: {str(e)}"