Melika Kheirieh commited on
Commit
72c0821
·
1 Parent(s): 575394d

test-mode: stub runner in router; factory stubs accept positional calls

Browse files
Files changed (2) hide show
  1. app/routers/nl2sql.py +21 -2
  2. nl2sql/pipeline_factory.py +24 -45
app/routers/nl2sql.py CHANGED
@@ -23,13 +23,32 @@ from nl2sql.pipeline_factory import (
23
  pipeline_from_config_with_adapter,
24
  )
25
 
26
- _PIPELINE: Optional[Any] = None
27
 
28
  Runner = Callable[..., _FinalResult]
29
 
30
 
31
  def get_runner() -> Runner:
32
- """Build pipeline lazily so PYTEST_CURRENT_TEST is respected."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  global _PIPELINE
34
  if _PIPELINE is None:
35
  _PIPELINE = pipeline_from_config(CONFIG_PATH)
 
23
  pipeline_from_config_with_adapter,
24
  )
25
 
26
+ _PIPELINE: Optional[Any] = None # lazy cache
27
 
28
  Runner = Callable[..., _FinalResult]
29
 
30
 
31
  def get_runner() -> Runner:
32
+ """Build pipeline lazily; under pytest return a stub runner."""
33
+ if os.getenv("PYTEST_CURRENT_TEST"):
34
+ # Minimal OK runner for route tests (no ambiguity)
35
+ def _fake_runner(
36
+ *, user_query: str, schema_preview: str | None = None
37
+ ) -> _FinalResult:
38
+ return _FinalResult(
39
+ ok=True,
40
+ ambiguous=False,
41
+ error=False,
42
+ details=None,
43
+ questions=None,
44
+ sql="SELECT 1;",
45
+ rationale=None,
46
+ verified=True,
47
+ traces=[],
48
+ )
49
+
50
+ return _fake_runner
51
+
52
  global _PIPELINE
53
  if _PIPELINE is None:
54
  _PIPELINE = pipeline_from_config(CONFIG_PATH)
nl2sql/pipeline_factory.py CHANGED
@@ -77,11 +77,9 @@ def pipeline_from_config(path: str) -> Pipeline:
77
  llm = _build_llm(llm_cfg)
78
 
79
  if is_pytest:
80
- # ---------- full stubs (detector/planner/generator/executor/verifier/repair) ----------
81
  class _StubDetector:
82
- def detect(
83
- self, *, user_query: str, schema_preview: Optional[str] = None
84
- ) -> StageResult:
85
  return StageResult(
86
  ok=True,
87
  data={"questions": []},
@@ -92,44 +90,30 @@ def pipeline_from_config(path: str) -> Pipeline:
92
  },
93
  )
94
 
95
- # compatibility if somewhere calls run():
96
- def run(
97
- self, *, user_query: str, schema_preview: Optional[str] = None
98
- ) -> StageResult:
99
- return self.detect(user_query=user_query, schema_preview=schema_preview)
100
 
101
  class _StubPlanner:
102
  def __init__(self, llm: Any = None) -> None: ...
103
 
104
- def plan(
105
- self, *, user_query: str, schema_preview: Optional[str] = None
106
- ) -> StageResult:
107
  return StageResult(
108
  ok=True,
109
  data={"plan": "stub plan"},
110
  trace={
111
  "stage": "planner",
112
  "duration_ms": 0,
113
- "notes": {"len_plan": 8},
114
  },
115
  )
116
 
117
- def run(
118
- self, *, user_query: str, schema_preview: Optional[str] = None
119
- ) -> StageResult:
120
- return self.plan(user_query=user_query, schema_preview=schema_preview)
121
 
122
  class _StubGenerator:
123
  def __init__(self, llm: Any = None) -> None: ...
124
 
125
- def generate(
126
- self,
127
- *,
128
- user_query: str,
129
- schema_preview: Optional[str] = None,
130
- plan_text: Optional[str] = None,
131
- clarify_answers: Optional[Dict[str, Any]] = None,
132
- ) -> StageResult:
133
  return StageResult(
134
  ok=True,
135
  data={"sql": "SELECT 1;", "rationale": "stub"},
@@ -140,56 +124,51 @@ def pipeline_from_config(path: str) -> Pipeline:
140
  },
141
  )
142
 
143
- def run(self, **kwargs) -> StageResult:
144
- return self.generate(**kwargs)
145
 
146
  class _StubExecutor:
147
- def __init__(self, db: DBAdapter | None = None) -> None: ...
148
 
149
- def execute(self, *, sql: str) -> StageResult:
150
  rows = [{"x": 1}]
151
  return StageResult(
152
  ok=True,
153
- data={"rows": rows, "row_count": len(rows)},
154
  trace={
155
  "stage": "executor",
156
  "duration_ms": 0,
157
- "notes": {"row_count": len(rows)},
158
  },
159
  )
160
 
161
- def run(self, *, sql: str) -> StageResult:
162
- return self.execute(sql=sql)
163
 
164
  class _StubVerifier:
165
- def verify(self, *, sql: str, exec_result: Dict[str, Any]) -> StageResult:
166
  return StageResult(
167
  ok=True,
168
  data={"verified": True},
169
  trace={"stage": "verifier", "duration_ms": 0, "notes": None},
170
  )
171
 
172
- def run(self, *, sql: str, exec_result: Dict[str, Any]) -> StageResult:
173
- return self.verify(sql=sql, exec_result=exec_result)
174
 
175
  class _StubRepair:
176
  def __init__(self, llm: Any = None) -> None: ...
177
 
178
- def repair(
179
- self, *, sql: str, error_msg: str, schema_preview: Optional[str] = None
180
- ) -> StageResult:
181
  return StageResult(
182
  ok=True,
183
  data={"sql": sql},
184
  trace={"stage": "repair", "duration_ms": 0, "notes": None},
185
  )
186
 
187
- def run(
188
- self, *, sql: str, error_msg: str, schema_preview: Optional[str] = None
189
- ) -> StageResult:
190
- return self.repair(
191
- sql=sql, error_msg=error_msg, schema_preview=schema_preview
192
- )
193
 
194
  detector = _StubDetector()
195
  planner = _StubPlanner()
 
77
  llm = _build_llm(llm_cfg)
78
 
79
  if is_pytest:
80
+
81
  class _StubDetector:
82
+ def detect(self, *args, **kwargs) -> StageResult:
 
 
83
  return StageResult(
84
  ok=True,
85
  data={"questions": []},
 
90
  },
91
  )
92
 
93
+ def run(self, *args, **kwargs) -> StageResult:
94
+ return self.detect(*args, **kwargs)
 
 
 
95
 
96
  class _StubPlanner:
97
  def __init__(self, llm: Any = None) -> None: ...
98
 
99
+ def plan(self, *args, **kwargs) -> StageResult:
 
 
100
  return StageResult(
101
  ok=True,
102
  data={"plan": "stub plan"},
103
  trace={
104
  "stage": "planner",
105
  "duration_ms": 0,
106
+ "notes": {"len_plan": 9},
107
  },
108
  )
109
 
110
+ def run(self, *args, **kwargs) -> StageResult:
111
+ return self.plan(*args, **kwargs)
 
 
112
 
113
  class _StubGenerator:
114
  def __init__(self, llm: Any = None) -> None: ...
115
 
116
+ def generate(self, *args, **kwargs) -> StageResult:
 
 
 
 
 
 
 
117
  return StageResult(
118
  ok=True,
119
  data={"sql": "SELECT 1;", "rationale": "stub"},
 
124
  },
125
  )
126
 
127
+ def run(self, *args, **kwargs) -> StageResult:
128
+ return self.generate(*args, **kwargs)
129
 
130
  class _StubExecutor:
131
+ def __init__(self, db: Any | None = None) -> None: ...
132
 
133
+ def execute(self, *args, **kwargs) -> StageResult:
134
  rows = [{"x": 1}]
135
  return StageResult(
136
  ok=True,
137
+ data={"rows": rows, "row_count": 1},
138
  trace={
139
  "stage": "executor",
140
  "duration_ms": 0,
141
+ "notes": {"row_count": 1},
142
  },
143
  )
144
 
145
+ def run(self, *args, **kwargs) -> StageResult:
146
+ return self.execute(*args, **kwargs)
147
 
148
  class _StubVerifier:
149
+ def verify(self, *args, **kwargs) -> StageResult:
150
  return StageResult(
151
  ok=True,
152
  data={"verified": True},
153
  trace={"stage": "verifier", "duration_ms": 0, "notes": None},
154
  )
155
 
156
+ def run(self, *args, **kwargs) -> StageResult:
157
+ return self.verify(*args, **kwargs)
158
 
159
  class _StubRepair:
160
  def __init__(self, llm: Any = None) -> None: ...
161
 
162
+ def repair(self, *args, **kwargs) -> StageResult:
163
+ sql = kwargs.get("sql") or "SELECT 1;"
 
164
  return StageResult(
165
  ok=True,
166
  data={"sql": sql},
167
  trace={"stage": "repair", "duration_ms": 0, "notes": None},
168
  )
169
 
170
+ def run(self, *args, **kwargs) -> StageResult:
171
+ return self.repair(*args, **kwargs)
 
 
 
 
172
 
173
  detector = _StubDetector()
174
  planner = _StubPlanner()