idacy's picture
Fix live inference slider regression
c789799 verified
Raw
History Blame Contribute Delete
3.4 kB
"""FastAPI app for live datacenter verification model inference."""
from __future__ import annotations
import os
from contextlib import asynccontextmanager
from typing import AsyncIterator
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from src.datacenter_verification_api import __version__, build_info
from src.datacenter_verification_api.model_service import ModelService
from src.datacenter_verification_api.schemas import (
BatchPredictRequest,
BatchPredictResponse,
HealthResponse,
MetadataResponse,
PredictRequest,
PredictResponse,
)
DEFAULT_ALLOWED_ORIGINS = [
"https://idacy.github.io",
"http://localhost:8000",
"http://localhost:5173",
"http://localhost:7860",
"http://127.0.0.1:8000",
"http://127.0.0.1:5173",
"http://127.0.0.1:7860",
]
service: ModelService | None = None
load_error: str | None = None
def parse_allowed_origins() -> list[str]:
raw = os.getenv("DCV_ALLOWED_ORIGINS", "")
if not raw.strip():
return DEFAULT_ALLOWED_ORIGINS
return [item.strip() for item in raw.split(",") if item.strip()]
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
global service, load_error
try:
service = ModelService.from_env()
load_error = None
except Exception as exc: # pragma: no cover - exercised by deployment misconfiguration
service = None
load_error = str(exc)
yield
app = FastAPI(
title="Datacenter Verification Live Inference API",
version=__version__,
lifespan=lifespan,
)
allowed_origins = parse_allowed_origins()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"] if "*" in allowed_origins else allowed_origins,
allow_credentials=False,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
def require_service() -> ModelService:
if service is None:
detail = "model service is not loaded"
if load_error:
detail = f"{detail}: {load_error}"
raise HTTPException(status_code=503, detail=detail)
return service
@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
loaded = service is not None
build = build_info()
return HealthResponse(
status="ok" if loaded else "error",
model_loaded=loaded,
api_version=__version__,
build_sha=build.sha,
build_source=build.source,
model_run_id=service.model_run_id if service else None,
dataset_id=service.dataset_id if service else None,
feature_count=len(service.feature_columns) if service else 0,
base_row_lookup_enabled=bool(service.feature_lookup) if service else False,
error=load_error,
)
@app.get("/metadata", response_model=MetadataResponse)
def metadata() -> MetadataResponse:
return require_service().metadata()
@app.post("/predict", response_model=PredictResponse)
def predict(request: PredictRequest) -> PredictResponse:
return require_service().predict(request)
@app.post("/predict/batch", response_model=BatchPredictResponse)
def predict_batch(request: BatchPredictRequest) -> BatchPredictResponse:
loaded = require_service()
predictions = [loaded.predict(item) for item in request.requests]
return BatchPredictResponse(model_run_id=loaded.model_run_id, predictions=predictions)