github-actions[bot] commited on
Commit
4596e5b
·
1 Parent(s): ddd54ed

Sync from GitHub main @ 8f40ad2807fc87dbdaae076316a949ce2aa8d865

Browse files
README.md CHANGED
@@ -10,7 +10,7 @@ pinned: false
10
  # NL2SQL Copilot — Safety-First, Production-Grade Text-to-SQL
11
 
12
  [![CI](https://github.com/melika-kheirieh/nl2sql-copilot/actions/workflows/ci.yml/badge.svg)](https://github.com/melika-kheirieh/nl2sql-copilot/actions/workflows/ci.yml)
13
- [![Docker](https://img.shields.io/badge/docker-ready-blue?logo=docker)](#)
14
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
15
 
16
  A **production-oriented, multi-stage Natural Language → SQL system** built around
@@ -103,7 +103,7 @@ make demo-metrics
103
  ```
104
 
105
  >For a complete end-to-end setup (API, infra, metrics, dashboards, UIs),
106
- see [docs//runbook.md](docs//runbook.md).
107
 
108
  ---
109
 
 
10
  # NL2SQL Copilot — Safety-First, Production-Grade Text-to-SQL
11
 
12
  [![CI](https://github.com/melika-kheirieh/nl2sql-copilot/actions/workflows/ci.yml/badge.svg)](https://github.com/melika-kheirieh/nl2sql-copilot/actions/workflows/ci.yml)
13
+ [![Docker](https://img.shields.io/badge/docker--compose-demo-blue?logo=docker)](docs/runbook.md)
14
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
15
 
16
  A **production-oriented, multi-stage Natural Language → SQL system** built around
 
103
  ```
104
 
105
  >For a complete end-to-end setup (API, infra, metrics, dashboards, UIs),
106
+ see [docs/runbook.md](docs/runbook.md).
107
 
108
  ---
109
 
app/routers/nl2sql.py CHANGED
@@ -213,7 +213,7 @@ def health():
213
  return {"status": "ok", "version": settings.app_version}
214
 
215
 
216
- def _ck(db_id: Optional[str], query: str, schema_preview: str) -> str:
217
  db_part = db_id or "__default__"
218
  seed = f"{db_part}\n{query.strip()}"
219
  return hashlib.sha1(seed.encode("utf-8")).hexdigest()
@@ -232,20 +232,6 @@ def nl2sql_handler(
232
  ) -> NL2SQLResponse | ClarifyResponse:
233
  db_id = getattr(request, "db_id", None)
234
 
235
- # # ---- deterministic SELECT-only guard ----
236
- # # Block DML/DDL intents early (before schema derivation, cache, or LLM calls).
237
- # if _is_unsafe_intent(getattr(request, "query", "")):
238
- # raise HTTPException(
239
- # status_code=400,
240
- # detail={
241
- # "error": {
242
- # "code": "BAD_REQUEST",
243
- # "retryable": False,
244
- # "details": ["non_select_query"],
245
- # }
246
- # },
247
- # )
248
-
249
  # ---- schema preview ----
250
  try:
251
  final_preview = svc.get_schema_preview(
@@ -266,7 +252,7 @@ def nl2sql_handler(
266
  ) from exc
267
 
268
  # ---- cache lookup ----
269
- cache_key = _ck(db_id, request.query, final_preview)
270
  cached_payload = cache.get(cache_key)
271
  if cached_payload is not None:
272
  # Cache stores dicts; convert back to response models for type safety.
 
213
  return {"status": "ok", "version": settings.app_version}
214
 
215
 
216
+ def _ck(db_id: Optional[str], query: str) -> str:
217
  db_part = db_id or "__default__"
218
  seed = f"{db_part}\n{query.strip()}"
219
  return hashlib.sha1(seed.encode("utf-8")).hexdigest()
 
232
  ) -> NL2SQLResponse | ClarifyResponse:
233
  db_id = getattr(request, "db_id", None)
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  # ---- schema preview ----
236
  try:
237
  final_preview = svc.get_schema_preview(
 
252
  ) from exc
253
 
254
  # ---- cache lookup ----
255
+ cache_key = _ck(db_id, request.query)
256
  cached_payload = cache.get(cache_key)
257
  if cached_payload is not None:
258
  # Cache stores dicts; convert back to response models for type safety.
scripts/demo_cache_showcase.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate a reproducible cache/metrics screenshot workload.
3
+
4
+ What it does:
5
+ 1) Waits for API readiness (healthz + readyz + router health).
6
+ 2) Uploads a demo SQLite DB to the API (upload_db) and captures db_id.
7
+ 3) Sends a burst of unique queries (mostly misses).
8
+ 4) Sends repeated queries over ~70–90s (hits), with jitter so charts look natural.
9
+ 5) Triggers a safety violation once (should be blocked) WITHOUT failing the whole demo.
10
+ 6) Sends a final "recovery" query (OK).
11
+ 7) (Optional) Prints a Prometheus instant-query sanity check for cache metrics.
12
+
13
+ Expected API:
14
+ - POST {API_BASE}/api/v1/nl2sql/upload_db (multipart form: file=@db.sqlite) -> {db_id: "..."}
15
+ - POST {API_BASE}/api/v1/nl2sql (json: {db_id, query, schema_preview?}) -> 200 or 4xx/5xx
16
+ - GET {API_BASE}/healthz
17
+ - GET {API_BASE}/readyz
18
+ - GET {API_BASE}/api/v1/nl2sql/health
19
+
20
+ Env:
21
+ - API_BASE (default http://127.0.0.1:8000)
22
+ - API_KEY (default dev-key)
23
+ - DB_PATH (default /tmp/nl2sql_dbs/smoke_demo.sqlite)
24
+ - PROM_BASE (default http://127.0.0.1:9090) (optional; set empty to skip)
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import json
30
+ import os
31
+ import random
32
+ import subprocess
33
+ import time
34
+ from dataclasses import dataclass
35
+ from typing import Any
36
+
37
+
38
+ def sh(args: list[str], *, check: bool = True) -> subprocess.CompletedProcess[str]:
39
+ """Run a command and return the completed process (text mode)."""
40
+ return subprocess.run(
41
+ args,
42
+ check=check,
43
+ text=True,
44
+ stdout=subprocess.PIPE,
45
+ stderr=subprocess.PIPE,
46
+ )
47
+
48
+
49
+ @dataclass(frozen=True)
50
+ class Cfg:
51
+ api_base: str
52
+ api_key: str
53
+ db_path: str
54
+ prom_base: str | None
55
+
56
+
57
+ def load_cfg() -> Cfg:
58
+ api_base = os.getenv("API_BASE", "http://127.0.0.1:8000").rstrip("/")
59
+ api_key = os.getenv("API_KEY", "dev-key")
60
+ db_path = os.getenv("DB_PATH", "/tmp/nl2sql_dbs/smoke_demo.sqlite")
61
+ prom_base_env = os.getenv("PROM_BASE", "http://127.0.0.1:9090").rstrip("/")
62
+ prom_base: str | None = prom_base_env if prom_base_env else None
63
+ return Cfg(api_base=api_base, api_key=api_key, db_path=db_path, prom_base=prom_base)
64
+
65
+
66
+ def wait_for_ready(cfg: Cfg, timeout_s: float = 60.0) -> None:
67
+ """Wait until API is responsive and ready.
68
+
69
+ We try multiple endpoints because on cold starts the container may accept TCP but reset early requests.
70
+ """
71
+ endpoints = [
72
+ f"{cfg.api_base}/healthz",
73
+ f"{cfg.api_base}/readyz",
74
+ f"{cfg.api_base}/api/v1/nl2sql/health",
75
+ ]
76
+
77
+ start = time.time()
78
+ last = ""
79
+ while time.time() - start < timeout_s:
80
+ ok = True
81
+ for url in endpoints:
82
+ cp = subprocess.run(
83
+ ["curl", "-sS", "-o", "/dev/null", "-w", "%{http_code}", url],
84
+ check=False,
85
+ text=True,
86
+ capture_output=True,
87
+ )
88
+ code = (cp.stdout or "").strip()
89
+ if code != "200":
90
+ ok = False
91
+ last = f"url={url} http={code} stderr={cp.stderr.strip()!r}"
92
+ break
93
+
94
+ if ok:
95
+ return
96
+
97
+ time.sleep(0.6)
98
+
99
+ raise RuntimeError(f"API not ready after {timeout_s:.0f}s. Last={last}")
100
+
101
+
102
+ def upload_db(cfg: Cfg) -> str:
103
+ if not os.path.exists(cfg.db_path):
104
+ raise FileNotFoundError(f"DB_PATH not found: {cfg.db_path}")
105
+
106
+ url = f"{cfg.api_base}/api/v1/nl2sql/upload_db"
107
+
108
+ # Do NOT use -f here; on error we want the body.
109
+ cp = subprocess.run(
110
+ [
111
+ "curl",
112
+ "-sS",
113
+ "-D",
114
+ "-",
115
+ "-H",
116
+ f"X-API-Key: {cfg.api_key}",
117
+ "-F",
118
+ f"file=@{cfg.db_path}",
119
+ url,
120
+ ],
121
+ check=False,
122
+ text=True,
123
+ capture_output=True,
124
+ )
125
+
126
+ if cp.returncode != 0:
127
+ raise RuntimeError(
128
+ f"upload_db curl failed (rc={cp.returncode}). stderr={cp.stderr.strip()!r}\nstdout:\n{cp.stdout}"
129
+ )
130
+
131
+ # Split headers/body
132
+ raw = cp.stdout
133
+ parts = raw.split("\r\n\r\n", 1)
134
+ if len(parts) != 2:
135
+ parts = raw.split("\n\n", 1)
136
+ if len(parts) != 2:
137
+ raise RuntimeError(f"upload_db returned unexpected response:\n{raw}")
138
+
139
+ headers, body = parts[0], parts[1]
140
+ status_line = headers.splitlines()[0] if headers.splitlines() else ""
141
+ if " 200 " not in status_line:
142
+ raise RuntimeError(f"upload_db non-200.\n{headers}\n\n{body}")
143
+
144
+ try:
145
+ data = json.loads(body)
146
+ except json.JSONDecodeError as e:
147
+ raise RuntimeError(
148
+ f"upload_db returned non-JSON body.\n{headers}\n\n{body}"
149
+ ) from e
150
+
151
+ db_id = data.get("db_id")
152
+ if not isinstance(db_id, str) or not db_id:
153
+ raise RuntimeError(f"upload_db response missing db_id: {data}")
154
+ return db_id
155
+
156
+
157
+ def post_query(
158
+ cfg: Cfg, *, db_id: str, query: str, fail_on_non_200: bool = True
159
+ ) -> int:
160
+ """POST a query. Returns HTTP status code. Optionally raises on non-200 with full response."""
161
+ url = f"{cfg.api_base}/api/v1/nl2sql"
162
+ payload = json.dumps({"db_id": db_id, "query": query})
163
+
164
+ cp = subprocess.run(
165
+ [
166
+ "curl",
167
+ "-sS",
168
+ "-D",
169
+ "-",
170
+ "-H",
171
+ f"X-API-Key: {cfg.api_key}",
172
+ "-H",
173
+ "Content-Type: application/json",
174
+ "-d",
175
+ payload,
176
+ url,
177
+ ],
178
+ check=False,
179
+ text=True,
180
+ capture_output=True,
181
+ )
182
+
183
+ if cp.returncode != 0:
184
+ raise RuntimeError(
185
+ f"query curl failed (rc={cp.returncode}). query={query!r}\n"
186
+ f"stderr={cp.stderr.strip()!r}\nstdout:\n{cp.stdout}"
187
+ )
188
+
189
+ raw = cp.stdout
190
+ parts = raw.split("\r\n\r\n", 1)
191
+ if len(parts) != 2:
192
+ parts = raw.split("\n\n", 1)
193
+ if len(parts) != 2:
194
+ raise RuntimeError(
195
+ f"query returned unexpected response. query={query!r}\n{raw}"
196
+ )
197
+
198
+ headers, body = parts[0], parts[1]
199
+ status_line = headers.splitlines()[0] if headers.splitlines() else ""
200
+
201
+ # Parse HTTP status code from first line: HTTP/1.1 200 OK
202
+ status_code = 0
203
+ try:
204
+ status_code = int(status_line.split()[1])
205
+ except Exception:
206
+ status_code = 0
207
+
208
+ if fail_on_non_200 and status_code != 200:
209
+ raise RuntimeError(f"Non-200 response for query={query!r}\n{headers}\n\n{body}")
210
+
211
+ return status_code
212
+
213
+
214
+ def prom_instant_query(cfg: Cfg, expr: str) -> Any | None:
215
+ if not cfg.prom_base:
216
+ return None
217
+ url = f"{cfg.prom_base}/api/v1/query"
218
+ cp = sh(["curl", "-fsS", url, "--data-urlencode", f"query={expr}"])
219
+ return json.loads(cp.stdout)
220
+
221
+
222
+ def post_dev_safety(cfg: Cfg, sql: str) -> int:
223
+ """Trigger the Safety stage directly (dev endpoint) so OK-rate panels aren't affected."""
224
+ url = f"{cfg.api_base}/api/v1/_dev/safety"
225
+ payload = json.dumps({"sql": sql})
226
+ cp = subprocess.run(
227
+ [
228
+ "curl",
229
+ "-sS",
230
+ "-D",
231
+ "-",
232
+ "-H",
233
+ f"X-API-Key: {cfg.api_key}",
234
+ "-H",
235
+ "Content-Type: application/json",
236
+ "-d",
237
+ payload,
238
+ url,
239
+ ],
240
+ check=False,
241
+ text=True,
242
+ capture_output=True,
243
+ )
244
+ raw = cp.stdout
245
+ # Parse status code from HTTP status line.
246
+ header_block = raw.split("\r\n\r\n", 1)[0]
247
+ status_line = header_block.splitlines()[0] if header_block.splitlines() else ""
248
+ try:
249
+ return int(status_line.split()[1])
250
+ except Exception:
251
+ return 0
252
+
253
+
254
+ def print_cache_sanity(cfg: Cfg) -> None:
255
+ if not cfg.prom_base:
256
+ return
257
+
258
+ candidates = [
259
+ "nl2sql:cache_hit_ratio",
260
+ 'sum(rate(cache_events_total{hit="true"}[5m])) / sum(rate(cache_events_total[5m]))',
261
+ ]
262
+
263
+ for expr in candidates:
264
+ try:
265
+ data = prom_instant_query(cfg, expr)
266
+ if data is None:
267
+ continue
268
+ except Exception:
269
+ continue
270
+ try:
271
+ result = data["data"]["result"]
272
+ except Exception:
273
+ continue
274
+ if result:
275
+ value = result[0].get("value", [None, None])[1]
276
+ print(f"[prom] {expr} = {value}")
277
+ return
278
+
279
+ print("[prom] Could not find cache ratio metric yet (ok right after cold start).")
280
+
281
+
282
+ def main() -> int:
283
+ cfg = load_cfg()
284
+
285
+ random.seed(7) # deterministic-ish graphs
286
+
287
+ print("Waiting for API readiness...")
288
+ wait_for_ready(cfg, timeout_s=75)
289
+
290
+ print("Uploading DB...")
291
+ db_id = upload_db(cfg)
292
+ print(f"DB_ID={db_id}")
293
+
294
+ # Phase A: warm-up (mostly misses)
295
+ unique = [
296
+ "List the first 10 artists.",
297
+ "Which customer spent the most based on total invoice amount?",
298
+ "Top 5 tracks by duration.",
299
+ ]
300
+
301
+ print("Phase A: warmup (mostly misses)...")
302
+ for q in unique:
303
+ post_query(cfg, db_id=db_id, query=q)
304
+ time.sleep(0.7)
305
+
306
+ # Phase B: repeats (hits)
307
+ repeats = [
308
+ "Which customer spent the most based on total invoice amount?",
309
+ "List the first 10 artists.",
310
+ "Which customer spent the most based on total invoice amount?",
311
+ "Top 5 tracks by duration.",
312
+ "List the first 10 artists.",
313
+ ]
314
+
315
+ print("Phase B: repeated queries (hits)...")
316
+ # ~60 requests over ~1.5–2 minutes (enough signal for window-based panels)
317
+ for _ in range(60):
318
+ q = random.choice(repeats)
319
+ post_query(cfg, db_id=db_id, query=q)
320
+ time.sleep(1.1 + random.random() * 0.5)
321
+
322
+ # Give Prometheus a moment to scrape after the last request.
323
+ time.sleep(10)
324
+
325
+ print("\nSanity check:")
326
+ print_cache_sanity(cfg)
327
+
328
+ print("\n>>> NOW TAKE SCREENSHOT <<<")
329
+ print(
330
+ "Grafana: set time range to Last 10 minutes (or Last 15 minutes), refresh 5s, wait ~10s."
331
+ )
332
+ print("Tip: if hit% looks low, wait one more scrape interval and refresh.")
333
+
334
+ # Phase C: safety check (expected block) — after screenshot so OK% stays high in-window.
335
+ print("\nPhase C: safety check (expected block, after screenshot)...")
336
+ code = post_dev_safety(cfg, "drop table users;")
337
+ print(f"Safety request status={code} (expected non-200)")
338
+
339
+ # Phase D: recovery
340
+ print("Phase D: recovery...")
341
+ post_query(cfg, db_id=db_id, query="List the first 10 artists.")
342
+
343
+ print("\nDone. Suggested screenshot steps:")
344
+ print(" 1) In Grafana set time range: Last 10 minutes (or Last 15 minutes).")
345
+ print(" 2) Set refresh to 5s–10s and wait 10–20s for panels to catch up.")
346
+ print(" 3) Expect Requests-in-window > 10 and Cache Hit Ratio > 0.")
347
+ return 0
348
+
349
+
350
+ if __name__ == "__main__":
351
+ raise SystemExit(main())
scripts/smoke_api.py CHANGED
@@ -16,6 +16,7 @@ import json
16
  import os
17
  import time
18
  from pathlib import Path
 
19
 
20
  import requests
21
 
@@ -26,6 +27,19 @@ API_KEY = os.getenv("API_KEY", "dev-key")
26
  DB_DIR = Path("/tmp/nl2sql_dbs")
27
  DB_PATH = DB_DIR / "smoke_demo.sqlite"
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def _ensure_demo_db(path: Path) -> None:
31
  """Delegate to scripts/smoke_run.py if available; otherwise fail."""
@@ -112,7 +126,7 @@ def main() -> int:
112
  checks = [
113
  ("List the first 10 artists.", True),
114
  ("Which customer spent the most based on total invoice amount?", True),
115
- ("DELETE FROM users;", False), # must be blocked
116
  ]
117
 
118
  ok_all = True
@@ -130,12 +144,21 @@ def main() -> int:
130
  else:
131
  allowed = {
132
  "LLM_BAD_OUTPUT",
133
- "SQL_NOT_ALLOWED",
134
- "INVALID_SQL",
135
- "BAD_REQUEST",
136
  }
137
- if not _is_expected_block(status=status, body=body, allowed_codes=allowed):
138
- ok_all = False
 
 
 
 
 
 
 
 
 
139
 
140
  if ok_all:
141
  print("\n✅ demo-smoke passed")
 
16
  import os
17
  import time
18
  from pathlib import Path
19
+ import re
20
 
21
  import requests
22
 
 
27
  DB_DIR = Path("/tmp/nl2sql_dbs")
28
  DB_PATH = DB_DIR / "smoke_demo.sqlite"
29
 
30
+ _DML_DDL_SQL_RE = re.compile(
31
+ r"\b(delete|update|insert|drop|alter|truncate|create|replace)\b", re.IGNORECASE
32
+ )
33
+
34
+
35
+ def _is_select_only_sql(sql: str | None) -> bool:
36
+ if not sql:
37
+ return False
38
+ s = sql.strip().lower()
39
+ if not s.startswith("select"):
40
+ return False
41
+ return _DML_DDL_SQL_RE.search(sql) is None
42
+
43
 
44
  def _ensure_demo_db(path: Path) -> None:
45
  """Delegate to scripts/smoke_run.py if available; otherwise fail."""
 
126
  checks = [
127
  ("List the first 10 artists.", True),
128
  ("Which customer spent the most based on total invoice amount?", True),
129
+ ("SELECT * FROM Invoice;", False), # must be blocked (full scan without LIMIT)
130
  ]
131
 
132
  ok_all = True
 
144
  else:
145
  allowed = {
146
  "LLM_BAD_OUTPUT",
147
+ "PIPELINE_CRASH", # e.g. full_scan_without_limit guardrail
148
+ "SAFETY_NON_SELECT",
149
+ "SAFETY_MULTI_STATEMENT",
150
  }
151
+
152
+ if status != 200:
153
+ if not _is_expected_block(
154
+ status=status, body=body, allowed_codes=allowed
155
+ ):
156
+ ok_all = False
157
+ else:
158
+ # Accept safe refusal: 200 but SQL must be SELECT-only.
159
+ sql = body.get("sql") if isinstance(body, dict) else None
160
+ if not _is_select_only_sql(sql):
161
+ ok_all = False
162
 
163
  if ok_all:
164
  print("\n✅ demo-smoke passed")