Melika Kheirieh commited on
Commit
cf4af3c
Β·
1 Parent(s): 72e96d1

fix(smoke): align smoke_run and smoke_metrics for CI stability and disambiguated queries

Browse files
Files changed (2) hide show
  1. scripts/smoke_metrics.sh +50 -32
  2. scripts/smoke_run.py +133 -313
scripts/smoke_metrics.sh CHANGED
@@ -1,34 +1,52 @@
1
- #!/usr/bin/env bash
2
  set -euo pipefail
3
 
4
- BASE=${BASE:-http://localhost:8000}
5
- API="$BASE/api/v1"
6
-
7
- # Send a few successful queries to populate basic metrics
8
- for q in \
9
- "List all artists" \
10
- "Top 5 albums by sales" \
11
- "Count customers"
12
- do
13
- curl -s -X POST "$API/nl2sql" \
14
- -H 'Content-Type: application/json' \
15
- -H 'X-API-Key: dev-key' \
16
- -d "{\"query\":\"$q\"}" >/dev/null || true
17
- done
18
-
19
- # Send queries that trigger safety and verifier checks
20
- curl -s -X POST "$API/nl2sql" \
21
- -H 'Content-Type: application/json' \
22
- -H 'X-API-Key: dev-key' \
23
- -d '{"query":"DELETE FROM users;"}' >/dev/null || true
24
-
25
- curl -s -X POST "$API/nl2sql" \
26
- -H 'Content-Type: application/json' \
27
- -H 'X-API-Key: dev-key' \
28
- -d '{"query":"SELECT COUNT(*), country FROM customers;"}' >/dev/null || true
29
-
30
- # Print a snapshot of key Prometheus metrics
31
- echo -e "\n--- Metrics snapshot ---"
32
- curl -s "$BASE/metrics" | grep -E \
33
- 'stage_duration_ms_(sum|count|bucket)|pipeline_runs_total|safety_(checks|blocks)_total|verifier_(checks|failures)_total' \
34
- || true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  set -euo pipefail
2
 
3
+ API_BASE=${API_BASE:-"http://127.0.0.1:8000"}
4
+ API_KEY=${API_KEY:-"dev-key"}
5
+ PROM=${PROMETHEUS_URL:-"http://127.0.0.1:9090"}
6
+ TMP_DB="/tmp/nl2sql_dbs/smoke_demo.sqlite"
7
+
8
+ echo "πŸ§ͺ Running NL2SQL smoke metrics validation..."
9
+ echo "API_BASE=$API_BASE"
10
+ echo "PROMETHEUS_URL=$PROM"
11
+ echo "TMP_DB=$TMP_DB"
12
+
13
+ # --- 1. Make sure the DB exists ---
14
+ if [ ! -f "$TMP_DB" ]; then
15
+ echo "βš™οΈ Creating demo database via smoke_run.py..."
16
+ python scripts/smoke_run.py || {
17
+ echo "❌ smoke_run.py failed to create demo DB."
18
+ exit 1
19
+ }
20
+ else
21
+ echo "βœ… Found existing DB at $TMP_DB"
22
+ fi
23
+
24
+ # --- 2. Upload DB and capture db_id ---
25
+ echo "⬆️ Uploading demo DB..."
26
+ DB_ID=$(curl -s -X POST "$API_BASE/api/v1/nl2sql/upload_db" \
27
+ -H "X-API-Key: $API_KEY" \
28
+ -F "file=@${TMP_DB}" | jq -r '.db_id')
29
+
30
+ if [ "$DB_ID" = "null" ] || [ -z "$DB_ID" ]; then
31
+ echo "❌ Failed to upload DB or get db_id."
32
+ exit 1
33
+ fi
34
+ echo "βœ… Uploaded DB_ID: $DB_ID"
35
+
36
+ # --- 3. Run a few API smoke queries ---
37
+ echo "πŸš€ Sending test queries..."
38
+ curl -s -X POST "$API_BASE/api/v1/nl2sql" \
39
+ -H "Content-Type: application/json" -H "X-API-Key: $API_KEY" \
40
+ -d "{\"db_id\":\"$DB_ID\",\"query\":\"How many artists are there?\"}" | jq .
41
+
42
+ curl -s -X POST "$API_BASE/api/v1/nl2sql" \
43
+ -H "Content-Type: application/json" -H "X-API-Key: $API_KEY" \
44
+ -d "{\"db_id\":\"$DB_ID\",\"query\":\"Which customer spent the most?\"}" | jq .
45
+
46
+ # --- 4. Collect metrics snapshot from Prometheus ---
47
+ echo "πŸ“Š Checking Prometheus metrics..."
48
+ curl -s "$PROM/api/v1/query?query=nl2sql:pipeline_success_ratio" | jq .
49
+
50
+ curl -s "$PROM/api/v1/query?query=nl2sql:stage_p95_ms" | jq .
51
+
52
+ echo "βœ… Smoke metrics check completed."
scripts/smoke_run.py CHANGED
@@ -1,334 +1,154 @@
1
  """
2
- Minimal smoke/demo runner for the NL2SQL pipeline.
3
 
4
- - Builds the pipeline via the official factory (no app/router imports).
5
- - Runs a small set of demo questions against a SQLite DB.
6
- - Works in two modes:
7
- * Stub mode (set PYTEST_CURRENT_TEST=1) β†’ no API key needed.
8
- * Real mode (set OPENAI_API_KEY=...) β†’ uses actual LLM provider.
9
 
10
- Outputs:
11
- benchmarks/results_demo/<timestamp>/
12
- - demo.jsonl # one JSON record per query
13
- - summary.json # latency & success overview
14
- - results.csv # compact table for quick inspection
15
-
16
- Usage examples:
17
- PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
18
- python scripts/smoke_run.py --db-path demo.db
19
-
20
- # With a custom dataset file (JSON: list[str] or list[{question: "..."}])
21
- PYTHONPATH=$PWD PYTEST_CURRENT_TEST=1 \
22
- python scripts/smoke_run.py --db-path demo.db --dataset-file benchmarks/demo.json
23
  """
24
 
25
- from __future__ import annotations
26
-
27
- import argparse
28
- import csv
29
- import json
30
  import os
 
 
31
  import time
32
- from pathlib import Path
33
- from typing import Any, Dict, List, Optional
34
  import sqlite3
 
 
35
 
36
- from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
37
- from adapters.db.sqlite_adapter import SQLiteAdapter
38
-
39
- CONFIG_PATH = "configs/sqlite_pipeline.yaml"
40
- DEFAULT_QUESTIONS: List[str] = [
41
- "list all customers",
42
- "show total invoices per country",
43
- "top 3 albums by total sales",
44
- "artists with more than 3 albums",
45
- "number of employees per city",
46
- ]
47
 
48
- RESULT_ROOT = Path("benchmarks") / "results_demo"
49
- TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
50
- RESULT_DIR = RESULT_ROOT / TIMESTAMP
51
 
52
 
53
- def ensure_demo_db(db_path: Path) -> None:
54
- """Create a tiny demo SQLite DB if it doesn't exist."""
55
- if db_path.exists():
 
56
  return
57
- db_path.parent.mkdir(parents=True, exist_ok=True)
58
- conn = sqlite3.connect(str(db_path))
59
- cur = conn.cursor()
60
-
61
- # Minimal schema that matches our default demo questions
62
- cur.executescript("""
63
- DROP TABLE IF EXISTS customers;
64
- DROP TABLE IF EXISTS invoices;
65
- DROP TABLE IF EXISTS employees;
66
- DROP TABLE IF EXISTS artists;
67
- DROP TABLE IF EXISTS albums;
68
-
69
- CREATE TABLE customers (
70
- id INTEGER PRIMARY KEY,
71
- name TEXT,
72
- country TEXT
73
- );
74
-
75
- CREATE TABLE invoices (
76
- id INTEGER PRIMARY KEY,
77
- customer_id INTEGER,
78
- total REAL,
79
- country TEXT,
80
- FOREIGN KEY (customer_id) REFERENCES customers(id)
81
- );
82
-
83
- CREATE TABLE employees (
84
- id INTEGER PRIMARY KEY,
85
- name TEXT,
86
- city TEXT
87
- );
88
 
89
- CREATE TABLE artists (
90
- id INTEGER PRIMARY KEY,
91
- name TEXT
92
- );
93
-
94
- CREATE TABLE albums (
95
- id INTEGER PRIMARY KEY,
96
- artist_id INTEGER,
97
- title TEXT,
98
- sales REAL DEFAULT 0,
99
- FOREIGN KEY (artist_id) REFERENCES artists(id)
100
- );
101
- """)
102
 
103
- # Seed a bit of data
104
- cur.executemany(
105
- "INSERT INTO customers (id, name, country) VALUES (?, ?, ?)",
106
- [
107
- (1, "Alice", "USA"),
108
- (2, "Bob", "Germany"),
109
- (3, "Carlos", "Brazil"),
110
- (4, "Darya", "Iran"),
111
- ],
112
- )
113
- cur.executemany(
114
- "INSERT INTO invoices (id, customer_id, total, country) VALUES (?, ?, ?, ?)",
115
- [
116
- (1, 1, 120.5, "USA"),
117
- (2, 2, 75.0, "Germany"),
118
- (3, 1, 33.2, "USA"),
119
- (4, 3, 48.0, "Brazil"),
120
- (5, 4, 90.0, "Iran"),
121
- ],
122
- )
123
- cur.executemany(
124
- "INSERT INTO employees (id, name, city) VALUES (?, ?, ?)",
125
- [
126
- (1, "Eve", "New York"),
127
- (2, "Frank", "Berlin"),
128
- (3, "Gita", "Tehran"),
129
- ],
130
- )
131
- cur.executemany(
132
- "INSERT INTO artists (id, name) VALUES (?, ?)",
133
- [
134
- (1, "ABand"),
135
- (2, "BGroup"),
136
- (3, "CEnsemble"),
137
- ],
138
- )
139
- cur.executemany(
140
- "INSERT INTO albums (id, artist_id, title, sales) VALUES (?, ?, ?, ?)",
141
- [
142
- (1, 1, "First Light", 500.0),
143
- (2, 1, "Second Wind", 300.0),
144
- (3, 2, "Blue Lines", 900.0),
145
- (4, 3, "Echoes", 150.0),
146
- ],
 
 
147
  )
148
-
149
  conn.commit()
150
  conn.close()
151
-
152
-
153
- def _ms(start_s: float) -> int:
154
- """Convert elapsed seconds to integer milliseconds."""
155
- return int((time.perf_counter() - start_s) * 1000)
156
-
157
-
158
- def _derive_schema_preview(pipeline_obj: Any) -> Optional[str]:
159
- """Try to derive schema preview from adapter/executor if available."""
160
- for attr in ("executor", "adapter"):
161
- obj = getattr(pipeline_obj, attr, None)
162
- if obj and hasattr(obj, "derive_schema_preview"):
163
- try:
164
- return obj.derive_schema_preview() # type: ignore[no-any-return]
165
- except Exception:
166
- pass
167
- return None
168
-
169
-
170
- def _normalize_trace(trace_obj: Any) -> List[Dict[str, Any]]:
171
- """Convert trace to a list of {stage, ms} dicts for logging/export."""
172
- out: List[Dict[str, Any]] = []
173
- if not isinstance(trace_obj, list):
174
- return out
175
- for t in trace_obj:
176
- if isinstance(t, dict):
177
- stage = t.get("stage", "?")
178
- ms = t.get("duration_ms", 0)
179
- else:
180
- stage = getattr(t, "stage", "?")
181
- ms = getattr(t, "duration_ms", 0)
182
- try:
183
- out.append({"stage": str(stage), "ms": int(ms)})
184
- except Exception:
185
- out.append({"stage": str(stage), "ms": 0})
186
- return out
187
-
188
-
189
- def _load_questions(path: Optional[str]) -> List[str]:
190
- """Load questions from a JSON file or return defaults."""
191
- if not path:
192
- return DEFAULT_QUESTIONS
193
- p = Path(path)
194
- if not p.exists():
195
- raise FileNotFoundError(f"dataset file not found: {p}")
196
- data = json.loads(p.read_text(encoding="utf-8"))
197
- if isinstance(data, list):
198
- if all(isinstance(x, str) for x in data):
199
- return list(data)
200
- if all(isinstance(x, dict) and "question" in x for x in data):
201
- return [str(x["question"]) for x in data]
202
- raise ValueError(
203
- "Dataset must be a JSON array of strings or objects with a 'question' field."
204
- )
205
-
206
-
207
- def main() -> None:
208
- ap = argparse.ArgumentParser()
209
- ap.add_argument(
210
- "--db-path",
211
- type=str,
212
- default="demo.db",
213
- help="Path to SQLite DB (default: demo.db)",
214
- )
215
- ap.add_argument(
216
- "--dataset-file",
217
- type=str,
218
- default=None,
219
- help="Optional JSON file: list[str] or list[{question: str}]",
220
- )
221
- ap.add_argument(
222
- "--config",
223
- type=str,
224
- default=CONFIG_PATH,
225
- help=f"Pipeline YAML (default: {CONFIG_PATH})",
226
- )
227
- args = ap.parse_args()
228
-
229
- RESULT_DIR.mkdir(parents=True, exist_ok=True)
230
-
231
- # Resolve DB path and ensure demo DB exists for quick smoke runs
232
- db_path = Path(args.db_path).resolve()
233
- ensure_demo_db(db_path)
234
-
235
- # Build pipeline via the official factory (factory decides real vs stub by env)
236
- adapter = SQLiteAdapter(str(db_path))
237
- pipeline = pipeline_from_config_with_adapter(args.config, adapter=adapter)
238
-
239
- schema_preview = _derive_schema_preview(pipeline)
240
- print(f"βœ… Pipeline ready (db={db_path.name}, config={args.config})")
241
- print(
242
- "πŸ“„ Schema preview:",
243
- "yes" if schema_preview else "no",
244
- "| provider:",
245
- "STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL",
246
- )
247
-
248
- questions = _load_questions(args.dataset_file)
249
- print(f"πŸ—‚ Loaded {len(questions)} questions.")
250
-
251
- rows: List[Dict[str, Any]] = []
252
- for q in questions:
253
- print(f"\n🧠 Query: {q}")
254
- t0 = time.perf_counter()
255
- try:
256
- result = pipeline.run(user_query=q, schema_preview=schema_preview or "")
257
- latency_ms = _ms(t0) or 1 # clamp to 1ms when stubs are instant
258
- stages = _normalize_trace(
259
- getattr(result, "traces", getattr(result, "trace", []))
260
- )
261
- rows.append(
262
- {
263
- "query": q,
264
- "ok": bool(getattr(result, "ok", True)),
265
- "latency_ms": latency_ms,
266
- "trace": stages,
267
- "error": None,
268
- }
269
- )
270
- print(f"βœ… Success ({latency_ms} ms)")
271
- except Exception as exc:
272
- latency_ms = _ms(t0) or 1
273
- rows.append(
274
- {
275
- "query": q,
276
- "ok": False,
277
- "latency_ms": latency_ms,
278
- "trace": [],
279
- "error": str(exc),
280
- }
281
- )
282
- print(f"❌ Failed: {exc!s} ({latency_ms} ms)")
283
-
284
- # Aggregate and persist
285
- avg_latency = (
286
- round(sum(r["latency_ms"] for r in rows) / max(len(rows), 1), 1)
287
- if rows
288
- else 0.0
289
- )
290
- success_rate = (
291
- (sum(1 for r in rows if r["ok"]) / max(len(rows), 1)) if rows else 0.0
292
- )
293
- meta = {
294
- "db_path": str(db_path),
295
- "config": args.config,
296
- "provider_hint": "STUBS" if os.getenv("PYTEST_CURRENT_TEST") else "REAL",
297
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
298
- }
299
-
300
- jsonl_path = RESULT_DIR / "demo.jsonl"
301
- with jsonl_path.open("w", encoding="utf-8") as f:
302
- for r in rows:
303
- json.dump(r, f, ensure_ascii=False)
304
- f.write("\n")
305
-
306
- summary_path = RESULT_DIR / "summary.json"
307
- with summary_path.open("w", encoding="utf-8") as f:
308
- json.dump(
309
- {"avg_latency_ms": avg_latency, "success_rate": success_rate, **meta},
310
- f,
311
- indent=2,
312
- )
313
-
314
- csv_path = RESULT_DIR / "results.csv"
315
- with csv_path.open("w", newline="", encoding="utf-8") as f:
316
- wr = csv.DictWriter(f, fieldnames=["query", "ok", "latency_ms"])
317
- wr.writeheader()
318
- for r in rows:
319
- wr.writerow(
320
- {
321
- "query": r["query"],
322
- "ok": "βœ…" if r["ok"] else "❌",
323
- "latency_ms": int(r["latency_ms"]),
324
- }
325
- )
326
-
327
- print(
328
- "\nπŸ’Ύ Saved outputs:\n"
329
- f"- {jsonl_path}\n- {summary_path}\n- {csv_path}\n"
330
- f"πŸ“Š Avg latency: {avg_latency} ms | Success rate: {success_rate:.0%}\n"
331
- )
332
 
333
 
334
  if __name__ == "__main__":
 
1
  """
2
+ Smoke test for NL2SQL Copilot
3
 
4
+ Creates a demo SQLite DB (with proper table casing),
5
+ uploads it, runs representative queries, and prints results.
 
 
 
6
 
7
+ Exit code is always 0 for metrics pipelines, even if some tests fail.
 
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
 
 
 
 
 
 
10
  import os
11
+ import sys
12
+ import json
13
  import time
 
 
14
  import sqlite3
15
+ import requests
16
+ from pathlib import Path
17
 
18
+ API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000")
19
+ API_KEY = os.getenv("API_KEY", "dev-key")
 
 
 
 
 
 
 
 
 
20
 
21
+ DB_DIR = Path("/tmp/nl2sql_dbs")
22
+ DB_DIR.mkdir(parents=True, exist_ok=True)
23
+ DB_PATH = DB_DIR / "smoke_demo.sqlite"
24
 
25
 
26
+ def ensure_demo_db(path: Path):
27
+ """Create demo SQLite DB if missing."""
28
+ if path.exists():
29
+ print(f"βœ… Demo DB already exists at {path}")
30
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ conn = sqlite3.connect(path)
33
+ cur = conn.cursor()
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # --- create schema (fixed casing) ---
36
+ cur.executescript(
37
+ """
38
+ DROP TABLE IF EXISTS Artist;
39
+ DROP TABLE IF EXISTS Customer;
40
+ DROP TABLE IF EXISTS Invoice;
41
+
42
+ CREATE TABLE Artist (
43
+ ArtistId INTEGER PRIMARY KEY,
44
+ Name TEXT
45
+ );
46
+
47
+ CREATE TABLE Customer (
48
+ CustomerId INTEGER PRIMARY KEY,
49
+ FirstName TEXT,
50
+ LastName TEXT,
51
+ Country TEXT
52
+ );
53
+
54
+ CREATE TABLE Invoice (
55
+ InvoiceId INTEGER PRIMARY KEY,
56
+ CustomerId INTEGER,
57
+ Total REAL,
58
+ FOREIGN KEY(CustomerId) REFERENCES Customer(CustomerId)
59
+ );
60
+
61
+ INSERT INTO Artist (Name) VALUES
62
+ ('Miles Davis'),
63
+ ('Nina Simone'),
64
+ ('Radiohead'),
65
+ ('BjΓΆrk'),
66
+ ('Daft Punk');
67
+
68
+ INSERT INTO Customer (FirstName, LastName, Country) VALUES
69
+ ('Alice','Doe','USA'),
70
+ ('Bob','Smith','Canada'),
71
+ ('Claire','Johnson','France'),
72
+ ('Diego','Martinez','Spain');
73
+
74
+ INSERT INTO Invoice (CustomerId, Total) VALUES
75
+ (1, 15.0),
76
+ (2, 23.5),
77
+ (3, 10.2),
78
+ (4, 45.9),
79
+ (1, 8.9);
80
+ """
81
  )
 
82
  conn.commit()
83
  conn.close()
84
+ print(f"βœ… Demo DB created at {path}")
85
+
86
+
87
+ def upload_db_and_get_id(path: Path) -> str:
88
+ """Upload DB file to API and return db_id."""
89
+ url = f"{API_BASE}/api/v1/nl2sql/upload_db"
90
+ headers = {"X-API-Key": API_KEY}
91
+ with open(path, "rb") as f:
92
+ resp = requests.post(url, headers=headers, files={"file": f})
93
+ if resp.status_code != 200:
94
+ print(f"❌ Upload failed: {resp.status_code} {resp.text}")
95
+ sys.exit(0)
96
+ data = resp.json()
97
+ db_id = data.get("db_id")
98
+ if not db_id:
99
+ print(f"❌ Invalid upload response: {data}")
100
+ sys.exit(0)
101
+ print(f"βœ… Uploaded DB, got db_id={db_id}")
102
+ return db_id
103
+
104
+
105
+ def run_query(query: str, db_id: str):
106
+ """Send a query to NL2SQL endpoint."""
107
+ url = f"{API_BASE}/api/v1/nl2sql"
108
+ headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"}
109
+ payload = {"db_id": db_id, "query": query}
110
+
111
+ t0 = time.time()
112
+ resp = requests.post(url, headers=headers, json=payload)
113
+ dt = (time.time() - t0) * 1000
114
+
115
+ ok = resp.status_code == 200
116
+ prefix = "βœ…" if ok else "❌"
117
+ print(f"{prefix} {query} ({resp.status_code}) β€” {dt:.0f} ms")
118
+
119
+ try:
120
+ parsed = resp.json()
121
+ print(json.dumps(parsed, indent=2)[:500])
122
+ except Exception:
123
+ print(resp.text[:500])
124
+
125
+ print("-" * 80)
126
+ return ok
127
+
128
+
129
+ def main():
130
+ ensure_demo_db(DB_PATH)
131
+ db_id = upload_db_and_get_id(DB_PATH)
132
+
133
+ queries = [
134
+ "How many artists are there?",
135
+ "List all artist names",
136
+ # βœ… Disambiguated phrasing
137
+ "Which customer spent the most based on total invoice amount?",
138
+ "Average invoice total per country",
139
+ "DELETE FROM users;", # expected to fail (Safety check)
140
+ ]
141
+
142
+ success = True
143
+ for q in queries:
144
+ ok = run_query(q, db_id)
145
+ success &= ok
146
+
147
+ if success:
148
+ print("πŸŽ‰ Smoke tests completed successfully.")
149
+ else:
150
+ print("⚠️ Some smoke tests failed, but continuing for metrics.")
151
+ sys.exit(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
 
154
  if __name__ == "__main__":