nl2sql-copilot / nl2sql /pipeline.py
Melika Kheirieh
build(mypy): fix type errors and add safety guards for None values
a337fad
raw
history blame
7.85 kB
from __future__ import annotations
import traceback
from dataclasses import dataclass
from typing import Dict, Any, Optional, List
from nl2sql.types import StageResult
from nl2sql.ambiguity_detector import AmbiguityDetector
from nl2sql.planner import Planner
from nl2sql.generator import Generator
from nl2sql.safety import Safety
from nl2sql.executor import Executor
from nl2sql.verifier import Verifier
from nl2sql.repair import Repair
from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
@dataclass(frozen=True)
class FinalResult:
ok: bool
ambiguous: bool
error: bool
details: Optional[List[str]]
sql: Optional[str]
rationale: Optional[str]
verified: Optional[bool]
questions: Optional[List[str]]
traces: List[dict]
class Pipeline:
"""
NL2SQL Copilot pipeline.
Stages return StageResult; final result is a type-safe FinalResult.
Adapters (e.g. FastAPI) can serialize with dataclasses.asdict().
"""
def __init__(
self,
*,
detector: AmbiguityDetector,
planner: Planner,
generator: Generator,
safety: Safety,
executor: Optional[Executor] = None,
verifier: Optional[Verifier] = None,
repair: Optional[Repair] = None,
):
self.detector = detector
self.planner = planner
self.generator = generator
self.safety = safety
self.executor = executor or NoOpExecutor()
self.verifier = verifier or NoOpVerifier()
self.repair = repair or NoOpRepair()
# ------------------------------------------------------------
def _trace_list(self, *stages: StageResult) -> List[dict]:
traces = []
for s in stages:
if not s:
continue
t = getattr(s, "trace", None)
if t:
traces.append(t.__dict__)
return traces
# ------------------------------------------------------------
def _safe_stage(self, fn, **kwargs) -> StageResult:
"""Run a stage safely; if it throws, catch and convert to StageResult."""
try:
r = fn(**kwargs)
if isinstance(r, StageResult):
return r
else:
# Normalize non-StageResult returns
return StageResult(ok=True, data=r, trace=None)
except Exception as e:
tb = traceback.format_exc()
return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
# ------------------------------------------------------------
def run(
self,
*,
user_query: str,
schema_preview: str,
clarify_answers: Optional[Dict[str, Any]] = None,
) -> FinalResult:
traces: List[dict] = []
details: List[str] = []
sql: Optional[str] = None
rationale: Optional[str] = None
verified: Optional[bool] = None
# --- 1) ambiguity detection
try:
questions = self.detector.detect(user_query, schema_preview)
if questions:
return FinalResult(
ok=True,
ambiguous=True,
error=False,
details=[f"Ambiguities found: {len(questions)}"],
questions=questions,
sql=None,
rationale=None,
verified=None,
traces=[],
)
except Exception as e:
return FinalResult(
ok=False,
ambiguous=True,
error=True,
details=[f"Detector failed: {e}"],
questions=None,
sql=None,
rationale=None,
verified=None,
traces=[],
)
# --- 2) planner
r_plan = self._safe_stage(
self.planner.run, user_query=user_query, schema_preview=schema_preview
)
traces.extend(self._trace_list(r_plan))
if not r_plan.ok:
return FinalResult(
ok=False,
ambiguous=False,
error=True,
details=r_plan.error,
questions=None,
sql=None,
rationale=None,
verified=None,
traces=traces,
)
# --- 3) generator
r_gen = self._safe_stage(
self.generator.run,
user_query=user_query,
schema_preview=schema_preview,
plan_text=(r_plan.data or {}).get("plan"),
clarify_answers=clarify_answers or {},
)
traces.extend(self._trace_list(r_gen))
if not r_gen.ok:
return FinalResult(
ok=False,
ambiguous=False,
error=True,
details=r_gen.error,
questions=None,
sql=None,
rationale=None,
verified=None,
traces=traces,
)
sql = (r_gen.data or {}).get("sql")
rationale = (r_gen.data or {}).get("rationale")
# --- 4) safety
r_safe = self._safe_stage(self.safety.run, sql=sql)
traces.extend(self._trace_list(r_safe))
if not r_safe.ok:
return FinalResult(
ok=False,
ambiguous=False,
error=True,
details=r_safe.error,
questions=None,
sql=sql,
rationale=rationale,
verified=None,
traces=traces,
)
# --- 5) executor
r_exec = self._safe_stage(
self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
)
traces.extend(self._trace_list(r_exec))
if not r_exec.ok:
details.extend(r_exec.error or [])
# --- 6) verifier
r_ver = self._safe_stage(
self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
)
traces.extend(self._trace_list(r_ver))
verified = bool(r_ver.ok)
# --- 7) repair loop if verification failed
if not verified:
for _attempt in range(2):
r_fix = self._safe_stage(
self.repair.run,
sql=sql,
error_msg="; ".join(details or ["unknown"]),
schema_preview=schema_preview,
)
traces.extend(self._trace_list(r_fix))
if not r_fix.ok:
break
sql = (r_fix.data or {}).get("sql")
r_safe = self._safe_stage(self.safety.run, sql=sql)
traces.extend(self._trace_list(r_safe))
if not r_safe.ok:
details.extend(r_safe.error or [])
continue
r_exec = self._safe_stage(
self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
)
traces.extend(self._trace_list(r_exec))
if not r_exec.ok:
details.extend(r_exec.error or [])
continue
r_ver = self._safe_stage(
self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
)
traces.extend(self._trace_list(r_ver))
verified = bool(r_ver.ok)
if verified:
break
return FinalResult(
ok=bool(verified) and not details,
ambiguous=False,
error=bool(details) and not bool(verified),
details=details or None,
sql=sql,
rationale=rationale,
verified=verified,
questions=None,
traces=traces,
)