File size: 3,201 Bytes
570f7bd
 
a45c0eb
570f7bd
 
 
4dae3e6
a45c0eb
 
 
c1bc4eb
4dae3e6
570f7bd
 
c1bc4eb
570f7bd
 
 
 
a45c0eb
570f7bd
a45c0eb
570f7bd
a45c0eb
 
 
 
 
 
 
 
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a45c0eb
570f7bd
 
 
 
 
 
 
a45c0eb
 
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a45c0eb
570f7bd
a45c0eb
 
 
 
 
 
 
 
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a45c0eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from fastapi.testclient import TestClient
from app.main import app
from nl2sql.pipeline import FinalResult

client = TestClient(app)


def fake_trace(stage: str) -> dict:
    # FinalResult.traces is a list of dicts (StageTrace.__dict__)
    return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}


path = app.url_path_for("nl2sql_handler")


# --- 1) Clarify / ambiguity case ---------------------------------------------
def test_ambiguity_route(monkeypatch):
    from app.routers import nl2sql

    # mock pipeline to return FinalResult with ambiguous=True
    def fake_run(*args, **kwargs):
        return FinalResult(
            ok=True,
            ambiguous=True,
            error=False,
            details=["Ambiguities found: 1"],
            questions=["Which table do you mean?"],
            sql=None,
            rationale=None,
            verified=None,
            traces=[fake_trace("detector")],
        )

    monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)

    resp = client.post(
        path,
        json={
            "query": "show all records",
            "schema_preview": "CREATE TABLE ...",
        },
    )
    assert resp.status_code == 200
    data = resp.json()
    assert data["ambiguous"] is True
    assert "questions" in data
    assert isinstance(data["questions"], list)


# --- 2) Error / failure case -------------------------------------------------
def test_error_route(monkeypatch):
    from app.routers import nl2sql

    def fake_run(*args, **kwargs):
        return FinalResult(
            ok=False,
            ambiguous=False,
            error=True,
            details=["Bad SQL"],
            questions=None,
            sql=None,
            rationale=None,
            verified=None,
            traces=[fake_trace("safety")],
        )

    monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)

    resp = client.post(
        path,
        json={
            "query": "drop table users;",
            "schema_preview": "CREATE TABLE users(id int);",
        },
    )
    assert resp.status_code == 400
    assert "Bad SQL" in resp.json()["detail"]


# --- 3) Success / happy path -------------------------------------------------
def test_success_route(monkeypatch):
    from app.routers import nl2sql

    def fake_run(*args, **kwargs):
        return FinalResult(
            ok=True,
            ambiguous=False,
            error=False,
            details=None,
            questions=None,
            sql="SELECT * FROM users;",
            rationale="Simple listing",
            verified=True,
            traces=[fake_trace("planner"), fake_trace("generator")],
        )

    monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)

    resp = client.post(
        path,
        json={
            "query": "show all users",
            "schema_preview": "CREATE TABLE users(id int, name text);",
        },
    )

    assert resp.status_code == 200
    data = resp.json()
    assert data["sql"].lower().startswith("select")
    assert isinstance(data["traces"], list)
    assert any(t["stage"] == "planner" for t in data["traces"])
    assert any(t["stage"] == "generator" for t in data["traces"])