| """FastAPI app for live datacenter verification model inference.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| from contextlib import asynccontextmanager |
| from typing import AsyncIterator |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| from src.datacenter_verification_api import __version__, build_info |
| from src.datacenter_verification_api.model_service import ModelService |
| from src.datacenter_verification_api.schemas import ( |
| BatchPredictRequest, |
| BatchPredictResponse, |
| HealthResponse, |
| MetadataResponse, |
| PredictRequest, |
| PredictResponse, |
| ) |
|
|
|
|
| DEFAULT_ALLOWED_ORIGINS = [ |
| "https://idacy.github.io", |
| "http://localhost:8000", |
| "http://localhost:5173", |
| "http://localhost:7860", |
| "http://127.0.0.1:8000", |
| "http://127.0.0.1:5173", |
| "http://127.0.0.1:7860", |
| ] |
|
|
| service: ModelService | None = None |
| load_error: str | None = None |
|
|
|
|
| def parse_allowed_origins() -> list[str]: |
| raw = os.getenv("DCV_ALLOWED_ORIGINS", "") |
| if not raw.strip(): |
| return DEFAULT_ALLOWED_ORIGINS |
| return [item.strip() for item in raw.split(",") if item.strip()] |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(_: FastAPI) -> AsyncIterator[None]: |
| global service, load_error |
| try: |
| service = ModelService.from_env() |
| load_error = None |
| except Exception as exc: |
| service = None |
| load_error = str(exc) |
| yield |
|
|
|
|
| app = FastAPI( |
| title="Datacenter Verification Live Inference API", |
| version=__version__, |
| lifespan=lifespan, |
| ) |
|
|
| allowed_origins = parse_allowed_origins() |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"] if "*" in allowed_origins else allowed_origins, |
| allow_credentials=False, |
| allow_methods=["GET", "POST", "OPTIONS"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| def require_service() -> ModelService: |
| if service is None: |
| detail = "model service is not loaded" |
| if load_error: |
| detail = f"{detail}: {load_error}" |
| raise HTTPException(status_code=503, detail=detail) |
| return service |
|
|
|
|
| @app.get("/health", response_model=HealthResponse) |
| def health() -> HealthResponse: |
| loaded = service is not None |
| build = build_info() |
| return HealthResponse( |
| status="ok" if loaded else "error", |
| model_loaded=loaded, |
| api_version=__version__, |
| build_sha=build.sha, |
| build_source=build.source, |
| model_run_id=service.model_run_id if service else None, |
| dataset_id=service.dataset_id if service else None, |
| feature_count=len(service.feature_columns) if service else 0, |
| base_row_lookup_enabled=bool(service.feature_lookup) if service else False, |
| error=load_error, |
| ) |
|
|
|
|
| @app.get("/metadata", response_model=MetadataResponse) |
| def metadata() -> MetadataResponse: |
| return require_service().metadata() |
|
|
|
|
| @app.post("/predict", response_model=PredictResponse) |
| def predict(request: PredictRequest) -> PredictResponse: |
| return require_service().predict(request) |
|
|
|
|
| @app.post("/predict/batch", response_model=BatchPredictResponse) |
| def predict_batch(request: BatchPredictRequest) -> BatchPredictResponse: |
| loaded = require_service() |
| predictions = [loaded.predict(item) for item in request.requests] |
| return BatchPredictResponse(model_run_id=loaded.model_run_id, predictions=predictions) |
|
|