Melika Kheirieh commited on
Commit
c76014a
·
1 Parent(s): 24a500b

fix(router): align _select_adapter with upload_db via shared DB_MAP + TTL cleanup

Browse files
Files changed (2) hide show
  1. app/routers/nl2sql.py +10 -31
  2. app/state.py +24 -0
app/routers/nl2sql.py CHANGED
@@ -7,13 +7,14 @@ import os
7
  from pathlib import Path
8
  import time
9
  import uuid
10
- from typing import Any, Dict, Optional, TypedDict, Union, cast, List, Callable
11
 
12
  # --- Third-party ---
13
  from fastapi import APIRouter, HTTPException, UploadFile, File, Depends
14
 
15
  # --- Local ---
16
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
 
17
  from nl2sql.pipeline import FinalResult, FinalResult as _FinalResult
18
  from adapters.llm.openai_provider import OpenAIProvider
19
  from adapters.db.sqlite_adapter import SQLiteAdapter
@@ -144,17 +145,6 @@ _load_db_map()
144
  def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
145
  """
146
  Resolve a DB adapter based on module-level DB_MODE and an optional db_id.
147
-
148
- - postgres mode:
149
- requires POSTGRES_DSN in env
150
- - sqlite mode:
151
- if db_id provided, resolve file by:
152
- 1) absolute path (if user supplied a full path)
153
- 2) uploads/{db_id}.sqlite
154
- 3) uploads/{db_id}.db
155
- 4) data/{db_id}.sqlite
156
- 5) data/{db_id}.db
157
- else fallback to DEFAULT_SQLITE_PATH
158
  """
159
  if DB_MODE == "postgres":
160
  dsn = os.environ.get("POSTGRES_DSN")
@@ -164,25 +154,13 @@ def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapte
164
 
165
  # sqlite mode
166
  if db_id:
167
- # 1) absolute path
168
- p = Path(db_id)
169
- candidates: List[Path] = []
170
- if p.is_absolute():
171
- candidates.append(p)
172
-
173
- # 2) uploads/
174
- candidates.append(UPLOAD_DIR / f"{db_id}.sqlite")
175
- candidates.append(UPLOAD_DIR / f"{db_id}.db")
176
-
177
- # 3) data/
178
- candidates.append(Path("data") / f"{db_id}.sqlite")
179
- candidates.append(Path("data") / f"{db_id}.db")
180
-
181
- for c in candidates:
182
- if c.exists() and c.is_file():
183
- return SQLiteAdapter(str(c))
184
-
185
- raise HTTPException(status_code=400, detail="invalid db_id (file not found)")
186
 
187
  # default sqlite fallback
188
  default_path = Path(DEFAULT_SQLITE_PATH)
@@ -288,6 +266,7 @@ async def upload_db(file: UploadFile = File(...)):
288
 
289
  _DB_MAP[db_id] = {"path": out_path, "ts": time.time()}
290
  _save_db_map()
 
291
  return {"db_id": db_id}
292
 
293
 
 
7
  from pathlib import Path
8
  import time
9
  import uuid
10
+ from typing import Any, Dict, Optional, TypedDict, Union, cast, Callable
11
 
12
  # --- Third-party ---
13
  from fastapi import APIRouter, HTTPException, UploadFile, File, Depends
14
 
15
  # --- Local ---
16
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
17
+ from app.state import get_db_path, cleanup_stale_dbs, register_db
18
  from nl2sql.pipeline import FinalResult, FinalResult as _FinalResult
19
  from adapters.llm.openai_provider import OpenAIProvider
20
  from adapters.db.sqlite_adapter import SQLiteAdapter
 
145
  def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
146
  """
147
  Resolve a DB adapter based on module-level DB_MODE and an optional db_id.
 
 
 
 
 
 
 
 
 
 
 
148
  """
149
  if DB_MODE == "postgres":
150
  dsn = os.environ.get("POSTGRES_DSN")
 
154
 
155
  # sqlite mode
156
  if db_id:
157
+ cleanup_stale_dbs()
158
+ path = get_db_path(db_id)
159
+ if path and os.path.exists(path):
160
+ return SQLiteAdapter(path)
161
+ raise HTTPException(
162
+ status_code=404, detail=f"db_id not found or expired: {db_id}"
163
+ )
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # default sqlite fallback
166
  default_path = Path(DEFAULT_SQLITE_PATH)
 
266
 
267
  _DB_MAP[db_id] = {"path": out_path, "ts": time.time()}
268
  _save_db_map()
269
+ register_db(db_id, out_path)
270
  return {"db_id": db_id}
271
 
272
 
app/state.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import Dict, Any
4
+
5
+ DB_TTL_SECONDS = int(os.getenv("NL2SQL_DB_TTL_SEC", "86400"))
6
+ DB_MAP: Dict[str, Dict[str, Any]] = {}
7
+
8
+
9
+ def register_db(db_id: str, path: str) -> None:
10
+ DB_MAP[db_id] = {"path": path, "created_at": time.time()}
11
+
12
+
13
+ def get_db_path(db_id: str) -> str | None:
14
+ entry = DB_MAP.get(db_id)
15
+ return entry["path"] if entry else None
16
+
17
+
18
+ def cleanup_stale_dbs() -> None:
19
+ now = time.time()
20
+ stale = [
21
+ k for k, v in DB_MAP.items() if now - v.get("created_at", now) > DB_TTL_SECONDS
22
+ ]
23
+ for k in stale:
24
+ DB_MAP.pop(k, None)