github-actions[bot] commited on
Commit
b1d64ab
·
1 Parent(s): d279df5

Sync from GitHub main

Browse files
adapters/db/base.py CHANGED
@@ -12,3 +12,6 @@ class DBAdapter(Protocol):
12
 
13
  def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
14
  """Execute a SELECT query and return (rows, columns)."""
 
 
 
 
12
 
13
  def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
14
  """Execute a SELECT query and return (rows, columns)."""
15
+
16
+ def explain_query_plan(self, sql: str) -> None:
17
+ """Validate SQL by asking the DB to plan it (must be read-only). Raise on failure."""
adapters/db/postgres_adapter.py CHANGED
@@ -68,3 +68,16 @@ class PostgresAdapter(DBAdapter):
68
  desc = cur.description or ()
69
  cols: List[str] = [d[0] for d in desc if d]
70
  return rows, cols
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  desc = cur.description or ()
69
  cols: List[str] = [d[0] for d in desc if d]
70
  return rows, cols
71
+
72
+ def explain_query_plan(self, sql: str) -> None:
73
+ sql_stripped = (sql or "").strip().rstrip(";")
74
+ if not sql_stripped.lower().startswith("select"):
75
+ raise ValueError("Only SELECT statements are allowed.")
76
+
77
+ with psycopg.connect(self.dsn) as conn:
78
+ # Make it explicitly read-only at the session level
79
+ with conn.cursor() as cur:
80
+ cur.execute("SET TRANSACTION READ ONLY;")
81
+ cur.execute(f"EXPLAIN {sql_stripped}")
82
+ # We don't need the output; if planning fails, it raises.
83
+ _ = cur.fetchall()
adapters/db/sqlite_adapter.py CHANGED
@@ -44,3 +44,20 @@ class SQLiteAdapter(DBAdapter):
44
  cols = [desc[0] for desc in cur.description]
45
  log.info("Query executed successfully. Returned %d rows.", len(rows))
46
  return rows, cols
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  cols = [desc[0] for desc in cur.description]
45
  log.info("Query executed successfully. Returned %d rows.", len(rows))
46
  return rows, cols
47
+
48
+ def explain_query_plan(self, sql: str) -> None:
49
+ if not self.path.exists():
50
+ raise FileNotFoundError(f"SQLite DB does not exist: {self.path}")
51
+
52
+ sql_stripped = (sql or "").strip().rstrip(";")
53
+ if not sql_stripped.lower().startswith("select"):
54
+ raise ValueError("Only SELECT statements are allowed.")
55
+
56
+ uri = f"file:{self.path}?mode=ro"
57
+ with sqlite3.connect(uri, uri=True, timeout=3) as conn:
58
+ # Extra safety: enforce query-only mode if available
59
+ try:
60
+ conn.execute("PRAGMA query_only = ON;")
61
+ except Exception:
62
+ pass
63
+ conn.execute(f"EXPLAIN QUERY PLAN {sql_stripped}")
app/errors.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
- from dataclasses import dataclass
 
4
 
5
 
6
  @dataclass
@@ -8,38 +9,60 @@ class AppError(Exception):
8
  """Base class for domain-level errors."""
9
 
10
  message: str
 
 
 
 
 
11
 
12
  def __str__(self) -> str:
13
  return self.message
14
 
15
 
16
- # 4xx-ish
17
  @dataclass
18
- class DbNotFound(AppError):
19
- """Requested DB (or db_id) does not exist."""
 
20
 
21
 
22
  @dataclass
23
- class InvalidRequest(AppError):
24
- """User input is invalid or cannot be processed."""
 
25
 
26
 
27
  @dataclass
28
- class SchemaRequired(AppError):
29
- """Caller must provide schema_preview (e.g. postgres mode)."""
 
30
 
31
 
 
32
  @dataclass
33
- class SchemaDeriveError(AppError):
34
- """Failed to derive schema preview from DB."""
 
 
35
 
36
 
37
- # 5xx-ish
38
  @dataclass
