Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
a45c0eb
1
Parent(s):
713d3ca
refactor: unify pipeline output via FinalResult model
Browse files- app/routers/nl2sql.py +12 -13
- nl2sql/pipeline.py +103 -66
- nl2sql/types.py +17 -0
- tests/test_nl2sql_router.py +35 -23
- tests/test_pipeline_integration.py +23 -17
app/routers/nl2sql.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
from dataclasses import asdict, is_dataclass
|
| 2 |
from fastapi import APIRouter, HTTPException
|
| 3 |
from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
|
| 4 |
-
from nl2sql.pipeline import Pipeline
|
| 5 |
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 6 |
from nl2sql.safety import Safety
|
| 7 |
from nl2sql.planner import Planner
|
| 8 |
from nl2sql.generator import Generator
|
| 9 |
from adapters.llm.openai_provider import OpenAIProvider
|
| 10 |
-
from nl2sql.types import StageResult
|
| 11 |
from nl2sql.executor import Executor
|
| 12 |
from nl2sql.verifier import Verifier
|
| 13 |
from nl2sql.repair import Repair
|
|
@@ -59,28 +58,28 @@ def _round_trace(t: dict) -> dict:
|
|
| 59 |
@router.post("", name="nl2sql_handler")
|
| 60 |
def nl2sql_handler(request: NL2SQLRequest):
|
| 61 |
result = _pipeline.run(
|
| 62 |
-
user_query=request.query,
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
# --- Ensure result type ---
|
| 66 |
-
if not isinstance(result,
|
| 67 |
raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
|
| 68 |
|
| 69 |
-
data = result.data or {}
|
| 70 |
-
|
| 71 |
# --- Handle ambiguity ---
|
| 72 |
-
if
|
| 73 |
-
return ClarifyResponse(ambiguous=True, questions=
|
| 74 |
|
| 75 |
# --- Handle error ---
|
| 76 |
-
if not result.ok:
|
| 77 |
-
detail = "; ".join(result.
|
| 78 |
raise HTTPException(status_code=400, detail=detail)
|
| 79 |
|
| 80 |
# --- Success case ---
|
|
|
|
| 81 |
return NL2SQLResponse(
|
| 82 |
ambiguous=False,
|
| 83 |
-
sql=
|
| 84 |
-
rationale=
|
| 85 |
-
traces=
|
| 86 |
)
|
|
|
|
| 1 |
from dataclasses import asdict, is_dataclass
|
| 2 |
from fastapi import APIRouter, HTTPException
|
| 3 |
from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
|
| 4 |
+
from nl2sql.pipeline import Pipeline, FinalResult
|
| 5 |
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 6 |
from nl2sql.safety import Safety
|
| 7 |
from nl2sql.planner import Planner
|
| 8 |
from nl2sql.generator import Generator
|
| 9 |
from adapters.llm.openai_provider import OpenAIProvider
|
|
|
|
| 10 |
from nl2sql.executor import Executor
|
| 11 |
from nl2sql.verifier import Verifier
|
| 12 |
from nl2sql.repair import Repair
|
|
|
|
| 58 |
@router.post("", name="nl2sql_handler")
|
| 59 |
def nl2sql_handler(request: NL2SQLRequest):
|
| 60 |
result = _pipeline.run(
|
| 61 |
+
user_query=request.query,
|
| 62 |
+
schema_preview=request.schema_preview,
|
| 63 |
)
|
| 64 |
|
| 65 |
# --- Ensure result type ---
|
| 66 |
+
if not isinstance(result, FinalResult):
|
| 67 |
raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
|
| 68 |
|
|
|
|
|
|
|
| 69 |
# --- Handle ambiguity ---
|
| 70 |
+
if result.ambiguous and result.questions:
|
| 71 |
+
return ClarifyResponse(ambiguous=True, questions=result.questions)
|
| 72 |
|
| 73 |
# --- Handle error ---
|
| 74 |
+
if not result.ok or result.error:
|
| 75 |
+
detail = "; ".join(result.details or ["Unknown error"])
|
| 76 |
raise HTTPException(status_code=400, detail=detail)
|
| 77 |
|
| 78 |
# --- Success case ---
|
| 79 |
+
traces = [ _round_trace(t) for t in (result.traces or []) ]
|
| 80 |
return NL2SQLResponse(
|
| 81 |
ambiguous=False,
|
| 82 |
+
sql=result.sql,
|
| 83 |
+
rationale=result.rationale,
|
| 84 |
+
traces=traces,
|
| 85 |
)
|
nl2sql/pipeline.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import traceback
|
|
|
|
| 3 |
from typing import Dict, Any, Optional, List
|
|
|
|
| 4 |
from nl2sql.types import StageResult
|
| 5 |
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 6 |
from nl2sql.planner import Planner
|
|
@@ -11,10 +13,26 @@ from nl2sql.verifier import Verifier
|
|
| 11 |
from nl2sql.repair import Repair
|
| 12 |
from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class Pipeline:
|
| 15 |
"""
|
| 16 |
-
NL2SQL Copilot pipeline
|
| 17 |
-
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
def __init__(
|
|
@@ -25,7 +43,7 @@ class Pipeline:
|
|
| 25 |
generator: Generator,
|
| 26 |
safety: Safety,
|
| 27 |
executor: Optional[Executor] = None,
|
| 28 |
-
verifier: Optional[Verifier] = None
|
| 29 |
repair: Optional[Repair] = None,
|
| 30 |
):
|
| 31 |
self.detector = detector
|
|
@@ -55,7 +73,7 @@ class Pipeline:
|
|
| 55 |
if isinstance(r, StageResult):
|
| 56 |
return r
|
| 57 |
else:
|
| 58 |
-
#
|
| 59 |
return StageResult(ok=True, data=r, trace=None)
|
| 60 |
except Exception as e:
|
| 61 |
tb = traceback.format_exc()
|
|
@@ -68,41 +86,40 @@ class Pipeline:
|
|
| 68 |
user_query: str,
|
| 69 |
schema_preview: str,
|
| 70 |
clarify_answers: Optional[Dict[str, Any]] = None,
|
| 71 |
-
) ->
|
| 72 |
-
"""
|
| 73 |
-
Always returns:
|
| 74 |
-
{
|
| 75 |
-
"ambiguous": bool,
|
| 76 |
-
"error": bool,
|
| 77 |
-
"details": list[str] | None,
|
| 78 |
-
"sql": str | None,
|
| 79 |
-
"rationale": str | None,
|
| 80 |
-
"verified": bool | None,
|
| 81 |
-
"traces": list[dict]
|
| 82 |
-
}
|
| 83 |
-
"""
|
| 84 |
traces: List[dict] = []
|
| 85 |
details: List[str] = []
|
| 86 |
-
sql
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# --- 1) ambiguity detection
|
| 89 |
try:
|
| 90 |
questions = self.detector.detect(user_query, schema_preview)
|
| 91 |
if questions:
|
| 92 |
-
return
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
"
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
except Exception as e:
|
| 100 |
-
return
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
"
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
# --- 2) planner
|
| 108 |
r_plan = self._safe_stage(
|
|
@@ -110,12 +127,17 @@ class Pipeline:
|
|
| 110 |
)
|
| 111 |
traces.extend(self._trace_list(r_plan))
|
| 112 |
if not r_plan.ok:
|
| 113 |
-
return
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
# --- 3) generator
|
| 121 |
r_gen = self._safe_stage(
|
|
@@ -127,40 +149,51 @@ class Pipeline:
|
|
| 127 |
)
|
| 128 |
traces.extend(self._trace_list(r_gen))
|
| 129 |
if not r_gen.ok:
|
| 130 |
-
return
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
sql = r_gen.data.get("sql")
|
| 137 |
rationale = r_gen.data.get("rationale")
|
| 138 |
|
| 139 |
# --- 4) safety
|
| 140 |
-
|
|
|
|
| 141 |
traces.extend(self._trace_list(r_safe))
|
| 142 |
if not r_safe.ok:
|
| 143 |
-
return
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
# --- 5) executor
|
| 151 |
-
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data
|
| 152 |
traces.extend(self._trace_list(r_exec))
|
| 153 |
if not r_exec.ok:
|
| 154 |
details.extend(r_exec.error or [])
|
| 155 |
|
| 156 |
# --- 6) verifier
|
| 157 |
-
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
|
| 158 |
traces.extend(self._trace_list(r_ver))
|
| 159 |
verified = bool(r_ver.ok)
|
| 160 |
|
| 161 |
# --- 7) repair loop if verification failed
|
| 162 |
if not verified:
|
| 163 |
-
for
|
| 164 |
r_fix = self._safe_stage(
|
| 165 |
self.repair.run,
|
| 166 |
sql=sql,
|
|
@@ -171,29 +204,33 @@ class Pipeline:
|
|
| 171 |
if not r_fix.ok:
|
| 172 |
break
|
| 173 |
sql = r_fix.data.get("sql")
|
| 174 |
-
|
|
|
|
| 175 |
traces.extend(self._trace_list(r_safe))
|
| 176 |
if not r_safe.ok:
|
| 177 |
details.extend(r_safe.error or [])
|
| 178 |
continue
|
| 179 |
-
|
|
|
|
| 180 |
traces.extend(self._trace_list(r_exec))
|
| 181 |
if not r_exec.ok:
|
| 182 |
details.extend(r_exec.error or [])
|
| 183 |
continue
|
| 184 |
-
|
|
|
|
| 185 |
traces.extend(self._trace_list(r_ver))
|
| 186 |
verified = bool(r_ver.ok)
|
| 187 |
if verified:
|
| 188 |
break
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import traceback
|
| 3 |
+
from dataclasses import dataclass, asdict
|
| 4 |
from typing import Dict, Any, Optional, List
|
| 5 |
+
|
| 6 |
from nl2sql.types import StageResult
|
| 7 |
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 8 |
from nl2sql.planner import Planner
|
|
|
|
| 13 |
from nl2sql.repair import Repair
|
| 14 |
from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
|
| 15 |
|
| 16 |
+
|
| 17 |
+
# ---- NEW: FinalResult as domain-level, type-safe result ----
|
| 18 |
+
@dataclass(frozen=True)
|
| 19 |
+
class FinalResult:
|
| 20 |
+
ok: bool
|
| 21 |
+
ambiguous: bool
|
| 22 |
+
error: bool
|
| 23 |
+
details: Optional[List[str]]
|
| 24 |
+
sql: Optional[str]
|
| 25 |
+
rationale: Optional[str]
|
| 26 |
+
verified: Optional[bool]
|
| 27 |
+
questions: Optional[List[str]]
|
| 28 |
+
traces: List[dict]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
class Pipeline:
|
| 32 |
"""
|
| 33 |
+
NL2SQL Copilot pipeline.
|
| 34 |
+
Stages return StageResult; final result is a type-safe FinalResult.
|
| 35 |
+
Adapters (e.g. FastAPI) can serialize with dataclasses.asdict().
|
| 36 |
"""
|
| 37 |
|
| 38 |
def __init__(
|
|
|
|
| 43 |
generator: Generator,
|
| 44 |
safety: Safety,
|
| 45 |
executor: Optional[Executor] = None,
|
| 46 |
+
verifier: Optional[Verifier] = None,
|
| 47 |
repair: Optional[Repair] = None,
|
| 48 |
):
|
| 49 |
self.detector = detector
|
|
|
|
| 73 |
if isinstance(r, StageResult):
|
| 74 |
return r
|
| 75 |
else:
|
| 76 |
+
# Normalize non-StageResult returns
|
| 77 |
return StageResult(ok=True, data=r, trace=None)
|
| 78 |
except Exception as e:
|
| 79 |
tb = traceback.format_exc()
|
|
|
|
| 86 |
user_query: str,
|
| 87 |
schema_preview: str,
|
| 88 |
clarify_answers: Optional[Dict[str, Any]] = None,
|
| 89 |
+
) -> FinalResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
traces: List[dict] = []
|
| 91 |
details: List[str] = []
|
| 92 |
+
sql: Optional[str] = None
|
| 93 |
+
rationale: Optional[str] = None
|
| 94 |
+
verified: Optional[bool] = None
|
| 95 |
|
| 96 |
# --- 1) ambiguity detection
|
| 97 |
try:
|
| 98 |
questions = self.detector.detect(user_query, schema_preview)
|
| 99 |
if questions:
|
| 100 |
+
return FinalResult(
|
| 101 |
+
ok=True,
|
| 102 |
+
ambiguous=True,
|
| 103 |
+
error=False,
|
| 104 |
+
details=[f"Ambiguities found: {len(questions)}"],
|
| 105 |
+
questions=questions,
|
| 106 |
+
sql=None,
|
| 107 |
+
rationale=None,
|
| 108 |
+
verified=None,
|
| 109 |
+
traces=[],
|
| 110 |
+
)
|
| 111 |
except Exception as e:
|
| 112 |
+
return FinalResult(
|
| 113 |
+
ok=False,
|
| 114 |
+
ambiguous=True,
|
| 115 |
+
error=True,
|
| 116 |
+
details=[f"Detector failed: {e}"],
|
| 117 |
+
questions=None,
|
| 118 |
+
sql=None,
|
| 119 |
+
rationale=None,
|
| 120 |
+
verified=None,
|
| 121 |
+
traces=[],
|
| 122 |
+
)
|
| 123 |
|
| 124 |
# --- 2) planner
|
| 125 |
r_plan = self._safe_stage(
|
|
|
|
| 127 |
)
|
| 128 |
traces.extend(self._trace_list(r_plan))
|
| 129 |
if not r_plan.ok:
|
| 130 |
+
return FinalResult(
|
| 131 |
+
ok=False,
|
| 132 |
+
ambiguous=False,
|
| 133 |
+
error=True,
|
| 134 |
+
details=r_plan.error,
|
| 135 |
+
questions=None,
|
| 136 |
+
sql=None,
|
| 137 |
+
rationale=None,
|
| 138 |
+
verified=None,
|
| 139 |
+
traces=traces,
|
| 140 |
+
)
|
| 141 |
|
| 142 |
# --- 3) generator
|
| 143 |
r_gen = self._safe_stage(
|
|
|
|
| 149 |
)
|
| 150 |
traces.extend(self._trace_list(r_gen))
|
| 151 |
if not r_gen.ok:
|
| 152 |
+
return FinalResult(
|
| 153 |
+
ok=False,
|
| 154 |
+
ambiguous=False,
|
| 155 |
+
error=True,
|
| 156 |
+
details=r_gen.error,
|
| 157 |
+
questions=None,
|
| 158 |
+
sql=None,
|
| 159 |
+
rationale=None,
|
| 160 |
+
verified=None,
|
| 161 |
+
traces=traces,
|
| 162 |
+
)
|
| 163 |
sql = r_gen.data.get("sql")
|
| 164 |
rationale = r_gen.data.get("rationale")
|
| 165 |
|
| 166 |
# --- 4) safety
|
| 167 |
+
# fix: align with DummySafety signature → use .run (not .check)
|
| 168 |
+
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 169 |
traces.extend(self._trace_list(r_safe))
|
| 170 |
if not r_safe.ok:
|
| 171 |
+
return FinalResult(
|
| 172 |
+
ok=False,
|
| 173 |
+
ambiguous=False,
|
| 174 |
+
error=True,
|
| 175 |
+
details=r_safe.error,
|
| 176 |
+
questions=None,
|
| 177 |
+
sql=sql,
|
| 178 |
+
rationale=rationale,
|
| 179 |
+
verified=None,
|
| 180 |
+
traces=traces,
|
| 181 |
+
)
|
| 182 |
|
| 183 |
# --- 5) executor
|
| 184 |
+
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data.get("sql", sql))
|
| 185 |
traces.extend(self._trace_list(r_exec))
|
| 186 |
if not r_exec.ok:
|
| 187 |
details.extend(r_exec.error or [])
|
| 188 |
|
| 189 |
# --- 6) verifier
|
| 190 |
+
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec.data)
|
| 191 |
traces.extend(self._trace_list(r_ver))
|
| 192 |
verified = bool(r_ver.ok)
|
| 193 |
|
| 194 |
# --- 7) repair loop if verification failed
|
| 195 |
if not verified:
|
| 196 |
+
for _attempt in range(2):
|
| 197 |
r_fix = self._safe_stage(
|
| 198 |
self.repair.run,
|
| 199 |
sql=sql,
|
|
|
|
| 204 |
if not r_fix.ok:
|
| 205 |
break
|
| 206 |
sql = r_fix.data.get("sql")
|
| 207 |
+
|
| 208 |
+
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 209 |
traces.extend(self._trace_list(r_safe))
|
| 210 |
if not r_safe.ok:
|
| 211 |
details.extend(r_safe.error or [])
|
| 212 |
continue
|
| 213 |
+
|
| 214 |
+
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data.get("sql", sql))
|
| 215 |
traces.extend(self._trace_list(r_exec))
|
| 216 |
if not r_exec.ok:
|
| 217 |
details.extend(r_exec.error or [])
|
| 218 |
continue
|
| 219 |
+
|
| 220 |
+
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec.data)
|
| 221 |
traces.extend(self._trace_list(r_ver))
|
| 222 |
verified = bool(r_ver.ok)
|
| 223 |
if verified:
|
| 224 |
break
|
| 225 |
|
| 226 |
+
return FinalResult(
|
| 227 |
+
ok=bool(verified) and not details,
|
| 228 |
+
ambiguous=False,
|
| 229 |
+
error=bool(details) and not bool(verified),
|
| 230 |
+
details=details or None,
|
| 231 |
+
sql=sql,
|
| 232 |
+
rationale=rationale,
|
| 233 |
+
verified=verified,
|
| 234 |
+
questions=None,
|
| 235 |
+
traces=traces,
|
| 236 |
+
)
|
nl2sql/types.py
CHANGED
|
@@ -19,3 +19,20 @@ class StageResult:
|
|
| 19 |
trace: Optional[StageTrace] = None
|
| 20 |
error: Optional[List[str]] = None
|
| 21 |
notes: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
trace: Optional[StageTrace] = None
|
| 20 |
error: Optional[List[str]] = None
|
| 21 |
notes: Optional[Dict[str, Any]] = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(frozen=True)
|
| 25 |
+
class FinalResult:
|
| 26 |
+
"""
|
| 27 |
+
Final domain result of the whole pipeline.
|
| 28 |
+
Adapters (HTTP/CLI/UI) should serialize this to dict/JSON at the boundary.
|
| 29 |
+
"""
|
| 30 |
+
ok: bool # end-to-end success
|
| 31 |
+
ambiguous: bool
|
| 32 |
+
error: bool
|
| 33 |
+
sql: Optional[str]
|
| 34 |
+
rationale: Optional[str]
|
| 35 |
+
verified: Optional[bool]
|
| 36 |
+
details: Optional[List[str]]
|
| 37 |
+
questions: Optional[List[str]]
|
| 38 |
+
traces: List[Dict[str, Any]]
|
tests/test_nl2sql_router.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
from fastapi.testclient import TestClient
|
| 2 |
from app.main import app
|
| 3 |
-
from nl2sql.
|
| 4 |
|
| 5 |
client = TestClient(app)
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
return
|
| 10 |
-
|
| 11 |
|
| 12 |
path = app.url_path_for("nl2sql_handler")
|
| 13 |
|
|
@@ -16,15 +15,18 @@ path = app.url_path_for("nl2sql_handler")
|
|
| 16 |
def test_ambiguity_route(monkeypatch):
|
| 17 |
from app.routers import nl2sql
|
| 18 |
|
| 19 |
-
# mock pipeline to return
|
| 20 |
def fake_run(*args, **kwargs):
|
| 21 |
-
return
|
| 22 |
ok=True,
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
|
| 30 |
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
|
@@ -36,11 +38,11 @@ def test_ambiguity_route(monkeypatch):
|
|
| 36 |
"schema_preview": "CREATE TABLE ...",
|
| 37 |
},
|
| 38 |
)
|
| 39 |
-
|
| 40 |
assert resp.status_code == 200
|
| 41 |
data = resp.json()
|
| 42 |
assert data["ambiguous"] is True
|
| 43 |
assert "questions" in data
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
# --- 2) Error / failure case -------------------------------------------------
|
|
@@ -48,8 +50,16 @@ def test_error_route(monkeypatch):
|
|
| 48 |
from app.routers import nl2sql
|
| 49 |
|
| 50 |
def fake_run(*args, **kwargs):
|
| 51 |
-
return
|
| 52 |
-
ok=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
)
|
| 54 |
|
| 55 |
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
|
@@ -61,7 +71,6 @@ def test_error_route(monkeypatch):
|
|
| 61 |
"schema_preview": "CREATE TABLE users(id int);",
|
| 62 |
},
|
| 63 |
)
|
| 64 |
-
|
| 65 |
assert resp.status_code == 400
|
| 66 |
assert "Bad SQL" in resp.json()["detail"]
|
| 67 |
|
|
@@ -71,14 +80,16 @@ def test_success_route(monkeypatch):
|
|
| 71 |
from app.routers import nl2sql
|
| 72 |
|
| 73 |
def fake_run(*args, **kwargs):
|
| 74 |
-
return
|
| 75 |
ok=True,
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
)
|
| 83 |
|
| 84 |
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
|
@@ -96,3 +107,4 @@ def test_success_route(monkeypatch):
|
|
| 96 |
assert data["sql"].lower().startswith("select")
|
| 97 |
assert isinstance(data["traces"], list)
|
| 98 |
assert any(t["stage"] == "planner" for t in data["traces"])
|
|
|
|
|
|
| 1 |
from fastapi.testclient import TestClient
|
| 2 |
from app.main import app
|
| 3 |
+
from nl2sql.pipeline import FinalResult
|
| 4 |
|
| 5 |
client = TestClient(app)
|
| 6 |
|
| 7 |
+
def fake_trace(stage: str) -> dict:
|
| 8 |
+
# FinalResult.traces is a list of dicts (StageTrace.__dict__)
|
| 9 |
+
return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}
|
|
|
|
| 10 |
|
| 11 |
path = app.url_path_for("nl2sql_handler")
|
| 12 |
|
|
|
|
| 15 |
def test_ambiguity_route(monkeypatch):
|
| 16 |
from app.routers import nl2sql
|
| 17 |
|
| 18 |
+
# mock pipeline to return FinalResult with ambiguous=True
|
| 19 |
def fake_run(*args, **kwargs):
|
| 20 |
+
return FinalResult(
|
| 21 |
ok=True,
|
| 22 |
+
ambiguous=True,
|
| 23 |
+
error=False,
|
| 24 |
+
details=["Ambiguities found: 1"],
|
| 25 |
+
questions=["Which table do you mean?"],
|
| 26 |
+
sql=None,
|
| 27 |
+
rationale=None,
|
| 28 |
+
verified=None,
|
| 29 |
+
traces=[fake_trace("detector")],
|
| 30 |
)
|
| 31 |
|
| 32 |
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
|
|
|
| 38 |
"schema_preview": "CREATE TABLE ...",
|
| 39 |
},
|
| 40 |
)
|
|
|
|
| 41 |
assert resp.status_code == 200
|
| 42 |
data = resp.json()
|
| 43 |
assert data["ambiguous"] is True
|
| 44 |
assert "questions" in data
|
| 45 |
+
assert isinstance(data["questions"], list)
|
| 46 |
|
| 47 |
|
| 48 |
# --- 2) Error / failure case -------------------------------------------------
|
|
|
|
| 50 |
from app.routers import nl2sql
|
| 51 |
|
| 52 |
def fake_run(*args, **kwargs):
|
| 53 |
+
return FinalResult(
|
| 54 |
+
ok=False,
|
| 55 |
+
ambiguous=False,
|
| 56 |
+
error=True,
|
| 57 |
+
details=["Bad SQL"],
|
| 58 |
+
questions=None,
|
| 59 |
+
sql=None,
|
| 60 |
+
rationale=None,
|
| 61 |
+
verified=None,
|
| 62 |
+
traces=[fake_trace("safety")],
|
| 63 |
)
|
| 64 |
|
| 65 |
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
|
|
|
| 71 |
"schema_preview": "CREATE TABLE users(id int);",
|
| 72 |
},
|
| 73 |
)
|
|
|
|
| 74 |
assert resp.status_code == 400
|
| 75 |
assert "Bad SQL" in resp.json()["detail"]
|
| 76 |
|
|
|
|
| 80 |
from app.routers import nl2sql
|
| 81 |
|
| 82 |
def fake_run(*args, **kwargs):
|
| 83 |
+
return FinalResult(
|
| 84 |
ok=True,
|
| 85 |
+
ambiguous=False,
|
| 86 |
+
error=False,
|
| 87 |
+
details=None,
|
| 88 |
+
questions=None,
|
| 89 |
+
sql="SELECT * FROM users;",
|
| 90 |
+
rationale="Simple listing",
|
| 91 |
+
verified=True,
|
| 92 |
+
traces=[fake_trace("planner"), fake_trace("generator")],
|
| 93 |
)
|
| 94 |
|
| 95 |
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
|
|
|
| 107 |
assert data["sql"].lower().startswith("select")
|
| 108 |
assert isinstance(data["traces"], list)
|
| 109 |
assert any(t["stage"] == "planner" for t in data["traces"])
|
| 110 |
+
assert any(t["stage"] == "generator" for t in data["traces"])
|
tests/test_pipeline_integration.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
| 1 |
-
from nl2sql.pipeline import Pipeline
|
| 2 |
from nl2sql.types import StageResult, StageTrace
|
| 3 |
|
| 4 |
|
| 5 |
# --- Dummy stages to isolate pipeline -----------------------------------------
|
| 6 |
|
| 7 |
-
|
| 8 |
class DummyDetector:
|
| 9 |
"""Simulates ambiguity detector stage."""
|
| 10 |
|
| 11 |
-
def __init__(self, ambiguous=False):
|
| 12 |
self.ambiguous = ambiguous
|
| 13 |
|
| 14 |
def detect(self, user_query, schema_preview):
|
|
@@ -43,10 +42,12 @@ class DummyGenerator:
|
|
| 43 |
class DummySafety:
|
| 44 |
"""Simulates safety stage."""
|
| 45 |
|
| 46 |
-
|
|
|
|
| 47 |
trace = StageTrace(stage="safety", duration_ms=1.0)
|
| 48 |
if "DROP" in sql.upper():
|
| 49 |
return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
|
|
|
|
| 50 |
return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)
|
| 51 |
|
| 52 |
|
|
@@ -64,13 +65,13 @@ def test_pipeline_success():
|
|
| 64 |
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 65 |
)
|
| 66 |
|
| 67 |
-
assert isinstance(r,
|
| 68 |
assert r.ok is True
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
assert any(t.stage == "planner" for t in
|
| 72 |
-
assert any(t.stage == "generator" for t in
|
| 73 |
-
assert any(t.stage == "safety" for t in
|
| 74 |
|
| 75 |
|
| 76 |
# --- 2) Ambiguity case --------------------------------------------------------
|
|
@@ -84,10 +85,10 @@ def test_pipeline_ambiguity():
|
|
| 84 |
|
| 85 |
r = pipeline.run(user_query="show data", schema_preview="CREATE TABLE x(id int);")
|
| 86 |
|
| 87 |
-
assert isinstance(r,
|
| 88 |
assert r.ok is True
|
| 89 |
-
assert r.
|
| 90 |
-
assert isinstance(r.
|
| 91 |
|
| 92 |
|
| 93 |
# --- 3) Planner failure -------------------------------------------------------
|
|
@@ -101,9 +102,10 @@ def test_pipeline_plan_fail():
|
|
| 101 |
r = pipeline.run(
|
| 102 |
user_query="fail_plan", schema_preview="CREATE TABLE singer(id int);"
|
| 103 |
)
|
| 104 |
-
assert isinstance(r,
|
| 105 |
assert r.ok is False
|
| 106 |
-
assert
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
# --- 4) Generator failure -----------------------------------------------------
|
|
@@ -117,8 +119,10 @@ def test_pipeline_gen_fail():
|
|
| 117 |
r = pipeline.run(
|
| 118 |
user_query="fail_gen", schema_preview="CREATE TABLE singer(id int);"
|
| 119 |
)
|
|
|
|
| 120 |
assert r.ok is False
|
| 121 |
-
assert
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
# --- 5) Safety failure --------------------------------------------------------
|
|
@@ -140,5 +144,7 @@ def test_pipeline_safety_fail():
|
|
| 140 |
r = pipeline.run(
|
| 141 |
user_query="drop something", schema_preview="CREATE TABLE x(id int);"
|
| 142 |
)
|
|
|
|
| 143 |
assert r.ok is False
|
| 144 |
-
assert
|
|
|
|
|
|
| 1 |
+
from nl2sql.pipeline import Pipeline, FinalResult
|
| 2 |
from nl2sql.types import StageResult, StageTrace
|
| 3 |
|
| 4 |
|
| 5 |
# --- Dummy stages to isolate pipeline -----------------------------------------
|
| 6 |
|
|
|
|
| 7 |
class DummyDetector:
|
| 8 |
"""Simulates ambiguity detector stage."""
|
| 9 |
|
| 10 |
+
def __init__(self, ambiguous: bool = False):
|
| 11 |
self.ambiguous = ambiguous
|
| 12 |
|
| 13 |
def detect(self, user_query, schema_preview):
|
|
|
|
| 42 |
class DummySafety:
|
| 43 |
"""Simulates safety stage."""
|
| 44 |
|
| 45 |
+
# NOTE: pipeline now calls safety.run(sql=...)
|
| 46 |
+
def run(self, *, sql):
|
| 47 |
trace = StageTrace(stage="safety", duration_ms=1.0)
|
| 48 |
if "DROP" in sql.upper():
|
| 49 |
return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
|
| 50 |
+
# echo back sql in data to feed executor
|
| 51 |
return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)
|
| 52 |
|
| 53 |
|
|
|
|
| 65 |
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 66 |
)
|
| 67 |
|
| 68 |
+
assert isinstance(r, FinalResult)
|
| 69 |
assert r.ok is True
|
| 70 |
+
assert r.sql is not None and r.sql.lower().startswith("select")
|
| 71 |
+
# traces is a list of dicts (StageTrace.__dict__)
|
| 72 |
+
assert any(t.get("stage") == "planner" for t in r.traces)
|
| 73 |
+
assert any(t.get("stage") == "generator" for t in r.traces)
|
| 74 |
+
assert any(t.get("stage") == "safety" for t in r.traces)
|
| 75 |
|
| 76 |
|
| 77 |
# --- 2) Ambiguity case --------------------------------------------------------
|
|
|
|
| 85 |
|
| 86 |
r = pipeline.run(user_query="show data", schema_preview="CREATE TABLE x(id int);")
|
| 87 |
|
| 88 |
+
assert isinstance(r, FinalResult)
|
| 89 |
assert r.ok is True
|
| 90 |
+
assert r.ambiguous is True
|
| 91 |
+
assert isinstance(r.questions, list) and len(r.questions) > 0
|
| 92 |
|
| 93 |
|
| 94 |
# --- 3) Planner failure -------------------------------------------------------
|
|
|
|
| 102 |
r = pipeline.run(
|
| 103 |
user_query="fail_plan", schema_preview="CREATE TABLE singer(id int);"
|
| 104 |
)
|
| 105 |
+
assert isinstance(r, FinalResult)
|
| 106 |
assert r.ok is False
|
| 107 |
+
assert r.details is not None
|
| 108 |
+
assert "Planner failed" in " ".join(r.details)
|
| 109 |
|
| 110 |
|
| 111 |
# --- 4) Generator failure -----------------------------------------------------
|
|
|
|
| 119 |
r = pipeline.run(
|
| 120 |
user_query="fail_gen", schema_preview="CREATE TABLE singer(id int);"
|
| 121 |
)
|
| 122 |
+
assert isinstance(r, FinalResult)
|
| 123 |
assert r.ok is False
|
| 124 |
+
assert r.details is not None
|
| 125 |
+
assert "Generator failed" in " ".join(r.details)
|
| 126 |
|
| 127 |
|
| 128 |
# --- 5) Safety failure --------------------------------------------------------
|
|
|
|
| 144 |
r = pipeline.run(
|
| 145 |
user_query="drop something", schema_preview="CREATE TABLE x(id int);"
|
| 146 |
)
|
| 147 |
+
assert isinstance(r, FinalResult)
|
| 148 |
assert r.ok is False
|
| 149 |
+
assert r.details is not None
|
| 150 |
+
assert "unsafe" in " ".join(r.details).lower()
|