"""FastAPI app for live datacenter verification model inference.""" from __future__ import annotations import os from contextlib import asynccontextmanager from typing import AsyncIterator from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from src.datacenter_verification_api import __version__, build_info from src.datacenter_verification_api.model_service import ModelService from src.datacenter_verification_api.schemas import ( BatchPredictRequest, BatchPredictResponse, HealthResponse, MetadataResponse, PredictRequest, PredictResponse, ) DEFAULT_ALLOWED_ORIGINS = [ "https://idacy.github.io", "http://localhost:8000", "http://localhost:5173", "http://localhost:7860", "http://127.0.0.1:8000", "http://127.0.0.1:5173", "http://127.0.0.1:7860", ] service: ModelService | None = None load_error: str | None = None def parse_allowed_origins() -> list[str]: raw = os.getenv("DCV_ALLOWED_ORIGINS", "") if not raw.strip(): return DEFAULT_ALLOWED_ORIGINS return [item.strip() for item in raw.split(",") if item.strip()] @asynccontextmanager async def lifespan(_: FastAPI) -> AsyncIterator[None]: global service, load_error try: service = ModelService.from_env() load_error = None except Exception as exc: # pragma: no cover - exercised by deployment misconfiguration service = None load_error = str(exc) yield app = FastAPI( title="Datacenter Verification Live Inference API", version=__version__, lifespan=lifespan, ) allowed_origins = parse_allowed_origins() app.add_middleware( CORSMiddleware, allow_origins=["*"] if "*" in allowed_origins else allowed_origins, allow_credentials=False, allow_methods=["GET", "POST", "OPTIONS"], allow_headers=["*"], ) def require_service() -> ModelService: if service is None: detail = "model service is not loaded" if load_error: detail = f"{detail}: {load_error}" raise HTTPException(status_code=503, detail=detail) return service @app.get("/health", response_model=HealthResponse) def health() -> HealthResponse: loaded = service is not None build = build_info() return HealthResponse( status="ok" if loaded else "error", model_loaded=loaded, api_version=__version__, build_sha=build.sha, build_source=build.source, model_run_id=service.model_run_id if service else None, dataset_id=service.dataset_id if service else None, feature_count=len(service.feature_columns) if service else 0, base_row_lookup_enabled=bool(service.feature_lookup) if service else False, error=load_error, ) @app.get("/metadata", response_model=MetadataResponse) def metadata() -> MetadataResponse: return require_service().metadata() @app.post("/predict", response_model=PredictResponse) def predict(request: PredictRequest) -> PredictResponse: return require_service().predict(request) @app.post("/predict/batch", response_model=BatchPredictResponse) def predict_batch(request: BatchPredictRequest) -> BatchPredictResponse: loaded = require_service() predictions = [loaded.predict(item) for item in request.requests] return BatchPredictResponse(model_run_id=loaded.model_run_id, predictions=predictions)