"""Portable smoke requests for NL2SQL Copilot. - Ensures a demo SQLite DB exists under /tmp/nl2sql_dbs/smoke_demo.sqlite - Uploads it to the API - Runs a few representative queries - Exits non-zero on failure (so Make/CI can trust it) Env: API_BASE: base URL of API (default: http://127.0.0.1:8000) API_KEY: API key header value (default: dev-key) """ from __future__ import annotations import json import os import time from pathlib import Path import re import requests API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000").rstrip("/") API_KEY = os.getenv("API_KEY", "dev-key") DB_DIR = Path("/tmp/nl2sql_dbs") DB_PATH = DB_DIR / "smoke_demo.sqlite" _DML_DDL_SQL_RE = re.compile( r"\b(delete|update|insert|drop|alter|truncate|create|replace)\b", re.IGNORECASE ) def _is_select_only_sql(sql: str | None) -> bool: if not sql: return False s = sql.strip().lower() if not s.startswith("select"): return False return _DML_DDL_SQL_RE.search(sql) is None def _ensure_demo_db(path: Path) -> None: """Delegate to scripts/smoke_run.py if available; otherwise fail.""" # Your repo already has scripts/smoke_run.py which creates the DB deterministically. from smoke_run import ensure_demo_db # type: ignore ensure_demo_db(path) def _upload_db_and_get_id(path: Path) -> str: url = f"{API_BASE}/api/v1/nl2sql/upload_db" headers = {"X-API-Key": API_KEY} with path.open("rb") as f: resp = requests.post(url, headers=headers, files={"file": f}, timeout=30) if resp.status_code != 200: raise RuntimeError(f"Upload failed: {resp.status_code} {resp.text[:400]}") data = resp.json() db_id = data.get("db_id") if not db_id: raise RuntimeError(f"Invalid upload response: {data}") return str(db_id) def _run_query(db_id: str, query: str) -> dict: url = f"{API_BASE}/api/v1/nl2sql" headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"} payload = {"db_id": db_id, "query": query} t0 = time.time() timeout_s = float(os.getenv("SMOKE_TIMEOUT", "180")) try: resp = requests.post(url, headers=headers, json=payload, timeout=timeout_s) except requests.exceptions.ReadTimeout: # One retry to smooth over transient provider/LLM slowness. time.sleep(2) resp = requests.post(url, headers=headers, json=payload, timeout=timeout_s) dt_ms = int(round((time.time() - t0) * 1000)) out: dict = {} try: out = resp.json() except Exception: out = {"raw": resp.text} return {"status": resp.status_code, "latency_ms": dt_ms, "body": out} def _get_error_code(body: dict) -> str | None: """Extract error.code from the API response shape if present.""" try: err = body.get("error") if isinstance(err, dict): code = err.get("code") return str(code) if code is not None else None except Exception: return None return None def _is_expected_block(status: int, body: dict, allowed_codes: set[str]) -> bool: """Return True if this looks like an intentional safety rejection.""" if status == 200: return False code = _get_error_code(body) return code in allowed_codes def main() -> int: DB_DIR.mkdir(parents=True, exist_ok=True) try: _ensure_demo_db(DB_PATH) except Exception as e: print(f"āŒ Failed to create demo DB: {e}") return 2 try: db_id = _upload_db_and_get_id(DB_PATH) except Exception as e: print(f"āŒ Failed to upload demo DB: {e}") return 3 checks = [ ("List the first 10 artists.", True), ("Which customer spent the most based on total invoice amount?", True), ("SELECT * FROM Invoice;", False), # must be blocked (full scan without LIMIT) ] ok_all = True for q, should_succeed in checks: r = _run_query(db_id=db_id, query=q) status = r["status"] body = r["body"] print(f"\nQuery: {q}") print(f"HTTP {status} | {r['latency_ms']} ms") print(json.dumps(body, indent=2)[:800]) if should_succeed: if status != 200: ok_all = False else: allowed = { "LLM_BAD_OUTPUT", "PIPELINE_CRASH", # e.g. full_scan_without_limit guardrail "SAFETY_NON_SELECT", "SAFETY_MULTI_STATEMENT", } if status != 200: if not _is_expected_block( status=status, body=body, allowed_codes=allowed ): ok_all = False else: # Accept safe refusal: 200 but SQL must be SELECT-only. sql = body.get("sql") if isinstance(body, dict) else None if not _is_select_only_sql(sql): ok_all = False if ok_all: print("\nāœ… demo-smoke passed") return 0 print("\nāŒ demo-smoke failed (see output above)") return 4 if __name__ == "__main__": raise SystemExit(main())