Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
4dae3e6
1
Parent(s):
a45c0eb
style: format code with ruff
Browse files- app/routers/nl2sql.py +1 -1
- nl2sql/pipeline.py +6 -2
- nl2sql/types.py +2 -1
- tests/test_nl2sql_router.py +2 -0
- tests/test_pipeline_integration.py +1 -0
app/routers/nl2sql.py
CHANGED
|
@@ -76,7 +76,7 @@ def nl2sql_handler(request: NL2SQLRequest):
|
|
| 76 |
raise HTTPException(status_code=400, detail=detail)
|
| 77 |
|
| 78 |
# --- Success case ---
|
| 79 |
-
traces = [
|
| 80 |
return NL2SQLResponse(
|
| 81 |
ambiguous=False,
|
| 82 |
sql=result.sql,
|
|
|
|
| 76 |
raise HTTPException(status_code=400, detail=detail)
|
| 77 |
|
| 78 |
# --- Success case ---
|
| 79 |
+
traces = [_round_trace(t) for t in (result.traces or [])]
|
| 80 |
return NL2SQLResponse(
|
| 81 |
ambiguous=False,
|
| 82 |
sql=result.sql,
|
nl2sql/pipeline.py
CHANGED
|
@@ -211,13 +211,17 @@ class Pipeline:
|
|
| 211 |
details.extend(r_safe.error or [])
|
| 212 |
continue
|
| 213 |
|
| 214 |
-
r_exec = self._safe_stage(
|
|
|
|
|
|
|
| 215 |
traces.extend(self._trace_list(r_exec))
|
| 216 |
if not r_exec.ok:
|
| 217 |
details.extend(r_exec.error or [])
|
| 218 |
continue
|
| 219 |
|
| 220 |
-
r_ver = self._safe_stage(
|
|
|
|
|
|
|
| 221 |
traces.extend(self._trace_list(r_ver))
|
| 222 |
verified = bool(r_ver.ok)
|
| 223 |
if verified:
|
|
|
|
| 211 |
details.extend(r_safe.error or [])
|
| 212 |
continue
|
| 213 |
|
| 214 |
+
r_exec = self._safe_stage(
|
| 215 |
+
self.executor.run, sql=r_safe.data.get("sql", sql)
|
| 216 |
+
)
|
| 217 |
traces.extend(self._trace_list(r_exec))
|
| 218 |
if not r_exec.ok:
|
| 219 |
details.extend(r_exec.error or [])
|
| 220 |
continue
|
| 221 |
|
| 222 |
+
r_ver = self._safe_stage(
|
| 223 |
+
self.verifier.run, sql=sql, exec_result=r_exec.data
|
| 224 |
+
)
|
| 225 |
traces.extend(self._trace_list(r_ver))
|
| 226 |
verified = bool(r_ver.ok)
|
| 227 |
if verified:
|
nl2sql/types.py
CHANGED
|
@@ -27,7 +27,8 @@ class FinalResult:
|
|
| 27 |
Final domain result of the whole pipeline.
|
| 28 |
Adapters (HTTP/CLI/UI) should serialize this to dict/JSON at the boundary.
|
| 29 |
"""
|
| 30 |
-
|
|
|
|
| 31 |
ambiguous: bool
|
| 32 |
error: bool
|
| 33 |
sql: Optional[str]
|
|
|
|
| 27 |
Final domain result of the whole pipeline.
|
| 28 |
Adapters (HTTP/CLI/UI) should serialize this to dict/JSON at the boundary.
|
| 29 |
"""
|
| 30 |
+
|
| 31 |
+
ok: bool # end-to-end success
|
| 32 |
ambiguous: bool
|
| 33 |
error: bool
|
| 34 |
sql: Optional[str]
|
tests/test_nl2sql_router.py
CHANGED
|
@@ -4,10 +4,12 @@ from nl2sql.pipeline import FinalResult
|
|
| 4 |
|
| 5 |
client = TestClient(app)
|
| 6 |
|
|
|
|
| 7 |
def fake_trace(stage: str) -> dict:
|
| 8 |
# FinalResult.traces is a list of dicts (StageTrace.__dict__)
|
| 9 |
return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}
|
| 10 |
|
|
|
|
| 11 |
path = app.url_path_for("nl2sql_handler")
|
| 12 |
|
| 13 |
|
|
|
|
| 4 |
|
| 5 |
client = TestClient(app)
|
| 6 |
|
| 7 |
+
|
| 8 |
def fake_trace(stage: str) -> dict:
|
| 9 |
# FinalResult.traces is a list of dicts (StageTrace.__dict__)
|
| 10 |
return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}
|
| 11 |
|
| 12 |
+
|
| 13 |
path = app.url_path_for("nl2sql_handler")
|
| 14 |
|
| 15 |
|
tests/test_pipeline_integration.py
CHANGED
|
@@ -4,6 +4,7 @@ from nl2sql.types import StageResult, StageTrace
|
|
| 4 |
|
| 5 |
# --- Dummy stages to isolate pipeline -----------------------------------------
|
| 6 |
|
|
|
|
| 7 |
class DummyDetector:
|
| 8 |
"""Simulates ambiguity detector stage."""
|
| 9 |
|
|
|
|
| 4 |
|
| 5 |
# --- Dummy stages to isolate pipeline -----------------------------------------
|
| 6 |
|
| 7 |
+
|
| 8 |
class DummyDetector:
|
| 9 |
"""Simulates ambiguity detector stage."""
|
| 10 |
|