github-actions[bot] commited on
Commit
7666573
·
1 Parent(s): 11975fd

Sync from GitHub main @ 51304ea81c450e4e5ce90de52f10a63dcca33c64

Browse files
adapters/metrics/base.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
  from abc import ABC, abstractmethod
4
  from typing import Literal
5
 
 
6
  RepairOutcome = Literal["attempt", "success", "failed", "skipped"]
7
 
8
 
@@ -11,7 +12,7 @@ class Metrics(ABC):
11
  def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None: ...
12
 
13
  @abstractmethod
14
- def inc_pipeline_run(self, *, status: str) -> None: ...
15
 
16
  @abstractmethod
17
  def inc_stage_call(self, *, stage: str, ok: bool) -> None: ...
 
3
  from abc import ABC, abstractmethod
4
  from typing import Literal
5
 
6
+ PipelineStatus = Literal["ok", "error", "ambiguous"]
7
  RepairOutcome = Literal["attempt", "success", "failed", "skipped"]
8
 
9
 
 
12
  def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None: ...
13
 
14
  @abstractmethod
15
+ def inc_pipeline_run(self, *, status: PipelineStatus) -> None: ...
16
 
17
  @abstractmethod
18
  def inc_stage_call(self, *, stage: str, ok: bool) -> None: ...
adapters/metrics/noop.py CHANGED
@@ -1,13 +1,13 @@
1
  from __future__ import annotations
2
 
3
- from adapters.metrics.base import Metrics, RepairOutcome
4
 
5
 
6
  class NoOpMetrics(Metrics):
7
  def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None:
8
  return
9
 
10
- def inc_pipeline_run(self, *, status: str) -> None:
11
  return
12
 
13
  def inc_stage_call(self, *, stage: str, ok: bool) -> None:
 
1
  from __future__ import annotations
2
 
3
+ from adapters.metrics.base import Metrics, PipelineStatus, RepairOutcome
4
 
5
 
6
  class NoOpMetrics(Metrics):
7
  def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None:
8
  return
9
 
10
+ def inc_pipeline_run(self, *, status: PipelineStatus) -> None:
11
  return
12
 
13
  def inc_stage_call(self, *, stage: str, ok: bool) -> None:
adapters/metrics/prometheus.py CHANGED
@@ -1,50 +1,154 @@
1
  from __future__ import annotations
2
 
3
- from prometheus_client import Counter
4
- from adapters.metrics.base import Metrics, RepairOutcome
5
- from nl2sql.metrics import stage_duration_ms, pipeline_runs_total
6
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  stage_calls_total = Counter(
9
  "stage_calls_total",
10
- "Total number of stage calls by stage and success",
11
  ["stage", "ok"],
 
12
  )
13
 
14
  stage_errors_total = Counter(
15
  "stage_errors_total",
16
- "Total number of stage errors by stage and error code",
17
  ["stage", "error_code"],
 
18
  )
19
 
 
 
 
20
  repair_attempts_total = Counter(
21
  "repair_attempts_total",
22
- "Total repair attempts by stage and outcome",
23
  ["stage", "outcome"],
 
24
  )
25
 
