github-actions[bot] commited on
Commit
c758204
·
1 Parent(s): b1d64ab

Sync from GitHub main

Browse files
app/main.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import time
3
 
4
  from fastapi import FastAPI, Request, Response, HTTPException
5
- from fastapi.responses import PlainTextResponse, RedirectResponse
6
  from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
7
 
8
  from nl2sql.prom import REGISTRY
@@ -38,6 +38,15 @@ application.include_router(nl2sql.router, prefix="/api/v1")
38
  if os.getenv("APP_ENV", "dev").lower() == "dev":
39
  application.include_router(dev.router, prefix="/api/v1")
40
 
 
 
 
 
 
 
 
 
 
41
  # ----------------------------------------------------------------------------
42
  # Prometheus Metrics Middleware
43
  # ----------------------------------------------------------------------------
 
2
  import time
3
 
4
  from fastapi import FastAPI, Request, Response, HTTPException
5
+ from fastapi.responses import PlainTextResponse, RedirectResponse, JSONResponse
6
  from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
7
 
8
  from nl2sql.prom import REGISTRY
 
38
  if os.getenv("APP_ENV", "dev").lower() == "dev":
39
  application.include_router(dev.router, prefix="/api/v1")
40
 
41
+
42
+ @application.exception_handler(HTTPException)
43
+ async def http_exception_to_error_contract(request: Request, exc: HTTPException):
44
+ if isinstance(exc.detail, dict) and "error" in exc.detail:
45
+ return JSONResponse(status_code=exc.status_code, content=exc.detail)
46
+
47
+ return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
48
+
49
+
50
  # ----------------------------------------------------------------------------
51
  # Prometheus Metrics Middleware
52
  # ----------------------------------------------------------------------------
app/routers/nl2sql.py CHANGED
@@ -23,11 +23,10 @@ from app.services.nl2sql_service import NL2SQLService
23
  from app.settings import get_settings
24
  from app.errors import (
25
  AppError,
26
- BadRequestError,
27
- SafetyViolationError,
28
- DependencyError,
29
  PipelineRunError,
30
  )
 
 
31
 
32
  logger = logging.getLogger(__name__)
33
  settings = get_settings()