39
  class PipelineConfigError(AppError):
40
- """Pipeline/YAML/config is missing or malformed."""
 
41
 
42
 
43
  @dataclass
44
  class PipelineRunError(AppError):
45
- """Unexpected failure while running the pipeline."""
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, Optional, List
5
 
6
 
7
  @dataclass
 
9
  """Base class for domain-level errors."""
10
 
11
  message: str
12
+ http_status: int = 500
13
+ code: str = "internal_error"
14
+ retryable: bool = False
15
+ extra: Dict[str, Any] = field(default_factory=dict)
16
+ details: Optional[List[str]] = None
17
 
18
  def __str__(self) -> str:
19
  return self.message
20
 
21
 
22
+ # 4xx
23
  @dataclass
24
+ class BadRequestError(AppError):
25
+ http_status: int = 400
26
+ code: str = "bad_request"
27
 
28
 
29
  @dataclass
30
+ class SafetyViolationError(AppError):
31
+ http_status: int = 422
32
+ code: str = "safety_violation"
33
 
34
 
35
  @dataclass
36
+ class SchemaDeriveError(AppError):
37
+ http_status: int = 400
38
+ code: str = "schema_derive_error"
39
 
40
 
41
+ # 5xx-ish
42
  @dataclass
43
+ class DependencyError(AppError):
44
+ http_status: int = 503
45
+ code: str = "dependency_error"
46
+ retryable: bool = True
47
 
48
 
 
49
  @dataclass
50
  class PipelineConfigError(AppError):
51
+ http_status: int = 500
52
+ code: str = "pipeline_config_error"
53
 
54
 
55
  @dataclass
56
  class PipelineRunError(AppError):
57
+ http_status: int = 500
58
+ code: str = "pipeline_run_error"
59
+
60
+
61
+ @dataclass
62
+ class DbNotFound(BadRequestError):
63
+ code: str = "db_not_found"
64
+
65
+
66
+ @dataclass
67
+ class SchemaRequired(BadRequestError):
68
+ code: str = "schema_required"
app/exception_handlers.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
- from typing import Any, Dict
 
4
 
5
  from fastapi import FastAPI, Request
6
  from fastapi.responses import JSONResponse
@@ -9,24 +10,32 @@ from app.errors import AppError
9
 
10
 
11
  def register_exception_handlers(app: FastAPI) -> None:
12
- """
13
- Register global exception handlers for the FastAPI application.
14
- """
15
 
16
  @app.exception_handler(AppError)
17
  async def app_error_handler(request: Request, exc: AppError) -> JSONResponse:
18
- """
19
- Map domain-level AppError instances to HTTP responses.
20
- This keeps routers thin and lets the domain raise AppError freely.
21
- """
22
  status = getattr(exc, "http_status", 500)
23
  code = getattr(exc, "code", "app_error")
24
  message = getattr(exc, "message", str(exc))
 
25
  extra: Dict[str, Any] = getattr(exc, "extra", {}) or {}
 
26
 
27
  payload = {
28
- "code": code,
29
- "message": message,
30
- "extra": extra,
 
 
 
 
 
31
  }
32
- return JSONResponse(status_code=status, content=payload)
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import uuid
4
+ from typing import Any, Dict, Optional, List
5
 
6
  from fastapi import FastAPI, Request
7
  from fastapi.responses import JSONResponse
 
10
 
11
 
12
  def register_exception_handlers(app: FastAPI) -> None:
13
+ """Register global exception handlers for the FastAPI application."""
 
 
14
 
15
  @app.exception_handler(AppError)
16
  async def app_error_handler(request: Request, exc: AppError) -> JSONResponse:
17
+ request_id = request.headers.get("X-Request-ID") or str(uuid.uuid4())
18
+
 
 
19
  status = getattr(exc, "http_status", 500)
20
  code = getattr(exc, "code", "app_error")
21
  message = getattr(exc, "message", str(exc))
