stroke_detect / main.py
madhwdh11's picture
Update main.py
5cb5395 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import onnxruntime as ort
import numpy as np
app = FastAPI()
# Load the ONNX model (assumed in the same folder)
session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# Optional: print model input shape for debugging
print("Model expects input shape:", session.get_inputs()[0].shape)
class InputData(BaseModel):
features: list[float]
@app.get("/")
def root():
return {"message": "ONNX model API is running. Use POST /predict with a JSON payload."}
@app.post("/predict")
async def predict(data: InputData):
try:
arr = np.array([data.features], dtype=np.float32)
print("Received input shape:", arr.shape)
pred = session.run([output_name], {input_name: arr})
return {"prediction": pred[0].tolist()}
except Exception as e:
print("Prediction error:", str(e))
raise HTTPException(status_code=500, detail=str(e))