@@ -358,65 +357,34 @@ def nl2sql_handler(
358
  qs = result.questions or []
359
  return ClarifyResponse(ambiguous=True, questions=qs)
360
 
361
- # ---- error path: map pipeline failures to stable HTTP+JSON error contract ----
362
  if (not result.ok) or result.error:
363
  logger.debug(
364
  "Pipeline reported failure",
365
- extra={"ok": result.ok, "error": result.error, "details": result.details},
 
 
 
 
 
366
  )
367
 
368
- details = list(result.details or [])
369
- traces = list(result.traces or [])
370
- last_stage = str(traces[-1].get("stage", "unknown")) if traces else "unknown"
371
- details_l = " ".join(d.lower() for d in details)
372
-
373
- # 1) Safety violations → 422
374
- if last_stage == "safety":
375
- raise SafetyViolationError(
376
- message="Rejected by safety checks.",
377
- details=details or None,
378
- extra={"stage": last_stage},
379
- )
380
-
381
- # 2) Retryable dependency failures → 503
382
- retry_hints = (
383
- "timeout",
384
- "timed out",
385
- "rate limit",
386
- "429",
387
- "too many requests",
388
- "locked",
389
- "busy",
390
- "unavailable",
391
- "connection",
392
- )
393
- if any(h in details_l for h in retry_hints):
394
- raise DependencyError(
395
- message="Temporary dependency failure. Please retry.",
396
- details=details or None,
397
- extra={"stage": last_stage},
398
- )
399
-
400
- # 3) User-fixable parse/syntax-ish errors → 400
401
- user_hints = (
402
- "parse_error",
403
- "non-select",
404
- "explain not allowed",
405
- "multiple statements",
406
- "forbidden",
407
- )
408
- if any(h in details_l for h in user_hints):
409
- raise BadRequestError(
410
- message="Request could not be processed.",
411
- details=details or None,
412
- extra={"stage": last_stage},
413
- )
414
-
415
- # 4) Default → 500
416
- raise PipelineRunError(
417
- message="Pipeline failed unexpectedly.",
418
- details=details or None,
419
- extra={"stage": last_stage},
420
  )
421
 
422
  # ---- success path → 200 (normalize traces and executor result) ----
 
23
  from app.settings import get_settings
24
  from app.errors import (
25
  AppError,
 
 
 
26
  PipelineRunError,
27
  )
28
+ from nl2sql.errors.mapper import map_error
29
+ from nl2sql.errors.codes import ErrorCode
30
 
31
  logger = logging.getLogger(__name__)
32
  settings = get_settings()
 
357
  qs = result.questions or []
358
  return ClarifyResponse(ambiguous=True, questions=qs)
359
 
360
+ # ---- error path: contract-based mapping (Phase 3) ----
361
  if (not result.ok) or result.error:
362
  logger.debug(
363
  "Pipeline reported failure",
364
+ extra={
365
+ "ok": result.ok,
366
+ "error": result.error,
367
+ "error_code": getattr(result, "error_code", None),
368
+ "details": result.details,
369
+ },
370
  )
371
 
372
+ # 1) Normalize code (never string-match here)
373
+ code = result.error_code or ErrorCode.PIPELINE_CRASH
374
+
375
+ # 2) Single source of truth for HTTP semantics
376
+ status, retryable = map_error(code)
377
+
378
+ # 3) Stable error payload for UI/clients
379
+ raise HTTPException(
380
+ status_code=status,
381
+ detail={
382
+ "error": {
383
+ "code": code.value,
384
+ "retryable": retryable,
385
+ "details": list(result.details or []),
386
+ }
387
+ },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  )
389
 
390
  # ---- success path → 200 (normalize traces and executor result) ----
nl2sql/errors/codes.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class ErrorCode(str, Enum):
5
+ # --- Safety ---
6
+ SAFETY_NON_SELECT = "SAFETY_NON_SELECT"
7
+ SAFETY_MULTI_STATEMENT = "SAFETY_MULTI_STATEMENT"
8
+
9
+ # --- Verifier ---
10
+ PLAN_NO_SUCH_TABLE = "PLAN_NO_SUCH_TABLE"
11
+ PLAN_NO_SUCH_COLUMN = "PLAN_NO_SUCH_COLUMN"
12
+ PLAN_SYNTAX_ERROR = "PLAN_SYNTAX_ERROR"
13
+
14
+ # --- Executor / DB ---
15
+ DB_LOCKED = "DB_LOCKED"
16
+ DB_TIMEOUT = "DB_TIMEOUT"
17
+
18
+ # --- LLM ---
19
+ LLM_TIMEOUT = "LLM_TIMEOUT"
20
+ LLM_BAD_OUTPUT = "LLM_BAD_OUTPUT"
21
+
22
+ # --- Internal ---
23
+ PIPELINE_CRASH = "PIPELINE_CRASH"
nl2sql/errors/mapper.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nl2sql.errors.codes import ErrorCode
2
+
3
+ ERROR_MAP = {
4
+ ErrorCode.SAFETY_NON_SELECT: (422, False),
5
+ ErrorCode.SAFETY_MULTI_STATEMENT: (422, False),
6
+ ErrorCode.PLAN_NO_SUCH_TABLE: (422, False),
7
+ ErrorCode.PLAN_NO_SUCH_COLUMN: (422, False),
8
+ ErrorCode.PLAN_SYNTAX_ERROR: (422, False),
9
+ ErrorCode.DB_LOCKED: (503, True),
10
+ ErrorCode.DB_TIMEOUT: (503, True),
11
+ ErrorCode.LLM_TIMEOUT: (503, True),
12
+ ErrorCode.PIPELINE_CRASH: (500, False),
13
+ }
14
+
15
+
16
+ def map_error(code: ErrorCode | None) -> tuple[int, bool]:
17
+ if code is None:
18
+ return (500, False)
19
+ return ERROR_MAP.get(code, (500, False))
nl2sql/pipeline.py CHANGED
@@ -14,6 +14,7 @@ from nl2sql.verifier import Verifier
14
  from nl2sql.repair import Repair
15
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
16
  from nl2sql.metrics import stage_duration_ms, pipeline_runs_total, repair_attempts_total
 
17
 
18
 
19
  @dataclass(frozen=True)
@@ -21,12 +22,16 @@ class FinalResult:
21
  ok: bool
22
  ambiguous: bool
23
  error: bool
 
24
  details: Optional[List[str]]
25
  sql: Optional[str]
26
  rationale: Optional[str]
27
  verified: Optional[bool]
28
  questions: Optional[List[str]]
29
  traces: List[dict]
 
 
 
30
  result: Optional[Dict[str, Any]] = None
31
 
32
 
 
14
  from nl2sql.repair import Repair
15
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
16
  from nl2sql.metrics import stage_duration_ms, pipeline_runs_total, repair_attempts_total
17
+ from nl2sql.errors.codes import ErrorCode
18
 
19
 
20
  @dataclass(frozen=True)
 
22
  ok: bool
23
  ambiguous: bool
24
  error: bool
25
+
26
  details: Optional[List[str]]
27
  sql: Optional[str]
28
  rationale: Optional[str]
29
  verified: Optional[bool]
30
  questions: Optional[List[str]]
31
  traces: List[dict]
32
+
33
+ error_code: Optional[ErrorCode] = None
34
+
35
  result: Optional[Dict[str, Any]] = None
36
 
37
 
nl2sql/types.py CHANGED
@@ -1,35 +1,62 @@
1
  from dataclasses import dataclass
2
  from typing import Any, Dict, Optional, List
3
 
 
 
 
 
 
 
 
4
 
5
  @dataclass(frozen=True)
6
  class StageTrace:
7
  stage: str
8
- duration_ms: float # keep float internally if you like
9
- summary: str = "" # ← default to keep legacy call-sites working
10
  notes: Optional[Dict[str, Any]] = None
 
 
11
  token_in: Optional[int] = None
12
  token_out: Optional[int] = None
13
  cost_usd: Optional[float] = None
14
 
15
- # Enriched fields
16
  sql_length: Optional[int] = None
17
  row_count: Optional[int] = None
18
  verified: Optional[bool] = None
19
- error_type: Optional[str] = None
20
  repair_attempts: Optional[int] = None
21
  skipped: bool = False
22
 
23
 
 
 
 
 
 
24
  @dataclass(frozen=True)
25
  class StageResult:
26
  ok: bool
 
27
  data: Optional[Any] = None
28
  trace: Optional[StageTrace] = None
 
 
29
  error: Optional[List[str]] = None
 
 
 
 
 
 
30
  notes: Optional[Dict[str, Any]] = None
31
 
32
 
 
 
 
 
 
33
  @dataclass(frozen=True)
34
  class FinalResult:
35
  """
@@ -37,12 +64,21 @@ class FinalResult:
37
  Adapters (HTTP/CLI/UI) should serialize this to dict/JSON at the boundary.
38
  """
39
 
40
- ok: bool # end-to-end success
41
  ambiguous: bool
42
  error: bool
 
 
43
  sql: Optional[str]
44
  rationale: Optional[str]
45
  verified: Optional[bool]
 
 
 
46
  details: Optional[List[str]]
 
 
47
  questions: Optional[List[str]]
 
 
48
  traces: List[Dict[str, Any]]
 
1
  from dataclasses import dataclass
2
  from typing import Any, Dict, Optional, List
3
 
4
+ from nl2sql.errors.codes import ErrorCode
5
+
6
+
7
+ # =====================
8
+ # Tracing / Observability
9
+ # =====================
10
+
11
 
12
  @dataclass(frozen=True)
13
  class StageTrace:
14
  stage: str
15
+ duration_ms: float
16
+ summary: str = ""
17
  notes: Optional[Dict[str, Any]] = None
18
+
19
+ # Optional observability fields
20
  token_in: Optional[int] = None
21
  token_out: Optional[int] = None
22
  cost_usd: Optional[float] = None
23
 
24
+ # Enriched / debug-only fields
25
  sql_length: Optional[int] = None
26
  row_count: Optional[int] = None
27
  verified: Optional[bool] = None
 
28
  repair_attempts: Optional[int] = None
29
  skipped: bool = False
30
 
31
 
32
+ # =====================
33
+ # Stage-level contract
34
+ # =====================
35
+
36
+
37
  @dataclass(frozen=True)
38
  class StageResult:
39
  ok: bool
40
+
41
  data: Optional[Any] = None
42
  trace: Optional[StageTrace] = None
43
+
44
+ # Human-readable error messages (debug / UI only)
45
  error: Optional[List[str]] = None
46
+
47
+ # === Contract-level semantics ===
48
+ error_code: Optional[ErrorCode] = None
49
+ retryable: Optional[bool] = None
50
+
51
+ # Free-form notes (internal use)
52
  notes: Optional[Dict[str, Any]] = None
53
 
54
 
55
+ # =====================
56
+ # Final pipeline result
57
+ # =====================
58
+
59
+
60
  @dataclass(frozen=True)
61
  class FinalResult:
62
  """
 
64
  Adapters (HTTP/CLI/UI) should serialize this to dict/JSON at the boundary.
65
  """
66
 
67
+ ok: bool
68
  ambiguous: bool
69
  error: bool
70
+
71
+ # Output
72
  sql: Optional[str]
73
  rationale: Optional[str]
74
  verified: Optional[bool]
75
+
76
+ # Error surface
77
+ error_code: Optional[ErrorCode]
78
  details: Optional[List[str]]
79
+
80
+ # UX helpers
81
  questions: Optional[List[str]]
82
+
83
+ # Observability
84
  traces: List[Dict[str, Any]]
nl2sql/verifier.py CHANGED
@@ -4,6 +4,7 @@ import re
4
  import time
5
  from typing import Any, Dict
6
 
 
7
  from nl2sql.metrics import verifier_checks_total, verifier_failures_total
8
  from nl2sql.types import StageResult, StageTrace
9
 
@@ -43,6 +44,7 @@ class Verifier:
43
  notes,
44
  error=["parse_error"],
45
  reason=reason,
 
46
  )
