File size: 4,122 Bytes
454d146
cf4af3c
454d146
cf4af3c
 
454d146
cf4af3c
454d146
 
 
cf4af3c
 
454d146
 
cf4af3c
 
454d146
cf4af3c
 
454d146
cf4af3c
 
 
454d146
 
cf4af3c
 
 
 
454d146
 
cf4af3c
 
454d146
cf4af3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454d146
 
 
cf4af3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454d146
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Smoke test for NL2SQL Copilot

Creates a demo SQLite DB (with proper table casing),
uploads it, runs representative queries, and prints results.

Exit code is always 0 for metrics pipelines, even if some tests fail.
"""

import os
import sys
import json
import time
import sqlite3
import requests
from pathlib import Path

API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000")
API_KEY = os.getenv("API_KEY", "dev-key")

DB_DIR = Path("/tmp/nl2sql_dbs")
DB_DIR.mkdir(parents=True, exist_ok=True)
DB_PATH = DB_DIR / "smoke_demo.sqlite"


def ensure_demo_db(path: Path):
    """Create demo SQLite DB if missing."""
    if path.exists():
        print(f"βœ… Demo DB already exists at {path}")
        return

    conn = sqlite3.connect(path)
    cur = conn.cursor()

    # --- create schema (fixed casing) ---
    cur.executescript(
        """
        DROP TABLE IF EXISTS Artist;
        DROP TABLE IF EXISTS Customer;
        DROP TABLE IF EXISTS Invoice;

        CREATE TABLE Artist (
            ArtistId INTEGER PRIMARY KEY,
            Name TEXT
        );

        CREATE TABLE Customer (
            CustomerId INTEGER PRIMARY KEY,
            FirstName TEXT,
            LastName TEXT,
            Country TEXT
        );

        CREATE TABLE Invoice (
            InvoiceId INTEGER PRIMARY KEY,
            CustomerId INTEGER,
            Total REAL,
            FOREIGN KEY(CustomerId) REFERENCES Customer(CustomerId)
        );

        INSERT INTO Artist (Name) VALUES
            ('Miles Davis'),
            ('Nina Simone'),
            ('Radiohead'),
            ('BjΓΆrk'),
            ('Daft Punk');

        INSERT INTO Customer (FirstName, LastName, Country) VALUES
            ('Alice','Doe','USA'),
            ('Bob','Smith','Canada'),
            ('Claire','Johnson','France'),
            ('Diego','Martinez','Spain');

        INSERT INTO Invoice (CustomerId, Total) VALUES
            (1, 15.0),
            (2, 23.5),
            (3, 10.2),
            (4, 45.9),
            (1, 8.9);
        """
    )
    conn.commit()
    conn.close()
    print(f"βœ… Demo DB created at {path}")


def upload_db_and_get_id(path: Path) -> str:
    """Upload DB file to API and return db_id."""
    url = f"{API_BASE}/api/v1/nl2sql/upload_db"
    headers = {"X-API-Key": API_KEY}
    with open(path, "rb") as f:
        resp = requests.post(url, headers=headers, files={"file": f})
    if resp.status_code != 200:
        print(f"❌ Upload failed: {resp.status_code} {resp.text}")
        sys.exit(0)
    data = resp.json()
    db_id = data.get("db_id")
    if not db_id:
        print(f"❌ Invalid upload response: {data}")
        sys.exit(0)
    print(f"βœ… Uploaded DB, got db_id={db_id}")
    return db_id


def run_query(query: str, db_id: str):
    """Send a query to NL2SQL endpoint."""
    url = f"{API_BASE}/api/v1/nl2sql"
    headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"}
    payload = {"db_id": db_id, "query": query}

    t0 = time.time()
    resp = requests.post(url, headers=headers, json=payload)
    dt = (time.time() - t0) * 1000

    ok = resp.status_code == 200
    prefix = "βœ…" if ok else "❌"
    print(f"{prefix} {query} ({resp.status_code}) β€” {dt:.0f} ms")

    try:
        parsed = resp.json()
        print(json.dumps(parsed, indent=2)[:500])
    except Exception:
        print(resp.text[:500])

    print("-" * 80)
    return ok


def main():
    ensure_demo_db(DB_PATH)
    db_id = upload_db_and_get_id(DB_PATH)

    queries = [
        "How many artists are there?",
        "List all artist names",
        # βœ… Disambiguated phrasing
        "Which customer spent the most based on total invoice amount?",
        "Average invoice total per country",
        "DELETE FROM users;",  # expected to fail (Safety check)
    ]

    success = True
    for q in queries:
        ok = run_query(q, db_id)
        success &= ok

    if success:
        print("πŸŽ‰ Smoke tests completed successfully.")
    else:
        print("⚠️  Some smoke tests failed, but continuing for metrics.")
    sys.exit(0)


if __name__ == "__main__":
    main()