from fastapi import FastAPI, HTTPException from pydantic import BaseModel import onnxruntime as ort import numpy as np app = FastAPI() # Load the ONNX model (assumed in the same folder) session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"]) input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name # Optional: print model input shape for debugging print("Model expects input shape:", session.get_inputs()[0].shape) class InputData(BaseModel): features: list[float] @app.get("/") def root(): return {"message": "ONNX model API is running. Use POST /predict with a JSON payload."} @app.post("/predict") async def predict(data: InputData): try: arr = np.array([data.features], dtype=np.float32) print("Received input shape:", arr.shape) pred = session.run([output_name], {input_name: arr}) return {"prediction": pred[0].tolist()} except Exception as e: print("Prediction error:", str(e)) raise HTTPException(status_code=500, detail=str(e))