Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
7ece28d
1
Parent(s):
b21cd69
feat(core): always emit per-stage traces in Pipeline.run (fallback when StageResult.trace is empty)
Browse files- nl2sql/pipeline.py +47 -29
nl2sql/pipeline.py
CHANGED
|
@@ -152,11 +152,20 @@ class Pipeline:
|
|
| 152 |
traces: List[dict] = []
|
| 153 |
details: List[str] = []
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Normalize inputs
|
| 156 |
schema_preview = schema_preview or ""
|
| 157 |
clarify_answers = clarify_answers or {}
|
| 158 |
|
| 159 |
-
# --- 1) ambiguity detection (with explicit timing & trace) ---
|
| 160 |
try:
|
| 161 |
# --- 1) detector ---
|
| 162 |
t_det0 = time.perf_counter()
|
|
@@ -192,10 +201,11 @@ class Pipeline:
|
|
| 192 |
r_plan = self._safe_stage(
|
| 193 |
self.planner.run, user_query=user_query, schema_preview=schema_preview
|
| 194 |
)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
)
|
| 198 |
traces.extend(self._trace_list(r_plan))
|
|
|
|
|
|
|
| 199 |
if not r_plan.ok:
|
| 200 |
pipeline_runs_total.labels(status="error").inc()
|
| 201 |
return FinalResult(
|
|
@@ -219,10 +229,11 @@ class Pipeline:
|
|
| 219 |
plan_text=(r_plan.data or {}).get("plan"),
|
| 220 |
clarify_answers=clarify_answers,
|
| 221 |
)
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
)
|
| 225 |
traces.extend(self._trace_list(r_gen))
|
|
|
|
|
|
|
| 226 |
if not r_gen.ok:
|
| 227 |
pipeline_runs_total.labels(status="error").inc()
|
| 228 |
return FinalResult(
|
|
@@ -243,10 +254,11 @@ class Pipeline:
|
|
| 243 |
# --- 4) safety ---
|
| 244 |
t_saf0 = time.perf_counter()
|
| 245 |
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
)
|
| 249 |
traces.extend(self._trace_list(r_safe))
|
|
|
|
|
|
|
| 250 |
if not r_safe.ok:
|
| 251 |
pipeline_runs_total.labels(status="error").inc()
|
| 252 |
return FinalResult(
|
|
@@ -266,10 +278,11 @@ class Pipeline:
|
|
| 266 |
r_exec = self._safe_stage(
|
| 267 |
self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
|
| 268 |
)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
)
|
| 272 |
traces.extend(self._trace_list(r_exec))
|
|
|
|
|
|
|
| 273 |
if not r_exec.ok and r_exec.error:
|
| 274 |
# executor failure is soft; collect for repair/verifier context
|
| 275 |
details.extend(r_exec.error)
|
|
@@ -279,10 +292,11 @@ class Pipeline:
|
|
| 279 |
r_ver = self._safe_stage(
|
| 280 |
self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
|
| 281 |
)
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
)
|
| 285 |
traces.extend(self._trace_list(r_ver))
|
|
|
|
|
|
|
| 286 |
verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
|
| 287 |
|
| 288 |
# --- 7) repair loop if verification failed ---
|
|
@@ -296,10 +310,11 @@ class Pipeline:
|
|
| 296 |
error_msg="; ".join(details or ["unknown"]),
|
| 297 |
schema_preview=schema_preview,
|
| 298 |
)
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
)
|
| 302 |
traces.extend(self._trace_list(r_fix))
|
|
|
|
|
|
|
| 303 |
if not r_fix.ok:
|
| 304 |
break # give up on repair
|
| 305 |
|
|
@@ -309,10 +324,11 @@ class Pipeline:
|
|
| 309 |
# safety
|
| 310 |
t_saf0 = time.perf_counter()
|
| 311 |
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
)
|
| 315 |
traces.extend(self._trace_list(r_safe))
|
|
|
|
|
|
|
| 316 |
if not r_safe.ok:
|
| 317 |
if r_safe.error:
|
| 318 |
details.extend(r_safe.error)
|
|
@@ -323,10 +339,11 @@ class Pipeline:
|
|
| 323 |
r_exec = self._safe_stage(
|
| 324 |
self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
|
| 325 |
)
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
)
|
| 329 |
traces.extend(self._trace_list(r_exec))
|
|
|
|
|
|
|
| 330 |
if not r_exec.ok:
|
| 331 |
if r_exec.error:
|
| 332 |
details.extend(r_exec.error)
|
|
@@ -337,10 +354,11 @@ class Pipeline:
|
|
| 337 |
r_ver = self._safe_stage(
|
| 338 |
self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
|
| 339 |
)
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
)
|
| 343 |
traces.extend(self._trace_list(r_ver))
|
|
|
|
|
|
|
| 344 |
verified = (
|
| 345 |
bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
|
| 346 |
)
|
|
@@ -400,7 +418,7 @@ class Pipeline:
|
|
| 400 |
)
|
| 401 |
|
| 402 |
except Exception:
|
| 403 |
-
#
|
| 404 |
pipeline_runs_total.labels(status="error").inc()
|
| 405 |
raise
|
| 406 |
|
|
|
|
| 152 |
traces: List[dict] = []
|
| 153 |
details: List[str] = []
|
| 154 |
|
| 155 |
+
# Always push a normalized per-stage timing, even if StageResult.trace is empty
|
| 156 |
+
def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
|
| 157 |
+
traces.append(
|
| 158 |
+
self._mk_trace(
|
| 159 |
+
stage=stage_name,
|
| 160 |
+
duration_ms=dt_ms,
|
| 161 |
+
summary=("ok" if ok else "failed"),
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
# Normalize inputs
|
| 166 |
schema_preview = schema_preview or ""
|
| 167 |
clarify_answers = clarify_answers or {}
|
| 168 |
|
|
|
|
| 169 |
try:
|
| 170 |
# --- 1) detector ---
|
| 171 |
t_det0 = time.perf_counter()
|
|
|
|
| 201 |
r_plan = self._safe_stage(
|
| 202 |
self.planner.run, user_query=user_query, schema_preview=schema_preview
|
| 203 |
)
|
| 204 |
+
pln_ms = (time.perf_counter() - t_pln0) * 1000.0
|
| 205 |
+
stage_duration_ms.labels("planner").observe(pln_ms)
|
|
|
|
| 206 |
traces.extend(self._trace_list(r_plan))
|
| 207 |
+
if not getattr(r_plan, "trace", None):
|
| 208 |
+
_fallback_trace("planner", pln_ms, r_plan.ok)
|
| 209 |
if not r_plan.ok:
|
| 210 |
pipeline_runs_total.labels(status="error").inc()
|
| 211 |
return FinalResult(
|
|
|
|
| 229 |
plan_text=(r_plan.data or {}).get("plan"),
|
| 230 |
clarify_answers=clarify_answers,
|
| 231 |
)
|
| 232 |
+
gen_ms = (time.perf_counter() - t_gen0) * 1000.0
|
| 233 |
+
stage_duration_ms.labels("generator").observe(gen_ms)
|
|
|
|
| 234 |
traces.extend(self._trace_list(r_gen))
|
| 235 |
+
if not getattr(r_gen, "trace", None):
|
| 236 |
+
_fallback_trace("generator", gen_ms, r_gen.ok)
|
| 237 |
if not r_gen.ok:
|
| 238 |
pipeline_runs_total.labels(status="error").inc()
|
| 239 |
return FinalResult(
|
|
|
|
| 254 |
# --- 4) safety ---
|
| 255 |
t_saf0 = time.perf_counter()
|
| 256 |
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 257 |
+
saf_ms = (time.perf_counter() - t_saf0) * 1000.0
|
| 258 |
+
stage_duration_ms.labels("safety").observe(saf_ms)
|
|
|
|
| 259 |
traces.extend(self._trace_list(r_safe))
|
| 260 |
+
if not getattr(r_safe, "trace", None):
|
| 261 |
+
_fallback_trace("safety", saf_ms, r_safe.ok)
|
| 262 |
if not r_safe.ok:
|
| 263 |
pipeline_runs_total.labels(status="error").inc()
|
| 264 |
return FinalResult(
|
|
|
|
| 278 |
r_exec = self._safe_stage(
|
| 279 |
self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
|
| 280 |
)
|
| 281 |
+
exe_ms = (time.perf_counter() - t_exe0) * 1000.0
|
| 282 |
+
stage_duration_ms.labels("executor").observe(exe_ms)
|
|
|
|
| 283 |
traces.extend(self._trace_list(r_exec))
|
| 284 |
+
if not getattr(r_exec, "trace", None):
|
| 285 |
+
_fallback_trace("executor", exe_ms, r_exec.ok)
|
| 286 |
if not r_exec.ok and r_exec.error:
|
| 287 |
# executor failure is soft; collect for repair/verifier context
|
| 288 |
details.extend(r_exec.error)
|
|
|
|
| 292 |
r_ver = self._safe_stage(
|
| 293 |
self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
|
| 294 |
)
|
| 295 |
+
ver_ms = (time.perf_counter() - t_ver0) * 1000.0
|
| 296 |
+
stage_duration_ms.labels("verifier").observe(ver_ms)
|
|
|
|
| 297 |
traces.extend(self._trace_list(r_ver))
|
| 298 |
+
if not getattr(r_ver, "trace", None):
|
| 299 |
+
_fallback_trace("verifier", ver_ms, r_ver.ok)
|
| 300 |
verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
|
| 301 |
|
| 302 |
# --- 7) repair loop if verification failed ---
|
|
|
|
| 310 |
error_msg="; ".join(details or ["unknown"]),
|
| 311 |
schema_preview=schema_preview,
|
| 312 |
)
|
| 313 |
+
fix_ms = (time.perf_counter() - t_fix0) * 1000.0
|
| 314 |
+
stage_duration_ms.labels("repair").observe(fix_ms)
|
|
|
|
| 315 |
traces.extend(self._trace_list(r_fix))
|
| 316 |
+
if not getattr(r_fix, "trace", None):
|
| 317 |
+
_fallback_trace("repair", fix_ms, r_fix.ok)
|
| 318 |
if not r_fix.ok:
|
| 319 |
break # give up on repair
|
| 320 |
|
|
|
|
| 324 |
# safety
|
| 325 |
t_saf0 = time.perf_counter()
|
| 326 |
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 327 |
+
saf_ms2 = (time.perf_counter() - t_saf0) * 1000.0
|
| 328 |
+
stage_duration_ms.labels("safety").observe(saf_ms2)
|
|
|
|
| 329 |
traces.extend(self._trace_list(r_safe))
|
| 330 |
+
if not getattr(r_safe, "trace", None):
|
| 331 |
+
_fallback_trace("safety", saf_ms2, r_safe.ok)
|
| 332 |
if not r_safe.ok:
|
| 333 |
if r_safe.error:
|
| 334 |
details.extend(r_safe.error)
|
|
|
|
| 339 |
r_exec = self._safe_stage(
|
| 340 |
self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
|
| 341 |
)
|
| 342 |
+
exe_ms2 = (time.perf_counter() - t_exe0) * 1000.0
|
| 343 |
+
stage_duration_ms.labels("executor").observe(exe_ms2)
|
|
|
|
| 344 |
traces.extend(self._trace_list(r_exec))
|
| 345 |
+
if not getattr(r_exec, "trace", None):
|
| 346 |
+
_fallback_trace("executor", exe_ms2, r_exec.ok)
|
| 347 |
if not r_exec.ok:
|
| 348 |
if r_exec.error:
|
| 349 |
details.extend(r_exec.error)
|
|
|
|
| 354 |
r_ver = self._safe_stage(
|
| 355 |
self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
|
| 356 |
)
|
| 357 |
+
ver_ms2 = (time.perf_counter() - t_ver0) * 1000.0
|
| 358 |
+
stage_duration_ms.labels("verifier").observe(ver_ms2)
|
|
|
|
| 359 |
traces.extend(self._trace_list(r_ver))
|
| 360 |
+
if not getattr(r_ver, "trace", None):
|
| 361 |
+
_fallback_trace("verifier", ver_ms2, r_ver.ok)
|
| 362 |
verified = (
|
| 363 |
bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
|
| 364 |
)
|
|
|
|
| 418 |
)
|
| 419 |
|
| 420 |
except Exception:
|
| 421 |
+
# Any unexpected crash
|
| 422 |
pipeline_runs_total.labels(status="error").inc()
|
| 423 |
raise
|
| 424 |
|