Melika Kheirieh commited on
Commit
79a5f4a
·
1 Parent(s): 8618ece

feat(trace): standardize StageTrace (add summary) and coerce duration_ms to int at API boundary

Browse files
Files changed (3) hide show
  1. app/routers/nl2sql.py +26 -71
  2. nl2sql/pipeline.py +83 -31
  3. nl2sql/types.py +2 -1
app/routers/nl2sql.py CHANGED
@@ -57,16 +57,9 @@ def get_runner() -> Runner:
57
 
58
  def _build_pipeline(adapter) -> Any:
59
  """Thin wrapper for tests to monkeypatch; builds a pipeline bound to adapter."""
60
-
61
  return pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
62
 
63
 
64
- #
65
- # # Stable public re-exports
66
- # Pipeline = _Pipeline
67
- # FinalResult = _FinalResult
68
- # __all__ = ["Pipeline", "FinalResult"]
69
-
70
  router = APIRouter(prefix="/nl2sql")
71
 
72
  # -------------------------------
@@ -148,7 +141,6 @@ _load_db_map()
148
  # -------------------------------
149
  # Adapter selection (lazy)
150
  # -------------------------------
151
- # ---------- SELECT ADAPTER ----------
152
  def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
153
  """
154
  Resolve a DB adapter based on module-level DB_MODE and an optional db_id.
@@ -207,66 +199,8 @@ def _get_llm() -> OpenAIProvider:
207
  return OpenAIProvider()
208
 
209
 
210
- # def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
211
- # """
212
- # Build a fresh Pipeline bound to the given adapter.
213
- # All stateful/external pieces (LLM, executor) are instantiated here (lazy).
214
- # """
215
- # llm = _get_llm()
216
- # detector = AmbiguityDetector()
217
- # planner = Planner(llm=llm)
218
- # generator = Generator(llm=llm)
219
- # safety = Safety()
220
- # executor = Executor(adapter)
221
- # verifier = Verifier()
222
- # repair = Repair(llm=llm)
223
- # return Pipeline(
224
- # detector=detector,
225
- # planner=planner,
226
- # generator=generator,
227
- # safety=safety,
228
- # executor=executor,
229
- # verifier=verifier,
230
- # repair=repair,
231
- # )
232
-
233
-
234
- # -------------------------------
235
- # Dependency-injected runner
236
- # -------------------------------
237
- # class Runner(Protocol):
238
- # def __call__(
239
- # self, *, user_query: str, schema_preview: str | None = None
240
- # ) -> FinalResult: ...
241
- #
242
- #
243
- # def get_runner(request: Request) -> Runner:
244
- # """
245
- # Returns a callable runner. Preferred path in production:
246
- # - app.state.pipeline_runner (if set) -> used (e.g., tests or special wiring)
247
- # - app.state.pipeline -> reuse existing
248
- # - else build default pipeline lazily and cache
249
- # """
250
- # runner: Optional[Runner] = getattr(request.app.state, "pipeline_runner", None) # type: ignore[attr-defined]
251
- # if runner:
252
- # return runner
253
- #
254
- # pipeline: Optional[Pipeline] = getattr(request.app.state, "pipeline", None) # type: ignore[attr-defined]
255
- # if pipeline is None:
256
- # # Build a default pipeline lazily (no side-effect on import)
257
- # adapter = _select_adapter(db_id=None)
258
- # try:
259
- # pipeline = _build_pipeline(adapter)
260
- # request.app.state.pipeline = pipeline # type: ignore[attr-defined]
261
- # except Exception as exc:
262
- # raise HTTPException(
263
- # status_code=500, detail=f"Pipeline unavailable: {exc!s}"
264
- # )
265
- # return pipeline.run # type: ignore[return-value]
266
-
267
-
268
  # -------------------------------
269
- # Helpers (unchanged)
270
  # -------------------------------
271
  def _to_dict(obj: Any) -> Any:
272
  if is_dataclass(obj) and not isinstance(obj, type):
@@ -275,29 +209,50 @@ def _to_dict(obj: Any) -> Any:
275
 
276
 
277
  def _round_trace(t: Any) -> Dict[str, Any]:
278
- """Normalize a trace entry to a dict and coerce duration_ms to int."""
 
 
 
 
 
 
 
279
  if isinstance(t, dict):
280
  stage = t.get("stage", "?")
281
  ms = t.get("duration_ms", 0)
282
  notes = t.get("notes")
