File size: 1,917 Bytes
7d2a23e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a290f8
 
 
 
 
7d2a23e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from pathlib import Path

from fastapi import FastAPI, HTTPException

from .registry import BundleConfigError, ModelRegistry, RequestValidationError
from .schemas import PredictionRequest


def create_app(bundle_root: Path) -> FastAPI:
    registry = ModelRegistry(bundle_root)

    app = FastAPI(
        title="SQuADDS ML Inference API",
        version="0.1.0",
        description=(
            "HTTP API for running inference against ML models trained in "
            "ML_qubit_design and packaged for the SQuADDS Hugging Face Space."
        ),
    )

    @app.get("/")
    def root() -> dict:
        return {
            "service": "SQuADDS ML Inference API",
            "docs": "/docs",
            "models_endpoint": "/models",
            "predict_endpoint": "/predict",
        }

    @app.get("/health")
    def health() -> dict:
        return {
            "status": "ok",
            "available_models": registry.available_model_ids(),
            "bundle_root": str(bundle_root),
        }

    @app.get("/models")
    def list_models() -> dict:
        return {"models": registry.describe_models()}

    @app.post("/predict")
    def predict(request: PredictionRequest) -> dict:
        try:
            payload = registry.predict(
                model_id=request.model_id,
                inputs=request.inputs,
                include_scaled_outputs=request.options.include_scaled_outputs,
            )
        except RequestValidationError as exc:
            raise HTTPException(status_code=400, detail=str(exc)) from exc
        except BundleConfigError as exc:
            raise HTTPException(status_code=500, detail=str(exc)) from exc
        except Exception as exc:
            raise HTTPException(
                status_code=500,
                detail=f"Unexpected inference error for model '{request.model_id}': {exc}",
            ) from exc
        return payload

    return app