Spaces:
Running
Running
Melika Kheirieh
commited on
Commit
·
e7d7c61
1
Parent(s):
7f99e2c
fix(api): make adapter non-optional within branch and annotate pipeline for mypy
Browse files- app/bootstrap.py +7 -0
- app/main.py +98 -10
- app/routers/nl2sql.py +4 -3
- requirements.txt +1 -0
- tests/test_health_and_metrics.py +19 -0
app/bootstrap.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
except Exception:
|
| 6 |
+
# optional: silently continue if python-dotenv is not installed
|
| 7 |
+
pass
|
app/main.py
CHANGED
|
@@ -1,11 +1,30 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
|
|
|
| 4 |
|
| 5 |
-
from
|
| 6 |
-
from app.routers import nl2sql # noqa: E402
|
| 7 |
|
| 8 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
try:
|
| 10 |
from app.routers.nl2sql import _load_db_map
|
| 11 |
|
|
@@ -15,16 +34,78 @@ except Exception as e:
|
|
| 15 |
|
| 16 |
app = FastAPI(
|
| 17 |
title="NL2SQL Copilot Prototype",
|
| 18 |
-
version="0.1.0",
|
| 19 |
-
description="
|
| 20 |
)
|
| 21 |
|
| 22 |
app.include_router(nl2sql.router, prefix="/api/v1")
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
@app.
|
| 26 |
-
def
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
@app.get("/")
|
|
@@ -34,4 +115,11 @@ def root():
|
|
| 34 |
|
| 35 |
@app.get("/health")
|
| 36 |
def health():
|
|
|
|
| 37 |
return {"status": "ok", "db": "connected", "llm": "reachable", "uptime_sec": 123.4}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from typing import Protocol, runtime_checkable, cast
|
| 4 |
|
| 5 |
+
from fastapi import FastAPI, Request, Response, HTTPException
|
| 6 |
+
from fastapi.responses import PlainTextResponse
|
| 7 |
|
| 8 |
+
from app.routers import nl2sql
|
|
|
|
| 9 |
|
| 10 |
+
# Prometheus
|
| 11 |
+
from prometheus_client import (
|
| 12 |
+
Counter,
|
| 13 |
+
Histogram,
|
| 14 |
+
CollectorRegistry,
|
| 15 |
+
generate_latest,
|
| 16 |
+
CONTENT_TYPE_LATEST,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@runtime_checkable
|
| 21 |
+
class HasPing(Protocol):
|
| 22 |
+
"""Minimal interface for adapters that support a connectivity check."""
|
| 23 |
+
|
| 24 |
+
def ping(self) -> None: ...
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ---- Optionally restore uploaded DB map ----
|
| 28 |
try:
|
| 29 |
from app.routers.nl2sql import _load_db_map
|
| 30 |
|
|
|
|
| 34 |
|
| 35 |
app = FastAPI(
|
| 36 |
title="NL2SQL Copilot Prototype",
|
| 37 |
+
version=os.getenv("APP_VERSION", "0.1.0"),
|
| 38 |
+
description="Convert natural language to safe & verified SQL",
|
| 39 |
)
|
| 40 |
|
| 41 |
app.include_router(nl2sql.router, prefix="/api/v1")
|
| 42 |
|
| 43 |
+
# ---- Prometheus metrics ----
|
| 44 |
+
REGISTRY: CollectorRegistry = CollectorRegistry()
|
| 45 |
+
|
| 46 |
+
REQUEST_COUNT = Counter(
|
| 47 |
+
"http_requests_total",
|
| 48 |
+
"Total HTTP requests",
|
| 49 |
+
["path", "method", "status_code"],
|
| 50 |
+
registry=REGISTRY,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
REQUEST_LATENCY = Histogram(
|
| 54 |
+
"http_request_latency_seconds",
|
| 55 |
+
"Request latency",
|
| 56 |
+
["path", "method"],
|
| 57 |
+
registry=REGISTRY,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
|
| 61 |
+
@app.middleware("http")
|
| 62 |
+
async def metrics_middleware(request: Request, call_next):
|
| 63 |
+
start = time.perf_counter()
|
| 64 |
+
response: Response = await call_next(request)
|
| 65 |
+
elapsed = time.perf_counter() - start
|
| 66 |
+
|
| 67 |
+
# Use route path if available, else raw path (typed guard for mypy)
|
| 68 |
+
route = request.scope.get("route")
|
| 69 |
+
path = (
|
| 70 |
+
route.path
|
| 71 |
+
if (route is not None and hasattr(route, "path"))
|
| 72 |
+
else request.url.path
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
REQUEST_COUNT.labels(
|
| 76 |
+
path=path, method=request.method, status_code=str(response.status_code)
|
| 77 |
+
).inc()
|
| 78 |
+
REQUEST_LATENCY.labels(path=path, method=request.method).observe(elapsed)
|
| 79 |
+
return response
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# --- Liveness (super light) ---
|
| 83 |
+
@app.get("/healthz", response_class=PlainTextResponse, tags=["system"])
|
| 84 |
+
def healthz() -> str:
|
| 85 |
+
return "ok"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# --- Readiness (checks DB/env lightly) ---
|
| 89 |
+
@app.get("/readyz", response_class=PlainTextResponse, tags=["system"])
|
| 90 |
+
def readyz() -> str:
|
| 91 |
+
mode = os.getenv("DB_MODE", "sqlite").lower()
|
| 92 |
+
try:
|
| 93 |
+
if mode == "postgres":
|
| 94 |
+
from adapters.db.postgres_adapter import PostgresAdapter
|
| 95 |
+
|
| 96 |
+
dsn = os.environ["POSTGRES_DSN"]
|
| 97 |
+
# Call ping inline; avoid cross-branch variable typing
|
| 98 |
+
cast(HasPing, PostgresAdapter(dsn)).ping()
|
| 99 |
+
else:
|
| 100 |
+
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 101 |
+
|
| 102 |
+
db_path = os.getenv("SQLITE_DB_PATH", "data/chinook.db")
|
| 103 |
+
cast(HasPing, SQLiteAdapter(db_path)).ping()
|
| 104 |
+
|
| 105 |
+
# if not os.getenv("PROXY_API_KEY"): pass
|
| 106 |
+
return "ready"
|
| 107 |
+
except Exception:
|
| 108 |
+
raise HTTPException(status_code=503, detail="not ready")
|
| 109 |
|
| 110 |
|
| 111 |
@app.get("/")
|
|
|
|
| 115 |
|
| 116 |
@app.get("/health")
|
| 117 |
def health():
|
| 118 |
+
# You might want to replace the placeholders with real checks later.
|
| 119 |
return {"status": "ok", "db": "connected", "llm": "reachable", "uptime_sec": 123.4}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@app.get("/metrics", tags=["system"])
|
| 123 |
+
def metrics():
|
| 124 |
+
data = generate_latest(REGISTRY)
|
| 125 |
+
return Response(content=data, media_type=CONTENT_TYPE_LATEST)
|
app/routers/nl2sql.py
CHANGED
|
@@ -285,11 +285,13 @@ async def upload_db(file: UploadFile = File(...)):
|
|
| 285 |
@router.post("", name="nl2sql_handler")
|
| 286 |
def nl2sql_handler(request: NL2SQLRequest):
|
| 287 |
db_id = getattr(request, "db_id", None)
|
| 288 |
-
adapter: Optional[Union[PostgresAdapter, SQLiteAdapter]] = None
|
| 289 |
|
| 290 |
# 1) Pick pipeline (+ optional per-request adapter)
|
|
|
|
| 291 |
if db_id:
|
| 292 |
-
adapter = _select_adapter(db_id)
|
|
|
|
|
|
|
| 293 |
pipeline = _build_pipeline(adapter)
|
| 294 |
derived_preview_val: str = _derive_schema_preview(adapter)
|
| 295 |
else:
|
|
@@ -299,7 +301,6 @@ def nl2sql_handler(request: NL2SQLRequest):
|
|
| 299 |
# 2) Resolve schema_preview
|
| 300 |
provided_preview_any: Any = getattr(request, "schema_preview", None)
|
| 301 |
provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
|
| 302 |
-
|
| 303 |
final_preview: str = provided_preview or derived_preview_val
|
| 304 |
|
| 305 |
# 3) Run pipeline
|
|
|
|
| 285 |
@router.post("", name="nl2sql_handler")
|
| 286 |
def nl2sql_handler(request: NL2SQLRequest):
|
| 287 |
db_id = getattr(request, "db_id", None)
|
|
|
|
| 288 |
|
| 289 |
# 1) Pick pipeline (+ optional per-request adapter)
|
| 290 |
+
pipeline: Pipeline
|
| 291 |
if db_id:
|
| 292 |
+
adapter = _select_adapter(db_id) # returns PostgresAdapter | SQLiteAdapter
|
| 293 |
+
# If _select_adapter could theoretically return None, uncomment the next line:
|
| 294 |
+
# assert adapter is not None, "adapter must be set when db_id is provided"
|
| 295 |
pipeline = _build_pipeline(adapter)
|
| 296 |
derived_preview_val: str = _derive_schema_preview(adapter)
|
| 297 |
else:
|
|
|
|
| 301 |
# 2) Resolve schema_preview
|
| 302 |
provided_preview_any: Any = getattr(request, "schema_preview", None)
|
| 303 |
provided_preview: Optional[str] = cast(Optional[str], provided_preview_any)
|
|
|
|
| 304 |
final_preview: str = provided_preview or derived_preview_val
|
| 305 |
|
| 306 |
# 3) Run pipeline
|
requirements.txt
CHANGED
|
@@ -9,6 +9,7 @@ pytest==8.3.3
|
|
| 9 |
python-dotenv==1.1.1
|
| 10 |
openai==2.6.1
|
| 11 |
psycopg[binary]~=3.2
|
|
|
|
| 12 |
ruff
|
| 13 |
gradio
|
| 14 |
sqlalchemy
|
|
|
|
| 9 |
python-dotenv==1.1.1
|
| 10 |
openai==2.6.1
|
| 11 |
psycopg[binary]~=3.2
|
| 12 |
+
prometheus-client>=0.20.0
|
| 13 |
ruff
|
| 14 |
gradio
|
| 15 |
sqlalchemy
|
tests/test_health_and_metrics.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tests/test_health_and_metrics.py
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
from app.main import app
|
| 4 |
+
|
| 5 |
+
client = TestClient(app)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_healthz_ok():
|
| 9 |
+
r = client.get("/healthz")
|
| 10 |
+
assert r.status_code == 200
|
| 11 |
+
assert r.text == "ok"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_metrics_exposes_prometheus():
|
| 15 |
+
# Hit one endpoint to bump counters
|
| 16 |
+
client.get("/healthz")
|
| 17 |
+
r = client.get("/metrics")
|
| 18 |
+
assert r.status_code == 200
|
| 19 |
+
assert "http_requests_total" in r.text
|