nl2sql-copilot / nl2sql /pipeline.py
Melika Kheirieh
refactor: unify pipeline output via FinalResult model
a45c0eb
raw
history blame
7.82 kB
from __future__ import annotations
import traceback
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional, List
from nl2sql.types import StageResult
from nl2sql.ambiguity_detector import AmbiguityDetector
from nl2sql.planner import Planner
from nl2sql.generator import Generator
from nl2sql.safety import Safety
from nl2sql.executor import Executor
from nl2sql.verifier import Verifier
from nl2sql.repair import Repair
from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
# ---- NEW: FinalResult as domain-level, type-safe result ----
@dataclass(frozen=True)
class FinalResult:
ok: bool
ambiguous: bool
error: bool
details: Optional[List[str]]
sql: Optional[str]
rationale: Optional[str]
verified: Optional[bool]
questions: Optional[List[str]]
traces: List[dict]
class Pipeline:
"""
NL2SQL Copilot pipeline.
Stages return StageResult; final result is a type-safe FinalResult.
Adapters (e.g. FastAPI) can serialize with dataclasses.asdict().
"""
def __init__(
self,
*,
detector: AmbiguityDetector,
planner: Planner,
generator: Generator,
safety: Safety,
executor: Optional[Executor] = None,
verifier: Optional[Verifier] = None,
repair: Optional[Repair] = None,
):
self.detector = detector
self.planner = planner
self.generator = generator
self.safety = safety
self.executor = executor or NoOpExecutor()
self.verifier = verifier or NoOpVerifier()
self.repair = repair or NoOpRepair()
# ------------------------------------------------------------
def _trace_list(self, *stages: StageResult) -> List[dict]:
traces = []
for s in stages:
if not s:
continue
t = getattr(s, "trace", None)
if t:
traces.append(t.__dict__)
return traces
# ------------------------------------------------------------
def _safe_stage(self, fn, **kwargs) -> StageResult:
"""Run a stage safely; if it throws, catch and convert to StageResult."""
try:
r = fn(**kwargs)
if isinstance(r, StageResult):
return r
else:
# Normalize non-StageResult returns
return StageResult(ok=True, data=r, trace=None)
except Exception as e:
tb = traceback.format_exc()
return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
# ------------------------------------------------------------
def run(
self,
*,
user_query: str,
schema_preview: str,
clarify_answers: Optional[Dict[str, Any]] = None,
) -> FinalResult:
traces: List[dict] = []
details: List[str] = []
sql: Optional[str] = None
rationale: Optional[str] = None
verified: Optional[bool] = None
# --- 1) ambiguity detection
try:
questions = self.detector.detect(user_query, schema_preview)
if questions:
return FinalResult(
ok=True,
ambiguous=True,
error=False,
details=[f"Ambiguities found: {len(questions)}"],
questions=questions,
sql=None,
rationale=None,
verified=None,
traces=[],
)
except Exception as e:
return FinalResult(
ok=False,
ambiguous=True,
error=True,
details=[f"Detector failed: {e}"],
questions=None,
sql=None,
rationale=None,
verified=None,
traces=[],
)
# --- 2) planner
r_plan = self._safe_stage(
self.planner.run, user_query=user_query, schema_preview=schema_preview
)
traces.extend(self._trace_list(r_plan))
if not r_plan.ok:
return FinalResult(
ok=False,
ambiguous=False,
error=True,
details=r_plan.error,
questions=None,
sql=None,
rationale=None,
verified=None,
traces=traces,
)
# --- 3) generator
r_gen = self._safe_stage(
self.generator.run,
user_query=user_query,
schema_preview=schema_preview,
plan_text=r_plan.data.get("plan"),
clarify_answers=clarify_answers or {},
)
traces.extend(self._trace_list(r_gen))
if not r_gen.ok:
return FinalResult(
ok=False,
ambiguous=False,
error=True,
details=r_gen.error,
questions=None,
sql=None,
rationale=None,
verified=None,
traces=traces,
)
sql = r_gen.data.get("sql")
rationale = r_gen.data.get("rationale")
# --- 4) safety
# fix: align with DummySafety signature → use .run (not .check)
r_safe = self._safe_stage(self.safety.run, sql=sql)
traces.extend(self._trace_list(r_safe))
if not r_safe.ok:
return FinalResult(
ok=False,
ambiguous=False,
error=True,
details=r_safe.error,
questions=None,
sql=sql,
rationale=rationale,
verified=None,
traces=traces,
)
# --- 5) executor
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data.get("sql", sql))
traces.extend(self._trace_list(r_exec))
if not r_exec.ok:
details.extend(r_exec.error or [])
# --- 6) verifier
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec.data)
traces.extend(self._trace_list(r_ver))
verified = bool(r_ver.ok)
# --- 7) repair loop if verification failed
if not verified:
for _attempt in range(2):
r_fix = self._safe_stage(
self.repair.run,
sql=sql,
error_msg="; ".join(details or ["unknown"]),
schema_preview=schema_preview,
)
traces.extend(self._trace_list(r_fix))
if not r_fix.ok:
break
sql = r_fix.data.get("sql")
r_safe = self._safe_stage(self.safety.run, sql=sql)
traces.extend(self._trace_list(r_safe))
if not r_safe.ok:
details.extend(r_safe.error or [])
continue
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data.get("sql", sql))
traces.extend(self._trace_list(r_exec))
if not r_exec.ok:
details.extend(r_exec.error or [])
continue
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec.data)
traces.extend(self._trace_list(r_ver))
verified = bool(r_ver.ok)
if verified:
break
return FinalResult(
ok=bool(verified) and not details,
ambiguous=False,
error=bool(details) and not bool(verified),
details=details or None,
sql=sql,
rationale=rationale,
verified=verified,
questions=None,
traces=traces,
)