Add appropriate pre- to input image and postprocessing to validate function
Browse files- convert_to_mixed.py +54 -18
convert_to_mixed.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
-
import sys
|
| 2 |
import numpy as np
|
| 3 |
-
from PIL import Image
|
| 4 |
import onnx
|
| 5 |
-
import onnxruntime as ort
|
| 6 |
from onnxconverter_common import auto_mixed_precision_model_path
|
| 7 |
import argparse
|
| 8 |
from rtmo_gpu import RTMO_GPU, draw_skeleton
|
|
@@ -24,28 +21,49 @@ def detect_model_input_size(model_path):
|
|
| 24 |
return tuple(dims[2:4]) # Return (height, width)
|
| 25 |
raise ValueError("Input node 'input' not found in the model")
|
| 26 |
|
| 27 |
-
def load_and_preprocess_image(image_path,
|
| 28 |
-
|
| 29 |
-
image =
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
| 33 |
return image
|
| 34 |
|
| 35 |
-
def
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return False
|
| 43 |
return True
|
| 44 |
|
| 45 |
def infer_on_image(onnx_model, model_input_size, test_image_path):
|
| 46 |
body = RTMO_GPU(onnx_model=onnx_model,
|
| 47 |
model_input_size=model_input_size,
|
| 48 |
-
is_yolo_nas_pose=
|
| 49 |
|
| 50 |
frame = cv2.imread(test_image_path)
|
| 51 |
img_show = frame.copy()
|
|
@@ -63,12 +81,29 @@ def infer_on_image(onnx_model, model_input_size, test_image_path):
|
|
| 63 |
|
| 64 |
def main(args):
|
| 65 |
model_input_size = detect_model_input_size(args.source_model_path)
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
auto_mixed_precision_model_path.auto_convert_mixed_precision_model_path(source_model_path=args.source_model_path,
|
| 69 |
input_feed=input_feed,
|
| 70 |
target_model_path=args.target_model_path,
|
| 71 |
-
customized_validate_func=validate_pose,
|
| 72 |
rtol=args.rtol, atol=args.atol,
|
| 73 |
provider=PROVIDERS,
|
| 74 |
keep_io_types=True,
|
|
@@ -83,6 +118,7 @@ if __name__ == "__main__":
|
|
| 83 |
parser.add_argument("test_image_path", type=str, help="Path to a test image for validating the model conversion.")
|
| 84 |
parser.add_argument('--rtol', type=float, default=0.01, help=' the relative tolerance to do validation')
|
| 85 |
parser.add_argument('--atol', type=float, default=0.001, help=' the absolute tolerance to do validation')
|
|
|
|
| 86 |
|
| 87 |
args = parser.parse_args()
|
| 88 |
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
|
|
|
| 2 |
import onnx
|
|
|
|
| 3 |
from onnxconverter_common import auto_mixed_precision_model_path
|
| 4 |
import argparse
|
| 5 |
from rtmo_gpu import RTMO_GPU, draw_skeleton
|
|
|
|
| 21 |
return tuple(dims[2:4]) # Return (height, width)
|
| 22 |
raise ValueError("Input node 'input' not found in the model")
|
| 23 |
|
| 24 |
+
def load_and_preprocess_image(image_path, preprocesss=None):
|
| 25 |
+
|
| 26 |
+
image = cv2.imread(image_path)
|
| 27 |
+
|
| 28 |
+
if preprocesss is not None:
|
| 29 |
+
image = preprocesss(image)
|
| 30 |
+
|
| 31 |
return image
|
| 32 |
|
| 33 |
+
def compare_result(res1, res2):
|
| 34 |
+
keypoints1, scores1 = res1
|
| 35 |
+
keypoints2, scores2 = res2
|
| 36 |
+
|
| 37 |
+
from termcolor import colored
|
| 38 |
+
|
| 39 |
+
for d1, d2 in zip(keypoints1, keypoints2):
|
| 40 |
+
for i, (j1, j2) in enumerate(zip(d1, d2)):
|
| 41 |
+
x1, y1 = j1
|
| 42 |
+
x2, y2 = j2
|
| 43 |
+
print(f"Joint-{i}: X: {colored(x1,'green')} VS {colored(x2, 'blue')} Y: {colored(y1, 'green')} VS {colored(y2, 'blue')}")
|
| 44 |
+
|
| 45 |
+
for d1, d2 in zip(scores1, scores2):
|
| 46 |
+
for i, (s1, s2) in enumerate(zip(d1, d2)):
|
| 47 |
+
print(f"Joint-{i}: S: {colored(s1,'green')} VS {colored(s2, 'blue')}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def validate_pose(res1, res2, postprocess=None):
|
| 51 |
|
| 52 |
+
if postprocess is not None:
|
| 53 |
+
res1 = postprocess(res1)
|
| 54 |
+
res2 = postprocess(res2)
|
| 55 |
+
|
| 56 |
+
compare_result(res1, res2)
|
| 57 |
+
|
| 58 |
+
for r1, r2 in zip(res1, res2):
|
| 59 |
+
if not np.allclose(r1, r2, rtol=args.rtol, atol=args.atol):
|
| 60 |
return False
|
| 61 |
return True
|
| 62 |
|
| 63 |
def infer_on_image(onnx_model, model_input_size, test_image_path):
|
| 64 |
body = RTMO_GPU(onnx_model=onnx_model,
|
| 65 |
model_input_size=model_input_size,
|
| 66 |
+
is_yolo_nas_pose=args.yolo_nas_pose)
|
| 67 |
|
| 68 |
frame = cv2.imread(test_image_path)
|
| 69 |
img_show = frame.copy()
|
|
|
|
| 81 |
|
| 82 |
def main(args):
|
| 83 |
model_input_size = detect_model_input_size(args.source_model_path)
|
| 84 |
+
|
| 85 |
+
body = RTMO_GPU(onnx_model=args.source_model_path,
|
| 86 |
+
model_input_size=model_input_size,
|
| 87 |
+
is_yolo_nas_pose=args.yolo_nas_pose)
|
| 88 |
+
|
| 89 |
+
def preprocess(image, body, is_yolo_nas_pose):
|
| 90 |
+
|
| 91 |
+
img, _ = body.preprocess(image)
|
| 92 |
+
|
| 93 |
+
# build input to (1, 3, H, W)
|
| 94 |
+
img = img.transpose(2, 0, 1)
|
| 95 |
+
img = np.ascontiguousarray(img, dtype=np.float32 if not is_yolo_nas_pose else np.uint8)
|
| 96 |
+
img = img[None, :, :, :]
|
| 97 |
+
return img
|
| 98 |
+
|
| 99 |
+
image = load_and_preprocess_image(args.test_image_path, lambda img: preprocess(img, body, args.yolo_nas_pose))
|
| 100 |
+
|
| 101 |
+
input_feed = {'input': image}
|
| 102 |
|
| 103 |
auto_mixed_precision_model_path.auto_convert_mixed_precision_model_path(source_model_path=args.source_model_path,
|
| 104 |
input_feed=input_feed,
|
| 105 |
target_model_path=args.target_model_path,
|
| 106 |
+
customized_validate_func=lambda res1,res2:validate_pose(res1, res2, body.postprocess),
|
| 107 |
rtol=args.rtol, atol=args.atol,
|
| 108 |
provider=PROVIDERS,
|
| 109 |
keep_io_types=True,
|
|
|
|
| 118 |
parser.add_argument("test_image_path", type=str, help="Path to a test image for validating the model conversion.")
|
| 119 |
parser.add_argument('--rtol', type=float, default=0.01, help=' the relative tolerance to do validation')
|
| 120 |
parser.add_argument('--atol', type=float, default=0.001, help=' the absolute tolerance to do validation')
|
| 121 |
+
parser.add_argument('--yolo_nas_pose', action='store_true', help='Use YOLO NAS Pose (flat format only) instead of RTMO Model')
|
| 122 |
|
| 123 |
args = parser.parse_args()
|
| 124 |
|