RAFT-stereo / infer_onnx.py
fangmingguo's picture
Upload 22 files
e4fa43e verified
import argparse
import cv2
import numpy as np
import onnxruntime as ort
import matplotlib.pyplot as plt
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--left",
type=str,
required=True,
help="Path to left image.",
)
parser.add_argument(
"--right",
type=str,
required=True,
help="Path to right image.",
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to ONNX model.",
)
return parser.parse_args()
def infer(left: str, right: str, model: str):
# 自动检测可用的执行提供者,优先使用 CUDA,否则使用 CPU
available_providers = ort.get_available_providers()
if "CUDAExecutionProvider" in available_providers:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
print("使用 CUDA 执行提供者")
else:
providers = ["CPUExecutionProvider"]
print("使用 CPU 执行提供者 (CUDA 不可用)")
print(f"正在加载模型: {model}")
session = ort.InferenceSession(
model, providers=providers
)
print("模型加载完成")
input_info = []
input_tensors = session.get_inputs()
for tensor in input_tensors:
info = dict(name=tensor.name , type=tensor.type, shape=tensor.shape)
input_info.append(info)
H,W = input_info[0]['shape'][2:4]
print(f"输入尺寸: {H}x{W}")
print(f"正在读取左图: {left}")
left_raw = cv2.imread(left)
if left_raw is None:
raise ValueError(f"无法读取左图: {left}")
image_left = cv2.cvtColor(left_raw, cv2.COLOR_BGR2RGB)
orig_h_left, orig_w_left = image_left.shape[:2]
print(f"左图原始尺寸: {orig_h_left}x{orig_w_left}")
image_left = cv2.resize(image_left, (W,H) )
image_left = image_left.transpose(2,0,1)
image_left = image_left[None].astype(np.float32)
print(f"正在读取右图: {right}")
right_raw = cv2.imread(right)
if right_raw is None:
raise ValueError(f"无法读取右图: {right}")
image_right = cv2.cvtColor(right_raw, cv2.COLOR_BGR2RGB)
orig_h_right, orig_w_right = image_right.shape[:2]
print(f"右图原始尺寸: {orig_h_right}x{orig_w_right}")
image_right = cv2.resize(image_right, (W,H) )
image_right = image_right.transpose(2,0,1)
image_right = image_right[None].astype(np.float32)
assert orig_h_left == orig_h_right and orig_w_left == orig_w_right
print("正在进行推理...")
flow_up = session.run(None, {input_info[0]['name']: image_left, input_info[1]['name']:image_right})[0]
print("推理完成")
flow_up = cv2.resize(flow_up[0,0], (orig_w_left, orig_h_left))
flow_up *= orig_w_left/W
output = np.abs(flow_up)
print(f"正在保存结果到 output-onnx.png")
plt.imsave(f"output-onnx.png", output, cmap='jet')
print("推理完成,结果已保存到 output-onnx.png")
return output
if __name__ == "__main__":
args = parse_args()
infer(**vars(args))