Spaces:
Sleeping
Sleeping
File size: 3,399 Bytes
e4b1ed6 c789799 e4b1ed6 c789799 e4b1ed6 c789799 e4b1ed6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | """FastAPI app for live datacenter verification model inference."""
from __future__ import annotations
import os
from contextlib import asynccontextmanager
from typing import AsyncIterator
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from src.datacenter_verification_api import __version__, build_info
from src.datacenter_verification_api.model_service import ModelService
from src.datacenter_verification_api.schemas import (
BatchPredictRequest,
BatchPredictResponse,
HealthResponse,
MetadataResponse,
PredictRequest,
PredictResponse,
)
DEFAULT_ALLOWED_ORIGINS = [
"https://idacy.github.io",
"http://localhost:8000",
"http://localhost:5173",
"http://localhost:7860",
"http://127.0.0.1:8000",
"http://127.0.0.1:5173",
"http://127.0.0.1:7860",
]
service: ModelService | None = None
load_error: str | None = None
def parse_allowed_origins() -> list[str]:
raw = os.getenv("DCV_ALLOWED_ORIGINS", "")
if not raw.strip():
return DEFAULT_ALLOWED_ORIGINS
return [item.strip() for item in raw.split(",") if item.strip()]
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
global service, load_error
try:
service = ModelService.from_env()
load_error = None
except Exception as exc: # pragma: no cover - exercised by deployment misconfiguration
service = None
load_error = str(exc)
yield
app = FastAPI(
title="Datacenter Verification Live Inference API",
version=__version__,
lifespan=lifespan,
)
allowed_origins = parse_allowed_origins()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"] if "*" in allowed_origins else allowed_origins,
allow_credentials=False,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
def require_service() -> ModelService:
if service is None:
detail = "model service is not loaded"
if load_error:
detail = f"{detail}: {load_error}"
raise HTTPException(status_code=503, detail=detail)
return service
@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
loaded = service is not None
build = build_info()
return HealthResponse(
status="ok" if loaded else "error",
model_loaded=loaded,
api_version=__version__,
build_sha=build.sha,
build_source=build.source,
model_run_id=service.model_run_id if service else None,
dataset_id=service.dataset_id if service else None,
feature_count=len(service.feature_columns) if service else 0,
base_row_lookup_enabled=bool(service.feature_lookup) if service else False,
error=load_error,
)
@app.get("/metadata", response_model=MetadataResponse)
def metadata() -> MetadataResponse:
return require_service().metadata()
@app.post("/predict", response_model=PredictResponse)
def predict(request: PredictRequest) -> PredictResponse:
return require_service().predict(request)
@app.post("/predict/batch", response_model=BatchPredictResponse)
def predict_batch(request: BatchPredictRequest) -> BatchPredictResponse:
loaded = require_service()
predictions = [loaded.predict(item) for item in request.requests]
return BatchPredictResponse(model_run_id=loaded.model_run_id, predictions=predictions)
|