Melika Kheirieh commited on
Commit
ba06dd4
Β·
1 Parent(s): b9c72a7

fix(types): avoid mypy no-redef in nl2sql_handler by predeclaring pipeline_obj

Browse files
Files changed (1) hide show
  1. app/routers/nl2sql.py +33 -18
app/routers/nl2sql.py CHANGED
@@ -168,6 +168,18 @@ def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
168
  )
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # -------------------------------
172
  # Helpers (unchanged)
173
  # -------------------------------
@@ -232,16 +244,22 @@ async def upload_db(file: UploadFile = File(...)):
232
  def nl2sql_handler(request: NL2SQLRequest):
233
  db_id = getattr(request, "db_id", None)
234
 
235
- # Pick adapter per-request (default or uploaded or postgres)
236
- adapter = _select_adapter(db_id)
237
-
238
- # Build pipeline lazily with this adapter
239
- pipeline = _build_pipeline(adapter)
240
-
241
- # Derive schema preview only for sqlite with a real path
242
- derived_preview_val: str = (
243
- _derive_schema_preview(adapter) if isinstance(adapter, SQLiteAdapter) else ""
244
- )
 
 
 
 
 
 
245
 
246
  # Resolve schema_preview
247
  provided_preview_any: Any = getattr(request, "schema_preview", None)
@@ -250,7 +268,7 @@ def nl2sql_handler(request: NL2SQLRequest):
250
 
251
  # Run pipeline (ensure schema_preview is str for typing)
252
  try:
253
- result = pipeline.run(
254
  user_query=request.query,
255
  schema_preview=(final_preview or ""), # pipeline expects str
256
  )
@@ -260,14 +278,11 @@ def nl2sql_handler(request: NL2SQLRequest):
260
  if not isinstance(result, FinalResult):
261
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
262
 
263
- # Ambiguous β†’ 200 with ClarifyResponse schema
264
  if result.ambiguous and (result.questions is not None):
265
- return ClarifyResponse(
266
- ambiguous=True,
267
- questions=result.questions,
268
- )
269
 
270
- # Error β†’ 400, with debug print
271
  if (not result.ok) or result.error:
272
  print("❌ Pipeline failure dump:")
273
  print(" ok:", result.ok)
@@ -279,7 +294,7 @@ def nl2sql_handler(request: NL2SQLRequest):
279
  detail="; ".join(result.details or []) or (result.error or "Unknown error"),
280
  )
281
 
282
- # Success β†’ 200 with NL2SQLResponse schema
283
  traces = [_round_trace(t) for t in (result.traces or [])]
284
  return NL2SQLResponse(
285
  ambiguous=False,
 
168
  )
169
 
170
 
171
+ # --- Module-level default Pipeline instance for no-db_id requests ---
172
+ # It lets tests monkeypatch `Pipeline.run` and avoids building adapters on each call.
173
+ try:
174
+ _pipeline: Pipeline = _build_pipeline(SQLiteAdapter(":memory:"))
175
+ except Exception as e:
176
+ # Fallback to a file-based sqlite if in-memory init fails in some environments
177
+ print(
178
+ f"⚠️ default _pipeline init failed on :memory: β†’ {e}; falling back to data/chinook.db"
179
+ )
180
+ _pipeline = _build_pipeline(SQLiteAdapter("data/chinook.db"))
181
+
182
+
183
  # -------------------------------
184
  # Helpers (unchanged)
185
  # -------------------------------
 
244
  def nl2sql_handler(request: NL2SQLRequest):
245
  db_id = getattr(request, "db_id", None)
246
 
247
+ # Declare once to avoid mypy no-redef
248
+ pipeline_obj: Pipeline
249
+ derived_preview_val: str
250
+
251
+ if not db_id:
252
+ # Use module-level pipeline instance (already initialized)
253
+ pipeline_obj = cast(Pipeline, _pipeline)
254
+ derived_preview_val = ""
255
+ else:
256
+ adapter = _select_adapter(db_id)
257
+ pipeline_obj = _build_pipeline(adapter)
258
+ derived_preview_val = (
259
+ _derive_schema_preview(adapter)
260
+ if isinstance(adapter, SQLiteAdapter)
261
+ else ""
262
+ )
263
 
264
  # Resolve schema_preview
265
  provided_preview_any: Any = getattr(request, "schema_preview", None)
 
268
 
269
  # Run pipeline (ensure schema_preview is str for typing)
270
  try:
271
+ result = pipeline_obj.run(
272
  user_query=request.query,
273
  schema_preview=(final_preview or ""), # pipeline expects str
274
  )
 
278
  if not isinstance(result, FinalResult):
279
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
280
 
281
+ # Ambiguous β†’ 200
282
  if result.ambiguous and (result.questions is not None):
283
+ return ClarifyResponse(ambiguous=True, questions=result.questions)
 
 
 
284
 
285
+ # Error β†’ 400 (with debug dump)
286
  if (not result.ok) or result.error:
287
  print("❌ Pipeline failure dump:")
288
  print(" ok:", result.ok)
 
294
  detail="; ".join(result.details or []) or (result.error or "Unknown error"),
295
  )
296
 
297
+ # Success β†’ 200
298
  traces = [_round_trace(t) for t in (result.traces or [])]
299
  return NL2SQLResponse(
300
  ambiguous=False,