github-actions[bot] commited on
Commit
3a2092d
·
1 Parent(s): d8756b9

Sync from GitHub main @ bb85132be0e6ce9fb3b72683c3321f8a7056eeb6

Browse files
app/dependencies.py CHANGED
@@ -20,6 +20,8 @@ def get_nl2sql_service() -> NL2SQLService:
20
  def get_cache() -> NL2SQLCache:
21
  """
22
  Singleton in-memory cache for NL2SQL responses.
23
- TTL is intentionally short; this is a per-process best-effort cache.
 
24
  """
25
- return NL2SQLCache(ttl=15.0)
 
 
20
  def get_cache() -> NL2SQLCache:
21
  """
22
  Singleton in-memory cache for NL2SQL responses.
23
+
24
+ TTL is loaded from Settings (NL2SQL_CACHE_TTL_SEC).
25
  """
26
+ settings = get_settings()
27
+ return NL2SQLCache(ttl=float(settings.cache_ttl_sec))
app/routers/nl2sql.py CHANGED
@@ -5,7 +5,7 @@ from dataclasses import asdict, is_dataclass
5
  import os
6
  from pathlib import Path
7
  import uuid
8
- from typing import Any, Dict, Optional, Tuple, cast
9
  import hashlib
10
  import logging
11
 
@@ -50,68 +50,12 @@ def require_api_key(key: Optional[str] = Security(api_key_header)):
50
  raise HTTPException(status_code=401, detail="invalid API key")
51
 
52
 
53
- ####################################
54
- # ---- Simple in-memory cache for NL→SQL responses ----
55
-
56
- # Cache TTL and max size from centralized settings
57
- _CACHE_TTL = settings.cache_ttl_sec
58
- _CACHE_MAX = settings.cache_max_entries
59
- _CACHE: Dict[Tuple[str, str, str], Tuple[float, Dict[str, Any]]] = {}
60
-
61
-
62
- def _norm_q(s: str) -> str:
63
- """Normalize a user query for cache key purposes."""
64
- return (s or "").strip().lower()
65
-
66
-
67
- def _schema_key(preview: str) -> str:
68
- """Hash the schema preview so we do not store huge strings in the cache key."""
69
- return hashlib.md5((preview or "").encode()).hexdigest()
70
-
71
-
72
- def _ck(
73
- db_id: Optional[str],
74
- query: str,
75
- schema_preview: str,
76
- ) -> str:
77
- """
78
- Build a stable cache key for (db_id, query, schema_preview).
79
-
80
- We keep the external cache API string-based, and hash the
81
- potentially large schema_preview to avoid huge dictionary keys.
82
- """
83
- # Normalize db_id
84
- db_part = db_id or "__default__"
85
-
86
- # Build a single string seed
87
- seed = f"{db_part}\n{query}\n{schema_preview}"
88
-
89
- # Short, deterministic key
90
- return hashlib.sha1(seed.encode("utf-8")).hexdigest()
91
-
92
-
93
- def _cache_gc(now: float) -> None:
94
- """
95
- Garbage-collect cache entries by TTL and max size.
96
- """
97
- # TTL eviction
98
- for k, (ts, _) in list(_CACHE.items()):
99
- if now - ts > _CACHE_TTL:
100
- _CACHE.pop(k, None)
101
-
102
- # Size eviction (naive FIFO-style)
103
- while len(_CACHE) > _CACHE_MAX:
104
- _CACHE.pop(next(iter(_CACHE)), None)
105
-
106
-
107
- ####################################
108
-
109
  router = APIRouter(prefix="/nl2sql")
110
 
111
  # -------------------------------
112
  # Config / Defaults
113
  # -------------------------------
114
- DB_MODE = settings.db_mode.lower() # "sqlite" or "postgres"
115
 
116
  # Runtime upload storage for SQLite DBs
117
  _DB_UPLOAD_DIR = settings.db_upload_dir
@@ -127,29 +71,14 @@ logger.debug(
127
  )
128
 
129
 
130
- # -------------------------------
131
- # Schema preview endpoint
132
- # -------------------------------
133
-
134
-
135
  @router.get("/schema")
