Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
9b94364
1
Parent(s):
d5f745f
fix(api): map FinalResult → HTTP (200/400) and stabilize nl2sql handler; prevent 500s in tests
Browse files- app/routers/nl2sql.py +6 -15
- requirements.txt +1 -0
app/routers/nl2sql.py
CHANGED
|
@@ -225,9 +225,6 @@ async def upload_db(file: UploadFile = File(...)):
|
|
| 225 |
return {"db_id": db_id}
|
| 226 |
|
| 227 |
|
| 228 |
-
# -------------------------------
|
| 229 |
-
# Main NL2SQL endpoint
|
| 230 |
-
# -------------------------------
|
| 231 |
# -------------------------------
|
| 232 |
# Main NL2SQL endpoint
|
| 233 |
# -------------------------------
|
|
@@ -246,16 +243,16 @@ def nl2sql_handler(request: NL2SQLRequest):
|
|
| 246 |
_derive_schema_preview(adapter) if isinstance(adapter, SQLiteAdapter) else ""
|
| 247 |
)
|
| 248 |
|
| 249 |
-
# Resolve schema_preview
|
| 250 |
provided_preview_any: Any = getattr(request, "schema_preview", None)
|
| 251 |
provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
|
| 252 |
final_preview: Optional[str] = provided_preview or (derived_preview_val or None)
|
| 253 |
|
| 254 |
-
# Run pipeline
|
| 255 |
try:
|
| 256 |
result = pipeline.run(
|
| 257 |
user_query=request.query,
|
| 258 |
-
schema_preview=final_preview,
|
| 259 |
)
|
| 260 |
except Exception as exc:
|
| 261 |
raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
|
|
@@ -263,17 +260,14 @@ def nl2sql_handler(request: NL2SQLRequest):
|
|
| 263 |
if not isinstance(result, FinalResult):
|
| 264 |
raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
|
| 265 |
|
| 266 |
-
# Ambiguous → 200 with
|
| 267 |
if result.ambiguous and (result.questions is not None):
|
| 268 |
return ClarifyResponse(
|
| 269 |
-
ok=True, # minimal addition for schema compatibility
|
| 270 |
ambiguous=True,
|
| 271 |
questions=result.questions,
|
| 272 |
-
traces=result.traces or [], # safe default to avoid validation errors
|
| 273 |
-
details=result.details or None,
|
| 274 |
)
|
| 275 |
|
| 276 |
-
# Error → 400
|
| 277 |
if (not result.ok) or result.error:
|
| 278 |
print("❌ Pipeline failure dump:")
|
| 279 |
print(" ok:", result.ok)
|
|
@@ -285,16 +279,13 @@ def nl2sql_handler(request: NL2SQLRequest):
|
|
| 285 |
detail="; ".join(result.details or []) or (result.error or "Unknown error"),
|
| 286 |
)
|
| 287 |
|
| 288 |
-
# Success → 200
|
| 289 |
traces = [_round_trace(t) for t in (result.traces or [])]
|
| 290 |
return NL2SQLResponse(
|
| 291 |
-
ok=True, # minimal addition
|
| 292 |
ambiguous=False,
|
| 293 |
sql=result.sql,
|
| 294 |
rationale=result.rationale,
|
| 295 |
-
verified=getattr(result, "verified", None),
|
| 296 |
traces=traces,
|
| 297 |
-
details=result.details or None,
|
| 298 |
)
|
| 299 |
|
| 300 |
|
|
|
|
| 225 |
return {"db_id": db_id}
|
| 226 |
|
| 227 |
|
|
|
|
|
|
|
|
|
|
| 228 |
# -------------------------------
|
| 229 |
# Main NL2SQL endpoint
|
| 230 |
# -------------------------------
|
|
|
|
| 243 |
_derive_schema_preview(adapter) if isinstance(adapter, SQLiteAdapter) else ""
|
| 244 |
)
|
| 245 |
|
| 246 |
+
# Resolve schema_preview
|
| 247 |
provided_preview_any: Any = getattr(request, "schema_preview", None)
|
| 248 |
provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
|
| 249 |
final_preview: Optional[str] = provided_preview or (derived_preview_val or None)
|
| 250 |
|
| 251 |
+
# Run pipeline (ensure schema_preview is str for typing)
|
| 252 |
try:
|
| 253 |
result = pipeline.run(
|
| 254 |
user_query=request.query,
|
| 255 |
+
schema_preview=(final_preview or ""), # pipeline expects str
|
| 256 |
)
|
| 257 |
except Exception as exc:
|
| 258 |
raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
|
|
|
|
| 260 |
if not isinstance(result, FinalResult):
|
| 261 |
raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
|
| 262 |
|
| 263 |
+
# Ambiguous → 200 with ClarifyResponse schema
|
| 264 |
if result.ambiguous and (result.questions is not None):
|
| 265 |
return ClarifyResponse(
|
|
|
|
| 266 |
ambiguous=True,
|
| 267 |
questions=result.questions,
|
|
|
|
|
|
|
| 268 |
)
|
| 269 |
|
| 270 |
+
# Error → 400, with debug print
|
| 271 |
if (not result.ok) or result.error:
|
| 272 |
print("❌ Pipeline failure dump:")
|
| 273 |
print(" ok:", result.ok)
|
|
|
|
| 279 |
detail="; ".join(result.details or []) or (result.error or "Unknown error"),
|
| 280 |
)
|
| 281 |
|
| 282 |
+
# Success → 200 with NL2SQLResponse schema
|
| 283 |
traces = [_round_trace(t) for t in (result.traces or [])]
|
| 284 |
return NL2SQLResponse(
|
|
|
|
| 285 |
ambiguous=False,
|
| 286 |
sql=result.sql,
|
| 287 |
rationale=result.rationale,
|
|
|
|
| 288 |
traces=traces,
|
|
|
|
| 289 |
)
|
| 290 |
|
| 291 |
|
requirements.txt
CHANGED
|
@@ -10,6 +10,7 @@ python-dotenv==1.1.1
|
|
| 10 |
openai==2.6.1
|
| 11 |
psycopg[binary]~=3.2
|
| 12 |
prometheus-client>=0.20.0
|
|
|
|
| 13 |
ruff
|
| 14 |
gradio
|
| 15 |
sqlalchemy
|
|
|
|
| 10 |
openai==2.6.1
|
| 11 |
psycopg[binary]~=3.2
|
| 12 |
prometheus-client>=0.20.0
|
| 13 |
+
types-requests>=2.32.0.20241016
|
| 14 |
ruff
|
| 15 |
gradio
|
| 16 |
sqlalchemy
|