26
  repair_trigger_total = Counter(
27
  "repair_trigger_total",
28
- "Total repair triggers by stage and reason",
29
  ["stage", "reason"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
31
 
32
 
33
  class PrometheusMetrics(Metrics):
34
  def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None:
35
- stage_duration_ms.labels(stage=stage).observe(dt_ms)
36
 
37
- def inc_pipeline_run(self, *, status: str) -> None:
38
  pipeline_runs_total.labels(status=status).inc()
39
 
40
  def inc_stage_call(self, *, stage: str, ok: bool) -> None:
41
- stage_calls_total.labels(stage=stage, ok=str(ok).lower()).inc()
42
 
43
  def inc_stage_error(self, *, stage: str, error_code: str) -> None:
44
- stage_errors_total.labels(stage=stage, error_code=error_code).inc()
45
 
46
  def inc_repair_trigger(self, *, stage: str, reason: str) -> None:
47
- repair_trigger_total.labels(stage=stage, reason=reason).inc()
48
 
49
  def inc_repair_attempt(self, *, stage: str, outcome: RepairOutcome) -> None:
50
  repair_attempts_total.labels(stage=stage, outcome=outcome).inc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ from prometheus_client import Counter, Histogram
4
+ from nl2sql.prom import REGISTRY
 
5
 
6
+ from adapters.metrics.base import Metrics, PipelineStatus, RepairOutcome
7
+
8
+ # -----------------------------------------------------------------------------
9
+ # Stage-level metrics
10
+ # -----------------------------------------------------------------------------
11
+ stage_duration_ms = Histogram(
12
+ "stage_duration_ms",
13
+ "Duration (ms) of each pipeline stage",
14
+ ["stage"],
15
+ buckets=(1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000),
16
+ registry=REGISTRY,
17
+ )
18
 
19
  stage_calls_total = Counter(
20
  "stage_calls_total",
21
+ "Count of stage calls labeled by stage and ok",
22
  ["stage", "ok"],
23
+ registry=REGISTRY,
24
  )
25
 
26
  stage_errors_total = Counter(
27
  "stage_errors_total",
28
+ "Count of stage errors labeled by stage and error_code",
29
  ["stage", "error_code"],
30
+ registry=REGISTRY,
31
  )
32
 
33
+ # -----------------------------------------------------------------------------
34
+ # Repair metrics (Day 3 contract)
35
+ # -----------------------------------------------------------------------------
36
  repair_attempts_total = Counter(
37
  "repair_attempts_total",
38
+ "Count of repair attempts labeled by stage and outcome",
39
  ["stage", "outcome"],
40
+ registry=REGISTRY,
41
  )
42
 
43
  repair_trigger_total = Counter(
44
  "repair_trigger_total",
45
+ "Count of repair triggers labeled by stage and reason",
46
  ["stage", "reason"],
47
+ registry=REGISTRY,
48
+ )
49
+
50
+ # -----------------------------------------------------------------------------
51
+ # Safety stage metrics (existing / optional)
52
+ # -----------------------------------------------------------------------------
53
+ safety_blocks_total = Counter(
54
+ "safety_blocks_total",
55
+ "Count of blocked SQL queries by safety checks",
56
+ ["reason"],
57
+ registry=REGISTRY,
58
+ )
59
+
60
+ safety_checks_total = Counter(
61
+ "safety_checks_total",
62
+ "Total SQL queries checked by safety",
63
+ ["ok"],
64
+ registry=REGISTRY,
65
+ )
66
+
67
+ # -----------------------------------------------------------------------------
68
+ # Verifier stage metrics (existing / optional)
69
+ # -----------------------------------------------------------------------------
70
+ verifier_checks_total = Counter(
71
+ "verifier_checks_total",
72
+ "Count of verifier checks (success/failure)",
73
+ ["ok"],
74
+ registry=REGISTRY,
75
+ )
76
+
77
+ verifier_failures_total = Counter(
78
+ "verifier_failures_total",
79
+ "Count of verifier failures by type",
80
+ ["reason"],
81
+ registry=REGISTRY,
82
+ )
83
+
84
+ # -----------------------------------------------------------------------------
85
+ # Pipeline-level metrics
86
+ # -----------------------------------------------------------------------------
87
+ pipeline_runs_total = Counter(
88
+ "pipeline_runs_total",
89
+ "Total number of full pipeline runs",
90
+ ["status"],
91
+ registry=REGISTRY,
92
+ )
93
+
94
+ # -----------------------------------------------------------------------------
95
+ # Cache metrics (optional)
96
+ # -----------------------------------------------------------------------------
97
+ cache_events_total = Counter(
98
+ "cache_events_total",
99
+ "Cache hit/miss events in the pipeline",
100
+ ["hit"],
101
+ registry=REGISTRY,
102
  )
103
 
104
 
105
  class PrometheusMetrics(Metrics):
106
  def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None:
107
+ stage_duration_ms.labels(stage=stage).observe(float(dt_ms))
108
 
109
+ def inc_pipeline_run(self, *, status: PipelineStatus) -> None:
110
  pipeline_runs_total.labels(status=status).inc()
111
 
112
  def inc_stage_call(self, *, stage: str, ok: bool) -> None:
113
+ stage_calls_total.labels(stage=stage, ok=("true" if ok else "false")).inc()
114
 
115
  def inc_stage_error(self, *, stage: str, error_code: str) -> None:
116
+ stage_errors_total.labels(stage=stage, error_code=str(error_code)).inc()
117
 
118
  def inc_repair_trigger(self, *, stage: str, reason: str) -> None:
119
+ repair_trigger_total.labels(stage=stage, reason=str(reason)).inc()
120
 
121
  def inc_repair_attempt(self, *, stage: str, outcome: RepairOutcome) -> None:
122
  repair_attempts_total.labels(stage=stage, outcome=outcome).inc()
123
+
124
+
125
+ # -----------------------------------------------------------------------------
126
+ # Label priming to keep /metrics stable
127
+ # -----------------------------------------------------------------------------
128
+ for ok in ("true", "false"):
129
+ safety_checks_total.labels(ok=ok).inc(0)
130
+ verifier_checks_total.labels(ok=ok).inc(0)
131
+
132
+ for status in ("ok", "error", "ambiguous"):
133
+ pipeline_runs_total.labels(status=status).inc(0)
134
+
135
+ for hit in ("true", "false"):
136
+ cache_events_total.labels(hit=hit).inc(0)
137
+
138
+ # Prime Day 3 series
139
+ for stage in (
140
+ "detector",
141
+ "planner",
142
+ "generator",
143
+ "safety",
144
+ "executor",
145
+ "verifier",
146
+ "repair",
147
+ ):
148
+ for ok in ("true", "false"):
149
+ stage_calls_total.labels(stage=stage, ok=ok).inc(0)
150
+ for outcome in ("attempt", "success", "failed", "skipped"):
151
+ repair_attempts_total.labels(stage=stage, outcome=outcome).inc(0)
152
+
153
+ for reason in ("semantic_failure", "unknown"):
154
+ repair_trigger_total.labels(stage="verifier", reason=reason).inc(0)
app/services/nl2sql_service.py CHANGED
@@ -9,6 +9,8 @@ from nl2sql.pipeline import FinalResult
9
  from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
10
  from adapters.db.sqlite_adapter import SQLiteAdapter
11
  from adapters.db.postgres_adapter import PostgresAdapter
 
 
12
  from app import state
13
  from app.settings import Settings
14
  from app.errors import (
@@ -65,7 +67,6 @@ class NL2SQLService:
65
  This is a straight port of the previous sqlite3 logic, but contained
66
  inside the service instead of the router.
67
  """
68
- # Try to locate the underlying .db path from the adapter
69
  db_path = getattr(adapter, "db_path", None) or getattr(adapter, "path", None)
70
  if not db_path:
71
  raise RuntimeError(
@@ -105,8 +106,7 @@ class NL2SQLService:
105
 
106
  - If override is provided by the client → use it.
107
  - Else, in sqlite mode → introspect the DB.
108
- - In postgres mode without override → fail fast, the caller can map
109
- this to a proper HTTP error.
110
  """
111
  if override:
112
  return override
@@ -152,6 +152,13 @@ class NL2SQLService:
152
  f"Failed to build pipeline from {self.settings.pipeline_config_path!r}: {exc}"
153
  ) from exc
154
 
 
 
 
 
 
 
 
155
  try:
156
  result = pipeline.run(user_query=query, schema_preview=schema_preview)
157
  except AppError:
 
9
  from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
10
  from adapters.db.sqlite_adapter import SQLiteAdapter
11
  from adapters.db.postgres_adapter import PostgresAdapter
12
+ from adapters.metrics.prometheus import PrometheusMetrics
13
+
14
  from app import state
15
  from app.settings import Settings
16
  from app.errors import (
 
67
  This is a straight port of the previous sqlite3 logic, but contained
68
  inside the service instead of the router.
69
  """
 
70
  db_path = getattr(adapter, "db_path", None) or getattr(adapter, "path", None)
71
  if not db_path:
72
  raise RuntimeError(
 
106
 
107
  - If override is provided by the client → use it.
108
  - Else, in sqlite mode → introspect the DB.
109
+ - In postgres mode without override → fail fast.
 
110
  """
111
  if override:
112
  return override
 
152
  f"Failed to build pipeline from {self.settings.pipeline_config_path!r}: {exc}"
153
  ) from exc
154
 
155
+ # Force PrometheusMetrics to avoid silent NoOp wiring via factory defaults.
156
+ if (
157
+ getattr(pipeline, "metrics", None) is None
158
+ or pipeline.metrics.__class__.__name__ == "NoOpMetrics"
159
+ ):
160
+ pipeline.metrics = PrometheusMetrics()
161
+
162
  try:
163
  result = pipeline.run(user_query=query, schema_preview=schema_preview)
164
  except AppError:
nl2sql/metrics.py CHANGED
@@ -1,111 +1,10 @@
1
- from prometheus_client import Counter, Histogram
2
- from nl2sql.prom import REGISTRY
3
 
 
 
4
 
5
- # -----------------------------------------------------------------------------
6
- # Stage-level metrics
7
- # -----------------------------------------------------------------------------
8
- stage_duration_ms = Histogram(
9
- "stage_duration_ms",
10
- "Duration (ms) of each pipeline stage",
11
- ["stage"], # e.g. detector|planner|generator|safety|verifier
12
- buckets=(1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000),
13
- registry=REGISTRY,
14
- )
15
 
16
- # -----------------------------------------------------------------------------
17
- # Safety stage metrics
18
- # -----------------------------------------------------------------------------
19
- safety_blocks_total = Counter(
20
- "safety_blocks_total",
21
- "Count of blocked SQL queries by safety checks",
22
- [
23
- "reason"
24
- ], # e.g. forbidden_keyword, multiple_statements, non_readonly, explain_not_allowed
25
- registry=REGISTRY,
26
- )
27
-
28
- safety_checks_total = Counter(
29
- "safety_checks_total",
30
- "Total SQL queries checked by safety",
31
- ["ok"], # "true" or "false"
32
- registry=REGISTRY,
33
- )
34
-
35
- # -----------------------------------------------------------------------------
36
- # Verifier stage metrics
37
- # -----------------------------------------------------------------------------
38
- verifier_checks_total = Counter(
39
- "verifier_checks_total",
40
- "Count of verifier checks (success/failure)",
41
- ["ok"], # "true" | "false"
42
- registry=REGISTRY,
43
- )
44
-
45
- verifier_failures_total = Counter(
46
- "verifier_failures_total",
47
- "Count of verifier failures by type",
48
- ["reason"], # e.g. parse_error, semantic_check_error, adapter_failure
49
- registry=REGISTRY,
50
- )
51
-
52
- # -----------------------------------------------------------------------------
53
- # Repair stage metrics
54
- # -----------------------------------------------------------------------------
55
- repair_attempts_total = Counter(
56
- "repair_attempts_total",
57
- "Number of repair loop attempts",
58
- ["outcome"], # attempt | success | failed
59
- registry=REGISTRY,
60
- )
61
-
62
- # -----------------------------------------------------------------------------
63
- # Pipeline-level metrics
64
- # -----------------------------------------------------------------------------
65
- pipeline_runs_total = Counter(
66
- "pipeline_runs_total",
67
- "Total number of full pipeline runs",
68
- ["status"], # ok | error | ambiguous
69
- registry=REGISTRY,
70
- )
71
-
72
- # -----------------------------------------------------------------------------
73
- # Cache metrics (optional)
74
- # -----------------------------------------------------------------------------
75
- cache_events_total = Counter(
76
- "cache_events_total",
77
- "Cache hit/miss events in the pipeline",
78
- ["hit"], # "true" | "false"
79
- registry=REGISTRY,
80
- )
81
-
82
- # -----------------------------------------------------------------------------
83
- # Prime all counters with zero to ensure Grafana panels always have data
84
- # -----------------------------------------------------------------------------
85
- for reason in (
86
- "forbidden_keyword",
87
- "multiple_statements",
88
- "non_readonly",
89
- "explain_not_allowed",
90
- "parse_error",
91
- "semantic_check_error",
92
- "adapter_failure",
93
- "unsafe-sql",
94
- "malformed-sql",
95
- "unknown",
96
- ):
97
- safety_blocks_total.labels(reason=reason).inc(0)
98
- verifier_failures_total.labels(reason=reason).inc(0)
99
-
100
- for ok in ("true", "false"):
101
- safety_checks_total.labels(ok=ok).inc(0)
102
- verifier_checks_total.labels(ok=ok).inc(0)
103
-
104
- for outcome in ("attempt", "success", "failed"):
105
- repair_attempts_total.labels(outcome=outcome).inc(0)
106
-
107
- for status in ("ok", "error", "ambiguous"):
108
- pipeline_runs_total.labels(status=status).inc(0)
109
-
110
- for hit in ("true", "false"):
111
- cache_events_total.labels(hit=hit).inc(0)
 
1
+ """
2
+ Deprecated shim.
3
 
4
+ All Prometheus metric definitions and the PrometheusMetrics adapter live in:
5
+ adapters.metrics.prometheus
6
 
7
+ This module only exists for backward-compatible imports.
8
+ """
 
 
 
 
 
 
 
 
9
 
10
+ from adapters.metrics.prometheus import * # noqa: F401,F403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nl2sql/pipeline.py CHANGED
@@ -36,7 +36,6 @@ class FinalResult:
36
  traces: List[dict]
37
 
38
  error_code: Optional[ErrorCode] = None
39
-
40
  result: Optional[Dict[str, Any]] = None
41
 
42
 
@@ -124,10 +123,36 @@ class Pipeline:
124
  )
125
  return norm
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  @staticmethod
128
  def _safe_stage(fn, **kwargs) -> StageResult:
129
  try:
130
- r = fn(**kwargs)
 
131
  if isinstance(r, StageResult):
132
  return r
133
  return StageResult(ok=True, data=r, trace=None)
@@ -154,12 +179,10 @@ class Pipeline:
154
  return any(p in m for p in patterns)
155
 
156
  def _should_repair(self, stage_name: str, r: StageResult) -> tuple[bool, str]:
157
- """Return (eligible, reason)."""
158
- # Never repair safety blocks: policy violations must not be 'fixed' via LLM.
159
  if stage_name == "safety":
160
  return (False, "blocked_by_safety")
161
 
162
- # Never repair cost guardrail blocks: require user constraint (LIMIT) / query refinement.
163
  if (
164
  stage_name == "executor"
165
  and getattr(r, "error_code", None)
@@ -167,15 +190,13 @@ class Pipeline:
167
  ):
168
  return (False, "blocked_by_cost")
169
 
170
- # Only repair SQL-related stages.
171
  if stage_name not in self.SQL_REPAIR_STAGES:
172
  return (False, "not_sql_stage")
173
 
174
- # Verifier may return ok=True but verified=False (semantic mismatch). Allow a single repair attempt.
175
  if stage_name == "verifier":
176
  data = r.data if isinstance(r.data, dict) else {}
177
- if data.get("verified") is False:
178
- return (True, "not_verified")
179
 
180
  errs = r.error or []
181
  if any(isinstance(e, str) and self._is_repairable_sql_error(e) for e in errs):
@@ -190,14 +211,18 @@ class Pipeline:
190
  *,
191
  repair_input_builder,
192
  max_attempts: int = 1,
193
- traces: list,
194
  **kwargs,
195
  ) -> StageResult:
