nl2sql-copilot / nl2sql /pipeline.py
Melika Kheirieh
init: NL2SQL Copilot base with API and Dockerfile
570f7bd
raw
history blame
6.39 kB
from __future__ import annotations
import traceback
from typing import Dict, Any, Optional, List
from nl2sql.types import StageResult
from nl2sql.ambiguity_detector import AmbiguityDetector
from nl2sql.planner import Planner
from nl2sql.generator import Generator
from nl2sql.safety import Safety
from nl2sql.executor import Executor
from nl2sql.verifier import Verifier
from nl2sql.repair import Repair
class Pipeline:
"""
NL2SQL Copilot pipeline with guaranteed dict output.
All stages return structured traces and errors but final result is JSON-safe dict.
"""
def __init__(self, *,
detector: AmbiguityDetector,
planner: Planner,
generator: Generator,
safety: Safety,
executor: Executor,
verifier: Verifier,
repair: Repair):
self.detector = detector
self.planner = planner
self.generator = generator
self.safety = safety
self.executor = executor
self.verifier = verifier
self.repair = repair
# ------------------------------------------------------------
def _trace_list(self, *stages: StageResult) -> List[dict]:
traces = []
for s in stages:
if not s:
continue
t = getattr(s, "trace", None)
if t:
traces.append(t.__dict__)
return traces
# ------------------------------------------------------------
def _safe_stage(self, fn, **kwargs) -> StageResult:
"""Run a stage safely; if it throws, catch and convert to StageResult."""
try:
r = fn(**kwargs)
if isinstance(r, StageResult):
return r
else:
# not ideal, but wrap it
return StageResult(ok=True, data=r, trace=None)
except Exception as e:
tb = traceback.format_exc()
return StageResult(ok=False, data=None, trace=None, errors=[f"{e}", tb])
# ------------------------------------------------------------
def run(self, *, user_query: str, schema_preview: str,
clarify_answers: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Always returns:
{
"ambiguous": bool,
"error": bool,
"details": list[str] | None,
"sql": str | None,
"rationale": str | None,
"verified": bool | None,
"traces": list[dict]
}
"""
traces: List[dict] = []
details: List[str] = []
sql, rationale, verified = None, None, None
# --- 1) ambiguity detection
try:
questions = self.detector.detect(user_query, schema_preview)
if questions:
return {
"ambiguous": True,
"error": False,
"details": [f"Ambiguities found: {len(questions)}"],
"questions": questions,
"traces": []
}
except Exception as e:
return {"ambiguous": True, "error": True, "details": [f"Detector failed: {e}"], "traces": []}
# --- 2) planner
r_plan = self._safe_stage(self.planner.run, user_query=user_query, schema_preview=schema_preview)
traces.extend(self._trace_list(r_plan))
if not r_plan.ok:
return {"ambiguous": False, "error": True, "details": r_plan.errors, "traces": traces}
# --- 3) generator
r_gen = self._safe_stage(self.generator.run,
user_query=user_query,
schema_preview=schema_preview,
plan_text=r_plan.data.get("plan"),
clarify_answers=clarify_answers or {})
traces.extend(self._trace_list(r_gen))
if not r_gen.ok:
return {"ambiguous": False, "error": True, "details": r_gen.errors, "traces": traces}
sql = r_gen.data.get("sql")
rationale = r_gen.data.get("rationale")
# --- 4) safety
r_safe = self._safe_stage(self.safety.check, sql=sql)
traces.extend(self._trace_list(r_safe))
if not r_safe.ok:
return {"ambiguous": False, "error": True, "details": r_safe.errors, "traces": traces}
# --- 5) executor
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
traces.extend(self._trace_list(r_exec))
if not r_exec.ok:
details.extend(r_exec.errors or [])
# --- 6) verifier
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
traces.extend(self._trace_list(r_ver))
verified = bool(r_ver.ok)
# --- 7) repair loop if verification failed
if not verified:
for attempt in range(2):
r_fix = self._safe_stage(self.repair.run,
sql=sql,
error_msg="; ".join(details or ["unknown"]),
schema_preview=schema_preview)
traces.extend(self._trace_list(r_fix))
if not r_fix.ok:
break
sql = r_fix.data.get("sql")
r_safe = self._safe_stage(self.safety.check, sql=sql)
traces.extend(self._trace_list(r_safe))
if not r_safe.ok:
details.extend(r_safe.errors or [])
continue
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
traces.extend(self._trace_list(r_exec))
if not r_exec.ok:
details.extend(r_exec.errors or [])
continue
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
traces.extend(self._trace_list(r_ver))
verified = bool(r_ver.ok)
if verified:
break
# --- Final result dict
return {
"ambiguous": False,
"error": len(details) > 0 and not verified,
"details": details or None,
"sql": sql,
"rationale": rationale,
"verified": verified,
"traces": traces,
}