Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import time | |
| import json | |
| from dotenv import load_dotenv # pyre-ignore[21] | |
| from sqlalchemy import create_engine, text # pyre-ignore[21] | |
| from openai import OpenAI as OpenAIClient | |
| load_dotenv() | |
| # Config file paths | |
| BASE_DIR = os.path.dirname(__file__) | |
| def _load_json(path, name): | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except (FileNotFoundError, json.JSONDecodeError) as e: | |
| print(f" β {name}: {e}") | |
| return {} | |
| class DataBot: | |
| def __init__(self): | |
| print("Loading configurations...") | |
| self.db_cfg = _load_json(os.path.join(BASE_DIR, "db_config.json"), "db_config") | |
| self.ai_cfg = _load_json(os.path.join(BASE_DIR, "ai_config.json"), "ai_config") | |
| self.prompts = _load_json(os.path.join(BASE_DIR, "prompts_config.json"), "prompts_config") | |
| self.access_cfg = _load_json(os.path.join(BASE_DIR, "data_access_config.json"), "data_access_config") | |
| # Query limits | |
| ql = self.db_cfg.get("query_limits", {}) | |
| self.MAX_ROWS = ql.get("max_rows", 100) | |
| self.MAX_QUERY_TIME = ql.get("max_query_time_seconds", 30) | |
| self.MAX_JOIN_TABLES = ql.get("max_join_tables", 3) | |
| # Pre-cache restricted columns as a lowercase set (used on every column check) | |
| self._restricted_cols = {c.lower() for c in self.access_cfg.get("restricted_columns", [])} | |
| # AI model | |
| self.model = self.ai_cfg.get("model", os.getenv("LLM_MODEL", "gpt-4o")) | |
| self.client = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY")) | |
| # Database engine | |
| conn_cfg = self.db_cfg.get("connection", {}) | |
| timeouts = self.db_cfg.get("timeouts", {}) | |
| pool = self.db_cfg.get("pool", {}) | |
| self.db_user = os.getenv("DB_USER") | |
| self.db_pass = os.getenv("DB_PASSWORD") | |
| self.db_host = os.getenv("DB_HOST", conn_cfg.get("host", "51.89.104.26")) | |
| self.db_name = os.getenv("DB_NAME", conn_cfg.get("database", "dev_poly")) | |
| self.port = conn_cfg.get("port", "3306") | |
| self.engine = create_engine( | |
| f"mysql+pymysql://{self.db_user}:{self.db_pass}@{self.db_host}:{self.port}/{self.db_name}?charset={conn_cfg.get('charset', 'utf8')}", | |
| connect_args={ | |
| "connect_timeout": timeouts.get("connect_timeout", 30), | |
| "read_timeout": timeouts.get("read_timeout", 60), | |
| "write_timeout": timeouts.get("write_timeout", 60), | |
| }, | |
| pool_pre_ping=pool.get("pool_pre_ping", True), | |
| pool_recycle=pool.get("pool_recycle", 300), | |
| ) | |
| # Load and filter schema | |
| print("Loading database schema...") | |
| schema_cfg = self.db_cfg.get("schema_loading", {}) | |
| raw = self._load_schema(schema_cfg.get("max_retries", 3), schema_cfg.get("retry_delay_seconds", 5)) | |
| self.schema_info = self._filter_schema(raw) | |
| print(f"Loaded {len(self.schema_info)} accessible tables (from {len(raw)} total).") | |
| # ββ Schema ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_schema(self, retries=3, delay=5): | |
| for attempt in range(1, retries + 1): | |
| try: | |
| schema = {} | |
| with self.engine.connect() as conn: | |
| rows = conn.execute(text( | |
| "SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE " | |
| "FROM INFORMATION_SCHEMA.COLUMNS " | |
| "WHERE TABLE_SCHEMA = :db ORDER BY TABLE_NAME, ORDINAL_POSITION" | |
| ), {"db": self.db_name}) | |
| for r in rows: | |
| schema.setdefault(r[0], []).append(f"{r[1]} ({r[2]})") | |
| return schema | |
| except Exception as e: | |
| if attempt < retries: | |
| print(f" β Attempt {attempt}/{retries} failed, retrying in {delay}s...") | |
| time.sleep(delay) | |
| else: | |
| print(f"ERROR: Cannot connect to {self.db_host}:{self.port}/{self.db_name}") | |
| raise SystemExit(1) from e | |
| return {} | |
| def _filter_schema(self, raw): | |
| if not self.access_cfg: | |
| return raw | |
| filtered = {} | |
| blocked = 0 | |
| for table, cols in raw.items(): | |
| if not self._table_allowed(table): | |
| blocked += 1 | |
| continue | |
| safe = [c for c in cols if self._column_allowed(c.split(" (")[0].strip())] | |
| if safe: | |
| filtered[table] = safe | |
| if blocked: | |
| print(f" β Blocked {blocked} restricted tables.") | |
| return filtered | |
| def _table_allowed(self, name): | |
| if not self.access_cfg: | |
| return True | |
| t = name.lower() | |
| for p in self.access_cfg.get("restricted_table_prefixes", []): | |
| if t.startswith(p.lower()): | |
| return False | |
| for p in self.access_cfg.get("allowed_table_prefixes", []): | |
| if t.startswith(p.lower()): | |
| return True | |
| return False | |
| def _column_allowed(self, name): | |
| if not self.access_cfg: | |
| return True | |
| return name.lower() not in self._restricted_cols | |
| # ββ Security & Limits βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _validate_security(self, sql): | |
| if not self.access_cfg: | |
| return True, "" | |
| sql_up = sql.upper() | |
| for op in ("INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE", "CREATE"): | |
| if re.search(rf'\b{op}\b', sql_up): | |
| return False, f"Write operation '{op}' is not allowed." | |
| sql_lo = sql.lower() | |
| for prefix in self.access_cfg.get("restricted_table_prefixes", []): | |
| if re.search(rf'\b{re.escape(prefix.lower())}\w*\b', sql_lo): | |
| return False, f"Restricted data ('{prefix}*' tables). Access denied." | |
| for col in self.access_cfg.get("restricted_columns", []): | |
| if re.search(rf'\b{re.escape(col.lower())}\b', sql_lo): | |
| return False, f"Restricted column '{col}'. Access denied." | |
| return True, "" | |
| def _validate_complexity(self, sql): | |
| sql_up = sql.upper() | |
| if "CROSS JOIN" in sql_up: | |
| return False, "CROSS JOIN is not allowed." | |
| if len(re.findall(r'\bJOIN\b', sql_up)) > self.MAX_JOIN_TABLES: | |
| return False, f"Too many JOINs (max {self.MAX_JOIN_TABLES}). Simplify your question." | |
| if re.search(r'SELECT\s+\*', sql_up) and not re.search(r'SELECT\s+COUNT\s*\(\s*\*\s*\)', sql_up): | |
| return False, "SELECT * is not allowed. Specific columns must be selected." | |
| has_where = bool(re.search(r'\bWHERE\b', sql_up)) | |
| has_agg = bool(re.search(r'SELECT\s+(COUNT|SUM|AVG|MIN|MAX)\s*\(', sql_up)) | |
| has_group = bool(re.search(r'\bGROUP\s+BY\b', sql_up)) | |
| if not has_where and not has_agg and not has_group: | |
| return False, "No WHERE clause or aggregation. Add filters to your question." | |
| return True, "" | |
| def _enforce_limit(self, sql): | |
| sql_up = sql.upper().strip() | |
| # Skip pure aggregates without GROUP BY | |
| if re.search(r'^SELECT\s+(COUNT|SUM|AVG|MIN|MAX)\s*\(', sql_up) and not re.search(r'\bGROUP\s+BY\b', sql_up): | |
| return sql | |
| m = re.search(r'\bLIMIT\s+(\d+)', sql_up) | |
| if m: | |
| if int(m.group(1)) > self.MAX_ROWS: | |
| sql = re.sub(r'\bLIMIT\s+\d+', f'LIMIT {self.MAX_ROWS}', sql, flags=re.IGNORECASE) | |
| return sql | |
| return f"{sql.rstrip()} LIMIT {self.MAX_ROWS}" | |
| # ββ Prompt Helper βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _prompt(self, key, **kw): | |
| t = self.prompts.get(key, "") | |
| if not t: | |
| print(f" β WARNING: prompt '{key}' not found in prompts_config.json") | |
| return "" | |
| try: | |
| return t.format(**kw) | |
| except KeyError as e: | |
| print(f" β WARNING: missing placeholder {e} in prompt '{key}'") | |
| return t | |
| # ββ LLM Pipeline βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _pick_tables(self, question): | |
| cfg = self.ai_cfg.get("table_picker", {}) | |
| max_t = cfg.get("max_tables", 5) | |
| names = list(self.schema_info.keys()) | |
| resp = self.client.chat.completions.create( | |
| model=self.model, | |
| temperature=cfg.get("temperature", 0), | |
| max_tokens=cfg.get("max_tokens", 200), | |
| messages=[ | |
| {"role": "system", "content": self._prompt("table_picker_system")}, | |
| {"role": "user", "content": self._prompt("table_picker_user", | |
| db_name=self.db_name, table_list=", ".join(names), | |
| question=question, max_tables=max_t)}, | |
| ] | |
| ) | |
| picked = [t.strip().strip("'\"` ") for t in (resp.choices[0].message.content or "").split(",")] | |
| valid = [t for t in picked if t in self.schema_info] | |
| return valid or names[:max_t] | |
| def _generate_sql(self, question, schema_ctx): | |
| cfg = self.ai_cfg.get("sql_generator", {}) | |
| resp = self.client.chat.completions.create( | |
| model=self.model, | |
| temperature=cfg.get("temperature", 0), | |
| max_tokens=cfg.get("max_tokens", 500), | |
| messages=[ | |
| {"role": "system", "content": self._prompt("sql_generator_system", | |
| db_name=self.db_name, max_rows=self.MAX_ROWS, max_join_tables=self.MAX_JOIN_TABLES)}, | |
| {"role": "user", "content": self._prompt("sql_generator_user", | |
| schema_context=schema_ctx, question=question)}, | |
| ] | |
| ) | |
| sql = (resp.choices[0].message.content or "").strip() | |
| if "SECURITY_BLOCK" in sql.upper(): | |
| return "SECURITY_BLOCK" | |
| if "NOT_A_QUERY" in sql.upper(): | |
| return "NOT_A_QUERY" | |
| sql = sql.replace("```sql", "").replace("```", "").strip() | |
| if ";" in sql: | |
| sql = sql.split(";")[0].strip() | |
| return sql | |
| def _execute(self, sql): | |
| with self.engine.connect() as conn: | |
| # Try setting query timeout (MariaDB vs MySQL have different syntax) | |
| try: | |
| conn.execute(text(f"SET SESSION max_statement_time = {self.MAX_QUERY_TIME}")) | |
| except Exception: | |
| try: | |
| conn.execute(text(f"SET SESSION MAX_EXECUTION_TIME = {self.MAX_QUERY_TIME * 1000}")) | |
| except Exception: | |
| pass # Neither supported β LIMIT and row cap still protect us | |
| result = conn.execute(text(sql)) | |
| cols = list(result.keys()) | |
| batch = result.fetchmany(self.MAX_ROWS + 1) | |
| rows = [dict(zip(cols, r)) for r in batch[:self.MAX_ROWS]] | |
| if len(batch) > self.MAX_ROWS: | |
| print(f" β Capped at {self.MAX_ROWS} rows") | |
| return cols, rows | |
| def _summarize(self, question, sql, cols, rows): | |
| cfg = self.ai_cfg.get("summarizer", {}) | |
| max_disp = cfg.get("max_display_rows", 50) | |
| shown = rows[:max_disp] | |
| result_text = f"Columns: {cols}\nRows ({len(rows)} total" | |
| if len(rows) > max_disp: | |
| result_text += f", showing first {max_disp}" | |
| result_text += "):\n" + "\n".join(str(r) for r in shown) | |
| resp = self.client.chat.completions.create( | |
| model=self.model, | |
| temperature=cfg.get("temperature", 0.3), | |
| max_tokens=cfg.get("max_tokens", 2000), | |
| messages=[ | |
| {"role": "system", "content": self._prompt("summarizer_system", db_name=self.db_name)}, | |
| {"role": "user", "content": self._prompt("summarizer_user", | |
| question=question, sql=sql, result_text=result_text)}, | |
| ] | |
| ) | |
| return (resp.choices[0].message.content or "").strip() | |
| # ββ Main Entry ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ask(self, question): | |
| try: | |
| tables = self._pick_tables(question) | |
| print(f" β Tables: {', '.join(tables)}") | |
| schema_ctx = "\n".join( | |
| f"Table '{t}': {', '.join(self.schema_info[t])}" | |
| for t in tables if t in self.schema_info | |
| ) | |
| sql = self._generate_sql(question, schema_ctx) | |
| responses = self.prompts.get("responses", {}) | |
| if sql == "NOT_A_QUERY": | |
| return responses.get("not_a_query", "I'm DataBot. Ask me about your business data.") | |
| if sql == "SECURITY_BLOCK": | |
| return responses.get("security_block", "Access denied: sensitive data requested.") | |
| print(f" β SQL: {sql}") | |
| ok, reason = self._validate_security(sql) | |
| if not ok: | |
| print(f" β BLOCKED: {reason}") | |
| return responses.get("security_check_fail", "Query blocked: {reason}").format(reason=reason) | |
| ok, reason = self._validate_complexity(sql) | |
| if not ok: | |
| print(f" β BLOCKED: {reason}") | |
| return responses.get("complexity_fail", "Query too complex: {reason}").format(reason=reason) | |
| sql = self._enforce_limit(sql) | |
| print(f" β Final: {sql}") | |
| cols, rows = self._execute(sql) | |
| return self._summarize(question, sql, cols, rows) | |
| except Exception as e: | |
| return f"DataBot Error: {str(e)}" |