Spaces:
Sleeping
Sleeping
File size: 1,062 Bytes
70ad082 b1d342c 70ad082 b1d342c 9624ca0 5cb5395 e46266e b1d342c 6bb200e 5cb5395 b1d342c e29208d 5cb5395 e29208d e46266e e29208d e46266e 5cb5395 e46266e e29208d 5cb5395 6bb200e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import onnxruntime as ort
import numpy as np
app = FastAPI()
# Load the ONNX model (assumed in the same folder)
session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# Optional: print model input shape for debugging
print("Model expects input shape:", session.get_inputs()[0].shape)
class InputData(BaseModel):
features: list[float]
@app.get("/")
def root():
return {"message": "ONNX model API is running. Use POST /predict with a JSON payload."}
@app.post("/predict")
async def predict(data: InputData):
try:
arr = np.array([data.features], dtype=np.float32)
print("Received input shape:", arr.shape)
pred = session.run([output_name], {input_name: arr})
return {"prediction": pred[0].tolist()}
except Exception as e:
print("Prediction error:", str(e))
raise HTTPException(status_code=500, detail=str(e))
|