File size: 11,919 Bytes
9c10293
 
 
570f7bd
64907d7
 
 
82e122c
773ba18
8143b24
64907d7
9c10293
82e122c
78a30b1
64907d7
9c10293
570f7bd
82e122c
 
 
 
 
 
 
 
 
 
 
 
570f7bd
78a30b1
 
 
 
82e122c
 
 
 
 
 
 
78a30b1
82e122c
 
78a30b1
 
 
 
 
773ba18
 
 
82e122c
 
 
773ba18
 
 
 
82e122c
773ba18
 
 
 
82e122c
773ba18
 
 
82e122c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773ba18
 
 
82e122c
 
 
773ba18
 
 
 
82e122c
 
773ba18
 
 
 
 
 
570f7bd
 
b568b83
99fa656
b568b83
82e122c
99fa656
82e122c
 
b568b83
 
82e122c
99fa656
 
 
82e122c
 
 
 
b568b83
5cbfffe
78914bb
 
 
 
 
 
82e122c
 
 
 
78914bb
82e122c
 
 
 
 
 
 
78914bb
 
82e122c
 
 
78914bb
82e122c
 
 
 
 
 
78914bb
82e122c
1fa9a31
5cbfffe
b568b83
79a5f4a
b568b83
82e122c
 
977a885
82e122c
 
 
977a885
 
 
570f7bd
5cbfffe
6181651
79a5f4a
 
82e122c
79a5f4a
 
 
 
 
 
6181651
 
 
 
 
79a5f4a
 
 
6181651
 
 
 
 
79a5f4a
 
 
6181651
82e122c
6181651
79a5f4a
6181651
 
 
79a5f4a
6181651
 
 
 
 
79a5f4a
 
 
 
 
 
 
570f7bd
5cbfffe
b568b83
 
 
82e122c
 
78a30b1
b568b83
82e122c
 
 
 
 
 
 
b568b83
5cbfffe
 
 
b568b83
 
 
5cbfffe
 
 
b568b83
 
82e122c
b568b83
5cbfffe
 
 
b568b83
 
 
 
 
 
 
82e122c
b568b83
 
c76014a
82e122c
b568b83
 
5cbfffe
d53014f
 
82e122c
 
d53014f
 
b568b83
 
 
82e122c
 
78a30b1
2d682e2
 
82e122c
 
 
9c10293
82e122c
 
 
 
 
 
 
9c10293
6a94b42
82e122c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a94b42
773ba18
82e122c
 
 
 
570f7bd
82e122c
370553a
82e122c
 
 
 
 
 
 
 
370553a
82e122c
 
 
 
 
 
 
 
370553a
82e122c
a45c0eb
82e122c
 
 
 
 
 
 
 
570f7bd
82e122c
2d682e2
 
 
 
82e122c
d5f745f
82e122c
 
 
 
 
 
 
 
343ad62
 
570f7bd
82e122c
4dae3e6
3ef53b4
 
 
 
 
 
 
 
 
773ba18
570f7bd
a45c0eb
 
 
3ef53b4
570f7bd
3ef53b4
82e122c
 
773ba18
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
from __future__ import annotations

# --- Stdlib ---
from dataclasses import asdict, is_dataclass
import os
from pathlib import Path
import uuid
from typing import Any, Dict, Optional, Tuple, cast
import hashlib
import logging

# --- Third-party ---
from fastapi import APIRouter, Depends, HTTPException, Security, UploadFile, File
from fastapi.security import APIKeyHeader

# --- Local ---
from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
from app.state import register_db
from nl2sql.pipeline import FinalResult
from app.dependencies import get_cache, get_nl2sql_service
from app.cache import NL2SQLCache
from app.services.nl2sql_service import NL2SQLService
from app.settings import get_settings
from app.errors import (
    AppError,
)

logger = logging.getLogger(__name__)
settings = get_settings()

api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)


def require_api_key(key: Optional[str] = Security(api_key_header)):
    """
    Simple API key check using X-API-Key header and configured API keys.

    - Settings.api_keys_raw is a comma-separated list of keys.
    - If api_keys_raw is empty → auth disabled (dev mode).
    """
    raw = settings.api_keys_raw or ""
    allowed = {k.strip() for k in raw.split(",") if k.strip()}
    if not allowed:
        # No keys configured → treat as dev mode (auth off).
        return
    if not key or key not in allowed:
        raise HTTPException(status_code=401, detail="invalid API key")


####################################
# ---- Simple in-memory cache for NL→SQL responses ----

# Cache TTL and max size from centralized settings
_CACHE_TTL = settings.cache_ttl_sec
_CACHE_MAX = settings.cache_max_entries
_CACHE: Dict[Tuple[str, str, str], Tuple[float, Dict[str, Any]]] = {}