196
  """
197
  Run a stage with per-stage repair + full observability integration.
198
  SQL-only repair occurs for safety/executor/verifier.
199
- Planner/Generator get log-only repair (trace only, no effect).
 
200
  """
 
 
 
 
201
  attempt = 0
202
 
203
  while True:
@@ -228,13 +253,32 @@ class Pipeline:
228
  }
229
  )
230
 
231
- if r.ok:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  return r
233
 
234
  # stage failed → check repair availability
235
  eligible, reason = self._should_repair(stage_name, r)
236
  if not eligible:
237
- self.metrics.inc_repair_attempt(stage="verifier", outcome="skipped")
238
  # annotate latest stage trace entry
239
  if traces and isinstance(traces[-1], dict):
240
  notes = traces[-1].get("notes") or {}
@@ -254,7 +298,7 @@ class Pipeline:
254
 
255
  # --- 3) Run repair (always logged) ---
256
  self.metrics.inc_repair_trigger(stage=stage_name, reason=reason)
257
- self.metrics.inc_repair_attempt(stage="verifier", outcome="attempt")
258
  t1 = time.perf_counter()
259
  r_fix = self._safe_stage(self.repair.run, **repair_args)
260
  dt_fix = (time.perf_counter() - t1) * 1000.0
@@ -274,7 +318,7 @@ class Pipeline:
274
  )
