Databot / databot.py
julkarnaeen's picture
Update databot.py
27514ca verified
import os
import re
import time
import json
from dotenv import load_dotenv # pyre-ignore[21]
from sqlalchemy import create_engine, text # pyre-ignore[21]
from openai import OpenAI as OpenAIClient
load_dotenv()
# Config file paths
BASE_DIR = os.path.dirname(__file__)
def _load_json(path, name):
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f" βœ— {name}: {e}")
return {}
class DataBot:
def __init__(self):
print("Loading configurations...")
self.db_cfg = _load_json(os.path.join(BASE_DIR, "db_config.json"), "db_config")
self.ai_cfg = _load_json(os.path.join(BASE_DIR, "ai_config.json"), "ai_config")
self.prompts = _load_json(os.path.join(BASE_DIR, "prompts_config.json"), "prompts_config")
self.access_cfg = _load_json(os.path.join(BASE_DIR, "data_access_config.json"), "data_access_config")
# Query limits
ql = self.db_cfg.get("query_limits", {})
self.MAX_ROWS = ql.get("max_rows", 100)
self.MAX_QUERY_TIME = ql.get("max_query_time_seconds", 30)
self.MAX_JOIN_TABLES = ql.get("max_join_tables", 3)
# Pre-cache restricted columns as a lowercase set (used on every column check)
self._restricted_cols = {c.lower() for c in self.access_cfg.get("restricted_columns", [])}
# AI model
self.model = self.ai_cfg.get("model", os.getenv("LLM_MODEL", "gpt-4o"))
self.client = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY"))
# Database engine
conn_cfg = self.db_cfg.get("connection", {})
timeouts = self.db_cfg.get("timeouts", {})
pool = self.db_cfg.get("pool", {})
self.db_user = os.getenv("DB_USER")
self.db_pass = os.getenv("DB_PASSWORD")
self.db_host = os.getenv("DB_HOST", conn_cfg.get("host", "51.89.104.26"))
self.db_name = os.getenv("DB_NAME", conn_cfg.get("database", "dev_poly"))
self.port = conn_cfg.get("port", "3306")
self.engine = create_engine(
f"mysql+pymysql://{self.db_user}:{self.db_pass}@{self.db_host}:{self.port}/{self.db_name}?charset={conn_cfg.get('charset', 'utf8')}",
connect_args={
"connect_timeout": timeouts.get("connect_timeout", 30),
"read_timeout": timeouts.get("read_timeout", 60),
"write_timeout": timeouts.get("write_timeout", 60),
},
pool_pre_ping=pool.get("pool_pre_ping", True),
pool_recycle=pool.get("pool_recycle", 300),
)
# Load and filter schema
print("Loading database schema...")
schema_cfg = self.db_cfg.get("schema_loading", {})
raw = self._load_schema(schema_cfg.get("max_retries", 3), schema_cfg.get("retry_delay_seconds", 5))
self.schema_info = self._filter_schema(raw)
print(f"Loaded {len(self.schema_info)} accessible tables (from {len(raw)} total).")
# ── Schema ────────────────────────────────────────────────────────
def _load_schema(self, retries=3, delay=5):
for attempt in range(1, retries + 1):
try:
schema = {}
with self.engine.connect() as conn:
rows = conn.execute(text(
"SELECT TABLE_NAME, COLUMN_NAME, COLUMN_TYPE "
"FROM INFORMATION_SCHEMA.COLUMNS "
"WHERE TABLE_SCHEMA = :db ORDER BY TABLE_NAME, ORDINAL_POSITION"
), {"db": self.db_name})
for r in rows:
schema.setdefault(r[0], []).append(f"{r[1]} ({r[2]})")
return schema
except Exception as e:
if attempt < retries:
print(f" βœ— Attempt {attempt}/{retries} failed, retrying in {delay}s...")
time.sleep(delay)
else:
print(f"ERROR: Cannot connect to {self.db_host}:{self.port}/{self.db_name}")
raise SystemExit(1) from e
return {}
def _filter_schema(self, raw):
if not self.access_cfg:
return raw
filtered = {}
blocked = 0
for table, cols in raw.items():
if not self._table_allowed(table):
blocked += 1
continue
safe = [c for c in cols if self._column_allowed(c.split(" (")[0].strip())]
if safe:
filtered[table] = safe
if blocked:
print(f" β†’ Blocked {blocked} restricted tables.")
return filtered
def _table_allowed(self, name):
if not self.access_cfg:
return True
t = name.lower()
for p in self.access_cfg.get("restricted_table_prefixes", []):
if t.startswith(p.lower()):
return False
for p in self.access_cfg.get("allowed_table_prefixes", []):
if t.startswith(p.lower()):
return True
return False
def _column_allowed(self, name):
if not self.access_cfg:
return True
return name.lower() not in self._restricted_cols
# ── Security & Limits ─────────────────────────────────────────────
def _validate_security(self, sql):
if not self.access_cfg:
return True, ""
sql_up = sql.upper()
for op in ("INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE", "CREATE"):
if re.search(rf'\b{op}\b', sql_up):
return False, f"Write operation '{op}' is not allowed."
sql_lo = sql.lower()
for prefix in self.access_cfg.get("restricted_table_prefixes", []):
if re.search(rf'\b{re.escape(prefix.lower())}\w*\b', sql_lo):
return False, f"Restricted data ('{prefix}*' tables). Access denied."
for col in self.access_cfg.get("restricted_columns", []):
if re.search(rf'\b{re.escape(col.lower())}\b', sql_lo):
return False, f"Restricted column '{col}'. Access denied."
return True, ""
def _validate_complexity(self, sql):
sql_up = sql.upper()
if "CROSS JOIN" in sql_up:
return False, "CROSS JOIN is not allowed."
if len(re.findall(r'\bJOIN\b', sql_up)) > self.MAX_JOIN_TABLES:
return False, f"Too many JOINs (max {self.MAX_JOIN_TABLES}). Simplify your question."
if re.search(r'SELECT\s+\*', sql_up) and not re.search(r'SELECT\s+COUNT\s*\(\s*\*\s*\)', sql_up):
return False, "SELECT * is not allowed. Specific columns must be selected."
has_where = bool(re.search(r'\bWHERE\b', sql_up))
has_agg = bool(re.search(r'SELECT\s+(COUNT|SUM|AVG|MIN|MAX)\s*\(', sql_up))
has_group = bool(re.search(r'\bGROUP\s+BY\b', sql_up))
if not has_where and not has_agg and not has_group:
return False, "No WHERE clause or aggregation. Add filters to your question."
return True, ""
def _enforce_limit(self, sql):
sql_up = sql.upper().strip()
# Skip pure aggregates without GROUP BY
if re.search(r'^SELECT\s+(COUNT|SUM|AVG|MIN|MAX)\s*\(', sql_up) and not re.search(r'\bGROUP\s+BY\b', sql_up):
return sql
m = re.search(r'\bLIMIT\s+(\d+)', sql_up)
if m:
if int(m.group(1)) > self.MAX_ROWS:
sql = re.sub(r'\bLIMIT\s+\d+', f'LIMIT {self.MAX_ROWS}', sql, flags=re.IGNORECASE)
return sql
return f"{sql.rstrip()} LIMIT {self.MAX_ROWS}"
# ── Prompt Helper ─────────────────────────────────────────────────
def _prompt(self, key, **kw):
t = self.prompts.get(key, "")
if not t:
print(f" βœ— WARNING: prompt '{key}' not found in prompts_config.json")
return ""
try:
return t.format(**kw)
except KeyError as e:
print(f" βœ— WARNING: missing placeholder {e} in prompt '{key}'")
return t
# ── LLM Pipeline ─────────────────────────────────────────────────
def _pick_tables(self, question):
cfg = self.ai_cfg.get("table_picker", {})
max_t = cfg.get("max_tables", 5)
names = list(self.schema_info.keys())
resp = self.client.chat.completions.create(
model=self.model,
temperature=cfg.get("temperature", 0),
max_tokens=cfg.get("max_tokens", 200),
messages=[
{"role": "system", "content": self._prompt("table_picker_system")},
{"role": "user", "content": self._prompt("table_picker_user",
db_name=self.db_name, table_list=", ".join(names),
question=question, max_tables=max_t)},
]
)
picked = [t.strip().strip("'\"` ") for t in (resp.choices[0].message.content or "").split(",")]
valid = [t for t in picked if t in self.schema_info]
return valid or names[:max_t]
def _generate_sql(self, question, schema_ctx):
cfg = self.ai_cfg.get("sql_generator", {})
resp = self.client.chat.completions.create(
model=self.model,
temperature=cfg.get("temperature", 0),
max_tokens=cfg.get("max_tokens", 500),
messages=[
{"role": "system", "content": self._prompt("sql_generator_system",
db_name=self.db_name, max_rows=self.MAX_ROWS, max_join_tables=self.MAX_JOIN_TABLES)},
{"role": "user", "content": self._prompt("sql_generator_user",
schema_context=schema_ctx, question=question)},
]
)
sql = (resp.choices[0].message.content or "").strip()
if "SECURITY_BLOCK" in sql.upper():
return "SECURITY_BLOCK"
if "NOT_A_QUERY" in sql.upper():
return "NOT_A_QUERY"
sql = sql.replace("```sql", "").replace("```", "").strip()
if ";" in sql:
sql = sql.split(";")[0].strip()
return sql
def _execute(self, sql):
with self.engine.connect() as conn:
# Try setting query timeout (MariaDB vs MySQL have different syntax)
try:
conn.execute(text(f"SET SESSION max_statement_time = {self.MAX_QUERY_TIME}"))
except Exception:
try:
conn.execute(text(f"SET SESSION MAX_EXECUTION_TIME = {self.MAX_QUERY_TIME * 1000}"))
except Exception:
pass # Neither supported β€” LIMIT and row cap still protect us
result = conn.execute(text(sql))
cols = list(result.keys())
batch = result.fetchmany(self.MAX_ROWS + 1)
rows = [dict(zip(cols, r)) for r in batch[:self.MAX_ROWS]]
if len(batch) > self.MAX_ROWS:
print(f" β†’ Capped at {self.MAX_ROWS} rows")
return cols, rows
def _summarize(self, question, sql, cols, rows):
cfg = self.ai_cfg.get("summarizer", {})
max_disp = cfg.get("max_display_rows", 50)
shown = rows[:max_disp]
result_text = f"Columns: {cols}\nRows ({len(rows)} total"
if len(rows) > max_disp:
result_text += f", showing first {max_disp}"
result_text += "):\n" + "\n".join(str(r) for r in shown)
resp = self.client.chat.completions.create(
model=self.model,
temperature=cfg.get("temperature", 0.3),
max_tokens=cfg.get("max_tokens", 2000),
messages=[
{"role": "system", "content": self._prompt("summarizer_system", db_name=self.db_name)},
{"role": "user", "content": self._prompt("summarizer_user",
question=question, sql=sql, result_text=result_text)},
]
)
return (resp.choices[0].message.content or "").strip()
# ── Main Entry ────────────────────────────────────────────────────
def ask(self, question):
try:
tables = self._pick_tables(question)
print(f" β†’ Tables: {', '.join(tables)}")
schema_ctx = "\n".join(
f"Table '{t}': {', '.join(self.schema_info[t])}"
for t in tables if t in self.schema_info
)
sql = self._generate_sql(question, schema_ctx)
responses = self.prompts.get("responses", {})
if sql == "NOT_A_QUERY":
return responses.get("not_a_query", "I'm DataBot. Ask me about your business data.")
if sql == "SECURITY_BLOCK":
return responses.get("security_block", "Access denied: sensitive data requested.")
print(f" β†’ SQL: {sql}")
ok, reason = self._validate_security(sql)
if not ok:
print(f" β†’ BLOCKED: {reason}")
return responses.get("security_check_fail", "Query blocked: {reason}").format(reason=reason)
ok, reason = self._validate_complexity(sql)
if not ok:
print(f" β†’ BLOCKED: {reason}")
return responses.get("complexity_fail", "Query too complex: {reason}").format(reason=reason)
sql = self._enforce_limit(sql)
print(f" β†’ Final: {sql}")
cols, rows = self._execute(sql)
return self._summarize(question, sql, cols, rows)
except Exception as e:
return f"DataBot Error: {str(e)}"