def _norm_q(s: str) -> str:
    """Normalize a user query for cache key purposes."""
    return (s or "").strip().lower()


def _schema_key(preview: str) -> str:
    """Hash the schema preview so we do not store huge strings in the cache key."""
    return hashlib.md5((preview or "").encode()).hexdigest()


def _ck(
    db_id: Optional[str],
    query: str,
    schema_preview: str,
) -> str:
    """
    Build a stable cache key for (db_id, query, schema_preview).

    We keep the external cache API string-based, and hash the
    potentially large schema_preview to avoid huge dictionary keys.
    """
    # Normalize db_id
    db_part = db_id or "__default__"

    # Build a single string seed
    seed = f"{db_part}\n{query}\n{schema_preview}"

    # Short, deterministic key
    return hashlib.sha1(seed.encode("utf-8")).hexdigest()


def _cache_gc(now: float) -> None:
    """
    Garbage-collect cache entries by TTL and max size.
    """
    # TTL eviction
    for k, (ts, _) in list(_CACHE.items()):
        if now - ts > _CACHE_TTL:
            _CACHE.pop(k, None)

    # Size eviction (naive FIFO-style)
    while len(_CACHE) > _CACHE_MAX:
        _CACHE.pop(next(iter(_CACHE)), None)


####################################

router = APIRouter(prefix="/nl2sql")

# -------------------------------
# Config / Defaults
# -------------------------------
DB_MODE = settings.db_mode.lower()  # "sqlite" or "postgres"

# Runtime upload storage for SQLite DBs
_DB_UPLOAD_DIR = settings.db_upload_dir
os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)

# Optional: separate directory for other uploads (kept as-is for now)
UPLOAD_DIR = Path("data/uploads")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)

logger.debug(
    "NL2SQL router configured",
    extra={"db_mode": DB_MODE, "upload_dir": _DB_UPLOAD_DIR},
)


# -------------------------------
# Schema preview endpoint
# -------------------------------


@router.get("/schema")
def schema_endpoint(
    db_id: Optional[str] = None,
    svc: NL2SQLService = Depends(get_nl2sql_service),
):
    """
    Return a lightweight schema preview string for the given DB.

    - If db_id is provided, service will resolve the uploaded DB.
    - If not, service falls back to the default DB.
    - In postgres mode, caller must usually provide schema_preview explicitly.
    Domain errors (AppError subclasses) are handled by the global exception handler.
    This endpoint only wraps truly unexpected errors into a generic HTTP 500
    """
    try:
        preview = svc.get_schema_preview(db_id=db_id, override=None)
    except AppError:
        # Let the global AppError handler deal with it.
        raise
    except Exception as exc:
        logger.exception("Unexpected error in schema_endpoint", exc_info=exc)
        raise HTTPException(
            status_code=500,
            detail="failed to derive schema preview",
        ) from exc

    return {"schema_preview": preview}


# -------------------------------
# Helpers
# -------------------------------


def _to_dict(obj: Any) -> Any:
    """
    Convert dataclass-like objects (and similar) to plain dicts for JSON.
    """
    if is_dataclass(obj) and not isinstance(obj, type):
        return asdict(obj)  # type: ignore[arg-type]
    return obj


def _round_trace(t: Any) -> Dict[str, Any]:
    """
    Normalize a trace entry (dict or StageTrace-like object) for API/UI:

    - stage: str (required)
    - duration_ms: int (rounded)
    - summary: optional (pass-through if exists)
    - notes: optional
    - token_in/out, cost_usd: pass-through if present
    """
    if isinstance(t, dict):
        stage = t.get("stage", "?")
        ms = t.get("duration_ms", 0)
        notes = t.get("notes")
        cost = t.get("cost_usd")
        summary = t.get("summary")
        token_in = t.get("token_in")
        token_out = t.get("token_out")
    else:
        stage = getattr(t, "stage", "?")
        ms = getattr(t, "duration_ms", 0)
        notes = getattr(t, "notes", None)
        cost = getattr(t, "cost_usd", None)
        summary = getattr(t, "summary", None)
        token_in = getattr(t, "token_in", None)
        token_out = getattr(t, "token_out", None)

    # Coerce duration to int with rounding
    try:
        ms_int = int(round(float(ms))) if ms is not None else 0
    except Exception:
        ms_int = 0

    out: Dict[str, Any] = {
        "stage": str(stage) if stage is not None else "?",
        "duration_ms": ms_int,
        "notes": notes,
        "cost_usd": cost,
    }
    if summary is not None:
        out["summary"] = summary
    if token_in is not None:
        out["token_in"] = token_in
    if token_out is not None:
        out["token_out"] = token_out
    return out


# -------------------------------
# Upload endpoint (SQLite only)
# -------------------------------


