File size: 2,674 Bytes
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pytest
from fastapi.testclient import TestClient
from app.main import app
from nl2sql.types import StageResult, StageTrace

client = TestClient(app)


def fake_trace(stage: str):
    return StageTrace(stage=stage, duration_ms=10.0)

path = app.url_path_for("nl2sql_handler")

# --- 1) Clarify / ambiguity case ---------------------------------------------
def test_ambiguity_route(monkeypatch):
    from app.routers import nl2sql

    # mock pipeline to return StageResult with ambiguous=True
    def fake_run(*args, **kwargs):
        return StageResult(
            ok=True,
            data={
                "ambiguous": True,
                "questions": ["Which table do you mean?"],
                "traces": [fake_trace("detector")],
            },
        )

    monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)

    resp = client.post(
        path,
        json={
            "query": "show all records",
            "schema_preview": "CREATE TABLE ...",
        },
    )

    assert resp.status_code == 200
    data = resp.json()
    assert data["ambiguous"] is True
    assert "questions" in data


# --- 2) Error / failure case -------------------------------------------------
def test_error_route(monkeypatch):
    from app.routers import nl2sql

    def fake_run(*args, **kwargs):
        return StageResult(ok=False, error=["Bad SQL"], data={"traces": [fake_trace("safety")]})

    monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)

    resp = client.post(
        path,
        json={
            "query": "drop table users;",
            "schema_preview": "CREATE TABLE users(id int);",
        },
    )

    assert resp.status_code == 400
    assert "Bad SQL" in resp.json()["detail"]


# --- 3) Success / happy path -------------------------------------------------
def test_success_route(monkeypatch):
    from app.routers import nl2sql

    def fake_run(*args, **kwargs):
        return StageResult(
            ok=True,
            data={
                "ambiguous": False,
                "sql": "SELECT * FROM users;",
                "rationale": "Simple listing",
                "traces": [fake_trace("planner"), fake_trace("generator")],
            },
        )

    monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)

    resp = client.post(
        path,
        json={
            "query": "show all users",
            "schema_preview": "CREATE TABLE users(id int, name text);",
        },
    )

    assert resp.status_code == 200
    data = resp.json()
    assert data["sql"].lower().startswith("select")
    assert isinstance(data["traces"], list)
    assert any(t["stage"] == "planner" for t in data["traces"])