Spaces:
Running
Running
Melika Kheirieh
commited on
Commit
·
977a885
1
Parent(s):
370553a
fix(mypy): add types-requests, validate numeric inputs, enforce non-null schema_preview
Browse files- app/routers/nl2sql.py +63 -37
- requirements.txt +1 -0
app/routers/nl2sql.py
CHANGED
|
@@ -18,7 +18,7 @@ from pathlib import Path
|
|
| 18 |
import time
|
| 19 |
import json
|
| 20 |
import uuid
|
| 21 |
-
from typing import Union, Optional, Dict
|
| 22 |
|
| 23 |
router = APIRouter(prefix="/nl2sql")
|
| 24 |
|
|
@@ -27,20 +27,24 @@ router = APIRouter(prefix="/nl2sql")
|
|
| 27 |
# Files are stored under /tmp, mapped by a short-lived db_id
|
| 28 |
# -------------------------------
|
| 29 |
_DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
|
| 30 |
-
_DB_TTL_SECONDS = int(os.getenv("DB_TTL_SECONDS", "7200")) # default 2 hours
|
| 31 |
os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# In-memory map: db_id -> {"path": str, "ts": float}
|
| 34 |
-
_DB_MAP: Dict[str,
|
| 35 |
|
| 36 |
# -------------------------------
|
| 37 |
# Default DB resolution
|
| 38 |
# -------------------------------
|
| 39 |
DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
|
| 40 |
POSTGRES_DSN = os.getenv("POSTGRES_DSN")
|
| 41 |
-
DEFAULT_SQLITE_DB = os.getenv(
|
| 42 |
-
"DEFAULT_SQLITE_DB", "data/chinook.db"
|
| 43 |
-
) # keep your current default
|
| 44 |
|
| 45 |
# -------------------------------
|
| 46 |
# Path to persist db_id → file map
|
|
@@ -48,14 +52,13 @@ DEFAULT_SQLITE_DB = os.getenv(
|
|
| 48 |
_DB_MAP_PATH = Path("data/uploads/db_map.json")
|
| 49 |
_DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 50 |
|
| 51 |
-
|
| 52 |
UPLOAD_DIR = Path("data/uploads")
|
| 53 |
UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # ensure folder exists
|
| 54 |
|
| 55 |
DEFAULT_SQLITE_PATH = "data/Chinook_Sqlite.sqlite"
|
| 56 |
|
| 57 |
|
| 58 |
-
def _save_db_map():
|
| 59 |
"""Persist the in-memory DB map to disk as JSON."""
|
| 60 |
try:
|
| 61 |
with open(_DB_MAP_PATH, "w") as f:
|
|
@@ -64,15 +67,22 @@ def _save_db_map():
|
|
| 64 |
print(f"⚠️ Failed to save DB map: {e}")
|
| 65 |
|
| 66 |
|
| 67 |
-
def _load_db_map():
|
| 68 |
"""Load the DB map from disk if it exists (called on startup)."""
|
| 69 |
global _DB_MAP
|
| 70 |
if _DB_MAP_PATH.exists():
|
| 71 |
try:
|
| 72 |
with open(_DB_MAP_PATH, "r") as f:
|
| 73 |
data = json.load(f)
|
|
|
|
| 74 |
if isinstance(data, dict):
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
print(f"📂 Restored {_DB_MAP_PATH} with {len(_DB_MAP)} entries.")
|
| 77 |
except Exception as e:
|
| 78 |
print(f"⚠️ Failed to load DB map: {e}")
|
|
@@ -81,13 +91,11 @@ def _load_db_map():
|
|
| 81 |
def _cleanup_db_map() -> None:
|
| 82 |
"""Remove expired uploaded DB files (best-effort)."""
|
| 83 |
now = time.time()
|
| 84 |
-
expired = [
|
| 85 |
-
k for k, v in _DB_MAP.items() if now - float(v.get("ts", 0)) > _DB_TTL_SECONDS
|
| 86 |
-
]
|
| 87 |
for k in expired:
|
| 88 |
-
path = _DB_MAP[k]
|
| 89 |
try:
|
| 90 |
-
if
|
| 91 |
os.remove(path)
|
| 92 |
except Exception:
|
| 93 |
pass
|
|
@@ -98,11 +106,11 @@ def _resolve_sqlite_path(db_id: Optional[str]) -> str:
|
|
| 98 |
"""Resolve a SQLite file path from db_id or fallback to default."""
|
| 99 |
_cleanup_db_map()
|
| 100 |
if db_id and db_id in _DB_MAP:
|
| 101 |
-
return
|
| 102 |
return DEFAULT_SQLITE_DB
|
| 103 |
|
| 104 |
|
| 105 |
-
def _select_adapter(db_id: str
|
| 106 |
mode = os.getenv("DB_MODE", "sqlite").lower()
|
| 107 |
if mode == "postgres":
|
| 108 |
dsn = os.environ.get("POSTGRES_DSN")
|
|
@@ -113,23 +121,23 @@ def _select_adapter(db_id: str | None):
|
|
| 113 |
# sqlite mode
|
| 114 |
if db_id:
|
| 115 |
_cleanup_db_map()
|
| 116 |
-
db_path = None
|
| 117 |
# first check runtime map
|
| 118 |
if db_id in _DB_MAP:
|
| 119 |
-
db_path = _DB_MAP[db_id]
|
| 120 |
# fallback: check /tmp or uploads
|
| 121 |
if not db_path or not os.path.exists(db_path):
|
| 122 |
fallback_tmp = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
|
| 123 |
-
fallback_uploads = UPLOAD_DIR / f"{db_id}.sqlite"
|
| 124 |
for candidate in (fallback_tmp, fallback_uploads):
|
| 125 |
if os.path.exists(candidate):
|
| 126 |
-
db_path =
|
| 127 |
break
|
| 128 |
if not db_path or not os.path.exists(db_path):
|
| 129 |
raise HTTPException(
|
| 130 |
status_code=400, detail="invalid db_id (file not found)"
|
| 131 |
)
|
| 132 |
-
return SQLiteAdapter(
|
| 133 |
|
| 134 |
# fallback to default Chinook
|
| 135 |
if not Path(DEFAULT_SQLITE_PATH).exists():
|
|
@@ -169,17 +177,28 @@ def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
|
|
| 169 |
# -------------------------------
|
| 170 |
# Helpers
|
| 171 |
# -------------------------------
|
| 172 |
-
def _to_dict(obj):
|
| 173 |
-
"""Safely convert dataclass → dict.
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
-
def _round_trace(t:
|
| 178 |
"""Round float fields to keep responses tidy and stable."""
|
| 179 |
if t.get("cost_usd") is not None:
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
| 181 |
if t.get("duration_ms") is not None:
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
return t
|
| 184 |
|
| 185 |
|
|
@@ -244,18 +263,22 @@ def nl2sql_handler(request: NL2SQLRequest):
|
|
| 244 |
pipeline = _build_pipeline(adapter)
|
| 245 |
|
| 246 |
# 2) Resolve schema_preview (optional in request)
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
|
|
|
| 252 |
)
|
| 253 |
|
|
|
|
|
|
|
|
|
|
| 254 |
# 3) Run pipeline
|
| 255 |
try:
|
| 256 |
result = pipeline.run(
|
| 257 |
-
user_query=request.query, # assumes NL2SQLRequest has `query`
|
| 258 |
-
schema_preview=
|
| 259 |
)
|
| 260 |
except Exception as exc:
|
| 261 |
# Hard failure in pipeline itself
|
|
@@ -291,7 +314,7 @@ def nl2sql_handler(request: NL2SQLRequest):
|
|
| 291 |
)
|
| 292 |
|
| 293 |
|
| 294 |
-
def _derive_schema_preview(adapter) -> str:
|
| 295 |
"""
|
| 296 |
Build a strict, exact-cased schema preview for the LLM.
|
| 297 |
Works for SQLite adapters by querying sqlite_master / pragma table_info.
|
|
@@ -299,7 +322,10 @@ def _derive_schema_preview(adapter) -> str:
|
|
| 299 |
import sqlite3
|
| 300 |
import os
|
| 301 |
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
| 303 |
if not db_path or not os.path.exists(db_path):
|
| 304 |
return ""
|
| 305 |
|
|
|
|
| 18 |
import time
|
| 19 |
import json
|
| 20 |
import uuid
|
| 21 |
+
from typing import Union, Optional, Dict, TypedDict, Any, cast
|
| 22 |
|
| 23 |
router = APIRouter(prefix="/nl2sql")
|
| 24 |
|
|
|
|
| 27 |
# Files are stored under /tmp, mapped by a short-lived db_id
|
| 28 |
# -------------------------------
|
| 29 |
_DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
|
| 30 |
+
_DB_TTL_SECONDS: int = int(os.getenv("DB_TTL_SECONDS", "7200")) # default 2 hours
|
| 31 |
os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
|
| 32 |
|
| 33 |
+
|
| 34 |
+
class DBEntry(TypedDict):
|
| 35 |
+
path: str
|
| 36 |
+
ts: float
|
| 37 |
+
|
| 38 |
+
|
| 39 |
# In-memory map: db_id -> {"path": str, "ts": float}
|
| 40 |
+
_DB_MAP: Dict[str, DBEntry] = {}
|
| 41 |
|
| 42 |
# -------------------------------
|
| 43 |
# Default DB resolution
|
| 44 |
# -------------------------------
|
| 45 |
DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
|
| 46 |
POSTGRES_DSN = os.getenv("POSTGRES_DSN")
|
| 47 |
+
DEFAULT_SQLITE_DB: str = os.getenv("DEFAULT_SQLITE_DB", "data/chinook.db")
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# -------------------------------
|
| 50 |
# Path to persist db_id → file map
|
|
|
|
| 52 |
_DB_MAP_PATH = Path("data/uploads/db_map.json")
|
| 53 |
_DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 54 |
|
|
|
|
| 55 |
UPLOAD_DIR = Path("data/uploads")
|
| 56 |
UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # ensure folder exists
|
| 57 |
|
| 58 |
DEFAULT_SQLITE_PATH = "data/Chinook_Sqlite.sqlite"
|
| 59 |
|
| 60 |
|
| 61 |
+
def _save_db_map() -> None:
|
| 62 |
"""Persist the in-memory DB map to disk as JSON."""
|
| 63 |
try:
|
| 64 |
with open(_DB_MAP_PATH, "w") as f:
|
|
|
|
| 67 |
print(f"⚠️ Failed to save DB map: {e}")
|
| 68 |
|
| 69 |
|
| 70 |
+
def _load_db_map() -> None:
|
| 71 |
"""Load the DB map from disk if it exists (called on startup)."""
|
| 72 |
global _DB_MAP
|
| 73 |
if _DB_MAP_PATH.exists():
|
| 74 |
try:
|
| 75 |
with open(_DB_MAP_PATH, "r") as f:
|
| 76 |
data = json.load(f)
|
| 77 |
+
# Be liberal in what we accept; validate into TypedDict
|
| 78 |
if isinstance(data, dict):
|
| 79 |
+
restored: Dict[str, DBEntry] = {}
|
| 80 |
+
for k, v in data.items():
|
| 81 |
+
path = v.get("path")
|
| 82 |
+
ts = v.get("ts")
|
| 83 |
+
if isinstance(path, str) and isinstance(ts, (int, float)):
|
| 84 |
+
restored[k] = {"path": path, "ts": float(ts)}
|
| 85 |
+
_DB_MAP.update(restored)
|
| 86 |
print(f"📂 Restored {_DB_MAP_PATH} with {len(_DB_MAP)} entries.")
|
| 87 |
except Exception as e:
|
| 88 |
print(f"⚠️ Failed to load DB map: {e}")
|
|
|
|
| 91 |
def _cleanup_db_map() -> None:
|
| 92 |
"""Remove expired uploaded DB files (best-effort)."""
|
| 93 |
now = time.time()
|
| 94 |
+
expired = [k for k, v in _DB_MAP.items() if (now - v["ts"]) > _DB_TTL_SECONDS]
|
|
|
|
|
|
|
| 95 |
for k in expired:
|
| 96 |
+
path: str = _DB_MAP[k]["path"]
|
| 97 |
try:
|
| 98 |
+
if os.path.exists(path):
|
| 99 |
os.remove(path)
|
| 100 |
except Exception:
|
| 101 |
pass
|
|
|
|
| 106 |
"""Resolve a SQLite file path from db_id or fallback to default."""
|
| 107 |
_cleanup_db_map()
|
| 108 |
if db_id and db_id in _DB_MAP:
|
| 109 |
+
return _DB_MAP[db_id]["path"]
|
| 110 |
return DEFAULT_SQLITE_DB
|
| 111 |
|
| 112 |
|
| 113 |
+
def _select_adapter(db_id: Optional[str]):
|
| 114 |
mode = os.getenv("DB_MODE", "sqlite").lower()
|
| 115 |
if mode == "postgres":
|
| 116 |
dsn = os.environ.get("POSTGRES_DSN")
|
|
|
|
| 121 |
# sqlite mode
|
| 122 |
if db_id:
|
| 123 |
_cleanup_db_map()
|
| 124 |
+
db_path: Optional[str] = None
|
| 125 |
# first check runtime map
|
| 126 |
if db_id in _DB_MAP:
|
| 127 |
+
db_path = _DB_MAP[db_id]["path"]
|
| 128 |
# fallback: check /tmp or uploads
|
| 129 |
if not db_path or not os.path.exists(db_path):
|
| 130 |
fallback_tmp = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
|
| 131 |
+
fallback_uploads = str(UPLOAD_DIR / f"{db_id}.sqlite")
|
| 132 |
for candidate in (fallback_tmp, fallback_uploads):
|
| 133 |
if os.path.exists(candidate):
|
| 134 |
+
db_path = candidate
|
| 135 |
break
|
| 136 |
if not db_path or not os.path.exists(db_path):
|
| 137 |
raise HTTPException(
|
| 138 |
status_code=400, detail="invalid db_id (file not found)"
|
| 139 |
)
|
| 140 |
+
return SQLiteAdapter(db_path)
|
| 141 |
|
| 142 |
# fallback to default Chinook
|
| 143 |
if not Path(DEFAULT_SQLITE_PATH).exists():
|
|
|
|
| 177 |
# -------------------------------
|
| 178 |
# Helpers
|
| 179 |
# -------------------------------
|
| 180 |
+
def _to_dict(obj: Any) -> Any:
|
| 181 |
+
"""Safely convert dataclass instance → dict, otherwise return as-is.
|
| 182 |
+
|
| 183 |
+
Note: dataclasses.is_dataclass returns True for both classes and instances.
|
| 184 |
+
We must exclude classes; mypy cannot refine this perfectly, so we ignore arg-type.
|
| 185 |
+
"""
|
| 186 |
+
if is_dataclass(obj) and not isinstance(obj, type):
|
| 187 |
+
return asdict(obj) # type: ignore[arg-type]
|
| 188 |
+
return obj
|
| 189 |
|
| 190 |
|
| 191 |
+
def _round_trace(t: Dict[str, Any]) -> Dict[str, Any]:
|
| 192 |
"""Round float fields to keep responses tidy and stable."""
|
| 193 |
if t.get("cost_usd") is not None:
|
| 194 |
+
# Ensure numeric before rounding
|
| 195 |
+
cost = t["cost_usd"]
|
| 196 |
+
if isinstance(cost, (int, float)):
|
| 197 |
+
t["cost_usd"] = round(float(cost), 6)
|
| 198 |
if t.get("duration_ms") is not None:
|
| 199 |
+
dur = t["duration_ms"]
|
| 200 |
+
if isinstance(dur, (int, float)):
|
| 201 |
+
t["duration_ms"] = round(float(dur), 2)
|
| 202 |
return t
|
| 203 |
|
| 204 |
|
|
|
|
| 263 |
pipeline = _build_pipeline(adapter)
|
| 264 |
|
| 265 |
# 2) Resolve schema_preview (optional in request)
|
| 266 |
+
provided_preview_any: Any = getattr(request, "schema_preview", None)
|
| 267 |
+
provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
|
| 268 |
+
|
| 269 |
+
derived_preview: str = _derive_schema_preview(adapter)
|
| 270 |
+
schema_preview_opt: Optional[str] = (
|
| 271 |
+
provided_preview if provided_preview not in ("", None) else derived_preview
|
| 272 |
)
|
| 273 |
|
| 274 |
+
# Guarantee a str for Pipeline.run (mypy requirement)
|
| 275 |
+
final_preview: str = schema_preview_opt or ""
|
| 276 |
+
|
| 277 |
# 3) Run pipeline
|
| 278 |
try:
|
| 279 |
result = pipeline.run(
|
| 280 |
+
user_query=request.query, # assumes NL2SQLRequest has `query: str`
|
| 281 |
+
schema_preview=final_preview, # str guaranteed
|
| 282 |
)
|
| 283 |
except Exception as exc:
|
| 284 |
# Hard failure in pipeline itself
|
|
|
|
| 314 |
)
|
| 315 |
|
| 316 |
|
| 317 |
+
def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
|
| 318 |
"""
|
| 319 |
Build a strict, exact-cased schema preview for the LLM.
|
| 320 |
Works for SQLite adapters by querying sqlite_master / pragma table_info.
|
|
|
|
| 322 |
import sqlite3
|
| 323 |
import os
|
| 324 |
|
| 325 |
+
# Adapters may expose db_path or path; both are str in our codebase
|
| 326 |
+
db_path: Optional[str] = cast(
|
| 327 |
+
Optional[str], getattr(adapter, "db_path", None)
|
| 328 |
+
) or cast(Optional[str], getattr(adapter, "path", None))
|
| 329 |
if not db_path or not os.path.exists(db_path):
|
| 330 |
return ""
|
| 331 |
|
requirements.txt
CHANGED
|
@@ -12,3 +12,4 @@ psycopg[binary]~=3.2
|
|
| 12 |
ruff
|
| 13 |
gradio
|
| 14 |
sqlalchemy
|
|
|
|
|
|
| 12 |
ruff
|
| 13 |
gradio
|
| 14 |
sqlalchemy
|
| 15 |
+
types-requests
|