File size: 5,419 Bytes
b3357d3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import argparse
import sys
import cv2
import numpy as np
import time
def preprocess_image(im: np.ndarray, model_input_size: list) -> np.ndarray:
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
im_np = np.transpose(im, (2, 0, 1)).astype(np.float32)
im_np = np.expand_dims(im_np, axis=0)
_, C, H_ori, W_ori = im_np.shape
H_target, W_target = model_input_size
x_target = np.linspace(0, W_ori - 1, W_target)
y_target = np.linspace(0, H_ori - 1, H_target)
xx_target, yy_target = np.meshgrid(x_target, y_target)
x0 = np.floor(xx_target).astype(np.int32)
x1 = np.minimum(x0 + 1, W_ori - 1)
y0 = np.floor(yy_target).astype(np.int32)
y1 = np.minimum(y0 + 1, H_ori - 1)
wx0 = xx_target - x0
wx1 = 1 - wx0
wy0 = yy_target - y0
wy1 = 1 - wy0
im_interp = np.zeros((1, C, H_target, W_target), dtype=np.float32)
for c in range(C):
channel_data = im_np[0, c, :, :]
top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1]
bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1]
im_interp[0, c, :, :] = wy1 * top + wy0 * bottom
image = (im_interp / 1.0).astype(np.uint8)
return image
def postprocess_image(result: np.ndarray, im_size: list)-> np.ndarray:
result_np = np.squeeze(result, axis=0)
C, H_ori, W_ori = result_np.shape
H_target, W_target = im_size # 目标尺寸(H,W)
x_target = np.linspace(0, W_ori - 1, W_target)
y_target = np.linspace(0, H_ori - 1, H_target)
xx_target, yy_target = np.meshgrid(x_target, y_target)
x0 = np.floor(xx_target).astype(np.int32)
x1 = np.minimum(x0 + 1, W_ori - 1)
y0 = np.floor(yy_target).astype(np.int32)
y1 = np.minimum(y0 + 1, H_ori - 1)
wx0 = xx_target - x0
wx1 = 1 - wx0
wy0 = yy_target - y0
wy1 = 1 - wy0
result_interp = np.zeros((C, H_target, W_target), dtype=np.float32)
for c in range(C):
channel_data = result_np[c, :, :]
top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1]
bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1]
result_interp[c, :, :] = wy1 * top + wy0 * bottom
ma = np.max(result_interp)
mi = np.min(result_interp)
result_norm = (result_interp - mi) / (ma - mi + 1e-8) # 加极小值避免除零
result_scaled = result_norm * 255
im_array = np.transpose(result_scaled, (1, 2, 0)).astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array
def inference(img_path,
model_path,
save_path):
if model_path.endswith(".axmodel"):
import axengine as ort
session = ort.InferenceSession(model_path)
input_name = None
for inp_meta in session.get_inputs():
input_shape = inp_meta.shape[2:]
input_name = inp_meta.name
print(f"输入名称:{input_name},输入尺寸:{input_shape}")
model_input_size = [1024, 1024]
orig_im_bgr = cv2.imread(img_path)
if orig_im_bgr is None:
raise FileNotFoundError(f"无法读取图片文件:{img_path},请检查路径是否正确或图片是否损坏")
orig_im = cv2.cvtColor(orig_im_bgr, cv2.COLOR_BGR2RGB) # 转换为RGB格式 (H,W,3)
orig_im_size = orig_im.shape[0:2]
image = preprocess_image(orig_im, model_input_size)
t1 = time.time()
result = session.run(None, {input_name: image})
t2 = time.time()
print(f"推理时间:{(t2-t1)*1000:.2f} ms")
result_image = postprocess_image(result[0], orig_im_size) # 得到单通道掩码 (H,W)
orig_im_unchanged = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # 读取所有通道(BGR/BGRA)
mask = result_image # 单通道掩码 (H,W),值范围0-255
if orig_im_unchanged.shape[-1] == 3: # 原图为BGR格式(无透明通道)
b, g, r = cv2.split(orig_im_unchanged)
a = mask
no_bg_image = cv2.merge((b, g, r, a)) # 合并为BGRA格式
elif orig_im_unchanged.shape[-1] == 4: # 原图为BGRA格式(已有透明通道)
b, g, r, _ = cv2.split(orig_im_unchanged)
a = mask
no_bg_image = cv2.merge((b, g, r, a))
else:
raise ValueError(f"不支持的图片通道数:{orig_im_unchanged.shape[-1]},仅支持3通道(BGR)或4通道(BGRA)")
if save_path.lower().endswith(('.jpg', '.jpeg')):
cv2.imwrite(save_path, cv2.cvtColor(no_bg_image, cv2.COLOR_BGRA2BGR))
print(f"JPG格式不支持透明通道,已丢弃Alpha通道,结果保存至:{save_path}")
else:
cv2.imwrite(save_path, no_bg_image)
print(f"推理完成,结果已保存至:{save_path}")
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="ax rmbg exsample")
parser.add_argument("--model","-m", type=str, help="compiled.axmodel path")
parser.add_argument("--img","-i", type=str, help="img path")
parser.add_argument("--save_path", type=str, default="./result.png", help="save result path (png)")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
print(f"Command: {' '.join(sys.argv)}")
print("Parameters:")
print(f" --model: {args.model}")
print(f" --img_path: {args.img}")
print(f" --save_path: {args.save_path}")
inference(args.img, args.model, args.save_path) |