Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
773ba18
1
Parent(s):
3954d57
feat(cache): add lightweight in-memory TTL cache for NL→SQL responses
Browse files- app/routers/nl2sql.py +47 -2
app/routers/nl2sql.py
CHANGED
|
@@ -7,7 +7,8 @@ import os
|
|
| 7 |
from pathlib import Path
|
| 8 |
import time
|
| 9 |
import uuid
|
| 10 |
-
from typing import Any, Dict, Optional, TypedDict, Union, cast, Callable
|
|
|
|
| 11 |
|
| 12 |
# --- Third-party ---
|
| 13 |
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query
|
|
@@ -61,6 +62,39 @@ def _build_pipeline(adapter) -> Any:
|
|
| 61 |
return pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
router = APIRouter(prefix="/nl2sql")
|
| 65 |
|
| 66 |
# -------------------------------
|
|
@@ -321,6 +355,14 @@ def nl2sql_handler(
|
|
| 321 |
db_id, cast(Optional[str], getattr(request, "schema_preview", None))
|
| 322 |
)
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
# Choose runner: default pipeline from YAML OR per-request override with a specific adapter
|
| 325 |
if db_id:
|
| 326 |
adapter = _select_adapter(db_id)
|
|
@@ -359,12 +401,15 @@ def nl2sql_handler(
|
|
| 359 |
|
| 360 |
# Success path → 200 (coerce/standardize traces for API)
|
| 361 |
traces = [_round_trace(t) for t in (result.traces or [])]
|
| 362 |
-
|
| 363 |
ambiguous=False,
|
| 364 |
sql=result.sql,
|
| 365 |
rationale=result.rationale,
|
| 366 |
traces=traces,
|
| 367 |
)
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
|
| 370 |
def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
import time
|
| 9 |
import uuid
|
| 10 |
+
from typing import Any, Dict, Optional, TypedDict, Union, cast, Callable, Tuple
|
| 11 |
+
import hashlib
|
| 12 |
|
| 13 |
# --- Third-party ---
|
| 14 |
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query
|
|
|
|
| 62 |
return pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
|
| 63 |
|
| 64 |
|
| 65 |
+
####################################
|
| 66 |
+
# ---- Simple in-memory cache for NL→SQL responses ----
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
_CACHE_TTL = int(os.getenv("NL2SQL_CACHE_TTL_SEC", "300")) # 5 minutes
|
| 70 |
+
_CACHE_MAX = int(os.getenv("NL2SQL_CACHE_MAX", "256"))
|
| 71 |
+
_CACHE: Dict[Tuple[str, str, str], Tuple[float, Dict[str, Any]]] = {}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _norm_q(s: str) -> str:
|
| 75 |
+
return (s or "").strip().lower()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _schema_key(preview: str) -> str:
|
| 79 |
+
return hashlib.md5((preview or "").encode()).hexdigest()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _ck(db_id: Optional[str], query: str, preview: str) -> Tuple[str, str, str]:
|
| 83 |
+
return (db_id or "default", _norm_q(query), _schema_key(preview))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _cache_gc(now: float) -> None:
|
| 87 |
+
# TTL eviction
|
| 88 |
+
for k, (ts, _) in list(_CACHE.items()):
|
| 89 |
+
if now - ts > _CACHE_TTL:
|
| 90 |
+
_CACHE.pop(k, None)
|
| 91 |
+
# size eviction (ساده)
|
| 92 |
+
while len(_CACHE) > _CACHE_MAX:
|
| 93 |
+
_CACHE.pop(next(iter(_CACHE)), None)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
####################################
|
| 97 |
+
|
| 98 |
router = APIRouter(prefix="/nl2sql")
|
| 99 |
|
| 100 |
# -------------------------------
|
|
|
|
| 355 |
db_id, cast(Optional[str], getattr(request, "schema_preview", None))
|
| 356 |
)
|
| 357 |
|
| 358 |
+
# ---- cache lookup ----
|
| 359 |
+
now = time.time()
|
| 360 |
+
_cache_gc(now)
|
| 361 |
+
ck = _ck(db_id, request.query, final_preview)
|
| 362 |
+
hit = _CACHE.get(ck)
|
| 363 |
+
if hit and now - hit[0] <= _CACHE_TTL:
|
| 364 |
+
return cast(Dict[str, Any], hit[1]) # early return
|
| 365 |
+
|
| 366 |
# Choose runner: default pipeline from YAML OR per-request override with a specific adapter
|
| 367 |
if db_id:
|
| 368 |
adapter = _select_adapter(db_id)
|
|
|
|
| 401 |
|
| 402 |
# Success path → 200 (coerce/standardize traces for API)
|
| 403 |
traces = [_round_trace(t) for t in (result.traces or [])]
|
| 404 |
+
payload = NL2SQLResponse(
|
| 405 |
ambiguous=False,
|
| 406 |
sql=result.sql,
|
| 407 |
rationale=result.rationale,
|
| 408 |
traces=traces,
|
| 409 |
)
|
| 410 |
+
# store in cache
|
| 411 |
+
_CACHE[ck] = (time.time(), cast(Dict[str, Any], payload.model_dump()))
|
| 412 |
+
return payload
|
| 413 |
|
| 414 |
|
| 415 |
def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
|