47
 
48
  # --- semantic sanity: aggregation without GROUP BY (unless allowed) ---
@@ -83,6 +85,7 @@ class Verifier:
83
  notes,
84
  error=["aggregation_without_group_by"],
85
  reason=reason,
 
86
  )
87
 
88
  # --- DB-backed plan validation (read-only), if adapter provided ---
@@ -94,12 +97,16 @@ class Verifier:
94
  except Exception as e:
95
  reason = "plan-error"
96
  notes["plan_check"] = "failed"
 
 
 
97
  return self._fail(
98
  t0,
99
  notes,
100
  error=[str(e)],
101
  reason=reason,
102
  exc_type=type(e).__name__,
 
103
  )
104
 
105
  # --- pass ---
@@ -125,6 +132,7 @@ class Verifier:
125
  error=[str(e)],
126
  reason=reason,
127
  exc_type=type(e).__name__,
 
128
  )
129
 
130
  def run(
@@ -137,6 +145,25 @@ class Verifier:
137
  # exec_result kept for signature compatibility, not used here.
138
  return self.verify(sql, adapter=adapter)
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  def _fail(
141
  self,
142
  t0: float,
@@ -145,6 +172,7 @@ class Verifier:
145
  error: list[str],
146
  reason: str,
147
  exc_type: str | None = None,
 
148
  ) -> StageResult:
149
  dt = int(round((time.perf_counter() - t0) * 1000.0))
150
 
@@ -166,4 +194,6 @@ class Verifier:
166
  data={"verified": False},
167
  trace=trace,
168
  error=error,
 
 
169
  )
 
