Spaces:
Sleeping
Sleeping
File size: 6,391 Bytes
570f7bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from __future__ import annotations
import traceback
from typing import Dict, Any, Optional, List
from nl2sql.types import StageResult
from nl2sql.ambiguity_detector import AmbiguityDetector
from nl2sql.planner import Planner
from nl2sql.generator import Generator
from nl2sql.safety import Safety
from nl2sql.executor import Executor
from nl2sql.verifier import Verifier
from nl2sql.repair import Repair
class Pipeline:
"""
NL2SQL Copilot pipeline with guaranteed dict output.
All stages return structured traces and errors but final result is JSON-safe dict.
"""
def __init__(self, *,
detector: AmbiguityDetector,
planner: Planner,
generator: Generator,
safety: Safety,
executor: Executor,
verifier: Verifier,
repair: Repair):
self.detector = detector
self.planner = planner
self.generator = generator
self.safety = safety
self.executor = executor
self.verifier = verifier
self.repair = repair
# ------------------------------------------------------------
def _trace_list(self, *stages: StageResult) -> List[dict]:
traces = []
for s in stages:
if not s:
continue
t = getattr(s, "trace", None)
if t:
traces.append(t.__dict__)
return traces
# ------------------------------------------------------------
def _safe_stage(self, fn, **kwargs) -> StageResult:
"""Run a stage safely; if it throws, catch and convert to StageResult."""
try:
r = fn(**kwargs)
if isinstance(r, StageResult):
return r
else:
# not ideal, but wrap it
return StageResult(ok=True, data=r, trace=None)
except Exception as e:
tb = traceback.format_exc()
return StageResult(ok=False, data=None, trace=None, errors=[f"{e}", tb])
# ------------------------------------------------------------
def run(self, *, user_query: str, schema_preview: str,
clarify_answers: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Always returns:
{
"ambiguous": bool,
"error": bool,
"details": list[str] | None,
"sql": str | None,
"rationale": str | None,
"verified": bool | None,
"traces": list[dict]
}
"""
traces: List[dict] = []
details: List[str] = []
sql, rationale, verified = None, None, None
# --- 1) ambiguity detection
try:
questions = self.detector.detect(user_query, schema_preview)
if questions:
return {
"ambiguous": True,
"error": False,
"details": [f"Ambiguities found: {len(questions)}"],
"questions": questions,
"traces": []
}
except Exception as e:
return {"ambiguous": True, "error": True, "details": [f"Detector failed: {e}"], "traces": []}
# --- 2) planner
r_plan = self._safe_stage(self.planner.run, user_query=user_query, schema_preview=schema_preview)
traces.extend(self._trace_list(r_plan))
if not r_plan.ok:
return {"ambiguous": False, "error": True, "details": r_plan.errors, "traces": traces}
# --- 3) generator
r_gen = self._safe_stage(self.generator.run,
user_query=user_query,
schema_preview=schema_preview,
plan_text=r_plan.data.get("plan"),
clarify_answers=clarify_answers or {})
traces.extend(self._trace_list(r_gen))
if not r_gen.ok:
return {"ambiguous": False, "error": True, "details": r_gen.errors, "traces": traces}
sql = r_gen.data.get("sql")
rationale = r_gen.data.get("rationale")
# --- 4) safety
r_safe = self._safe_stage(self.safety.check, sql=sql)
traces.extend(self._trace_list(r_safe))
if not r_safe.ok:
return {"ambiguous": False, "error": True, "details": r_safe.errors, "traces": traces}
# --- 5) executor
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
traces.extend(self._trace_list(r_exec))
if not r_exec.ok:
details.extend(r_exec.errors or [])
# --- 6) verifier
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
traces.extend(self._trace_list(r_ver))
verified = bool(r_ver.ok)
# --- 7) repair loop if verification failed
if not verified:
for attempt in range(2):
r_fix = self._safe_stage(self.repair.run,
sql=sql,
error_msg="; ".join(details or ["unknown"]),
schema_preview=schema_preview)
traces.extend(self._trace_list(r_fix))
if not r_fix.ok:
break
sql = r_fix.data.get("sql")
r_safe = self._safe_stage(self.safety.check, sql=sql)
traces.extend(self._trace_list(r_safe))
if not r_safe.ok:
details.extend(r_safe.errors or [])
continue
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
traces.extend(self._trace_list(r_exec))
if not r_exec.ok:
details.extend(r_exec.errors or [])
continue
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
traces.extend(self._trace_list(r_ver))
verified = bool(r_ver.ok)
if verified:
break
# --- Final result dict
return {
"ambiguous": False,
"error": len(details) > 0 and not verified,
"details": details or None,
"sql": sql,
"rationale": rationale,
"verified": verified,
"traces": traces,
}
|