Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,10 +33,11 @@ ort_session = ort.InferenceSession("densenet-9.onnx")
|
|
| 33 |
def predict(pil):
|
| 34 |
input_tensor = preprocess(pil)
|
| 35 |
img_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
|
|
|
| 36 |
|
| 37 |
outputs = ort_session.run(
|
| 38 |
None,
|
| 39 |
-
{"data_0":
|
| 40 |
)
|
| 41 |
|
| 42 |
a = np.argsort(-outputs[0].flatten())
|
|
|
|
| 33 |
def predict(pil):
|
| 34 |
input_tensor = preprocess(pil)
|
| 35 |
img_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
| 36 |
+
img_batch_np = img_batch.cpu().detach().numpy()
|
| 37 |
|
| 38 |
outputs = ort_session.run(
|
| 39 |
None,
|
| 40 |
+
{"data_0": img_batch_np.astype(np.float32)},
|
| 41 |
)
|
| 42 |
|
| 43 |
a = np.argsort(-outputs[0].flatten())
|