4
  import time
5
  from typing import Any, Dict
6
 
7
+ from nl2sql.errors.codes import ErrorCode
8
  from nl2sql.metrics import verifier_checks_total, verifier_failures_total
9
  from nl2sql.types import StageResult, StageTrace
10
 
 
44
  notes,
45
  error=["parse_error"],
46
  reason=reason,
47
+ error_code=ErrorCode.PLAN_SYNTAX_ERROR, # best-fit for malformed SQL
48
  )
49
 
50
  # --- semantic sanity: aggregation without GROUP BY (unless allowed) ---
 
85
  notes,
86
  error=["aggregation_without_group_by"],
87
  reason=reason,
88
+ error_code=ErrorCode.PLAN_SYNTAX_ERROR,
89
  )
90
 
91
  # --- DB-backed plan validation (read-only), if adapter provided ---
 
97
  except Exception as e:
98
  reason = "plan-error"
99
  notes["plan_check"] = "failed"
100
+
101
+ code = self._classify_plan_error(e)
102
+
103
  return self._fail(
104
  t0,
105
  notes,
106
  error=[str(e)],
107
  reason=reason,
108
  exc_type=type(e).__name__,
109
+ error_code=code,
110
  )
111
 
112
  # --- pass ---
 
132
  error=[str(e)],
133
  reason=reason,
134
  exc_type=type(e).__name__,
135
+ error_code=ErrorCode.PIPELINE_CRASH,
136
  )
137
 
138
  def run(
 
145
  # exec_result kept for signature compatibility, not used here.
146
  return self.verify(sql, adapter=adapter)
147
 
148
+ def _classify_plan_error(self, e: Exception) -> ErrorCode:
149
+ msg = str(e).lower()
150
+
151
+ # SQLite-style messages
152
+ if "no such table" in msg:
153
+ return ErrorCode.PLAN_NO_SUCH_TABLE
154
+ if "no such column" in msg:
155
+ return ErrorCode.PLAN_NO_SUCH_COLUMN
156
+ if "syntax error" in msg:
157
+ return ErrorCode.PLAN_SYNTAX_ERROR
158
+
159
+ # Postgres-style messages (common cases)
160
+ if "relation" in msg and "does not exist" in msg:
161
+ return ErrorCode.PLAN_NO_SUCH_TABLE
162
+ if "column" in msg and "does not exist" in msg:
163
+ return ErrorCode.PLAN_NO_SUCH_COLUMN
164
+
165
+ return ErrorCode.PLAN_SYNTAX_ERROR
166
+
167
  def _fail(
168
  self,
169
  t0: float,
 
172
  error: list[str],
173
  reason: str,
174
  exc_type: str | None = None,
175
+ error_code: ErrorCode | None = None,
176
  ) -> StageResult:
177
  dt = int(round((time.perf_counter() - t0) * 1000.0))
178
 
 
194
  data={"verified": False},
195
  trace=trace,
196
  error=error,
197
+ error_code=error_code,
198
+ retryable=False,
199
  )