github-actions[bot] commited on
Commit
82e122c
·
1 Parent(s): e7970d0

Sync from GitHub main

Browse files
app/cache.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from typing import Any, Dict, Optional, Tuple
5
+
6
+ from prometheus_client import Counter
7
+
8
+ cache_hits_total = Counter("nl2sql_cache_hits", "NL2SQL cache hits")
9
+ cache_misses_total = Counter("nl2sql_cache_misses", "NL2SQL cache misses")
10
+
11
+
12
+ class NL2SQLCache:
13
+ """
14
+ Tiny in-memory TTL cache for NL2SQL responses.
15
+ Stores serialized response payloads (dicts) keyed by a hash.
16
+ """
17
+
18
+ def __init__(self, ttl: float = 15.0) -> None:
19
+ self.ttl = ttl
20
+ self._store: Dict[str, Tuple[float, Dict[str, Any]]] = {}
21
+
22
+ def _gc(self, now: float) -> None:
23
+ """Remove expired entries based on the configured TTL."""
24
+ expired_keys = [
25
+ key for key, (ts, _) in self._store.items() if now - ts > self.ttl
26
+ ]
27
+ for key in expired_keys:
28
+ del self._store[key]
29
+
30
+ def get(self, key: str) -> Optional[Dict[str, Any]]:
31
+ """
32
+ Return cached payload if present and not expired, otherwise None.
33
+ Also updates Prometheus counters for hits/misses.
34
+ """
35
+ now = time.time()
36
+ self._gc(now)
37
+
38
+ entry = self._store.get(key)
39
+ if entry is None:
40
+ cache_misses_total.inc()
41
+ return None
42
+
43
+ ts, payload = entry
44
+ if now - ts <= self.ttl:
45
+ cache_hits_total.inc()
46
+ return payload
47
+
48
+ # Entry is expired
49
+ del self._store[key]
50
+ cache_misses_total.inc()
51
+ return None
52
+
53
+ def set(self, key: str, payload: Dict[str, Any]) -> None:
54
+ """Store payload under the given key with current timestamp."""
55
+ self._store[key] = (time.time(), payload)
app/dependencies.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ from app.services.nl2sql_service import NL2SQLService
4
+ from app.cache import NL2SQLCache
5
+ from app.settings import get_settings
6
+
7
+
8
+ @lru_cache()
9
+ def get_nl2sql_service() -> NL2SQLService:
10
+ """
11
+ Singleton-ish NL2SQLService for the FastAPI app.
12
+
13
+ Uses centralized Settings so configuration is loaded once and injected.
14
+ """
15
+ settings = get_settings()
16
+ return NL2SQLService(settings=settings)
17
+
18
+
19
+ @lru_cache()
20
+ def get_cache() -> NL2SQLCache:
21
+ """
22
+ Singleton in-memory cache for NL2SQL responses.
23
+ TTL is intentionally short; this is a per-process best-effort cache.
24
+ """
25
+ return NL2SQLCache(ttl=15.0)
app/errors.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class AppError(Exception):
8
+ """Base class for domain-level errors."""
9
+
10
+ message: str
11
+
12
+ def __str__(self) -> str:
13
+ return self.message
14
+
15
+
16
+ # 4xx-ish
17
+ @dataclass
18
+ class DbNotFound(AppError):
19
+ """Requested DB (or db_id) does not exist."""
20
+
21
+
22
+ @dataclass
23
+ class InvalidRequest(AppError):
24
+ """User input is invalid or cannot be processed."""
25
+
26
+
27
+ @dataclass
28
+ class SchemaRequired(AppError):
29
+ """Caller must provide schema_preview (e.g. postgres mode)."""
30
+
31
+
32
+ @dataclass
33
+ class SchemaDeriveError(AppError):
34
+ """Failed to derive schema preview from DB."""
35
+
36
+
37
+ # 5xx-ish
38
+ @dataclass
39
+ class PipelineConfigError(AppError):
40
+ """Pipeline/YAML/config is missing or malformed."""
41
+
42
+
43
+ @dataclass
44
+ class PipelineRunError(AppError):
45
+ """Unexpected failure while running the pipeline."""
app/exception_handlers.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict
4
+
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.responses import JSONResponse
7
+
8
+ from app.errors import AppError
9
+
10
+
11
+ def register_exception_handlers(app: FastAPI) -> None:
12
+ """
13
+ Register global exception handlers for the FastAPI application.
14
+ """
15
+
16
+ @app.exception_handler(AppError)
17
+ async def app_error_handler(request: Request, exc: AppError) -> JSONResponse:
18
+ """
19
+ Map domain-level AppError instances to HTTP responses.
20
+ This keeps routers thin and lets the domain raise AppError freely.
21
+ """
22
+ status = getattr(exc, "http_status", 500)
23
+ code = getattr(exc, "code", "app_error")
24
+ message = getattr(exc, "message", str(exc))
25
+ extra: Dict[str, Any] = getattr(exc, "extra", {}) or {}
26
+
27
+ payload = {
28
+ "code": code,
29
+ "message": message,
30
+ "extra": extra,
31
+ }
32
+ return JSONResponse(status_code=status, content=payload)
app/main.py CHANGED
@@ -1,28 +1,35 @@
1
  import os
2
  import time
 
3
  from fastapi import FastAPI, Request, Response, HTTPException
4
  from fastapi.responses import PlainTextResponse, RedirectResponse
5
  from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
 
6
  from nl2sql.prom import REGISTRY
7
- from app.routers import dev
 
 
8
 
9
  try:
10
  from dotenv import load_dotenv
11
 
12
  load_dotenv()
13
  except Exception:
 
14
  pass
15
 
16
- from app.routers import nl2sql
 
17
 
18
  # ----------------------------------------------------------------------------
19
  # App definition
20
  # ----------------------------------------------------------------------------
21
  application = FastAPI(
22
  title="NL2SQL Copilot Prototype",
23
- version=os.getenv("APP_VERSION", "0.1.0"),
24
  description="Convert natural language to safe & verified SQL",
25
  )
 
26
 
27
  # Register only versioned API
28
  application.include_router(nl2sql.router, prefix="/api/v1")
@@ -30,6 +37,7 @@ application.include_router(nl2sql.router, prefix="/api/v1")
30
  # Register Dev-only routes (only when APP_ENV=dev)
31
  if os.getenv("APP_ENV", "dev").lower() == "dev":
32
  application.include_router(dev.router, prefix="/api/v1")
 
33
  # ----------------------------------------------------------------------------
34
  # Prometheus Metrics Middleware
35
  # ----------------------------------------------------------------------------
@@ -75,21 +83,30 @@ def healthz() -> str:
75
 
76
  @application.get("/readyz", response_class=PlainTextResponse, tags=["system"])
77
  def readyz() -> str:
78
- mode = os.getenv("DB_MODE", "sqlite").lower()
 
 
 
 
 
 
79
  try:
80
  if mode == "postgres":
81
  from adapters.db.postgres_adapter import PostgresAdapter
82
 
83
- pg = PostgresAdapter(os.environ["POSTGRES_DSN"])
 
 
 
 
84
  ping_fn = getattr(pg, "ping", None)
85
  if callable(ping_fn):
86
  ping_fn()
87
  else:
88
  from adapters.db.sqlite_adapter import SQLiteAdapter
89
 
