Melika Kheirieh commited on
Commit
d367a93
·
1 Parent(s): a578b12

feat(metrics): instrument per-stage and pipeline_total latency; count pipeline_runs_total (ok/error/ambiguous)

Browse files
Files changed (1) hide show
  1. nl2sql/pipeline.py +214 -173
nl2sql/pipeline.py CHANGED
@@ -13,6 +13,7 @@ from nl2sql.executor import Executor
13
  from nl2sql.verifier import Verifier
14
  from nl2sql.repair import Repair
15
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
 
16
 
17
 
18
  @dataclass(frozen=True)
@@ -147,11 +148,9 @@ class Pipeline:
147
  schema_preview: str | None = None,
148
  clarify_answers: Optional[Dict[str, Any]] = None,
149
  ) -> FinalResult:
 
150
  traces: List[dict] = []
151
  details: List[str] = []
152
- sql: Optional[str] = None
153
- rationale: Optional[str] = None
154
- verified: Optional[bool] = None
155
 
156
  # Normalize inputs
157
  schema_preview = schema_preview or ""
@@ -159,23 +158,23 @@ class Pipeline:
159
 
160
  # --- 1) ambiguity detection (with explicit timing & trace) ---
161
  try:
162
- t0 = time.perf_counter()
 
163
  questions = self.detector.detect(user_query, schema_preview)
164
- t1 = time.perf_counter()
165
  is_amb = bool(questions)
 
166
  traces.append(
167
  self._mk_trace(
168
  stage="detector",
169
- duration_ms=(t1 - t0) * 1000.0,
170
  summary=("ambiguous" if is_amb else "clear"),
171
- notes={
172
- "ambiguous": is_amb,
173
- "questions_len": len(questions or []),
174
- },
175
  )
176
  )
177
 
178
  if questions:
 
179
  return FinalResult(
180
  ok=True,
181
  ambiguous=True,
@@ -187,184 +186,226 @@ class Pipeline:
187
  verified=None,
188
  traces=self._normalize_traces(traces),
189
  )
