Melika Kheirieh commited on
Commit
99fa656
·
1 Parent(s): fb51384

fix(typing): split adapter vars in /readyz to satisfy mypy

Browse files
Files changed (4) hide show
  1. .pre-commit-config.yaml +3 -3
  2. app/bootstrap.py +13 -3
  3. app/main.py +28 -38
  4. app/routers/nl2sql.py +78 -142
.pre-commit-config.yaml CHANGED
@@ -1,7 +1,7 @@
1
  repos:
2
  # --- Basic hygiene checks ---
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.6.0
5
  hooks:
6
  - id: check-merge-conflict
7
  - id: end-of-file-fixer
@@ -9,7 +9,7 @@ repos:
9
 
10
  # --- Ruff: linting and formatting ---
11
  - repo: https://github.com/astral-sh/ruff-pre-commit
12
- rev: v0.7.1
13
  hooks:
14
  - id: ruff
15
  args: [--fix, --exit-non-zero-on-fix]
@@ -17,7 +17,7 @@ repos:
17
 
18
  # --- Mypy: type-checking on staged Python files ---
19
  - repo: https://github.com/pre-commit/mirrors-mypy
20
- rev: v1.11.2
21
  hooks:
22
  - id: mypy
23
  args:
 
1
  repos:
2
  # --- Basic hygiene checks ---
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v6.0.0
5
  hooks:
6
  - id: check-merge-conflict
7
  - id: end-of-file-fixer
 
9
 
10
  # --- Ruff: linting and formatting ---
11
  - repo: https://github.com/astral-sh/ruff-pre-commit
12
+ rev: v0.14.3
13
  hooks:
14
  - id: ruff
15
  args: [--fix, --exit-non-zero-on-fix]
 
17
 
18
  # --- Mypy: type-checking on staged Python files ---
19
  - repo: https://github.com/pre-commit/mirrors-mypy
20
+ rev: v1.18.2
21
  hooks:
22
  - id: mypy
23
  args:
app/bootstrap.py CHANGED
@@ -1,7 +1,17 @@
 
 
 
 
 
 
 
 
 
 
1
  try:
2
  from dotenv import load_dotenv
3
 
4
  load_dotenv()
5
- except Exception:
6
- # optional: silently continue if python-dotenv is not installed
7
- pass
 
