madhwdh11 commited on
Commit
5cb5395
·
verified ·
1 Parent(s): 4822fef

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +10 -1
main.py CHANGED
@@ -5,19 +5,28 @@ import numpy as np
5
 
6
  app = FastAPI()
7
 
8
- # Load the ONNX model (in same folder as this script)
9
  session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
10
  input_name = session.get_inputs()[0].name
11
  output_name = session.get_outputs()[0].name
12
 
 
 
 
13
  class InputData(BaseModel):
14
  features: list[float]
15
 
 
 
 
 
16
  @app.post("/predict")
17
  async def predict(data: InputData):
18
  try:
19
  arr = np.array([data.features], dtype=np.float32)
 
20
  pred = session.run([output_name], {input_name: arr})
21
  return {"prediction": pred[0].tolist()}
22
  except Exception as e:
 
23
  raise HTTPException(status_code=500, detail=str(e))
 
5
 
6
  app = FastAPI()
7
 
8
+ # Load the ONNX model (assumed in the same folder)
9
  session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
10
  input_name = session.get_inputs()[0].name
11
  output_name = session.get_outputs()[0].name
12
 
13
+ # Optional: print model input shape for debugging
14
+ print("Model expects input shape:", session.get_inputs()[0].shape)
15
+
16
  class InputData(BaseModel):
17
  features: list[float]
18
 
19
+ @app.get("/")
20
+ def root():
21
+ return {"message": "ONNX model API is running. Use POST /predict with a JSON payload."}
22
+
23
  @app.post("/predict")
24
  async def predict(data: InputData):
25
  try:
26
  arr = np.array([data.features], dtype=np.float32)
27
+ print("Received input shape:", arr.shape)
28
  pred = session.run([output_name], {input_name: arr})
29
  return {"prediction": pred[0].tolist()}
30
  except Exception as e:
31
+ print("Prediction error:", str(e))
32
  raise HTTPException(status_code=500, detail=str(e))