Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
import gradio as gr
|
| 8 |
|
| 9 |
import mxnet
|
| 10 |
-
from
|
| 11 |
|
| 12 |
os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
|
| 13 |
|
|
@@ -30,28 +30,21 @@ def get_image(path):
|
|
| 30 |
img = np.array(img.convert('RGB'))
|
| 31 |
return img
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
transforms.CenterCrop(224),
|
| 41 |
-
transforms.ToTensor(),
|
| 42 |
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 43 |
-
])
|
| 44 |
-
img = mxnet.ndarray.array(img)
|
| 45 |
-
img = transform_fn(img)
|
| 46 |
-
img = img.expand_dims(axis=0) # batchify
|
| 47 |
|
| 48 |
-
return img.asnumpy()
|
| 49 |
|
| 50 |
|
| 51 |
def predict(path):
|
| 52 |
img = get_image(path)
|
| 53 |
-
|
| 54 |
-
|
|
|
|
| 55 |
preds = session.run(None, ort_inputs)[0]
|
| 56 |
preds = np.squeeze(preds)
|
| 57 |
a = np.argsort(preds)
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
|
| 9 |
import mxnet
|
| 10 |
+
from torchvision import transforms
|
| 11 |
|
| 12 |
os.system("wget https://s3.amazonaws.com/onnx-model-zoo/synset.txt")
|
| 13 |
|
|
|
|
| 30 |
img = np.array(img.convert('RGB'))
|
| 31 |
return img
|
| 32 |
|
| 33 |
+
preprocess = transforms.Compose([
|
| 34 |
+
transforms.Resize(256),
|
| 35 |
+
transforms.CenterCrop(224),
|
| 36 |
+
transforms.ToTensor(),
|
| 37 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 38 |
+
])
|
| 39 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
def predict(path):
|
| 44 |
img = get_image(path)
|
| 45 |
+
input_tensor = preprocess(img)
|
| 46 |
+
img = input_tensor.unsqueeze(0)
|
| 47 |
+
ort_inputs = {session.get_inputs()[0].name: img.cpu().detach().numpy()}
|
| 48 |
preds = session.run(None, ort_inputs)[0]
|
| 49 |
preds = np.squeeze(preds)
|
| 50 |
a = np.argsort(preds)
|