Spaces:
Running
Running
| from __future__ import annotations | |
| import traceback | |
| from typing import Dict, Any, Optional, List | |
| from nl2sql.types import StageResult | |
| from nl2sql.ambiguity_detector import AmbiguityDetector | |
| from nl2sql.planner import Planner | |
| from nl2sql.generator import Generator | |
| from nl2sql.safety import Safety | |
| from nl2sql.executor import Executor | |
| from nl2sql.verifier import Verifier | |
| from nl2sql.repair import Repair | |
| class Pipeline: | |
| """ | |
| NL2SQL Copilot pipeline with guaranteed dict output. | |
| All stages return structured traces and errors but final result is JSON-safe dict. | |
| """ | |
| def __init__(self, *, | |
| detector: AmbiguityDetector, | |
| planner: Planner, | |
| generator: Generator, | |
| safety: Safety, | |
| executor: Executor, | |
| verifier: Verifier, | |
| repair: Repair): | |
| self.detector = detector | |
| self.planner = planner | |
| self.generator = generator | |
| self.safety = safety | |
| self.executor = executor | |
| self.verifier = verifier | |
| self.repair = repair | |
| # ------------------------------------------------------------ | |
| def _trace_list(self, *stages: StageResult) -> List[dict]: | |
| traces = [] | |
| for s in stages: | |
| if not s: | |
| continue | |
| t = getattr(s, "trace", None) | |
| if t: | |
| traces.append(t.__dict__) | |
| return traces | |
| # ------------------------------------------------------------ | |
| def _safe_stage(self, fn, **kwargs) -> StageResult: | |
| """Run a stage safely; if it throws, catch and convert to StageResult.""" | |
| try: | |
| r = fn(**kwargs) | |
| if isinstance(r, StageResult): | |
| return r | |
| else: | |
| # not ideal, but wrap it | |
| return StageResult(ok=True, data=r, trace=None) | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| return StageResult(ok=False, data=None, trace=None, errors=[f"{e}", tb]) | |
| # ------------------------------------------------------------ | |
| def run(self, *, user_query: str, schema_preview: str, | |
| clarify_answers: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: | |
| """ | |
| Always returns: | |
| { | |
| "ambiguous": bool, | |
| "error": bool, | |
| "details": list[str] | None, | |
| "sql": str | None, | |
| "rationale": str | None, | |
| "verified": bool | None, | |
| "traces": list[dict] | |
| } | |
| """ | |
| traces: List[dict] = [] | |
| details: List[str] = [] | |
| sql, rationale, verified = None, None, None | |
| # --- 1) ambiguity detection | |
| try: | |
| questions = self.detector.detect(user_query, schema_preview) | |
| if questions: | |
| return { | |
| "ambiguous": True, | |
| "error": False, | |
| "details": [f"Ambiguities found: {len(questions)}"], | |
| "questions": questions, | |
| "traces": [] | |
| } | |
| except Exception as e: | |
| return {"ambiguous": True, "error": True, "details": [f"Detector failed: {e}"], "traces": []} | |
| # --- 2) planner | |
| r_plan = self._safe_stage(self.planner.run, user_query=user_query, schema_preview=schema_preview) | |
| traces.extend(self._trace_list(r_plan)) | |
| if not r_plan.ok: | |
| return {"ambiguous": False, "error": True, "details": r_plan.errors, "traces": traces} | |
| # --- 3) generator | |
| r_gen = self._safe_stage(self.generator.run, | |
| user_query=user_query, | |
| schema_preview=schema_preview, | |
| plan_text=r_plan.data.get("plan"), | |
| clarify_answers=clarify_answers or {}) | |
| traces.extend(self._trace_list(r_gen)) | |
| if not r_gen.ok: | |
| return {"ambiguous": False, "error": True, "details": r_gen.errors, "traces": traces} | |
| sql = r_gen.data.get("sql") | |
| rationale = r_gen.data.get("rationale") | |
| # --- 4) safety | |
| r_safe = self._safe_stage(self.safety.check, sql=sql) | |
| traces.extend(self._trace_list(r_safe)) | |
| if not r_safe.ok: | |
| return {"ambiguous": False, "error": True, "details": r_safe.errors, "traces": traces} | |
| # --- 5) executor | |
| r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"]) | |
| traces.extend(self._trace_list(r_exec)) | |
| if not r_exec.ok: | |
| details.extend(r_exec.errors or []) | |
| # --- 6) verifier | |
| r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec) | |
| traces.extend(self._trace_list(r_ver)) | |
| verified = bool(r_ver.ok) | |
| # --- 7) repair loop if verification failed | |
| if not verified: | |
| for attempt in range(2): | |
| r_fix = self._safe_stage(self.repair.run, | |
| sql=sql, | |
| error_msg="; ".join(details or ["unknown"]), | |
| schema_preview=schema_preview) | |
| traces.extend(self._trace_list(r_fix)) | |
| if not r_fix.ok: | |
| break | |
| sql = r_fix.data.get("sql") | |
| r_safe = self._safe_stage(self.safety.check, sql=sql) | |
| traces.extend(self._trace_list(r_safe)) | |
| if not r_safe.ok: | |
| details.extend(r_safe.errors or []) | |
| continue | |
| r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"]) | |
| traces.extend(self._trace_list(r_exec)) | |
| if not r_exec.ok: | |
| details.extend(r_exec.errors or []) | |
| continue | |
| r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec) | |
| traces.extend(self._trace_list(r_ver)) | |
| verified = bool(r_ver.ok) | |
| if verified: | |
| break | |
| # --- Final result dict | |
| return { | |
| "ambiguous": False, | |
| "error": len(details) > 0 and not verified, | |
| "details": details or None, | |
| "sql": sql, | |
| "rationale": rationale, | |
| "verified": verified, | |
| "traces": traces, | |
| } | |