File size: 5,117 Bytes
ddd54ed
f0b4004
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4596e5b
f0b4004
 
 
 
 
 
 
 
 
 
4596e5b
 
 
 
 
 
 
 
 
 
 
 
 
f0b4004
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd54ed
 
 
 
 
 
 
 
f0b4004
 
 
 
 
 
 
 
 
 
 
ddd54ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0b4004
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd54ed
f0b4004
4596e5b
f0b4004
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd54ed
 
4596e5b
 
 
ddd54ed
4596e5b
 
 
 
 
 
 
 
 
 
 
f0b4004
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""Portable smoke requests for NL2SQL Copilot.

- Ensures a demo SQLite DB exists under /tmp/nl2sql_dbs/smoke_demo.sqlite
- Uploads it to the API
- Runs a few representative queries
- Exits non-zero on failure (so Make/CI can trust it)

Env:
  API_BASE: base URL of API (default: http://127.0.0.1:8000)
  API_KEY:  API key header value (default: dev-key)
"""

from __future__ import annotations

import json
import os
import time
from pathlib import Path
import re

import requests


API_BASE = os.getenv("API_BASE", "http://127.0.0.1:8000").rstrip("/")
API_KEY = os.getenv("API_KEY", "dev-key")

DB_DIR = Path("/tmp/nl2sql_dbs")
DB_PATH = DB_DIR / "smoke_demo.sqlite"

_DML_DDL_SQL_RE = re.compile(
    r"\b(delete|update|insert|drop|alter|truncate|create|replace)\b", re.IGNORECASE
)


def _is_select_only_sql(sql: str | None) -> bool:
    if not sql:
        return False
    s = sql.strip().lower()
    if not s.startswith("select"):
        return False
    return _DML_DDL_SQL_RE.search(sql) is None


def _ensure_demo_db(path: Path) -> None:
    """Delegate to scripts/smoke_run.py if available; otherwise fail."""
    # Your repo already has scripts/smoke_run.py which creates the DB deterministically.
    from smoke_run import ensure_demo_db  # type: ignore

    ensure_demo_db(path)


def _upload_db_and_get_id(path: Path) -> str:
    url = f"{API_BASE}/api/v1/nl2sql/upload_db"
    headers = {"X-API-Key": API_KEY}
    with path.open("rb") as f:
        resp = requests.post(url, headers=headers, files={"file": f}, timeout=30)
    if resp.status_code != 200:
        raise RuntimeError(f"Upload failed: {resp.status_code} {resp.text[:400]}")
    data = resp.json()
    db_id = data.get("db_id")
    if not db_id:
        raise RuntimeError(f"Invalid upload response: {data}")
    return str(db_id)


def _run_query(db_id: str, query: str) -> dict:
    url = f"{API_BASE}/api/v1/nl2sql"
    headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"}
    payload = {"db_id": db_id, "query": query}

    t0 = time.time()
    timeout_s = float(os.getenv("SMOKE_TIMEOUT", "180"))
    try:
        resp = requests.post(url, headers=headers, json=payload, timeout=timeout_s)
    except requests.exceptions.ReadTimeout:
        # One retry to smooth over transient provider/LLM slowness.
        time.sleep(2)
        resp = requests.post(url, headers=headers, json=payload, timeout=timeout_s)

    dt_ms = int(round((time.time() - t0) * 1000))

    out: dict = {}
    try:
        out = resp.json()
    except Exception:
        out = {"raw": resp.text}

    return {"status": resp.status_code, "latency_ms": dt_ms, "body": out}


def _get_error_code(body: dict) -> str | None:
    """Extract error.code from the API response shape if present."""
    try:
        err = body.get("error")
        if isinstance(err, dict):
            code = err.get("code")
            return str(code) if code is not None else None
    except Exception:
        return None
    return None


def _is_expected_block(status: int, body: dict, allowed_codes: set[str]) -> bool:
    """Return True if this looks like an intentional safety rejection."""
    if status == 200:
        return False
    code = _get_error_code(body)
    return code in allowed_codes


def main() -> int:
    DB_DIR.mkdir(parents=True, exist_ok=True)

    try:
        _ensure_demo_db(DB_PATH)
    except Exception as e:
        print(f"❌ Failed to create demo DB: {e}")
        return 2

    try:
        db_id = _upload_db_and_get_id(DB_PATH)
    except Exception as e:
        print(f"❌ Failed to upload demo DB: {e}")
        return 3

    checks = [
        ("List the first 10 artists.", True),
        ("Which customer spent the most based on total invoice amount?", True),
        ("SELECT * FROM Invoice;", False),  # must be blocked (full scan without LIMIT)
    ]

    ok_all = True
    for q, should_succeed in checks:
        r = _run_query(db_id=db_id, query=q)
        status = r["status"]
        body = r["body"]
        print(f"\nQuery: {q}")
        print(f"HTTP {status} | {r['latency_ms']} ms")
        print(json.dumps(body, indent=2)[:800])

        if should_succeed:
            if status != 200:
                ok_all = False
        else:
            allowed = {
                "LLM_BAD_OUTPUT",
                "PIPELINE_CRASH",  # e.g. full_scan_without_limit guardrail
                "SAFETY_NON_SELECT",
                "SAFETY_MULTI_STATEMENT",
            }

            if status != 200:
                if not _is_expected_block(
                    status=status, body=body, allowed_codes=allowed
                ):
                    ok_all = False
            else:
                # Accept safe refusal: 200 but SQL must be SELECT-only.
                sql = body.get("sql") if isinstance(body, dict) else None
                if not _is_select_only_sql(sql):
                    ok_all = False

    if ok_all:
        print("\n✅ demo-smoke passed")
        return 0

    print("\n❌ demo-smoke failed (see output above)")
    return 4


if __name__ == "__main__":
    raise SystemExit(main())