Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Generate a reproducible cache/metrics screenshot workload. | |
| What it does: | |
| 1) Waits for API readiness (healthz + readyz + router health). | |
| 2) Uploads a demo SQLite DB to the API (upload_db) and captures db_id. | |
| 3) Sends a burst of unique queries (mostly misses). | |
| 4) Sends repeated queries over ~70–90s (hits), with jitter so charts look natural. | |
| 5) Triggers a safety violation once (should be blocked) WITHOUT failing the whole demo. | |
| 6) Sends a final "recovery" query (OK). | |
| 7) (Optional) Prints a Prometheus instant-query sanity check for cache metrics. | |
| Expected API: | |
| - POST {API_BASE}/api/v1/nl2sql/upload_db (multipart form: file=@db.sqlite) -> {db_id: "..."} | |
| - POST {API_BASE}/api/v1/nl2sql (json: {db_id, query, schema_preview?}) -> 200 or 4xx/5xx | |
| - GET {API_BASE}/healthz | |
| - GET {API_BASE}/readyz | |
| - GET {API_BASE}/api/v1/nl2sql/health | |
| Env: | |
| - API_BASE (default http://127.0.0.1:8000) | |
| - API_KEY (default dev-key) | |
| - DB_PATH (default /tmp/nl2sql_dbs/smoke_demo.sqlite) | |
| - PROM_BASE (default http://127.0.0.1:9090) (optional; set empty to skip) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import random | |
| import subprocess | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Any | |
| def sh(args: list[str], *, check: bool = True) -> subprocess.CompletedProcess[str]: | |
| """Run a command and return the completed process (text mode).""" | |
| return subprocess.run( | |
| args, | |
| check=check, | |
| text=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| ) | |
| class Cfg: | |
| api_base: str | |
| api_key: str | |
| db_path: str | |
| prom_base: str | None | |
| def load_cfg() -> Cfg: | |
| api_base = os.getenv("API_BASE", "http://127.0.0.1:8000").rstrip("/") | |
| api_key = os.getenv("API_KEY", "dev-key") | |
| db_path = os.getenv("DB_PATH", "/tmp/nl2sql_dbs/smoke_demo.sqlite") | |
| prom_base_env = os.getenv("PROM_BASE", "http://127.0.0.1:9090").rstrip("/") | |
| prom_base: str | None = prom_base_env if prom_base_env else None | |
| return Cfg(api_base=api_base, api_key=api_key, db_path=db_path, prom_base=prom_base) | |
| def wait_for_ready(cfg: Cfg, timeout_s: float = 60.0) -> None: | |
| """Wait until API is responsive and ready. | |
| We try multiple endpoints because on cold starts the container may accept TCP but reset early requests. | |
| """ | |
| endpoints = [ | |
| f"{cfg.api_base}/healthz", | |
| f"{cfg.api_base}/readyz", | |
| f"{cfg.api_base}/api/v1/nl2sql/health", | |
| ] | |
| start = time.time() | |
| last = "" | |
| while time.time() - start < timeout_s: | |
| ok = True | |
| for url in endpoints: | |
| cp = subprocess.run( | |
| ["curl", "-sS", "-o", "/dev/null", "-w", "%{http_code}", url], | |
| check=False, | |
| text=True, | |
| capture_output=True, | |
| ) | |
| code = (cp.stdout or "").strip() | |
| if code != "200": | |
| ok = False | |
| last = f"url={url} http={code} stderr={cp.stderr.strip()!r}" | |
| break | |
| if ok: | |
| return | |
| time.sleep(0.6) | |
| raise RuntimeError(f"API not ready after {timeout_s:.0f}s. Last={last}") | |
| def upload_db(cfg: Cfg) -> str: | |
| if not os.path.exists(cfg.db_path): | |
| raise FileNotFoundError(f"DB_PATH not found: {cfg.db_path}") | |
| url = f"{cfg.api_base}/api/v1/nl2sql/upload_db" | |
| # Do NOT use -f here; on error we want the body. | |
| cp = subprocess.run( | |
| [ | |
| "curl", | |
| "-sS", | |
| "-D", | |
| "-", | |
| "-H", | |
| f"X-API-Key: {cfg.api_key}", | |
| "-F", | |
| f"file=@{cfg.db_path}", | |
| url, | |
| ], | |
| check=False, | |
| text=True, | |
| capture_output=True, | |
| ) | |
| if cp.returncode != 0: | |
| raise RuntimeError( | |
| f"upload_db curl failed (rc={cp.returncode}). stderr={cp.stderr.strip()!r}\nstdout:\n{cp.stdout}" | |
| ) | |
| # Split headers/body | |
| raw = cp.stdout | |
| parts = raw.split("\r\n\r\n", 1) | |
| if len(parts) != 2: | |
| parts = raw.split("\n\n", 1) | |
| if len(parts) != 2: | |
| raise RuntimeError(f"upload_db returned unexpected response:\n{raw}") | |
| headers, body = parts[0], parts[1] | |
| status_line = headers.splitlines()[0] if headers.splitlines() else "" | |
| if " 200 " not in status_line: | |
| raise RuntimeError(f"upload_db non-200.\n{headers}\n\n{body}") | |
| try: | |
| data = json.loads(body) | |
| except json.JSONDecodeError as e: | |
| raise RuntimeError( | |
| f"upload_db returned non-JSON body.\n{headers}\n\n{body}" | |
| ) from e | |
| db_id = data.get("db_id") | |
| if not isinstance(db_id, str) or not db_id: | |
| raise RuntimeError(f"upload_db response missing db_id: {data}") | |
| return db_id | |
| def post_query( | |
| cfg: Cfg, *, db_id: str, query: str, fail_on_non_200: bool = True | |
| ) -> int: | |
| """POST a query. Returns HTTP status code. Optionally raises on non-200 with full response.""" | |
| url = f"{cfg.api_base}/api/v1/nl2sql" | |
| payload = json.dumps({"db_id": db_id, "query": query}) | |
| cp = subprocess.run( | |
| [ | |
| "curl", | |
| "-sS", | |
| "-D", | |
| "-", | |
| "-H", | |
| f"X-API-Key: {cfg.api_key}", | |
| "-H", | |
| "Content-Type: application/json", | |
| "-d", | |
| payload, | |
| url, | |
| ], | |
| check=False, | |
| text=True, | |
| capture_output=True, | |
| ) | |
| if cp.returncode != 0: | |
| raise RuntimeError( | |
| f"query curl failed (rc={cp.returncode}). query={query!r}\n" | |
| f"stderr={cp.stderr.strip()!r}\nstdout:\n{cp.stdout}" | |
| ) | |
| raw = cp.stdout | |
| parts = raw.split("\r\n\r\n", 1) | |
| if len(parts) != 2: | |
| parts = raw.split("\n\n", 1) | |
| if len(parts) != 2: | |
| raise RuntimeError( | |
| f"query returned unexpected response. query={query!r}\n{raw}" | |
| ) | |
| headers, body = parts[0], parts[1] | |
| status_line = headers.splitlines()[0] if headers.splitlines() else "" | |
| # Parse HTTP status code from first line: HTTP/1.1 200 OK | |
| status_code = 0 | |
| try: | |
| status_code = int(status_line.split()[1]) | |
| except Exception: | |
| status_code = 0 | |
| if fail_on_non_200 and status_code != 200: | |
| raise RuntimeError(f"Non-200 response for query={query!r}\n{headers}\n\n{body}") | |
| return status_code | |
| def prom_instant_query(cfg: Cfg, expr: str) -> Any | None: | |
| if not cfg.prom_base: | |
| return None | |
| url = f"{cfg.prom_base}/api/v1/query" | |
| cp = sh(["curl", "-fsS", url, "--data-urlencode", f"query={expr}"]) | |
| return json.loads(cp.stdout) | |
| def post_dev_safety(cfg: Cfg, sql: str) -> int: | |
| """Trigger the Safety stage directly (dev endpoint) so OK-rate panels aren't affected.""" | |
| url = f"{cfg.api_base}/api/v1/_dev/safety" | |
| payload = json.dumps({"sql": sql}) | |
| cp = subprocess.run( | |
| [ | |
| "curl", | |
| "-sS", | |
| "-D", | |
| "-", | |
| "-H", | |
| f"X-API-Key: {cfg.api_key}", | |
| "-H", | |
| "Content-Type: application/json", | |
| "-d", | |
| payload, | |
| url, | |
| ], | |
| check=False, | |
| text=True, | |
| capture_output=True, | |
| ) | |
| raw = cp.stdout | |
| # Parse status code from HTTP status line. | |
| header_block = raw.split("\r\n\r\n", 1)[0] | |
| status_line = header_block.splitlines()[0] if header_block.splitlines() else "" | |
| try: | |
| return int(status_line.split()[1]) | |
| except Exception: | |
| return 0 | |
| def print_cache_sanity(cfg: Cfg) -> None: | |
| if not cfg.prom_base: | |
| return | |
| candidates = [ | |
| "nl2sql:cache_hit_ratio", | |
| 'sum(rate(cache_events_total{hit="true"}[5m])) / sum(rate(cache_events_total[5m]))', | |
| ] | |
| for expr in candidates: | |
| try: | |
| data = prom_instant_query(cfg, expr) | |
| if data is None: | |
| continue | |
| except Exception: | |
| continue | |
| try: | |
| result = data["data"]["result"] | |
| except Exception: | |
| continue | |
| if result: | |
| value = result[0].get("value", [None, None])[1] | |
| print(f"[prom] {expr} = {value}") | |
| return | |
| print("[prom] Could not find cache ratio metric yet (ok right after cold start).") | |
| def main() -> int: | |
| cfg = load_cfg() | |
| random.seed(7) # deterministic-ish graphs | |
| print("Waiting for API readiness...") | |
| wait_for_ready(cfg, timeout_s=75) | |
| print("Uploading DB...") | |
| db_id = upload_db(cfg) | |
| print(f"DB_ID={db_id}") | |
| # Phase A: warm-up (mostly misses) | |
| unique = [ | |
| "List the first 10 artists.", | |
| "Which customer spent the most based on total invoice amount?", | |
| "Top 5 tracks by duration.", | |
| ] | |
| print("Phase A: warmup (mostly misses)...") | |
| for q in unique: | |
| post_query(cfg, db_id=db_id, query=q) | |
| time.sleep(0.7) | |
| # Phase B: repeats (hits) | |
| repeats = [ | |
| "Which customer spent the most based on total invoice amount?", | |
| "List the first 10 artists.", | |
| "Which customer spent the most based on total invoice amount?", | |
| "Top 5 tracks by duration.", | |
| "List the first 10 artists.", | |
| ] | |
| print("Phase B: repeated queries (hits)...") | |
| # ~60 requests over ~1.5–2 minutes (enough signal for window-based panels) | |
| for _ in range(60): | |
| q = random.choice(repeats) | |
| post_query(cfg, db_id=db_id, query=q) | |
| time.sleep(1.1 + random.random() * 0.5) | |
| # Give Prometheus a moment to scrape after the last request. | |
| time.sleep(10) | |
| print("\nSanity check:") | |
| print_cache_sanity(cfg) | |
| print("\n>>> NOW TAKE SCREENSHOT <<<") | |
| print( | |
| "Grafana: set time range to Last 10 minutes (or Last 15 minutes), refresh 5s, wait ~10s." | |
| ) | |
| print("Tip: if hit% looks low, wait one more scrape interval and refresh.") | |
| # Phase C: safety check (expected block) — after screenshot so OK% stays high in-window. | |
| print("\nPhase C: safety check (expected block, after screenshot)...") | |
| code = post_dev_safety(cfg, "drop table users;") | |
| print(f"Safety request status={code} (expected non-200)") | |
| # Phase D: recovery | |
| print("Phase D: recovery...") | |
| post_query(cfg, db_id=db_id, query="List the first 10 artists.") | |
| print("\nDone. Suggested screenshot steps:") | |
| print(" 1) In Grafana set time range: Last 10 minutes (or Last 15 minutes).") | |
| print(" 2) Set refresh to 5s–10s and wait 10–20s for panels to catch up.") | |
| print(" 3) Expect Requests-in-window > 10 and Cache Hit Ratio > 0.") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |