File size: 3,198 Bytes
570f7bd
 
a45c0eb
570f7bd
 
 
4dae3e6
a45c0eb
 
 
c1bc4eb
4dae3e6
570f7bd
 
c1bc4eb
570f7bd
 
 
 
a45c0eb
570f7bd
a45c0eb
570f7bd
a45c0eb
 
 
 
 
 
 
 
570f7bd
 
ccefd8e
570f7bd
 
 
 
 
 
 
 
 
 
 
 
a45c0eb
570f7bd
 
 
 
 
 
 
a45c0eb
 
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
ccefd8e
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a45c0eb
570f7bd
a45c0eb
 
 
 
 
 
 
 
570f7bd
 
ccefd8e
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
a45c0eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from fastapi.testclient import TestClient
from app.main import app
from nl2sql.pipeline import FinalResult

client = TestClient(app)


def fake_trace(stage: str) -> dict:
    # FinalResult.traces is a list of dicts (StageTrace.__dict__)
    return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}


path = app.url_path_for("nl2sql_handler")


# --- 1) Clarify / ambiguity case ---------------------------------------------
def test_ambiguity_route(monkeypatch):
    from app.routers import nl2sql

    # mock pipeline to return FinalResult with ambiguous=True
    def fake_run(*args, **kwargs):
        return FinalResult(
            ok=True,
            ambiguous=True,
            error=False,
            details=["Ambiguities found: 1"],
            questions=["Which table do you mean?"],
            sql=None,
            rationale=None,
            verified=None,
            traces=[fake_trace("detector")],
        )

    monkeypatch.setattr(nl2sql.Pipeline, "run", fake_run)

    resp = client.post(
        path,
        json={
            "query": "show all records",
            "schema_preview": "CREATE TABLE ...",
        },
    )
    assert resp.status_code == 200
    data = resp.json()
    assert data["ambiguous"] is True
    assert "questions" in data
    assert isinstance(data["questions"], list)


# --- 2) Error / failure case -------------------------------------------------
def test_error_route(monkeypatch):
    from app.routers import nl2sql

    def fake_run(*args, **kwargs):
        return FinalResult(
            ok=False,
            ambiguous=False,
            error=True,
            details=["Bad SQL"],
            questions=None,
            sql=None,
            rationale=None,
            verified=None,
            traces=[fake_trace("safety")],
        )

    monkeypatch.setattr(nl2sql.Pipeline, "run", fake_run)

    resp = client.post(
        path,
        json={
            "query": "drop table users;",
            "schema_preview": "CREATE TABLE users(id int);",
        },
    )
    assert resp.status_code == 400
    assert "Bad SQL" in resp.json()["detail"]


# --- 3) Success / happy path -------------------------------------------------
def test_success_route(monkeypatch):
    from app.routers import nl2sql

    def fake_run(*args, **kwargs):
        return FinalResult(
            ok=True,
            ambiguous=False,
            error=False,
            details=None,
            questions=None,
            sql="SELECT * FROM users;",
            rationale="Simple listing",
            verified=True,
            traces=[fake_trace("planner"), fake_trace("generator")],
        )

    monkeypatch.setattr(nl2sql.Pipeline, "run", fake_run)

    resp = client.post(
        path,
        json={
            "query": "show all users",
            "schema_preview": "CREATE TABLE users(id int, name text);",
        },
    )

    assert resp.status_code == 200
    data = resp.json()
    assert data["sql"].lower().startswith("select")
    assert isinstance(data["traces"], list)
    assert any(t["stage"] == "planner" for t in data["traces"])
    assert any(t["stage"] == "generator" for t in data["traces"])