22
+ retryable = bool(getattr(exc, "retryable", False))
23
  extra: Dict[str, Any] = getattr(exc, "extra", {}) or {}
24
+ details: Optional[List[str]] = getattr(exc, "details", None)
25
 
26
  payload = {
27
+ "error": {
28
+ "code": code,
29
+ "message": message,
30
+ "details": details,
31
+ "retryable": retryable,
32
+ "request_id": request_id,
33
+ "extra": extra,
34
+ }
35
  }
36
+
37
+ headers = {"X-Request-ID": request_id}
38
+ if retryable:
39
+ headers["Retry-After"] = "2"
40
+
41
+ return JSONResponse(status_code=status, content=payload, headers=headers)
app/routers/nl2sql.py CHANGED
@@ -23,6 +23,10 @@ from app.services.nl2sql_service import NL2SQLService
23
  from app.settings import get_settings
24
  from app.errors import (
25
  AppError,
 
 
 
 
26
  )
27
 
28
  logger = logging.getLogger(__name__)
@@ -330,14 +334,12 @@ def nl2sql_handler(
330
  # Let the global handler convert it to an HTTP response.
331
  raise
332
  except Exception as exc:
333
- logger.exception(
334
- "Unexpected pipeline crash in NL2SQLService.run_query",
335
- exc_info=exc,
 
 
336
  )
337
- raise HTTPException(
338
- status_code=500,
339
- detail="internal pipeline error",
340
- ) from exc
341
 
342
  # ---- type sanity check ----
343
  if not isinstance(result, FinalResult):
@@ -345,9 +347,10 @@ def nl2sql_handler(
345
  "Pipeline returned unexpected type",
346
  extra={"type": type(result).__name__},
347
  )
348
- raise HTTPException(
349
- status_code=500,
350
- detail="pipeline returned unexpected type",
 
351
  )
352
 
353
  # ---- ambiguity path → 200 with clarification questions ----
@@ -355,18 +358,66 @@ def nl2sql_handler(
355
  qs = result.questions or []
356
  return ClarifyResponse(ambiguous=True, questions=qs)
357
 
358
- # ---- error path 400 with joined details ----
359
  if (not result.ok) or result.error:
360
  logger.debug(
361
  "Pipeline reported failure",
362
- extra={
363
- "ok": result.ok,
364
- "error": result.error,
365
- "details": result.details,
366
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  )
368
- message = "; ".join(result.details or []) or "Unknown error"
369
- raise HTTPException(status_code=400, detail=message)
370
 
371
  # ---- success path → 200 (normalize traces and executor result) ----
372
  traces = [_round_trace(t) for t in (result.traces or [])]
 
23
  from app.settings import get_settings
24
  from app.errors import (
25
  AppError,
26
+ BadRequestError,
27
+ SafetyViolationError,
28
+ DependencyError,
29
+ PipelineRunError,
30
  )
31
 
32
  logger = logging.getLogger(__name__)
 
334
  # Let the global handler convert it to an HTTP response.
335
  raise
336
  except Exception as exc:
337
+ logger.exception("Unexpected pipeline crash in NL2SQLService.run_query")
338
+ raise PipelineRunError(
339
+ message="Internal pipeline error.",
340
+ details=[str(exc)],
341
+ extra={"stage": "unknown"},
342
  )
 
 
 
 
343
 
344
  # ---- type sanity check ----
345
  if not isinstance(result, FinalResult):
 
347
  "Pipeline returned unexpected type",
348
  extra={"type": type(result).__name__},
349
  )
350
+ raise PipelineRunError(
351
+ message="Pipeline returned unexpected type.",
352
+ details=[type(result).__name__],
353
+ extra={"stage": "unknown"},
354
  )
355
 
356
  # ---- ambiguity path → 200 with clarification questions ----
 
358
  qs = result.questions or []
359
  return ClarifyResponse(ambiguous=True, questions=qs)
360
 
361
+ # ---- error path: map pipeline failures to stable HTTP+JSON error contract ----
362
  if (not result.ok) or result.error:
363
  logger.debug(
364
  "Pipeline reported failure",
365
+ extra={"ok": result.ok, "error": result.error, "details": result.details},
366
+ )
367
+
368
+ details = list(result.details or [])
369
+ traces = list(result.traces or [])
370
+ last_stage = str(traces[-1].get("stage", "unknown")) if traces else "unknown"
371
+ details_l = " ".join(d.lower() for d in details)
372
+
373
+ # 1) Safety violations → 422
374
+ if last_stage == "safety":
375
+ raise SafetyViolationError(
376
+ message="Rejected by safety checks.",
377
+ details=details or None,
378
+ extra={"stage": last_stage},
379
+ )
380
+
381
+ # 2) Retryable dependency failures → 503
382
+ retry_hints = (
383
+ "timeout",
384
+ "timed out",
385
+ "rate limit",
386
+ "429",
387
+ "too many requests",
388
+ "locked",
389
+ "busy",
390
+ "unavailable",
391
+ "connection",
392
+ )
393
+ if any(h in details_l for h in retry_hints):
394
+ raise DependencyError(
395
+ message="Temporary dependency failure. Please retry.",
396
+ details=details or None,
397
+ extra={"stage": last_stage},
398
+ )
399
+
400
+ # 3) User-fixable parse/syntax-ish errors → 400
401
+ user_hints = (
402
+ "parse_error",
403
+ "non-select",
404
+ "explain not allowed",
405
+ "multiple statements",
406
+ "forbidden",
407
+ )
408
+ if any(h in details_l for h in user_hints):
409
+ raise BadRequestError(
410
+ message="Request could not be processed.",
411
+ details=details or None,
412
+ extra={"stage": last_stage},
413
+ )
414
+
415
+ # 4) Default → 500
416
+ raise PipelineRunError(
417
+ message="Pipeline failed unexpectedly.",
418
+ details=details or None,
419
+ extra={"stage": last_stage},
420
  )
 
 
421
 
422
  # ---- success path → 200 (normalize traces and executor result) ----
423
  traces = [_round_trace(t) for t in (result.traces or [])]
nl2sql/pipeline.py CHANGED
@@ -422,22 +422,41 @@ class Pipeline:
422
  if r_exec.ok and isinstance(r_exec.data, dict):
423
  exec_result = dict(r_exec.data)
424
 
425
- # --- 6) verifier ---
426
  t0 = time.perf_counter()
427
- r_ver = self._safe_stage(
 
428
  self.verifier.run,
 
 
429
  sql=sql,
430
  exec_result=(r_exec.data or {}),
431
- adapter=getattr(
432
- self.executor, "adapter", None
433
- ), # let verifier use adapter
434
  )
435
  dt = (time.perf_counter() - t0) * 1000.0
436
  stage_duration_ms.labels("verifier").observe(dt)
 
 
437
  traces.extend(self._trace_list(r_ver))
438
  if not getattr(r_ver, "trace", None):
439
  _fallback_trace("verifier", dt, r_ver.ok)
440
- verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
  # consume repaired SQL from verifier if any
443
  if r_ver.data and "sql" in r_ver.data and r_ver.data["sql"]:
 
422
  if r_exec.ok and isinstance(r_exec.data, dict):
423
  exec_result = dict(r_exec.data)
424
 
425
+ # --- 6) verifier (run with repair for consistency) ---
426
  t0 = time.perf_counter()
427
+ r_ver = self._run_with_repair(
428
+ "verifier",
429
  self.verifier.run,
430
+ repair_input_builder=self._sql_repair_input_builder,
431
+ max_attempts=1,
432
  sql=sql,
433
  exec_result=(r_exec.data or {}),
434
+ adapter=getattr(self.executor, "adapter", None),
435
+ traces=traces,
 
436
  )
437
  dt = (time.perf_counter() - t0) * 1000.0
438
  stage_duration_ms.labels("verifier").observe(dt)
439
+
440
+ # Traces
441
  traces.extend(self._trace_list(r_ver))
442
  if not getattr(r_ver, "trace", None):
443
  _fallback_trace("verifier", dt, r_ver.ok)
444
+
445
+ # If verifier (or its repair) produced a new SQL, consume it
446
+ if r_ver.data and isinstance(r_ver.data, dict):
447
+ repaired_sql = r_ver.data.get("sql")
448
+ if repaired_sql:
449
+ sql = repaired_sql
450
+
451
+ # Verified flag
452
+ verified = (
453
+ bool(
454
+ r_ver.data
455
+ and isinstance(r_ver.data, dict)
456
+ and r_ver.data.get("verified")
457
+ )
458
+ or r_ver.ok
459
+ )
460
 
461
  # consume repaired SQL from verifier if any
462
  if r_ver.data and "sql" in r_ver.data and r_ver.data["sql"]:
nl2sql/verifier.py CHANGED
@@ -4,32 +4,33 @@ import re
4
  import time
5
  from typing import Any, Dict
6
 
 
7
  from nl2sql.types import StageResult, StageTrace
8
- from nl2sql.metrics import (
9
- verifier_checks_total,
10
- verifier_failures_total,
11
- )
12
 
 
13
 
14
- class Verifier:
15
- """Static verifier used by tests.
16
 
17
- Provides verify(...) for tests and run(...) for pipeline.
 
 
 
 
 
18
  """
19
 
20
  required = False
21
 
22
- def verify(self, sql: str, *, adapter: Any | None = None) -> StageResult:
23
  t0 = time.perf_counter()
24
  notes: Dict[str, Any] = {}
25
- reason = "ok" # new field
26
 
27
  s = (sql or "").strip()
28
  sl = s.lower()
29
  notes["sql_length"] = len(s)
30
 
31
  try:
32
- # --- quick parse sanity: require SELECT and FROM ---
33
  has_select = bool(re.search(r"\bselect\b", sl))
34
  has_from = bool(re.search(r"\bfrom\b", sl))
35
  notes["has_select"] = has_select
@@ -45,6 +46,7 @@ class Verifier:
45
  )
46
 
47
  # --- semantic sanity: aggregation without GROUP BY (unless allowed) ---
 
48
  has_over = " over (" in sl
49
  has_group_by = " group by " in sl
50
  has_distinct = sl.startswith("select distinct") or (
@@ -83,20 +85,29 @@ class Verifier:
83
  reason=reason,
84
  )
85
 
86
- # --- execution-error sentinel for tests ---
87
- if "imaginary_table" in sl:
88
- reason = "exec-error"
89
- return self._fail(
90
- t0,
91
- notes,
92
- error=["exec_error: no such table: imaginary_table"],
93
- reason=reason,
94
- )
 
 
 
 
 
 
 
95
 
96
  # --- pass ---
97
  dt = int(round((time.perf_counter() - t0) * 1000.0))
98
  notes.update({"verified": True, "reason": reason})
 
99
  verifier_checks_total.labels(ok="true").inc()
 
100
  trace = StageTrace(
101
  stage="verifier",
102
  duration_ms=dt,
@@ -106,6 +117,7 @@ class Verifier:
106
  return StageResult(ok=True, data={"verified": True}, trace=trace)
107
 
108
  except Exception as e:
 
109
  reason = "exception"
110
  return self._fail(
111
  t0,
@@ -115,6 +127,16 @@ class Verifier:
115
  exc_type=type(e).__name__,
116
  )
117
 
 
 
 
 
 
 
 
 
 
 
118
  def _fail(
119
  self,
120
  t0: float,
@@ -125,6 +147,7 @@ class Verifier:
125
  exc_type: str | None = None,
126
  ) -> StageResult:
127
  dt = int(round((time.perf_counter() - t0) * 1000.0))
 
128
  notes.update({"verified": False, "reason": reason})
129
  if exc_type:
130
  notes["exception_type"] = exc_type
@@ -144,8 +167,3 @@ class Verifier:
144
  trace=trace,
145
  error=error,
146
  )
147
-
148
- def run(
149
- self, *, sql: str, exec_result: Dict[str, Any], adapter: Any = None
150
- ) -> StageResult:
151
- return self.verify(sql, adapter=adapter)
 
4
  import time
5
  from typing import Any, Dict
6
 
7
+ from nl2sql.metrics import verifier_checks_total, verifier_failures_total
8
  from nl2sql.types import StageResult, StageTrace
 
 
 
 
9
 
10
+ from adapters.db.base import DBAdapter
11
 
 
 
12
 
13
+ class Verifier:
14
+ """
15
+ Verifier stage:
16
+ - Lightweight sanity checks (lint-like; NOT safety policy)
17
+ - Optional DB-backed plan validation via adapter.explain_query_plan(sql)
18
+ (read-only, no query execution)
19
  """
20
 
21
  required = False
22
 
23
+ def verify(self, sql: str, *, adapter: DBAdapter | None = None) -> StageResult:
24
  t0 = time.perf_counter()
25
  notes: Dict[str, Any] = {}
26
+ reason = "ok"
27
 
28
  s = (sql or "").strip()
29
  sl = s.lower()
30
  notes["sql_length"] = len(s)
31
 
32
  try:
33
+ # --- quick sanity: require SELECT and FROM (lint-like) ---
34
  has_select = bool(re.search(r"\bselect\b", sl))
35
  has_from = bool(re.search(r"\bfrom\b", sl))
36
  notes["has_select"] = has_select
 
46
  )