275
 
276
  if not r_fix.ok:
277
- self.metrics.inc_repair_attempt(stage="verifier", outcome="failed")
278
  return r # repair itself failed → stop here
279
 
280
  # --- 4) Only inject SQL if the stage is an SQL-producing stage ---
@@ -282,16 +326,9 @@ class Pipeline:
282
  if "sql" in repair_args and "sql" in kwargs:
283
  kwargs["sql"] = (r_fix.data or {}).get("sql", kwargs["sql"])
284
 
285
- # important: success metric must reflect if repair was applied meaningfully
286
- if stage_name in self.SQL_REPAIR_STAGES:
287
- self.metrics.inc_repair_attempt(stage="verifier", outcome="success")
288
- else:
289
- # log-only mode counts as a success-attempt but not semantic success
290
- self.metrics.inc_repair_attempt(stage="verifier", outcome="success")
291
 
292
- # for SQL stages, we re-run the stage again with modified kwargs
293
- # for log-only stages, this simply loops and stage is re-run unchanged
294
- # (which is correct)
295
 
296
  @staticmethod
297
  def _planner_repair_input_builder(stage_result, kwargs):
@@ -395,6 +432,7 @@ class Pipeline:
395
  dt = (time.perf_counter() - t0) * 1000.0
396
  is_amb = bool(questions)
