Spaces:
Sleeping
Sleeping
File size: 4,122 Bytes
454d146 cf4af3c 454d146 cf4af3c 454d146 cf4af3c 454d146 cf4af3c 454d146 cf4af3c 454d146 cf4af3c 454d146 cf4af3c 454d146 cf4af3c 454d146 cf4af3c 454d146 cf4af3c 454d146 cf4af3c 454d146 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""
Smoke test for NL2SQL Copilot
Creates a demo SQLite DB (with proper table casing),
uploads it, runs representative queries, and prints results.
Exit code is always 0 for metrics pipelines, even if some tests fail.
"""
import os
import sys
import json
import time
import sqlite3
import requests
from pathlib import Path
API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000")
API_KEY = os.getenv("API_KEY", "dev-key")
DB_DIR = Path("/tmp/nl2sql_dbs")
DB_DIR.mkdir(parents=True, exist_ok=True)
DB_PATH = DB_DIR / "smoke_demo.sqlite"
def ensure_demo_db(path: Path):
"""Create demo SQLite DB if missing."""
if path.exists():
print(f"β
Demo DB already exists at {path}")
return
conn = sqlite3.connect(path)
cur = conn.cursor()
# --- create schema (fixed casing) ---
cur.executescript(
"""
DROP TABLE IF EXISTS Artist;
DROP TABLE IF EXISTS Customer;
DROP TABLE IF EXISTS Invoice;
CREATE TABLE Artist (
ArtistId INTEGER PRIMARY KEY,
Name TEXT
);
CREATE TABLE Customer (
CustomerId INTEGER PRIMARY KEY,
FirstName TEXT,
LastName TEXT,
Country TEXT
);
CREATE TABLE Invoice (
InvoiceId INTEGER PRIMARY KEY,
CustomerId INTEGER,
Total REAL,
FOREIGN KEY(CustomerId) REFERENCES Customer(CustomerId)
);
INSERT INTO Artist (Name) VALUES
('Miles Davis'),
('Nina Simone'),
('Radiohead'),
('BjΓΆrk'),
('Daft Punk');
INSERT INTO Customer (FirstName, LastName, Country) VALUES
('Alice','Doe','USA'),
('Bob','Smith','Canada'),
('Claire','Johnson','France'),
('Diego','Martinez','Spain');
INSERT INTO Invoice (CustomerId, Total) VALUES
(1, 15.0),
(2, 23.5),
(3, 10.2),
(4, 45.9),
(1, 8.9);
"""
)
conn.commit()
conn.close()
print(f"β
Demo DB created at {path}")
def upload_db_and_get_id(path: Path) -> str:
"""Upload DB file to API and return db_id."""
url = f"{API_BASE}/api/v1/nl2sql/upload_db"
headers = {"X-API-Key": API_KEY}
with open(path, "rb") as f:
resp = requests.post(url, headers=headers, files={"file": f})
if resp.status_code != 200:
print(f"β Upload failed: {resp.status_code} {resp.text}")
sys.exit(0)
data = resp.json()
db_id = data.get("db_id")
if not db_id:
print(f"β Invalid upload response: {data}")
sys.exit(0)
print(f"β
Uploaded DB, got db_id={db_id}")
return db_id
def run_query(query: str, db_id: str):
"""Send a query to NL2SQL endpoint."""
url = f"{API_BASE}/api/v1/nl2sql"
headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"}
payload = {"db_id": db_id, "query": query}
t0 = time.time()
resp = requests.post(url, headers=headers, json=payload)
dt = (time.time() - t0) * 1000
ok = resp.status_code == 200
prefix = "β
" if ok else "β"
print(f"{prefix} {query} ({resp.status_code}) β {dt:.0f} ms")
try:
parsed = resp.json()
print(json.dumps(parsed, indent=2)[:500])
except Exception:
print(resp.text[:500])
print("-" * 80)
return ok
def main():
ensure_demo_db(DB_PATH)
db_id = upload_db_and_get_id(DB_PATH)
queries = [
"How many artists are there?",
"List all artist names",
# β
Disambiguated phrasing
"Which customer spent the most based on total invoice amount?",
"Average invoice total per country",
"DELETE FROM users;", # expected to fail (Safety check)
]
success = True
for q in queries:
ok = run_query(q, db_id)
success &= ok
if success:
print("π Smoke tests completed successfully.")
else:
print("β οΈ Some smoke tests failed, but continuing for metrics.")
sys.exit(0)
if __name__ == "__main__":
main()
|