@router.post("/upload_db", dependencies=[Depends(require_api_key)])
async def upload_db(file: UploadFile = File(...)):
    """
    Upload a SQLite DB file and register it under a generated db_id.

    Only available when DB_MODE is 'sqlite':
    - Allowed extensions: .db, .sqlite
    - File size capped by configured upload_max_bytes (default 20 MB)
    """
    if DB_MODE != "sqlite":
        raise HTTPException(
            status_code=400, detail="DB upload is only supported in sqlite mode"
        )

    filename = file.filename or "db.sqlite"
    if not (filename.endswith(".db") or filename.endswith(".sqlite")):
        raise HTTPException(
            status_code=400, detail="Only .db or .sqlite files are allowed"
        )

    data = await file.read()
    max_bytes = settings.upload_max_bytes
    if len(data) > max_bytes:
        raise HTTPException(
            status_code=400, detail=f"File too large (> {max_bytes} bytes)"
        )

    db_id = str(uuid.uuid4())
    out_path = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
    try:
        with open(out_path, "wb") as f:
            f.write(data)
    except Exception as e:
        logger.debug("Failed to store uploaded DB file", exc_info=e)
        raise HTTPException(status_code=500, detail=f"Failed to store DB: {e}")

    register_db(db_id, out_path)
    logger.debug("Registered uploaded DB", extra={"db_id": db_id, "path": out_path})
    return {"db_id": db_id}


@router.get("/health")
def health():
    """Simple router-level health endpoint."""
    return {"status": "ok", "version": settings.app_version}


# -------------------------------
# Main NL2SQL endpoint
# -------------------------------


@router.post("", name="nl2sql_handler", dependencies=[Depends(require_api_key)])
def nl2sql_handler(
    request: NL2SQLRequest,
    svc: NL2SQLService = Depends(get_nl2sql_service),
    cache: NL2SQLCache = Depends(get_cache),
) -> NL2SQLResponse | ClarifyResponse | Dict[str, Any]:
    """
    Main NL→SQL handler.

    Flow:
    - Resolve schema preview (client override or derived from DB).
    - Check in-memory cache (db_id + query + schema hash).
    - Run the pipeline through NL2SQLService.
    - Map FinalResult to API response or HTTP error.
    """
    db_id = getattr(request, "db_id", None)

    # ---- schema preview ----
    try:
        final_preview = svc.get_schema_preview(
            db_id=db_id,
            override=request.schema_preview,
        )
    except AppError:
        # Domain-level errors are handled by the global AppError handler.
        raise
    except Exception as exc:
        logger.exception(
            "Unexpected error while preparing schema preview",
            exc_info=exc,
        )
        raise HTTPException(
            status_code=500,
            detail="failed to prepare schema",
        ) from exc

    # ---- cache lookup ----
    cache_key = _ck(db_id, request.query, final_preview)
    cached_payload = cache.get(cache_key)
    if cached_payload is not None:
        return cached_payload

    # ---- pipeline execution via service ----
    try:
        result = svc.run_query(
            query=request.query,
            db_id=db_id,
            schema_preview=final_preview,
        )
    except AppError:
        # Let the global handler convert it to an HTTP response.
        raise
    except Exception as exc:
        logger.exception(
            "Unexpected pipeline crash in NL2SQLService.run_query",
            exc_info=exc,
        )
        raise HTTPException(
            status_code=500,
            detail="internal pipeline error",
        ) from exc

    # ---- type sanity check ----
    if not isinstance(result, FinalResult):
        logger.debug(
            "Pipeline returned unexpected type",
            extra={"type": type(result).__name__},
        )
        raise HTTPException(
            status_code=500,
            detail="pipeline returned unexpected type",
        )

    # ---- ambiguity path → 200 with clarification questions ----
    if result.ambiguous:
        qs = result.questions or []
        return ClarifyResponse(ambiguous=True, questions=qs)

    # ---- error path → 400 with joined details ----
    if (not result.ok) or result.error:
        logger.debug(
            "Pipeline reported failure",
            extra={
                "ok": result.ok,
                "error": result.error,
                "details": result.details,
            },
        )
        message = "; ".join(result.details or []) or "Unknown error"
        raise HTTPException(status_code=400, detail=message)

    # ---- success path → 200 (normalize traces and executor result) ----
    traces = [_round_trace(t) for t in (result.traces or [])]

    response_result: Dict[str, Any] = {}
    raw_result = getattr(result, "result", None)
    if raw_result is not None:
        if isinstance(raw_result, dict):
            response_result = raw_result
        else:
            response_result = cast(Dict[str, Any], _to_dict(raw_result))

    payload = NL2SQLResponse(
        ambiguous=False,
        sql=result.sql,
        rationale=result.rationale,
        traces=traces,
        result=response_result,
    )

    # Store in cache (as plain dict)
    cache.set(cache_key, payload.model_dump())
    return payload