397
  self.metrics.observe_stage_duration_ms(stage="detector", dt_ms=dt)
 
398
  traces.append(
399
  self._mk_trace(
400
  stage="detector",
@@ -405,6 +443,7 @@ class Pipeline:
405
  )
406
  if questions:
407
  self.metrics.inc_pipeline_run(status="ambiguous")
 
408
  return FinalResult(
409
  ok=True,
410
  ambiguous=True,
@@ -418,8 +457,6 @@ class Pipeline:
418
  )
419
 
420
  # --- 2) planner ---
421
- t0 = time.perf_counter()
422
-
423
  planner_kwargs: Dict[str, Any] = {
424
  "user_query": user_query,
425
  "schema_preview": schema_for_llm,
@@ -454,14 +491,13 @@ class Pipeline:
454
  )
455
 
456
  # --- 3) generator ---
457
- t0 = time.perf_counter()
458
-
459
  gen_kwargs: Dict[str, Any] = {
460
  "user_query": user_query,
461
  "schema_preview": schema_for_llm,
462
  "plan_text": (r_plan.data or {}).get("plan"),
463
  "clarify_answers": clarify_answers,
464
  "traces": traces,
 
465
  }
466
  try:
467
  if "schema_pack" in inspect.signature(self.generator.run).parameters:
@@ -494,33 +530,6 @@ class Pipeline:
494
  sql = (r_gen.data or {}).get("sql")
495
  rationale = (r_gen.data or {}).get("rationale")
496
 
497
- # --- schema drift signal (planner vs generator table usage)
498
- planner_used_tables = (
499
- (r_plan.data or {}).get("used_tables")
500
- or (r_plan.data or {}).get("tables")
501
- or []
502
- )
503
- generator_used_tables = (
504
- (r_gen.data or {}).get("used_tables")
505
- or (r_gen.data or {}).get("tables")
506
- or []
507
- )
508
- planner_set = set(planner_used_tables)
509
- generator_set = set(generator_used_tables)
510
- schema_drift = bool(generator_set - planner_set)
511
- traces.append(
512
- self._mk_trace(
513
- stage="schema_drift_check",
514
- duration_ms=0.0,
515
- summary="compare planner vs generator table usage",
516
- notes={
517
- "planner_used_tables": sorted(planner_set),
518
- "generator_used_tables": sorted(generator_set),
519
- "schema_drift": schema_drift,
520
- },
521
- )
522
- )
523
-
524
  # Guard: empty SQL
525
  if not sql or not str(sql).strip():
526
  self.metrics.inc_pipeline_run(status="error")
@@ -541,13 +550,13 @@ class Pipeline:
541
  )
542
 
543
  # --- 4) safety ---
544
- t0 = time.perf_counter()
545
  r_safe = self._run_with_repair(
546
  "safety",
547
  self.safety.run,
548
  repair_input_builder=self._sql_repair_input_builder,
549
  max_attempts=1,
550
  sql=sql,
 
551
  traces=traces,
552
  )
553
  if not r_safe.ok:
@@ -569,201 +578,49 @@ class Pipeline:
569
  sql = (r_safe.data or {}).get("sql", sql)
570
 
571
  # --- 5) executor ---
572
- t0 = time.perf_counter()
573
  r_exec = self._run_with_repair(
574
  "executor",
575
  self.executor.run,
576
  repair_input_builder=self._sql_repair_input_builder,
577
  max_attempts=1,
578
  sql=sql,
 
579
  traces=traces,
580
  )
581
  if not r_exec.ok and r_exec.error:
582
- details.extend(
583
- r_exec.error
584
- ) # soft: keep for repair/verifier context_engineering
585
  if r_exec.ok and isinstance(r_exec.data, dict):
586
  exec_result = dict(r_exec.data)
587
 
588
  # --- 6) verifier (only if execution succeeded) ---
589
- r_ver: StageResult | None = None
590
- verifier_failed = False
591
  verified = False
592
-
593
  if r_exec.ok:
594
- t0 = time.perf_counter()
595
- r_ver = self._safe_stage(
596
  self._call_verifier,
 
 
597
  sql=sql,
598
  exec_result=(r_exec.data or {}),
 
599
  traces=traces,
600
  )