90
- sq = SQLiteAdapter(
91
- os.getenv("DEFAULT_SQLITE_PATH", "data/Chinook_Sqlite.sqlite")
92
- )
93
  ping_fn = getattr(sq, "ping", None)
94
  if callable(ping_fn):
95
  ping_fn()
@@ -105,6 +122,7 @@ def root():
105
 
106
  @application.get("/health")
107
  def health():
 
108
  return {"status": "ok", "db": "connected", "llm": "reachable", "uptime_sec": 123.4}
109
 
110
 
 
1
  import os
2
  import time
3
+
4
  from fastapi import FastAPI, Request, Response, HTTPException
5
  from fastapi.responses import PlainTextResponse, RedirectResponse
6
  from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
7
+
8
  from nl2sql.prom import REGISTRY
9
+ from app.routers import dev, nl2sql
10
+ from app.settings import get_settings
11
+ from app.exception_handlers import register_exception_handlers
12
 
13
  try:
14
  from dotenv import load_dotenv
15
 
16
  load_dotenv()
17
  except Exception:
18
+ # Best-effort .env loading; app must not crash if dotenv is missing.
19
  pass
20
 
21
+
22
+ settings = get_settings()
23
 
24
  # ----------------------------------------------------------------------------
25
  # App definition
26
  # ----------------------------------------------------------------------------
27
  application = FastAPI(
28
  title="NL2SQL Copilot Prototype",
29
+ version=settings.app_version,
30
  description="Convert natural language to safe & verified SQL",
31
  )
32
+ register_exception_handlers(application)
33
 
34
  # Register only versioned API
35
  application.include_router(nl2sql.router, prefix="/api/v1")
 
37
  # Register Dev-only routes (only when APP_ENV=dev)
38
  if os.getenv("APP_ENV", "dev").lower() == "dev":
39
  application.include_router(dev.router, prefix="/api/v1")
40
+
41
  # ----------------------------------------------------------------------------
42
  # Prometheus Metrics Middleware
43
  # ----------------------------------------------------------------------------
 
83
 
84
  @application.get("/readyz", response_class=PlainTextResponse, tags=["system"])
85
  def readyz() -> str:
86
+ """
87
+ Lightweight readiness probe:
88
+
89
+ - For postgres mode → ping PostgresAdapter using configured DSN.
90
+ - For sqlite mode → ping SQLiteAdapter using configured default path.
91
+ """
92
+ mode = settings.db_mode.lower()
93
  try:
94
  if mode == "postgres":
95
  from adapters.db.postgres_adapter import PostgresAdapter
96
 
97
+ dsn = (settings.postgres_dsn or "").strip()
98
+ if not dsn:
99
+ raise RuntimeError("POSTGRES_DSN is not configured for readiness check")
100
+
101
+ pg = PostgresAdapter(dsn)
102
  ping_fn = getattr(pg, "ping", None)
103
  if callable(ping_fn):
104
  ping_fn()
105
  else:
106
  from adapters.db.sqlite_adapter import SQLiteAdapter
107
 
108
+ db_path = settings.default_sqlite_path or "data/Chinook_Sqlite.sqlite"
109
+ sq = SQLiteAdapter(db_path)
 
110
  ping_fn = getattr(sq, "ping", None)
111
  if callable(ping_fn):
112
  ping_fn()
 
122
 
123
  @application.get("/health")
124
  def health():
125
+ # This is a higher-level health stub; real checks can be wired later
126
  return {"status": "ok", "db": "connected", "llm": "reachable", "uptime_sec": 123.4}
127
 
128
 
app/routers/nl2sql.py CHANGED
@@ -4,82 +4,51 @@ from __future__ import annotations
4
  from dataclasses import asdict, is_dataclass
5
  import os
6
  from pathlib import Path
7
- import time
8
  import uuid
9
- from typing import Any, Dict, Optional, Union, cast, Callable, Tuple
10
  import hashlib
11
  import logging
12
 
13
  # --- Third-party ---
14
- from fastapi import APIRouter, HTTPException, UploadFile, File, Depends, Query
15
- from fastapi import Security
16
  from fastapi.security import APIKeyHeader
17
  from prometheus_client import Counter
18
 
19
  # --- Local ---
20
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
21
- from app.state import cleanup_stale_dbs, register_db
22
- from nl2sql.pipeline import FinalResult, FinalResult as _FinalResult
23
- from adapters.llm.openai_provider import OpenAIProvider
24
- from adapters.db.sqlite_adapter import SQLiteAdapter
25
- from adapters.db.postgres_adapter import PostgresAdapter
26
- from nl2sql.pipeline_factory import (
27
- pipeline_from_config_with_adapter,
28
- )
29
  from nl2sql.prom import REGISTRY
 
 
 
 
 
 
 
 
 
 
30
 
31
- log = logging.getLogger(__name__)
32
  api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
33
 
34
 
35
  def require_api_key(key: Optional[str] = Security(api_key_header)):
36
- raw = os.getenv("API_KEYS", "")
 
 
 
 
 
 
37
  allowed = {k.strip() for k in raw.split(",") if k.strip()}
38
- if not allowed: # no keys set → auth disabled (dev mode)
 
39
  return
40
  if not key or key not in allowed:
41
  raise HTTPException(status_code=401, detail="invalid API key")
42
 
43
 
44
- _PIPELINE: Optional[Any] = None # lazy cache
45
-
46
-
47
- Runner = Callable[..., _FinalResult]
48
-
49
-
50
- def get_runner() -> Runner:
51
- """Build pipeline lazily; under pytest return a stub runner."""
52
- if os.getenv("PYTEST_CURRENT_TEST"):
53
- # Minimal OK runner for route tests (no ambiguity)
54
- def _fake_runner(
55
- *, user_query: str, schema_preview: str | None = None
56
- ) -> _FinalResult:
57
- return _FinalResult(
58
- ok=True,
59
- ambiguous=False,
60
- error=False,
61
- details=None,
62
- questions=None,
63
- sql="SELECT 1;",
64
- rationale=None,
65
- verified=True,
66
- traces=[],
67
- )
68
-
69
- return _fake_runner
70
-
71
- global _PIPELINE
72
- if _PIPELINE is None:
73
- adapter = _select_adapter(None) # fallback demo.db
74
- _PIPELINE = pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
75
- return _PIPELINE.run
76
-
77
-
78
- def _build_pipeline(adapter) -> Any:
79
- """Thin wrapper for tests to monkeypatch; builds a pipeline bound to adapter."""
80
- return pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
81
-
82
-
83
  ####################################
84
  # ---- Simple in-memory cache for NL→SQL responses ----
85
 