136
  def schema_endpoint(
137
  db_id: Optional[str] = None,
138
  svc: NL2SQLService = Depends(get_nl2sql_service),
139
  ):
140
- """
141
- Return a lightweight schema preview string for the given DB.
142
-
143
- - If db_id is provided, service will resolve the uploaded DB.
144
- - If not, service falls back to the default DB.
145
- - In postgres mode, caller must usually provide schema_preview explicitly.
146
- Domain errors (AppError subclasses) are handled by the global exception handler.
147
- This endpoint only wraps truly unexpected errors into a generic HTTP 500
148
- """
149
  try:
150
  preview = svc.get_schema_preview(db_id=db_id, override=None)
151
  except AppError:
152
- # Let the global AppError handler deal with it.
153
  raise
154
  except Exception as exc:
155
  logger.exception("Unexpected error in schema_endpoint", exc_info=exc)
@@ -176,15 +105,6 @@ def _to_dict(obj: Any) -> Any:
176
 
177
 
178
  def _round_trace(t: Any) -> Dict[str, Any]:
179
- """
180
- Normalize a trace entry (dict or StageTrace-like object) for API/UI:
181
-
182
- - stage: str (required)
183
- - duration_ms: int (rounded)
184
- - summary: optional (pass-through if exists)
185
- - notes: optional
186
- - token_in/out, cost_usd: pass-through if present
187
- """
188
  if isinstance(t, dict):
189
  stage = t.get("stage", "?")
190
  ms = t.get("duration_ms", 0)
@@ -275,26 +195,23 @@ def health():
275
  return {"status": "ok", "version": settings.app_version}
276
 
277
 
278
- # -------------------------------
279
- # Main NL2SQL endpoint
280
- # -------------------------------
 
281
 
282
 
283
- @router.post("", name="nl2sql_handler", dependencies=[Depends(require_api_key)])
 
 
 
 
 
284
  def nl2sql_handler(
285
  request: NL2SQLRequest,
286
  svc: NL2SQLService = Depends(get_nl2sql_service),
287
  cache: NL2SQLCache = Depends(get_cache),
288
- ) -> NL2SQLResponse | ClarifyResponse | Dict[str, Any]:
289
- """
290
- Main NL→SQL handler.
291
-
292
- Flow:
293
- - Resolve schema preview (client override or derived from DB).
294
- - Check in-memory cache (db_id + query + schema hash).
295
- - Run the pipeline through NL2SQLService.
296
- - Map FinalResult to API response or HTTP error.
297
- """
298
  db_id = getattr(request, "db_id", None)
299
 
300
  # ---- schema preview ----
