Melika Kheirieh commited on
Commit
713d3ca
·
1 Parent(s): 105e019

Fix some typo

Browse files
app/main.py CHANGED
@@ -1,9 +1,9 @@
1
  from dotenv import load_dotenv
 
 
2
 
3
  load_dotenv()
4
 
5
- from fastapi import FastAPI
6
- from app.routers import nl2sql
7
 
8
  app = FastAPI(
9
  title="NL2SQL Copilot Prototype",
 
1
  from dotenv import load_dotenv
2
+ from fastapi import FastAPI
3
+ from app.routers import nl2sql
4
 
5
  load_dotenv()
6
 
 
 
7
 
8
  app = FastAPI(
9
  title="NL2SQL Copilot Prototype",
benchmarks/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .spider_loader import load_spider_sqlite, open_readonly_connection
2
 
3
- __all__ = ["load_spider_sqlite"]
 
1
  from .spider_loader import load_spider_sqlite, open_readonly_connection
2
 
3
+ __all__ = ["load_spider_sqlite", "open_readonly_connection"]
nl2sql/pipeline.py CHANGED
@@ -9,7 +9,7 @@ from nl2sql.safety import Safety
9
  from nl2sql.executor import Executor
10
  from nl2sql.verifier import Verifier
11
  from nl2sql.repair import Repair
12
-
13
 
14
  class Pipeline:
15
  """
@@ -24,17 +24,17 @@ class Pipeline:
24
  planner: Planner,
25
  generator: Generator,
26
  safety: Safety,
27
- executor: Executor,
28
- verifier: Verifier,
29
- repair: Repair,
30
  ):
31
  self.detector = detector
32
  self.planner = planner
33
  self.generator = generator
34
  self.safety = safety
35
- self.executor = executor
36
- self.verifier = verifier
37
- self.repair = repair
38
 
39
  # ------------------------------------------------------------
40
  def _trace_list(self, *stages: StageResult) -> List[dict]:
@@ -59,7 +59,7 @@ class Pipeline:
59
  return StageResult(ok=True, data=r, trace=None)
60
  except Exception as e:
61
  tb = traceback.format_exc()
62
- return StageResult(ok=False, data=None, trace=None, errors=[f"{e}", tb])
63
 
64
  # ------------------------------------------------------------
65
  def run(
@@ -113,7 +113,7 @@ class Pipeline:
113
  return {
114
  "ambiguous": False,
115
  "error": True,
116
- "details": r_plan.errors,
117
  "traces": traces,
118
  }
119
 
@@ -143,7 +143,7 @@ class Pipeline:
143
  return {
144
  "ambiguous": False,
145
  "error": True,
146
- "details": r_safe.errors,
147
  "traces": traces,
148
  }
149
 
@@ -151,7 +151,7 @@ class Pipeline:
151
  r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
152
  traces.extend(self._trace_list(r_exec))
153
  if not r_exec.ok:
154
- details.extend(r_exec.errors or [])
155
 
156
  # --- 6) verifier
157
  r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
@@ -174,12 +174,12 @@ class Pipeline:
174
  r_safe = self._safe_stage(self.safety.check, sql=sql)
175
  traces.extend(self._trace_list(r_safe))
176
  if not r_safe.ok:
177
- details.extend(r_safe.errors or [])
178
  continue
179
  r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
180
  traces.extend(self._trace_list(r_exec))
181
  if not r_exec.ok:
182
- details.extend(r_exec.errors or [])
183
  continue
184
  r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
185
  traces.extend(self._trace_list(r_ver))
 
9
  from nl2sql.executor import Executor
10
  from nl2sql.verifier import Verifier
11
  from nl2sql.repair import Repair
12
+ from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
13
 
14
  class Pipeline:
15
  """
 
24
  planner: Planner,
25
  generator: Generator,
26
  safety: Safety,
27
+ executor: Optional[Executor] = None,
28
+ verifier: Optional[Verifier] = None ,
29
+ repair: Optional[Repair] = None,
30
  ):
31
  self.detector = detector
32
  self.planner = planner
33
  self.generator = generator
34
  self.safety = safety
35
+ self.executor = executor or NoOpExecutor()
36
+ self.verifier = verifier or NoOpVerifier()
37
+ self.repair = repair or NoOpRepair()
38
 
39
  # ------------------------------------------------------------
40
  def _trace_list(self, *stages: StageResult) -> List[dict]:
 
59
  return StageResult(ok=True, data=r, trace=None)
60
  except Exception as e:
61
  tb = traceback.format_exc()
62
+ return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
63
 
64
  # ------------------------------------------------------------
65
  def run(
 
113
  return {
114
  "ambiguous": False,
115
  "error": True,
116
+ "details": r_plan.error,
117
  "traces": traces,
118
  }
119
 
 
143
  return {
144
  "ambiguous": False,
145
  "error": True,
146
+ "details": r_safe.error,
147
  "traces": traces,
148
  }
149
 
 
151
  r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
152
  traces.extend(self._trace_list(r_exec))
153
  if not r_exec.ok:
154
+ details.extend(r_exec.error or [])
155
 
156
  # --- 6) verifier
157
  r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
 
174
  r_safe = self._safe_stage(self.safety.check, sql=sql)
175
  traces.extend(self._trace_list(r_safe))
176
  if not r_safe.ok:
177
+ details.extend(r_safe.error or [])
178
  continue
179
  r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
180
  traces.extend(self._trace_list(r_exec))
181
  if not r_exec.ok:
182
+ details.extend(r_exec.error or [])
183
  continue
184
  r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
185
  traces.extend(self._trace_list(r_ver))
nl2sql/verifier.py CHANGED
@@ -14,7 +14,7 @@ class Verifier:
14
  trace=StageTrace(
15
  stage=self.name, duration_ms=0, notes={"reason": "execution_error"}
16
  ),
17
- error=exec_result.errors,
18
  )
19
 
20
  # Rule 1: check SELECT / GROUP consistency
 
14
  trace=StageTrace(
15
  stage=self.name, duration_ms=0, notes={"reason": "execution_error"}
16
  ),
17
+ error=exec_result.error,
18
  )
19
 
20
  # Rule 1: check SELECT / GROUP consistency