nl2sql-copilot / scripts /smoke_run.py
Melika Kheirieh
fix(smoke): align smoke_run and smoke_metrics for CI stability and disambiguated queries
cf4af3c
raw
history blame
4.12 kB
"""
Smoke test for NL2SQL Copilot
Creates a demo SQLite DB (with proper table casing),
uploads it, runs representative queries, and prints results.
Exit code is always 0 for metrics pipelines, even if some tests fail.
"""
import os
import sys
import json
import time
import sqlite3
import requests
from pathlib import Path
API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000")
API_KEY = os.getenv("API_KEY", "dev-key")
DB_DIR = Path("/tmp/nl2sql_dbs")
DB_DIR.mkdir(parents=True, exist_ok=True)
DB_PATH = DB_DIR / "smoke_demo.sqlite"
def ensure_demo_db(path: Path):
"""Create demo SQLite DB if missing."""
if path.exists():
print(f"βœ… Demo DB already exists at {path}")
return
conn = sqlite3.connect(path)
cur = conn.cursor()
# --- create schema (fixed casing) ---
cur.executescript(
"""
DROP TABLE IF EXISTS Artist;
DROP TABLE IF EXISTS Customer;
DROP TABLE IF EXISTS Invoice;
CREATE TABLE Artist (
ArtistId INTEGER PRIMARY KEY,
Name TEXT
);
CREATE TABLE Customer (
CustomerId INTEGER PRIMARY KEY,
FirstName TEXT,
LastName TEXT,
Country TEXT
);
CREATE TABLE Invoice (
InvoiceId INTEGER PRIMARY KEY,
CustomerId INTEGER,
Total REAL,
FOREIGN KEY(CustomerId) REFERENCES Customer(CustomerId)
);
INSERT INTO Artist (Name) VALUES
('Miles Davis'),
('Nina Simone'),
('Radiohead'),
('BjΓΆrk'),
('Daft Punk');
INSERT INTO Customer (FirstName, LastName, Country) VALUES
('Alice','Doe','USA'),
('Bob','Smith','Canada'),
('Claire','Johnson','France'),
('Diego','Martinez','Spain');
INSERT INTO Invoice (CustomerId, Total) VALUES
(1, 15.0),
(2, 23.5),
(3, 10.2),
(4, 45.9),
(1, 8.9);
"""
)
conn.commit()
conn.close()
print(f"βœ… Demo DB created at {path}")
def upload_db_and_get_id(path: Path) -> str:
"""Upload DB file to API and return db_id."""
url = f"{API_BASE}/api/v1/nl2sql/upload_db"
headers = {"X-API-Key": API_KEY}
with open(path, "rb") as f:
resp = requests.post(url, headers=headers, files={"file": f})
if resp.status_code != 200:
print(f"❌ Upload failed: {resp.status_code} {resp.text}")
sys.exit(0)
data = resp.json()
db_id = data.get("db_id")
if not db_id:
print(f"❌ Invalid upload response: {data}")
sys.exit(0)
print(f"βœ… Uploaded DB, got db_id={db_id}")
return db_id
def run_query(query: str, db_id: str):
"""Send a query to NL2SQL endpoint."""
url = f"{API_BASE}/api/v1/nl2sql"
headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"}
payload = {"db_id": db_id, "query": query}
t0 = time.time()
resp = requests.post(url, headers=headers, json=payload)
dt = (time.time() - t0) * 1000
ok = resp.status_code == 200
prefix = "βœ…" if ok else "❌"
print(f"{prefix} {query} ({resp.status_code}) β€” {dt:.0f} ms")
try:
parsed = resp.json()
print(json.dumps(parsed, indent=2)[:500])
except Exception:
print(resp.text[:500])
print("-" * 80)
return ok
def main():
ensure_demo_db(DB_PATH)
db_id = upload_db_and_get_id(DB_PATH)
queries = [
"How many artists are there?",
"List all artist names",
# βœ… Disambiguated phrasing
"Which customer spent the most based on total invoice amount?",
"Average invoice total per country",
"DELETE FROM users;", # expected to fail (Safety check)
]
success = True
for q in queries:
ok = run_query(q, db_id)
success &= ok
if success:
print("πŸŽ‰ Smoke tests completed successfully.")
else:
print("⚠️ Some smoke tests failed, but continuing for metrics.")
sys.exit(0)
if __name__ == "__main__":
main()