283
  cost = t.get("cost_usd")
 
 
 
284
  else:
285
  stage = getattr(t, "stage", "?")
286
  ms = getattr(t, "duration_ms", 0)
287
  notes = getattr(t, "notes", None)
288
  cost = getattr(t, "cost_usd", None)
 
 
 
289
 
 
290
  try:
291
- ms_int = int(ms) if ms is not None else 0
292
  except Exception:
293
  ms_int = 0
294
 
295
- return {
296
  "stage": str(stage) if stage is not None else "?",
297
  "duration_ms": ms_int,
298
  "notes": notes,
299
  "cost_usd": cost,
300
  }
 
 
 
 
 
 
 
301
 
302
 
303
  # -------------------------------
@@ -391,7 +346,7 @@ def nl2sql_handler(
391
  message = "; ".join(result.details or []) or "Unknown error"
392
  raise HTTPException(status_code=400, detail=message)
393
 
394
- # Success path → 200
395
  traces = [_round_trace(t) for t in (result.traces or [])]
396
  return NL2SQLResponse(
397
  ambiguous=False,
 
57
 
58
  def _build_pipeline(adapter) -> Any:
59
  """Thin wrapper for tests to monkeypatch; builds a pipeline bound to adapter."""
 
60
  return pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
61
 
62
 
 
 
 
 
 
 
63
  router = APIRouter(prefix="/nl2sql")
64
 
65
  # -------------------------------
 
141
  # -------------------------------
142
  # Adapter selection (lazy)
143
  # -------------------------------
 
144
  def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
145
  """
146
  Resolve a DB adapter based on module-level DB_MODE and an optional db_id.
 
199
  return OpenAIProvider()
200
 
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  # -------------------------------
203
+ # Helpers
204
  # -------------------------------
205
  def _to_dict(obj: Any) -> Any:
206
  if is_dataclass(obj) and not isinstance(obj, type):
 
209
 
210
 
211
  def _round_trace(t: Any) -> Dict[str, Any]:
212
+ """
213
+ Normalize a trace entry (dict or StageTrace-like object) for API/UI:
214
+ - stage: str (required)
215
+ - duration_ms: int (rounded)
216
+ - summary: optional (pass-through if exists)
217
+ - notes: optional
218
+ - token_in/out, cost_usd: pass-through if present
219
+ """
220
  if isinstance(t, dict):
221
  stage = t.get("stage", "?")
222
  ms = t.get("duration_ms", 0)
223
  notes = t.get("notes")
224
  cost = t.get("cost_usd")
225
+ summary = t.get("summary")
226
+ token_in = t.get("token_in")
227
+ token_out = t.get("token_out")
228
  else:
229
  stage = getattr(t, "stage", "?")
230
  ms = getattr(t, "duration_ms", 0)
231
  notes = getattr(t, "notes", None)
232
  cost = getattr(t, "cost_usd", None)
233
+ summary = getattr(t, "summary", None)
234
+ token_in = getattr(t, "token_in", None)
235
+ token_out = getattr(t, "token_out", None)
236
 
237
+ # coerce duration to int with rounding
238
  try:
239
+ ms_int = int(round(float(ms))) if ms is not None else 0
240
  except Exception:
241
  ms_int = 0
242
 
243
+ out: Dict[str, Any] = {
244
  "stage": str(stage) if stage is not None else "?",
245
  "duration_ms": ms_int,
246
  "notes": notes,
247
  "cost_usd": cost,
248
  }
249
+ if summary is not None:
250
+ out["summary"] = summary
251
+ if token_in is not None:
252
+ out["token_in"] = token_in
253
+ if token_out is not None:
254
+ out["token_out"] = token_out
255
+ return out
256
 
257
 
258
  # -------------------------------
 
346
  message = "; ".join(result.details or []) or "Unknown error"
347
  raise HTTPException(status_code=400, detail=message)
348
 
349
+ # Success path → 200 (coerce/standardize traces for API)
350
  traces = [_round_trace(t) for t in (result.traces or [])]
351
  return NL2SQLResponse(
352
  ambiguous=False,
nl2sql/pipeline.py CHANGED
@@ -68,6 +68,61 @@ class Pipeline:
68
  traces.append(getattr(t, "__dict__", t))
69
  return traces
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # ------------------------------------------------------------
72
  @staticmethod
73
  def _safe_stage(fn, **kwargs) -> StageResult:
@@ -84,18 +139,6 @@ class Pipeline:
84
  tb = traceback.format_exc()
85
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
86
 
87
- # ------------------------------------------------------------
88
- @staticmethod
89
- def _mk_trace(
90
- stage: str, duration_ms: float, notes: Optional[Dict[str, Any]] = None
91
- ) -> dict:
92
- """Create a normalized trace dict."""
93
- return {
94
- "stage": stage,
95
- "duration_ms": float(duration_ms),
96
- "notes": notes or {},
97
- }
98
-
99
  # ------------------------------------------------------------
100
  def run(
101
  self,
@@ -119,12 +162,14 @@ class Pipeline:
119
  t0 = time.perf_counter()
120
  questions = self.detector.detect(user_query, schema_preview)
121
  t1 = time.perf_counter()
 
122
  traces.append(
123
  self._mk_trace(
124
- "detector",
125
- (t1 - t0) * 1000.0,
126
- {
127
- "ambiguous": bool(questions),
 
128
  "questions_len": len(questions or []),
129
  },
130
  )
@@ -140,11 +185,18 @@ class Pipeline:
140
  sql=None,
141
  rationale=None,
142
  verified=None,
143
- traces=traces,
144
  )
145
  except Exception as e:
146
  # detector crash – mark as error but keep trace so far
147
- traces.append(self._mk_trace("detector", 0.0, {"error": str(e)}))
 
 
 
 
 
 
 
148
  return FinalResult(
149
  ok=False,
150
  ambiguous=True,
@@ -154,7 +206,7 @@ class Pipeline:
154
  sql=None,
155
  rationale=None,
156
  verified=None,
157
- traces=traces,
158
  )
159
 
160
  # --- 2) planner ---
@@ -172,7 +224,7 @@ class Pipeline:
172
  sql=None,
173
  rationale=None,
174
  verified=None,
175
- traces=traces,
176
  )
177
 
178
  # --- 3) generator ---
@@ -194,7 +246,7 @@ class Pipeline:
194
  sql=None,
195
  rationale=None,
196
  verified=None,
197
- traces=traces,
198
  )
199
 
200
  sql = (r_gen.data or {}).get("sql")
@@ -213,7 +265,7 @@ class Pipeline:
213
  sql=sql,
214
  rationale=rationale,
215
  verified=None,
216
- traces=traces,
217
  )
218
 
219
  # --- 5) executor ---
@@ -283,11 +335,10 @@ class Pipeline:
283
  if any_exec_ok:
284
  traces.append(
285
  self._mk_trace(
286
- "pipeline",
287
- 0.0,
288
- {
289
- "auto_fix": "verified=True (executor succeeded, verifier silent)"
290
- },
291
  )
292
  )
293
  verified = True
@@ -299,9 +350,10 @@ class Pipeline:
299
 
300
  traces.append(
301
  self._mk_trace(
302
- "pipeline",
303
- 0.0,
304
- {"final_verified": bool(verified), "details_len": len(details)},
 
305
  )
306
  )
307
 
@@ -314,5 +366,5 @@ class Pipeline:
314
  rationale=rationale,
315
  verified=verified,
316
  questions=None,
317
- traces=traces,
318
  )
 
68
  traces.append(getattr(t, "__dict__", t))
69
  return traces
70
 
71
+ # ------------------------------------------------------------
72
+ @staticmethod
73
+ def _mk_trace(
74
+ stage: str,
75
+ duration_ms: float,
76
+ summary: str,
77
+ notes: Optional[Dict[str, Any]] = None,
78
+ ) -> dict:
79
+ """Create a normalized trace dict (internal: duration may be float)."""
80
+ return {
81
+ "stage": stage,
82
+ "duration_ms": float(duration_ms),
83
+ "summary": summary,
84
+ "notes": notes or {},
85
+ }
86
+
87
+ @staticmethod
88
+ def _normalize_traces(traces: List[dict]) -> List[dict]:
89
+ """
90
+ Normalize trace list for API/UI:
91
+ - coerce duration_ms to int
92
+ - ensure `summary` exists (fallback to a minimal one)
93
+ """
94
+ norm: List[dict] = []
95
+ for t in traces:
96
+ stage = str(t.get("stage", "unknown"))
97
+ dur = t.get("duration_ms", 0)
98
+ try:
99
+ dur_int = int(round(float(dur)))
100
+ except Exception:
101
+ dur_int = 0
102
+ summary = t.get("summary")
103
+ if not summary:
104
+ # fallback summary if not provided by stage
105
+ notes = t.get("notes") or {}
106
+ failed = bool(notes.get("error") or notes.get("errors"))
107
+ summary = "failed" if failed else "ok"
108
+ notes = t.get("notes") or {}
109
+ # preserve any accounting fields if present (token_in/out, cost_usd, ...)
110
+ payload = {
111
+ "stage": stage,
112
+ "duration_ms": dur_int,
113
+ "summary": summary,
114
+ "notes": notes,
115
+ }
116
+ # keep extra accounting if exists
117
+ if "token_in" in t:
118
+ payload["token_in"] = t["token_in"]
119
+ if "token_out" in t:
120
+ payload["token_out"] = t["token_out"]
121
+ if "cost_usd" in t:
122
+ payload["cost_usd"] = t["cost_usd"]
123
+ norm.append(payload)
124
+ return norm
125
+
126
  # ------------------------------------------------------------
127
  @staticmethod
128
  def _safe_stage(fn, **kwargs) -> StageResult:
 
139
  tb = traceback.format_exc()
140
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
141
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  # ------------------------------------------------------------
143
  def run(
144
  self,
 
162
  t0 = time.perf_counter()
163
  questions = self.detector.detect(user_query, schema_preview)
164
  t1 = time.perf_counter()
165
+ is_amb = bool(questions)
166
  traces.append(
167
  self._mk_trace(
168
+ stage="detector",
169
+ duration_ms=(t1 - t0) * 1000.0,
170
+ summary=("ambiguous" if is_amb else "clear"),
171
+ notes={
172
+ "ambiguous": is_amb,
173
  "questions_len": len(questions or []),
174
  },
175
  )
 
185
  sql=None,
186
  rationale=None,
187
  verified=None,
188
+ traces=self._normalize_traces(traces),
189
  )
190
  except Exception as e:
191
  # detector crash – mark as error but keep trace so far
192
+ traces.append(
193
+ self._mk_trace(
194
+ stage="detector",
195
+ duration_ms=0.0,
196
+ summary="failed",
197
+ notes={"error": str(e)},
198
+ )
199
+ )
200
  return FinalResult(
201
  ok=False,
202
  ambiguous=True,
 
206
  sql=None,
207
  rationale=None,
208
  verified=None,
209
+ traces=self._normalize_traces(traces),
210
  )
211
 
212
  # --- 2) planner ---
 
224
  sql=None,
225
  rationale=None,
226
  verified=None,
227
+ traces=self._normalize_traces(traces),
228
  )
229
 
230
  # --- 3) generator ---
 
246
  sql=None,
247
  rationale=None,
248
  verified=None,
249
+ traces=self._normalize_traces(traces),
250
  )
251
 
252
  sql = (r_gen.data or {}).get("sql")
 
265
  sql=sql,
266
  rationale=rationale,
267
  verified=None,
268
+ traces=self._normalize_traces(traces),
269
  )
270
 
271
  # --- 5) executor ---
 
335
  if any_exec_ok:
336
  traces.append(
337
  self._mk_trace(
338
+ stage="pipeline",
339
+ duration_ms=0.0,
340
+ summary="auto-verified",
341
+ notes={"reason": "executor succeeded, verifier silent"},
 
342
  )
343
  )
344
  verified = True
 
350
 
351
  traces.append(
352
  self._mk_trace(
353
+ stage="pipeline",
354
+ duration_ms=0.0,
355
+ summary="finalize",
356
+ notes={"final_verified": bool(verified), "details_len": len(details)},
357
  )
358
  )
359
 
 
366
  rationale=rationale,
367
  verified=verified,
368
  questions=None,
369
+ traces=self._normalize_traces(traces),
370
  )
nl2sql/types.py CHANGED
@@ -5,7 +5,8 @@ from typing import Any, Dict, Optional, List
5
  @dataclass(frozen=True)
6
  class StageTrace:
7
  stage: str
8
- duration_ms: float
 
9
  notes: Optional[Dict[str, Any]] = None
10
  token_in: Optional[int] = None
11
  token_out: Optional[int] = None
 
5
  @dataclass(frozen=True)
6
  class StageTrace:
7
  stage: str
8
+ duration_ms: float # keep float internally if you like
9
+ summary: str = "" # ← default to keep legacy call-sites working
10
  notes: Optional[Dict[str, Any]] = None
11
  token_in: Optional[int] = None
12
  token_out: Optional[int] = None