@@ -320,7 +237,10 @@ def nl2sql_handler(
320
  cache_key = _ck(db_id, request.query, final_preview)
321
  cached_payload = cache.get(cache_key)
322
  if cached_payload is not None:
323
- return cached_payload
 
 
 
324
 
325
  # ---- pipeline execution via service ----
326
  try:
@@ -354,8 +274,7 @@ def nl2sql_handler(
354
 
355
  # ---- ambiguity path → 200 with clarification questions ----
356
  if result.ambiguous:
357
- qs = result.questions or []
358
- return ClarifyResponse(ambiguous=True, questions=qs)
359
 
360
  # ---- error path: contract-based mapping (Phase 3) ----
361
  if (not result.ok) or result.error:
 
5
  import os
6
  from pathlib import Path
7
  import uuid
8
+ from typing import Any, Dict, Optional, cast
9
  import hashlib
10
  import logging
11
 
 
50
  raise HTTPException(status_code=401, detail="invalid API key")
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  router = APIRouter(prefix="/nl2sql")
54
 
55
  # -------------------------------
56
  # Config / Defaults
57
  # -------------------------------
58
+ DB_MODE = settings.db_mode.lower()
59
 
60
  # Runtime upload storage for SQLite DBs
61
  _DB_UPLOAD_DIR = settings.db_upload_dir
 
71
  )
72
 
73
 
 
 
 
 
 
74
  @router.get("/schema")
75
  def schema_endpoint(
76
  db_id: Optional[str] = None,
77
  svc: NL2SQLService = Depends(get_nl2sql_service),
78
  ):
 
 
 
 
 
 
 
 
 
79
  try:
80
  preview = svc.get_schema_preview(db_id=db_id, override=None)
81
  except AppError:
 
82
  raise
83
  except Exception as exc:
84
  logger.exception("Unexpected error in schema_endpoint", exc_info=exc)
 
105
 
106
 
107
  def _round_trace(t: Any) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
108
  if isinstance(t, dict):
109
  stage = t.get("stage", "?")
110
  ms = t.get("duration_ms", 0)
 
195
  return {"status": "ok", "version": settings.app_version}
196
 
197
 
198
+ def _ck(db_id: Optional[str], query: str, schema_preview: str) -> str:
199
+ db_part = db_id or "__default__"
200
+ seed = f"{db_part}\n{query}\n{schema_preview}"
201
+ return hashlib.sha1(seed.encode("utf-8")).hexdigest()
202
 
203
 
204
+ @router.post(
205
+ "",
206
+ name="nl2sql_handler",
207
+ dependencies=[Depends(require_api_key)],
208
+ response_model=NL2SQLResponse | ClarifyResponse,
209
+ )
210
  def nl2sql_handler(
211
  request: NL2SQLRequest,
212
  svc: NL2SQLService = Depends(get_nl2sql_service),
213
  cache: NL2SQLCache = Depends(get_cache),
214
+ ) -> NL2SQLResponse | ClarifyResponse:
 
 
 
 
 
 
 
 
 
215
  db_id = getattr(request, "db_id", None)
216
 
217
  # ---- schema preview ----
 
237
  cache_key = _ck(db_id, request.query, final_preview)
238
  cached_payload = cache.get(cache_key)
239
  if cached_payload is not None:
240
+ # Cache stores dicts; convert back to response models for type safety.
241
+ if isinstance(cached_payload, dict) and cached_payload.get("ambiguous") is True:
242
+ return ClarifyResponse.model_validate(cached_payload)
243
+ return NL2SQLResponse.model_validate(cached_payload)
244
 
245
  # ---- pipeline execution via service ----
246
  try:
 
274
 
275
  # ---- ambiguity path → 200 with clarification questions ----
276
  if result.ambiguous:
277
+ return ClarifyResponse(questions=(result.questions or []))
 
278
 
279
  # ---- error path: contract-based mapping (Phase 3) ----
280
  if (not result.ok) or result.error:
app/schemas.py CHANGED
@@ -1,6 +1,7 @@
1
- from pydantic import BaseModel
2
  from typing import List, Optional, Any, Dict
3
 
 
 
4
 
5
  class NL2SQLRequest(BaseModel):
6
  query: str
@@ -21,16 +22,16 @@ class TraceModel(BaseModel):
21
 
22
 
23
  class NL2SQLResponse(BaseModel):
24
- ambiguous: bool
25
  sql: Optional[str] = None
26
  rationale: Optional[str] = None
27
- traces: List[Dict[str, Any]] = []
28
- result: Dict[str, Any] = {}
29
 
30
 
31
  class ClarifyResponse(BaseModel):
32
  ambiguous: bool = True
33
- questions: List[str]
34
 
35
 
36
  class ErrorResponse(BaseModel):
 
 
1
  from typing import List, Optional, Any, Dict
2
 
3
+ from pydantic import BaseModel, Field
4
+
5
 
6
  class NL2SQLRequest(BaseModel):
7
  query: str
 
22
 
23
 
24
  class NL2SQLResponse(BaseModel):
25
+ ambiguous: bool = False
26
  sql: Optional[str] = None
27
  rationale: Optional[str] = None
28
+ traces: List[Dict[str, Any]] = Field(default_factory=list)
29
+ result: Dict[str, Any] = Field(default_factory=dict)
30
 
31
 
32
  class ClarifyResponse(BaseModel):
33
  ambiguous: bool = True
34
+ questions: List[str] = Field(default_factory=list)
35
 
36
 
37
  class ErrorResponse(BaseModel):
scripts/verify_metrics_wiring.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import json
4
+ import re
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Any, Iterable
8
+
9
+ ROOT = Path(__file__).resolve().parents[1]
10
+
11
+ PROMETHEUS_FILE = ROOT / "adapters" / "metrics" / "prometheus.py"
12
+ RULES_FILE = ROOT / "infra" / "prometheus" / "rules.yml"
13
+ DASHBOARD_DIR = ROOT / "infra" / "grafana" / "dashboards"
14
+
15
+ # Extract metric names from prometheus client constructors:
16
+ # Counter("x", ...), Gauge("x", ...), Histogram("x", ...), Summary("x", ...)
17
+ METRIC_CTOR_RE = re.compile(r'\b(?:Counter|Gauge|Histogram|Summary)\(\s*"([^"]+)"')
18
+
19
+ # Fallback in case a metric is defined via keyword arg name="..."
20
+ METRIC_NAME_KW_RE = re.compile(r'\bname\s*=\s*"([^"]+)"')
21
+
22
+ # PromQL token pattern
23
+ PROMQL_TOKEN_RE = re.compile(r"([a-zA-Z_:][a-zA-Z0-9_:]*)")
24
+
25
+ PROMQL_KEYWORDS_AND_FUNCS = {
26
+ # aggregations / funcs
27
+ "sum",
28
+ "rate",
29
+ "increase",
30
+ "irate",
31
+ "avg",
32
+ "min",
33
+ "max",
34
+ "count",
35
+ "count_values",
36
+ "stddev",
37
+ "stdvar",
38
+ "bottomk",
39
+ "topk",
40
+ "quantile",
41
+ "histogram_quantile",
42
+ "clamp_min",
43
+ "clamp_max",
44
+ "abs",
45
+ "round",
46
+ "floor",
47
+ "ceil",
48
+ "scalar",
49
+ "vector",
50
+ "sort",
51
+ "sort_desc",
52
+ "label_replace",
53
+ "label_join",
54
+ "time",
55
+ # modifiers / keywords
56
+ "by",
57
+ "without",
58
+ "offset",
59
+ "bool",
60
+ "on",
61
+ "ignoring",
62
+ "group_left",
63
+ "group_right",
64
+ # literals / common
65
+ "true",
66
+ "false",
67
+ "nan",
68
+ "inf",
69
+ }
70
+
71
+ PROMQL_LABEL_KEYS = {
72
+ "le",
73
+ "job",
74
+ "instance",
75
+ "stage",
76
+ "status",
77
+ "outcome",
78
+ "hit",
79
+ "ok",
80
+ }
81
+
82
+ # label values that appear in your rules/dashboards
83
+ PROMQL_COMMON_LABEL_VALUES = {
84
+ "attempt",
85
+ "success",
86
+ "failed",
87
+ "ok",
88
+ "error",
89
+ "true",
90
+ "false",
91
+ }
92
+
93
+ # time units that can show up e.g. [5m], [10s]
94
+ PROMQL_TIME_UNITS = {"ms", "s", "m", "h", "d", "w", "y"}
95
+
96
+
97
+ def extract_defined_metrics() -> set[str]:
98
+ text = PROMETHEUS_FILE.read_text(encoding="utf-8")
99
+ defined = set(METRIC_CTOR_RE.findall(text))
100
+ defined |= set(METRIC_NAME_KW_RE.findall(text))
101
+ return defined
102
+
103
+
104
+ def _collect_promql_from_rules_yml(text: str) -> list[str]:
105
+ """
106
+ Extract only PromQL expressions from rules.yml:
107
+ - expr: <single line>
108
+ - expr: | (multiline indented block)
109
+ - expr: > (multiline indented block)
110
+ """
111
+ lines = text.splitlines()
112
+ exprs: list[str] = []
113
+
114
+ i = 0
115
+ while i < len(lines):
116
+ line = lines[i]
117
+ stripped = line.lstrip()
118
+ if not stripped.startswith("expr:"):
119
+ i += 1
120
+ continue
121
+
122
+ indent = len(line) - len(stripped)
123
+ rest = stripped[len("expr:") :].strip()
124
+
125
+ # Case 1: expr: <single-line>
126
+ if rest and rest not in {"|", ">"}:
127
+ exprs.append(rest)
128
+ i += 1
129
+ continue
130
+
131
+ # Case 2: expr: | or expr: > or expr: (empty) with following indented block
132
+ i += 1
133
+ block_lines: list[str] = []
134
+ while i < len(lines):
135
+ nxt = lines[i]
136
+ nxt_stripped = nxt.lstrip()
137
+ nxt_indent = len(nxt) - len(nxt_stripped)
138
+
139
+ # Stop when indentation returns to expr level (or less)
140
+ if nxt_stripped and nxt_indent <= indent:
141
+ break
142
+
143
+ # Keep blank lines inside block as separators
144
+ block_lines.append(nxt_stripped)
145
+ i += 1
146
+
147
+ expr = "\n".join(block_lines).strip()
148
+ if expr:
149
+ exprs.append(expr)
150
+
151
+ return exprs
152
+
153
+
154
+ def _collect_promql_from_dashboard_json(obj: Any) -> Iterable[str]:
155
+ """
156
+ Recursively collect PromQL strings from Grafana dashboard JSON.
157
+ Common keys are: "expr" (Prometheus target), sometimes "query".
158
+ """
159
+ if isinstance(obj, dict):
160
+ for k, v in obj.items():
161
+ if k in {"expr", "query"} and isinstance(v, str):
162
+ yield v
163
+ else:
164
+ yield from _collect_promql_from_dashboard_json(v)
165
+ elif isinstance(obj, list):
166
+ for item in obj:
167
+ yield from _collect_promql_from_dashboard_json(item)
168
+
169
+
170
+ def extract_promql_sources() -> list[str]:
171
+ sources: list[str] = []
172
+
173
+ # rules.yml
174
+ rules_text = RULES_FILE.read_text(encoding="utf-8")
175
+ sources.extend(_collect_promql_from_rules_yml(rules_text))
176
+
177
+ # dashboards
178
+ for path in DASHBOARD_DIR.glob("**/*.json"):
179
+ data = json.loads(path.read_text(encoding="utf-8"))
180
+ sources.extend(list(_collect_promql_from_dashboard_json(data)))
181
+
182
+ return sources
183
+
184
+
185
+ def extract_metrics_from_promql(promql: str) -> set[str]:
186
+ tokens = set(PROMQL_TOKEN_RE.findall(promql))
187
+ out: set[str] = set()
188
+ for t in tokens:
189
+ if t in PROMQL_KEYWORDS_AND_FUNCS:
190
+ continue
191
+ if t in PROMQL_LABEL_KEYS:
192
+ continue
193
+ if t in PROMQL_COMMON_LABEL_VALUES:
194
+ continue
195
+ if t in PROMQL_TIME_UNITS:
196
+ continue
197
+ if t.isupper():
198
+ continue
199
+ out.add(t)
200
+ return out
201
+
202
+
203
+ def is_generated_from_defined(metric: str, defined: set[str]) -> bool:
204
+ """
205
+ Accept generated series from client libraries:
206
+ - Histogram: <base>_bucket, <base>_sum, <base>_count, <base>_created
207
+ - Summary: <base>_sum, <base>_count, <base>_created
208
+ """
209
+ generated_suffixes = ("_bucket", "_sum", "_count", "_created")
210
+ for base in defined:
211
+ for suf in generated_suffixes:
212
+ if metric == f"{base}{suf}":
213
+ return True
214
+ return False
215
+
216
+
217
+ def main() -> None:
218
+ defined = extract_defined_metrics()
219
+ promql_sources = extract_promql_sources()
220
+
221
+ used: set[str] = set()
222
+ for q in promql_sources:
223
+ used |= extract_metrics_from_promql(q)
224
+
225
+ # Ignore recorded series (contain ':') — derived metrics are allowed.
226
+ missing = sorted(
227
+ m
228
+ for m in used
229
+ if ":" not in m
230
+ and m not in defined
231
+ and not is_generated_from_defined(m, defined)
232
+ )
233
+
234
+ if missing:
235
+ print("❌ Metrics used but not defined (raw):")
236
+ for m in missing:
237
+ print(f" - {m}")
238
+ sys.exit(1)
239
+
240
+ print("✅ Metrics wiring OK — no drift detected.")
241
+
242
+
243
+ if __name__ == "__main__":
244
+ main()