Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
713d3ca
1
Parent(s):
105e019
Fix some typo
Browse files- app/main.py +2 -2
- benchmarks/__init__.py +1 -1
- nl2sql/pipeline.py +13 -13
- nl2sql/verifier.py +1 -1
app/main.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
| 2 |
|
| 3 |
load_dotenv()
|
| 4 |
|
| 5 |
-
from fastapi import FastAPI
|
| 6 |
-
from app.routers import nl2sql
|
| 7 |
|
| 8 |
app = FastAPI(
|
| 9 |
title="NL2SQL Copilot Prototype",
|
|
|
|
| 1 |
from dotenv import load_dotenv
|
| 2 |
+
from fastapi import FastAPI
|
| 3 |
+
from app.routers import nl2sql
|
| 4 |
|
| 5 |
load_dotenv()
|
| 6 |
|
|
|
|
|
|
|
| 7 |
|
| 8 |
app = FastAPI(
|
| 9 |
title="NL2SQL Copilot Prototype",
|
benchmarks/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
from .spider_loader import load_spider_sqlite, open_readonly_connection
|
| 2 |
|
| 3 |
-
__all__ = ["load_spider_sqlite"]
|
|
|
|
| 1 |
from .spider_loader import load_spider_sqlite, open_readonly_connection
|
| 2 |
|
| 3 |
+
__all__ = ["load_spider_sqlite", "open_readonly_connection"]
|
nl2sql/pipeline.py
CHANGED
|
@@ -9,7 +9,7 @@ from nl2sql.safety import Safety
|
|
| 9 |
from nl2sql.executor import Executor
|
| 10 |
from nl2sql.verifier import Verifier
|
| 11 |
from nl2sql.repair import Repair
|
| 12 |
-
|
| 13 |
|
| 14 |
class Pipeline:
|
| 15 |
"""
|
|
@@ -24,17 +24,17 @@ class Pipeline:
|
|
| 24 |
planner: Planner,
|
| 25 |
generator: Generator,
|
| 26 |
safety: Safety,
|
| 27 |
-
executor: Executor,
|
| 28 |
-
verifier: Verifier,
|
| 29 |
-
repair: Repair,
|
| 30 |
):
|
| 31 |
self.detector = detector
|
| 32 |
self.planner = planner
|
| 33 |
self.generator = generator
|
| 34 |
self.safety = safety
|
| 35 |
-
self.executor = executor
|
| 36 |
-
self.verifier = verifier
|
| 37 |
-
self.repair = repair
|
| 38 |
|
| 39 |
# ------------------------------------------------------------
|
| 40 |
def _trace_list(self, *stages: StageResult) -> List[dict]:
|
|
@@ -59,7 +59,7 @@ class Pipeline:
|
|
| 59 |
return StageResult(ok=True, data=r, trace=None)
|
| 60 |
except Exception as e:
|
| 61 |
tb = traceback.format_exc()
|
| 62 |
-
return StageResult(ok=False, data=None, trace=None,
|
| 63 |
|
| 64 |
# ------------------------------------------------------------
|
| 65 |
def run(
|
|
@@ -113,7 +113,7 @@ class Pipeline:
|
|
| 113 |
return {
|
| 114 |
"ambiguous": False,
|
| 115 |
"error": True,
|
| 116 |
-
"details": r_plan.
|
| 117 |
"traces": traces,
|
| 118 |
}
|
| 119 |
|
|
@@ -143,7 +143,7 @@ class Pipeline:
|
|
| 143 |
return {
|
| 144 |
"ambiguous": False,
|
| 145 |
"error": True,
|
| 146 |
-
"details": r_safe.
|
| 147 |
"traces": traces,
|
| 148 |
}
|
| 149 |
|
|
@@ -151,7 +151,7 @@ class Pipeline:
|
|
| 151 |
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
|
| 152 |
traces.extend(self._trace_list(r_exec))
|
| 153 |
if not r_exec.ok:
|
| 154 |
-
details.extend(r_exec.
|
| 155 |
|
| 156 |
# --- 6) verifier
|
| 157 |
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
|
|
@@ -174,12 +174,12 @@ class Pipeline:
|
|
| 174 |
r_safe = self._safe_stage(self.safety.check, sql=sql)
|
| 175 |
traces.extend(self._trace_list(r_safe))
|
| 176 |
if not r_safe.ok:
|
| 177 |
-
details.extend(r_safe.
|
| 178 |
continue
|
| 179 |
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
|
| 180 |
traces.extend(self._trace_list(r_exec))
|
| 181 |
if not r_exec.ok:
|
| 182 |
-
details.extend(r_exec.
|
| 183 |
continue
|
| 184 |
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
|
| 185 |
traces.extend(self._trace_list(r_ver))
|
|
|
|
| 9 |
from nl2sql.executor import Executor
|
| 10 |
from nl2sql.verifier import Verifier
|
| 11 |
from nl2sql.repair import Repair
|
| 12 |
+
from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
|
| 13 |
|
| 14 |
class Pipeline:
|
| 15 |
"""
|
|
|
|
| 24 |
planner: Planner,
|
| 25 |
generator: Generator,
|
| 26 |
safety: Safety,
|
| 27 |
+
executor: Optional[Executor] = None,
|
| 28 |
+
verifier: Optional[Verifier] = None ,
|
| 29 |
+
repair: Optional[Repair] = None,
|
| 30 |
):
|
| 31 |
self.detector = detector
|
| 32 |
self.planner = planner
|
| 33 |
self.generator = generator
|
| 34 |
self.safety = safety
|
| 35 |
+
self.executor = executor or NoOpExecutor()
|
| 36 |
+
self.verifier = verifier or NoOpVerifier()
|
| 37 |
+
self.repair = repair or NoOpRepair()
|
| 38 |
|
| 39 |
# ------------------------------------------------------------
|
| 40 |
def _trace_list(self, *stages: StageResult) -> List[dict]:
|
|
|
|
| 59 |
return StageResult(ok=True, data=r, trace=None)
|
| 60 |
except Exception as e:
|
| 61 |
tb = traceback.format_exc()
|
| 62 |
+
return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
|
| 63 |
|
| 64 |
# ------------------------------------------------------------
|
| 65 |
def run(
|
|
|
|
| 113 |
return {
|
| 114 |
"ambiguous": False,
|
| 115 |
"error": True,
|
| 116 |
+
"details": r_plan.error,
|
| 117 |
"traces": traces,
|
| 118 |
}
|
| 119 |
|
|
|
|
| 143 |
return {
|
| 144 |
"ambiguous": False,
|
| 145 |
"error": True,
|
| 146 |
+
"details": r_safe.error,
|
| 147 |
"traces": traces,
|
| 148 |
}
|
| 149 |
|
|
|
|
| 151 |
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
|
| 152 |
traces.extend(self._trace_list(r_exec))
|
| 153 |
if not r_exec.ok:
|
| 154 |
+
details.extend(r_exec.error or [])
|
| 155 |
|
| 156 |
# --- 6) verifier
|
| 157 |
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
|
|
|
|
| 174 |
r_safe = self._safe_stage(self.safety.check, sql=sql)
|
| 175 |
traces.extend(self._trace_list(r_safe))
|
| 176 |
if not r_safe.ok:
|
| 177 |
+
details.extend(r_safe.error or [])
|
| 178 |
continue
|
| 179 |
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
|
| 180 |
traces.extend(self._trace_list(r_exec))
|
| 181 |
if not r_exec.ok:
|
| 182 |
+
details.extend(r_exec.error or [])
|
| 183 |
continue
|
| 184 |
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
|
| 185 |
traces.extend(self._trace_list(r_ver))
|
nl2sql/verifier.py
CHANGED
|
@@ -14,7 +14,7 @@ class Verifier:
|
|
| 14 |
trace=StageTrace(
|
| 15 |
stage=self.name, duration_ms=0, notes={"reason": "execution_error"}
|
| 16 |
),
|
| 17 |
-
error=exec_result.
|
| 18 |
)
|
| 19 |
|
| 20 |
# Rule 1: check SELECT / GROUP consistency
|
|
|
|
| 14 |
trace=StageTrace(
|
| 15 |
stage=self.name, duration_ms=0, notes={"reason": "execution_error"}
|
| 16 |
),
|
| 17 |
+
error=exec_result.error,
|
| 18 |
)
|
| 19 |
|
| 20 |
# Rule 1: check SELECT / GROUP consistency
|