601
- dt = (time.perf_counter() - t0) * 1000.0
602
- self.metrics.observe_stage_duration_ms(stage="verifier", dt_ms=dt)
603
-
604
- # Attach a trace entry if verifier didn't provide one
605
- if getattr(r_ver, "trace", None):
606
- traces.append(r_ver.trace.__dict__)
607
- else:
608
- traces.append(
609
- {
610
- "stage": "verifier",
611
- "duration_ms": dt,
612
- "summary": "ok" if r_ver.ok else "failed",
613
- "notes": {},
614
- }
615
- )
616
-
617
- verifier_failed = not bool(r_ver.ok)
618
- data0 = (
619
- r_ver.data if (r_ver.ok and isinstance(r_ver.data, dict)) else {}
620
- )
621
- verified = bool(data0.get("verified") is True)
622
-
623
- # --- 7) semantic repair gating (verifier fail OR not_verified) ---
624
- # If verifier failed or said "not verified", attempt a single repair and then:
625
- # safety → executor → verifier again.
626
- if r_exec.ok and (verifier_failed or not verified):
627
- eligible, _reason = self._should_repair(
628
- "verifier",
629
- StageResult(
630
- ok=True, data={"verified": False}, error=None, trace=None
631
- ),
632
- )
633
- if eligible:
634
- self.metrics.inc_repair_trigger(stage="verifier", reason=_reason)
635
- # Prefer the real verifier message if present (tests expect this).
636
- err_list = (r_ver.error if (r_ver and r_ver.error) else None) or []
637
- error_msg = (
638
- "; ".join([e for e in err_list if isinstance(e, str)])
639
- or "not_verified"
640
- )
641
-
642
- rep_kwargs_all: Dict[str, Any] = {
643
- "user_query": user_query,
644
- "schema_preview": schema_for_llm,
645
- "sql": sql,
646
- "error_msg": error_msg,
647
- "error_message": error_msg,
648
- "traces": traces,
649
- "constraints": constraints,
650
- }
651
- try:
652
- params = inspect.signature(self.repair.run).parameters
653
- rep_kwargs = {
654
- k: v for k, v in rep_kwargs_all.items() if k in params
655
- }
656
- except (TypeError, ValueError):
657
- rep_kwargs = {
658
- "sql": sql,
659
- "error_msg": error_msg,
660
- "schema_preview": schema_for_llm,
661
- }
662
-
663
- self.metrics.inc_repair_attempt(stage="verifier", outcome="attempt")
664
- r_rep = self.repair.run(**rep_kwargs)
665
-
666
- new_sql = (
667
- r_rep.data.get("sql")
668
- if (r_rep.ok and isinstance(r_rep.data, dict))
669
- else None
670
- )
671
-
672
- if new_sql and str(new_sql).strip():
673
- sql = str(new_sql)
674
-
675
- # Re-run safety → executor → verifier using the standard path.
676
- r_safe2 = self._run_with_repair(
677
- "safety",
678
- self.safety.run,
679
- repair_input_builder=self._sql_repair_input_builder,
680
- max_attempts=1,
681
- sql=sql,
682
- traces=traces,
683
- )
684
- if not r_safe2.ok:
685
- self.metrics.inc_pipeline_run(status="error")
686
- return FinalResult(
687
- ok=False,
688
- ambiguous=False,
689
- error=True,
690
- details=r_safe2.error,
691
- error_code=r_safe2.error_code,
692
- questions=None,
693
- sql=sql,
694
- rationale=rationale,
695
- verified=None,
696
- traces=self._normalize_traces(traces),
697
- )
698
-
699
- sql = (r_safe2.data or {}).get("sql", sql)
700
-
701
- r_exec2 = self._run_with_repair(
702
- "executor",
703
- self.executor.run,
704
- repair_input_builder=self._sql_repair_input_builder,
705
- max_attempts=1,
706
- sql=sql,
707
- traces=traces,
708
- )
709
- if r_exec2.ok and isinstance(r_exec2.data, dict):
710
- exec_result = dict(r_exec2.data)
711
-
712
- r_ver2 = self._run_with_repair(
713
- "verifier",
714
- self._call_verifier,
715
- repair_input_builder=self._sql_repair_input_builder,
716
- max_attempts=1,
717
- sql=sql,
718
- exec_result=(r_exec2.data or {}),
719
- traces=traces,
720
- )
721
- data2 = r_ver2.data if isinstance(r_ver2.data, dict) else {}
722
- verified = bool(data2.get("verified") is True)
723
-
724
- if verified:
725
- self.metrics.inc_repair_attempt(
726
- stage="verifier", outcome="success"
727
- )
728
- else:
729
- self.metrics.inc_repair_attempt(
730
- stage="verifier", outcome="failed"
731
- )
732
- else:
733
- self.metrics.inc_repair_attempt(stage="verifier", outcome="skipped")
734
-
735
- # --- 8) optional soft auto-verify (executor success, no details) --- (executor success, no details) ---
736
- if (verified is None or not verified) and not details:
737
- any_exec_ok = any(
738
- t.get("stage") == "executor"
739
- and (t.get("notes") or {}).get("row_count")
740
- for t in traces
741
- )
742
- if any_exec_ok:
743
- traces.append(
744
- self._mk_trace(
745
- stage="pipeline",
746
- duration_ms=0.0,
747
- summary="auto-verified",
748
- notes={"reason": "executor succeeded, verifier silent"},
749
- )
750
- )
751
- verified = True
752
 
753
  # --- 9) finalize ---
754
  has_errors = bool(details)
755
  need_ver = bool(self.require_verification)
756
 
757
- # base success condition
758
  final_ok_by_verifier = bool(verified)
759
- base_ok = (
760
- bool(sql) and not has_errors and (final_ok_by_verifier or not need_ver)
 
 
761
  )
762
- ok = base_ok
763
  err = (not ok) and has_errors
764
 
765
- # align `verified` with baseline semantics:
766
- # if verification is NOT required and pipeline is ok, report verified=True
767
  if not need_ver and ok and not final_ok_by_verifier:
768
  verified_final = True
769
  else:
@@ -799,7 +656,6 @@ class Pipeline:
799
 
800
  except Exception:
801
  self.metrics.inc_pipeline_run(status="error")
802
- # bubble up to make failures visible in tests and logs
803
  raise
804
 
805
  finally:
 
36
  traces: List[dict]
37
 
38
  error_code: Optional[ErrorCode] = None
 
39
  result: Optional[Dict[str, Any]] = None
40
 
41
 
 
123
  )
124
  return norm
125
 
