File size: 3,399 Bytes
e4b1ed6
 
 
 
 
 
 
 
 
 
 
c789799
e4b1ed6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c789799
e4b1ed6
 
 
 
c789799
 
e4b1ed6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""FastAPI app for live datacenter verification model inference."""

from __future__ import annotations

import os
from contextlib import asynccontextmanager
from typing import AsyncIterator

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware

from src.datacenter_verification_api import __version__, build_info
from src.datacenter_verification_api.model_service import ModelService
from src.datacenter_verification_api.schemas import (
    BatchPredictRequest,
    BatchPredictResponse,
    HealthResponse,
    MetadataResponse,
    PredictRequest,
    PredictResponse,
)


DEFAULT_ALLOWED_ORIGINS = [
    "https://idacy.github.io",
    "http://localhost:8000",
    "http://localhost:5173",
    "http://localhost:7860",
    "http://127.0.0.1:8000",
    "http://127.0.0.1:5173",
    "http://127.0.0.1:7860",
]

service: ModelService | None = None
load_error: str | None = None


def parse_allowed_origins() -> list[str]:
    raw = os.getenv("DCV_ALLOWED_ORIGINS", "")
    if not raw.strip():
        return DEFAULT_ALLOWED_ORIGINS
    return [item.strip() for item in raw.split(",") if item.strip()]


@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
    global service, load_error
    try:
        service = ModelService.from_env()
        load_error = None
    except Exception as exc:  # pragma: no cover - exercised by deployment misconfiguration
        service = None
        load_error = str(exc)
    yield


app = FastAPI(
    title="Datacenter Verification Live Inference API",
    version=__version__,
    lifespan=lifespan,
)

allowed_origins = parse_allowed_origins()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"] if "*" in allowed_origins else allowed_origins,
    allow_credentials=False,
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["*"],
)


def require_service() -> ModelService:
    if service is None:
        detail = "model service is not loaded"
        if load_error:
            detail = f"{detail}: {load_error}"
        raise HTTPException(status_code=503, detail=detail)
    return service


@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
    loaded = service is not None
    build = build_info()
    return HealthResponse(
        status="ok" if loaded else "error",
        model_loaded=loaded,
        api_version=__version__,
        build_sha=build.sha,
        build_source=build.source,
        model_run_id=service.model_run_id if service else None,
        dataset_id=service.dataset_id if service else None,
        feature_count=len(service.feature_columns) if service else 0,
        base_row_lookup_enabled=bool(service.feature_lookup) if service else False,
        error=load_error,
    )


@app.get("/metadata", response_model=MetadataResponse)
def metadata() -> MetadataResponse:
    return require_service().metadata()


@app.post("/predict", response_model=PredictResponse)
def predict(request: PredictRequest) -> PredictResponse:
    return require_service().predict(request)


@app.post("/predict/batch", response_model=BatchPredictResponse)
def predict_batch(request: BatchPredictRequest) -> BatchPredictResponse:
    loaded = require_service()
    predictions = [loaded.predict(item) for item in request.requests]
    return BatchPredictResponse(model_run_id=loaded.model_run_id, predictions=predictions)