yz commited on
Commit
b3357d3
·
1 Parent(s): 40cf2c0

Add files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/*.jpg filter=lfs diff=lfs merge=lfs -text
37
+ img/*.png filter=lfs diff=lfs merge=lfs -text
38
+ ax_rmbg filter=lfs diff=lfs merge=lfs -text
39
+ *.axmodel filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
File without changes
ax_inference.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import cv2
4
+ import numpy as np
5
+ import time
6
+
7
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> np.ndarray:
8
+ if len(im.shape) < 3:
9
+ im = im[:, :, np.newaxis]
10
+ im_np = np.transpose(im, (2, 0, 1)).astype(np.float32)
11
+ im_np = np.expand_dims(im_np, axis=0)
12
+
13
+ _, C, H_ori, W_ori = im_np.shape
14
+ H_target, W_target = model_input_size
15
+
16
+ x_target = np.linspace(0, W_ori - 1, W_target)
17
+ y_target = np.linspace(0, H_ori - 1, H_target)
18
+ xx_target, yy_target = np.meshgrid(x_target, y_target)
19
+
20
+ x0 = np.floor(xx_target).astype(np.int32)
21
+ x1 = np.minimum(x0 + 1, W_ori - 1)
22
+ y0 = np.floor(yy_target).astype(np.int32)
23
+ y1 = np.minimum(y0 + 1, H_ori - 1)
24
+
25
+ wx0 = xx_target - x0
26
+ wx1 = 1 - wx0
27
+ wy0 = yy_target - y0
28
+ wy1 = 1 - wy0
29
+
30
+ im_interp = np.zeros((1, C, H_target, W_target), dtype=np.float32)
31
+ for c in range(C):
32
+ channel_data = im_np[0, c, :, :]
33
+ top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1]
34
+ bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1]
35
+ im_interp[0, c, :, :] = wy1 * top + wy0 * bottom
36
+
37
+ image = (im_interp / 1.0).astype(np.uint8)
38
+
39
+ return image
40
+
41
+ def postprocess_image(result: np.ndarray, im_size: list)-> np.ndarray:
42
+ result_np = np.squeeze(result, axis=0)
43
+ C, H_ori, W_ori = result_np.shape
44
+ H_target, W_target = im_size # 目标尺寸(H,W)
45
+
46
+ x_target = np.linspace(0, W_ori - 1, W_target)
47
+ y_target = np.linspace(0, H_ori - 1, H_target)
48
+ xx_target, yy_target = np.meshgrid(x_target, y_target)
49
+
50
+ x0 = np.floor(xx_target).astype(np.int32)
51
+ x1 = np.minimum(x0 + 1, W_ori - 1)
52
+ y0 = np.floor(yy_target).astype(np.int32)
53
+ y1 = np.minimum(y0 + 1, H_ori - 1)
54
+
55
+ wx0 = xx_target - x0
56
+ wx1 = 1 - wx0
57
+ wy0 = yy_target - y0
58
+ wy1 = 1 - wy0
59
+
60
+ result_interp = np.zeros((C, H_target, W_target), dtype=np.float32)
61
+ for c in range(C):
62
+ channel_data = result_np[c, :, :]
63
+ top = wx1 * channel_data[y0, x0] + wx0 * channel_data[y0, x1]
64
+ bottom = wx1 * channel_data[y1, x0] + wx0 * channel_data[y1, x1]
65
+ result_interp[c, :, :] = wy1 * top + wy0 * bottom
66
+
67
+ ma = np.max(result_interp)
68
+ mi = np.min(result_interp)
69
+
70
+ result_norm = (result_interp - mi) / (ma - mi + 1e-8) # 加极小值避免除零
71
+ result_scaled = result_norm * 255
72
+ im_array = np.transpose(result_scaled, (1, 2, 0)).astype(np.uint8)
73
+ im_array = np.squeeze(im_array)
74
+ return im_array
75
+
76
+ def inference(img_path,
77
+ model_path,
78
+ save_path):
79
+
80
+ if model_path.endswith(".axmodel"):
81
+ import axengine as ort
82
+
83
+ session = ort.InferenceSession(model_path)
84
+ input_name = None
85
+ for inp_meta in session.get_inputs():
86
+ input_shape = inp_meta.shape[2:]
87
+ input_name = inp_meta.name
88
+ print(f"输入名称:{input_name},输入尺寸:{input_shape}")
89
+
90
+ model_input_size = [1024, 1024]
91
+ orig_im_bgr = cv2.imread(img_path)
92
+ if orig_im_bgr is None:
93
+ raise FileNotFoundError(f"无法读取图片文件:{img_path},请检查路径是否正确或图片是否损坏")
94
+ orig_im = cv2.cvtColor(orig_im_bgr, cv2.COLOR_BGR2RGB) # 转换为RGB格式 (H,W,3)
95
+ orig_im_size = orig_im.shape[0:2]
96
+ image = preprocess_image(orig_im, model_input_size)
97
+
98
+ t1 = time.time()
99
+ result = session.run(None, {input_name: image})
100
+ t2 = time.time()
101
+ print(f"推理时间:{(t2-t1)*1000:.2f} ms")
102
+
103
+ result_image = postprocess_image(result[0], orig_im_size) # 得到单通道掩码 (H,W)
104
+ orig_im_unchanged = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # 读取所有通道(BGR/BGRA)
105
+ mask = result_image # 单通道掩码 (H,W),值范围0-255
106
+ if orig_im_unchanged.shape[-1] == 3: # 原图为BGR格式(无透明通道)
107
+ b, g, r = cv2.split(orig_im_unchanged)
108
+ a = mask
109
+ no_bg_image = cv2.merge((b, g, r, a)) # 合并为BGRA格式
110
+ elif orig_im_unchanged.shape[-1] == 4: # 原图为BGRA格式(已有透明通道)
111
+ b, g, r, _ = cv2.split(orig_im_unchanged)
112
+ a = mask
113
+ no_bg_image = cv2.merge((b, g, r, a))
114
+ else:
115
+ raise ValueError(f"不支持的图片通道数:{orig_im_unchanged.shape[-1]},仅支持3通道(BGR)或4通道(BGRA)")
116
+
117
+ if save_path.lower().endswith(('.jpg', '.jpeg')):
118
+ cv2.imwrite(save_path, cv2.cvtColor(no_bg_image, cv2.COLOR_BGRA2BGR))
119
+ print(f"JPG格式不支持透明通道,已丢弃Alpha通道,结果保存至:{save_path}")
120
+ else:
121
+ cv2.imwrite(save_path, no_bg_image)
122
+ print(f"推理完成,结果已保存至:{save_path}")
123
+
124
+ def parse_args() -> argparse.ArgumentParser:
125
+ parser = argparse.ArgumentParser(description="ax rmbg exsample")
126
+ parser.add_argument("--model","-m", type=str, help="compiled.axmodel path")
127
+ parser.add_argument("--img","-i", type=str, help="img path")
128
+ parser.add_argument("--save_path", type=str, default="./result.png", help="save result path (png)")
129
+
130
+ args = parser.parse_args()
131
+ return args
132
+
133
+ if __name__ == "__main__":
134
+ args = parse_args()
135
+
136
+ print(f"Command: {' '.join(sys.argv)}")
137
+ print("Parameters:")
138
+ print(f" --model: {args.model}")
139
+ print(f" --img_path: {args.img}")
140
+ print(f" --save_path: {args.save_path}")
141
+
142
+ inference(args.img, args.model, args.save_path)
ax_rmbg ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11efeb576105cb6354076dec70a0c587775a942251600b5e3e1f5473456dc65d
3
+ size 5296200
axmodel/build_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input": "./rmbg_1_4_sim.onnx",
3
+ "output_dir": "output",
4
+ "output_name": "rmbg1_4_ax650.axmodel",
5
+ "work_dir": "",
6
+ "model_type": "ONNX",
7
+ "target_hardware": "AX650",
8
+ "npu_mode": "NPU1",
9
+ "quant": {
10
+ "input_configs": [
11
+ {
12
+ "tensor_name": "input",
13
+ "calibration_dataset": "pic.zip",
14
+ "calibration_format": "Image",
15
+ "calibration_size": 10,
16
+ "calibration_mean": [128 ,128 ,128],
17
+ "calibration_std": [255, 255, 255]
18
+ }
19
+ ],
20
+ "calibration_method": "MinMax",
21
+ "precision_analysis": false,
22
+ "precision_analysis_method": "EndToEnd"
23
+ },
24
+ "input_processors": [
25
+ {
26
+ "tensor_name": "input",
27
+ "tensor_format": "RGB",
28
+ "tensor_layout": "NCHW",
29
+ "src_format": "RGB",
30
+ "src_layout": "NCHW",
31
+ "src_dtype": "U8"
32
+ }
33
+ ],
34
+ "output_processors": [
35
+ {
36
+ "tensor_name": "DEFAULT"
37
+ }
38
+ ],
39
+ "compiler": {
40
+ "check": 0
41
+ }
42
+ }
43
+
axmodel/rmbg1_4_ax650.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6849b3aad38a1f67e2098f860d0f58f20407502ce5a468200e24303763a95f0f
3
+ size 45552006
img/3_1920x1080.jpg ADDED

Git LFS Details

  • SHA256: 9a4bff8a0d00cdd45b26a0d331cde508c8d561ca5f8388f0737f26abecfa4efc
  • Pointer size: 131 Bytes
  • Size of remote file: 311 kB
img/3_1920x1080_mask.png ADDED

Git LFS Details

  • SHA256: a3a90102c30111be9027c9468b67d56aaebde839a94d7ea49b249c52baaade63
  • Pointer size: 131 Bytes
  • Size of remote file: 269 kB
img/3_1920x1080_result.png ADDED

Git LFS Details

  • SHA256: 226d1ea35465bc42e39f5e3c6f919f738272e0cd899f1680499ccf1787dc8f5c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.87 MB
img/example_input.jpg ADDED

Git LFS Details

  • SHA256: 1e9cff13a43d13ec0d0d733a55234e862a35c282cdbfa197c85223a937f28a56
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
img/example_input_result.png ADDED

Git LFS Details

  • SHA256: d8e0ad94d86106ca6964ec77b29639a5001d65138f64174882dc0a07631ad220
  • Pointer size: 132 Bytes
  • Size of remote file: 2.1 MB