Spaces:
Sleeping
Sleeping
Melika Kheirieh
feat(pipeline): enable SQL-only repair + log-only repair for early stages with full traces/metrics
6004317
| from __future__ import annotations | |
| import traceback | |
| from dataclasses import dataclass | |
| from typing import Dict, Any, Optional, List | |
| import time | |
| from nl2sql.types import StageResult | |
| from nl2sql.ambiguity_detector import AmbiguityDetector | |
| from nl2sql.planner import Planner | |
| from nl2sql.generator import Generator | |
| from nl2sql.safety import Safety | |
| from nl2sql.executor import Executor | |
| from nl2sql.verifier import Verifier | |
| from nl2sql.repair import Repair | |
| from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier | |
| from nl2sql.metrics import stage_duration_ms, pipeline_runs_total, repair_attempts_total | |
| class FinalResult: | |
| ok: bool | |
| ambiguous: bool | |
| error: bool | |
| details: Optional[List[str]] | |
| sql: Optional[str] | |
| rationale: Optional[str] | |
| verified: Optional[bool] | |
| questions: Optional[List[str]] | |
| traces: List[dict] | |
| class Pipeline: | |
| """ | |
| NL2SQL Copilot pipeline: | |
| detector β planner β generator β safety β executor β verifier β (optional repair loop). | |
| """ | |
| SQL_REPAIR_STAGES = {"safety", "executor", "verifier"} | |
| def __init__( | |
| self, | |
| *, | |
| detector: AmbiguityDetector, | |
| planner: Planner, | |
| generator: Generator, | |
| safety: Safety, | |
| executor: Optional[Executor] = None, | |
| verifier: Optional[Verifier] = None, | |
| repair: Optional[Repair] = None, | |
| ): | |
| self.detector = detector | |
| self.planner = planner | |
| self.generator = generator | |
| self.safety = safety | |
| self.executor = executor or NoOpExecutor() | |
| self.verifier = verifier or NoOpVerifier() | |
| self.repair = repair or NoOpRepair() | |
| # If the verifier explicitly requires verification, enforce it in finalize. | |
| self.require_verification = bool(getattr(self.verifier, "required", False)) | |
| # ---------------------------- helpers ---------------------------- | |
| def _trace_list(*stages: Optional[StageResult]) -> List[dict]: | |
| traces: List[dict] = [] | |
| for s in stages: | |
| if not s: | |
| continue | |
| t = getattr(s, "trace", None) | |
| if t is not None: | |
| traces.append(getattr(t, "__dict__", t)) | |
| return traces | |
| def _mk_trace( | |
| stage: str, | |
| duration_ms: float, | |
| summary: str, | |
| notes: Optional[Dict[str, Any]] = None, | |
| ) -> dict: | |
| return { | |
| "stage": stage, | |
| "duration_ms": float(duration_ms), | |
| "summary": summary, | |
| "notes": notes or {}, | |
| } | |
| def _normalize_traces(traces: List[dict]) -> List[dict]: | |
| norm: List[dict] = [] | |
| for t in traces: | |
| stage = str(t.get("stage", "unknown")) | |
| dur = t.get("duration_ms", 0) | |
| # robust to any type; enforce minimum 1ms | |
| dur_val = 0.0 | |
| try: | |
| dur_val = float(dur) | |
| except Exception: | |
| dur_val = 0.0 | |
| dur_int = max(1, int(round(dur_val))) | |
| notes = t.get("notes") or {} | |
| summary = t.get("summary") or ("ok" if t.get("ok") else "failed") | |
| norm.append( | |
| { | |
| "stage": stage, | |
| "duration_ms": dur_int, | |
| "summary": summary, | |
| "notes": notes or {}, | |
| } | |
| ) | |
| return norm | |
| def _safe_stage(fn, **kwargs) -> StageResult: | |
| try: | |
| r = fn(**kwargs) | |
| if isinstance(r, StageResult): | |
| return r | |
| return StageResult(ok=True, data=r, trace=None) | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb]) | |
| def _run_with_repair( | |
| self, | |
| stage_name: str, | |
| fn, | |
| *, | |
| repair_input_builder, | |
| max_attempts: int = 1, | |
| traces: list, | |
| **kwargs, | |
| ) -> StageResult: | |
| """ | |
| Run a stage with per-stage repair + full observability integration. | |
| SQL-only repair occurs for safety/executor/verifier. | |
| Planner/Generator get log-only repair (trace only, no effect). | |
| """ | |
| attempt = 0 | |
| while True: | |
| # --- 1) Run stage normally --- | |
| t0 = time.perf_counter() | |
| r = self._safe_stage(fn, **kwargs) | |
| dt = (time.perf_counter() - t0) * 1000.0 | |
| stage_duration_ms.labels(stage_name).observe(dt) | |
| # attach stage trace | |
| if getattr(r, "trace", None): | |
| traces.append(r.trace.__dict__) | |
| else: | |
| traces.append( | |
| { | |
| "stage": stage_name, | |
| "duration_ms": dt, | |
| "summary": "ok" if r.ok else "failed", | |
| "notes": {}, | |
| } | |
| ) | |
| if r.ok: | |
| return r | |
| # stage failed β check repair availability | |
| attempt += 1 | |
| if attempt > max_attempts: | |
| return r | |
| # --- 2) Build repair input --- | |
| repair_args = repair_input_builder(r, kwargs) | |
| # --- 3) Run repair (always logged) --- | |
| repair_attempts_total.labels(outcome="attempt").inc() | |
| t1 = time.perf_counter() | |
| r_fix = self._safe_stage(self.repair.run, **repair_args) | |
| dt_fix = (time.perf_counter() - t1) * 1000.0 | |
| stage_duration_ms.labels("repair").observe(dt_fix) | |
| if getattr(r_fix, "trace", None): | |
| traces.append(r_fix.trace.__dict__) | |
| else: | |
| traces.append( | |
| { | |
| "stage": "repair", | |
| "duration_ms": dt_fix, | |
| "summary": "ok" if r_fix.ok else "failed", | |
| "notes": {"stage": stage_name}, | |
| } | |
| ) | |
| if not r_fix.ok: | |
| repair_attempts_total.labels(outcome="failed").inc() | |
| return r # repair itself failed β stop here | |
| # --- 4) Only inject SQL if the stage is an SQL-producing stage --- | |
| if stage_name in self.SQL_REPAIR_STAGES: | |
| if "sql" in repair_args and "sql" in kwargs: | |
| kwargs["sql"] = (r_fix.data or {}).get("sql", kwargs["sql"]) | |
| # important: success metric must reflect if repair was applied meaningfully | |
| if stage_name in self.SQL_REPAIR_STAGES: | |
| repair_attempts_total.labels(outcome="success").inc() | |
| else: | |
| # log-only mode counts as a success-attempt but not semantic success | |
| repair_attempts_total.labels(outcome="success").inc() | |
| # for SQL stages, we re-run the stage again with modified kwargs | |
| # for log-only stages, this simply loops and stage is re-run unchanged | |
| # (which is correct) | |
| def _planner_repair_input_builder(stage_result, kwargs): | |
| return { | |
| "sql": "", | |
| "error_msg": "; ".join(stage_result.error or ["planner_failed"]), | |
| "schema_preview": kwargs.get("schema_preview", ""), | |
| } | |
| def _generator_repair_input_builder(stage_result, kwargs): | |
| return { | |
| "sql": (stage_result.data or {}).get("sql", ""), | |
| "error_msg": "; ".join(stage_result.error or ["generator_failed"]), | |
| "schema_preview": kwargs.get("schema_preview", ""), | |
| } | |
| def _sql_repair_input_builder(stage_result, kwargs): | |
| return { | |
| "sql": kwargs.get("sql", ""), | |
| "error_msg": "; ".join(stage_result.error or ["stage_failed"]), | |
| "schema_preview": kwargs.get("schema_preview", ""), | |
| } | |
| def run( | |
| self, | |
| *, | |
| user_query: str, | |
| schema_preview: str | None = None, | |
| clarify_answers: Optional[Dict[str, Any]] = None, | |
| ) -> FinalResult: | |
| t_all0 = time.perf_counter() | |
| traces: List[dict] = [] | |
| details: List[str] = [] | |
| def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None: | |
| traces.append( | |
| self._mk_trace( | |
| stage=stage_name, | |
| duration_ms=dt_ms, | |
| summary=("ok" if ok else "failed"), | |
| ) | |
| ) | |
| schema_preview = schema_preview or "" | |
| clarify_answers = clarify_answers or {} | |
| try: | |
| # --- 1) detector --- | |
| t0 = time.perf_counter() | |
| questions = self.detector.detect(user_query, schema_preview) | |
| dt = (time.perf_counter() - t0) * 1000.0 | |
| is_amb = bool(questions) | |
| stage_duration_ms.labels("detector").observe(dt) | |
| traces.append( | |
| self._mk_trace( | |
| stage="detector", | |
| duration_ms=dt, | |
| summary=("ambiguous" if is_amb else "clear"), | |
| notes={"ambiguous": is_amb, "questions_len": len(questions or [])}, | |
| ) | |
| ) | |
| if questions: | |
| pipeline_runs_total.labels(status="ambiguous").inc() | |
| return FinalResult( | |
| ok=True, | |
| ambiguous=True, | |
| error=False, | |
| details=[f"Ambiguities found: {len(questions)}"], | |
| questions=questions, | |
| sql=None, | |
| rationale=None, | |
| verified=None, | |
| traces=self._normalize_traces(traces), | |
| ) | |
| # --- 2) planner --- | |
| t0 = time.perf_counter() | |
| r_plan = self._run_with_repair( | |
| "planner", | |
| self.planner.run, | |
| repair_input_builder=self._planner_repair_input_builder, | |
| max_attempts=1, | |
| user_query=user_query, | |
| traces=traces, | |
| schema_preview=schema_preview, | |
| ) | |
| dt = (time.perf_counter() - t0) * 1000.0 | |
| stage_duration_ms.labels("planner").observe(dt) | |
| traces.extend(self._trace_list(r_plan)) | |
| if not getattr(r_plan, "trace", None): | |
| _fallback_trace("planner", dt, r_plan.ok) | |
| if not r_plan.ok: | |
| pipeline_runs_total.labels(status="error").inc() | |
| return FinalResult( | |
| ok=False, | |
| ambiguous=False, | |
| error=True, | |
| details=r_plan.error, | |
| questions=None, | |
| sql=None, | |
| rationale=None, | |
| verified=None, | |
| traces=self._normalize_traces(traces), | |
| ) | |
| # --- 3) generator --- | |
| t0 = time.perf_counter() | |
| r_gen = self._run_with_repair( | |
| "generator", | |
| self.generator.run, | |
| repair_input_builder=self._generator_repair_input_builder, | |
| max_attempts=1, | |
| user_query=user_query, | |
| schema_preview=schema_preview, | |
| plan_text=(r_plan.data or {}).get("plan"), | |
| clarify_answers=clarify_answers, | |
| traces=traces, | |
| ) | |
| dt = (time.perf_counter() - t0) * 1000.0 | |
| stage_duration_ms.labels("generator").observe(dt) | |
| traces.extend(self._trace_list(r_gen)) | |
| if not getattr(r_gen, "trace", None): | |
| _fallback_trace("generator", dt, r_gen.ok) | |
| if not r_gen.ok: | |
| pipeline_runs_total.labels(status="error").inc() | |
| return FinalResult( | |
| ok=False, | |
| ambiguous=False, | |
| error=True, | |
| details=r_gen.error, | |
| questions=None, | |
| sql=None, | |
| rationale=None, | |
| verified=None, | |
| traces=self._normalize_traces(traces), | |
| ) | |
| sql = (r_gen.data or {}).get("sql") | |
| rationale = (r_gen.data or {}).get("rationale") | |
| # Guard: empty SQL | |
| if not sql or not str(sql).strip(): | |
| pipeline_runs_total.labels(status="error").inc() | |
| traces.append( | |
| self._mk_trace("generator", 0.0, "failed", {"reason": "empty_sql"}) | |
| ) | |
| return FinalResult( | |
| ok=False, | |
| ambiguous=False, | |
| error=True, | |
| details=["empty_sql"], | |
| questions=None, | |
| sql=None, | |
| rationale=rationale, | |
| verified=None, | |
| traces=self._normalize_traces(traces), | |
| ) | |
| # --- 4) safety --- | |
| t0 = time.perf_counter() | |
| r_safe = self._run_with_repair( | |
| "safety", | |
| self.safety.run, | |
| repair_input_builder=self._sql_repair_input_builder, | |
| max_attempts=1, | |
| sql=sql, | |
| traces=traces, | |
| ) | |
| dt = (time.perf_counter() - t0) * 1000.0 | |
| stage_duration_ms.labels("safety").observe(dt) | |
| traces.extend(self._trace_list(r_safe)) | |
| if not getattr(r_safe, "trace", None): | |
| _fallback_trace("safety", dt, r_safe.ok) | |
| if not r_safe.ok: | |
| pipeline_runs_total.labels(status="error").inc() | |
| return FinalResult( | |
| ok=False, | |
| ambiguous=False, | |
| error=True, | |
| details=r_safe.error, | |
| questions=None, | |
| sql=sql, | |
| rationale=rationale, | |
| verified=None, | |
| traces=self._normalize_traces(traces), | |
| ) | |
| # Use sanitized SQL from safety | |
| sql = (r_safe.data or {}).get("sql", sql) | |
| # --- 5) executor --- | |
| t0 = time.perf_counter() | |
| r_exec = self._run_with_repair( | |
| "executor", | |
| self.executor.run, | |
| repair_input_builder=self._sql_repair_input_builder, | |
| max_attempts=1, | |
| sql=sql, | |
| traces=traces, | |
| ) | |
| dt = (time.perf_counter() - t0) * 1000.0 | |
| stage_duration_ms.labels("executor").observe(dt) | |
| traces.extend(self._trace_list(r_exec)) | |
| if not getattr(r_exec, "trace", None): | |
| _fallback_trace("executor", dt, r_exec.ok) | |
| if not r_exec.ok and r_exec.error: | |
| details.extend(r_exec.error) # soft: keep for repair/verifier context | |
| # --- 6) verifier --- | |
| t0 = time.perf_counter() | |
| r_ver = self._safe_stage( | |
| self.verifier.run, | |
| sql=sql, | |
| exec_result=(r_exec.data or {}), | |
| adapter=getattr( | |
| self.executor, "adapter", None | |
| ), # let verifier use adapter | |
| ) | |
| dt = (time.perf_counter() - t0) * 1000.0 | |
| stage_duration_ms.labels("verifier").observe(dt) | |
| traces.extend(self._trace_list(r_ver)) | |
| if not getattr(r_ver, "trace", None): | |
| _fallback_trace("verifier", dt, r_ver.ok) | |
| verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok | |
| # consume repaired SQL from verifier if any | |
| if r_ver.data and "sql" in r_ver.data and r_ver.data["sql"]: | |
| sql = r_ver.data["sql"] | |
| # --- 7) repair loop (if not verified) --- | |
| if not verified: | |
| for _attempt in range(2): | |
| # repair | |
| repair_attempts_total.labels(outcome="attempt").inc() | |
| t0 = time.perf_counter() | |
| r_fix = self._safe_stage( | |
| self.repair.run, | |
| sql=sql, | |
| error_msg="; ".join(details or ["unknown"]), | |
| schema_preview=schema_preview, | |
| ) | |
| dt = (time.perf_counter() - t0) * 1000.0 | |
| stage_duration_ms.labels("repair").observe(dt) | |
| traces.extend(self._trace_list(r_fix)) | |
| if not getattr(r_fix, "trace", None): | |
| _fallback_trace("repair", dt, r_fix.ok) | |
| if not r_fix.ok: | |
| break | |
| # update SQL | |
| sql = (r_fix.data or {}).get("sql", sql) | |
| # safety again | |
| t0 = time.perf_counter() | |
| r_safe2 = self._safe_stage(self.safety.run, sql=sql) | |
| dt2 = (time.perf_counter() - t0) * 1000.0 | |
| stage_duration_ms.labels("safety").observe(dt2) | |
| traces.extend(self._trace_list(r_safe2)) | |
| if not getattr(r_safe2, "trace", None): | |
| _fallback_trace("safety", dt2, r_safe2.ok) | |
| if not r_safe2.ok: | |
| if r_safe2.error: | |
| details.extend(r_safe2.error) | |
| continue | |
| sql = (r_safe2.data or {}).get("sql", sql) | |
| # executor again | |
| t0 = time.perf_counter() | |
| r_exec2 = self._safe_stage(self.executor.run, sql=sql) | |
| dt2 = (time.perf_counter() - t0) * 1000.0 | |
| stage_duration_ms.labels("executor").observe(dt2) | |
| traces.extend(self._trace_list(r_exec2)) | |
| if not getattr(r_exec2, "trace", None): | |
| _fallback_trace("executor", dt2, r_exec2.ok) | |
| if not r_exec2.ok: | |
| if r_exec2.error: | |
| details.extend(r_exec2.error) | |
| continue | |
| # verifier again | |
| t0 = time.perf_counter() | |
| r_ver2 = self._safe_stage( | |
| self.verifier.run, | |
| sql=sql, | |
| exec_result=(r_exec2.data or {}), | |
| adapter=getattr(self.executor, "adapter", None), | |
| ) | |
| dt2 = (time.perf_counter() - t0) * 1000.0 | |
| stage_duration_ms.labels("verifier").observe(dt2) | |
| traces.extend(self._trace_list(r_ver2)) | |
| if not getattr(r_ver2, "trace", None): | |
| _fallback_trace("verifier", dt2, r_ver2.ok) | |
| verified = ( | |
| bool(r_ver2.data and r_ver2.data.get("verified")) or r_ver2.ok | |
| ) | |
| if r_ver2.data and "sql" in r_ver2.data and r_ver2.data["sql"]: | |
| sql = r_ver2.data["sql"] | |
| if verified: | |
| repair_attempts_total.labels(outcome="success").inc() | |
| break | |
| else: | |
| repair_attempts_total.labels(outcome="failed").inc() | |
| # --- 8) optional soft auto-verify (executor success, no details) --- | |
| if (verified is None or not verified) and not details: | |
| any_exec_ok = any( | |
| t.get("stage") == "executor" | |
| and (t.get("notes") or {}).get("row_count") | |
| for t in traces | |
| ) | |
| if any_exec_ok: | |
| traces.append( | |
| self._mk_trace( | |
| stage="pipeline", | |
| duration_ms=0.0, | |
| summary="auto-verified", | |
| notes={"reason": "executor succeeded, verifier silent"}, | |
| ) | |
| ) | |
| verified = True | |
| # --- 9) finalize --- | |
| has_errors = bool(details) | |
| need_ver = bool(self.require_verification) | |
| # base success condition | |
| final_ok_by_verifier = bool(verified) | |
| base_ok = ( | |
| bool(sql) and not has_errors and (final_ok_by_verifier or not need_ver) | |
| ) | |
| ok = base_ok | |
| err = (not ok) and has_errors | |
| # align `verified` with baseline semantics: | |
| # if verification is NOT required and pipeline is ok, report verified=True | |
| if not need_ver and ok and not final_ok_by_verifier: | |
| verified_final = True | |
| else: | |
| verified_final = bool(verified) | |
| pipeline_runs_total.labels(status=("ok" if ok else "error")).inc() | |
| traces.append( | |
| self._mk_trace( | |
| stage="pipeline", | |
| duration_ms=0.0, | |
| summary="finalize", | |
| notes={ | |
| "final_verified": bool(verified_final), | |
| "details_len": len(details), | |
| "need_verification": need_ver, | |
| }, | |
| ) | |
| ) | |
| return FinalResult( | |
| ok=ok, | |
| ambiguous=False, | |
| error=err, | |
| details=details or None, | |
| sql=sql, | |
| rationale=rationale, | |
| verified=verified_final, | |
| questions=None, | |
| traces=self._normalize_traces(traces), | |
| ) | |
| except Exception: | |
| pipeline_runs_total.labels(status="error").inc() | |
| # bubble up to make failures visible in tests and logs | |
| raise | |
| finally: | |
| # Always record total latency, even on early return/exception | |
| stage_duration_ms.labels("pipeline_total").observe( | |
| (time.perf_counter() - t_all0) * 1000.0 | |
| ) | |