File size: 12,428 Bytes
9c10293
 
 
570f7bd
9c10293
64907d7
 
 
 
2d682e2
64907d7
9c10293
2d682e2
64907d7
9c10293
570f7bd
575394d
64907d7
570f7bd
 
343ad62
 
 
 
570f7bd
72c0821
2d682e2
575394d
2d682e2
 
 
72c0821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575394d
 
 
2d682e2
 
 
 
 
 
 
9c10293
570f7bd
 
b568b83
99fa656
b568b83
99fa656
 
 
 
 
b568b83
977a885
b568b83
 
99fa656
 
 
 
 
 
 
343ad62
 
 
977a885
 
 
 
 
 
b568b83
977a885
b568b83
370553a
977a885
370553a
 
 
 
 
 
 
977a885
370553a
 
 
 
 
 
977a885
 
 
 
 
 
 
370553a
 
 
 
b568b83
 
 
977a885
b568b83
977a885
b568b83
977a885
b568b83
 
 
 
 
5cbfffe
99fa656
 
b568b83
5cbfffe
99fa656
 
 
 
 
343ad62
 
 
 
 
 
 
 
 
 
 
 
99fa656
343ad62
370553a
 
 
 
 
 
 
343ad62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99fa656
 
 
343ad62
 
 
 
 
b568b83
5cbfffe
b568b83
99fa656
b568b83
99fa656
 
1fa9a31
 
5cbfffe
b568b83
79a5f4a
b568b83
977a885
 
 
 
570f7bd
5cbfffe
6181651
79a5f4a
 
 
 
 
 
 
 
6181651
 
 
 
 
79a5f4a
 
 
6181651
 
 
 
 
79a5f4a
 
 
6181651
79a5f4a
6181651
79a5f4a
6181651
 
 
79a5f4a
6181651
 
 
 
 
79a5f4a
 
 
 
 
 
 
570f7bd
5cbfffe
b568b83
 
 
 
 
 
5cbfffe
 
 
b568b83
 
 
5cbfffe
 
 
b568b83
 
 
 
5cbfffe
 
 
b568b83
 
 
 
 
 
 
 
 
 
370553a
b568b83
 
5cbfffe
b568b83
 
 
570f7bd
2d682e2
 
 
 
9c10293
343ad62
2d682e2
9c10293
6a94b42
343ad62
 
9c10293
6a94b42
343ad62
9c10293
ba06dd4
2d682e2
343ad62
 
9c10293
2d682e2
 
570f7bd
343ad62
370553a
9c10293
370553a
 
 
343ad62
a45c0eb
570f7bd
 
343ad62
2d682e2
 
 
 
 
 
570f7bd
343ad62
d5f745f
370553a
 
 
 
 
343ad62
 
570f7bd
79a5f4a
4dae3e6
570f7bd
 
a45c0eb
 
 
570f7bd
370553a
 
977a885
370553a
99fa656
370553a
 
 
977a885
 
 
370553a
 
 
 
 
 
 
 
 
 
 
 
99fa656
370553a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
from __future__ import annotations

# --- Stdlib ---
from dataclasses import asdict, is_dataclass
import json
import os
from pathlib import Path
import time
import uuid
from typing import Any, Dict, Optional, TypedDict, Union, cast, List, Callable

# --- Third-party ---
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends

# --- Local ---
from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
from nl2sql.pipeline import FinalResult, FinalResult as _FinalResult
from adapters.llm.openai_provider import OpenAIProvider
from adapters.db.sqlite_adapter import SQLiteAdapter
from adapters.db.postgres_adapter import PostgresAdapter
from nl2sql.pipeline_factory import (
    pipeline_from_config,
    pipeline_from_config_with_adapter,
)

_PIPELINE: Optional[Any] = None  # lazy cache

Runner = Callable[..., _FinalResult]


def get_runner() -> Runner:
    """Build pipeline lazily; under pytest return a stub runner."""
    if os.getenv("PYTEST_CURRENT_TEST"):
        # Minimal OK runner for route tests (no ambiguity)
        def _fake_runner(
            *, user_query: str, schema_preview: str | None = None
        ) -> _FinalResult:
            return _FinalResult(
                ok=True,
                ambiguous=False,
                error=False,
                details=None,
                questions=None,
                sql="SELECT 1;",
                rationale=None,
                verified=True,
                traces=[],
            )

        return _fake_runner

    global _PIPELINE
    if _PIPELINE is None:
        _PIPELINE = pipeline_from_config(CONFIG_PATH)
    return _PIPELINE.run


def _build_pipeline(adapter) -> Any:
    """Thin wrapper for tests to monkeypatch; builds a pipeline bound to adapter."""
    return pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)


router = APIRouter(prefix="/nl2sql")

# -------------------------------
# Config / Defaults
# -------------------------------
DB_MODE = os.getenv("DB_MODE", "sqlite").lower()  # "sqlite" or "postgres"
POSTGRES_DSN = os.getenv("POSTGRES_DSN")
DEFAULT_SQLITE_PATH: str = os.getenv("DEFAULT_SQLITE_DB", "data/Chinook_Sqlite.sqlite")

# Runtime upload storage
_DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
_DB_TTL_SECONDS: int = int(os.getenv("DB_TTL_SECONDS", "7200"))  # default 2 hours
os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)

# Persisted map
_DB_MAP_PATH = Path("data/uploads/db_map.json")
_DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)

UPLOAD_DIR = Path("data/uploads")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)

CONFIG_PATH = os.getenv("PIPELINE_CONFIG", "configs/sqlite_pipeline.yaml")
_PIPELINE = pipeline_from_config(CONFIG_PATH)


class DBEntry(TypedDict):
    path: str
    ts: float


# In-memory map: db_id -> {"path": str, "ts": float}
_DB_MAP: Dict[str, DBEntry] = {}


def _save_db_map() -> None:
    try:
        with open(_DB_MAP_PATH, "w") as f:
            json.dump(_DB_MAP, f)
    except Exception as e:
        print(f"⚠️ Failed to save DB map: {e}")


def _load_db_map() -> None:
    global _DB_MAP
    if _DB_MAP_PATH.exists():
        try:
            with open(_DB_MAP_PATH, "r") as f:
                data = json.load(f)
            if isinstance(data, dict):
                restored: Dict[str, DBEntry] = {}
                for k, v in data.items():
                    path = v.get("path")
                    ts = v.get("ts")
                    if isinstance(path, str) and isinstance(ts, (int, float)):
                        restored[k] = {"path": path, "ts": float(ts)}
                _DB_MAP.update(restored)
                print(f"📂 Restored {_DB_MAP_PATH} with {len(_DB_MAP)} entries.")
        except Exception as e:
            print(f"⚠️ Failed to load DB map: {e}")


def _cleanup_db_map() -> None:
    now = time.time()
    expired = [k for k, v in _DB_MAP.items() if (now - v["ts"]) > _DB_TTL_SECONDS]
    for k in expired:
        path: str = _DB_MAP[k]["path"]
        try:
            if os.path.exists(path):
                os.remove(path)
        except Exception:
            pass
        _DB_MAP.pop(k, None)


# Call once at import (safe & light); heavy things remain lazy.
_load_db_map()


# -------------------------------
# Adapter selection (lazy)
# -------------------------------
def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
    """
    Resolve a DB adapter based on module-level DB_MODE and an optional db_id.

    - postgres mode:
        requires POSTGRES_DSN in env
    - sqlite mode:
        if db_id provided, resolve file by:
            1) absolute path (if user supplied a full path)
            2) uploads/{db_id}.sqlite
            3) uploads/{db_id}.db
            4) data/{db_id}.sqlite
            5) data/{db_id}.db
        else fallback to DEFAULT_SQLITE_PATH
    """
    if DB_MODE == "postgres":
        dsn = os.environ.get("POSTGRES_DSN")
        if not dsn:
            raise HTTPException(status_code=500, detail="POSTGRES_DSN env is missing")
        return PostgresAdapter(dsn)

    # sqlite mode
    if db_id:
        # 1) absolute path
        p = Path(db_id)
        candidates: List[Path] = []
        if p.is_absolute():
            candidates.append(p)

        # 2) uploads/
        candidates.append(UPLOAD_DIR / f"{db_id}.sqlite")
        candidates.append(UPLOAD_DIR / f"{db_id}.db")

        # 3) data/
        candidates.append(Path("data") / f"{db_id}.sqlite")
        candidates.append(Path("data") / f"{db_id}.db")

        for c in candidates:
            if c.exists() and c.is_file():
                return SQLiteAdapter(str(c))

        raise HTTPException(status_code=400, detail="invalid db_id (file not found)")

    # default sqlite fallback
    default_path = Path(DEFAULT_SQLITE_PATH)
    if not default_path.exists():
        raise HTTPException(status_code=500, detail="default SQLite DB not found")
    return SQLiteAdapter(str(default_path))


# -------------------------------
# LLM & Pipeline builders (lazy)
# -------------------------------
def _get_llm() -> OpenAIProvider:
    # Create provider on demand, after .env has been loaded in app.main
    return OpenAIProvider()


# -------------------------------
# Helpers
# -------------------------------
def _to_dict(obj: Any) -> Any:
    if is_dataclass(obj) and not isinstance(obj, type):
        return asdict(obj)  # type: ignore[arg-type]
    return obj


