Melika Kheirieh commited on
Commit
a38f7c7
·
1 Parent(s): f89e294

feat(pipeline): enable SQL-only repair + log-only repair for early stages with full traces/metrics

Browse files
Files changed (1) hide show
  1. nl2sql/pipeline.py +147 -6
nl2sql/pipeline.py CHANGED
@@ -29,6 +29,9 @@ class FinalResult:
29
  traces: List[dict]
30
 
31
 
 
 
 
32
  class Pipeline:
33
  """
34
  NL2SQL Copilot pipeline:
@@ -118,7 +121,119 @@ class Pipeline:
118
  tb = traceback.format_exc()
119
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
120
 
121
- # ------------------------------ run ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def run(
123
  self,
124
  *,
@@ -173,8 +288,14 @@ class Pipeline:
173
 
174
  # --- 2) planner ---
175
  t0 = time.perf_counter()
176
- r_plan = self._safe_stage(
177
- self.planner.run, user_query=user_query, schema_preview=schema_preview
 
 
 
 
 
 
178
  )
179
  dt = (time.perf_counter() - t0) * 1000.0
180
  stage_duration_ms.labels("planner").observe(dt)
@@ -197,12 +318,16 @@ class Pipeline:
197
 
198
  # --- 3) generator ---
199
  t0 = time.perf_counter()
200
- r_gen = self._safe_stage(
 
201
  self.generator.run,
 
 
202
  user_query=user_query,
203
  schema_preview=schema_preview,
204
  plan_text=(r_plan.data or {}).get("plan"),
205
  clarify_answers=clarify_answers,
 
206
  )
207
  dt = (time.perf_counter() - t0) * 1000.0
208
  stage_duration_ms.labels("generator").observe(dt)
@@ -246,7 +371,15 @@ class Pipeline:
246
 
247
  # --- 4) safety ---
248
  t0 = time.perf_counter()
249
- r_safe = self._safe_stage(self.safety.run, sql=sql)
 
 
 
 
 
 
 
 
250
  dt = (time.perf_counter() - t0) * 1000.0
251
  stage_duration_ms.labels("safety").observe(dt)
252
  traces.extend(self._trace_list(r_safe))
@@ -271,7 +404,15 @@ class Pipeline:
271
 
272
  # --- 5) executor ---
273
  t0 = time.perf_counter()
274
- r_exec = self._safe_stage(self.executor.run, sql=sql)
 
 
 
 
 
 
 
 
275
  dt = (time.perf_counter() - t0) * 1000.0
276
  stage_duration_ms.labels("executor").observe(dt)
277
  traces.extend(self._trace_list(r_exec))
 
29
  traces: List[dict]
30
 
31
 
32
+ SQL_REPAIR_STAGES = {"safety", "executor", "verifier"}
33
+
34
+
35
  class Pipeline:
36
  """
37
  NL2SQL Copilot pipeline:
 
121
  tb = traceback.format_exc()
122
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
123
 