126
+ @staticmethod
127
+ def _accepts_kwargs(fn) -> bool:
128
+ try:
129
+ sig = inspect.signature(fn)
130
+ except (TypeError, ValueError):
131
+ return True
132
+ return any(
133
+ p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
134
+ )
135
+
136
+ @staticmethod
137
+ def _filter_kwargs(fn, kwargs: Dict[str, Any]) -> Dict[str, Any]:
138
+ """
139
+ Make stage calls backward-compatible with older stubs/fakes that don't accept
140
+ extra kwargs like `traces`, `schema_preview`, etc.
141
+ """
142
+ if Pipeline._accepts_kwargs(fn):
143
+ return kwargs
144
+ try:
145
+ sig = inspect.signature(fn)
146
+ allowed = set(sig.parameters.keys())
147
+ return {k: v for k, v in kwargs.items() if k in allowed}
148
+ except (TypeError, ValueError):
149
+ return kwargs
150
+
151
  @staticmethod
152
  def _safe_stage(fn, **kwargs) -> StageResult:
153
  try:
154
+ call_kwargs = Pipeline._filter_kwargs(fn, kwargs)
155
+ r = fn(**call_kwargs)
156
  if isinstance(r, StageResult):
157
  return r
158
  return StageResult(ok=True, data=r, trace=None)
 
179
  return any(p in m for p in patterns)
180
 
181
  def _should_repair(self, stage_name: str, r: StageResult) -> tuple[bool, str]:
182
+ # Never repair safety blocks
 
183
  if stage_name == "safety":
184
  return (False, "blocked_by_safety")
185
 
 
186
  if (
187
  stage_name == "executor"
188
  and getattr(r, "error_code", None)
 
190
  ):
191
  return (False, "blocked_by_cost")
192
 
 
193
  if stage_name not in self.SQL_REPAIR_STAGES:
194
  return (False, "not_sql_stage")
195
 
 
196
  if stage_name == "verifier":
197
  data = r.data if isinstance(r.data, dict) else {}
198
+ if data.get("verified") is False or (r.ok is False):
199
+ return (True, "semantic_failure")
200
 
201
  errs = r.error or []
202
  if any(isinstance(e, str) and self._is_repairable_sql_error(e) for e in errs):
 
211
  *,
212
  repair_input_builder,
213
  max_attempts: int = 1,
 
214
  **kwargs,
215
  ) -> StageResult:
216
  """
217
  Run a stage with per-stage repair + full observability integration.
218
  SQL-only repair occurs for safety/executor/verifier.
219
+
220
+ IMPORTANT: `traces` must be provided in kwargs as a list.
221
  """
222
+ traces = kwargs.get("traces")
223
+ if traces is None or not isinstance(traces, list):
224
+ raise TypeError("_run_with_repair requires `traces` (list) in kwargs")
225
+
226
  attempt = 0
227
 
228
  while True:
 
253
  }
254
  )
255
 
256
+ # --- 1.5) Verifier semantic failure is repairable even if ok=True ---
257
+ if r.ok and stage_name == "verifier":
258
+ data0 = r.data if isinstance(r.data, dict) else {}
259
+ if data0.get("verified") is True:
260
+ return r
261
+ # ok=True but verified=False → treat as eligible for repair path
262
+ eligible, reason = self._should_repair(stage_name, r)
263
+ if not eligible:
264
+ self.metrics.inc_repair_attempt(stage=stage_name, outcome="skipped")
265
+ if traces and isinstance(traces[-1], dict):
266
+ notes = traces[-1].get("notes") or {}
267
+ if not isinstance(notes, dict):
268
+ notes = {}
269
+ notes["repair_eligible"] = False
270
+ notes["repair_skip_reason"] = reason
271
+ traces[-1]["notes"] = notes
272
+ return r
273
+ # fallthrough into repair branch below
274
+
275
+ elif r.ok:
276
  return r
277
 
278
  # stage failed → check repair availability
279
  eligible, reason = self._should_repair(stage_name, r)
280
  if not eligible:
281
+ self.metrics.inc_repair_attempt(stage=stage_name, outcome="skipped")
282
  # annotate latest stage trace entry
283
  if traces and isinstance(traces[-1], dict):
284
  notes = traces[-1].get("notes") or {}
 
298
 
299
  # --- 3) Run repair (always logged) ---
300
  self.metrics.inc_repair_trigger(stage=stage_name, reason=reason)
301
+ self.metrics.inc_repair_attempt(stage=stage_name, outcome="attempt")
302
  t1 = time.perf_counter()
303
  r_fix = self._safe_stage(self.repair.run, **repair_args)
304
  dt_fix = (time.perf_counter() - t1) * 1000.0
 
318
  )
319
 
320
  if not r_fix.ok:
321
+ self.metrics.inc_repair_attempt(stage=stage_name, outcome="failed")
322
  return r # repair itself failed → stop here
323
 
324
  # --- 4) Only inject SQL if the stage is an SQL-producing stage ---
 
326
  if "sql" in repair_args and "sql" in kwargs:
327
  kwargs["sql"] = (r_fix.data or {}).get("sql", kwargs["sql"])
328
 
329
+ self.metrics.inc_repair_attempt(stage=stage_name, outcome="success")
 
 
 
 
 
330
 
331
+ # re-run stage with updated kwargs
 
 
332
 
333
  @staticmethod
334
  def _planner_repair_input_builder(stage_result, kwargs):
 
432
  dt = (time.perf_counter() - t0) * 1000.0
433
  is_amb = bool(questions)
434
  self.metrics.observe_stage_duration_ms(stage="detector", dt_ms=dt)
