Melika Kheirieh commited on
Commit
773ba18
·
1 Parent(s): 3954d57

feat(cache): add lightweight in-memory TTL cache for NL→SQL responses

Browse files
Files changed (1) hide show
  1. app/routers/nl2sql.py +47 -2
app/routers/nl2sql.py CHANGED
@@ -7,7 +7,8 @@ import os
7
  from pathlib import Path
8
  import time
9
  import uuid
10
- from typing import Any, Dict, Optional, TypedDict, Union, cast, Callable
 
11
 
12
  # --- Third-party ---
13
  from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query
@@ -61,6 +62,39 @@ def _build_pipeline(adapter) -> Any:
61
  return pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  router = APIRouter(prefix="/nl2sql")
65
 
66
  # -------------------------------
@@ -321,6 +355,14 @@ def nl2sql_handler(
321
  db_id, cast(Optional[str], getattr(request, "schema_preview", None))
322
  )
323
 
 
 
 
 
 
 
 
 
324
  # Choose runner: default pipeline from YAML OR per-request override with a specific adapter
325
  if db_id:
326
  adapter = _select_adapter(db_id)
@@ -359,12 +401,15 @@ def nl2sql_handler(
359
 
360
  # Success path → 200 (coerce/standardize traces for API)
361
  traces = [_round_trace(t) for t in (result.traces or [])]
362
- return NL2SQLResponse(
363
  ambiguous=False,
364
  sql=result.sql,
365
  rationale=result.rationale,
366
  traces=traces,
367
  )
 
 
 
368
 
369
 
370
  def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
 
7
  from pathlib import Path
8
  import time
9
  import uuid
10
+ from typing import Any, Dict, Optional, TypedDict, Union, cast, Callable, Tuple
11
+ import hashlib
12
 
13
  # --- Third-party ---
14
  from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query
 
62
  return pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
63
 
64
 
65
+ ####################################
66
+ # ---- Simple in-memory cache for NL→SQL responses ----
67
+
68
+
69
+ _CACHE_TTL = int(os.getenv("NL2SQL_CACHE_TTL_SEC", "300")) # 5 minutes
70
+ _CACHE_MAX = int(os.getenv("NL2SQL_CACHE_MAX", "256"))
71
+ _CACHE: Dict[Tuple[str, str, str], Tuple[float, Dict[str, Any]]] = {}
72
+
73
+
74
+ def _norm_q(s: str) -> str:
75
+ return (s or "").strip().lower()
76
+
77
+
78
+ def _schema_key(preview: str) -> str:
79
+ return hashlib.md5((preview or "").encode()).hexdigest()
80
+
81
+
82
+ def _ck(db_id: Optional[str], query: str, preview: str) -> Tuple[str, str, str]:
83
+ return (db_id or "default", _norm_q(query), _schema_key(preview))
84
+
85
+
86
+ def _cache_gc(now: float) -> None:
87
+ # TTL eviction
88
+ for k, (ts, _) in list(_CACHE.items()):
89
+ if now - ts > _CACHE_TTL:
90
+ _CACHE.pop(k, None)
91
+ # size eviction (ساده)
92
+ while len(_CACHE) > _CACHE_MAX:
93
+ _CACHE.pop(next(iter(_CACHE)), None)
94
+
95
+
96
+ ####################################
97
+
98
  router = APIRouter(prefix="/nl2sql")
99
 
100
  # -------------------------------
 
355
  db_id, cast(Optional[str], getattr(request, "schema_preview", None))
356
  )
357
 
358
+ # ---- cache lookup ----
359
+ now = time.time()
360
+ _cache_gc(now)
361
+ ck = _ck(db_id, request.query, final_preview)
362
+ hit = _CACHE.get(ck)
363
+ if hit and now - hit[0] <= _CACHE_TTL:
364
+ return cast(Dict[str, Any], hit[1]) # early return
365
+
366
  # Choose runner: default pipeline from YAML OR per-request override with a specific adapter
367
  if db_id:
368
  adapter = _select_adapter(db_id)
 
401
 
402
  # Success path → 200 (coerce/standardize traces for API)
403
  traces = [_round_trace(t) for t in (result.traces or [])]
404
+ payload = NL2SQLResponse(
405
  ambiguous=False,
406
  sql=result.sql,
407
  rationale=result.rationale,
408
  traces=traces,
409
  )
410
+ # store in cache
411
+ _CACHE[ck] = (time.time(), cast(Dict[str, Any], payload.model_dump()))
412
+ return payload
413
 
414
 
415
  def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str: