Spaces:
Running
Running
Melika Kheirieh
commited on
Commit
·
99fa656
1
Parent(s):
fb51384
fix(typing): split adapter vars in /readyz to satisfy mypy
Browse files- .pre-commit-config.yaml +3 -3
- app/bootstrap.py +13 -3
- app/main.py +28 -38
- app/routers/nl2sql.py +78 -142
.pre-commit-config.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
repos:
|
| 2 |
# --- Basic hygiene checks ---
|
| 3 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 4 |
-
rev:
|
| 5 |
hooks:
|
| 6 |
- id: check-merge-conflict
|
| 7 |
- id: end-of-file-fixer
|
|
@@ -9,7 +9,7 @@ repos:
|
|
| 9 |
|
| 10 |
# --- Ruff: linting and formatting ---
|
| 11 |
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 12 |
-
rev: v0.
|
| 13 |
hooks:
|
| 14 |
- id: ruff
|
| 15 |
args: [--fix, --exit-non-zero-on-fix]
|
|
@@ -17,7 +17,7 @@ repos:
|
|
| 17 |
|
| 18 |
# --- Mypy: type-checking on staged Python files ---
|
| 19 |
- repo: https://github.com/pre-commit/mirrors-mypy
|
| 20 |
-
rev: v1.
|
| 21 |
hooks:
|
| 22 |
- id: mypy
|
| 23 |
args:
|
|
|
|
| 1 |
repos:
|
| 2 |
# --- Basic hygiene checks ---
|
| 3 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 4 |
+
rev: v6.0.0
|
| 5 |
hooks:
|
| 6 |
- id: check-merge-conflict
|
| 7 |
- id: end-of-file-fixer
|
|
|
|
| 9 |
|
| 10 |
# --- Ruff: linting and formatting ---
|
| 11 |
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 12 |
+
rev: v0.14.3
|
| 13 |
hooks:
|
| 14 |
- id: ruff
|
| 15 |
args: [--fix, --exit-non-zero-on-fix]
|
|
|
|
| 17 |
|
| 18 |
# --- Mypy: type-checking on staged Python files ---
|
| 19 |
- repo: https://github.com/pre-commit/mirrors-mypy
|
| 20 |
+
rev: v1.18.2
|
| 21 |
hooks:
|
| 22 |
- id: mypy
|
| 23 |
args:
|
app/bootstrap.py
CHANGED
|
@@ -1,7 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
try:
|
| 2 |
from dotenv import load_dotenv
|
| 3 |
|
| 4 |
load_dotenv()
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
| 1 |
+
"""App bootstrap: load .env and prepare environment paths."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# Change current dir to project root (so relative paths like data/ work)
|
| 7 |
+
ROOT_DIR = Path(__file__).resolve().parent.parent
|
| 8 |
+
os.chdir(ROOT_DIR)
|
| 9 |
+
|
| 10 |
+
# Load .env if available
|
| 11 |
try:
|
| 12 |
from dotenv import load_dotenv
|
| 13 |
|
| 14 |
load_dotenv()
|
| 15 |
+
print("✅ bootstrap: .env loaded")
|
| 16 |
+
except Exception as e:
|
| 17 |
+
print(f"⚠️ bootstrap: could not load .env ({e})")
|
app/main.py
CHANGED
|
@@ -1,13 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
-
from typing import Protocol, runtime_checkable, cast
|
| 4 |
-
|
| 5 |
from fastapi import FastAPI, Request, Response, HTTPException
|
| 6 |
from fastapi.responses import PlainTextResponse
|
| 7 |
-
|
| 8 |
-
from app.routers import nl2sql
|
| 9 |
-
|
| 10 |
-
# Prometheus
|
| 11 |
from prometheus_client import (
|
| 12 |
Counter,
|
| 13 |
Histogram,
|
|
@@ -16,13 +10,14 @@ from prometheus_client import (
|
|
| 16 |
CONTENT_TYPE_LATEST,
|
| 17 |
)
|
| 18 |
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def ping(self) -> None: ...
|
| 25 |
|
|
|
|
| 26 |
|
| 27 |
# ---- Optionally restore uploaded DB map ----
|
| 28 |
try:
|
|
@@ -32,24 +27,22 @@ try:
|
|
| 32 |
except Exception as e:
|
| 33 |
print(f"⚠️ DB map not restored: {e}")
|
| 34 |
|
| 35 |
-
|
| 36 |
title="NL2SQL Copilot Prototype",
|
| 37 |
version=os.getenv("APP_VERSION", "0.1.0"),
|
| 38 |
description="Convert natural language to safe & verified SQL",
|
| 39 |
)
|
| 40 |
|
| 41 |
-
|
| 42 |
|
| 43 |
# ---- Prometheus metrics ----
|
| 44 |
-
REGISTRY
|
| 45 |
-
|
| 46 |
REQUEST_COUNT = Counter(
|
| 47 |
"http_requests_total",
|
| 48 |
"Total HTTP requests",
|
| 49 |
["path", "method", "status_code"],
|
| 50 |
registry=REGISTRY,
|
| 51 |
)
|
| 52 |
-
|
| 53 |
REQUEST_LATENCY = Histogram(
|
| 54 |
"http_request_latency_seconds",
|
| 55 |
"Request latency",
|
|
@@ -58,20 +51,13 @@ REQUEST_LATENCY = Histogram(
|
|
| 58 |
)
|
| 59 |
|
| 60 |
|
| 61 |
-
@
|
| 62 |
async def metrics_middleware(request: Request, call_next):
|
| 63 |
start = time.perf_counter()
|
| 64 |
response: Response = await call_next(request)
|
| 65 |
elapsed = time.perf_counter() - start
|
| 66 |
-
|
| 67 |
-
# Use route path if available, else raw path (typed guard for mypy)
|
| 68 |
route = request.scope.get("route")
|
| 69 |
-
path =
|
| 70 |
-
route.path
|
| 71 |
-
if (route is not None and hasattr(route, "path"))
|
| 72 |
-
else request.url.path
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
REQUEST_COUNT.labels(
|
| 76 |
path=path, method=request.method, status_code=str(response.status_code)
|
| 77 |
).inc()
|
|
@@ -79,14 +65,16 @@ async def metrics_middleware(request: Request, call_next):
|
|
| 79 |
return response
|
| 80 |
|
| 81 |
|
| 82 |
-
# --- Liveness
|
| 83 |
-
@
|
| 84 |
def healthz() -> str:
|
| 85 |
return "ok"
|
| 86 |
|
| 87 |
|
| 88 |
-
# --- Readiness
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
def readyz() -> str:
|
| 91 |
mode = os.getenv("DB_MODE", "sqlite").lower()
|
| 92 |
try:
|
|
@@ -94,32 +82,34 @@ def readyz() -> str:
|
|
| 94 |
from adapters.db.postgres_adapter import PostgresAdapter
|
| 95 |
|
| 96 |
dsn = os.environ["POSTGRES_DSN"]
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
else:
|
| 100 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 101 |
|
| 102 |
db_path = os.getenv("SQLITE_DB_PATH", "data/chinook.db")
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 106 |
return "ready"
|
| 107 |
except Exception:
|
| 108 |
raise HTTPException(status_code=503, detail="not ready")
|
| 109 |
|
| 110 |
|
| 111 |
-
@
|
| 112 |
def root():
|
| 113 |
return {"status": "ok", "message": "NL2SQL Copilot API is running"}
|
| 114 |
|
| 115 |
|
| 116 |
-
@
|
| 117 |
def health():
|
| 118 |
-
# You might want to replace the placeholders with real checks later.
|
| 119 |
return {"status": "ok", "db": "connected", "llm": "reachable", "uptime_sec": 123.4}
|
| 120 |
|
| 121 |
|
| 122 |
-
@
|
| 123 |
def metrics():
|
| 124 |
data = generate_latest(REGISTRY)
|
| 125 |
return Response(content=data, media_type=CONTENT_TYPE_LATEST)
|
|
|
|
| 1 |
import os
|
| 2 |
import time
|
|
|
|
|
|
|
| 3 |
from fastapi import FastAPI, Request, Response, HTTPException
|
| 4 |
from fastapi.responses import PlainTextResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from prometheus_client import (
|
| 6 |
Counter,
|
| 7 |
Histogram,
|
|
|
|
| 10 |
CONTENT_TYPE_LATEST,
|
| 11 |
)
|
| 12 |
|
| 13 |
+
try:
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
|
| 16 |
+
load_dotenv()
|
| 17 |
+
except Exception:
|
| 18 |
+
pass
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
from app.routers import nl2sql
|
| 21 |
|
| 22 |
# ---- Optionally restore uploaded DB map ----
|
| 23 |
try:
|
|
|
|
| 27 |
except Exception as e:
|
| 28 |
print(f"⚠️ DB map not restored: {e}")
|
| 29 |
|
| 30 |
+
application: FastAPI = FastAPI(
|
| 31 |
title="NL2SQL Copilot Prototype",
|
| 32 |
version=os.getenv("APP_VERSION", "0.1.0"),
|
| 33 |
description="Convert natural language to safe & verified SQL",
|
| 34 |
)
|
| 35 |
|
| 36 |
+
application.include_router(nl2sql.router, prefix="/api/v1")
|
| 37 |
|
| 38 |
# ---- Prometheus metrics ----
|
| 39 |
+
REGISTRY = CollectorRegistry()
|
|
|
|
| 40 |
REQUEST_COUNT = Counter(
|
| 41 |
"http_requests_total",
|
| 42 |
"Total HTTP requests",
|
| 43 |
["path", "method", "status_code"],
|
| 44 |
registry=REGISTRY,
|
| 45 |
)
|
|
|
|
| 46 |
REQUEST_LATENCY = Histogram(
|
| 47 |
"http_request_latency_seconds",
|
| 48 |
"Request latency",
|
|
|
|
| 51 |
)
|
| 52 |
|
| 53 |
|
| 54 |
+
@application.middleware("http")
|
| 55 |
async def metrics_middleware(request: Request, call_next):
|
| 56 |
start = time.perf_counter()
|
| 57 |
response: Response = await call_next(request)
|
| 58 |
elapsed = time.perf_counter() - start
|
|
|
|
|
|
|
| 59 |
route = request.scope.get("route")
|
| 60 |
+
path = route.path if route else request.url.path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
REQUEST_COUNT.labels(
|
| 62 |
path=path, method=request.method, status_code=str(response.status_code)
|
| 63 |
).inc()
|
|
|
|
| 65 |
return response
|
| 66 |
|
| 67 |
|
| 68 |
+
# --- Liveness ---
|
| 69 |
+
@application.get("/healthz", response_class=PlainTextResponse, tags=["system"])
|
| 70 |
def healthz() -> str:
|
| 71 |
return "ok"
|
| 72 |
|
| 73 |
|
| 74 |
+
# --- Readiness ---
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@application.get("/readyz", response_class=PlainTextResponse, tags=["system"])
|
| 78 |
def readyz() -> str:
|
| 79 |
mode = os.getenv("DB_MODE", "sqlite").lower()
|
| 80 |
try:
|
|
|
|
| 82 |
from adapters.db.postgres_adapter import PostgresAdapter
|
| 83 |
|
| 84 |
dsn = os.environ["POSTGRES_DSN"]
|
| 85 |
+
pg = PostgresAdapter(dsn)
|
| 86 |
+
ping = getattr(pg, "ping", None)
|
| 87 |
+
if callable(ping):
|
| 88 |
+
ping()
|
| 89 |
else:
|
| 90 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 91 |
|
| 92 |
db_path = os.getenv("SQLITE_DB_PATH", "data/chinook.db")
|
| 93 |
+
sq = SQLiteAdapter(db_path)
|
| 94 |
+
ping = getattr(sq, "ping", None)
|
| 95 |
+
if callable(ping):
|
| 96 |
+
ping()
|
| 97 |
return "ready"
|
| 98 |
except Exception:
|
| 99 |
raise HTTPException(status_code=503, detail="not ready")
|
| 100 |
|
| 101 |
|
| 102 |
+
@application.get("/")
|
| 103 |
def root():
|
| 104 |
return {"status": "ok", "message": "NL2SQL Copilot API is running"}
|
| 105 |
|
| 106 |
|
| 107 |
+
@application.get("/health")
|
| 108 |
def health():
|
|
|
|
| 109 |
return {"status": "ok", "db": "connected", "llm": "reachable", "uptime_sec": 123.4}
|
| 110 |
|
| 111 |
|
| 112 |
+
@application.get("/metrics", tags=["system"])
|
| 113 |
def metrics():
|
| 114 |
data = generate_latest(REGISTRY)
|
| 115 |
return Response(content=data, media_type=CONTENT_TYPE_LATEST)
|
app/routers/nl2sql.py
CHANGED
|
@@ -22,46 +22,25 @@ from typing import Union, Optional, Dict, TypedDict, Any, cast
|
|
| 22 |
|
| 23 |
router = APIRouter(prefix="/nl2sql")
|
| 24 |
|
| 25 |
-
# --- Database adapter selection ---
|
| 26 |
-
DB_MODE = os.getenv("DB_MODE", "sqlite").lower()
|
| 27 |
-
|
| 28 |
-
_db: Union[PostgresAdapter, SQLiteAdapter]
|
| 29 |
-
if DB_MODE == "postgres":
|
| 30 |
-
dsn = os.environ.get("POSTGRES_DSN")
|
| 31 |
-
if not dsn:
|
| 32 |
-
raise RuntimeError(
|
| 33 |
-
"POSTGRES_DSN environment variable is required in postgres mode"
|
| 34 |
-
)
|
| 35 |
-
_db = PostgresAdapter(dsn)
|
| 36 |
-
else:
|
| 37 |
-
_db = SQLiteAdapter("data/chinook.db")
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# --- Build a single shared pipeline for all routes ---
|
| 41 |
-
def _make_pipeline() -> Pipeline:
|
| 42 |
-
llm = OpenAIProvider()
|
| 43 |
-
return Pipeline(
|
| 44 |
-
detector=AmbiguityDetector(),
|
| 45 |
-
planner=Planner(llm=llm),
|
| 46 |
-
generator=Generator(llm=llm),
|
| 47 |
-
safety=Safety(),
|
| 48 |
-
executor=Executor(db=_db),
|
| 49 |
-
verifier=Verifier(),
|
| 50 |
-
repair=Repair(llm=llm),
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
_pipeline: Pipeline = _make_pipeline()
|
| 55 |
-
|
| 56 |
-
|
| 57 |
# -------------------------------
|
| 58 |
-
#
|
| 59 |
-
# Files are stored under /tmp, mapped by a short-lived db_id
|
| 60 |
# -------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
_DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
|
| 62 |
_DB_TTL_SECONDS: int = int(os.getenv("DB_TTL_SECONDS", "7200")) # default 2 hours
|
| 63 |
os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
class DBEntry(TypedDict):
|
| 67 |
path: str
|
|
@@ -71,27 +50,8 @@ class DBEntry(TypedDict):
|
|
| 71 |
# In-memory map: db_id -> {"path": str, "ts": float}
|
| 72 |
_DB_MAP: Dict[str, DBEntry] = {}
|
| 73 |
|
| 74 |
-
# -------------------------------
|
| 75 |
-
# Default DB resolution
|
| 76 |
-
# -------------------------------
|
| 77 |
-
DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
|
| 78 |
-
POSTGRES_DSN = os.getenv("POSTGRES_DSN")
|
| 79 |
-
DEFAULT_SQLITE_DB: str = os.getenv("DEFAULT_SQLITE_DB", "data/chinook.db")
|
| 80 |
-
|
| 81 |
-
# -------------------------------
|
| 82 |
-
# Path to persist db_id → file map
|
| 83 |
-
# -------------------------------
|
| 84 |
-
_DB_MAP_PATH = Path("data/uploads/db_map.json")
|
| 85 |
-
_DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 86 |
-
|
| 87 |
-
UPLOAD_DIR = Path("data/uploads")
|
| 88 |
-
UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # ensure folder exists
|
| 89 |
-
|
| 90 |
-
DEFAULT_SQLITE_PATH = "data/Chinook_Sqlite.sqlite"
|
| 91 |
-
|
| 92 |
|
| 93 |
def _save_db_map() -> None:
|
| 94 |
-
"""Persist the in-memory DB map to disk as JSON."""
|
| 95 |
try:
|
| 96 |
with open(_DB_MAP_PATH, "w") as f:
|
| 97 |
json.dump(_DB_MAP, f)
|
|
@@ -100,13 +60,11 @@ def _save_db_map() -> None:
|
|
| 100 |
|
| 101 |
|
| 102 |
def _load_db_map() -> None:
|
| 103 |
-
"""Load the DB map from disk if it exists (called on startup)."""
|
| 104 |
global _DB_MAP
|
| 105 |
if _DB_MAP_PATH.exists():
|
| 106 |
try:
|
| 107 |
with open(_DB_MAP_PATH, "r") as f:
|
| 108 |
data = json.load(f)
|
| 109 |
-
# Be liberal in what we accept; validate into TypedDict
|
| 110 |
if isinstance(data, dict):
|
| 111 |
restored: Dict[str, DBEntry] = {}
|
| 112 |
for k, v in data.items():
|
|
@@ -121,7 +79,6 @@ def _load_db_map() -> None:
|
|
| 121 |
|
| 122 |
|
| 123 |
def _cleanup_db_map() -> None:
|
| 124 |
-
"""Remove expired uploaded DB files (best-effort)."""
|
| 125 |
now = time.time()
|
| 126 |
expired = [k for k, v in _DB_MAP.items() if (now - v["ts"]) > _DB_TTL_SECONDS]
|
| 127 |
for k in expired:
|
|
@@ -134,15 +91,20 @@ def _cleanup_db_map() -> None:
|
|
| 134 |
_DB_MAP.pop(k, None)
|
| 135 |
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
_cleanup_db_map()
|
| 140 |
-
if db_id and db_id in _DB_MAP:
|
| 141 |
-
return _DB_MAP[db_id]["path"]
|
| 142 |
-
return DEFAULT_SQLITE_DB
|
| 143 |
|
| 144 |
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
mode = os.getenv("DB_MODE", "sqlite").lower()
|
| 147 |
if mode == "postgres":
|
| 148 |
dsn = os.environ.get("POSTGRES_DSN")
|
|
@@ -151,79 +113,72 @@ def _select_adapter(db_id: Optional[str]):
|
|
| 151 |
return PostgresAdapter(dsn)
|
| 152 |
|
| 153 |
# sqlite mode
|
|
|
|
| 154 |
if db_id:
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
if
|
| 159 |
-
|
| 160 |
-
#
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
)
|
| 172 |
-
return SQLiteAdapter(db_path)
|
| 173 |
-
|
| 174 |
-
# fallback to default Chinook
|
| 175 |
if not Path(DEFAULT_SQLITE_PATH).exists():
|
| 176 |
raise HTTPException(status_code=500, detail="default DB not found")
|
| 177 |
return SQLiteAdapter(DEFAULT_SQLITE_PATH)
|
| 178 |
|
| 179 |
|
| 180 |
# -------------------------------
|
| 181 |
-
# LLM
|
| 182 |
# -------------------------------
|
| 183 |
-
def
|
|
|
|
| 184 |
return OpenAIProvider()
|
| 185 |
|
| 186 |
|
| 187 |
-
_detector = AmbiguityDetector()
|
| 188 |
-
_planner = Planner(get_llm())
|
| 189 |
-
_generator = Generator(get_llm())
|
| 190 |
-
_safety = Safety()
|
| 191 |
-
_verifier = Verifier()
|
| 192 |
-
_repair = Repair(get_llm())
|
| 193 |
-
|
| 194 |
-
|
| 195 |
def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
|
| 196 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
executor = Executor(adapter)
|
|
|
|
|
|
|
| 198 |
return Pipeline(
|
| 199 |
-
detector=
|
| 200 |
-
planner=
|
| 201 |
-
generator=
|
| 202 |
-
safety=
|
| 203 |
executor=executor,
|
| 204 |
-
verifier=
|
| 205 |
-
repair=
|
| 206 |
)
|
| 207 |
|
| 208 |
|
| 209 |
# -------------------------------
|
| 210 |
-
# Helpers
|
| 211 |
# -------------------------------
|
| 212 |
def _to_dict(obj: Any) -> Any:
|
| 213 |
-
"""Safely convert dataclass instance → dict, otherwise return as-is.
|
| 214 |
-
|
| 215 |
-
Note: dataclasses.is_dataclass returns True for both classes and instances.
|
| 216 |
-
We must exclude classes; mypy cannot refine this perfectly, so we ignore arg-type.
|
| 217 |
-
"""
|
| 218 |
if is_dataclass(obj) and not isinstance(obj, type):
|
| 219 |
return asdict(obj) # type: ignore[arg-type]
|
| 220 |
return obj
|
| 221 |
|
| 222 |
|
| 223 |
def _round_trace(t: Dict[str, Any]) -> Dict[str, Any]:
|
| 224 |
-
"""Round float fields to keep responses tidy and stable."""
|
| 225 |
if t.get("cost_usd") is not None:
|
| 226 |
-
# Ensure numeric before rounding
|
| 227 |
cost = t["cost_usd"]
|
| 228 |
if isinstance(cost, (int, float)):
|
| 229 |
t["cost_usd"] = round(float(cost), 6)
|
|
@@ -236,17 +191,9 @@ def _round_trace(t: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 236 |
|
| 237 |
# -------------------------------
|
| 238 |
# Upload endpoint (SQLite only)
|
| 239 |
-
# Path will be /api/nl2sql/upload_db if your root APIRouter is mounted at /api
|
| 240 |
# -------------------------------
|
| 241 |
@router.post("/upload_db")
|
| 242 |
async def upload_db(file: UploadFile = File(...)):
|
| 243 |
-
"""
|
| 244 |
-
Upload a SQLite database (.db/.sqlite). Returns a short-lived db_id.
|
| 245 |
-
Notes:
|
| 246 |
-
- Only SQLite files are allowed here (not for Postgres mode).
|
| 247 |
-
- Max size ~20MB recommended for demo environments like HF Spaces.
|
| 248 |
-
- Files are stored under /tmp and cleaned by TTL.
|
| 249 |
-
"""
|
| 250 |
if DB_MODE != "sqlite":
|
| 251 |
raise HTTPException(
|
| 252 |
status_code=400, detail="DB upload is only supported in sqlite mode"
|
|
@@ -280,48 +227,42 @@ async def upload_db(file: UploadFile = File(...)):
|
|
| 280 |
|
| 281 |
# -------------------------------
|
| 282 |
# Main NL2SQL endpoint
|
| 283 |
-
# Path will be /api/nl2sql if your root APIRouter is mounted at /api
|
| 284 |
# -------------------------------
|
| 285 |
@router.post("", name="nl2sql_handler")
|
| 286 |
def nl2sql_handler(request: NL2SQLRequest):
|
| 287 |
db_id = getattr(request, "db_id", None)
|
| 288 |
|
| 289 |
-
#
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
# 2) Resolve schema_preview
|
| 302 |
provided_preview_any: Any = getattr(request, "schema_preview", None)
|
| 303 |
provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
|
| 304 |
final_preview: str = provided_preview or derived_preview_val
|
| 305 |
|
| 306 |
-
#
|
| 307 |
try:
|
| 308 |
result = pipeline.run(
|
| 309 |
user_query=request.query,
|
| 310 |
schema_preview=final_preview,
|
| 311 |
)
|
| 312 |
except Exception as exc:
|
| 313 |
-
# Hard failure in pipeline itself
|
| 314 |
raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
|
| 315 |
|
| 316 |
-
# 4) Type check
|
| 317 |
if not isinstance(result, FinalResult):
|
| 318 |
raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
|
| 319 |
|
| 320 |
-
# 5) Ambiguity → ask for clarification
|
| 321 |
if result.ambiguous and result.questions:
|
| 322 |
return ClarifyResponse(ambiguous=True, questions=result.questions)
|
| 323 |
|
| 324 |
-
# 6) Soft errors → bubble up details with 400
|
| 325 |
if not result.ok or result.error:
|
| 326 |
print("❌ Pipeline failure dump:")
|
| 327 |
print(" ok:", result.ok)
|
|
@@ -333,7 +274,6 @@ def nl2sql_handler(request: NL2SQLRequest):
|
|
| 333 |
detail="; ".join(result.details or []) or (result.error or "Unknown error"),
|
| 334 |
)
|
| 335 |
|
| 336 |
-
# 7) Success
|
| 337 |
traces = [_round_trace(t) for t in (result.traces or [])]
|
| 338 |
return NL2SQLResponse(
|
| 339 |
ambiguous=False,
|
|
@@ -345,13 +285,10 @@ def nl2sql_handler(request: NL2SQLRequest):
|
|
| 345 |
|
| 346 |
def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
|
| 347 |
"""
|
| 348 |
-
Build a strict, exact-cased schema preview for the LLM.
|
| 349 |
-
Works for SQLite adapters by querying sqlite_master / pragma table_info.
|
| 350 |
"""
|
| 351 |
import sqlite3
|
| 352 |
-
import os
|
| 353 |
|
| 354 |
-
# Adapters may expose db_path or path; both are str in our codebase
|
| 355 |
db_path: Optional[str] = cast(
|
| 356 |
Optional[str], getattr(adapter, "db_path", None)
|
| 357 |
) or cast(Optional[str], getattr(adapter, "path", None))
|
|
@@ -367,8 +304,7 @@ def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> st
|
|
| 367 |
lines = []
|
| 368 |
for (tname,) in tables:
|
| 369 |
cols = cur.execute(f"PRAGMA table_info('{tname}')").fetchall()
|
| 370 |
-
|
| 371 |
-
colnames = [c[1] for c in cols]
|
| 372 |
lines.append(f"{tname}({', '.join(colnames)})")
|
| 373 |
conn.close()
|
| 374 |
return "\n".join(lines)
|
|
|
|
| 22 |
|
| 23 |
router = APIRouter(prefix="/nl2sql")
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# -------------------------------
|
| 26 |
+
# Config / Defaults
|
|
|
|
| 27 |
# -------------------------------
|
| 28 |
+
DB_MODE = os.getenv("DB_MODE", "sqlite").lower() # "sqlite" or "postgres"
|
| 29 |
+
POSTGRES_DSN = os.getenv("POSTGRES_DSN")
|
| 30 |
+
DEFAULT_SQLITE_PATH: str = os.getenv("DEFAULT_SQLITE_DB", "data/Chinook_Sqlite.sqlite")
|
| 31 |
+
|
| 32 |
+
# Runtime upload storage
|
| 33 |
_DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
|
| 34 |
_DB_TTL_SECONDS: int = int(os.getenv("DB_TTL_SECONDS", "7200")) # default 2 hours
|
| 35 |
os.makedirs(_DB_UPLOAD_DIR, exist_ok=True)
|
| 36 |
|
| 37 |
+
# Persisted map
|
| 38 |
+
_DB_MAP_PATH = Path("data/uploads/db_map.json")
|
| 39 |
+
_DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
UPLOAD_DIR = Path("data/uploads")
|
| 42 |
+
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
|
| 44 |
|
| 45 |
class DBEntry(TypedDict):
|
| 46 |
path: str
|
|
|
|
| 50 |
# In-memory map: db_id -> {"path": str, "ts": float}
|
| 51 |
_DB_MAP: Dict[str, DBEntry] = {}
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def _save_db_map() -> None:
|
|
|
|
| 55 |
try:
|
| 56 |
with open(_DB_MAP_PATH, "w") as f:
|
| 57 |
json.dump(_DB_MAP, f)
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def _load_db_map() -> None:
|
|
|
|
| 63 |
global _DB_MAP
|
| 64 |
if _DB_MAP_PATH.exists():
|
| 65 |
try:
|
| 66 |
with open(_DB_MAP_PATH, "r") as f:
|
| 67 |
data = json.load(f)
|
|
|
|
| 68 |
if isinstance(data, dict):
|
| 69 |
restored: Dict[str, DBEntry] = {}
|
| 70 |
for k, v in data.items():
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
def _cleanup_db_map() -> None:
|
|
|
|
| 82 |
now = time.time()
|
| 83 |
expired = [k for k, v in _DB_MAP.items() if (now - v["ts"]) > _DB_TTL_SECONDS]
|
| 84 |
for k in expired:
|
|
|
|
| 91 |
_DB_MAP.pop(k, None)
|
| 92 |
|
| 93 |
|
| 94 |
+
# Call once at import (safe & light); heavy things remain lazy.
|
| 95 |
+
_load_db_map()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
+
# -------------------------------
|
| 99 |
+
# Adapter selection (lazy)
|
| 100 |
+
# -------------------------------
|
| 101 |
+
def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
|
| 102 |
+
"""
|
| 103 |
+
Resolve a DB adapter:
|
| 104 |
+
- postgres: requires POSTGRES_DSN
|
| 105 |
+
- sqlite with db_id: uploaded file or fallback locations
|
| 106 |
+
- sqlite default: DEFAULT_SQLITE_PATH must exist
|
| 107 |
+
"""
|
| 108 |
mode = os.getenv("DB_MODE", "sqlite").lower()
|
| 109 |
if mode == "postgres":
|
| 110 |
dsn = os.environ.get("POSTGRES_DSN")
|
|
|
|
| 113 |
return PostgresAdapter(dsn)
|
| 114 |
|
| 115 |
# sqlite mode
|
| 116 |
+
_cleanup_db_map()
|
| 117 |
if db_id:
|
| 118 |
+
# Check runtime map
|
| 119 |
+
entry = _DB_MAP.get(db_id)
|
| 120 |
+
candidates = []
|
| 121 |
+
if entry and os.path.exists(entry["path"]):
|
| 122 |
+
candidates.append(entry["path"])
|
| 123 |
+
# Fallback locations based on convention
|
| 124 |
+
candidates.append(os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite"))
|
| 125 |
+
candidates.append(str(UPLOAD_DIR / f"{db_id}.sqlite"))
|
| 126 |
+
|
| 127 |
+
for p in candidates:
|
| 128 |
+
if p and os.path.exists(p):
|
| 129 |
+
return SQLiteAdapter(p)
|
| 130 |
+
|
| 131 |
+
raise HTTPException(status_code=400, detail="invalid db_id (file not found)")
|
| 132 |
+
|
| 133 |
+
# default sqlite
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
if not Path(DEFAULT_SQLITE_PATH).exists():
|
| 135 |
raise HTTPException(status_code=500, detail="default DB not found")
|
| 136 |
return SQLiteAdapter(DEFAULT_SQLITE_PATH)
|
| 137 |
|
| 138 |
|
| 139 |
# -------------------------------
|
| 140 |
+
# LLM & Pipeline builders (lazy)
|
| 141 |
# -------------------------------
|
| 142 |
+
def _get_llm() -> OpenAIProvider:
|
| 143 |
+
# Create provider on demand, after .env has been loaded in app.main
|
| 144 |
return OpenAIProvider()
|
| 145 |
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
|
| 148 |
+
"""
|
| 149 |
+
Build a fresh Pipeline bound to the given adapter.
|
| 150 |
+
All stateful/external pieces (LLM, executor) are instantiated here (lazy).
|
| 151 |
+
"""
|
| 152 |
+
llm = _get_llm()
|
| 153 |
+
detector = AmbiguityDetector()
|
| 154 |
+
planner = Planner(llm=llm)
|
| 155 |
+
generator = Generator(llm=llm)
|
| 156 |
+
safety = Safety()
|
| 157 |
executor = Executor(adapter)
|
| 158 |
+
verifier = Verifier()
|
| 159 |
+
repair = Repair(llm=llm)
|
| 160 |
return Pipeline(
|
| 161 |
+
detector=detector,
|
| 162 |
+
planner=planner,
|
| 163 |
+
generator=generator,
|
| 164 |
+
safety=safety,
|
| 165 |
executor=executor,
|
| 166 |
+
verifier=verifier,
|
| 167 |
+
repair=repair,
|
| 168 |
)
|
| 169 |
|
| 170 |
|
| 171 |
# -------------------------------
|
| 172 |
+
# Helpers (unchanged)
|
| 173 |
# -------------------------------
|
| 174 |
def _to_dict(obj: Any) -> Any:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if is_dataclass(obj) and not isinstance(obj, type):
|
| 176 |
return asdict(obj) # type: ignore[arg-type]
|
| 177 |
return obj
|
| 178 |
|
| 179 |
|
| 180 |
def _round_trace(t: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
| 181 |
if t.get("cost_usd") is not None:
|
|
|
|
| 182 |
cost = t["cost_usd"]
|
| 183 |
if isinstance(cost, (int, float)):
|
| 184 |
t["cost_usd"] = round(float(cost), 6)
|
|
|
|
| 191 |
|
| 192 |
# -------------------------------
|
| 193 |
# Upload endpoint (SQLite only)
|
|
|
|
| 194 |
# -------------------------------
|
| 195 |
@router.post("/upload_db")
|
| 196 |
async def upload_db(file: UploadFile = File(...)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
if DB_MODE != "sqlite":
|
| 198 |
raise HTTPException(
|
| 199 |
status_code=400, detail="DB upload is only supported in sqlite mode"
|
|
|
|
| 227 |
|
| 228 |
# -------------------------------
|
| 229 |
# Main NL2SQL endpoint
|
|
|
|
| 230 |
# -------------------------------
|
| 231 |
@router.post("", name="nl2sql_handler")
|
| 232 |
def nl2sql_handler(request: NL2SQLRequest):
|
| 233 |
db_id = getattr(request, "db_id", None)
|
| 234 |
|
| 235 |
+
# Pick adapter per-request (default or uploaded or postgres)
|
| 236 |
+
adapter = _select_adapter(db_id)
|
| 237 |
+
|
| 238 |
+
# Build pipeline lazily with this adapter
|
| 239 |
+
pipeline = _build_pipeline(adapter)
|
| 240 |
+
|
| 241 |
+
# Derive schema preview only for sqlite with a real path
|
| 242 |
+
derived_preview_val: str = (
|
| 243 |
+
_derive_schema_preview(adapter) if isinstance(adapter, SQLiteAdapter) else ""
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Resolve schema_preview
|
|
|
|
| 247 |
provided_preview_any: Any = getattr(request, "schema_preview", None)
|
| 248 |
provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
|
| 249 |
final_preview: str = provided_preview or derived_preview_val
|
| 250 |
|
| 251 |
+
# Run pipeline
|
| 252 |
try:
|
| 253 |
result = pipeline.run(
|
| 254 |
user_query=request.query,
|
| 255 |
schema_preview=final_preview,
|
| 256 |
)
|
| 257 |
except Exception as exc:
|
|
|
|
| 258 |
raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
|
| 259 |
|
|
|
|
| 260 |
if not isinstance(result, FinalResult):
|
| 261 |
raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
|
| 262 |
|
|
|
|
| 263 |
if result.ambiguous and result.questions:
|
| 264 |
return ClarifyResponse(ambiguous=True, questions=result.questions)
|
| 265 |
|
|
|
|
| 266 |
if not result.ok or result.error:
|
| 267 |
print("❌ Pipeline failure dump:")
|
| 268 |
print(" ok:", result.ok)
|
|
|
|
| 274 |
detail="; ".join(result.details or []) or (result.error or "Unknown error"),
|
| 275 |
)
|
| 276 |
|
|
|
|
| 277 |
traces = [_round_trace(t) for t in (result.traces or [])]
|
| 278 |
return NL2SQLResponse(
|
| 279 |
ambiguous=False,
|
|
|
|
| 285 |
|
| 286 |
def _derive_schema_preview(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> str:
|
| 287 |
"""
|
| 288 |
+
Build a strict, exact-cased schema preview for the LLM (SQLite only).
|
|
|
|
| 289 |
"""
|
| 290 |
import sqlite3
|
|
|
|
| 291 |
|
|
|
|
| 292 |
db_path: Optional[str] = cast(
|
| 293 |
Optional[str], getattr(adapter, "db_path", None)
|
| 294 |
) or cast(Optional[str], getattr(adapter, "path", None))
|
|
|
|
| 304 |
lines = []
|
| 305 |
for (tname,) in tables:
|
| 306 |
cols = cur.execute(f"PRAGMA table_info('{tname}')").fetchall()
|
| 307 |
+
colnames = [c[1] for c in cols] # (cid, name, type, notnull, dflt, pk)
|
|
|
|
| 308 |
lines.append(f"{tname}({', '.join(colnames)})")
|
| 309 |
conn.close()
|
| 310 |
return "\n".join(lines)
|