124
+ def _run_with_repair(
125
+ self,
126
+ stage_name: str,
127
+ fn,
128
+ *,
129
+ repair_input_builder,
130
+ max_attempts: int = 1,
131
+ traces: list,
132
+ **kwargs,
133
+ ) -> StageResult:
134
+ """
135
+ Run a stage with per-stage repair + full observability integration.
136
+ SQL-only repair occurs for safety/executor/verifier.
137
+ Planner/Generator get log-only repair (trace only, no effect).
138
+ """
139
+ attempt = 0
140
+
141
+ while True:
142
+ # --- 1) Run stage normally ---
143
+ t0 = time.perf_counter()
144
+ r = self._safe_stage(fn, **kwargs)
145
+ dt = (time.perf_counter() - t0) * 1000.0
146
+
147
+ stage_duration_ms.labels(stage_name).observe(dt)
148
+
149
+ # attach stage trace
150
+ if getattr(r, "trace", None):
151
+ traces.append(r.trace.__dict__)
152
+ else:
153
+ traces.append(
154
+ {
155
+ "stage": stage_name,
156
+ "duration_ms": dt,
157
+ "summary": "ok" if r.ok else "failed",
158
+ "notes": {},
159
+ }
160
+ )
161
+
162
+ if r.ok:
163
+ return r
164
+
165
+ # stage failed → check repair availability
166
+ attempt += 1
167
+ if attempt > max_attempts:
168
+ return r
169
+
170
+ # --- 2) Build repair input ---
171
+ repair_args = repair_input_builder(r, kwargs)
172
+
173
+ # --- 3) Run repair (always logged) ---
174
+ repair_attempts_total.labels(outcome="attempt").inc()
175
+ t1 = time.perf_counter()
176
+ r_fix = self._safe_stage(self.repair.run, **repair_args)
177
+ dt_fix = (time.perf_counter() - t1) * 1000.0
178
+
179
+ stage_duration_ms.labels("repair").observe(dt_fix)
180
+
181
+ if getattr(r_fix, "trace", None):
182
+ traces.append(r_fix.trace.__dict__)
183
+ else:
184
+ traces.append(
185
+ {
186
+ "stage": "repair",
187
+ "duration_ms": dt_fix,
188
+ "summary": "ok" if r_fix.ok else "failed",
189
+ "notes": {"stage": stage_name},
190
+ }
191
+ )
192
+
193
+ if not r_fix.ok:
194
+ repair_attempts_total.labels(outcome="failed").inc()
195
+ return r # repair itself failed → stop here
196
+
197
+ # --- 4) Only inject SQL if the stage is an SQL-producing stage ---
198
+ if stage_name in self.SQL_REPAIR_STAGES:
199
+ if "sql" in repair_args and "sql" in kwargs:
200
+ kwargs["sql"] = (r_fix.data or {}).get("sql", kwargs["sql"])
201
+
202
+ # important: success metric must reflect if repair was applied meaningfully
203
+ if stage_name in self.SQL_REPAIR_STAGES:
204
+ repair_attempts_total.labels(outcome="success").inc()
205
+ else:
206
+ # log-only mode counts as a success-attempt but not semantic success
207
+ repair_attempts_total.labels(outcome="success").inc()
208
+
209
+ # for SQL stages, we re-run the stage again with modified kwargs
210
+ # for log-only stages, this simply loops and stage is re-run unchanged
211
+ # (which is correct)
212
+
213
+ @staticmethod
214
+ def _planner_repair_input_builder(stage_result, kwargs):
215
+ return {
216
+ "sql": "",
217
+ "error_msg": "; ".join(stage_result.error or ["planner_failed"]),
218
+ "schema_preview": kwargs.get("schema_preview", ""),
219
+ }
220
+
221
+ @staticmethod
222
+ def _generator_repair_input_builder(stage_result, kwargs):
223
+ return {
224
+ "sql": (stage_result.data or {}).get("sql", ""),
225
+ "error_msg": "; ".join(stage_result.error or ["generator_failed"]),
226
+ "schema_preview": kwargs.get("schema_preview", ""),
227
+ }
228
+
229
+ @staticmethod
230
+ def _sql_repair_input_builder(stage_result, kwargs):
231
+ return {
232
+ "sql": kwargs.get("sql", ""),
233
+ "error_msg": "; ".join(stage_result.error or ["stage_failed"]),
234
+ "schema_preview": kwargs.get("schema_preview", ""),
235
+ }
236
+
237
  def run(
238
  self,
239
  *,
 
288
 
289
  # --- 2) planner ---
290
  t0 = time.perf_counter()
291
+ r_plan = self._run_with_repair(
292
+ "planner",
293
+ self.planner.run,
294
+ repair_input_builder=self._planner_repair_input_builder,
295
+ max_attempts=1,
296
+ user_query=user_query,
297
+ traces=traces,
298
+ schema_preview=schema_preview,
299
  )
300
  dt = (time.perf_counter() - t0) * 1000.0
301
  stage_duration_ms.labels("planner").observe(dt)
 
318
 
319
  # --- 3) generator ---
320
  t0 = time.perf_counter()
321
+ r_gen = self._run_with_repair(
322
+ "generator",
323
  self.generator.run,
324
+ repair_input_builder=self._generator_repair_input_builder,
325
+ max_attempts=1,
326
  user_query=user_query,
327
  schema_preview=schema_preview,
328
  plan_text=(r_plan.data or {}).get("plan"),
329
  clarify_answers=clarify_answers,
330
+ traces=traces,
331
  )
332
  dt = (time.perf_counter() - t0) * 1000.0
333
  stage_duration_ms.labels("generator").observe(dt)
 
371
 
372
  # --- 4) safety ---
373
  t0 = time.perf_counter()
374
+ r_safe = self._run_with_repair(
375
+ "safety",
376
+ self.safety.run,
377
+ repair_input_builder=self._sql_repair_input_builder,
378
+ max_attempts=1,
379
+ sql=sql,
380
+ schema_preview=schema_preview,
381
+ traces=traces,
382
+ )
383
  dt = (time.perf_counter() - t0) * 1000.0
384
  stage_duration_ms.labels("safety").observe(dt)
385
  traces.extend(self._trace_list(r_safe))
 
404
 
405
  # --- 5) executor ---
406
  t0 = time.perf_counter()
407
+ r_exec = self._run_with_repair(
408
+ "executor",
409
+ self.executor.run,
410
+ repair_input_builder=self._sql_repair_input_builder,
411
+ max_attempts=1,
412
+ sql=sql,
413
+ traces=traces,
414
+ schema_preview=schema_preview,
415
+ )
416
  dt = (time.perf_counter() - t0) * 1000.0
417
  stage_duration_ms.labels("executor").observe(dt)
418
  traces.extend(self._trace_list(r_exec))