File size: 1,062 Bytes
70ad082
 
b1d342c
70ad082
 
b1d342c
9624ca0
5cb5395
e46266e
b1d342c
 
6bb200e
5cb5395
 
 
b1d342c
 
e29208d
5cb5395
 
 
 
e29208d
e46266e
e29208d
e46266e
5cb5395
e46266e
 
e29208d
5cb5395
6bb200e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import onnxruntime as ort
import numpy as np

app = FastAPI()

# Load the ONNX model (assumed in the same folder)
session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# Optional: print model input shape for debugging
print("Model expects input shape:", session.get_inputs()[0].shape)

class InputData(BaseModel):
    features: list[float]

@app.get("/")
def root():
    return {"message": "ONNX model API is running. Use POST /predict with a JSON payload."}

@app.post("/predict")
async def predict(data: InputData):
    try:
        arr = np.array([data.features], dtype=np.float32)
        print("Received input shape:", arr.shape)
        pred = session.run([output_name], {input_name: arr})
        return {"prediction": pred[0].tolist()}
    except Exception as e:
        print("Prediction error:", str(e))
        raise HTTPException(status_code=500, detail=str(e))