from __future__ import annotations # --- Stdlib --- from dataclasses import asdict, is_dataclass import os from pathlib import Path import uuid from typing import Any, Dict, Optional, Tuple, cast import hashlib import logging # --- Third-party --- from fastapi import APIRouter, Depends, HTTPException, Security, UploadFile, File from fastapi.security import APIKeyHeader # --- Local --- from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse from app.state import register_db from nl2sql.pipeline import FinalResult from app.dependencies import get_cache, get_nl2sql_service from app.cache import NL2SQLCache from app.services.nl2sql_service import NL2SQLService from app.settings import get_settings from app.errors import ( AppError, ) logger = logging.getLogger(__name__) settings = get_settings() api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) def require_api_key(key: Optional[str] = Security(api_key_header)): """ Simple API key check using X-API-Key header and configured API keys. - Settings.api_keys_raw is a comma-separated list of keys. - If api_keys_raw is empty → auth disabled (dev mode). """ raw = settings.api_keys_raw or "" allowed = {k.strip() for k in raw.split(",") if k.strip()} if not allowed: # No keys configured → treat as dev mode (auth off). return if not key or key not in allowed: raise HTTPException(status_code=401, detail="invalid API key") #################################### # ---- Simple in-memory cache for NL→SQL responses ---- # Cache TTL and max size from centralized settings _CACHE_TTL = settings.cache_ttl_sec _CACHE_MAX = settings.cache_max_entries _CACHE: Dict[Tuple[str, str, str], Tuple[float, Dict[str, Any]]] = {} def _norm_q(s: str) -> str: """Normalize a user query for cache key purposes.""" return (s or "").strip().lower() def _schema_key(preview: str) -> str: """Hash the schema preview so we do not store huge strings in the cache key.""" return hashlib.md5((preview or "").encode()).hexdigest() def _ck( db_id: Optional[str], query: str, schema_preview: str, ) -> str: """ Build a stable cache key for (db_id, query, schema_preview). We keep the external cache API string-based, and hash the potentially large schema_preview to avoid huge dictionary keys. """ # Normalize db_id db_part = db_id or "__default__" # Build a single string seed seed = f"{db_part}\n{query}\n{schema_preview}" # Short, deterministic key return hashlib.sha1(seed.encode("utf-8")).hexdigest() def _cache_gc(now: float) -> None: """ Garbage-collect cache entries by TTL and max size. """ # TTL eviction for k, (ts, _) in list(_CACHE.items()): if now - ts > _CACHE_TTL: _CACHE.pop(k, None) # Size eviction (naive FIFO-style) while len(_CACHE) > _CACHE_MAX: _CACHE.pop(next(iter(_CACHE)), None) #################################### router = APIRouter(prefix="/nl2sql") # ------------------------------- # Config / Defaults # ------------------------------- DB_MODE = settings.db_mode.lower() # "sqlite" or "postgres" # Runtime upload storage for SQLite DBs _DB_UPLOAD_DIR = settings.db_upload_dir os.makedirs(_DB_UPLOAD_DIR, exist_ok=True) # Optional: separate directory for other uploads (kept as-is for now) UPLOAD_DIR = Path("data/uploads") UPLOAD_DIR.mkdir(parents=True, exist_ok=True) logger.debug( "NL2SQL router configured", extra={"db_mode": DB_MODE, "upload_dir": _DB_UPLOAD_DIR}, ) # ------------------------------- # Schema preview endpoint # ------------------------------- @router.get("/schema") def schema_endpoint( db_id: Optional[str] = None, svc: NL2SQLService = Depends(get_nl2sql_service), ): """ Return a lightweight schema preview string for the given DB. - If db_id is provided, service will resolve the uploaded DB. - If not, service falls back to the default DB. - In postgres mode, caller must usually provide schema_preview explicitly. Domain errors (AppError subclasses) are handled by the global exception handler. This endpoint only wraps truly unexpected errors into a generic HTTP 500 """ try: preview = svc.get_schema_preview(db_id=db_id, override=None) except AppError: # Let the global AppError handler deal with it. raise except Exception as exc: logger.exception("Unexpected error in schema_endpoint", exc_info=exc) raise HTTPException( status_code=500, detail="failed to derive schema preview", ) from exc return {"schema_preview": preview} # ------------------------------- # Helpers # ------------------------------- def _to_dict(obj: Any) -> Any: """ Convert dataclass-like objects (and similar) to plain dicts for JSON. """ if is_dataclass(obj) and not isinstance(obj, type): return asdict(obj) # type: ignore[arg-type] return obj def _round_trace(t: Any) -> Dict[str, Any]: """ Normalize a trace entry (dict or StageTrace-like object) for API/UI: - stage: str (required) - duration_ms: int (rounded) - summary: optional (pass-through if exists) - notes: optional - token_in/out, cost_usd: pass-through if present """ if isinstance(t, dict): stage = t.get("stage", "?") ms = t.get("duration_ms", 0) notes = t.get("notes") cost = t.get("cost_usd") summary = t.get("summary") token_in = t.get("token_in") token_out = t.get("token_out") else: stage = getattr(t, "stage", "?") ms = getattr(t, "duration_ms", 0) notes = getattr(t, "notes", None) cost = getattr(t, "cost_usd", None) summary = getattr(t, "summary", None) token_in = getattr(t, "token_in", None) token_out = getattr(t, "token_out", None) # Coerce duration to int with rounding try: ms_int = int(round(float(ms))) if ms is not None else 0 except Exception: ms_int = 0 out: Dict[str, Any] = { "stage": str(stage) if stage is not None else "?", "duration_ms": ms_int, "notes": notes, "cost_usd": cost, } if summary is not None: out["summary"] = summary if token_in is not None: out["token_in"] = token_in if token_out is not None: out["token_out"] = token_out return out # ------------------------------- # Upload endpoint (SQLite only) # ------------------------------- @router.post("/upload_db", dependencies=[Depends(require_api_key)]) async def upload_db(file: UploadFile = File(...)): """ Upload a SQLite DB file and register it under a generated db_id. Only available when DB_MODE is 'sqlite': - Allowed extensions: .db, .sqlite - File size capped by configured upload_max_bytes (default 20 MB) """ if DB_MODE != "sqlite": raise HTTPException( status_code=400, detail="DB upload is only supported in sqlite mode" ) filename = file.filename or "db.sqlite" if not (filename.endswith(".db") or filename.endswith(".sqlite")): raise HTTPException( status_code=400, detail="Only .db or .sqlite files are allowed" ) data = await file.read() max_bytes = settings.upload_max_bytes if len(data) > max_bytes: raise HTTPException( status_code=400, detail=f"File too large (> {max_bytes} bytes)" ) db_id = str(uuid.uuid4()) out_path = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite") try: with open(out_path, "wb") as f: f.write(data) except Exception as e: logger.debug("Failed to store uploaded DB file", exc_info=e) raise HTTPException(status_code=500, detail=f"Failed to store DB: {e}") register_db(db_id, out_path) logger.debug("Registered uploaded DB", extra={"db_id": db_id, "path": out_path}) return {"db_id": db_id} @router.get("/health") def health(): """Simple router-level health endpoint.""" return {"status": "ok", "version": settings.app_version} # ------------------------------- # Main NL2SQL endpoint # ------------------------------- @router.post("", name="nl2sql_handler", dependencies=[Depends(require_api_key)]) def nl2sql_handler( request: NL2SQLRequest, svc: NL2SQLService = Depends(get_nl2sql_service), cache: NL2SQLCache = Depends(get_cache), ) -> NL2SQLResponse | ClarifyResponse | Dict[str, Any]: """ Main NL→SQL handler. Flow: - Resolve schema preview (client override or derived from DB). - Check in-memory cache (db_id + query + schema hash). - Run the pipeline through NL2SQLService. - Map FinalResult to API response or HTTP error. """ db_id = getattr(request, "db_id", None) # ---- schema preview ---- try: final_preview = svc.get_schema_preview( db_id=db_id, override=request.schema_preview, ) except AppError: # Domain-level errors are handled by the global AppError handler. raise except Exception as exc: logger.exception( "Unexpected error while preparing schema preview", exc_info=exc, ) raise HTTPException( status_code=500, detail="failed to prepare schema", ) from exc # ---- cache lookup ---- cache_key = _ck(db_id, request.query, final_preview) cached_payload = cache.get(cache_key) if cached_payload is not None: return cached_payload # ---- pipeline execution via service ---- try: result = svc.run_query( query=request.query, db_id=db_id, schema_preview=final_preview, ) except AppError: # Let the global handler convert it to an HTTP response. raise except Exception as exc: logger.exception( "Unexpected pipeline crash in NL2SQLService.run_query", exc_info=exc, ) raise HTTPException( status_code=500, detail="internal pipeline error", ) from exc # ---- type sanity check ---- if not isinstance(result, FinalResult): logger.debug( "Pipeline returned unexpected type", extra={"type": type(result).__name__}, ) raise HTTPException( status_code=500, detail="pipeline returned unexpected type", ) # ---- ambiguity path → 200 with clarification questions ---- if result.ambiguous: qs = result.questions or [] return ClarifyResponse(ambiguous=True, questions=qs) # ---- error path → 400 with joined details ---- if (not result.ok) or result.error: logger.debug( "Pipeline reported failure", extra={ "ok": result.ok, "error": result.error, "details": result.details, }, ) message = "; ".join(result.details or []) or "Unknown error" raise HTTPException(status_code=400, detail=message) # ---- success path → 200 (normalize traces and executor result) ---- traces = [_round_trace(t) for t in (result.traces or [])] response_result: Dict[str, Any] = {} raw_result = getattr(result, "result", None) if raw_result is not None: if isinstance(raw_result, dict): response_result = raw_result else: response_result = cast(Dict[str, Any], _to_dict(raw_result)) payload = NL2SQLResponse( ambiguous=False, sql=result.sql, rationale=result.rationale, traces=traces, result=response_result, ) # Store in cache (as plain dict) cache.set(cache_key, payload.model_dump()) return payload