Melika Kheirieh commited on
Commit
9b94364
·
1 Parent(s): d5f745f

fix(api): map FinalResult → HTTP (200/400) and stabilize nl2sql handler; prevent 500s in tests

Browse files
Files changed (2) hide show
  1. app/routers/nl2sql.py +6 -15
  2. requirements.txt +1 -0
app/routers/nl2sql.py CHANGED
@@ -225,9 +225,6 @@ async def upload_db(file: UploadFile = File(...)):
225
  return {"db_id": db_id}
226
 
227
 
228
- # -------------------------------
229
- # Main NL2SQL endpoint
230
- # -------------------------------
231
  # -------------------------------
232
  # Main NL2SQL endpoint
233
  # -------------------------------
@@ -246,16 +243,16 @@ def nl2sql_handler(request: NL2SQLRequest):
246
  _derive_schema_preview(adapter) if isinstance(adapter, SQLiteAdapter) else ""
247
  )
248
 
249
- # Resolve schema_preview (send None when empty)
250
  provided_preview_any: Any = getattr(request, "schema_preview", None)
251
  provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
252
  final_preview: Optional[str] = provided_preview or (derived_preview_val or None)
253
 
254
- # Run pipeline
255
  try:
256
  result = pipeline.run(
257
  user_query=request.query,
258
- schema_preview=final_preview,
259
  )
260
  except Exception as exc:
261
  raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
@@ -263,17 +260,14 @@ def nl2sql_handler(request: NL2SQLRequest):
263
  if not isinstance(result, FinalResult):
264
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
265
 
266
- # Ambiguous → 200 with clarify payload
267
  if result.ambiguous and (result.questions is not None):
268
  return ClarifyResponse(
269
- ok=True, # minimal addition for schema compatibility
270
  ambiguous=True,
271
  questions=result.questions,
272
- traces=result.traces or [], # safe default to avoid validation errors
273
- details=result.details or None,
274
  )
275
 
276
- # Error → 400
277
  if (not result.ok) or result.error:
278
  print("❌ Pipeline failure dump:")
279
  print(" ok:", result.ok)
@@ -285,16 +279,13 @@ def nl2sql_handler(request: NL2SQLRequest):
285
  detail="; ".join(result.details or []) or (result.error or "Unknown error"),
286
  )
287
 
288
- # Success → 200
289
  traces = [_round_trace(t) for t in (result.traces or [])]
290
  return NL2SQLResponse(
291
- ok=True, # minimal addition
292
  ambiguous=False,
293
  sql=result.sql,
294
  rationale=result.rationale,
295
- verified=getattr(result, "verified", None),
296
  traces=traces,
297
- details=result.details or None,
298
  )
299
 
300
 
 
225
  return {"db_id": db_id}
226
 
227
 
 
 
 
228
  # -------------------------------
229
  # Main NL2SQL endpoint
230
  # -------------------------------
 
243
  _derive_schema_preview(adapter) if isinstance(adapter, SQLiteAdapter) else ""
244
  )
245
 
246
+ # Resolve schema_preview
247
  provided_preview_any: Any = getattr(request, "schema_preview", None)
248
  provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
249
  final_preview: Optional[str] = provided_preview or (derived_preview_val or None)
250
 
251
+ # Run pipeline (ensure schema_preview is str for typing)
252
  try:
253
  result = pipeline.run(
254
  user_query=request.query,
255
+ schema_preview=(final_preview or ""), # pipeline expects str
256
  )
257
  except Exception as exc:
258
  raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
 
260
  if not isinstance(result, FinalResult):
261
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
262
 
263
+ # Ambiguous → 200 with ClarifyResponse schema
264
  if result.ambiguous and (result.questions is not None):
265
  return ClarifyResponse(
 
266
  ambiguous=True,
267
  questions=result.questions,
 
 
268
  )
269
 
270
+ # Error → 400, with debug print
271
  if (not result.ok) or result.error:
272
  print("❌ Pipeline failure dump:")
273
  print(" ok:", result.ok)
 
279
  detail="; ".join(result.details or []) or (result.error or "Unknown error"),
280
  )
281
 
282
+ # Success → 200 with NL2SQLResponse schema
283
  traces = [_round_trace(t) for t in (result.traces or [])]
284
  return NL2SQLResponse(
 
285
  ambiguous=False,
286
  sql=result.sql,
287
  rationale=result.rationale,
 
288
  traces=traces,
 
289
  )
290
 
291
 
requirements.txt CHANGED
@@ -10,6 +10,7 @@ python-dotenv==1.1.1
10
  openai==2.6.1
11
  psycopg[binary]~=3.2
12
  prometheus-client>=0.20.0
 
13
  ruff
14
  gradio
15
  sqlalchemy
 
10
  openai==2.6.1
11
  psycopg[binary]~=3.2
12
  prometheus-client>=0.20.0
13
+ types-requests>=2.32.0.20241016
14
  ruff
15
  gradio
16
  sqlalchemy