wli1995 commited on
Commit
ff5b345
·
verified ·
1 Parent(s): 3838620

Upload 5 files

Browse files
README.md CHANGED
@@ -1,3 +1,102 @@
1
- ---
2
- license: bsd-3-clause
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EdgeTAM
2
+ 基于EdgeSAM的图像分割Pipeline,支持多种输入提示(框、点、掩码),支持650N系列平台的模型推理。
3
+
4
+ 支持芯片:
5
+ - AX650N
6
+
7
+
8
+ 支持硬件
9
+
10
+ - [M4N-Dock(爱芯派Pro)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
11
+ - [M.2 Accelerator card](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
12
+
13
+ 原始模型请参考
14
+ - [EdgeTAM Github](https://github.com/facebookresearch/EdgeTAM)
15
+ - [EdgeTAM Huggingface](https://huggingface.co/facebook/EdgeTAM)
16
+
17
+ ## 性能对比
18
+
19
+ - 输入图片大小 512x512
20
+
21
+ | Models | Latency (ms) | CMM Usage (MB) |
22
+ | --------------------- | ---------------------- | -------------- |
23
+ | edgetam_image_encoder | 22.348 | 29.124 |
24
+ | edgetam_prompt_encoder | 0.055 | 0.023 |
25
+ | edgetam_prompt_mask_encoder | 0.457 | 0.037 |
26
+ | edgetam_mask_decoder | 4.729 | 16.730 |
27
+
28
+ ## 模型转换
29
+ - 模型转换工具链[Pulsar2](https://huggingface.co/AXERA-TECH/Pulsar2)
30
+ - 转换文档[TODO]
31
+
32
+ ## 环境准备
33
+ - NPU Python API: [pyaxengine](https://github.com/AXERA-TECH/pyaxengine)
34
+
35
+ 安装需要的python库
36
+ ```pip install -r requirements.txt```
37
+
38
+ ## 运行
39
+
40
+ ```bash
41
+ (myenv) root@ax650:~/EdgeTAM# python3 image_prediction_ax.py --input_box 75,275,1725,850
42
+ [INFO] Available providers: ['AxEngineExecutionProvider']
43
+ Loading EdgeTAM Onnx models...
44
+ [INFO] Using provider: AxEngineExecutionProvider
45
+ [INFO] Chip type: ChipType.MC50
46
+ [INFO] VNPU type: VNPUType.DISABLED
47
+ [INFO] Engine version: 2.12.0s
48
+ [INFO] Model type: 2 (triple core)
49
+ [INFO] Compiler version: 5.0-patch1-dirty a512c95e-dirty
50
+ [INFO] Using provider: AxEngineExecutionProvider
51
+ [INFO] Model type: 2 (triple core)
52
+ [INFO] Compiler version: 5.0-patch1-dirty a512c95e-dirty
53
+ [INFO] Using provider: AxEngineExecutionProvider
54
+ [INFO] Model type: 2 (triple core)
55
+ [INFO] Compiler version: 5.0-patch1-dirty a512c95e-dirty
56
+ [INFO] Using provider: AxEngineExecutionProvider
57
+ [INFO] Model type: 2 (triple core)
58
+ [INFO] Compiler version: 5.0-patch1-dirty a512c95e-dirty
59
+ Get prompts:
60
+ input_box: [ 75 275 1725 850]
61
+ input_point_coords: None
62
+ input_point_labels: None
63
+ Only box input provided
64
+ Get dense_embeddings_no_mask
65
+ [0.9777304]
66
+ ✅ Saved: ./results/mask_1.png
67
+ ```
68
+
69
+ 保存结果在 `./results` 目录下:
70
+ ![image](./results/mask_1.png)
71
+
72
+ ```
73
+ (myenv) root@ax650:~/EdgeTAM# python3 image_prediction_ax.py --image_path ./examples/images/truck.jpg --input_box 425,600,700,875 --input_point_coords 575,750 --input_point_labels 0
74
+ [INFO] Available providers: ['AxEngineExecutionProvider']
75
+ Loading EdgeTAM Onnx models...
76
+ [INFO] Using provider: AxEngineExecutionProvider
77
+ [INFO] Chip type: ChipType.MC50
78
+ [INFO] VNPU type: VNPUType.DISABLED
79
+ [INFO] Engine version: 2.12.0s
80
+ [INFO] Model type: 2 (triple core)
81
+ [INFO] Compiler version: 5.0-patch1-dirty a512c95e-dirty
82
+ [INFO] Using provider: AxEngineExecutionProvider
83
+ [INFO] Model type: 2 (triple core)
84
+ [INFO] Compiler version: 5.0-patch1-dirty a512c95e-dirty
85
+ [INFO] Using provider: AxEngineExecutionProvider
86
+ [INFO] Model type: 2 (triple core)
87
+ [INFO] Compiler version: 5.0-patch1-dirty a512c95e-dirty
88
+ [INFO] Using provider: AxEngineExecutionProvider
89
+ [INFO] Model type: 2 (triple core)
90
+ [INFO] Compiler version: 5.0-patch1-dirty a512c95e-dirty
91
+ ['575,750']
92
+ 575,750
93
+ Get prompts:
94
+ input_box: [425 600 700 875]
95
+ input_point_coords: [[575 750]]
96
+ input_point_labels: [0]
97
+ Get dense_embeddings_no_mask
98
+ [0.90291053]
99
+ ✅ Saved: ./results/mask_1.png
100
+
101
+ ```
102
+ ![image](./results/mask_5.png)
axmodel/dense_embeddings_no_mask.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34b107f2e768982d45ede36d15d01e50ede546e652672578217a2b9dc0f0ac24
3
+ size 4194432
image_prediction_ax.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from PIL import Image
5
+ # import onnxruntime as ort
6
+ import cv2
7
+ from utils.EdgeTAM_image_predictor import ImagePredictor
8
+ import argparse
9
+
10
+ np.random.seed(3)
11
+
12
+ def show_mask(mask, ax, random_color=False, borders = True):
13
+ if random_color:
14
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
15
+ else:
16
+ color = np.array([30/255, 144/255, 255/255, 0.6])
17
+ h, w = mask.shape[-2:]
18
+ mask = mask.astype(np.uint8)
19
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
20
+ if borders:
21
+ import cv2
22
+ contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
23
+ # print(contours)
24
+ # Try to smooth contours
25
+ contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
26
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
27
+ # cv2.imwrite('./mask_image.jpg', mask_image)
28
+ ax.imshow(mask_image)
29
+
30
+ def show_points(coords, labels, ax, marker_size=375):
31
+ pos_points = coords[labels==1]
32
+ neg_points = coords[labels==0]
33
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
34
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
35
+
36
+ def show_box(box, ax):
37
+ x0, y0 = box[0], box[1]
38
+ w, h = box[2] - box[0], box[3] - box[1]
39
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
40
+
41
+ def show_masks(
42
+ image,
43
+ masks,
44
+ scores,
45
+ point_coords=None,
46
+ box_coords=None,
47
+ input_labels=None,
48
+ borders=True,
49
+ save_dir="./results", # 新增:保存目录
50
+ base_name="mask" # 新增:基础文件名
51
+ ):
52
+ """
53
+ 保存分割结果图像到文件,不再显示。
54
+
55
+ Args:
56
+ save_dir: 保存目录(会自动创建)
57
+ base_name: 文件名前缀,如 "mask" → "mask_1.png"
58
+ """
59
+ os.makedirs(save_dir, exist_ok=True)
60
+
61
+ for i, (mask, score) in enumerate(zip(masks, scores)):
62
+ plt.figure(figsize=(10, 10))
63
+ plt.imshow(image)
64
+ show_mask(mask, plt.gca(), borders=borders)
65
+
66
+ if point_coords is not None:
67
+ assert input_labels is not None
68
+ show_points(point_coords, input_labels, plt.gca())
69
+
70
+ if box_coords is not None:
71
+ show_box(box_coords, plt.gca())
72
+
73
+ if len(scores) > 1:
74
+ plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
75
+
76
+ plt.axis('off')
77
+
78
+ # 保存图像(不再 plt.show())
79
+ save_path = os.path.join(save_dir, f"{base_name}_{i+1}.png")
80
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=150)
81
+ plt.close() # 释放内存
82
+ print(f"✅ Saved: {save_path}")
83
+
84
+
85
+ if __name__ == "__main__":
86
+ argparser = argparse.ArgumentParser()
87
+ argparser.add_argument("--image_path", type=str, default="./examples/images/truck.jpg", help="Path to the input image.")
88
+ argparser.add_argument("--model_path", type=str, default="./axmodel", help="Path to the ImagePredictor model.")
89
+ argparser.add_argument("--save_dir", type=str, default="./results", help="Directory to save the output images.")
90
+ argparser.add_argument("--input_box", type=str, default=None, help="Input box coordinates as x1,y1,x2,y2")
91
+ argparser.add_argument("--input_mask", type=str, default=None, help="Path to the input mask numpy file.")
92
+ argparser.add_argument("--input_point_coords", type=str, default=None, help="Input point coordinates as x1,y1 or x1,y1:x2,y2")
93
+ argparser.add_argument("--input_point_labels", type=str, default=None, help="Input point labels as 1 or 0 or 1:0")
94
+
95
+ args = argparser.parse_args()
96
+
97
+ # load image
98
+ image = np.array(Image.open(args.image_path).convert("RGB"))
99
+
100
+ predictor = ImagePredictor(args.model_path)
101
+
102
+ predictor.set_image(image)
103
+
104
+ # define input prompts
105
+ if args.input_mask is not None:
106
+ input_mask = np.load(args.input_mask)
107
+ else:
108
+ input_mask = np.zeros((1, 256, 256), dtype=np.float32)
109
+
110
+ if args.input_box is not None:
111
+ input_box = np.array([int(x) for x in args.input_box.split(",")])
112
+ else:
113
+ input_box = None
114
+
115
+ if args.input_point_coords is not None:
116
+ input_point_coords = np.array([[int(coord) for coord in point.split(",")] for point in args.input_point_coords.split(":")])
117
+ else:
118
+ input_point_coords = None
119
+
120
+ if args.input_point_labels is not None:
121
+ input_point_labels = np.array([int(label) for label in args.input_point_labels.split(":")])
122
+ else:
123
+ input_point_labels = None
124
+
125
+ if input_box is None and input_point_coords is None:
126
+ raise ValueError("At least one of input_box or input_point_coords must be provided.")
127
+
128
+ print("Get prompts: ")
129
+ print(f" input_box: {input_box}")
130
+ print(f" input_point_coords: {input_point_coords}")
131
+ print(f" input_point_labels: {input_point_labels}")
132
+
133
+ #only box
134
+ # input_box = np.array([75, 275, 1725, 850])
135
+ # input_point_coords = None
136
+ # input_point_labels = None
137
+
138
+ # input_box = np.array([1375, 550, 1650, 800])
139
+ # input_point_coords = None
140
+ # input_point_labels = None
141
+
142
+ #only point
143
+ # input_box = None
144
+ # input_point_coords = np.array([[500, 375], [1125, 625]])
145
+ # input_point_labels = np.array([1, 1])
146
+
147
+ # input_box = None
148
+ # input_point_coords = np.array([[500, 375], [1125, 625]])
149
+ # input_point_labels = np.array([1, 0])
150
+
151
+ #point + box
152
+ # input_box = np.array([425, 600, 700, 875])
153
+ # input_point_coords = np.array([[575, 750]])
154
+ # input_point_labels = np.array([0])
155
+
156
+ # input_mask = np.load("./axmodel/logits.npy")
157
+
158
+ # predict masks
159
+ masks, scores, logits = predictor.predict(
160
+ point_coords=input_point_coords,
161
+ point_labels=input_point_labels,
162
+ box=input_box,
163
+ mask_input=input_mask,
164
+ multimask_output=False,
165
+ )
166
+
167
+ sorted_ind = np.argsort(scores)[::-1]
168
+ masks = masks[sorted_ind]
169
+ scores = scores[sorted_ind]
170
+ logits = logits[sorted_ind]
171
+ print(scores)
172
+
173
+ # visualize results
174
+ show_masks(
175
+ image,
176
+ masks,
177
+ scores,
178
+ point_coords=input_point_coords,
179
+ box_coords=input_box,
180
+ input_labels=input_point_labels,
181
+ borders=True,
182
+ )
image_prediction_onnx.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from PIL import Image
5
+ import onnxruntime as ort
6
+ import cv2
7
+ from utils.EdgeTAM_image_predictor_onnx import ImagePredictor
8
+ import argparse
9
+
10
+ np.random.seed(3)
11
+
12
+ def show_mask(mask, ax, random_color=False, borders = True):
13
+ if random_color:
14
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
15
+ else:
16
+ color = np.array([30/255, 144/255, 255/255, 0.6])
17
+ h, w = mask.shape[-2:]
18
+ mask = mask.astype(np.uint8)
19
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
20
+ if borders:
21
+ import cv2
22
+ contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
23
+ # print(contours)
24
+ # Try to smooth contours
25
+ contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
26
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
27
+ # cv2.imwrite('./mask_image.jpg', mask_image)
28
+ ax.imshow(mask_image)
29
+
30
+ def show_points(coords, labels, ax, marker_size=375):
31
+ pos_points = coords[labels==1]
32
+ neg_points = coords[labels==0]
33
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
34
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
35
+
36
+ def show_box(box, ax):
37
+ x0, y0 = box[0], box[1]
38
+ w, h = box[2] - box[0], box[3] - box[1]
39
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
40
+
41
+ def show_masks(
42
+ image,
43
+ masks,
44
+ scores,
45
+ point_coords=None,
46
+ box_coords=None,
47
+ input_labels=None,
48
+ borders=True,
49
+ save_dir="./results", # 新增:保存目录
50
+ base_name="mask" # 新增:基础文件名
51
+ ):
52
+ """
53
+ 保存分割结果图像到文件,不再显示。
54
+
55
+ Args:
56
+ save_dir: 保存目录(会自动创建)
57
+ base_name: 文件名前缀,如 "mask" → "mask_1.png"
58
+ """
59
+ os.makedirs(save_dir, exist_ok=True)
60
+
61
+ for i, (mask, score) in enumerate(zip(masks, scores)):
62
+ plt.figure(figsize=(10, 10))
63
+ plt.imshow(image)
64
+ show_mask(mask, plt.gca(), borders=borders)
65
+
66
+ if point_coords is not None:
67
+ assert input_labels is not None
68
+ show_points(point_coords, input_labels, plt.gca())
69
+
70
+ if box_coords is not None:
71
+ show_box(box_coords, plt.gca())
72
+
73
+ if len(scores) > 1:
74
+ plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
75
+
76
+ plt.axis('off')
77
+
78
+ # 保存图像(不再 plt.show())
79
+ save_path = os.path.join(save_dir, f"{base_name}_{i+1}.png")
80
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=150)
81
+ plt.close() # 释放内存
82
+ print(f"✅ Saved: {save_path}")
83
+
84
+
85
+ if __name__ == "__main__":
86
+ argparser = argparse.ArgumentParser()
87
+ argparser.add_argument("--image_path", type=str, default="./examples/images/truck.jpg", help="Path to the input image.")
88
+ argparser.add_argument("--model_path", type=str, default="./onnx_models", help="Path to the ImagePredictor model.")
89
+ argparser.add_argument("--save_dir", type=str, default="./results", help="Directory to save the output images.")
90
+ argparser.add_argument("--input_box", type=str, default="425,600,700,875", help="Input box coordinates as x1,y1,x2,y2")
91
+ argparser.add_argument("--input_mask", type=str, default=None, help="Path to the input mask numpy file.")
92
+ argparser.add_argument("--input_point_coords", type=str, default="575,750", help="Input point coordinates as x1,y1 or x1,y1:x2,y2")
93
+ argparser.add_argument("--input_point_labels", type=str, default="0", help="Input point labels as 1 or 0 or 1:0")
94
+
95
+ args = argparser.parse_args()
96
+
97
+ # load image
98
+ image = np.array(Image.open(args.image_path).convert("RGB"))
99
+
100
+ predictor = ImagePredictor(args.model_path)
101
+
102
+ predictor.set_image(image)
103
+
104
+ # define input prompts
105
+ if args.input_mask is not None:
106
+ input_mask = np.load(args.input_mask)
107
+ else:
108
+ input_mask = np.zeros((1, 256, 256), dtype=np.float32)
109
+
110
+ if args.input_box is not None:
111
+ input_box = np.array([int(x) for x in args.input_box.split(",")])
112
+ else:
113
+ input_box = None
114
+
115
+ if args.input_point_coords is not None:
116
+ input_point_coords = np.array([[int(coord) for coord in point.split(",")] for point in args.input_point_coords.split(":")])
117
+ else:
118
+ input_point_coords = None
119
+
120
+ if args.input_point_labels is not None:
121
+ input_point_labels = np.array([int(label) for label in args.input_point_labels.split(":")])
122
+ else:
123
+ input_point_labels = None
124
+
125
+ if input_box is None and input_point_coords is None:
126
+ raise ValueError("At least one of input_box or input_point_coords must be provided.")
127
+
128
+ #only box
129
+ # input_box = np.array([75, 275, 1725, 850])
130
+ # input_point_coords = None
131
+ # input_point_labels = None
132
+
133
+ # input_box = np.array([1375, 550, 1650, 800])
134
+ # input_point_coords = None
135
+ # input_point_labels = None
136
+
137
+ #only point
138
+ # input_box = None
139
+ # input_point_coords = np.array([[500, 375], [1125, 625]])
140
+ # input_point_labels = np.array([1, 1])
141
+
142
+ # input_box = None
143
+ # input_point_coords = np.array([[500, 375], [1125, 625]])
144
+ # input_point_labels = np.array([1, 0])
145
+
146
+ #point + box
147
+ # input_box = np.array([425, 600, 700, 875])
148
+ # input_point_coords = np.array([[575, 750]])
149
+ # input_point_labels = np.array([0])
150
+ # input_mask = np.load("./axmodel/logits.npy")
151
+ # predict masks
152
+ masks, scores, logits = predictor.predict(
153
+ point_coords=input_point_coords,
154
+ point_labels=input_point_labels,
155
+ box=input_box,
156
+ mask_input=input_mask,
157
+ multimask_output=False,
158
+ )
159
+
160
+ sorted_ind = np.argsort(scores)[::-1]
161
+ masks = masks[sorted_ind]
162
+ scores = scores[sorted_ind]
163
+ logits = logits[sorted_ind]
164
+ # np.save("./results/logits_onnx.npy", logits)
165
+ print(scores)
166
+ # visualize results
167
+ show_masks(
168
+ image,
169
+ masks,
170
+ scores,
171
+ point_coords=input_point_coords,
172
+ box_coords=input_box,
173
+ input_labels=input_point_labels,
174
+ borders=True,
175
+ )
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ python-opencv
3
+ onnxruntime
4
+ albumentations
5
+ matplotlib