Melika Kheirieh commited on
Commit
7ece28d
·
1 Parent(s): b21cd69

feat(core): always emit per-stage traces in Pipeline.run (fallback when StageResult.trace is empty)

Browse files
Files changed (1) hide show
  1. nl2sql/pipeline.py +47 -29
nl2sql/pipeline.py CHANGED
@@ -152,11 +152,20 @@ class Pipeline:
152
  traces: List[dict] = []
153
  details: List[str] = []
154
 
 
 
 
 
 
 
 
 
 
 
155
  # Normalize inputs
156
  schema_preview = schema_preview or ""
157
  clarify_answers = clarify_answers or {}
158
 
159
- # --- 1) ambiguity detection (with explicit timing & trace) ---
160
  try:
161
  # --- 1) detector ---
162
  t_det0 = time.perf_counter()
@@ -192,10 +201,11 @@ class Pipeline:
192
  r_plan = self._safe_stage(
193
  self.planner.run, user_query=user_query, schema_preview=schema_preview
194
  )
195
- stage_duration_ms.labels("planner").observe(
196
- (time.perf_counter() - t_pln0) * 1000.0
197
- )
198
  traces.extend(self._trace_list(r_plan))
 
 
199
  if not r_plan.ok:
200
  pipeline_runs_total.labels(status="error").inc()
201
  return FinalResult(
@@ -219,10 +229,11 @@ class Pipeline:
219
  plan_text=(r_plan.data or {}).get("plan"),
220
  clarify_answers=clarify_answers,
221
  )
222
- stage_duration_ms.labels("generator").observe(
223
- (time.perf_counter() - t_gen0) * 1000.0
224
- )
225
  traces.extend(self._trace_list(r_gen))
 
 
226
  if not r_gen.ok:
227
  pipeline_runs_total.labels(status="error").inc()
228
  return FinalResult(
@@ -243,10 +254,11 @@ class Pipeline:
243
  # --- 4) safety ---
244
  t_saf0 = time.perf_counter()
245
  r_safe = self._safe_stage(self.safety.run, sql=sql)
246
- stage_duration_ms.labels("safety").observe(
247
- (time.perf_counter() - t_saf0) * 1000.0
248
- )
249
  traces.extend(self._trace_list(r_safe))
 
 
250
  if not r_safe.ok:
251
  pipeline_runs_total.labels(status="error").inc()
252
  return FinalResult(
@@ -266,10 +278,11 @@ class Pipeline:
266
  r_exec = self._safe_stage(
267
  self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
268
  )
269
- stage_duration_ms.labels("executor").observe(
270
- (time.perf_counter() - t_exe0) * 1000.0
271
- )
272
  traces.extend(self._trace_list(r_exec))
 
 
273
  if not r_exec.ok and r_exec.error:
274
  # executor failure is soft; collect for repair/verifier context
275
  details.extend(r_exec.error)
@@ -279,10 +292,11 @@ class Pipeline:
279
  r_ver = self._safe_stage(
280
  self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
281
  )
282
- stage_duration_ms.labels("verifier").observe(
283
- (time.perf_counter() - t_ver0) * 1000.0
284
- )
285
  traces.extend(self._trace_list(r_ver))
 
 
286
  verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
287
 
288
  # --- 7) repair loop if verification failed ---
@@ -296,10 +310,11 @@ class Pipeline:
296
  error_msg="; ".join(details or ["unknown"]),
297
  schema_preview=schema_preview,
298
  )
299
- stage_duration_ms.labels("repair").observe(
300
- (time.perf_counter() - t_fix0) * 1000.0
301
- )
302
  traces.extend(self._trace_list(r_fix))
 
 
303
  if not r_fix.ok:
304
  break # give up on repair
305
 
@@ -309,10 +324,11 @@ class Pipeline:
309
  # safety
310
  t_saf0 = time.perf_counter()
311
  r_safe = self._safe_stage(self.safety.run, sql=sql)
312
- stage_duration_ms.labels("safety").observe(
313
- (time.perf_counter() - t_saf0) * 1000.0
314
- )
315
  traces.extend(self._trace_list(r_safe))
 
 
316
  if not r_safe.ok:
317
  if r_safe.error:
318
  details.extend(r_safe.error)
@@ -323,10 +339,11 @@ class Pipeline:
323
  r_exec = self._safe_stage(
324
  self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
325
  )
326
- stage_duration_ms.labels("executor").observe(
327
- (time.perf_counter() - t_exe0) * 1000.0
328
- )
329
  traces.extend(self._trace_list(r_exec))
 
 
330
  if not r_exec.ok:
331
  if r_exec.error:
332
  details.extend(r_exec.error)
@@ -337,10 +354,11 @@ class Pipeline:
337
  r_ver = self._safe_stage(
338
  self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
339
  )
340
- stage_duration_ms.labels("verifier").observe(
341
- (time.perf_counter() - t_ver0) * 1000.0
342
- )
343
  traces.extend(self._trace_list(r_ver))
 
 
344
  verified = (
345
  bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
346
  )
@@ -400,7 +418,7 @@ class Pipeline:
400
  )
401
 
402
  except Exception:
403
- # detector block already handled its own errors above; this is for any other crash
404
  pipeline_runs_total.labels(status="error").inc()
405
  raise
406
 
 
152
  traces: List[dict] = []
153
  details: List[str] = []
154
 
155
+ # Always push a normalized per-stage timing, even if StageResult.trace is empty
156
+ def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
157
+ traces.append(
158
+ self._mk_trace(
159
+ stage=stage_name,
160
+ duration_ms=dt_ms,
161
+ summary=("ok" if ok else "failed"),
162
+ )
163
+ )
164
+
165
  # Normalize inputs