190
- except Exception as e:
191
- # detector crash mark as error but keep trace so far
192
- traces.append(
193
- self._mk_trace(
194
- stage="detector",
195
- duration_ms=0.0,
196
- summary="failed",
197
- notes={"error": str(e)},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
 
 
 
 
 
 
 
 
 
199
  )
200
- return FinalResult(
201
- ok=False,
202
- ambiguous=True,
203
- error=True,
204
- details=[f"Detector failed: {e}"],
205
- questions=None,
206
- sql=None,
207
- rationale=None,
208
- verified=None,
209
- traces=self._normalize_traces(traces),
210
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- # --- 2) planner ---
213
- r_plan = self._safe_stage(
214
- self.planner.run, user_query=user_query, schema_preview=schema_preview
215
- )
216
- traces.extend(self._trace_list(r_plan))
217
- if not r_plan.ok:
218
- return FinalResult(
219
- ok=False,
220
- ambiguous=False,
221
- error=True,
222
- details=r_plan.error,
223
- questions=None,
224
- sql=None,
225
- rationale=None,
226
- verified=None,
227
- traces=self._normalize_traces(traces),
228
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- # --- 3) generator ---
231
- r_gen = self._safe_stage(
232
- self.generator.run,
233
- user_query=user_query,
234
- schema_preview=schema_preview,
235
- plan_text=(r_plan.data or {}).get("plan"),
236
- clarify_answers=clarify_answers,
237
- )
238
- traces.extend(self._trace_list(r_gen))
239
- if not r_gen.ok:
240
- return FinalResult(
241
- ok=False,
242
- ambiguous=False,
243
- error=True,
244
- details=r_gen.error,
245
- questions=None,
246
- sql=None,
247
- rationale=None,
248
- verified=None,
249
- traces=self._normalize_traces(traces),
250
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- sql = (r_gen.data or {}).get("sql")
253
- rationale = (r_gen.data or {}).get("rationale")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
- # --- 4) safety ---
256
- r_safe = self._safe_stage(self.safety.run, sql=sql)
257
- traces.extend(self._trace_list(r_safe))
258
- if not r_safe.ok:
259
  return FinalResult(
260
- ok=False,
261
  ambiguous=False,
262
- error=True,
263
- details=r_safe.error,
264
- questions=None,
265
  sql=sql,
266
  rationale=rationale,
267
- verified=None,
 
268
  traces=self._normalize_traces(traces),
269
  )
270
 
271
- # --- 5) executor ---
272
- r_exec = self._safe_stage(
273
- self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
274
- )
275
- traces.extend(self._trace_list(r_exec))
276
- if not r_exec.ok:
277
- # executor failure does not hard-fail the pipeline; accumulate details
278
- if r_exec.error:
279
- details.extend(r_exec.error)
280
 
281
- # --- 6) verifier ---
282
- r_ver = self._safe_stage(
283
- self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
284
- )
285
- traces.extend(self._trace_list(r_ver))
286
- verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
287
-
288
- # --- 7) repair loop if verification failed ---
289
- if not verified:
290
- for _attempt in range(2):
291
- r_fix = self._safe_stage(
292
- self.repair.run,
293
- sql=sql,
294
- error_msg="; ".join(details or ["unknown"]),
295
- schema_preview=schema_preview,
296
- )
297
- traces.extend(self._trace_list(r_fix))
298
- if not r_fix.ok:
299
- # repair failed – stop trying further
300
- break
301
-
302
- # re-run safety → executor → verifier on the fixed SQL
303
- sql = (r_fix.data or {}).get("sql", sql)
304
-
305
- r_safe = self._safe_stage(self.safety.run, sql=sql)
306
- traces.extend(self._trace_list(r_safe))
307
- if not r_safe.ok:
308
- if r_safe.error:
309
- details.extend(r_safe.error)
310
- continue
311
-
312
- r_exec = self._safe_stage(
313
- self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
314
- )
315
- traces.extend(self._trace_list(r_exec))
316
- if not r_exec.ok:
317
- if r_exec.error:
318
- details.extend(r_exec.error)
319
- continue
320
-
321
- r_ver = self._safe_stage(
322
- self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
323
- )
324
- traces.extend(self._trace_list(r_ver))
325
- verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
326
- if verified:
327
- break
328
-
329
- # --- 8) fallback: verifier silent but executor succeeded ---
330
- if (verified is None or not verified) and not details:
331
- any_exec_ok = any(
332
- t.get("stage") == "executor" and (t.get("notes") or {}).get("row_count")
333
- for t in traces
334
- )
335
- if any_exec_ok:
336
- traces.append(
337
- self._mk_trace(
338
- stage="pipeline",
339
- duration_ms=0.0,
340
- summary="auto-verified",
341
- notes={"reason": "executor succeeded, verifier silent"},
342
- )
343
- )
344
- verified = True
345
-
346
- # --- 9) finalize result ---
347
- has_errors = bool(details)
348
- ok = bool(verified) and not has_errors
349
- err = has_errors and not bool(verified)
350
-
351
- traces.append(
352
- self._mk_trace(
353
- stage="pipeline",
354
- duration_ms=0.0,
355
- summary="finalize",
356
- notes={"final_verified": bool(verified), "details_len": len(details)},
357
  )
358
- )
359
-
360
- return FinalResult(
361
- ok=ok,
362
- ambiguous=False,
363
- error=err,
364
- details=details or None,
365
- sql=sql,
366
- rationale=rationale,
367
- verified=verified,
368
- questions=None,
369
- traces=self._normalize_traces(traces),
370
- )
 
13
  from nl2sql.verifier import Verifier
14
  from nl2sql.repair import Repair
15
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
16
+ from nl2sql.metrics import stage_duration_ms, pipeline_runs_total
17
 
18
 
19
  @dataclass(frozen=True)
 
148
  schema_preview: str | None = None,
149
  clarify_answers: Optional[Dict[str, Any]] = None,
150
  ) -> FinalResult:
151
+ t_all0 = time.perf_counter()
152
  traces: List[dict] = []
153
  details: List[str] = []
 
 
 
154
 
155
  # Normalize inputs
156
  schema_preview = schema_preview or ""
 
158
 
159
  # --- 1) ambiguity detection (with explicit timing & trace) ---
160
  try:
161
+ # --- 1) detector ---
162
+ t_det0 = time.perf_counter()
163
  questions = self.detector.detect(user_query, schema_preview)
164
+ det_ms = (time.perf_counter() - t_det0) * 1000.0
165
  is_amb = bool(questions)
166
+ stage_duration_ms.labels("detector").observe(det_ms)
167
  traces.append(
168
  self._mk_trace(
169
  stage="detector",
170
+ duration_ms=det_ms,
171
  summary=("ambiguous" if is_amb else "clear"),
172
+ notes={"ambiguous": is_amb, "questions_len": len(questions or [])},
 
 
 
173
  )
174
  )
175
 
176
  if questions:
177
+ pipeline_runs_total.labels(status="ambiguous").inc()
178
  return FinalResult(
179
  ok=True,
180
  ambiguous=True,
 
186
  verified=None,
187
  traces=self._normalize_traces(traces),
188
  )
189
+
190
+ # --- 2) planner ---
191
+ t_pln0 = time.perf_counter()
192
+ r_plan = self._safe_stage(
193
+ self.planner.run, user_query=user_query, schema_preview=schema_preview
194
+ )
195
+ stage_duration_ms.labels("planner").observe(
196
+ (time.perf_counter() - t_pln0) * 1000.0
197
+ )
198
+ traces.extend(self._trace_list(r_plan))
199
+ if not r_plan.ok:
200
+ pipeline_runs_total.labels(status="error").inc()
201
+ return FinalResult(
202
+ ok=False,
203
+ ambiguous=False,
204
+ error=True,
205
+ details=r_plan.error,
206
+ questions=None,
207
+ sql=None,
208
+ rationale=None,
209
+ verified=None,
210
+ traces=self._normalize_traces(traces),
211
  )
212
+
213
+ # --- 3) generator ---
214
+ t_gen0 = time.perf_counter()
215
+ r_gen = self._safe_stage(
216
+ self.generator.run,
217
+ user_query=user_query,
218
+ schema_preview=schema_preview,
219
+ plan_text=(r_plan.data or {}).get("plan"),
220
+ clarify_answers=clarify_answers,
221
  )
222
+ stage_duration_ms.labels("generator").observe(
223
+ (time.perf_counter() - t_gen0) * 1000.0
 
 
 
 
 
 
 
 
224
  )
225
+ traces.extend(self._trace_list(r_gen))
226
+ if not r_gen.ok:
227
+ pipeline_runs_total.labels(status="error").inc()
228
+ return FinalResult(
229
+ ok=False,
230
+ ambiguous=False,
231
+ error=True,
232
+ details=r_gen.error,
233
+ questions=None,
234
+ sql=None,
235
+ rationale=None,
236
+ verified=None,
237
+ traces=self._normalize_traces(traces),
238
+ )
239
 
240
+ sql = (r_gen.data or {}).get("sql")
241
+ rationale = (r_gen.data or {}).get("rationale")
242
+
243
+ # --- 4) safety ---
244
+ t_saf0 = time.perf_counter()
245
+ r_safe = self._safe_stage(self.safety.run, sql=sql)
246
+ stage_duration_ms.labels("safety").observe(
247
+ (time.perf_counter() - t_saf0) * 1000.0
 
 
 
 
 
 
 
 
248
  )
249
+ traces.extend(self._trace_list(r_safe))
250
+ if not r_safe.ok:
251
+ pipeline_runs_total.labels(status="error").inc()
252
+ return FinalResult(
253
+ ok=False,
254
+ ambiguous=False,
255
+ error=True,
256
+ details=r_safe.error,
257
+ questions=None,
258
+ sql=sql,
259
+ rationale=rationale,
260
+ verified=None,
261
+ traces=self._normalize_traces(traces),
262
+ )
263
 
264
+ # --- 5) executor ---
265
+ t_exe0 = time.perf_counter()
266
+ r_exec = self._safe_stage(
267
+ self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
268
+ )
269
+ stage_duration_ms.labels("executor").observe(
270
+ (time.perf_counter() - t_exe0) * 1000.0
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  )
272
+ traces.extend(self._trace_list(r_exec))
273
+ if not r_exec.ok and r_exec.error:
274
+ # executor failure is soft; collect for repair/verifier context
275
+ details.extend(r_exec.error)
276
+
277
+ # --- 6) verifier ---
278
+ t_ver0 = time.perf_counter()
279
+ r_ver = self._safe_stage(
280
+ self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
281
+ )
282
+ stage_duration_ms.labels("verifier").observe(
283
+ (time.perf_counter() - t_ver0) * 1000.0
284
+ )
285
+ traces.extend(self._trace_list(r_ver))
286
+ verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
287
+
288
+ # --- 7) repair loop if verification failed ---
289
+ if not verified:
290
+ for _attempt in range(2):
291
+ # repair
292
+ t_fix0 = time.perf_counter()
293
+ r_fix = self._safe_stage(
294
+ self.repair.run,
295
+ sql=sql,
296
+ error_msg="; ".join(details or ["unknown"]),
297
+ schema_preview=schema_preview,
298
+ )
299
+ stage_duration_ms.labels("repair").observe(
300
+ (time.perf_counter() - t_fix0) * 1000.0
301
+ )
302
+ traces.extend(self._trace_list(r_fix))
303
+ if not r_fix.ok:
304
+ break # give up on repair
305
+
306
+ # fixed SQL
307
+ sql = (r_fix.data or {}).get("sql", sql)
308
+
309
+ # safety
310
+ t_saf0 = time.perf_counter()
311
+ r_safe = self._safe_stage(self.safety.run, sql=sql)
312
+ stage_duration_ms.labels("safety").observe(
313
+ (time.perf_counter() - t_saf0) * 1000.0
314
+ )
315
+ traces.extend(self._trace_list(r_safe))
316
+ if not r_safe.ok:
317
+ if r_safe.error:
318
+ details.extend(r_safe.error)
319
+ continue
320
+
321
+ # executor
322
+ t_exe0 = time.perf_counter()
323
+ r_exec = self._safe_stage(
324
+ self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
325
+ )
326
+ stage_duration_ms.labels("executor").observe(
327
+ (time.perf_counter() - t_exe0) * 1000.0
328
+ )
329
+ traces.extend(self._trace_list(r_exec))
330
+ if not r_exec.ok:
331
+ if r_exec.error:
332
+ details.extend(r_exec.error)
333
+ continue
334
+
335
+ # verifier
336
+ t_ver0 = time.perf_counter()
337
+ r_ver = self._safe_stage(
338
+ self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
339
+ )
340
+ stage_duration_ms.labels("verifier").observe(
341
+ (time.perf_counter() - t_ver0) * 1000.0
342
+ )
343
+ traces.extend(self._trace_list(r_ver))
344
+ verified = (
345
+ bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
346
+ )
347
+ if verified:
348
+ break
349
+
350
+ # --- 8) fallback: verifier silent but executor succeeded ---
351
+ if (verified is None or not verified) and not details:
352
+ any_exec_ok = any(
353
+ t.get("stage") == "executor"
354
+ and (t.get("notes") or {}).get("row_count")
355
+ for t in traces
356
+ )
357
+ if any_exec_ok:
358
+ traces.append(
359
+ self._mk_trace(
360
+ stage="pipeline",
361
+ duration_ms=0.0,
362
+ summary="auto-verified",
363
+ notes={"reason": "executor succeeded, verifier silent"},
364
+ )
365
+ )
366
+ verified = True
367
 
368
+ # --- 9) finalize ---
369
+ has_errors = bool(details)
370
+ ok = bool(verified) and not has_errors
371
+ err = has_errors and not bool(verified)
372
+
373
+ if ok:
374
+ pipeline_runs_total.labels(status="ok").inc()
375
+ else:
376
+ pipeline_runs_total.labels(status="error").inc()
377
+
378
+ traces.append(
379
+ self._mk_trace(
380
+ stage="pipeline",
381
+ duration_ms=0.0,
382
+ summary="finalize",
383
+ notes={
384
+ "final_verified": bool(verified),
385
+ "details_len": len(details),
386
+ },
387
+ )
388
+ )
389
 
 
 
 
 
390
  return FinalResult(
391
+ ok=ok,
392
  ambiguous=False,
393
+ error=err,
394
+ details=details or None,
 
395
  sql=sql,
396
  rationale=rationale,
397
+ verified=verified,
398
+ questions=None,
399
  traces=self._normalize_traces(traces),
400
  )
401
 
402
+ except Exception:
403
+ # detector block already handled its own errors above; this is for any other crash
404
+ pipeline_runs_total.labels(status="error").inc()
405
+ raise
 
 
 
 
 
406
 
407
+ finally:
408
+ # Always record total latency even on early-return / exceptions
409
+ stage_duration_ms.labels("pipeline_total").observe(
410
+ (time.perf_counter() - t_all0) * 1000.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  )