Spaces:
Running
Running
File size: 5,117 Bytes
ddd54ed f0b4004 4596e5b f0b4004 4596e5b f0b4004 ddd54ed f0b4004 ddd54ed f0b4004 ddd54ed f0b4004 4596e5b f0b4004 ddd54ed 4596e5b ddd54ed 4596e5b f0b4004 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""Portable smoke requests for NL2SQL Copilot.
- Ensures a demo SQLite DB exists under /tmp/nl2sql_dbs/smoke_demo.sqlite
- Uploads it to the API
- Runs a few representative queries
- Exits non-zero on failure (so Make/CI can trust it)
Env:
API_BASE: base URL of API (default: http://127.0.0.1:8000)
API_KEY: API key header value (default: dev-key)
"""
from __future__ import annotations
import json
import os
import time
from pathlib import Path
import re
import requests
API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000").rstrip("/")
API_KEY = os.getenv("API_KEY", "dev-key")
DB_DIR = Path("/tmp/nl2sql_dbs")
DB_PATH = DB_DIR / "smoke_demo.sqlite"
_DML_DDL_SQL_RE = re.compile(
r"\b(delete|update|insert|drop|alter|truncate|create|replace)\b", re.IGNORECASE
)
def _is_select_only_sql(sql: str | None) -> bool:
if not sql:
return False
s = sql.strip().lower()
if not s.startswith("select"):
return False
return _DML_DDL_SQL_RE.search(sql) is None
def _ensure_demo_db(path: Path) -> None:
"""Delegate to scripts/smoke_run.py if available; otherwise fail."""
# Your repo already has scripts/smoke_run.py which creates the DB deterministically.
from smoke_run import ensure_demo_db # type: ignore
ensure_demo_db(path)
def _upload_db_and_get_id(path: Path) -> str:
url = f"{API_BASE}/api/v1/nl2sql/upload_db"
headers = {"X-API-Key": API_KEY}
with path.open("rb") as f:
resp = requests.post(url, headers=headers, files={"file": f}, timeout=30)
if resp.status_code != 200:
raise RuntimeError(f"Upload failed: {resp.status_code} {resp.text[:400]}")
data = resp.json()
db_id = data.get("db_id")
if not db_id:
raise RuntimeError(f"Invalid upload response: {data}")
return str(db_id)
def _run_query(db_id: str, query: str) -> dict:
url = f"{API_BASE}/api/v1/nl2sql"
headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"}
payload = {"db_id": db_id, "query": query}
t0 = time.time()
timeout_s = float(os.getenv("SMOKE_TIMEOUT", "180"))
try:
resp = requests.post(url, headers=headers, json=payload, timeout=timeout_s)
except requests.exceptions.ReadTimeout:
# One retry to smooth over transient provider/LLM slowness.
time.sleep(2)
resp = requests.post(url, headers=headers, json=payload, timeout=timeout_s)
dt_ms = int(round((time.time() - t0) * 1000))
out: dict = {}
try:
out = resp.json()
except Exception:
out = {"raw": resp.text}
return {"status": resp.status_code, "latency_ms": dt_ms, "body": out}
def _get_error_code(body: dict) -> str | None:
"""Extract error.code from the API response shape if present."""
try:
err = body.get("error")
if isinstance(err, dict):
code = err.get("code")
return str(code) if code is not None else None
except Exception:
return None
return None
def _is_expected_block(status: int, body: dict, allowed_codes: set[str]) -> bool:
"""Return True if this looks like an intentional safety rejection."""
if status == 200:
return False
code = _get_error_code(body)
return code in allowed_codes
def main() -> int:
DB_DIR.mkdir(parents=True, exist_ok=True)
try:
_ensure_demo_db(DB_PATH)
except Exception as e:
print(f"❌ Failed to create demo DB: {e}")
return 2
try:
db_id = _upload_db_and_get_id(DB_PATH)
except Exception as e:
print(f"❌ Failed to upload demo DB: {e}")
return 3
checks = [
("List the first 10 artists.", True),
("Which customer spent the most based on total invoice amount?", True),
("SELECT * FROM Invoice;", False), # must be blocked (full scan without LIMIT)
]
ok_all = True
for q, should_succeed in checks:
r = _run_query(db_id=db_id, query=q)
status = r["status"]
body = r["body"]
print(f"\nQuery: {q}")
print(f"HTTP {status} | {r['latency_ms']} ms")
print(json.dumps(body, indent=2)[:800])
if should_succeed:
if status != 200:
ok_all = False
else:
allowed = {
"LLM_BAD_OUTPUT",
"PIPELINE_CRASH", # e.g. full_scan_without_limit guardrail
"SAFETY_NON_SELECT",
"SAFETY_MULTI_STATEMENT",
}
if status != 200:
if not _is_expected_block(
status=status, body=body, allowed_codes=allowed
):
ok_all = False
else:
# Accept safe refusal: 200 but SQL must be SELECT-only.
sql = body.get("sql") if isinstance(body, dict) else None
if not _is_select_only_sql(sql):
ok_all = False
if ok_all:
print("\n✅ demo-smoke passed")
return 0
print("\n❌ demo-smoke failed (see output above)")
return 4
if __name__ == "__main__":
raise SystemExit(main())
|