Melika Kheirieh commited on
Commit
977a885
·
1 Parent(s): 370553a

fix(mypy): add types-requests, validate numeric inputs, enforce non-null schema_preview

Browse files
Files changed (2) hide show
  1. app/routers/nl2sql.py +63 -37
  2. requirements.txt +1 -0
app/routers/nl2sql.py CHANGED
@@ -18,7 +18,7 @@ from pathlib import Path
18
  import time
19
  import json
20
  import uuid
21
- from typing import Union, Optional, Dict
22
 
23
  router = APIRouter(prefix="/nl2sql")
24
 
@@ -27,20 +27,24 @@ router = APIRouter(prefix="/nl2sql")
27
  # Files are stored under /tmp, mapped by a short-lived db_id
28
  # -------------------------------
29
  _DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
30
- _DB_TTL_SECONDS = int(os.getenv("DB_TTL_SECONDS", "7200")) # default 2 hours
31
  os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
32
 
 
 
 
 
 
 
33
  # In-memory map: db_id -> {"path": str, "ts": float}
34
- _DB_MAP: Dict[str, Dict[str, object]] = {}
35
 
36
  # -------------------------------
37
  # Default DB resolution
38
  # -------------------------------
39
  DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
40
  POSTGRES_DSN = os.getenv("POSTGRES_DSN")
41
- DEFAULT_SQLITE_DB = os.getenv(
42
- "DEFAULT_SQLITE_DB", "data/chinook.db"
43
- ) # keep your current default
44
 
45
  # -------------------------------
46
  # Path to persist db_id → file map
@@ -48,14 +52,13 @@ DEFAULT_SQLITE_DB = os.getenv(
48
  _DB_MAP_PATH = Path("data/uploads/db_map.json")
49
  _DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
50
 
51
-
52
  UPLOAD_DIR = Path("data/uploads")
53
  UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # ensure folder exists
54
 
55
  DEFAULT_SQLITE_PATH = "data/Chinook_Sqlite.sqlite"
56
 
57
 
58
- def _save_db_map():
59
  """Persist the in-memory DB map to disk as JSON."""
60
  try:
61
  with open(_DB_MAP_PATH, "w") as f:
@@ -64,15 +67,22 @@ def _save_db_map():
64
  print(f"⚠️ Failed to save DB map: {e}")
65
 
66
 
67
- def _load_db_map():
68
  """Load the DB map from disk if it exists (called on startup)."""
69
  global _DB_MAP
70
  if _DB_MAP_PATH.exists():
71
  try:
72
  with open(_DB_MAP_PATH, "r") as f:
73
  data = json.load(f)
 
74
  if isinstance(data, dict):
75
- _DB_MAP.update(data)
 
 
 
 
 
 
76
  print(f"📂 Restored {_DB_MAP_PATH} with {len(_DB_MAP)} entries.")
77
  except Exception as e:
78
  print(f"⚠️ Failed to load DB map: {e}")
@@ -81,13 +91,11 @@ def _load_db_map():
81
  def _cleanup_db_map() -> None:
82
  """Remove expired uploaded DB files (best-effort)."""
83
  now = time.time()
84
- expired = [
85
- k for k, v in _DB_MAP.items() if now - float(v.get("ts", 0)) > _DB_TTL_SECONDS
86
- ]
87
  for k in expired:
88
- path = _DB_MAP[k].get("path")
89
  try:
90
- if isinstance(path, str) and os.path.exists(path):
91
  os.remove(path)
92
  except Exception:
93
  pass
@@ -98,11 +106,11 @@ def _resolve_sqlite_path(db_id: Optional[str]) -> str:
98
  """Resolve a SQLite file path from db_id or fallback to default."""
99
  _cleanup_db_map()
100
  if db_id and db_id in _DB_MAP:
101
- return str(_DB_MAP[db_id]["path"])
102
  return DEFAULT_SQLITE_DB
103
 
104
 
105
- def _select_adapter(db_id: str | None):
106
  mode = os.getenv("DB_MODE", "sqlite").lower()
107
  if mode == "postgres":
108
  dsn = os.environ.get("POSTGRES_DSN")
@@ -113,23 +121,23 @@ def _select_adapter(db_id: str | None):
113
  # sqlite mode
114
  if db_id:
115
  _cleanup_db_map()
116
- db_path = None
117
  # first check runtime map
118
  if db_id in _DB_MAP:
119
- db_path = _DB_MAP[db_id].get("path")
120
  # fallback: check /tmp or uploads
121
  if not db_path or not os.path.exists(db_path):
122
  fallback_tmp = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
123
- fallback_uploads = UPLOAD_DIR / f"{db_id}.sqlite"
124
  for candidate in (fallback_tmp, fallback_uploads):
125
  if os.path.exists(candidate):
126
- db_path = str(candidate)
127
  break
128
  if not db_path or not os.path.exists(db_path):
129
  raise HTTPException(
130
  status_code=400, detail="invalid db_id (file not found)"
131
  )
132
- return SQLiteAdapter(str(db_path))
133
 
134
  # fallback to default Chinook
135
  if not Path(DEFAULT_SQLITE_PATH).exists():
@@ -169,17 +177,28 @@ def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
169
  # -------------------------------
170
  # Helpers
171
  # -------------------------------
172
- def _to_dict(obj):
173
- """Safely convert dataclass → dict."""
174
- return asdict(obj) if is_dataclass(obj) else obj
 
 
 
 
 
 
175
 
176
 
177
- def _round_trace(t: dict) -> dict:
178
  """Round float fields to keep responses tidy and stable."""
179
  if t.get("cost_usd") is not None:
180
- t["cost_usd"] = round(t["cost_usd"], 6)
 
 
 
181
  if t.get("duration_ms") is not None:
182
- t["duration_ms"] = round(t["duration_ms"], 2)
 
 
183
  return t
184
 
185
 
@@ -244,18 +263,22 @@ def nl2sql_handler(request: NL2SQLRequest):
244
  pipeline = _build_pipeline(adapter)
245
 
246
  # 2) Resolve schema_preview (optional in request)
247
- provided_preview = getattr(request, "schema_preview", None)
248
- schema_preview = (
249
- provided_preview
250
- if provided_preview not in ("", None)
251
- else _derive_schema_preview(adapter)
 
252
  )
253
 
 
 
 
254
  # 3) Run pipeline
255
  try:
256
  result = pipeline.run(
257
- user_query=request.query, # assumes NL2SQLRequest has `query`
258
- schema_preview=schema_preview, # may be empty string if adapter can't derive
259
  )
260
  except Exception as exc:
261
  # Hard failure in pipeline itself
@@ -291,7 +314,7 @@ def nl2sql_handler(request: NL2SQLRequest):
291
  )
