yz commited on
Commit ·
b3357d3
1
Parent(s): 40cf2c0
Add files
Browse files- .gitattributes +4 -0
- README.md +0 -0
- ax_inference.py +142 -0
- ax_rmbg +3 -0
- axmodel/build_config.json +43 -0
- axmodel/rmbg1_4_ax650.axmodel +3 -0
- img/3_1920x1080.jpg +3 -0
- img/3_1920x1080_mask.png +3 -0
- img/3_1920x1080_result.png +3 -0
- img/example_input.jpg +3 -0
- img/example_input_result.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
img/*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
img/*.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
ax_rmbg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.axmodel filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
File without changes
|
ax_inference.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import sys
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
def preprocess_image(im: np.ndarray, model_input_size: list) -> np.ndarray:
|
| 8 |
+
if len(im.shape) < 3:
|
| 9 |
+
im = im[:, :, np.newaxis]
|
| 10 |
+
im_np = np.transpose(im, (2, 0, 1)).astype(np.float32)
|
| 11 |
+
im_np = np.expand_dims(im_np, axis=0)
|
| 12 |
+
|
| 13 |
+
_, C, H_ori, W_ori = im_np.shape
|
| 14 |
+
H_target, W_target = model_input_size
|
| 15 |
+
|
| 16 |
+
x_target = np.linspace(0, W_ori - 1, W_target)
|
| 17 |
+
y_target = np.linspace(0, H_ori - 1, H_target)
|
| 18 |
+
xx_target, yy_target = np.meshgrid(x_target, y_target)
|
| 19 |
+
|
| 20 |
+
x0 = np.floor(xx_target).astype(np.int32)
|
| 21 |
+
x1 = np.minimum(x0 + 1, W_ori - 1)
|
| 22 |
+
y0 = np.floor(yy_target).astype(np.int32)
|
| 23 |
+
y1 = np.minimum(y0 + 1, H_ori - 1)
|
| 24 |
+
|
| 25 |
+
wx0 = xx_target - x0
|
| 26 |
+
wx1 = 1 - wx0
|
| 27 |
+
wy0 = yy_target - y0
|
| 28 |
+
wy1 = 1 - wy0
|
| 29 |
+
|
| 30 |
+
im_interp = np.zeros((1, C, H_target, W_target), dtype=np.float32)
|
| 31 |
+
for c in range(C):
|
| 32 |
+
channel_data = im_np[0, c, :, :]
|
| 33 |
+
top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1]
|
| 34 |
+
bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1]
|
| 35 |
+
im_interp[0, c, :, :] = wy1 * top + wy0 * bottom
|
| 36 |
+
|
| 37 |
+
image = (im_interp / 1.0).astype(np.uint8)
|
| 38 |
+
|
| 39 |
+
return image
|
| 40 |
+
|
| 41 |
+
def postprocess_image(result: np.ndarray, im_size: list)-> np.ndarray:
|
| 42 |
+
result_np = np.squeeze(result, axis=0)
|
| 43 |
+
C, H_ori, W_ori = result_np.shape
|
| 44 |
+
H_target, W_target = im_size # 目标尺寸(H,W)
|
| 45 |
+
|
| 46 |
+
x_target = np.linspace(0, W_ori - 1, W_target)
|
| 47 |
+
y_target = np.linspace(0, H_ori - 1, H_target)
|
| 48 |
+
xx_target, yy_target = np.meshgrid(x_target, y_target)
|
| 49 |
+
|
| 50 |
+
x0 = np.floor(xx_target).astype(np.int32)
|
| 51 |
+
x1 = np.minimum(x0 + 1, W_ori - 1)
|
| 52 |
+
y0 = np.floor(yy_target).astype(np.int32)
|
| 53 |
+
y1 = np.minimum(y0 + 1, H_ori - 1)
|
| 54 |
+
|
| 55 |
+
wx0 = xx_target - x0
|
| 56 |
+
wx1 = 1 - wx0
|
| 57 |
+
wy0 = yy_target - y0
|
| 58 |
+
wy1 = 1 - wy0
|
| 59 |
+
|
| 60 |
+
result_interp = np.zeros((C, H_target, W_target), dtype=np.float32)
|
| 61 |
+
for c in range(C):
|
| 62 |
+
channel_data = result_np[c, :, :]
|
| 63 |
+
top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1]
|
| 64 |
+
bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1]
|
| 65 |
+
result_interp[c, :, :] = wy1 * top + wy0 * bottom
|
| 66 |
+
|
| 67 |
+
ma = np.max(result_interp)
|
| 68 |
+
mi = np.min(result_interp)
|
| 69 |
+
|
| 70 |
+
result_norm = (result_interp - mi) / (ma - mi + 1e-8) # 加极小值避免除零
|
| 71 |
+
result_scaled = result_norm * 255
|
| 72 |
+
im_array = np.transpose(result_scaled, (1, 2, 0)).astype(np.uint8)
|
| 73 |
+
im_array = np.squeeze(im_array)
|
| 74 |
+
return im_array
|
| 75 |
+
|
| 76 |
+
def inference(img_path,
|
| 77 |
+
model_path,
|
| 78 |
+
save_path):
|
| 79 |
+
|
| 80 |
+
if model_path.endswith(".axmodel"):
|
| 81 |
+
import axengine as ort
|
| 82 |
+
|
| 83 |
+
session = ort.InferenceSession(model_path)
|
| 84 |
+
input_name = None
|
| 85 |
+
for inp_meta in session.get_inputs():
|
| 86 |
+
input_shape = inp_meta.shape[2:]
|
| 87 |
+
input_name = inp_meta.name
|
| 88 |
+
print(f"输入名称:{input_name},输入尺寸:{input_shape}")
|
| 89 |
+
|
| 90 |
+
model_input_size = [1024, 1024]
|
| 91 |
+
orig_im_bgr = cv2.imread(img_path)
|
| 92 |
+
if orig_im_bgr is None:
|
| 93 |
+
raise FileNotFoundError(f"无法读取图片文件:{img_path},请检查路径是否正确或图片是否损坏")
|
| 94 |
+
orig_im = cv2.cvtColor(orig_im_bgr, cv2.COLOR_BGR2RGB) # 转换为RGB格式 (H,W,3)
|
| 95 |
+
orig_im_size = orig_im.shape[0:2]
|
| 96 |
+
image = preprocess_image(orig_im, model_input_size)
|
| 97 |
+
|
| 98 |
+
t1 = time.time()
|
| 99 |
+
result = session.run(None, {input_name: image})
|
| 100 |
+
t2 = time.time()
|
| 101 |
+
print(f"推理时间:{(t2-t1)*1000:.2f} ms")
|
| 102 |
+
|
| 103 |
+
result_image = postprocess_image(result[0], orig_im_size) # 得到单通道掩码 (H,W)
|
| 104 |
+
orig_im_unchanged = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # 读取所有通道(BGR/BGRA)
|
| 105 |
+
mask = result_image # 单通道掩码 (H,W),值范围0-255
|
| 106 |
+
if orig_im_unchanged.shape[-1] == 3: # 原图为BGR格式(无透明通道)
|
| 107 |
+
b, g, r = cv2.split(orig_im_unchanged)
|
| 108 |
+
a = mask
|
| 109 |
+
no_bg_image = cv2.merge((b, g, r, a)) # 合并为BGRA格式
|
| 110 |
+
elif orig_im_unchanged.shape[-1] == 4: # 原图为BGRA格式(已有透明通道)
|
| 111 |
+
b, g, r, _ = cv2.split(orig_im_unchanged)
|
| 112 |
+
a = mask
|
| 113 |
+
no_bg_image = cv2.merge((b, g, r, a))
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"不支持的图片通道数:{orig_im_unchanged.shape[-1]},仅支持3通道(BGR)或4通道(BGRA)")
|
| 116 |
+
|
| 117 |
+
if save_path.lower().endswith(('.jpg', '.jpeg')):
|
| 118 |
+
cv2.imwrite(save_path, cv2.cvtColor(no_bg_image, cv2.COLOR_BGRA2BGR))
|
| 119 |
+
print(f"JPG格式不支持透明通道,已丢弃Alpha通道,结果保存至:{save_path}")
|
| 120 |
+
else:
|
| 121 |
+
cv2.imwrite(save_path, no_bg_image)
|
| 122 |
+
print(f"推理完成,结果已保存至:{save_path}")
|
| 123 |
+
|
| 124 |
+
def parse_args() -> argparse.ArgumentParser:
|
| 125 |
+
parser = argparse.ArgumentParser(description="ax rmbg exsample")
|
| 126 |
+
parser.add_argument("--model","-m", type=str, help="compiled.axmodel path")
|
| 127 |
+
parser.add_argument("--img","-i", type=str, help="img path")
|
| 128 |
+
parser.add_argument("--save_path", type=str, default="./result.png", help="save result path (png)")
|
| 129 |
+
|
| 130 |
+
args = parser.parse_args()
|
| 131 |
+
return args
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
args = parse_args()
|
| 135 |
+
|
| 136 |
+
print(f"Command: {' '.join(sys.argv)}")
|
| 137 |
+
print("Parameters:")
|
| 138 |
+
print(f" --model: {args.model}")
|
| 139 |
+
print(f" --img_path: {args.img}")
|
| 140 |
+
print(f" --save_path: {args.save_path}")
|
| 141 |
+
|
| 142 |
+
inference(args.img, args.model, args.save_path)
|
ax_rmbg
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11efeb576105cb6354076dec70a0c587775a942251600b5e3e1f5473456dc65d
|
| 3 |
+
size 5296200
|
axmodel/build_config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"input": "./rmbg_1_4_sim.onnx",
|
| 3 |
+
"output_dir": "output",
|
| 4 |
+
"output_name": "rmbg1_4_ax650.axmodel",
|
| 5 |
+
"work_dir": "",
|
| 6 |
+
"model_type": "ONNX",
|
| 7 |
+
"target_hardware": "AX650",
|
| 8 |
+
"npu_mode": "NPU1",
|
| 9 |
+
"quant": {
|
| 10 |
+
"input_configs": [
|
| 11 |
+
{
|
| 12 |
+
"tensor_name": "input",
|
| 13 |
+
"calibration_dataset": "pic.zip",
|
| 14 |
+
"calibration_format": "Image",
|
| 15 |
+
"calibration_size": 10,
|
| 16 |
+
"calibration_mean": [128 ,128 ,128],
|
| 17 |
+
"calibration_std": [255, 255, 255]
|
| 18 |
+
}
|
| 19 |
+
],
|
| 20 |
+
"calibration_method": "MinMax",
|
| 21 |
+
"precision_analysis": false,
|
| 22 |
+
"precision_analysis_method": "EndToEnd"
|
| 23 |
+
},
|
| 24 |
+
"input_processors": [
|
| 25 |
+
{
|
| 26 |
+
"tensor_name": "input",
|
| 27 |
+
"tensor_format": "RGB",
|
| 28 |
+
"tensor_layout": "NCHW",
|
| 29 |
+
"src_format": "RGB",
|
| 30 |
+
"src_layout": "NCHW",
|
| 31 |
+
"src_dtype": "U8"
|
| 32 |
+
}
|
| 33 |
+
],
|
| 34 |
+
"output_processors": [
|
| 35 |
+
{
|
| 36 |
+
"tensor_name": "DEFAULT"
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"compiler": {
|
| 40 |
+
"check": 0
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
axmodel/rmbg1_4_ax650.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6849b3aad38a1f67e2098f860d0f58f20407502ce5a468200e24303763a95f0f
|
| 3 |
+
size 45552006
|
img/3_1920x1080.jpg
ADDED
|
Git LFS Details
|
img/3_1920x1080_mask.png
ADDED
|
Git LFS Details
|
img/3_1920x1080_result.png
ADDED
|
Git LFS Details
|
img/example_input.jpg
ADDED
|
Git LFS Details
|
img/example_input_result.png
ADDED
|
Git LFS Details
|