Spaces:
Sleeping
Sleeping
github-actions[bot] commited on
Commit ·
7666573
1
Parent(s): 11975fd
Sync from GitHub main @ 51304ea81c450e4e5ce90de52f10a63dcca33c64
Browse files- adapters/metrics/base.py +2 -1
- adapters/metrics/noop.py +2 -2
- adapters/metrics/prometheus.py +116 -12
- app/services/nl2sql_service.py +10 -3
- nl2sql/metrics.py +7 -108
- nl2sql/pipeline.py +79 -223
- nl2sql/pipeline_factory.py +12 -9
adapters/metrics/base.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
from typing import Literal
|
| 5 |
|
|
|
|
| 6 |
RepairOutcome = Literal["attempt", "success", "failed", "skipped"]
|
| 7 |
|
| 8 |
|
|
@@ -11,7 +12,7 @@ class Metrics(ABC):
|
|
| 11 |
def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None: ...
|
| 12 |
|
| 13 |
@abstractmethod
|
| 14 |
-
def inc_pipeline_run(self, *, status:
|
| 15 |
|
| 16 |
@abstractmethod
|
| 17 |
def inc_stage_call(self, *, stage: str, ok: bool) -> None: ...
|
|
|
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
from typing import Literal
|
| 5 |
|
| 6 |
+
PipelineStatus = Literal["ok", "error", "ambiguous"]
|
| 7 |
RepairOutcome = Literal["attempt", "success", "failed", "skipped"]
|
| 8 |
|
| 9 |
|
|
|
|
| 12 |
def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None: ...
|
| 13 |
|
| 14 |
@abstractmethod
|
| 15 |
+
def inc_pipeline_run(self, *, status: PipelineStatus) -> None: ...
|
| 16 |
|
| 17 |
@abstractmethod
|
| 18 |
def inc_stage_call(self, *, stage: str, ok: bool) -> None: ...
|
adapters/metrics/noop.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
from adapters.metrics.base import Metrics, RepairOutcome
|
| 4 |
|
| 5 |
|
| 6 |
class NoOpMetrics(Metrics):
|
| 7 |
def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None:
|
| 8 |
return
|
| 9 |
|
| 10 |
-
def inc_pipeline_run(self, *, status:
|
| 11 |
return
|
| 12 |
|
| 13 |
def inc_stage_call(self, *, stage: str, ok: bool) -> None:
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from adapters.metrics.base import Metrics, PipelineStatus, RepairOutcome
|
| 4 |
|
| 5 |
|
| 6 |
class NoOpMetrics(Metrics):
|
| 7 |
def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None:
|
| 8 |
return
|
| 9 |
|
| 10 |
+
def inc_pipeline_run(self, *, status: PipelineStatus) -> None:
|
| 11 |
return
|
| 12 |
|
| 13 |
def inc_stage_call(self, *, stage: str, ok: bool) -> None:
|
adapters/metrics/prometheus.py
CHANGED
|
@@ -1,50 +1,154 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
from prometheus_client import Counter
|
| 4 |
-
from
|
| 5 |
-
from nl2sql.metrics import stage_duration_ms, pipeline_runs_total
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
stage_calls_total = Counter(
|
| 9 |
"stage_calls_total",
|
| 10 |
-
"
|
| 11 |
["stage", "ok"],
|
|
|
|
| 12 |
)
|
| 13 |
|
| 14 |
stage_errors_total = Counter(
|
| 15 |
"stage_errors_total",
|
| 16 |
-
"
|
| 17 |
["stage", "error_code"],
|
|
|
|
| 18 |
)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
repair_attempts_total = Counter(
|
| 21 |
"repair_attempts_total",
|
| 22 |
-
"
|
| 23 |
["stage", "outcome"],
|
|
|
|
| 24 |
)
|
| 25 |
|
| 26 |
repair_trigger_total = Counter(
|
| 27 |
"repair_trigger_total",
|
| 28 |
-
"
|
| 29 |
["stage", "reason"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
)
|
| 31 |
|
| 32 |
|
| 33 |
class PrometheusMetrics(Metrics):
|
| 34 |
def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None:
|
| 35 |
-
stage_duration_ms.labels(stage=stage).observe(dt_ms)
|
| 36 |
|
| 37 |
-
def inc_pipeline_run(self, *, status:
|
| 38 |
pipeline_runs_total.labels(status=status).inc()
|
| 39 |
|
| 40 |
def inc_stage_call(self, *, stage: str, ok: bool) -> None:
|
| 41 |
-
stage_calls_total.labels(stage=stage, ok=
|
| 42 |
|
| 43 |
def inc_stage_error(self, *, stage: str, error_code: str) -> None:
|
| 44 |
-
stage_errors_total.labels(stage=stage, error_code=error_code).inc()
|
| 45 |
|
| 46 |
def inc_repair_trigger(self, *, stage: str, reason: str) -> None:
|
| 47 |
-
repair_trigger_total.labels(stage=stage, reason=reason).inc()
|
| 48 |
|
| 49 |
def inc_repair_attempt(self, *, stage: str, outcome: RepairOutcome) -> None:
|
| 50 |
repair_attempts_total.labels(stage=stage, outcome=outcome).inc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from prometheus_client import Counter, Histogram
|
| 4 |
+
from nl2sql.prom import REGISTRY
|
|
|
|
| 5 |
|
| 6 |
+
from adapters.metrics.base import Metrics, PipelineStatus, RepairOutcome
|
| 7 |
+
|
| 8 |
+
# -----------------------------------------------------------------------------
|
| 9 |
+
# Stage-level metrics
|
| 10 |
+
# -----------------------------------------------------------------------------
|
| 11 |
+
stage_duration_ms = Histogram(
|
| 12 |
+
"stage_duration_ms",
|
| 13 |
+
"Duration (ms) of each pipeline stage",
|
| 14 |
+
["stage"],
|
| 15 |
+
buckets=(1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000),
|
| 16 |
+
registry=REGISTRY,
|
| 17 |
+
)
|
| 18 |
|
| 19 |
stage_calls_total = Counter(
|
| 20 |
"stage_calls_total",
|
| 21 |
+
"Count of stage calls labeled by stage and ok",
|
| 22 |
["stage", "ok"],
|
| 23 |
+
registry=REGISTRY,
|
| 24 |
)
|
| 25 |
|
| 26 |
stage_errors_total = Counter(
|
| 27 |
"stage_errors_total",
|
| 28 |
+
"Count of stage errors labeled by stage and error_code",
|
| 29 |
["stage", "error_code"],
|
| 30 |
+
registry=REGISTRY,
|
| 31 |
)
|
| 32 |
|
| 33 |
+
# -----------------------------------------------------------------------------
|
| 34 |
+
# Repair metrics (Day 3 contract)
|
| 35 |
+
# -----------------------------------------------------------------------------
|
| 36 |
repair_attempts_total = Counter(
|
| 37 |
"repair_attempts_total",
|
| 38 |
+
"Count of repair attempts labeled by stage and outcome",
|
| 39 |
["stage", "outcome"],
|
| 40 |
+
registry=REGISTRY,
|
| 41 |
)
|
| 42 |
|
| 43 |
repair_trigger_total = Counter(
|
| 44 |
"repair_trigger_total",
|
| 45 |
+
"Count of repair triggers labeled by stage and reason",
|
| 46 |
["stage", "reason"],
|
| 47 |
+
registry=REGISTRY,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# -----------------------------------------------------------------------------
|
| 51 |
+
# Safety stage metrics (existing / optional)
|
| 52 |
+
# -----------------------------------------------------------------------------
|
| 53 |
+
safety_blocks_total = Counter(
|
| 54 |
+
"safety_blocks_total",
|
| 55 |
+
"Count of blocked SQL queries by safety checks",
|
| 56 |
+
["reason"],
|
| 57 |
+
registry=REGISTRY,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
safety_checks_total = Counter(
|
| 61 |
+
"safety_checks_total",
|
| 62 |
+
"Total SQL queries checked by safety",
|
| 63 |
+
["ok"],
|
| 64 |
+
registry=REGISTRY,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# -----------------------------------------------------------------------------
|
| 68 |
+
# Verifier stage metrics (existing / optional)
|
| 69 |
+
# -----------------------------------------------------------------------------
|
| 70 |
+
verifier_checks_total = Counter(
|
| 71 |
+
"verifier_checks_total",
|
| 72 |
+
"Count of verifier checks (success/failure)",
|
| 73 |
+
["ok"],
|
| 74 |
+
registry=REGISTRY,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
verifier_failures_total = Counter(
|
| 78 |
+
"verifier_failures_total",
|
| 79 |
+
"Count of verifier failures by type",
|
| 80 |
+
["reason"],
|
| 81 |
+
registry=REGISTRY,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# -----------------------------------------------------------------------------
|
| 85 |
+
# Pipeline-level metrics
|
| 86 |
+
# -----------------------------------------------------------------------------
|
| 87 |
+
pipeline_runs_total = Counter(
|
| 88 |
+
"pipeline_runs_total",
|
| 89 |
+
"Total number of full pipeline runs",
|
| 90 |
+
["status"],
|
| 91 |
+
registry=REGISTRY,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# -----------------------------------------------------------------------------
|
| 95 |
+
# Cache metrics (optional)
|
| 96 |
+
# -----------------------------------------------------------------------------
|
| 97 |
+
cache_events_total = Counter(
|
| 98 |
+
"cache_events_total",
|
| 99 |
+
"Cache hit/miss events in the pipeline",
|
| 100 |
+
["hit"],
|
| 101 |
+
registry=REGISTRY,
|
| 102 |
)
|
| 103 |
|
| 104 |
|
| 105 |
class PrometheusMetrics(Metrics):
|
| 106 |
def observe_stage_duration_ms(self, *, stage: str, dt_ms: float) -> None:
|
| 107 |
+
stage_duration_ms.labels(stage=stage).observe(float(dt_ms))
|
| 108 |
|
| 109 |
+
def inc_pipeline_run(self, *, status: PipelineStatus) -> None:
|
| 110 |
pipeline_runs_total.labels(status=status).inc()
|
| 111 |
|
| 112 |
def inc_stage_call(self, *, stage: str, ok: bool) -> None:
|
| 113 |
+
stage_calls_total.labels(stage=stage, ok=("true" if ok else "false")).inc()
|
| 114 |
|
| 115 |
def inc_stage_error(self, *, stage: str, error_code: str) -> None:
|
| 116 |
+
stage_errors_total.labels(stage=stage, error_code=str(error_code)).inc()
|
| 117 |
|
| 118 |
def inc_repair_trigger(self, *, stage: str, reason: str) -> None:
|
| 119 |
+
repair_trigger_total.labels(stage=stage, reason=str(reason)).inc()
|
| 120 |
|
| 121 |
def inc_repair_attempt(self, *, stage: str, outcome: RepairOutcome) -> None:
|
| 122 |
repair_attempts_total.labels(stage=stage, outcome=outcome).inc()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# -----------------------------------------------------------------------------
|
| 126 |
+
# Label priming to keep /metrics stable
|
| 127 |
+
# -----------------------------------------------------------------------------
|
| 128 |
+
for ok in ("true", "false"):
|
| 129 |
+
safety_checks_total.labels(ok=ok).inc(0)
|
| 130 |
+
verifier_checks_total.labels(ok=ok).inc(0)
|
| 131 |
+
|
| 132 |
+
for status in ("ok", "error", "ambiguous"):
|
| 133 |
+
pipeline_runs_total.labels(status=status).inc(0)
|
| 134 |
+
|
| 135 |
+
for hit in ("true", "false"):
|
| 136 |
+
cache_events_total.labels(hit=hit).inc(0)
|
| 137 |
+
|
| 138 |
+
# Prime Day 3 series
|
| 139 |
+
for stage in (
|
| 140 |
+
"detector",
|
| 141 |
+
"planner",
|
| 142 |
+
"generator",
|
| 143 |
+
"safety",
|
| 144 |
+
"executor",
|
| 145 |
+
"verifier",
|
| 146 |
+
"repair",
|
| 147 |
+
):
|
| 148 |
+
for ok in ("true", "false"):
|
| 149 |
+
stage_calls_total.labels(stage=stage, ok=ok).inc(0)
|
| 150 |
+
for outcome in ("attempt", "success", "failed", "skipped"):
|
| 151 |
+
repair_attempts_total.labels(stage=stage, outcome=outcome).inc(0)
|
| 152 |
+
|
| 153 |
+
for reason in ("semantic_failure", "unknown"):
|
| 154 |
+
repair_trigger_total.labels(stage="verifier", reason=reason).inc(0)
|
app/services/nl2sql_service.py
CHANGED
|
@@ -9,6 +9,8 @@ from nl2sql.pipeline import FinalResult
|
|
| 9 |
from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
|
| 10 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 11 |
from adapters.db.postgres_adapter import PostgresAdapter
|
|
|
|
|
|
|
| 12 |
from app import state
|
| 13 |
from app.settings import Settings
|
| 14 |
from app.errors import (
|
|
@@ -65,7 +67,6 @@ class NL2SQLService:
|
|
| 65 |
This is a straight port of the previous sqlite3 logic, but contained
|
| 66 |
inside the service instead of the router.
|
| 67 |
"""
|
| 68 |
-
# Try to locate the underlying .db path from the adapter
|
| 69 |
db_path = getattr(adapter, "db_path", None) or getattr(adapter, "path", None)
|
| 70 |
if not db_path:
|
| 71 |
raise RuntimeError(
|
|
@@ -105,8 +106,7 @@ class NL2SQLService:
|
|
| 105 |
|
| 106 |
- If override is provided by the client → use it.
|
| 107 |
- Else, in sqlite mode → introspect the DB.
|
| 108 |
-
- In postgres mode without override → fail fast
|
| 109 |
-
this to a proper HTTP error.
|
| 110 |
"""
|
| 111 |
if override:
|
| 112 |
return override
|
|
@@ -152,6 +152,13 @@ class NL2SQLService:
|
|
| 152 |
f"Failed to build pipeline from {self.settings.pipeline_config_path!r}: {exc}"
|
| 153 |
) from exc
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
try:
|
| 156 |
result = pipeline.run(user_query=query, schema_preview=schema_preview)
|
| 157 |
except AppError:
|
|
|
|
| 9 |
from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
|
| 10 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 11 |
from adapters.db.postgres_adapter import PostgresAdapter
|
| 12 |
+
from adapters.metrics.prometheus import PrometheusMetrics
|
| 13 |
+
|
| 14 |
from app import state
|
| 15 |
from app.settings import Settings
|
| 16 |
from app.errors import (
|
|
|
|
| 67 |
This is a straight port of the previous sqlite3 logic, but contained
|
| 68 |
inside the service instead of the router.
|
| 69 |
"""
|
|
|
|
| 70 |
db_path = getattr(adapter, "db_path", None) or getattr(adapter, "path", None)
|
| 71 |
if not db_path:
|
| 72 |
raise RuntimeError(
|
|
|
|
| 106 |
|
| 107 |
- If override is provided by the client → use it.
|
| 108 |
- Else, in sqlite mode → introspect the DB.
|
| 109 |
+
- In postgres mode without override → fail fast.
|
|
|
|
| 110 |
"""
|
| 111 |
if override:
|
| 112 |
return override
|
|
|
|
| 152 |
f"Failed to build pipeline from {self.settings.pipeline_config_path!r}: {exc}"
|
| 153 |
) from exc
|
| 154 |
|
| 155 |
+
# Force PrometheusMetrics to avoid silent NoOp wiring via factory defaults.
|
| 156 |
+
if (
|
| 157 |
+
getattr(pipeline, "metrics", None) is None
|
| 158 |
+
or pipeline.metrics.__class__.__name__ == "NoOpMetrics"
|
| 159 |
+
):
|
| 160 |
+
pipeline.metrics = PrometheusMetrics()
|
| 161 |
+
|
| 162 |
try:
|
| 163 |
result = pipeline.run(user_query=query, schema_preview=schema_preview)
|
| 164 |
except AppError:
|
nl2sql/metrics.py
CHANGED
|
@@ -1,111 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
# -----------------------------------------------------------------------------
|
| 8 |
-
stage_duration_ms = Histogram(
|
| 9 |
-
"stage_duration_ms",
|
| 10 |
-
"Duration (ms) of each pipeline stage",
|
| 11 |
-
["stage"], # e.g. detector|planner|generator|safety|verifier
|
| 12 |
-
buckets=(1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000),
|
| 13 |
-
registry=REGISTRY,
|
| 14 |
-
)
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
# Safety stage metrics
|
| 18 |
-
# -----------------------------------------------------------------------------
|
| 19 |
-
safety_blocks_total = Counter(
|
| 20 |
-
"safety_blocks_total",
|
| 21 |
-
"Count of blocked SQL queries by safety checks",
|
| 22 |
-
[
|
| 23 |
-
"reason"
|
| 24 |
-
], # e.g. forbidden_keyword, multiple_statements, non_readonly, explain_not_allowed
|
| 25 |
-
registry=REGISTRY,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
safety_checks_total = Counter(
|
| 29 |
-
"safety_checks_total",
|
| 30 |
-
"Total SQL queries checked by safety",
|
| 31 |
-
["ok"], # "true" or "false"
|
| 32 |
-
registry=REGISTRY,
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
# -----------------------------------------------------------------------------
|
| 36 |
-
# Verifier stage metrics
|
| 37 |
-
# -----------------------------------------------------------------------------
|
| 38 |
-
verifier_checks_total = Counter(
|
| 39 |
-
"verifier_checks_total",
|
| 40 |
-
"Count of verifier checks (success/failure)",
|
| 41 |
-
["ok"], # "true" | "false"
|
| 42 |
-
registry=REGISTRY,
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
verifier_failures_total = Counter(
|
| 46 |
-
"verifier_failures_total",
|
| 47 |
-
"Count of verifier failures by type",
|
| 48 |
-
["reason"], # e.g. parse_error, semantic_check_error, adapter_failure
|
| 49 |
-
registry=REGISTRY,
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
# -----------------------------------------------------------------------------
|
| 53 |
-
# Repair stage metrics
|
| 54 |
-
# -----------------------------------------------------------------------------
|
| 55 |
-
repair_attempts_total = Counter(
|
| 56 |
-
"repair_attempts_total",
|
| 57 |
-
"Number of repair loop attempts",
|
| 58 |
-
["outcome"], # attempt | success | failed
|
| 59 |
-
registry=REGISTRY,
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
# -----------------------------------------------------------------------------
|
| 63 |
-
# Pipeline-level metrics
|
| 64 |
-
# -----------------------------------------------------------------------------
|
| 65 |
-
pipeline_runs_total = Counter(
|
| 66 |
-
"pipeline_runs_total",
|
| 67 |
-
"Total number of full pipeline runs",
|
| 68 |
-
["status"], # ok | error | ambiguous
|
| 69 |
-
registry=REGISTRY,
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
# -----------------------------------------------------------------------------
|
| 73 |
-
# Cache metrics (optional)
|
| 74 |
-
# -----------------------------------------------------------------------------
|
| 75 |
-
cache_events_total = Counter(
|
| 76 |
-
"cache_events_total",
|
| 77 |
-
"Cache hit/miss events in the pipeline",
|
| 78 |
-
["hit"], # "true" | "false"
|
| 79 |
-
registry=REGISTRY,
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
# -----------------------------------------------------------------------------
|
| 83 |
-
# Prime all counters with zero to ensure Grafana panels always have data
|
| 84 |
-
# -----------------------------------------------------------------------------
|
| 85 |
-
for reason in (
|
| 86 |
-
"forbidden_keyword",
|
| 87 |
-
"multiple_statements",
|
| 88 |
-
"non_readonly",
|
| 89 |
-
"explain_not_allowed",
|
| 90 |
-
"parse_error",
|
| 91 |
-
"semantic_check_error",
|
| 92 |
-
"adapter_failure",
|
| 93 |
-
"unsafe-sql",
|
| 94 |
-
"malformed-sql",
|
| 95 |
-
"unknown",
|
| 96 |
-
):
|
| 97 |
-
safety_blocks_total.labels(reason=reason).inc(0)
|
| 98 |
-
verifier_failures_total.labels(reason=reason).inc(0)
|
| 99 |
-
|
| 100 |
-
for ok in ("true", "false"):
|
| 101 |
-
safety_checks_total.labels(ok=ok).inc(0)
|
| 102 |
-
verifier_checks_total.labels(ok=ok).inc(0)
|
| 103 |
-
|
| 104 |
-
for outcome in ("attempt", "success", "failed"):
|
| 105 |
-
repair_attempts_total.labels(outcome=outcome).inc(0)
|
| 106 |
-
|
| 107 |
-
for status in ("ok", "error", "ambiguous"):
|
| 108 |
-
pipeline_runs_total.labels(status=status).inc(0)
|
| 109 |
-
|
| 110 |
-
for hit in ("true", "false"):
|
| 111 |
-
cache_events_total.labels(hit=hit).inc(0)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Deprecated shim.
|
| 3 |
|
| 4 |
+
All Prometheus metric definitions and the PrometheusMetrics adapter live in:
|
| 5 |
+
adapters.metrics.prometheus
|
| 6 |
|
| 7 |
+
This module only exists for backward-compatible imports.
|
| 8 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
from adapters.metrics.prometheus import * # noqa: F401,F403
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nl2sql/pipeline.py
CHANGED
|
@@ -36,7 +36,6 @@ class FinalResult:
|
|
| 36 |
traces: List[dict]
|
| 37 |
|
| 38 |
error_code: Optional[ErrorCode] = None
|
| 39 |
-
|
| 40 |
result: Optional[Dict[str, Any]] = None
|
| 41 |
|
| 42 |
|
|
@@ -124,10 +123,36 @@ class Pipeline:
|
|
| 124 |
)
|
| 125 |
return norm
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
@staticmethod
|
| 128 |
def _safe_stage(fn, **kwargs) -> StageResult:
|
| 129 |
try:
|
| 130 |
-
|
|
|
|
| 131 |
if isinstance(r, StageResult):
|
| 132 |
return r
|
| 133 |
return StageResult(ok=True, data=r, trace=None)
|
|
@@ -154,12 +179,10 @@ class Pipeline:
|
|
| 154 |
return any(p in m for p in patterns)
|
| 155 |
|
| 156 |
def _should_repair(self, stage_name: str, r: StageResult) -> tuple[bool, str]:
|
| 157 |
-
|
| 158 |
-
# Never repair safety blocks: policy violations must not be 'fixed' via LLM.
|
| 159 |
if stage_name == "safety":
|
| 160 |
return (False, "blocked_by_safety")
|
| 161 |
|
| 162 |
-
# Never repair cost guardrail blocks: require user constraint (LIMIT) / query refinement.
|
| 163 |
if (
|
| 164 |
stage_name == "executor"
|
| 165 |
and getattr(r, "error_code", None)
|
|
@@ -167,15 +190,13 @@ class Pipeline:
|
|
| 167 |
):
|
| 168 |
return (False, "blocked_by_cost")
|
| 169 |
|
| 170 |
-
# Only repair SQL-related stages.
|
| 171 |
if stage_name not in self.SQL_REPAIR_STAGES:
|
| 172 |
return (False, "not_sql_stage")
|
| 173 |
|
| 174 |
-
# Verifier may return ok=True but verified=False (semantic mismatch). Allow a single repair attempt.
|
| 175 |
if stage_name == "verifier":
|
| 176 |
data = r.data if isinstance(r.data, dict) else {}
|
| 177 |
-
if data.get("verified") is False:
|
| 178 |
-
return (True, "
|
| 179 |
|
| 180 |
errs = r.error or []
|
| 181 |
if any(isinstance(e, str) and self._is_repairable_sql_error(e) for e in errs):
|
|
@@ -190,14 +211,18 @@ class Pipeline:
|
|
| 190 |
*,
|
| 191 |
repair_input_builder,
|
| 192 |
max_attempts: int = 1,
|
| 193 |
-
traces: list,
|
| 194 |
**kwargs,
|
| 195 |
) -> StageResult:
|
| 196 |
"""
|
| 197 |
Run a stage with per-stage repair + full observability integration.
|
| 198 |
SQL-only repair occurs for safety/executor/verifier.
|
| 199 |
-
|
|
|
|
| 200 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
attempt = 0
|
| 202 |
|
| 203 |
while True:
|
|
@@ -228,13 +253,32 @@ class Pipeline:
|
|
| 228 |
}
|
| 229 |
)
|
| 230 |
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
return r
|
| 233 |
|
| 234 |
# stage failed → check repair availability
|
| 235 |
eligible, reason = self._should_repair(stage_name, r)
|
| 236 |
if not eligible:
|
| 237 |
-
self.metrics.inc_repair_attempt(stage=
|
| 238 |
# annotate latest stage trace entry
|
| 239 |
if traces and isinstance(traces[-1], dict):
|
| 240 |
notes = traces[-1].get("notes") or {}
|
|
@@ -254,7 +298,7 @@ class Pipeline:
|
|
| 254 |
|
| 255 |
# --- 3) Run repair (always logged) ---
|
| 256 |
self.metrics.inc_repair_trigger(stage=stage_name, reason=reason)
|
| 257 |
-
self.metrics.inc_repair_attempt(stage=
|
| 258 |
t1 = time.perf_counter()
|
| 259 |
r_fix = self._safe_stage(self.repair.run, **repair_args)
|
| 260 |
dt_fix = (time.perf_counter() - t1) * 1000.0
|
|
@@ -274,7 +318,7 @@ class Pipeline:
|
|
| 274 |
)
|
| 275 |
|
| 276 |
if not r_fix.ok:
|
| 277 |
-
self.metrics.inc_repair_attempt(stage=
|
| 278 |
return r # repair itself failed → stop here
|
| 279 |
|
| 280 |
# --- 4) Only inject SQL if the stage is an SQL-producing stage ---
|
|
@@ -282,16 +326,9 @@ class Pipeline:
|
|
| 282 |
if "sql" in repair_args and "sql" in kwargs:
|
| 283 |
kwargs["sql"] = (r_fix.data or {}).get("sql", kwargs["sql"])
|
| 284 |
|
| 285 |
-
|
| 286 |
-
if stage_name in self.SQL_REPAIR_STAGES:
|
| 287 |
-
self.metrics.inc_repair_attempt(stage="verifier", outcome="success")
|
| 288 |
-
else:
|
| 289 |
-
# log-only mode counts as a success-attempt but not semantic success
|
| 290 |
-
self.metrics.inc_repair_attempt(stage="verifier", outcome="success")
|
| 291 |
|
| 292 |
-
#
|
| 293 |
-
# for log-only stages, this simply loops and stage is re-run unchanged
|
| 294 |
-
# (which is correct)
|
| 295 |
|
| 296 |
@staticmethod
|
| 297 |
def _planner_repair_input_builder(stage_result, kwargs):
|
|
@@ -395,6 +432,7 @@ class Pipeline:
|
|
| 395 |
dt = (time.perf_counter() - t0) * 1000.0
|
| 396 |
is_amb = bool(questions)
|
| 397 |
self.metrics.observe_stage_duration_ms(stage="detector", dt_ms=dt)
|
|
|
|
| 398 |
traces.append(
|
| 399 |
self._mk_trace(
|
| 400 |
stage="detector",
|
|
@@ -405,6 +443,7 @@ class Pipeline:
|
|
| 405 |
)
|
| 406 |
if questions:
|
| 407 |
self.metrics.inc_pipeline_run(status="ambiguous")
|
|
|
|
| 408 |
return FinalResult(
|
| 409 |
ok=True,
|
| 410 |
ambiguous=True,
|
|
@@ -418,8 +457,6 @@ class Pipeline:
|
|
| 418 |
)
|
| 419 |
|
| 420 |
# --- 2) planner ---
|
| 421 |
-
t0 = time.perf_counter()
|
| 422 |
-
|
| 423 |
planner_kwargs: Dict[str, Any] = {
|
| 424 |
"user_query": user_query,
|
| 425 |
"schema_preview": schema_for_llm,
|
|
@@ -454,14 +491,13 @@ class Pipeline:
|
|
| 454 |
)
|
| 455 |
|
| 456 |
# --- 3) generator ---
|
| 457 |
-
t0 = time.perf_counter()
|
| 458 |
-
|
| 459 |
gen_kwargs: Dict[str, Any] = {
|
| 460 |
"user_query": user_query,
|
| 461 |
"schema_preview": schema_for_llm,
|
| 462 |
"plan_text": (r_plan.data or {}).get("plan"),
|
| 463 |
"clarify_answers": clarify_answers,
|
| 464 |
"traces": traces,
|
|
|
|
| 465 |
}
|
| 466 |
try:
|
| 467 |
if "schema_pack" in inspect.signature(self.generator.run).parameters:
|
|
@@ -494,33 +530,6 @@ class Pipeline:
|
|
| 494 |
sql = (r_gen.data or {}).get("sql")
|
| 495 |
rationale = (r_gen.data or {}).get("rationale")
|
| 496 |
|
| 497 |
-
# --- schema drift signal (planner vs generator table usage)
|
| 498 |
-
planner_used_tables = (
|
| 499 |
-
(r_plan.data or {}).get("used_tables")
|
| 500 |
-
or (r_plan.data or {}).get("tables")
|
| 501 |
-
or []
|
| 502 |
-
)
|
| 503 |
-
generator_used_tables = (
|
| 504 |
-
(r_gen.data or {}).get("used_tables")
|
| 505 |
-
or (r_gen.data or {}).get("tables")
|
| 506 |
-
or []
|
| 507 |
-
)
|
| 508 |
-
planner_set = set(planner_used_tables)
|
| 509 |
-
generator_set = set(generator_used_tables)
|
| 510 |
-
schema_drift = bool(generator_set - planner_set)
|
| 511 |
-
traces.append(
|
| 512 |
-
self._mk_trace(
|
| 513 |
-
stage="schema_drift_check",
|
| 514 |
-
duration_ms=0.0,
|
| 515 |
-
summary="compare planner vs generator table usage",
|
| 516 |
-
notes={
|
| 517 |
-
"planner_used_tables": sorted(planner_set),
|
| 518 |
-
"generator_used_tables": sorted(generator_set),
|
| 519 |
-
"schema_drift": schema_drift,
|
| 520 |
-
},
|
| 521 |
-
)
|
| 522 |
-
)
|
| 523 |
-
|
| 524 |
# Guard: empty SQL
|
| 525 |
if not sql or not str(sql).strip():
|
| 526 |
self.metrics.inc_pipeline_run(status="error")
|
|
@@ -541,13 +550,13 @@ class Pipeline:
|
|
| 541 |
)
|
| 542 |
|
| 543 |
# --- 4) safety ---
|
| 544 |
-
t0 = time.perf_counter()
|
| 545 |
r_safe = self._run_with_repair(
|
| 546 |
"safety",
|
| 547 |
self.safety.run,
|
| 548 |
repair_input_builder=self._sql_repair_input_builder,
|
| 549 |
max_attempts=1,
|
| 550 |
sql=sql,
|
|
|
|
| 551 |
traces=traces,
|
| 552 |
)
|
| 553 |
if not r_safe.ok:
|
|
@@ -569,201 +578,49 @@ class Pipeline:
|
|
| 569 |
sql = (r_safe.data or {}).get("sql", sql)
|
| 570 |
|
| 571 |
# --- 5) executor ---
|
| 572 |
-
t0 = time.perf_counter()
|
| 573 |
r_exec = self._run_with_repair(
|
| 574 |
"executor",
|
| 575 |
self.executor.run,
|
| 576 |
repair_input_builder=self._sql_repair_input_builder,
|
| 577 |
max_attempts=1,
|
| 578 |
sql=sql,
|
|
|
|
| 579 |
traces=traces,
|
| 580 |
)
|
| 581 |
if not r_exec.ok and r_exec.error:
|
| 582 |
-
details.extend(
|
| 583 |
-
r_exec.error
|
| 584 |
-
) # soft: keep for repair/verifier context_engineering
|
| 585 |
if r_exec.ok and isinstance(r_exec.data, dict):
|
| 586 |
exec_result = dict(r_exec.data)
|
| 587 |
|
| 588 |
# --- 6) verifier (only if execution succeeded) ---
|
| 589 |
-
r_ver: StageResult | None = None
|
| 590 |
-
verifier_failed = False
|
| 591 |
verified = False
|
| 592 |
-
|
| 593 |
if r_exec.ok:
|
| 594 |
-
|
| 595 |
-
|
| 596 |
self._call_verifier,
|
|
|
|
|
|
|
| 597 |
sql=sql,
|
| 598 |
exec_result=(r_exec.data or {}),
|
|
|
|
| 599 |
traces=traces,
|
| 600 |
)
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
# Attach a trace entry if verifier didn't provide one
|
| 605 |
-
if getattr(r_ver, "trace", None):
|
| 606 |
-
traces.append(r_ver.trace.__dict__)
|
| 607 |
-
else:
|
| 608 |
-
traces.append(
|
| 609 |
-
{
|
| 610 |
-
"stage": "verifier",
|
| 611 |
-
"duration_ms": dt,
|
| 612 |
-
"summary": "ok" if r_ver.ok else "failed",
|
| 613 |
-
"notes": {},
|
| 614 |
-
}
|
| 615 |
-
)
|
| 616 |
-
|
| 617 |
-
verifier_failed = not bool(r_ver.ok)
|
| 618 |
-
data0 = (
|
| 619 |
-
r_ver.data if (r_ver.ok and isinstance(r_ver.data, dict)) else {}
|
| 620 |
-
)
|
| 621 |
-
verified = bool(data0.get("verified") is True)
|
| 622 |
-
|
| 623 |
-
# --- 7) semantic repair gating (verifier fail OR not_verified) ---
|
| 624 |
-
# If verifier failed or said "not verified", attempt a single repair and then:
|
| 625 |
-
# safety → executor → verifier again.
|
| 626 |
-
if r_exec.ok and (verifier_failed or not verified):
|
| 627 |
-
eligible, _reason = self._should_repair(
|
| 628 |
-
"verifier",
|
| 629 |
-
StageResult(
|
| 630 |
-
ok=True, data={"verified": False}, error=None, trace=None
|
| 631 |
-
),
|
| 632 |
-
)
|
| 633 |
-
if eligible:
|
| 634 |
-
self.metrics.inc_repair_trigger(stage="verifier", reason=_reason)
|
| 635 |
-
# Prefer the real verifier message if present (tests expect this).
|
| 636 |
-
err_list = (r_ver.error if (r_ver and r_ver.error) else None) or []
|
| 637 |
-
error_msg = (
|
| 638 |
-
"; ".join([e for e in err_list if isinstance(e, str)])
|
| 639 |
-
or "not_verified"
|
| 640 |
-
)
|
| 641 |
-
|
| 642 |
-
rep_kwargs_all: Dict[str, Any] = {
|
| 643 |
-
"user_query": user_query,
|
| 644 |
-
"schema_preview": schema_for_llm,
|
| 645 |
-
"sql": sql,
|
| 646 |
-
"error_msg": error_msg,
|
| 647 |
-
"error_message": error_msg,
|
| 648 |
-
"traces": traces,
|
| 649 |
-
"constraints": constraints,
|
| 650 |
-
}
|
| 651 |
-
try:
|
| 652 |
-
params = inspect.signature(self.repair.run).parameters
|
| 653 |
-
rep_kwargs = {
|
| 654 |
-
k: v for k, v in rep_kwargs_all.items() if k in params
|
| 655 |
-
}
|
| 656 |
-
except (TypeError, ValueError):
|
| 657 |
-
rep_kwargs = {
|
| 658 |
-
"sql": sql,
|
| 659 |
-
"error_msg": error_msg,
|
| 660 |
-
"schema_preview": schema_for_llm,
|
| 661 |
-
}
|
| 662 |
-
|
| 663 |
-
self.metrics.inc_repair_attempt(stage="verifier", outcome="attempt")
|
| 664 |
-
r_rep = self.repair.run(**rep_kwargs)
|
| 665 |
-
|
| 666 |
-
new_sql = (
|
| 667 |
-
r_rep.data.get("sql")
|
| 668 |
-
if (r_rep.ok and isinstance(r_rep.data, dict))
|
| 669 |
-
else None
|
| 670 |
-
)
|
| 671 |
-
|
| 672 |
-
if new_sql and str(new_sql).strip():
|
| 673 |
-
sql = str(new_sql)
|
| 674 |
-
|
| 675 |
-
# Re-run safety → executor → verifier using the standard path.
|
| 676 |
-
r_safe2 = self._run_with_repair(
|
| 677 |
-
"safety",
|
| 678 |
-
self.safety.run,
|
| 679 |
-
repair_input_builder=self._sql_repair_input_builder,
|
| 680 |
-
max_attempts=1,
|
| 681 |
-
sql=sql,
|
| 682 |
-
traces=traces,
|
| 683 |
-
)
|
| 684 |
-
if not r_safe2.ok:
|
| 685 |
-
self.metrics.inc_pipeline_run(status="error")
|
| 686 |
-
return FinalResult(
|
| 687 |
-
ok=False,
|
| 688 |
-
ambiguous=False,
|
| 689 |
-
error=True,
|
| 690 |
-
details=r_safe2.error,
|
| 691 |
-
error_code=r_safe2.error_code,
|
| 692 |
-
questions=None,
|
| 693 |
-
sql=sql,
|
| 694 |
-
rationale=rationale,
|
| 695 |
-
verified=None,
|
| 696 |
-
traces=self._normalize_traces(traces),
|
| 697 |
-
)
|
| 698 |
-
|
| 699 |
-
sql = (r_safe2.data or {}).get("sql", sql)
|
| 700 |
-
|
| 701 |
-
r_exec2 = self._run_with_repair(
|
| 702 |
-
"executor",
|
| 703 |
-
self.executor.run,
|
| 704 |
-
repair_input_builder=self._sql_repair_input_builder,
|
| 705 |
-
max_attempts=1,
|
| 706 |
-
sql=sql,
|
| 707 |
-
traces=traces,
|
| 708 |
-
)
|
| 709 |
-
if r_exec2.ok and isinstance(r_exec2.data, dict):
|
| 710 |
-
exec_result = dict(r_exec2.data)
|
| 711 |
-
|
| 712 |
-
r_ver2 = self._run_with_repair(
|
| 713 |
-
"verifier",
|
| 714 |
-
self._call_verifier,
|
| 715 |
-
repair_input_builder=self._sql_repair_input_builder,
|
| 716 |
-
max_attempts=1,
|
| 717 |
-
sql=sql,
|
| 718 |
-
exec_result=(r_exec2.data or {}),
|
| 719 |
-
traces=traces,
|
| 720 |
-
)
|
| 721 |
-
data2 = r_ver2.data if isinstance(r_ver2.data, dict) else {}
|
| 722 |
-
verified = bool(data2.get("verified") is True)
|
| 723 |
-
|
| 724 |
-
if verified:
|
| 725 |
-
self.metrics.inc_repair_attempt(
|
| 726 |
-
stage="verifier", outcome="success"
|
| 727 |
-
)
|
| 728 |
-
else:
|
| 729 |
-
self.metrics.inc_repair_attempt(
|
| 730 |
-
stage="verifier", outcome="failed"
|
| 731 |
-
)
|
| 732 |
-
else:
|
| 733 |
-
self.metrics.inc_repair_attempt(stage="verifier", outcome="skipped")
|
| 734 |
-
|
| 735 |
-
# --- 8) optional soft auto-verify (executor success, no details) --- (executor success, no details) ---
|
| 736 |
-
if (verified is None or not verified) and not details:
|
| 737 |
-
any_exec_ok = any(
|
| 738 |
-
t.get("stage") == "executor"
|
| 739 |
-
and (t.get("notes") or {}).get("row_count")
|
| 740 |
-
for t in traces
|
| 741 |
-
)
|
| 742 |
-
if any_exec_ok:
|
| 743 |
-
traces.append(
|
| 744 |
-
self._mk_trace(
|
| 745 |
-
stage="pipeline",
|
| 746 |
-
duration_ms=0.0,
|
| 747 |
-
summary="auto-verified",
|
| 748 |
-
notes={"reason": "executor succeeded, verifier silent"},
|
| 749 |
-
)
|
| 750 |
-
)
|
| 751 |
-
verified = True
|
| 752 |
|
| 753 |
# --- 9) finalize ---
|
| 754 |
has_errors = bool(details)
|
| 755 |
need_ver = bool(self.require_verification)
|
| 756 |
|
| 757 |
-
# base success condition
|
| 758 |
final_ok_by_verifier = bool(verified)
|
| 759 |
-
|
| 760 |
-
bool(sql)
|
|
|
|
|
|
|
| 761 |
)
|
| 762 |
-
ok = base_ok
|
| 763 |
err = (not ok) and has_errors
|
| 764 |
|
| 765 |
-
#
|
| 766 |
-
# if verification is NOT required and pipeline is ok, report verified=True
|
| 767 |
if not need_ver and ok and not final_ok_by_verifier:
|
| 768 |
verified_final = True
|
| 769 |
else:
|
|
@@ -799,7 +656,6 @@ class Pipeline:
|
|
| 799 |
|
| 800 |
except Exception:
|
| 801 |
self.metrics.inc_pipeline_run(status="error")
|
| 802 |
-
# bubble up to make failures visible in tests and logs
|
| 803 |
raise
|
| 804 |
|
| 805 |
finally:
|
|
|
|
| 36 |
traces: List[dict]
|
| 37 |
|
| 38 |
error_code: Optional[ErrorCode] = None
|
|
|
|
| 39 |
result: Optional[Dict[str, Any]] = None
|
| 40 |
|
| 41 |
|
|
|
|
| 123 |
)
|
| 124 |
return norm
|
| 125 |
|
| 126 |
+
@staticmethod
|
| 127 |
+
def _accepts_kwargs(fn) -> bool:
|
| 128 |
+
try:
|
| 129 |
+
sig = inspect.signature(fn)
|
| 130 |
+
except (TypeError, ValueError):
|
| 131 |
+
return True
|
| 132 |
+
return any(
|
| 133 |
+
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
def _filter_kwargs(fn, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
| 138 |
+
"""
|
| 139 |
+
Make stage calls backward-compatible with older stubs/fakes that don't accept
|
| 140 |
+
extra kwargs like `traces`, `schema_preview`, etc.
|
| 141 |
+
"""
|
| 142 |
+
if Pipeline._accepts_kwargs(fn):
|
| 143 |
+
return kwargs
|
| 144 |
+
try:
|
| 145 |
+
sig = inspect.signature(fn)
|
| 146 |
+
allowed = set(sig.parameters.keys())
|
| 147 |
+
return {k: v for k, v in kwargs.items() if k in allowed}
|
| 148 |
+
except (TypeError, ValueError):
|
| 149 |
+
return kwargs
|
| 150 |
+
|
| 151 |
@staticmethod
|
| 152 |
def _safe_stage(fn, **kwargs) -> StageResult:
|
| 153 |
try:
|
| 154 |
+
call_kwargs = Pipeline._filter_kwargs(fn, kwargs)
|
| 155 |
+
r = fn(**call_kwargs)
|
| 156 |
if isinstance(r, StageResult):
|
| 157 |
return r
|
| 158 |
return StageResult(ok=True, data=r, trace=None)
|
|
|
|
| 179 |
return any(p in m for p in patterns)
|
| 180 |
|
| 181 |
def _should_repair(self, stage_name: str, r: StageResult) -> tuple[bool, str]:
|
| 182 |
+
# Never repair safety blocks
|
|
|
|
| 183 |
if stage_name == "safety":
|
| 184 |
return (False, "blocked_by_safety")
|
| 185 |
|
|
|
|
| 186 |
if (
|
| 187 |
stage_name == "executor"
|
| 188 |
and getattr(r, "error_code", None)
|
|
|
|
| 190 |
):
|
| 191 |
return (False, "blocked_by_cost")
|
| 192 |
|
|
|
|
| 193 |
if stage_name not in self.SQL_REPAIR_STAGES:
|
| 194 |
return (False, "not_sql_stage")
|
| 195 |
|
|
|
|
| 196 |
if stage_name == "verifier":
|
| 197 |
data = r.data if isinstance(r.data, dict) else {}
|
| 198 |
+
if data.get("verified") is False or (r.ok is False):
|
| 199 |
+
return (True, "semantic_failure")
|
| 200 |
|
| 201 |
errs = r.error or []
|
| 202 |
if any(isinstance(e, str) and self._is_repairable_sql_error(e) for e in errs):
|
|
|
|
| 211 |
*,
|
| 212 |
repair_input_builder,
|
| 213 |
max_attempts: int = 1,
|
|
|
|
| 214 |
**kwargs,
|
| 215 |
) -> StageResult:
|
| 216 |
"""
|
| 217 |
Run a stage with per-stage repair + full observability integration.
|
| 218 |
SQL-only repair occurs for safety/executor/verifier.
|
| 219 |
+
|
| 220 |
+
IMPORTANT: `traces` must be provided in kwargs as a list.
|
| 221 |
"""
|
| 222 |
+
traces = kwargs.get("traces")
|
| 223 |
+
if traces is None or not isinstance(traces, list):
|
| 224 |
+
raise TypeError("_run_with_repair requires `traces` (list) in kwargs")
|
| 225 |
+
|
| 226 |
attempt = 0
|
| 227 |
|
| 228 |
while True:
|
|
|
|
| 253 |
}
|
| 254 |
)
|
| 255 |
|
| 256 |
+
# --- 1.5) Verifier semantic failure is repairable even if ok=True ---
|
| 257 |
+
if r.ok and stage_name == "verifier":
|
| 258 |
+
data0 = r.data if isinstance(r.data, dict) else {}
|
| 259 |
+
if data0.get("verified") is True:
|
| 260 |
+
return r
|
| 261 |
+
# ok=True but verified=False → treat as eligible for repair path
|
| 262 |
+
eligible, reason = self._should_repair(stage_name, r)
|
| 263 |
+
if not eligible:
|
| 264 |
+
self.metrics.inc_repair_attempt(stage=stage_name, outcome="skipped")
|
| 265 |
+
if traces and isinstance(traces[-1], dict):
|
| 266 |
+
notes = traces[-1].get("notes") or {}
|
| 267 |
+
if not isinstance(notes, dict):
|
| 268 |
+
notes = {}
|
| 269 |
+
notes["repair_eligible"] = False
|
| 270 |
+
notes["repair_skip_reason"] = reason
|
| 271 |
+
traces[-1]["notes"] = notes
|
| 272 |
+
return r
|
| 273 |
+
# fallthrough into repair branch below
|
| 274 |
+
|
| 275 |
+
elif r.ok:
|
| 276 |
return r
|
| 277 |
|
| 278 |
# stage failed → check repair availability
|
| 279 |
eligible, reason = self._should_repair(stage_name, r)
|
| 280 |
if not eligible:
|
| 281 |
+
self.metrics.inc_repair_attempt(stage=stage_name, outcome="skipped")
|
| 282 |
# annotate latest stage trace entry
|
| 283 |
if traces and isinstance(traces[-1], dict):
|
| 284 |
notes = traces[-1].get("notes") or {}
|
|
|
|
| 298 |
|
| 299 |
# --- 3) Run repair (always logged) ---
|
| 300 |
self.metrics.inc_repair_trigger(stage=stage_name, reason=reason)
|
| 301 |
+
self.metrics.inc_repair_attempt(stage=stage_name, outcome="attempt")
|
| 302 |
t1 = time.perf_counter()
|
| 303 |
r_fix = self._safe_stage(self.repair.run, **repair_args)
|
| 304 |
dt_fix = (time.perf_counter() - t1) * 1000.0
|
|
|
|
| 318 |
)
|
| 319 |
|
| 320 |
if not r_fix.ok:
|
| 321 |
+
self.metrics.inc_repair_attempt(stage=stage_name, outcome="failed")
|
| 322 |
return r # repair itself failed → stop here
|
| 323 |
|
| 324 |
# --- 4) Only inject SQL if the stage is an SQL-producing stage ---
|
|
|
|
| 326 |
if "sql" in repair_args and "sql" in kwargs:
|
| 327 |
kwargs["sql"] = (r_fix.data or {}).get("sql", kwargs["sql"])
|
| 328 |
|
| 329 |
+
self.metrics.inc_repair_attempt(stage=stage_name, outcome="success")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
+
# re-run stage with updated kwargs
|
|
|
|
|
|
|
| 332 |
|
| 333 |
@staticmethod
|
| 334 |
def _planner_repair_input_builder(stage_result, kwargs):
|
|
|
|
| 432 |
dt = (time.perf_counter() - t0) * 1000.0
|
| 433 |
is_amb = bool(questions)
|
| 434 |
self.metrics.observe_stage_duration_ms(stage="detector", dt_ms=dt)
|
| 435 |
+
self.metrics.inc_stage_call(stage="detector", ok=True)
|
| 436 |
traces.append(
|
| 437 |
self._mk_trace(
|
| 438 |
stage="detector",
|
|
|
|
| 443 |
)
|
| 444 |
if questions:
|
| 445 |
self.metrics.inc_pipeline_run(status="ambiguous")
|
| 446 |
+
self.metrics.inc_stage_call(stage="detector", ok=False)
|
| 447 |
return FinalResult(
|
| 448 |
ok=True,
|
| 449 |
ambiguous=True,
|
|
|
|
| 457 |
)
|
| 458 |
|
| 459 |
# --- 2) planner ---
|
|
|
|
|
|
|
| 460 |
planner_kwargs: Dict[str, Any] = {
|
| 461 |
"user_query": user_query,
|
| 462 |
"schema_preview": schema_for_llm,
|
|
|
|
| 491 |
)
|
| 492 |
|
| 493 |
# --- 3) generator ---
|
|
|
|
|
|
|
| 494 |
gen_kwargs: Dict[str, Any] = {
|
| 495 |
"user_query": user_query,
|
| 496 |
"schema_preview": schema_for_llm,
|
| 497 |
"plan_text": (r_plan.data or {}).get("plan"),
|
| 498 |
"clarify_answers": clarify_answers,
|
| 499 |
"traces": traces,
|
| 500 |
+
"constraints": constraints,
|
| 501 |
}
|
| 502 |
try:
|
| 503 |
if "schema_pack" in inspect.signature(self.generator.run).parameters:
|
|
|
|
| 530 |
sql = (r_gen.data or {}).get("sql")
|
| 531 |
rationale = (r_gen.data or {}).get("rationale")
|
| 532 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
# Guard: empty SQL
|
| 534 |
if not sql or not str(sql).strip():
|
| 535 |
self.metrics.inc_pipeline_run(status="error")
|
|
|
|
| 550 |
)
|
| 551 |
|
| 552 |
# --- 4) safety ---
|
|
|
|
| 553 |
r_safe = self._run_with_repair(
|
| 554 |
"safety",
|
| 555 |
self.safety.run,
|
| 556 |
repair_input_builder=self._sql_repair_input_builder,
|
| 557 |
max_attempts=1,
|
| 558 |
sql=sql,
|
| 559 |
+
schema_preview=schema_for_llm,
|
| 560 |
traces=traces,
|
| 561 |
)
|
| 562 |
if not r_safe.ok:
|
|
|
|
| 578 |
sql = (r_safe.data or {}).get("sql", sql)
|
| 579 |
|
| 580 |
# --- 5) executor ---
|
|
|
|
| 581 |
r_exec = self._run_with_repair(
|
| 582 |
"executor",
|
| 583 |
self.executor.run,
|
| 584 |
repair_input_builder=self._sql_repair_input_builder,
|
| 585 |
max_attempts=1,
|
| 586 |
sql=sql,
|
| 587 |
+
schema_preview=schema_for_llm,
|
| 588 |
traces=traces,
|
| 589 |
)
|
| 590 |
if not r_exec.ok and r_exec.error:
|
| 591 |
+
details.extend(r_exec.error)
|
|
|
|
|
|
|
| 592 |
if r_exec.ok and isinstance(r_exec.data, dict):
|
| 593 |
exec_result = dict(r_exec.data)
|
| 594 |
|
| 595 |
# --- 6) verifier (only if execution succeeded) ---
|
|
|
|
|
|
|
| 596 |
verified = False
|
|
|
|
| 597 |
if r_exec.ok:
|
| 598 |
+
r_ver = self._run_with_repair(
|
| 599 |
+
"verifier",
|
| 600 |
self._call_verifier,
|
| 601 |
+
repair_input_builder=self._sql_repair_input_builder,
|
| 602 |
+
max_attempts=1,
|
| 603 |
sql=sql,
|
| 604 |
exec_result=(r_exec.data or {}),
|
| 605 |
+
schema_preview=schema_for_llm,
|
| 606 |
traces=traces,
|
| 607 |
)
|
| 608 |
+
data_v = r_ver.data if isinstance(r_ver.data, dict) else {}
|
| 609 |
+
verified = bool(data_v.get("verified") is True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
|
| 611 |
# --- 9) finalize ---
|
| 612 |
has_errors = bool(details)
|
| 613 |
need_ver = bool(self.require_verification)
|
| 614 |
|
|
|
|
| 615 |
final_ok_by_verifier = bool(verified)
|
| 616 |
+
ok = (
|
| 617 |
+
bool(sql)
|
| 618 |
+
and (not has_errors)
|
| 619 |
+
and (final_ok_by_verifier or not need_ver)
|
| 620 |
)
|
|
|
|
| 621 |
err = (not ok) and has_errors
|
| 622 |
|
| 623 |
+
# If verification is NOT required and pipeline is ok, report verified=True
|
|
|
|
| 624 |
if not need_ver and ok and not final_ok_by_verifier:
|
| 625 |
verified_final = True
|
| 626 |
else:
|
|
|
|
| 656 |
|
| 657 |
except Exception:
|
| 658 |
self.metrics.inc_pipeline_run(status="error")
|
|
|
|
| 659 |
raise
|
| 660 |
|
| 661 |
finally:
|
nl2sql/pipeline_factory.py
CHANGED
|
@@ -77,6 +77,16 @@ def _make_metrics() -> Metrics:
|
|
| 77 |
return PrometheusMetrics()
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def _tr(
|
| 81 |
stage: str,
|
| 82 |
*,
|
|
@@ -207,14 +217,6 @@ def pipeline_from_config(path: str) -> Pipeline:
|
|
| 207 |
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
|
| 208 |
repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
|
| 209 |
|
| 210 |
-
context_engineer = ContextEngineer(
|
| 211 |
-
budget=ContextBudget(
|
| 212 |
-
max_tables=25,
|
| 213 |
-
max_columns_per_table=25,
|
| 214 |
-
max_total_columns=400,
|
| 215 |
-
)
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
return Pipeline(
|
| 219 |
detector=detector,
|
| 220 |
planner=planner,
|
|
@@ -223,7 +225,7 @@ def pipeline_from_config(path: str) -> Pipeline:
|
|
| 223 |
executor=executor,
|
| 224 |
verifier=verifier,
|
| 225 |
repair=repair,
|
| 226 |
-
context_engineer=
|
| 227 |
metrics=_make_metrics(),
|
| 228 |
)
|
| 229 |
|
|
@@ -338,5 +340,6 @@ def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipel
|
|
| 338 |
executor=executor,
|
| 339 |
verifier=verifier,
|
| 340 |
repair=repair,
|
|
|
|
| 341 |
metrics=_make_metrics(),
|
| 342 |
)
|
|
|
|
| 77 |
return PrometheusMetrics()
|
| 78 |
|
| 79 |
|
| 80 |
+
def _default_context_engineer() -> ContextEngineer:
|
| 81 |
+
return ContextEngineer(
|
| 82 |
+
budget=ContextBudget(
|
| 83 |
+
max_tables=25,
|
| 84 |
+
max_columns_per_table=25,
|
| 85 |
+
max_total_columns=400,
|
| 86 |
+
)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
def _tr(
|
| 91 |
stage: str,
|
| 92 |
*,
|
|
|
|
| 217 |
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
|
| 218 |
repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
return Pipeline(
|
| 221 |
detector=detector,
|
| 222 |
planner=planner,
|
|
|
|
| 225 |
executor=executor,
|
| 226 |
verifier=verifier,
|
| 227 |
repair=repair,
|
| 228 |
+
context_engineer=_default_context_engineer(),
|
| 229 |
metrics=_make_metrics(),
|
| 230 |
)
|
| 231 |
|
|
|
|
| 340 |
executor=executor,
|
| 341 |
verifier=verifier,
|
| 342 |
repair=repair,
|
| 343 |
+
context_engineer=_default_context_engineer(),
|
| 344 |
metrics=_make_metrics(),
|
| 345 |
)
|