292
 
293
 
294
- def _derive_schema_preview(adapter) -> str:
295
  """
296
  Build a strict, exact-cased schema preview for the LLM.
297
  Works for SQLite adapters by querying sqlite_master / pragma table_info.
@@ -299,7 +322,10 @@ def _derive_schema_preview(adapter) -> str:
299
  import sqlite3
300
  import os
301
 
302
- db_path = getattr(adapter, "db_path", None) or getattr(adapter, "path", None)
 
 
 
303
  if not db_path or not os.path.exists(db_path):
304
  return ""
305
 
 
18
  import time
19
  import json
20
  import uuid
21
+ from typing import Union, Optional, Dict, TypedDict, Any, cast
22
 
23
  router = APIRouter(prefix="/nl2sql")
24
 
 
27
  # Files are stored under /tmp, mapped by a short-lived db_id
28
  # -------------------------------
29
  _DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
30
+ _DB_TTL_SECONDS: int = int(os.getenv("DB_TTL_SECONDS", "7200")) # default 2 hours
31
  os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
32
 
33
+
34
+ class DBEntry(TypedDict):
35
+ path: str
36
+ ts: float
37
+
38
+
39
  # In-memory map: db_id -> {"path": str, "ts": float}
40
+ _DB_MAP: Dict[str, DBEntry] = {}
41
 
42
  # -------------------------------
43
  # Default DB resolution
44
  # -------------------------------
45
  DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
46
  POSTGRES_DSN = os.getenv("POSTGRES_DSN")
47
+ DEFAULT_SQLITE_DB: str = os.getenv("DEFAULT_SQLITE_DB", "data/chinook.db")
 
 
48
 
49
  # -------------------------------
50
  # Path to persist db_id → file map
 
52
  _DB_MAP_PATH = Path("data/uploads/db_map.json")
53
  _DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
54
 
 
55
  UPLOAD_DIR = Path("data/uploads")
56
  UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # ensure folder exists
57
 
58
  DEFAULT_SQLITE_PATH = "data/Chinook_Sqlite.sqlite"
59
 
60
 
61
+ def _save_db_map() -> None:
62
  """Persist the in-memory DB map to disk as JSON."""
63
  try:
64
  with open(_DB_MAP_PATH, "w") as f:
 
67
  print(f"⚠️ Failed to save DB map: {e}")
68
 
69
 
70
+ def _load_db_map() -> None:
71
  """Load the DB map from disk if it exists (called on startup)."""
72
  global _DB_MAP
73
  if _DB_MAP_PATH.exists():
74
  try:
75
  with open(_DB_MAP_PATH, "r") as f:
76
  data = json.load(f)
77
+ # Be liberal in what we accept; validate into TypedDict
78
  if isinstance(data, dict):
79
+ restored: Dict[str, DBEntry] = {}
80
+ for k, v in data.items():
81
+ path = v.get("path")
82
+ ts = v.get("ts")
83
+ if isinstance(path, str) and isinstance(ts, (int, float)):
84
+ restored[k] = {"path": path, "ts": float(ts)}
85
+ _DB_MAP.update(restored)
86
  print(f"📂 Restored {_DB_MAP_PATH} with {len(_DB_MAP)} entries.")
87
  except Exception as e:
88
  print(f"⚠️ Failed to load DB map: {e}")
 
91
  def _cleanup_db_map() -> None:
92
  """Remove expired uploaded DB files (best-effort)."""
93
  now = time.time()
94
+ expired = [k for k, v in _DB_MAP.items() if (now - v["ts"]) > _DB_TTL_SECONDS]
 
 
95
  for k in expired:
96
+ path: str = _DB_MAP[k]["path"]
97
  try:
98
+ if os.path.exists(path):
99
  os.remove(path)
100
  except Exception:
101
  pass
 
106
  """Resolve a SQLite file path from db_id or fallback to default."""
107
  _cleanup_db_map()
108
  if db_id and db_id in _DB_MAP:
109
+ return _DB_MAP[db_id]["path"]
110
  return DEFAULT_SQLITE_DB
111
 
112
 
113
+ def _select_adapter(db_id: Optional[str]):
114
  mode = os.getenv("DB_MODE", "sqlite").lower()
115
  if mode == "postgres":
116
  dsn = os.environ.get("POSTGRES_DSN")
 
121
  # sqlite mode
122
  if db_id:
123
  _cleanup_db_map()
124
+ db_path: Optional[str] = None
125
  # first check runtime map
126
  if db_id in _DB_MAP:
127
+ db_path = _DB_MAP[db_id]["path"]
128
  # fallback: check /tmp or uploads
129
  if not db_path or not os.path.exists(db_path):
130
  fallback_tmp = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
131
+ fallback_uploads = str(UPLOAD_DIR / f"{db_id}.sqlite")
132
  for candidate in (fallback_tmp, fallback_uploads):
133
  if os.path.exists(candidate):
134
+ db_path = candidate
135
  break
136
  if not db_path or not os.path.exists(db_path):
137
  raise HTTPException(
138
  status_code=400, detail="invalid db_id (file not found)"
139
  )
140
+ return SQLiteAdapter(db_path)
141
 
142
  # fallback to default Chinook
143
  if not Path(DEFAULT_SQLITE_PATH).exists():
 
177
  # -------------------------------
178
  # Helpers
179
  # -------------------------------
180
+ def _to_dict(obj: Any) -> Any:
181
+ """Safely convert dataclass instance → dict, otherwise return as-is.
182
+
183
+ Note: dataclasses.is_dataclass returns True for both classes and instances.
184
+ We must exclude classes; mypy cannot refine this perfectly, so we ignore arg-type.
185
+ """
186
+ if is_dataclass(obj) and not isinstance(obj, type):
187
+ return asdict(obj) # type: ignore[arg-type]
188
+ return obj
189
 
190
 
191
+ def _round_trace(t: Dict[str, Any]) -> Dict[str, Any]:
192
  """Round float fields to keep responses tidy and stable."""
193
  if t.get("cost_usd") is not None:
194
+ # Ensure numeric before rounding
195
+ cost = t["cost_usd"]
196
+ if isinstance(cost, (int, float)):
197
+ t["cost_usd"] = round(float(cost), 6)
198
  if t.get("duration_ms") is not None:
199
+ dur = t["duration_ms"]
200
+ if isinstance(dur, (int, float)):
201
+ t["duration_ms"] = round(float(dur), 2)
202
  return t
203
 
204
 
 
263
  pipeline = _build_pipeline(adapter)
264
 
265
  # 2) Resolve schema_preview (optional in request)
266
+ provided_preview_any: Any = getattr(request, "schema_preview", None)
267
+ provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
268
+
269
+ derived_preview: str = _derive_schema_preview(adapter)
270
+ schema_preview_opt: Optional[str] = (
271
+ provided_preview if provided_preview not in ("", None) else derived_preview
272
  )
273
 
274
+ # Guarantee a str for Pipeline.run (mypy requirement)
275
+ final_preview: str = schema_preview_opt or ""
276
+
277
  # 3) Run pipeline
278
  try:
279
  result = pipeline.run(
280
+ user_query=request.query, # assumes NL2SQLRequest has `query: str`
281
+ schema_preview=final_preview, # str guaranteed
282
  )
283
  except Exception as exc:
284
  # Hard failure in pipeline itself
 
314
  )
315
 
316
 
317
+ def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
318
  """
319
  Build a strict, exact-cased schema preview for the LLM.
320
  Works for SQLite adapters by querying sqlite_master / pragma table_info.
 
322
  import sqlite3
323
  import os
324
 
325
+ # Adapters may expose db_path or path; both are str in our codebase
326
+ db_path: Optional[str] = cast(
327
+ Optional[str], getattr(adapter, "db_path", None)
328
+ ) or cast(Optional[str], getattr(adapter, "path", None))
329
  if not db_path or not os.path.exists(db_path):
330
  return ""
331
 
requirements.txt CHANGED
@@ -12,3 +12,4 @@ psycopg[binary]~=3.2
12
  ruff
13
  gradio
14
  sqlalchemy
 
 
12
  ruff
13
  gradio
14
  sqlalchemy
15
+ types-requests