Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import onnxruntime as ort | |
| import numpy as np | |
| app = FastAPI() | |
| # Load the ONNX model (assumed in the same folder) | |
| session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"]) | |
| input_name = session.get_inputs()[0].name | |
| output_name = session.get_outputs()[0].name | |
| # Optional: print model input shape for debugging | |
| print("Model expects input shape:", session.get_inputs()[0].shape) | |
| class InputData(BaseModel): | |
| features: list[float] | |
| def root(): | |
| return {"message": "ONNX model API is running. Use POST /predict with a JSON payload."} | |
| async def predict(data: InputData): | |
| try: | |
| arr = np.array([data.features], dtype=np.float32) | |
| print("Received input shape:", arr.shape) | |
| pred = session.run([output_name], {input_name: arr}) | |
| return {"prediction": pred[0].tolist()} | |
| except Exception as e: | |
| print("Prediction error:", str(e)) | |
| raise HTTPException(status_code=500, detail=str(e)) | |