47
 
48
  # --- semantic sanity: aggregation without GROUP BY (unless allowed) ---
49
+ # This is NOT a safety rule; it is a quality check to catch common mistakes.
50
  has_over = " over (" in sl
51
  has_group_by = " group by " in sl
52
  has_distinct = sl.startswith("select distinct") or (
 
85
  reason=reason,
86
  )
87
 
88
+ # --- DB-backed plan validation (read-only), if adapter provided ---
89
+ # Safety policy (SELECT-only, no multi-statement, etc.) must be enforced upstream.
90
+ if adapter is not None:
91
+ try:
92
+ adapter.explain_query_plan(s)
93
+ notes["plan_check"] = "ok"
94
+ except Exception as e:
95
+ reason = "plan-error"
96
+ notes["plan_check"] = "failed"
97
+ return self._fail(
98
+ t0,
99
+ notes,
100
+ error=[str(e)],
101
+ reason=reason,
102
+ exc_type=type(e).__name__,
103
+ )
104
 
105
  # --- pass ---
106
  dt = int(round((time.perf_counter() - t0) * 1000.0))
107
  notes.update({"verified": True, "reason": reason})
108
+
109
  verifier_checks_total.labels(ok="true").inc()
110
+
111
  trace = StageTrace(
112
  stage="verifier",
113
  duration_ms=dt,
 
117
  return StageResult(ok=True, data={"verified": True}, trace=trace)
118
 
119
  except Exception as e:
120
+ # Unexpected verifier crash (bug)
121
  reason = "exception"
122
  return self._fail(
123
  t0,
 
127
  exc_type=type(e).__name__,
128
  )
129
 
130
+ def run(
131
+ self,
132
+ *,
133
+ sql: str,
134
+ exec_result: Dict[str, Any],
135
+ adapter: DBAdapter | None = None,
136
+ ) -> StageResult:
137
+ # exec_result kept for signature compatibility, not used here.
138
+ return self.verify(sql, adapter=adapter)
139
+
140
  def _fail(
141
  self,
142
  t0: float,
 
147
  exc_type: str | None = None,
148
  ) -> StageResult:
149
  dt = int(round((time.perf_counter() - t0) * 1000.0))
150
+
151
  notes.update({"verified": False, "reason": reason})
152
  if exc_type:
153
  notes["exception_type"] = exc_type
 
167
  trace=trace,
168
  error=error,
169
  )