166
  schema_preview = schema_preview or ""
167
  clarify_answers = clarify_answers or {}
168
 
 
169
  try:
170
  # --- 1) detector ---
171
  t_det0 = time.perf_counter()
 
201
  r_plan = self._safe_stage(
202
  self.planner.run, user_query=user_query, schema_preview=schema_preview
203
  )
204
+ pln_ms = (time.perf_counter() - t_pln0) * 1000.0
205
+ stage_duration_ms.labels("planner").observe(pln_ms)
 
206
  traces.extend(self._trace_list(r_plan))
207
+ if not getattr(r_plan, "trace", None):
208
+ _fallback_trace("planner", pln_ms, r_plan.ok)
209
  if not r_plan.ok:
210
  pipeline_runs_total.labels(status="error").inc()
211
  return FinalResult(
 
229
  plan_text=(r_plan.data or {}).get("plan"),
230
  clarify_answers=clarify_answers,
231
  )
232
+ gen_ms = (time.perf_counter() - t_gen0) * 1000.0
233
+ stage_duration_ms.labels("generator").observe(gen_ms)
 
234
  traces.extend(self._trace_list(r_gen))
235
+ if not getattr(r_gen, "trace", None):
236
+ _fallback_trace("generator", gen_ms, r_gen.ok)
237
  if not r_gen.ok:
238
  pipeline_runs_total.labels(status="error").inc()
239
  return FinalResult(
 
254
  # --- 4) safety ---
255
  t_saf0 = time.perf_counter()
256
  r_safe = self._safe_stage(self.safety.run, sql=sql)
257
+ saf_ms = (time.perf_counter() - t_saf0) * 1000.0
258
+ stage_duration_ms.labels("safety").observe(saf_ms)
 
259
  traces.extend(self._trace_list(r_safe))
260
+ if not getattr(r_safe, "trace", None):
261
+ _fallback_trace("safety", saf_ms, r_safe.ok)
262
  if not r_safe.ok:
263
  pipeline_runs_total.labels(status="error").inc()
264
  return FinalResult(
 
278
  r_exec = self._safe_stage(
279
  self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
280
  )
281
+ exe_ms = (time.perf_counter() - t_exe0) * 1000.0
282
+ stage_duration_ms.labels("executor").observe(exe_ms)
 
283
  traces.extend(self._trace_list(r_exec))
284
+ if not getattr(r_exec, "trace", None):
285
+ _fallback_trace("executor", exe_ms, r_exec.ok)
286
  if not r_exec.ok and r_exec.error:
287
  # executor failure is soft; collect for repair/verifier context
288
  details.extend(r_exec.error)
 
292
  r_ver = self._safe_stage(
293
  self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
294
  )
295
+ ver_ms = (time.perf_counter() - t_ver0) * 1000.0
296
+ stage_duration_ms.labels("verifier").observe(ver_ms)
 
297
  traces.extend(self._trace_list(r_ver))
298
+ if not getattr(r_ver, "trace", None):
299
+ _fallback_trace("verifier", ver_ms, r_ver.ok)
300
  verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
301
 
302
  # --- 7) repair loop if verification failed ---
 
310
  error_msg="; ".join(details or ["unknown"]),
311
  schema_preview=schema_preview,
312
  )
313
+ fix_ms = (time.perf_counter() - t_fix0) * 1000.0
314
+ stage_duration_ms.labels("repair").observe(fix_ms)
 
315
  traces.extend(self._trace_list(r_fix))
316
+ if not getattr(r_fix, "trace", None):
317
+ _fallback_trace("repair", fix_ms, r_fix.ok)
318
  if not r_fix.ok:
319
  break # give up on repair
320
 
 
324
  # safety
325
  t_saf0 = time.perf_counter()
326
  r_safe = self._safe_stage(self.safety.run, sql=sql)
327
+ saf_ms2 = (time.perf_counter() - t_saf0) * 1000.0
328
+ stage_duration_ms.labels("safety").observe(saf_ms2)
 
329
  traces.extend(self._trace_list(r_safe))
330
+ if not getattr(r_safe, "trace", None):
331
+ _fallback_trace("safety", saf_ms2, r_safe.ok)
332
  if not r_safe.ok:
333
  if r_safe.error:
334
  details.extend(r_safe.error)
 
339
  r_exec = self._safe_stage(
340
  self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
341
  )
342
+ exe_ms2 = (time.perf_counter() - t_exe0) * 1000.0
343
+ stage_duration_ms.labels("executor").observe(exe_ms2)
 
344
  traces.extend(self._trace_list(r_exec))
345
+ if not getattr(r_exec, "trace", None):
346
+ _fallback_trace("executor", exe_ms2, r_exec.ok)
347
  if not r_exec.ok:
348
  if r_exec.error:
349
  details.extend(r_exec.error)
 
354
  r_ver = self._safe_stage(
355
  self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
356
  )
357
+ ver_ms2 = (time.perf_counter() - t_ver0) * 1000.0
358
+ stage_duration_ms.labels("verifier").observe(ver_ms2)
 
359
  traces.extend(self._trace_list(r_ver))
360
+ if not getattr(r_ver, "trace", None):
361
+ _fallback_trace("verifier", ver_ms2, r_ver.ok)
362
  verified = (
363
  bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
364
  )
 
418
  )
419
 
420
  except Exception:
421
+ # Any unexpected crash
422
  pipeline_runs_total.labels(status="error").inc()
423
  raise
424