@@ -87,29 +56,54 @@ cache_hits_total = Counter("cache_hits_total", "NL2SQL cache hits", registry=REG
87
  cache_misses_total = Counter(
88
  "cache_misses_total", "NL2SQL cache misses", registry=REGISTRY
89
  )
90
- _CACHE_TTL = int(os.getenv("NL2SQL_CACHE_TTL_SEC", "300")) # 5 minutes
91
- _CACHE_MAX = int(os.getenv("NL2SQL_CACHE_MAX", "256"))
 
 
92
  _CACHE: Dict[Tuple[str, str, str], Tuple[float, Dict[str, Any]]] = {}
93
 
94
 
95
  def _norm_q(s: str) -> str:
 
96
  return (s or "").strip().lower()
97
 
98
 
99
  def _schema_key(preview: str) -> str:
 
100
  return hashlib.md5((preview or "").encode()).hexdigest()
101
 
102
 
103
- def _ck(db_id: Optional[str], query: str, preview: str) -> Tuple[str, str, str]:
104
- return (db_id or "default", _norm_q(query), _schema_key(preview))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
 
107
  def _cache_gc(now: float) -> None:
 
 
 
108
  # TTL eviction
109
  for k, (ts, _) in list(_CACHE.items()):
110
  if now - ts > _CACHE_TTL:
111
  _CACHE.pop(k, None)
112
- # size eviction
 
113
  while len(_CACHE) > _CACHE_MAX:
114
  _CACHE.pop(next(iter(_CACHE)), None)
115
 
@@ -121,92 +115,20 @@ router = APIRouter(prefix="/nl2sql")
121
  # -------------------------------
122
  # Config / Defaults
123
  # -------------------------------
124
- DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
125
- POSTGRES_DSN = os.getenv("POSTGRES_DSN")
126
- # Default demo DB used when no db_id is provided (can be full Chinook or a tiny demo DB)
127
- DEFAULT_SQLITE_PATH: str = os.getenv(
128
- "DEFAULT_SQLITE_PATH", "data/Chinook_Sqlite.sqlite"
129
- )
130
- print("=== STARTUP DEBUG ===")
131
- print("DEFAULT_SQLITE_PATH:", DEFAULT_SQLITE_PATH)
132
- print("CWD:", os.getcwd())
133
- print("FILES in ./:", os.listdir("."))
134
- print(
135
- "FILES in ./data:",
136
- os.listdir("data") if os.path.exists("data") else "NO DATA FOLDER",
137
- )
138
-
139
 
140
- # Runtime upload storage
141
- _DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
142
- _DB_TTL_SECONDS: int = int(os.getenv("DB_TTL_SECONDS", "7200")) # default 2 hours
143
  os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
144
 
145
- # Persisted map
146
- _DB_MAP_PATH = Path("data/uploads/db_map.json")
147
- _DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
148
-
149
  UPLOAD_DIR = Path("data/uploads")
150
  UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
151
 
152
- CONFIG_PATH = os.getenv("PIPELINE_CONFIG", "configs/sqlite_pipeline.yaml")
153
-
154
-
155
- # -------------------------------
156
- # Adapter selection (lazy)
157
- # -------------------------------
158
- def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
159
- """
160
- Resolve DB adapter path for SQLite or Postgres.
161
- """
162
- if DB_MODE == "postgres":
163
- dsn = os.environ.get("POSTGRES_DSN")
164
- if not dsn:
165
- raise HTTPException(status_code=500, detail="POSTGRES_DSN env is missing")
166
- return PostgresAdapter(dsn)
167
-
168
- if db_id:
169
- cleanup_stale_dbs()
170
-
171
- candidates = [
172
- Path("/tmp/nl2sql_dbs") / f"{db_id}.sqlite",
173
- Path("/tmp/nl2sql_dbs") / f"{db_id}.db",
174
- Path("data/uploads") / f"{db_id}.sqlite",
175
- Path("data/uploads") / f"{db_id}.db",
176
- Path("data") / f"{db_id}.sqlite",
177
- Path("data") / f"{db_id}.db",
178
- ]
179
-
180
- for candidate in candidates:
181
- if candidate.exists():
182
- log.info(f"Using DB file: {candidate}")
183
- return SQLiteAdapter(str(candidate))
184
-
185
- raise HTTPException(status_code=404, detail=f"db_id not found: {db_id}")
186
-
187
- # -------- Default SQLite Logic --------
188
- default_path = Path(DEFAULT_SQLITE_PATH)
189
- db_path = str(default_path)
190
-
191
- log.debug("DEFAULT SQLITE DEBUG INFO:")
192
- log.debug(f"DEFAULT_SQLITE_PATH env: {DEFAULT_SQLITE_PATH}")
193
- log.debug(f"CWD: {os.getcwd()}")
194
- log.debug(f"ABS PATH: {default_path.resolve()}")
195
- log.debug(f"EXISTS?: {default_path.exists()}")
196
- if os.path.exists("data"):
197
- log.debug(f"LIST DATA: {os.listdir('data')}")
198
- else:
199
- log.debug("LIST DATA: NO DATA DIRECTORY")
200
-
201
- if not default_path.exists():
202
- fallback = Path("data/demo.db")
203
- if fallback.exists():
204
- log.warning("Default sqlite missing; using fallback demo.db")
205
- db_path = str(fallback)
206
- else:
207
- raise HTTPException(status_code=500, detail="no sqlite database found")
208
-
209
- return SQLiteAdapter(db_path)
210
 
211
 
212
  # -------------------------------
@@ -215,37 +137,43 @@ def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapte
215
 
216
 
217
  @router.get("/schema")
218
- def get_schema(db_id: Optional[str] = Query(default=None)):
 
 
 
219
  """
220
- Return a schema preview for a given db_id (SQLite only).
221
- If db_id is omitted, returns the default database schema.
 
 
 
 
 
222
  """
223
  try:
224
- adapter = _select_adapter(db_id)
225
- preview = _derive_schema_preview(adapter)
226
- if not preview.strip():
227
- raise HTTPException(
228
- status_code=404, detail="Schema preview not available or empty"
229
- )
230
- return {"db_id": db_id or "default", "schema_preview": preview}
231
- except HTTPException:
232
  raise
233
- except Exception as e:
234
- raise HTTPException(status_code=500, detail=f"Schema introspection failed: {e}")
 
 
 
 
235
 
236
-
237
- # -------------------------------
238
- # LLM & Pipeline builders (lazy)
239
- # -------------------------------
240
- def _get_llm() -> OpenAIProvider:
241
- # Create provider on demand, after .env has been loaded in app.main
242
- return OpenAIProvider()
243
 
244
 
245
  # -------------------------------
246
  # Helpers
247
  # -------------------------------
 
 
248
  def _to_dict(obj: Any) -> Any:
 
 
 
249
  if is_dataclass(obj) and not isinstance(obj, type):
250
  return asdict(obj) # type: ignore[arg-type]
251
  return obj
@@ -254,6 +182,7 @@ def _to_dict(obj: Any) -> Any:
254
  def _round_trace(t: Any) -> Dict[str, Any]:
255
  """
256
  Normalize a trace entry (dict or StageTrace-like object) for API/UI:
 
257
  - stage: str (required)
258
  - duration_ms: int (rounded)
259
  - summary: optional (pass-through if exists)
@@ -277,7 +206,7 @@ def _round_trace(t: Any) -> Dict[str, Any]:
277
  token_in = getattr(t, "token_in", None)
278
  token_out = getattr(t, "token_out", None)
279
 
280
- # coerce duration to int with rounding
281
  try:
282
  ms_int = int(round(float(ms))) if ms is not None else 0
283
  except Exception:
@@ -301,8 +230,17 @@ def _round_trace(t: Any) -> Dict[str, Any]:
301
  # -------------------------------
302
  # Upload endpoint (SQLite only)
303
  # -------------------------------
 
 
304
  @router.post("/upload_db", dependencies=[Depends(require_api_key)])
305
  async def upload_db(file: UploadFile = File(...)):
 
 
 
 
 
 
 
306
  if DB_MODE != "sqlite":
307
  raise HTTPException(
308
  status_code=400, detail="DB upload is only supported in sqlite mode"
@@ -315,7 +253,7 @@ async def upload_db(file: UploadFile = File(...)):
315
  )
316
 
317
  data = await file.read()
318
- max_bytes = int(os.getenv("UPLOAD_MAX_BYTES", str(20 * 1024 * 1024))) # 20 MB
319
  if len(data) > max_bytes:
320
  raise HTTPException(
321
  status_code=400, detail=f"File too large (> {max_bytes} bytes)"
@@ -327,92 +265,119 @@ async def upload_db(file: UploadFile = File(...)):
327
  with open(out_path, "wb") as f:
328
  f.write(data)
329
  except Exception as e:
 
330
  raise HTTPException(status_code=500, detail=f"Failed to store DB: {e}")
331
 
332
  register_db(db_id, out_path)
 
333
  return {"db_id": db_id}
334
 
335
 
336
- def _final_schema_preview(db_id: Optional[str], provided_preview: Optional[str]) -> str:
337
- if provided_preview and provided_preview.strip():
338
- return provided_preview
339
-
340
- adapter = _select_adapter(db_id) # works for both None and explicit db_id
341
- return _derive_schema_preview(adapter) or ""
342
-
343
-
344
  @router.get("/health")
345
  def health():
346
- return {"status": "ok", "version": os.getenv("APP_VERSION", "dev")}
 
347
 
348
 
349
  # -------------------------------
350
  # Main NL2SQL endpoint
351
  # -------------------------------
 
 
352
  @router.post("", name="nl2sql_handler", dependencies=[Depends(require_api_key)])
353
  def nl2sql_handler(
354
  request: NL2SQLRequest,
355
- run: Runner = Depends(get_runner),
356
- ):
 
357
  """
358
- NL→SQL handler using YAML-driven DI. If 'db_id' is provided, we override only the adapter
359
- while keeping all other stages from the YAML configs intact.
 
 
 
 
 
360
  """
361
  db_id = getattr(request, "db_id", None)
362
- final_preview = _final_schema_preview(
363
- db_id, cast(Optional[str], getattr(request, "schema_preview", None))
364
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  # ---- cache lookup ----
367
- now = time.time()
368
- _cache_gc(now)
369
- ck = _ck(db_id, request.query, final_preview)
370
- hit = _CACHE.get(ck)
371
- if hit and now - hit[0] <= _CACHE_TTL:
372
- cache_hits_total.inc()
373
- return hit[1] # early return
374
- cache_misses_total.inc()
375
-
376
- # Choose runner: default pipeline from YAML OR per-request override with a specific adapter
377
- if db_id:
378
- adapter = _select_adapter(db_id)
379
- pipeline = _build_pipeline(adapter)
380
- runner = pipeline.run
381
- else:
382
- runner = run
383
 
384
- # Execute pipeline
385
  try:
386
- result = runner(user_query=request.query, schema_preview=final_preview)
 
 
 
 
 
 
 
387
  except Exception as exc:
388
- raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
 
 
 
 
 
 
 
389
 
390
- # Type sanity
391
  if not isinstance(result, FinalResult):
392
- raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
 
 
 
 
 
 
 
393
 
394
- # Ambiguity path → 200 with questions
395
  if result.ambiguous:
396
  qs = result.questions or []
397
  return ClarifyResponse(ambiguous=True, questions=qs)
398
 
399
- if not isinstance(result, _FinalResult):
400
- raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
401
-
402
- # Error path → 400 with joined details
403
  if (not result.ok) or result.error:
404
- print("❌ Pipeline failure dump:")
405
- print(" ok:", result.ok)
406
- print(" error:", result.error)
407
- print(" details:", result.details)
408
- print(" traces:", result.traces)
 
 
 
409
  message = "; ".join(result.details or []) or "Unknown error"
410
  raise HTTPException(status_code=400, detail=message)
411
 
412
- # Success path → 200 (coerce/standardize traces for API)
413
  traces = [_round_trace(t) for t in (result.traces or [])]
414
 
415
- # Normalize execution result (if executor attached one)
416
  response_result: Dict[str, Any] = {}
417
  raw_result = getattr(result, "result", None)
418
  if raw_result is not None:
@@ -429,35 +394,6 @@ def nl2sql_handler(
429
  result=response_result,
430
  )
431
 
432
- # store in cache
433
- _CACHE[ck] = (time.time(), cast(Dict[str, Any], payload.model_dump()))
434
  return payload
435
-
436
-
437
- def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
438
- """
439
- Build a strict, exact-cased schema preview for the LLM (SQLite only).
440
- """
441
- import sqlite3
442
-
443
- db_path: Optional[str] = cast(
444
- Optional[str], getattr(adapter, "db_path", None)
445
- ) or cast(Optional[str], getattr(adapter, "path", None))
446
- if not db_path or not os.path.exists(db_path):
447
- return ""
448
-
449
- try:
450
- conn = sqlite3.connect(db_path)
451
- cur = conn.cursor()
452
- tables = cur.execute(
453
- "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
454
- ).fetchall()
455
- lines = []
456
- for (tname,) in tables:
457
- cols = cur.execute(f"PRAGMA table_info('{tname}')").fetchall()
458
- colnames = [c[1] for c in cols] # (cid, name, type, notnull, dflt, pk)
459
- lines.append(f"{tname}({', '.join(colnames)})")
460
- conn.close()
461
- return "\n".join(lines)
462
- except Exception:
463
- return ""
 
4
  from dataclasses import asdict, is_dataclass
5
  import os
6
  from pathlib import Path
 
7
  import uuid
8
+ from typing import Any, Dict, Optional, Tuple, cast
9
  import hashlib
10
  import logging
11
 
12
  # --- Third-party ---
13
+ from fastapi import APIRouter, Depends, HTTPException, Security, UploadFile, File
 
14
  from fastapi.security import APIKeyHeader
15
  from prometheus_client import Counter
16
 
17
  # --- Local ---
18
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
19
+ from app.state import register_db
20
+ from nl2sql.pipeline import FinalResult
 
 
 
 
 
 
21
  from nl2sql.prom import REGISTRY
22
+ from app.dependencies import get_cache, get_nl2sql_service
23
+ from app.cache import NL2SQLCache
24
+ from app.services.nl2sql_service import NL2SQLService
25
+ from app.settings import get_settings
26
+ from app.errors import (
27
+ AppError,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+ settings = get_settings()
32
 
 
33
  api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
34
 
35
 
36
  def require_api_key(key: Optional[str] = Security(api_key_header)):
37
+ """
38
+ Simple API key check using X-API-Key header and configured API keys.
39
+
40
+ - Settings.api_keys_raw is a comma-separated list of keys.
41
+ - If api_keys_raw is empty → auth disabled (dev mode).
42
+ """
43
+ raw = settings.api_keys_raw or ""
44
  allowed = {k.strip() for k in raw.split(",") if k.strip()}
45
+ if not allowed:
46
+ # No keys configured → treat as dev mode (auth off).
47
  return
48
  if not key or key not in allowed:
49
  raise HTTPException(status_code=401, detail="invalid API key")
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ####################################
53
  # ---- Simple in-memory cache for NL→SQL responses ----
54
 
 
56
  cache_misses_total = Counter(
57
  "cache_misses_total", "NL2SQL cache misses", registry=REGISTRY
58
  )
59
+
60
+ # Cache TTL and max size from centralized settings
61
+ _CACHE_TTL = settings.cache_ttl_sec
62
+ _CACHE_MAX = settings.cache_max_entries
63
  _CACHE: Dict[Tuple[str, str, str], Tuple[float, Dict[str, Any]]] = {}
64
 
65
 
66
  def _norm_q(s: str) -> str:
67
+ """Normalize a user query for cache key purposes."""
68
  return (s or "").strip().lower()
69
 
70
 
71
  def _schema_key(preview: str) -> str:
72
+ """Hash the schema preview so we do not store huge strings in the cache key."""
73
  return hashlib.md5((preview or "").encode()).hexdigest()
74
 
75
 
76
+ def _ck(
77
+ db_id: Optional[str],
78
+ query: str,
79
+ schema_preview: str,
80
+ ) -> str:
81
+ """
82
+ Build a stable cache key for (db_id, query, schema_preview).
83
+
84
+ We keep the external cache API string-based, and hash the
85
+ potentially large schema_preview to avoid huge dictionary keys.
86
+ """
87
+ # Normalize db_id
88
+ db_part = db_id or "__default__"
89
+
90
+ # Build a single string seed
91
+ seed = f"{db_part}\n{query}\n{schema_preview}"
92
+
93
+ # Short, deterministic key
94
+ return hashlib.sha1(seed.encode("utf-8")).hexdigest()
95
 
96
 
97
  def _cache_gc(now: float) -> None:
98
+ """
99
+ Garbage-collect cache entries by TTL and max size.
100
+ """
101
  # TTL eviction
102
  for k, (ts, _) in list(_CACHE.items()):
103
  if now - ts > _CACHE_TTL:
104
  _CACHE.pop(k, None)
105
+
106
+ # Size eviction (naive FIFO-style)
107
  while len(_CACHE) > _CACHE_MAX:
108
  _CACHE.pop(next(iter(_CACHE)), None)
109
 
 
115
  # -------------------------------
116
  # Config / Defaults
117
  # -------------------------------
118
+ DB_MODE = settings.db_mode.lower() # "sqlite" or "postgres"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ # Runtime upload storage for SQLite DBs
121
+ _DB_UPLOAD_DIR = settings.db_upload_dir
 
122
  os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
123
 
124
+ # Optional: separate directory for other uploads (kept as-is for now)
 
 
 
125
  UPLOAD_DIR = Path("data/uploads")
126
  UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
127
 
128
+ logger.debug(
129
+ "NL2SQL router configured",
130
+ extra={"db_mode": DB_MODE, "upload_dir": _DB_UPLOAD_DIR},
131
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  # -------------------------------
 
137
 
138
 
139
  @router.get("/schema")
140
+ def schema_endpoint(
141
+ db_id: Optional[str] = None,
142
+ svc: NL2SQLService = Depends(get_nl2sql_service),
143
+ ):
144
  """
145
+ Return a lightweight schema preview string for the given DB.
146
+
147
+ - If db_id is provided, service will resolve the uploaded DB.
148
+ - If not, service falls back to the default DB.
149
+ - In postgres mode, caller must usually provide schema_preview explicitly.
150
+ Domain errors (AppError subclasses) are handled by the global exception handler.
151
+ This endpoint only wraps truly unexpected errors into a generic HTTP 500
152
  """
153
  try:
154
+ preview = svc.get_schema_preview(db_id=db_id, override=None)
155
+ except AppError:
156
+ # Let the global AppError handler deal with it.
 
 
 
 
 
157
  raise
158
+ except Exception as exc:
159
+ logger.exception("Unexpected error in schema_endpoint", exc_info=exc)
160
+ raise HTTPException(
161
+ status_code=500,
162
+ detail="failed to derive schema preview",
163
+ ) from exc
164
 
165
+ return {"schema_preview": preview}
 
 
 
 
 
 
166
 
167
 
168
  # -------------------------------
169
  # Helpers
170
  # -------------------------------
171
+
172
+
173
  def _to_dict(obj: Any) -> Any:
174
+ """
175
+ Convert dataclass-like objects (and similar) to plain dicts for JSON.
176
+ """
177
  if is_dataclass(obj) and not isinstance(obj, type):
178
  return asdict(obj) # type: ignore[arg-type]
179
  return obj
 
182
  def _round_trace(t: Any) -> Dict[str, Any]:
183
  """
184
  Normalize a trace entry (dict or StageTrace-like object) for API/UI:
185
+
186
  - stage: str (required)
187
  - duration_ms: int (rounded)
188
  - summary: optional (pass-through if exists)
 
206
  token_in = getattr(t, "token_in", None)
207
  token_out = getattr(t, "token_out", None)
208
 
209
+ # Coerce duration to int with rounding
210
  try:
211
  ms_int = int(round(float(ms))) if ms is not None else 0
212
  except Exception:
 
230
  # -------------------------------
231
  # Upload endpoint (SQLite only)
232
  # -------------------------------
233
+
234
+
235
  @router.post("/upload_db", dependencies=[Depends(require_api_key)])
236
  async def upload_db(file: UploadFile = File(...)):
237
+ """
238
+ Upload a SQLite DB file and register it under a generated db_id.
239
+
240
+ Only available when DB_MODE is 'sqlite':
241
+ - Allowed extensions: .db, .sqlite
242
+ - File size capped by configured upload_max_bytes (default 20 MB)
243
+ """
244
  if DB_MODE != "sqlite":
245
  raise HTTPException(
246
  status_code=400, detail="DB upload is only supported in sqlite mode"
 
253
  )
254
 
255
  data = await file.read()
256
+ max_bytes = settings.upload_max_bytes
257
  if len(data) > max_bytes:
258
  raise HTTPException(
259
  status_code=400, detail=f"File too large (> {max_bytes} bytes)"
 
265
  with open(out_path, "wb") as f:
266
  f.write(data)
267
  except Exception as e:
268
+ logger.debug("Failed to store uploaded DB file", exc_info=e)
269
  raise HTTPException(status_code=500, detail=f"Failed to store DB: {e}")
270
 
271
  register_db(db_id, out_path)
272
+ logger.debug("Registered uploaded DB", extra={"db_id": db_id, "path": out_path})
273
  return {"db_id": db_id}
274
 
275
 
 
 
 
 
 
 
 
 
276
  @router.get("/health")
277
  def health():
278
+ """Simple router-level health endpoint."""
279
+ return {"status": "ok", "version": settings.app_version}
280
 
281
 
282
  # -------------------------------
283
  # Main NL2SQL endpoint
284
  # -------------------------------
285
+
286
+
287
  @router.post("", name="nl2sql_handler", dependencies=[Depends(require_api_key)])
288
  def nl2sql_handler(
289
  request: NL2SQLRequest,
290
+ svc: NL2SQLService = Depends(get_nl2sql_service),
291
+ cache: NL2SQLCache = Depends(get_cache),
292
+ ) -> NL2SQLResponse | ClarifyResponse | Dict[str, Any]:
293
  """
294
+ Main NL→SQL handler.
295
+
296
+ Flow:
297
+ - Resolve schema preview (client override or derived from DB).
298
+ - Check in-memory cache (db_id + query + schema hash).
299
+ - Run the pipeline through NL2SQLService.
300
+ - Map FinalResult to API response or HTTP error.
301
  """
302
  db_id = getattr(request, "db_id", None)
303
+
304
+ # ---- schema preview ----
305
+ try:
306
+ final_preview = svc.get_schema_preview(
307
+ db_id=db_id,
308
+ override=request.schema_preview,
309
+ )
310
+ except AppError:
311
+ # Domain-level errors are handled by the global AppError handler.
312
+ raise
313
+ except Exception as exc:
314
+ logger.exception(
315
+ "Unexpected error while preparing schema preview",
316
+ exc_info=exc,
317
+ )
318
+ raise HTTPException(
319
+ status_code=500,
320
+ detail="failed to prepare schema",
321
+ ) from exc
322
 
323
  # ---- cache lookup ----
324
+ cache_key = _ck(db_id, request.query, final_preview)
325
+ cached_payload = cache.get(cache_key)
326
+ if cached_payload is not None:
327
+ return cached_payload
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
+ # ---- pipeline execution via service ----
330
  try:
331
+ result = svc.run_query(
332
+ query=request.query,
333
+ db_id=db_id,
334
+ schema_preview=final_preview,
335
+ )
336
+ except AppError:
337
+ # Let the global handler convert it to an HTTP response.
338
+ raise
339
  except Exception as exc:
340
+ logger.exception(
341
+ "Unexpected pipeline crash in NL2SQLService.run_query",
342
+ exc_info=exc,
343
+ )
344
+ raise HTTPException(
345
+ status_code=500,
346
+ detail="internal pipeline error",
347
+ ) from exc
348
 
349
+ # ---- type sanity check ----
350
  if not isinstance(result, FinalResult):
351
+ logger.debug(
352
+ "Pipeline returned unexpected type",
353
+ extra={"type": type(result).__name__},
354
+ )
355
+ raise HTTPException(
356
+ status_code=500,
357
+ detail="pipeline returned unexpected type",
358
+ )
359
 
360
+ # ---- ambiguity path → 200 with clarification questions ----
361
  if result.ambiguous:
362
  qs = result.questions or []
363
  return ClarifyResponse(ambiguous=True, questions=qs)
364
 
365
+ # ---- error path → 400 with joined details ----
 
 
 
366
  if (not result.ok) or result.error:
367
+ logger.debug(
368
+ "Pipeline reported failure",
369
+ extra={
370
+ "ok": result.ok,
371
+ "error": result.error,
372
+ "details": result.details,
373
+ },
374
+ )
375
  message = "; ".join(result.details or []) or "Unknown error"
376
  raise HTTPException(status_code=400, detail=message)
377
 
378
+ # ---- success path → 200 (normalize traces and executor result) ----
379
  traces = [_round_trace(t) for t in (result.traces or [])]
380
 
 
381
  response_result: Dict[str, Any] = {}
382
  raw_result = getattr(result, "result", None)
383
  if raw_result is not None:
 
394
  result=response_result,
395
  )
396
 
397
+ # Store in cache (as plain dict)
398
+ cache.set(cache_key, payload.model_dump())
399
  return payload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/nl2sql_service.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sqlite3
4
+ from dataclasses import dataclass
5
+ from typing import Any, Optional
6
+ from pathlib import Path
7
+
8
+ from nl2sql.pipeline import FinalResult
9
+ from nl2sql.pipeline_factory import pipeline_from_config_with_adapter
10
+ from adapters.db.sqlite_adapter import SQLiteAdapter
11
+ from adapters.db.postgres_adapter import PostgresAdapter
12
+ from app import state
13
+ from app.settings import Settings
14
+ from app.errors import (
15
+ AppError,
16
+ DbNotFound,
17
+ SchemaRequired,
18
+ SchemaDeriveError,
19
+ PipelineConfigError,
20
+ PipelineRunError,
21
+ )
22
+
23
+ Adapter = Any # You can replace this with a Protocol later
24
+
25
+
26
+ @dataclass
27
+ class NL2SQLService:
28
+ """
29
+ Application-level service for the NL2SQL use-case.
30
+
31
+ Responsibilities:
32
+ - Choose the right DB adapter based on db_mode + db_id.
33
+ - Derive or accept schema preview.
34
+ - Build and run the pipeline for a given query.
35
+ """
36
+
37
+ settings: Settings
38
+
39
+ def _select_adapter(self, db_id: Optional[str]) -> Adapter:
40
+ mode = self.settings.db_mode.lower()
41
+
42
+ if mode == "postgres":
43
+ dsn = (self.settings.postgres_dsn or "").strip()
44
+ if not dsn:
45
+ raise PipelineConfigError("Postgres DSN is not configured")
46
+ return PostgresAdapter(dsn=dsn)
47
+
48
+ if db_id:
49
+ state.cleanup_stale_dbs()
50
+ path = state.get_db_path(db_id)
51
+ if not path:
52
+ raise DbNotFound(f"Could not resolve DB for db_id={db_id!r}")
53
+ return SQLiteAdapter(path=path)
54
+
55
+ default_path = self.settings.default_sqlite_path
56
+ if not Path(default_path).exists():
57
+ raise DbNotFound(f"SQLite database path does not exist: {default_path!r}")
58
+
59
+ return SQLiteAdapter(path=default_path)
60
+
61
+ def _introspect_sqlite_schema(self, adapter: Adapter) -> str:
62
+ """
63
+ Build a lightweight textual schema preview for a SQLite database.
64
+
65
+ This is a straight port of the previous sqlite3 logic, but contained
66
+ inside the service instead of the router.
67
+ """
68
+ # Try to locate the underlying .db path from the adapter
69
+ db_path = getattr(adapter, "db_path", None) or getattr(adapter, "path", None)
70
+ if not db_path:
71
+ raise RuntimeError(
72
+ "SQLite adapter must expose a .db_path or .path attribute"
73
+ )
74
+
75
+ if not Path(db_path).exists():
76
+ raise FileNotFoundError(f"SQLite database path does not exist: {db_path}")
77
+
78
+ lines: list[str] = []
79
+ conn = sqlite3.connect(db_path)
80
+ try:
81
+ cur = conn.cursor()
82
+ cur.execute(
83
+ "SELECT name FROM sqlite_master WHERE type='table' "
84
+ "AND name NOT LIKE 'sqlite_%' ORDER BY name"
85
+ )
86
+ tables = [row[0] for row in cur.fetchall()]
87
+
88
+ for table in tables:
89
+ cur.execute(f"PRAGMA table_info({table})")
90
+ cols = [row[1] for row in cur.fetchall()]
91
+ if cols:
92
+ lines.append(f"{table}({', '.join(cols)})")
93
+ finally:
94
+ conn.close()
95
+
96
+ return "\n".join(lines)
97
+
98
+ def get_schema_preview(
99
+ self,
100
+ db_id: Optional[str],
101
+ override: Optional[str],
102
+ ) -> str:
103
+ """
104
+ Decide which schema preview to use.
105
+
106
+ - If override is provided by the client → use it.
107
+ - Else, in sqlite mode → introspect the DB.
108
+ - In postgres mode without override → fail fast, the caller can map
109
+ this to a proper HTTP error.
110
+ """
111
+ if override:
112
+ return override
113
+
114
+ mode = self.settings.db_mode.lower()
115
+ if mode == "postgres":
116
+ raise SchemaRequired("schema_preview is required in postgres mode")
117
+
118
+ try:
119
+ adapter = self._select_adapter(db_id)
120
+ return self._introspect_sqlite_schema(adapter)
121
+ except DbNotFound:
122
+ raise
123
+ except Exception as exc:
124
+ raise SchemaDeriveError("failed to derive schema preview") from exc
125
+
126
+ def run_query(
127
+ self,
128
+ *,
129
+ query: str,
130
+ db_id: Optional[str],
131
+ schema_preview: str,
132
+ ) -> FinalResult:
133
+ """Build a pipeline for the given DB and run the query through it."""
134
+ try:
135
+ adapter = self._select_adapter(db_id)
136
+ except AppError:
137
+ raise
138
+ except Exception as exc:
139
+ raise PipelineRunError("failed to select adapter") from exc
140
+
141
+ try:
142
+ pipeline = pipeline_from_config_with_adapter(
143
+ self.settings.pipeline_config_path,
144
+ adapter=adapter,
145
+ )
146
+ except FileNotFoundError as exc:
147
+ raise PipelineConfigError(
148
+ f"Pipeline config not found at {self.settings.pipeline_config_path!r}"
149
+ ) from exc
150
+ except Exception as exc:
151
+ raise PipelineConfigError(
152
+ f"Failed to build pipeline from {self.settings.pipeline_config_path!r}: {exc}"
153
+ ) from exc
154
+
155
+ try:
156
+ result = pipeline.run(user_query=query, schema_preview=schema_preview)
157
+ except AppError:
158
+ raise
159
+ except Exception as exc:
160
+ raise PipelineRunError("pipeline crashed during execution") from exc
161
+
162
+ return result
app/settings.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from functools import lru_cache
6
+ from pathlib import Path
7
+
8
+ # Resolve repo root from this file's location:
9
+ # app/settings.py → parent = app/ → parent = repo root
10
+ REPO_ROOT = Path(__file__).resolve().parents[1]
11
+
12
+ # Canonical demo DB and pipeline config shipped with the repo
13
+ DEFAULT_DEMO_DB = REPO_ROOT / "data" / "demo.db"
14
+ DEFAULT_PIPELINE_CONFIG = REPO_ROOT / "configs" / "sqlite_pipeline.yaml"
15
+
16
+
17
+ @dataclass
18
+ class Settings:
19
+ """
20
+ Centralized application configuration.
21
+
22
+ Does NOT depend on pydantic. Values are loaded from environment
23
+ variables via Settings.from_env().
24
+ """
25
+
26
+ # --- DB mode / adapters ---
27
+ db_mode: str = "sqlite" # "sqlite" or "postgres"
28
+ postgres_dsn: str = ""
29
+
30
+ # --- Pipeline config ---
31
+ pipeline_config_path: str = str(DEFAULT_PIPELINE_CONFIG)
32
+
33
+ # --- SQLite uploaded DBs ---
34
+ db_upload_dir: str = "/tmp/nl2sql_dbs"
35
+ db_ttl_seconds: int = 7200 # 2 hours
36
+
37
+ # --- Upload constraints ---
38
+ upload_max_bytes: int = 20 * 1024 * 1024 # 20MB
39
+
40
+ # --- Cache settings ---
41
+ cache_ttl_sec: int = 300
42
+ cache_max_entries: int = 256
43
+
44
+ # --- Default SQLite path (demo DB) ---
45
+ default_sqlite_path: str = str(DEFAULT_DEMO_DB)
46
+
47
+ # --- API keys (comma-separated) ---
48
+ api_keys_raw: str = ""
49
+
50
+ # --- App version ---
51
+ app_version: str = "dev"
52
+
53
+ @classmethod
54
+ def from_env(cls) -> "Settings":
55
+ """
56
+ Build Settings from environment variables with sane fallbacks.
57
+
58
+ - DEFAULT_SQLITE_PATH and PIPELINE_CONFIG can be absolute or relative.
59
+ - Relative paths are resolved against REPO_ROOT.
60
+ """
61
+
62
+ def getenv_int(name: str, default: int) -> int:
63
+ raw = os.getenv(name)
64
+ if raw is None or raw.strip() == "":
65
+ return default
66
+ try:
67
+ return int(raw)
68
+ except ValueError:
69
+ return default
70
+
71
+ # --- Default SQLite path ---
72
+ raw_default_db = os.getenv("DEFAULT_SQLITE_PATH", "").strip()
73
+ if raw_default_db:
74
+ db_candidate = Path(raw_default_db)
75
+ if not db_candidate.is_absolute():
76
+ db_candidate = REPO_ROOT / raw_default_db
77
+ else:
78
+ db_candidate = DEFAULT_DEMO_DB
79
+
80
+ # --- Pipeline config path ---
81
+ raw_cfg = os.getenv("PIPELINE_CONFIG", "").strip()
82
+ if raw_cfg:
83
+ cfg_candidate = Path(raw_cfg)
84
+ if not cfg_candidate.is_absolute():
85
+ cfg_candidate = REPO_ROOT / raw_cfg
86
+ else:
87
+ cfg_candidate = DEFAULT_PIPELINE_CONFIG
88
+
89
+ return cls(
90
+ db_mode=os.getenv("DB_MODE", cls.db_mode),
91
+ postgres_dsn=os.getenv("POSTGRES_DSN", cls.postgres_dsn),
92
+ pipeline_config_path=str(cfg_candidate),
93
+ db_upload_dir=os.getenv("DB_UPLOAD_DIR", cls.db_upload_dir),
94
+ db_ttl_seconds=getenv_int("DB_TTL_SECONDS", cls.db_ttl_seconds),
95
+ upload_max_bytes=getenv_int("UPLOAD_MAX_BYTES", cls.upload_max_bytes),
96
+ cache_ttl_sec=getenv_int("NL2SQL_CACHE_TTL_SEC", cls.cache_ttl_sec),
97
+ cache_max_entries=getenv_int("NL2SQL_CACHE_MAX", cls.cache_max_entries),
98
+ default_sqlite_path=str(db_candidate),
99
+ api_keys_raw=os.getenv("API_KEYS", cls.api_keys_raw),
100
+ app_version=os.getenv("APP_VERSION", cls.app_version),
101
+ )
102
+
103
+
104
+ @lru_cache()
105
+ def get_settings() -> Settings:
106
+ return Settings.from_env()
app/state.py CHANGED
@@ -1,79 +1,166 @@
 
 
 
1
  import os
2
  import time
3
- import logging
4
  from pathlib import Path
5
- from typing import Optional, TypedDict
6
 
7
  log = logging.getLogger(__name__)
8
 
9
- # ------------------------------
10
- # Config
11
- # ------------------------------
12
-
13
- # default upload directory (can override via .env)
14
- _DB_UPLOAD_DIR = Path(os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs"))
15
- _DB_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
16
-
17
-
18
- class DBEntry(TypedDict):
19
- path: str
20
- ts: float
21
-
22
 
23
- # in-memory map: {db_id: {"path": str, "ts": float}}
24
- DB_MAP: dict[str, DBEntry] = {}
25
-
26
- # cleanup threshold (hours)
27
- DB_TTL_HOURS = 6
28
-
29
-
30
- # ------------------------------
31
- # Helpers
32
- # ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  def register_db(db_id: str, path: str) -> None:
36
- """Register new DB in memory (and ensure dir exists)."""
37
- _DB_UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
38
- DB_MAP[db_id] = {"path": path, "ts": time.time()}
39
- log.info(f"📦 Registered DB {db_id} -> {path}")
 
 
40
 
41
 
42
  def cleanup_stale_dbs() -> None:
43
- """Remove expired DBs from /tmp/nl2sql_dbs and memory map."""
44
- now = time.time()
45
- cutoff = DB_TTL_HOURS * 3600
46
- stale_ids = [db_id for db_id, entry in DB_MAP.items() if now - entry["ts"] > cutoff]
47
- for db_id in stale_ids:
48
- path_str = DB_MAP[db_id]["path"]
49
- path = Path(path_str)
50
- try:
51
- if path.exists():
52
- path.unlink()
53
- log.info(f"🧹 Deleted stale DB: {path}")
54
- except FileNotFoundError:
55
- pass
56
- DB_MAP.pop(db_id, None)
57
 
58
 
59
  def get_db_path(db_id: str) -> Optional[str]:
60
- """Return full path of an uploaded DB (persistent lookup)."""
61
- entry = DB_MAP.get(db_id)
62
- if entry:
63
- path_str = entry["path"]
64
- if Path(path_str).exists():
65
- return path_str
66
-
67
- candidates = [
68
- _DB_UPLOAD_DIR / f"{db_id}.sqlite",
69
- _DB_UPLOAD_DIR / f"{db_id}.db",
70
- Path("data/uploads") / f"{db_id}.sqlite",
71
- Path("data/uploads") / f"{db_id}.db",
72
- ]
73
- for p in candidates:
74
- if p.exists():
75
- log.info(f"🔍 Recovered DB path for {db_id}: {p}")
76
- return str(p)
77
-
78
- log.warning(f"⚠️ DB file not found for id={db_id}")
79
- return None
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
  import os
5
  import time
 
6
  from pathlib import Path
7
+ from typing import Dict, Tuple, Optional
8
 
9
  log = logging.getLogger(__name__)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ class DbUploadStore:
13
+ """
14
+ In-memory registry for uploaded DB files with simple TTL-based cleanup.
15
+
16
+ Responsibilities:
17
+ - Track uploaded DBs by db_id -> filesystem path.
18
+ - Enforce a TTL for uploaded DBs.
19
+ - Remove stale entries and delete underlying files when expired.
20
+ """
21
+
22
+ def __init__(self, upload_dir: str, ttl_seconds: int) -> None:
23
+ self.upload_dir = upload_dir
24
+ self.ttl_seconds = ttl_seconds
25
+ self._entries: Dict[str, Tuple[str, float]] = {}
26
+
27
+ Path(self.upload_dir).mkdir(parents=True, exist_ok=True)
28
+ log.debug(
29
+ "Initialized DbUploadStore",
30
+ extra={
31
+ "upload_dir": self.upload_dir,
32
+ "ttl_seconds": self.ttl_seconds,
33
+ },
34
+ )
35
+
36
+ def _now(self) -> float:
37
+ return time.time()
38
+
39
+ def _is_expired(self, ts: float, now: Optional[float] = None) -> bool:
40
+ if now is None:
41
+ now = self._now()
42
+ return (now - ts) > self.ttl_seconds
43
+
44
+ def _gc_locked(self, now: Optional[float] = None) -> None:
45
+ """
46
+ Internal garbage collector.
47
+
48
+ Removes stale entries and deletes the corresponding files on disk
49
+ if they still exist.
50
+ """
51
+ if now is None:
52
+ now = self._now()
53
+
54
+ to_delete = []
55
+ for db_id, (path, ts) in list(self._entries.items()):
56
+ if self._is_expired(ts, now) or (not os.path.exists(path)):
57
+ to_delete.append((db_id, path))
58
+
59
+ for db_id, path in to_delete:
60
+ self._entries.pop(db_id, None)
61
+ try:
62
+ if os.path.exists(path):
63
+ os.remove(path)
64
+ log.debug(
65
+ "Deleted expired uploaded DB file",
66
+ extra={"db_id": db_id, "path": path},
67
+ )
68
+ except Exception as exc:
69
+ # Best-effort cleanup; do not crash the app because of FS issues.
70
+ log.debug(
71
+ "Failed to delete expired uploaded DB file",
72
+ extra={"db_id": db_id, "path": path},
73
+ exc_info=exc,
74
+ )
75
+
76
+ def cleanup_stale(self) -> None:
77
+ """
78
+ Public cleanup entry point.
79
+
80
+ Can be called periodically (or on access) to remove expired DBs.
81
+ """
82
+ self._gc_locked()
83
+
84
+ def register(self, db_id: str, path: str) -> None:
85
+ """
86
+ Register a new uploaded DB with the given db_id and filesystem path.
87
+ """
88
+ now = self._now()
89
+ self._entries[db_id] = (path, now)
90
+ log.debug(
91
+ "Registered uploaded DB in DbUploadStore",
92
+ extra={"db_id": db_id, "path": path},
93
+ )
94
+ # Optionally clean up old entries as we go.
95
+ self._gc_locked(now=now)
96
+
97
+ def resolve(self, db_id: str) -> Optional[str]:
98
+ """
99
+ Resolve db_id to a filesystem path if it exists and is not expired.
100
+
101
+ Returns:
102
+ str path if valid, or None if missing/expired.
103
+ """
104
+ self._gc_locked()
105
+ entry = self._entries.get(db_id)
106
+ if not entry:
107
+ return None
108
+
109
+ path, ts = entry
110
+ if self._is_expired(ts):
111
+ # Expired between last GC and now; treat as missing.
112
+ self._entries.pop(db_id, None)
113
+ try:
114
+ if os.path.exists(path):
115
+ os.remove(path)
116
+ except Exception as exc:
117
+ log.debug(
118
+ "Failed to delete DB file on late-expiration",
119
+ extra={"db_id": db_id, "path": path},
120
+ exc_info=exc,
121
+ )
122
+ return None
123
+
124
+ if not os.path.exists(path):
125
+ # File disappeared; drop the entry.
126
+ self._entries.pop(db_id, None)
127
+ return None
128
+
129
+ return path
130
+
131
+
132
+ # --------------------------------------------------------------------
133
+ # Module-level singleton and legacy helper functions
134
+ # --------------------------------------------------------------------
135
+
136
+ _DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
137
+ _DB_TTL_SECONDS = int(os.getenv("DB_TTL_SECONDS", "7200")) # default: 2 hours
138
+
139
+ _STORE = DbUploadStore(upload_dir=_DB_UPLOAD_DIR, ttl_seconds=_DB_TTL_SECONDS)
140
 
141
 
142
  def register_db(db_id: str, path: str) -> None:
143
+ """
144
+ Backwards-compatible helper:
145
+
146
+ Register an uploaded DB in the process-wide DbUploadStore.
147
+ """
148
+ _STORE.register(db_id, path)
149
 
150
 
151
  def cleanup_stale_dbs() -> None:
152
+ """
153
+ Backwards-compatible helper:
154
+
155
+ Trigger TTL-based cleanup of stale DB entries.
156
+ """
157
+ _STORE.cleanup_stale()
 
 
 
 
 
 
 
 
158
 
159
 
160
  def get_db_path(db_id: str) -> Optional[str]:
161
+ """
162
+ Backwards-compatible helper:
163
+
164
+ Resolve db_id to a filesystem path if it is still valid.
165
+ """
166
+ return _STORE.resolve(db_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
huggingface.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ sdk: gradio
2
+ app_file: app/app.py
3
+ python_version: "3.10"