1
+ """App bootstrap: load .env and prepare environment paths."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ # Change current dir to project root (so relative paths like data/ work)
7
+ ROOT_DIR = Path(__file__).resolve().parent.parent
8
+ os.chdir(ROOT_DIR)
9
+
10
+ # Load .env if available
11
  try:
12
  from dotenv import load_dotenv
13
 
14
  load_dotenv()
15
+ print("✅ bootstrap: .env loaded")
16
+ except Exception as e:
17
+ print(f"⚠️ bootstrap: could not load .env ({e})")
app/main.py CHANGED
@@ -1,13 +1,7 @@
1
  import os
2
  import time
3
- from typing import Protocol, runtime_checkable, cast
4
-
5
  from fastapi import FastAPI, Request, Response, HTTPException
6
  from fastapi.responses import PlainTextResponse
7
-
8
- from app.routers import nl2sql
9
-
10
- # Prometheus
11
  from prometheus_client import (
12
  Counter,
13
  Histogram,
@@ -16,13 +10,14 @@ from prometheus_client import (
16
  CONTENT_TYPE_LATEST,
17
  )
18
 
 
 
19
 
20
- @runtime_checkable
21
- class HasPing(Protocol):
22
- """Minimal interface for adapters that support a connectivity check."""
23
-
24
- def ping(self) -> None: ...
25
 
 
26
 
27
  # ---- Optionally restore uploaded DB map ----
28
  try:
@@ -32,24 +27,22 @@ try:
32
  except Exception as e:
33
  print(f"⚠️ DB map not restored: {e}")
34
 
35
- app = FastAPI(
36
  title="NL2SQL Copilot Prototype",
37
  version=os.getenv("APP_VERSION", "0.1.0"),
38
  description="Convert natural language to safe & verified SQL",
39
  )
40
 
41
- app.include_router(nl2sql.router, prefix="/api/v1")
42
 
43
  # ---- Prometheus metrics ----
44
- REGISTRY: CollectorRegistry = CollectorRegistry()
45
-
46
  REQUEST_COUNT = Counter(
47
  "http_requests_total",
48
  "Total HTTP requests",
49
  ["path", "method", "status_code"],
50
  registry=REGISTRY,
51
  )
52
-
53
  REQUEST_LATENCY = Histogram(
54
  "http_request_latency_seconds",
55
  "Request latency",
@@ -58,20 +51,13 @@ REQUEST_LATENCY = Histogram(
58
  )
59
 
60
 
61
- @app.middleware("http")
62
  async def metrics_middleware(request: Request, call_next):
63
  start = time.perf_counter()
64
  response: Response = await call_next(request)
65
  elapsed = time.perf_counter() - start
66
-
67
- # Use route path if available, else raw path (typed guard for mypy)
68
  route = request.scope.get("route")
69
- path = (
70
- route.path
71
- if (route is not None and hasattr(route, "path"))
72
- else request.url.path
73
- )
74
-
75
  REQUEST_COUNT.labels(
76
  path=path, method=request.method, status_code=str(response.status_code)
77
  ).inc()
@@ -79,14 +65,16 @@ async def metrics_middleware(request: Request, call_next):
79
  return response
80
 
81
 
82
- # --- Liveness (super light) ---
83
- @app.get("/healthz", response_class=PlainTextResponse, tags=["system"])
84
  def healthz() -> str:
85
  return "ok"
86
 
87
 
88
- # --- Readiness (checks DB/env lightly) ---
89
- @app.get("/readyz", response_class=PlainTextResponse, tags=["system"])
 
 
90
  def readyz() -> str:
91
  mode = os.getenv("DB_MODE", "sqlite").lower()
92
  try:
@@ -94,32 +82,34 @@ def readyz() -> str:
94
  from adapters.db.postgres_adapter import PostgresAdapter
95
 
96
  dsn = os.environ["POSTGRES_DSN"]
97
- # Call ping inline; avoid cross-branch variable typing
98
- cast(HasPing, PostgresAdapter(dsn)).ping()
 
 
99
  else:
100
  from adapters.db.sqlite_adapter import SQLiteAdapter
101
 
102
  db_path = os.getenv("SQLITE_DB_PATH", "data/chinook.db")
103
- cast(HasPing, SQLiteAdapter(db_path)).ping()
104
-
105
- # if not os.getenv("PROXY_API_KEY"): pass
 
106
  return "ready"
107
  except Exception:
108
  raise HTTPException(status_code=503, detail="not ready")
109
 
110
 
111
- @app.get("/")
112
  def root():
113
  return {"status": "ok", "message": "NL2SQL Copilot API is running"}
114
 
115
 
116
- @app.get("/health")
117
  def health():
118
- # You might want to replace the placeholders with real checks later.
119
  return {"status": "ok", "db": "connected", "llm": "reachable", "uptime_sec": 123.4}
120
 
121
 
122
- @app.get("/metrics", tags=["system"])
123
  def metrics():
124
  data = generate_latest(REGISTRY)
125
  return Response(content=data, media_type=CONTENT_TYPE_LATEST)
 
1
  import os
2
  import time
 
 
3
  from fastapi import FastAPI, Request, Response, HTTPException
4
  from fastapi.responses import PlainTextResponse
 
 
 
 
5
  from prometheus_client import (
6
  Counter,
7
  Histogram,
 
10
  CONTENT_TYPE_LATEST,
11
  )
12
 
13
+ try:
14
+ from dotenv import load_dotenv
15
 
16
+ load_dotenv()
17
+ except Exception:
18
+ pass
 
 
19
 
20
+ from app.routers import nl2sql
21
 
22
  # ---- Optionally restore uploaded DB map ----
23
  try:
 
27
  except Exception as e:
28
  print(f"⚠️ DB map not restored: {e}")
29
 
30
+ application: FastAPI = FastAPI(
31
  title="NL2SQL Copilot Prototype",
32
  version=os.getenv("APP_VERSION", "0.1.0"),
33
  description="Convert natural language to safe & verified SQL",
34
  )
35
 
36
+ application.include_router(nl2sql.router, prefix="/api/v1")
37
 
38
  # ---- Prometheus metrics ----
39
+ REGISTRY = CollectorRegistry()
 
40
  REQUEST_COUNT = Counter(
41
  "http_requests_total",
42
  "Total HTTP requests",
43
  ["path", "method", "status_code"],
44
  registry=REGISTRY,
45
  )
 
46
  REQUEST_LATENCY = Histogram(
47
  "http_request_latency_seconds",
48
  "Request latency",
 
51
  )
52
 
53
 
54
+ @application.middleware("http")
55
  async def metrics_middleware(request: Request, call_next):
56
  start = time.perf_counter()
57
  response: Response = await call_next(request)
58
  elapsed = time.perf_counter() - start
 
 
59
  route = request.scope.get("route")
60
+ path = route.path if route else request.url.path
 
 
 
 
 
61
  REQUEST_COUNT.labels(
62
  path=path, method=request.method, status_code=str(response.status_code)
63
  ).inc()
 
65
  return response
66
 
67
 
68
+ # --- Liveness ---
69
+ @application.get("/healthz", response_class=PlainTextResponse, tags=["system"])
70
  def healthz() -> str:
71
  return "ok"
72
 
73
 
74
+ # --- Readiness ---
75
+
76
+
77
+ @application.get("/readyz", response_class=PlainTextResponse, tags=["system"])
78
  def readyz() -> str:
79
  mode = os.getenv("DB_MODE", "sqlite").lower()
80
  try:
 
82
  from adapters.db.postgres_adapter import PostgresAdapter
83
 
84
  dsn = os.environ["POSTGRES_DSN"]
85
+ pg = PostgresAdapter(dsn)
86
+ ping = getattr(pg, "ping", None)
87
+ if callable(ping):
88
+ ping()
89
  else:
90
  from adapters.db.sqlite_adapter import SQLiteAdapter
91
 
92
  db_path = os.getenv("SQLITE_DB_PATH", "data/chinook.db")
93
+ sq = SQLiteAdapter(db_path)
94
+ ping = getattr(sq, "ping", None)
95
+ if callable(ping):
96
+ ping()
97
  return "ready"
98
  except Exception:
99
  raise HTTPException(status_code=503, detail="not ready")
100
 
101
 
102
+ @application.get("/")
103
  def root():
104
  return {"status": "ok", "message": "NL2SQL Copilot API is running"}
105
 
106
 
107
+ @application.get("/health")
108
  def health():
 
109
  return {"status": "ok", "db": "connected", "llm": "reachable", "uptime_sec": 123.4}
110
 
111
 
112
+ @application.get("/metrics", tags=["system"])
113
  def metrics():
114
  data = generate_latest(REGISTRY)
115
  return Response(content=data, media_type=CONTENT_TYPE_LATEST)
app/routers/nl2sql.py CHANGED
@@ -22,46 +22,25 @@ from typing import Union, Optional, Dict, TypedDict, Any, cast
22
 
23
  router = APIRouter(prefix="/nl2sql")
24
 
25
- # --- Database adapter selection ---
26
- DB_MODE = os.getenv("DB_MODE", "sqlite").lower()
27
-
28
- _db: Union[PostgresAdapter, SQLiteAdapter]
29
- if DB_MODE == "postgres":
30
- dsn = os.environ.get("POSTGRES_DSN")
31
- if not dsn:
32
- raise RuntimeError(
33
- "POSTGRES_DSN environment variable is required in postgres mode"
34
- )
35
- _db = PostgresAdapter(dsn)
36
- else:
37
- _db = SQLiteAdapter("data/chinook.db")
38
-
39
-
40
- # --- Build a single shared pipeline for all routes ---
41
- def _make_pipeline() -> Pipeline:
42
- llm = OpenAIProvider()
43
- return Pipeline(
44
- detector=AmbiguityDetector(),
45
- planner=Planner(llm=llm),
46
- generator=Generator(llm=llm),
47
- safety=Safety(),
48
- executor=Executor(db=_db),
49
- verifier=Verifier(),
50
- repair=Repair(llm=llm),
51
- )
52
-
53
-
54
- _pipeline: Pipeline = _make_pipeline()
55
-
56
-
57
  # -------------------------------
58
- # Runtime DB registry (for uploaded SQLite files)
59
- # Files are stored under /tmp, mapped by a short-lived db_id
60
  # -------------------------------
 
 
 
 
 
61
  _DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
62
  _DB_TTL_SECONDS: int = int(os.getenv("DB_TTL_SECONDS", "7200")) # default 2 hours
63
  os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
64
 
 
 
 
 
 
 
 
65
 
66
  class DBEntry(TypedDict):
67
  path: str
@@ -71,27 +50,8 @@ class DBEntry(TypedDict):
71
  # In-memory map: db_id -> {"path": str, "ts": float}
72
  _DB_MAP: Dict[str, DBEntry] = {}
73
 
74
- # -------------------------------
75
- # Default DB resolution
76
- # -------------------------------
77
- DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
78
- POSTGRES_DSN = os.getenv("POSTGRES_DSN")
79
- DEFAULT_SQLITE_DB: str = os.getenv("DEFAULT_SQLITE_DB", "data/chinook.db")
80
-
81
- # -------------------------------
82
- # Path to persist db_id → file map
83
- # -------------------------------
84
- _DB_MAP_PATH = Path("data/uploads/db_map.json")
85
- _DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
86
-
87
- UPLOAD_DIR = Path("data/uploads")
88
- UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # ensure folder exists
89
-
90
- DEFAULT_SQLITE_PATH = "data/Chinook_Sqlite.sqlite"
91
-
92
 
93
  def _save_db_map() -> None:
94
- """Persist the in-memory DB map to disk as JSON."""
95
  try:
96
  with open(_DB_MAP_PATH, "w") as f:
97
  json.dump(_DB_MAP, f)
@@ -100,13 +60,11 @@ def _save_db_map() -> None:
100
 
101
 
102
  def _load_db_map() -> None:
103
- """Load the DB map from disk if it exists (called on startup)."""
104
  global _DB_MAP
105
  if _DB_MAP_PATH.exists():
106
  try:
107
  with open(_DB_MAP_PATH, "r") as f:
108
  data = json.load(f)
109
- # Be liberal in what we accept; validate into TypedDict
110
  if isinstance(data, dict):
111
  restored: Dict[str, DBEntry] = {}
112
  for k, v in data.items():
@@ -121,7 +79,6 @@ def _load_db_map() -> None:
121
 
122
 
123
  def _cleanup_db_map() -> None:
124
- """Remove expired uploaded DB files (best-effort)."""
125
  now = time.time()
126
  expired = [k for k, v in _DB_MAP.items() if (now - v["ts"]) > _DB_TTL_SECONDS]
127
  for k in expired:
@@ -134,15 +91,20 @@ def _cleanup_db_map() -> None:
134
  _DB_MAP.pop(k, None)
135
 
136
 
137
- def _resolve_sqlite_path(db_id: Optional[str]) -> str:
138
- """Resolve a SQLite file path from db_id or fallback to default."""
139
- _cleanup_db_map()
140
- if db_id and db_id in _DB_MAP:
141
- return _DB_MAP[db_id]["path"]
142
- return DEFAULT_SQLITE_DB
143
 
144
 
145
- def _select_adapter(db_id: Optional[str]):
 
 
 
 
 
 
 
 
 
146
  mode = os.getenv("DB_MODE", "sqlite").lower()
147
  if mode == "postgres":
148
  dsn = os.environ.get("POSTGRES_DSN")
@@ -151,79 +113,72 @@ def _select_adapter(db_id: Optional[str]):
151
  return PostgresAdapter(dsn)
152
 
153
  # sqlite mode
 
154
  if db_id:
155
- _cleanup_db_map()
156
- db_path: Optional[str] = None
157
- # first check runtime map
158
- if db_id in _DB_MAP:
159
- db_path = _DB_MAP[db_id]["path"]
160
- # fallback: check /tmp or uploads
161
- if not db_path or not os.path.exists(db_path):
162
- fallback_tmp = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
163
- fallback_uploads = str(UPLOAD_DIR / f"{db_id}.sqlite")
164
- for candidate in (fallback_tmp, fallback_uploads):
165
- if os.path.exists(candidate):
166
- db_path = candidate
167
- break
168
- if not db_path or not os.path.exists(db_path):
169
- raise HTTPException(
170
- status_code=400, detail="invalid db_id (file not found)"
171
- )
172
- return SQLiteAdapter(db_path)
173
-
174
- # fallback to default Chinook
175
  if not Path(DEFAULT_SQLITE_PATH).exists():
176
  raise HTTPException(status_code=500, detail="default DB not found")
177
  return SQLiteAdapter(DEFAULT_SQLITE_PATH)
178
 
179
 
180
  # -------------------------------
181
- # LLM providers & shared components (stateless)
182
  # -------------------------------
183
- def get_llm():
 
184
  return OpenAIProvider()
185
 
186
 
187
- _detector = AmbiguityDetector()
188
- _planner = Planner(get_llm())
189
- _generator = Generator(get_llm())
190
- _safety = Safety()
191
- _verifier = Verifier()
192
- _repair = Repair(get_llm())
193
-
194
-
195
  def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
196
- """Build a fresh Pipeline with a per-request Executor bound to the chosen adapter."""
 
 
 
 
 
 
 
 
197
  executor = Executor(adapter)
 
 
198
  return Pipeline(
199
- detector=_detector,
200
- planner=_planner,
201
- generator=_generator,
202
- safety=_safety,
203
  executor=executor,
204
- verifier=_verifier,
205
- repair=_repair,
206
  )
207
 
208
 
209
  # -------------------------------
210
- # Helpers
211
  # -------------------------------
212
  def _to_dict(obj: Any) -> Any:
213
- """Safely convert dataclass instance → dict, otherwise return as-is.
214
-
215
- Note: dataclasses.is_dataclass returns True for both classes and instances.
216
- We must exclude classes; mypy cannot refine this perfectly, so we ignore arg-type.
217
- """
218
  if is_dataclass(obj) and not isinstance(obj, type):
219
  return asdict(obj) # type: ignore[arg-type]
220
  return obj
221
 
222
 
223
  def _round_trace(t: Dict[str, Any]) -> Dict[str, Any]:
224
- """Round float fields to keep responses tidy and stable."""
225
  if t.get("cost_usd") is not None:
226
- # Ensure numeric before rounding
227
  cost = t["cost_usd"]
228
  if isinstance(cost, (int, float)):
229
  t["cost_usd"] = round(float(cost), 6)
@@ -236,17 +191,9 @@ def _round_trace(t: Dict[str, Any]) -> Dict[str, Any]:
236
 
237
  # -------------------------------
238
  # Upload endpoint (SQLite only)
239
- # Path will be /api/nl2sql/upload_db if your root APIRouter is mounted at /api
240
  # -------------------------------
241
  @router.post("/upload_db")
242
  async def upload_db(file: UploadFile = File(...)):
243
- """
244
- Upload a SQLite database (.db/.sqlite). Returns a short-lived db_id.
245
- Notes:
246
- - Only SQLite files are allowed here (not for Postgres mode).
247
- - Max size ~20MB recommended for demo environments like HF Spaces.
248
- - Files are stored under /tmp and cleaned by TTL.
249
- """
250
  if DB_MODE != "sqlite":
251
  raise HTTPException(
252
  status_code=400, detail="DB upload is only supported in sqlite mode"
@@ -280,48 +227,42 @@ async def upload_db(file: UploadFile = File(...)):
280
 
281
  # -------------------------------
282
  # Main NL2SQL endpoint
283
- # Path will be /api/nl2sql if your root APIRouter is mounted at /api
284
  # -------------------------------
285
  @router.post("", name="nl2sql_handler")
286
  def nl2sql_handler(request: NL2SQLRequest):
287
  db_id = getattr(request, "db_id", None)
288
 
289
- # 1) Pick pipeline (+ optional per-request adapter)
290
- pipeline: Pipeline
291
- if db_id:
292
- adapter = _select_adapter(db_id) # returns PostgresAdapter | SQLiteAdapter
293
- # If _select_adapter could theoretically return None, uncomment the next line:
294
- # assert adapter is not None, "adapter must be set when db_id is provided"
295
- pipeline = _build_pipeline(adapter)
296
- derived_preview_val: str = _derive_schema_preview(adapter)
297
- else:
298
- pipeline = _pipeline
299
- derived_preview_val = "" # no adapter → no derive
300
-
301
- # 2) Resolve schema_preview
302
  provided_preview_any: Any = getattr(request, "schema_preview", None)
303
  provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
304
  final_preview: str = provided_preview or derived_preview_val
305
 
306
- # 3) Run pipeline
307
  try:
308
  result = pipeline.run(
309
  user_query=request.query,
310
  schema_preview=final_preview,
311
  )
312
  except Exception as exc:
313
- # Hard failure in pipeline itself
314
  raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
315
 
316
- # 4) Type check
317
  if not isinstance(result, FinalResult):
318
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
319
 
320
- # 5) Ambiguity → ask for clarification
321
  if result.ambiguous and result.questions:
322
  return ClarifyResponse(ambiguous=True, questions=result.questions)
323
 
324
- # 6) Soft errors → bubble up details with 400
325
  if not result.ok or result.error:
326
  print("❌ Pipeline failure dump:")
327
  print(" ok:", result.ok)
@@ -333,7 +274,6 @@ def nl2sql_handler(request: NL2SQLRequest):
333
  detail="; ".join(result.details or []) or (result.error or "Unknown error"),
334
  )
335
 
336
- # 7) Success
337
  traces = [_round_trace(t) for t in (result.traces or [])]
338
  return NL2SQLResponse(
339
  ambiguous=False,
@@ -345,13 +285,10 @@ def nl2sql_handler(request: NL2SQLRequest):
345
 
346
  def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
347
  """
