| import argparse |
| import cv2 |
| import numpy as np |
| import onnxruntime as ort |
| import matplotlib.pyplot as plt |
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--left", |
| type=str, |
| required=True, |
| help="Path to left image.", |
| ) |
| parser.add_argument( |
| "--right", |
| type=str, |
| required=True, |
| help="Path to right image.", |
| ) |
| parser.add_argument( |
| "--model", |
| type=str, |
| required=True, |
| help="Path to ONNX model.", |
| ) |
| |
| return parser.parse_args() |
|
|
|
|
| def infer(left: str, right: str, model: str): |
| |
| available_providers = ort.get_available_providers() |
| if "CUDAExecutionProvider" in available_providers: |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| print("使用 CUDA 执行提供者") |
| else: |
| providers = ["CPUExecutionProvider"] |
| print("使用 CPU 执行提供者 (CUDA 不可用)") |
|
|
| print(f"正在加载模型: {model}") |
| session = ort.InferenceSession( |
| model, providers=providers |
| ) |
| print("模型加载完成") |
| |
| input_info = [] |
| input_tensors = session.get_inputs() |
| for tensor in input_tensors: |
| info = dict(name=tensor.name , type=tensor.type, shape=tensor.shape) |
| input_info.append(info) |
|
|
| H,W = input_info[0]['shape'][2:4] |
| print(f"输入尺寸: {H}x{W}") |
|
|
| print(f"正在读取左图: {left}") |
| left_raw = cv2.imread(left) |
| if left_raw is None: |
| raise ValueError(f"无法读取左图: {left}") |
| image_left = cv2.cvtColor(left_raw, cv2.COLOR_BGR2RGB) |
| orig_h_left, orig_w_left = image_left.shape[:2] |
| print(f"左图原始尺寸: {orig_h_left}x{orig_w_left}") |
| image_left = cv2.resize(image_left, (W,H) ) |
| |
| image_left = image_left.transpose(2,0,1) |
| image_left = image_left[None].astype(np.float32) |
|
|
| print(f"正在读取右图: {right}") |
| right_raw = cv2.imread(right) |
| if right_raw is None: |
| raise ValueError(f"无法读取右图: {right}") |
| image_right = cv2.cvtColor(right_raw, cv2.COLOR_BGR2RGB) |
| orig_h_right, orig_w_right = image_right.shape[:2] |
| print(f"右图原始尺寸: {orig_h_right}x{orig_w_right}") |
| image_right = cv2.resize(image_right, (W,H) ) |
| |
| image_right = image_right.transpose(2,0,1) |
| image_right = image_right[None].astype(np.float32) |
|
|
| assert orig_h_left == orig_h_right and orig_w_left == orig_w_right |
|
|
| print("正在进行推理...") |
| flow_up = session.run(None, {input_info[0]['name']: image_left, input_info[1]['name']:image_right})[0] |
| print("推理完成") |
| |
| flow_up = cv2.resize(flow_up[0,0], (orig_w_left, orig_h_left)) |
| flow_up *= orig_w_left/W |
| |
| output = np.abs(flow_up) |
| |
| print(f"正在保存结果到 output-onnx.png") |
| plt.imsave(f"output-onnx.png", output, cmap='jet') |
| print("推理完成,结果已保存到 output-onnx.png") |
|
|
| return output |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| infer(**vars(args)) |
|
|