def _round_trace(t: Any) -> Dict[str, Any]:
    """
    Normalize a trace entry (dict or StageTrace-like object) for API/UI:
    - stage: str (required)
    - duration_ms: int (rounded)
    - summary: optional (pass-through if exists)
    - notes: optional
    - token_in/out, cost_usd: pass-through if present
    """
    if isinstance(t, dict):
        stage = t.get("stage", "?")
        ms = t.get("duration_ms", 0)
        notes = t.get("notes")
        cost = t.get("cost_usd")
        summary = t.get("summary")
        token_in = t.get("token_in")
        token_out = t.get("token_out")
    else:
        stage = getattr(t, "stage", "?")
        ms = getattr(t, "duration_ms", 0)
        notes = getattr(t, "notes", None)
        cost = getattr(t, "cost_usd", None)
        summary = getattr(t, "summary", None)
        token_in = getattr(t, "token_in", None)
        token_out = getattr(t, "token_out", None)

    # coerce duration to int with rounding
    try:
        ms_int = int(round(float(ms))) if ms is not None else 0
    except Exception:
        ms_int = 0

    out: Dict[str, Any] = {
        "stage": str(stage) if stage is not None else "?",
        "duration_ms": ms_int,
        "notes": notes,
        "cost_usd": cost,
    }
    if summary is not None:
        out["summary"] = summary
    if token_in is not None:
        out["token_in"] = token_in
    if token_out is not None:
        out["token_out"] = token_out
    return out


# -------------------------------
# Upload endpoint (SQLite only)
# -------------------------------
@router.post("/upload_db")
async def upload_db(file: UploadFile = File(...)):
    if DB_MODE != "sqlite":
        raise HTTPException(
            status_code=400, detail="DB upload is only supported in sqlite mode"
        )

    filename = file.filename or "db.sqlite"
    if not (filename.endswith(".db") or filename.endswith(".sqlite")):
        raise HTTPException(
            status_code=400, detail="Only .db or .sqlite files are allowed"
        )

    data = await file.read()
    max_bytes = int(os.getenv("UPLOAD_MAX_BYTES", str(20 * 1024 * 1024)))  # 20 MB
    if len(data) > max_bytes:
        raise HTTPException(
            status_code=400, detail=f"File too large (> {max_bytes} bytes)"
        )

    db_id = str(uuid.uuid4())
    out_path = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
    try:
        with open(out_path, "wb") as f:
            f.write(data)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to store DB: {e}")

    _DB_MAP[db_id] = {"path": out_path, "ts": time.time()}
    _save_db_map()
    return {"db_id": db_id}


# -------------------------------
# Main NL2SQL endpoint
# -------------------------------
@router.post("", name="nl2sql_handler")
def nl2sql_handler(
    request: NL2SQLRequest,
    run: Runner = Depends(get_runner),
):
    """
    NL→SQL handler using YAML-driven DI. If 'db_id' is provided, we override only the adapter
    while keeping all other stages from the YAML configs intact.
    """
    db_id = getattr(request, "db_id", None)
    provided_preview = (
        cast(Optional[str], getattr(request, "schema_preview", None)) or ""
    )

    # Choose runner: default pipeline from YAML OR per-request override with a specific adapter
    if db_id:
        adapter = _select_adapter(db_id)
        pipeline = _build_pipeline(adapter)
        runner = pipeline.run
        final_preview = provided_preview  # keep simple; derive only if you have a SQLite schema helper
    else:
        runner = run
        final_preview = provided_preview or ""

    # Execute pipeline
    try:
        result = runner(user_query=request.query, schema_preview=final_preview)
    except Exception as exc:
        raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")

    # Type sanity
    if not isinstance(result, FinalResult):
        raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")

    # Ambiguity path → 200 with questions
    if result.ambiguous:
        qs = result.questions or []
        return ClarifyResponse(ambiguous=True, questions=qs)

    if not isinstance(result, _FinalResult):
        raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")

    # Error path → 400 with joined details
    if (not result.ok) or result.error:
        print("❌ Pipeline failure dump:")
        print("  ok:", result.ok)
        print("  error:", result.error)
        print("  details:", result.details)
        print("  traces:", result.traces)
        message = "; ".join(result.details or []) or "Unknown error"
        raise HTTPException(status_code=400, detail=message)

    # Success path → 200 (coerce/standardize traces for API)
    traces = [_round_trace(t) for t in (result.traces or [])]
    return NL2SQLResponse(
        ambiguous=False,
        sql=result.sql,
        rationale=result.rationale,
        traces=traces,
    )


def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
    """
    Build a strict, exact-cased schema preview for the LLM (SQLite only).
    """
    import sqlite3

    db_path: Optional[str] = cast(
        Optional[str], getattr(adapter, "db_path", None)
    ) or cast(Optional[str], getattr(adapter, "path", None))
    if not db_path or not os.path.exists(db_path):
        return ""

    try:
        conn = sqlite3.connect(db_path)
        cur = conn.cursor()
        tables = cur.execute(
            "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
        ).fetchall()
        lines = []
        for (tname,) in tables:
            cols = cur.execute(f"PRAGMA table_info('{tname}')").fetchall()
            colnames = [c[1] for c in cols]  # (cid, name, type, notnull, dflt, pk)
            lines.append(f"{tname}({', '.join(colnames)})")
        conn.close()
        return "\n".join(lines)
    except Exception:
        return ""