348
- Build a strict, exact-cased schema preview for the LLM.
349
- Works for SQLite adapters by querying sqlite_master / pragma table_info.
350
  """
351
  import sqlite3
352
- import os
353
 
354
- # Adapters may expose db_path or path; both are str in our codebase
355
  db_path: Optional[str] = cast(
356
  Optional[str], getattr(adapter, "db_path", None)
357
  ) or cast(Optional[str], getattr(adapter, "path", None))
@@ -367,8 +304,7 @@ def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> st
367
  lines = []
368
  for (tname,) in tables:
369
  cols = cur.execute(f"PRAGMA table_info('{tname}')").fetchall()
370
- # sqlite: pragma columns (cid, name, type, notnull, dflt_value, pk)
371
- colnames = [c[1] for c in cols]
372
  lines.append(f"{tname}({', '.join(colnames)})")
373
  conn.close()
374
  return "\n".join(lines)
 
22
 
23
  router = APIRouter(prefix="/nl2sql")
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # -------------------------------
26
+ # Config / Defaults
 
27
  # -------------------------------
28
+ DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
29
+ POSTGRES_DSN = os.getenv("POSTGRES_DSN")
30
+ DEFAULT_SQLITE_PATH: str = os.getenv("DEFAULT_SQLITE_DB", "data/Chinook_Sqlite.sqlite")
31
+
32
+ # Runtime upload storage
33
  _DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
34
  _DB_TTL_SECONDS: int = int(os.getenv("DB_TTL_SECONDS", "7200")) # default 2 hours
35
  os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
36
 
37
+ # Persisted map
38
+ _DB_MAP_PATH = Path("data/uploads/db_map.json")
39
+ _DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
40
+
41
+ UPLOAD_DIR = Path("data/uploads")
42
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
43
+
44
 
45
  class DBEntry(TypedDict):
46
  path: str
 
50
  # In-memory map: db_id -> {"path": str, "ts": float}
51
  _DB_MAP: Dict[str, DBEntry] = {}
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def _save_db_map() -> None:
 
55
  try:
56
  with open(_DB_MAP_PATH, "w") as f:
57
  json.dump(_DB_MAP, f)
 
60
 
61
 
62
  def _load_db_map() -> None:
 
63
  global _DB_MAP
64
  if _DB_MAP_PATH.exists():
65
  try:
66
  with open(_DB_MAP_PATH, "r") as f:
67
  data = json.load(f)
 
68
  if isinstance(data, dict):
69
  restored: Dict[str, DBEntry] = {}
70
  for k, v in data.items():
 
79
 
80
 
81
  def _cleanup_db_map() -> None:
 
82
  now = time.time()
83
  expired = [k for k, v in _DB_MAP.items() if (now - v["ts"]) > _DB_TTL_SECONDS]
84
  for k in expired:
 
91
  _DB_MAP.pop(k, None)
92
 
93
 
94
+ # Call once at import (safe & light); heavy things remain lazy.
95
+ _load_db_map()
 
 
 
 
96
 
97
 
98
+ # -------------------------------
99
+ # Adapter selection (lazy)
100
+ # -------------------------------
101
+ def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
102
+ """
103
+ Resolve a DB adapter:
104
+ - postgres: requires POSTGRES_DSN
105
+ - sqlite with db_id: uploaded file or fallback locations
106
+ - sqlite default: DEFAULT_SQLITE_PATH must exist
107
+ """
108
  mode = os.getenv("DB_MODE", "sqlite").lower()
109
  if mode == "postgres":
110
  dsn = os.environ.get("POSTGRES_DSN")
 
113
  return PostgresAdapter(dsn)
114
 
115
  # sqlite mode
116
+ _cleanup_db_map()
117
  if db_id:
118
+ # Check runtime map
119
+ entry = _DB_MAP.get(db_id)
120
+ candidates = []
121
+ if entry and os.path.exists(entry["path"]):
122
+ candidates.append(entry["path"])
123
+ # Fallback locations based on convention
124
+ candidates.append(os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite"))
125
+ candidates.append(str(UPLOAD_DIR / f"{db_id}.sqlite"))
126
+
127
+ for p in candidates:
128
+ if p and os.path.exists(p):
129
+ return SQLiteAdapter(p)
130
+
131
+ raise HTTPException(status_code=400, detail="invalid db_id (file not found)")
132
+
133
+ # default sqlite
 
 
 
 
134
  if not Path(DEFAULT_SQLITE_PATH).exists():
135
  raise HTTPException(status_code=500, detail="default DB not found")
136
  return SQLiteAdapter(DEFAULT_SQLITE_PATH)
137
 
138
 
139
  # -------------------------------
140
+ # LLM & Pipeline builders (lazy)
141
  # -------------------------------
142
+ def _get_llm() -> OpenAIProvider:
143
+ # Create provider on demand, after .env has been loaded in app.main
144
  return OpenAIProvider()
145
 
146
 
 
 
 
 
 
 
 
 
147
  def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
148
+ """
149
+ Build a fresh Pipeline bound to the given adapter.
150
+ All stateful/external pieces (LLM, executor) are instantiated here (lazy).
151
+ """
152
+ llm = _get_llm()
153
+ detector = AmbiguityDetector()
154
+ planner = Planner(llm=llm)
155
+ generator = Generator(llm=llm)
156
+ safety = Safety()
157
  executor = Executor(adapter)
158
+ verifier = Verifier()
159
+ repair = Repair(llm=llm)
160
  return Pipeline(
161
+ detector=detector,
162
+ planner=planner,
163
+ generator=generator,
164
+ safety=safety,
165
  executor=executor,
166
+ verifier=verifier,
167
+ repair=repair,
168
  )
169
 
170
 
171
  # -------------------------------
172
+ # Helpers (unchanged)
173
  # -------------------------------
174
  def _to_dict(obj: Any) -> Any:
 
 
 
 
 
175
  if is_dataclass(obj) and not isinstance(obj, type):
176
  return asdict(obj) # type: ignore[arg-type]
177
  return obj
178
 
179
 
180
  def _round_trace(t: Dict[str, Any]) -> Dict[str, Any]:
 
181
  if t.get("cost_usd") is not None:
 
182
  cost = t["cost_usd"]
183
  if isinstance(cost, (int, float)):
184
  t["cost_usd"] = round(float(cost), 6)
 
191
 
192
  # -------------------------------
193
  # Upload endpoint (SQLite only)
 
194
  # -------------------------------
195
  @router.post("/upload_db")
196
  async def upload_db(file: UploadFile = File(...)):
 
 
 
 
 
 
 
197
  if DB_MODE != "sqlite":
198
  raise HTTPException(
199
  status_code=400, detail="DB upload is only supported in sqlite mode"
 
227
 
228
  # -------------------------------
229
  # Main NL2SQL endpoint
 
230
  # -------------------------------
231
  @router.post("", name="nl2sql_handler")
232
  def nl2sql_handler(request: NL2SQLRequest):
233
  db_id = getattr(request, "db_id", None)
234
 
235
+ # Pick adapter per-request (default or uploaded or postgres)
236
+ adapter = _select_adapter(db_id)
237
+
238
+ # Build pipeline lazily with this adapter
239
+ pipeline = _build_pipeline(adapter)
240
+
241
+ # Derive schema preview only for sqlite with a real path
242
+ derived_preview_val: str = (
243
+ _derive_schema_preview(adapter) if isinstance(adapter, SQLiteAdapter) else ""
244
+ )
245
+
246
+ # Resolve schema_preview
 
247
  provided_preview_any: Any = getattr(request, "schema_preview", None)
248
  provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
249
  final_preview: str = provided_preview or derived_preview_val
250
 
251
+ # Run pipeline
252
  try:
253
  result = pipeline.run(
254
  user_query=request.query,
255
  schema_preview=final_preview,
256
  )
257
  except Exception as exc:
 
258
  raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
259
 
 
260
  if not isinstance(result, FinalResult):
261
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
262
 
 
263
  if result.ambiguous and result.questions:
264
  return ClarifyResponse(ambiguous=True, questions=result.questions)
265
 
 
266
  if not result.ok or result.error:
267
  print("❌ Pipeline failure dump:")
268
  print(" ok:", result.ok)
 
274
  detail="; ".join(result.details or []) or (result.error or "Unknown error"),
275
  )
276
 
 
277
  traces = [_round_trace(t) for t in (result.traces or [])]
278
  return NL2SQLResponse(
279
  ambiguous=False,
 
285
 
286
  def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
287
  """
288
+ Build a strict, exact-cased schema preview for the LLM (SQLite only).
 
289
  """
290
  import sqlite3
 
291
 
 
292
  db_path: Optional[str] = cast(
293
  Optional[str], getattr(adapter, "db_path", None)
294
  ) or cast(Optional[str], getattr(adapter, "path", None))
 
304
  lines = []
305
  for (tname,) in tables:
306
  cols = cur.execute(f"PRAGMA table_info('{tname}')").fetchall()
307
+ colnames = [c[1] for c in cols] # (cid, name, type, notnull, dflt, pk)
 
308
  lines.append(f"{tname}({', '.join(colnames)})")
309
  conn.close()
310
  return "\n".join(lines)