435
+ self.metrics.inc_stage_call(stage="detector", ok=True)
436
  traces.append(
437
  self._mk_trace(
438
  stage="detector",
 
443
  )
444
  if questions:
445
  self.metrics.inc_pipeline_run(status="ambiguous")
446
+ self.metrics.inc_stage_call(stage="detector", ok=False)
447
  return FinalResult(
448
  ok=True,
449
  ambiguous=True,
 
457
  )
458
 
459
  # --- 2) planner ---
 
 
460
  planner_kwargs: Dict[str, Any] = {
461
  "user_query": user_query,
462
  "schema_preview": schema_for_llm,
 
491
  )
492
 
493
  # --- 3) generator ---
 
 
494
  gen_kwargs: Dict[str, Any] = {
495
  "user_query": user_query,
496
  "schema_preview": schema_for_llm,
497
  "plan_text": (r_plan.data or {}).get("plan"),
498
  "clarify_answers": clarify_answers,
499
  "traces": traces,
500
+ "constraints": constraints,
501
  }
502
  try:
503
  if "schema_pack" in inspect.signature(self.generator.run).parameters:
 
530
  sql = (r_gen.data or {}).get("sql")
531
  rationale = (r_gen.data or {}).get("rationale")
532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  # Guard: empty SQL
534
  if not sql or not str(sql).strip():
535
  self.metrics.inc_pipeline_run(status="error")
 
550
  )
551
 
552
  # --- 4) safety ---
 
553
  r_safe = self._run_with_repair(
554
  "safety",
555
  self.safety.run,
556
  repair_input_builder=self._sql_repair_input_builder,
557
  max_attempts=1,
558
  sql=sql,
559
+ schema_preview=schema_for_llm,
560
  traces=traces,
561
  )
562
  if not r_safe.ok:
 
578
  sql = (r_safe.data or {}).get("sql", sql)
579
 
580
  # --- 5) executor ---
 
581
  r_exec = self._run_with_repair(
582
  "executor",
583
  self.executor.run,
584
  repair_input_builder=self._sql_repair_input_builder,
585
  max_attempts=1,
586
  sql=sql,
587
+ schema_preview=schema_for_llm,
588
  traces=traces,
589
  )
590
  if not r_exec.ok and r_exec.error:
591
+ details.extend(r_exec.error)
 
 
592
  if r_exec.ok and isinstance(r_exec.data, dict):
593
  exec_result = dict(r_exec.data)
594
 
595
  # --- 6) verifier (only if execution succeeded) ---
 
 
596
  verified = False
 
597
  if r_exec.ok:
598
+ r_ver = self._run_with_repair(
599
+ "verifier",
600
  self._call_verifier,
601
+ repair_input_builder=self._sql_repair_input_builder,
602
+ max_attempts=1,
603
  sql=sql,
604
  exec_result=(r_exec.data or {}),
605
+ schema_preview=schema_for_llm,
606
  traces=traces,
607
  )
608
+ data_v = r_ver.data if isinstance(r_ver.data, dict) else {}
609
+ verified = bool(data_v.get("verified") is True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
 
611
  # --- 9) finalize ---
612
  has_errors = bool(details)
613
  need_ver = bool(self.require_verification)
614
 
 
615
  final_ok_by_verifier = bool(verified)
616
+ ok = (
617
+ bool(sql)
618
+ and (not has_errors)
619
+ and (final_ok_by_verifier or not need_ver)
620
  )
 
621
  err = (not ok) and has_errors
622
 
623
+ # If verification is NOT required and pipeline is ok, report verified=True
 
624
  if not need_ver and ok and not final_ok_by_verifier:
625
  verified_final = True
626
  else:
 
656
 
657
  except Exception:
658
  self.metrics.inc_pipeline_run(status="error")
 
659
  raise
660
 
661
  finally:
nl2sql/pipeline_factory.py CHANGED
@@ -77,6 +77,16 @@ def _make_metrics() -> Metrics:
77
  return PrometheusMetrics()
78
 
79
 
 
 
 
 
 
 
 
 
 
 
80
  def _tr(
81
  stage: str,
82
  *,
@@ -207,14 +217,6 @@ def pipeline_from_config(path: str) -> Pipeline:
207
  verifier = VERIFIERS[cfg.get("verifier", "basic")]()
208
  repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
209
 
210
- context_engineer = ContextEngineer(
211
- budget=ContextBudget(
212
- max_tables=25,
213
- max_columns_per_table=25,
214
- max_total_columns=400,
215
- )
216
- )
217
-
218
  return Pipeline(
219
  detector=detector,
220
  planner=planner,
@@ -223,7 +225,7 @@ def pipeline_from_config(path: str) -> Pipeline:
223
  executor=executor,
224
  verifier=verifier,
225
  repair=repair,
226
- context_engineer=context_engineer,
227
  metrics=_make_metrics(),
228
  )
229
 
@@ -338,5 +340,6 @@ def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipel
338
  executor=executor,
339
  verifier=verifier,
340
  repair=repair,
 
341
  metrics=_make_metrics(),
342
  )
 
77
  return PrometheusMetrics()
78
 
79
 
80
+ def _default_context_engineer() -> ContextEngineer:
81
+ return ContextEngineer(
82
+ budget=ContextBudget(
83
+ max_tables=25,
84
+ max_columns_per_table=25,
85
+ max_total_columns=400,
86
+ )
87
+ )
88
+
89
+
90
  def _tr(
91
  stage: str,
92
  *,
 
217
  verifier = VERIFIERS[cfg.get("verifier", "basic")]()
218
  repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
219
 
 
 
 
 
 
 
 
 
220
  return Pipeline(
221
  detector=detector,
222
  planner=planner,
 
225
  executor=executor,
226
  verifier=verifier,
227
  repair=repair,
228
+ context_engineer=_default_context_engineer(),
229
  metrics=_make_metrics(),
230
  )
231
 
 
340
  executor=executor,
341
  verifier=verifier,
342
  repair=repair,
343
+ context_engineer=_default_context_engineer(),
344
  metrics=_make_metrics(),
345
  )