Update infer_onnx.py
#1
by
hangyang-amd
- opened
- infer_onnx.py +3 -0
infer_onnx.py
CHANGED
|
@@ -25,6 +25,7 @@ import torchvision.transforms as transforms
|
|
| 25 |
parser = argparse.ArgumentParser()
|
| 26 |
parser.add_argument('--onnx_path', type=str, default="EfficientNet_int.onnx", required=False)
|
| 27 |
parser.add_argument('--image_path', type=str, required=True)
|
|
|
|
| 28 |
parser.add_argument(
|
| 29 |
"--ipu",
|
| 30 |
action="store_true",
|
|
@@ -51,6 +52,8 @@ def read_image():
|
|
| 51 |
normalize,
|
| 52 |
])
|
| 53 |
img_tensor = transform(image).unsqueeze(0)
|
|
|
|
|
|
|
| 54 |
return img_tensor.numpy()
|
| 55 |
|
| 56 |
|
|
|
|
| 25 |
parser = argparse.ArgumentParser()
|
| 26 |
parser.add_argument('--onnx_path', type=str, default="EfficientNet_int.onnx", required=False)
|
| 27 |
parser.add_argument('--image_path', type=str, required=True)
|
| 28 |
+
parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
|
| 29 |
parser.add_argument(
|
| 30 |
"--ipu",
|
| 31 |
action="store_true",
|
|
|
|
| 52 |
normalize,
|
| 53 |
])
|
| 54 |
img_tensor = transform(image).unsqueeze(0)
|
| 55 |
+
if args.data_format == "nhwc":
|
| 56 |
+
img_tensor = transform(image).unsqueeze(0).transpose(1, 3)
|
| 57 |
return img_tensor.numpy()
|
| 58 |
|
| 59 |
|