File size: 2,770 Bytes
c86f57f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5e505f
c86f57f
 
 
 
 
 
b5e505f
c86f57f
 
 
 
 
b5e505f
 
 
 
 
c86f57f
 
 
b5e505f
c86f57f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5e505f
 
 
 
c86f57f
 
 
 
 
b5e505f
 
 
c86f57f
b5e505f
c86f57f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import sys
import argparse
import glob
import numpy as np
import torch
from tqdm import tqdm
from pathlib import Path
from PIL import Image
from matplotlib import pyplot as plt
import os
import onnxruntime as ort
import axengine as axe


def load_image(imfile):
    img = np.array(Image.open(imfile).resize((512,384))).astype(np.uint8)[..., :3]
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img[None]


def visualize_disparity(disparity_map, title, name="test"):
    plt.figure(figsize=(10, 6))
    plt.imshow(disparity_map, cmap='jet')
    plt.colorbar(label="Disparity")
    plt.title(title)
    plt.axis('off')
    # plt.show()
    plt.savefig(f"{title}-rt-{name}.png")


def demo(args):
    left_images = sorted(glob.glob(args.left_imgs, recursive=True))
    right_images = sorted(glob.glob(args.right_imgs, recursive=True))
    if args.target_chip == "AX637":
        print("\033[91mWarning: AX637 uses quant_axmodel, which can not be run by onnxruntime, \
so we use AX650's onnx model for comparison\033[0m")
    ort_session = ort.InferenceSession(f"./models/AX650.onnx")
    ax_session = axe.InferenceSession(f"./models/{args.target_chip}_RTIGEV.axmodel")
    for (imfile1, imfile2) in tqdm(list(zip(left_images, right_images))):
        image1 = load_image(imfile1)
        image2 = load_image(imfile2)
        img_name = Path(imfile1).parent.name

        input_l_np = image1.cpu().numpy()
        input_r_np = image2.cpu().numpy()
        ax_inputs = {"left": input_l_np.transpose(0,2,3,1).astype(np.uint8), "right": input_r_np.transpose(0,2,3,1).astype(np.uint8)}

        input_l_np = (2 * (input_l_np / 255.0) - 1.0)
        input_r_np = (2 * (input_r_np / 255.0) - 1.0)
        onnx_inputs = {"left": input_l_np, "right": input_r_np}
        
        onnx_outputs = ort_session.run(None, onnx_inputs)
        disp_onnx = onnx_outputs[0].squeeze()

        ax_outputs = ax_session.run(None, ax_inputs)
        disp_ax = ax_outputs[0].squeeze()

        # print("disp_onnx",disp_onnx)
        # print("disp_ax",disp_ax)
        visualize_disparity(disp_onnx, title="ONNX_Disparity_Map", name=img_name)
        visualize_disparity(disp_ax, title="AXModel_Disparity_Map", name=img_name)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-l', '--left_imgs', help="path to all first (left) frames",
                        default="./demo-imgs/*/im0.png")
    parser.add_argument('-t', '--target_chip', help="target chip for inference",
                        default="AX650", choices=["AX637", "AX650"])
    parser.add_argument('-r', '--right_imgs', help="path to all second (right) frames",
                        default="./demo-imgs/*/im1.png")
    args = parser.parse_args()
    demo(args)