wli1995 commited on
Commit
8c50c7c
·
verified ·
1 Parent(s): ff5b345

Upload 5 files

Browse files
utils/EdgeTAM_image_predictor.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ from PIL.Image import Image
13
+ from utils.transforms import SAM2Transforms, trunc_normal_
14
+ # import onnxruntime as ort
15
+ import axengine as ort
16
+ import cv2
17
+ import os
18
+
19
+ class ImagePredictor:
20
+ def __init__(
21
+ self,
22
+ model_path,
23
+ mask_threshold=0.0,
24
+ max_hole_area=0.0,
25
+ max_sprinkle_area=0.0,
26
+ resolution=1024,
27
+ **kwargs,
28
+ ) -> None:
29
+ """
30
+ Uses SAM-2 to calculate the image embedding for an image, and then
31
+ allow repeated, efficient mask prediction given prompts.
32
+
33
+ Arguments:
34
+ sam_model (Sam-2): The model to use for mask prediction.
35
+ mask_threshold (float): The threshold to use when converting mask logits
36
+ to binary masks. Masks are thresholded at 0 by default.
37
+ max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
38
+ the maximum area of max_hole_area in low_res_masks.
39
+ max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
40
+ the maximum area of max_sprinkle_area in low_res_masks.
41
+ """
42
+ super().__init__()
43
+
44
+ print("Loading EdgeTAM Onnx models...")
45
+ self.image_encoder = ort.InferenceSession(f"{model_path}/edgetam_image_encoder.axmodel")
46
+ self.prompt_encoder = ort.InferenceSession(f"{model_path}/edgetam_prompt_encoder.axmodel")
47
+ self.prompt_mask_encoder = ort.InferenceSession(f"{model_path}/edgetam_prompt_mask_encoder.axmodel")
48
+ self.mask_decoder = ort.InferenceSession(f"{model_path}/edgetam_mask_decoder.axmodel")
49
+
50
+ self.model_path = model_path
51
+
52
+ self._transforms = SAM2Transforms(
53
+ resolution=resolution,
54
+ mask_threshold=mask_threshold,
55
+ max_hole_area=max_hole_area,
56
+ max_sprinkle_area=max_sprinkle_area,
57
+ )
58
+ # Predictor state
59
+ self._is_image_set = False
60
+ self._features = None
61
+ self._orig_hw = None
62
+ # Whether the predictor is set for single image or a batch of images
63
+ self._is_batch = False
64
+
65
+ # Predictor config
66
+ self.mask_threshold = mask_threshold
67
+ self.num_feature_levels = 3
68
+ self.no_mem_embed = np.zeros((1, 1, 256))
69
+ trunc_normal_(self.no_mem_embed, std=0.02)
70
+
71
+ # Spatial dim for backbone feature maps
72
+ self._bb_feat_sizes = [
73
+ (256, 256),
74
+ (128, 128),
75
+ (64, 64),
76
+ ]
77
+
78
+ def set_image(
79
+ self,
80
+ image: Union[np.ndarray, Image],
81
+ ) -> None:
82
+ """
83
+ Calculates the image embeddings for the provided image, allowing
84
+ masks to be predicted with the 'predict' method.
85
+
86
+ Arguments:
87
+ image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
88
+ with pixel values in [0, 255].
89
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
90
+ """
91
+ self.reset_predictor()
92
+ # Transform the image to the form expected by the model
93
+ if isinstance(image, np.ndarray):
94
+ logging.info("For numpy array image, we assume (HxWxC) format")
95
+ self._orig_hw = [image.shape[:2]]
96
+
97
+ input_image = self._transforms(image).astype(np.float32) # return 3xHxW np.ndarray
98
+ input_image = input_image[None, ...]
99
+ # np.save(f"{self.path}/input_image.npy", input_image)
100
+
101
+ assert (
102
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
103
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
104
+ logging.info("Computing image embeddings for the provided image...")
105
+ vision_feats = self.image_encoder.run(None, {"input_image": input_image.astype(np.float32)})
106
+
107
+ feats = [
108
+ np.transpose(feat[:, 0, :].reshape(H, W, feat.shape[-1]), (2, 0, 1))[np.newaxis, :]
109
+ for feat, (H, W) in zip(reversed(vision_feats), reversed(self._bb_feat_sizes))
110
+ ][::-1]
111
+
112
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
113
+ self._is_image_set = True
114
+ logging.info("Image embeddings computed.")
115
+
116
+ def predict(
117
+ self,
118
+ point_coords: Optional[np.ndarray] = None,
119
+ point_labels: Optional[np.ndarray] = None,
120
+ box: Optional[np.ndarray] = None,
121
+ mask_input: Optional[np.ndarray] = None,
122
+ multimask_output: bool = True,
123
+ return_logits: bool = False,
124
+ normalize_coords=True,
125
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
126
+ """
127
+ Predict masks for the given input prompts, using the currently set image.
128
+
129
+ Arguments:
130
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
131
+ model. Each point is in (X,Y) in pixels.
132
+ point_labels (np.ndarray or None): A length N array of labels for the
133
+ point prompts. 1 indicates a foreground point and 0 indicates a
134
+ background point.
135
+ box (np.ndarray or None): A length 4 array given a box prompt to the
136
+ model, in XYXY format.
137
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
138
+ coming from a previous prediction iteration. Has form 1xHxW, where
139
+ for SAM, H=W=256.
140
+ multimask_output (bool): If true, the model will return three masks.
141
+ For ambiguous input prompts (such as a single click), this will often
142
+ produce better masks than a single prediction. If only a single
143
+ mask is needed, the model's predicted quality score can be used
144
+ to select the best mask. For non-ambiguous prompts, such as multiple
145
+ input prompts, multimask_output=False can give better results.
146
+ return_logits (bool): If true, returns un-thresholded masks logits
147
+ instead of a binary mask.
148
+ normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
149
+
150
+ Returns:
151
+ (np.ndarray): The output masks in CxHxW format, where C is the
152
+ number of masks, and (H, W) is the original image size.
153
+ (np.ndarray): An array of length C containing the model's
154
+ predictions for the quality of each mask.
155
+ (np.ndarray): An array of shape CxHxW, where C is the number
156
+ of masks and H=W=256. These low resolution logits can be passed to
157
+ a subsequent iteration as mask input.
158
+ """
159
+ if not self._is_image_set:
160
+ raise RuntimeError(
161
+ "An image must be set with .set_image(...) before mask prediction."
162
+ )
163
+
164
+ # Transform input prompts
165
+
166
+ #type check
167
+ point_coords = point_coords.astype(np.float32) if point_coords is not None else None
168
+ point_labels = point_labels.astype(np.float32) if point_labels is not None else None
169
+ box = box.astype(np.float32) if box is not None else None
170
+ mask_input = mask_input.astype(np.float32) if mask_input is not None else None
171
+
172
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
173
+ point_coords, point_labels, box, mask_input, normalize_coords
174
+ )
175
+
176
+ masks, iou_predictions, low_res_masks = self._predict(
177
+ unnorm_coords,
178
+ labels,
179
+ unnorm_box,
180
+ mask_input,
181
+ multimask_output,
182
+ return_logits=return_logits,
183
+ )
184
+
185
+ masks_np = masks
186
+
187
+ iou_predictions_np = iou_predictions[0]
188
+ low_res_masks_np = low_res_masks[0]
189
+ return masks_np, iou_predictions_np, low_res_masks_np
190
+
191
+ def _prep_prompts(
192
+ self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
193
+ ):
194
+
195
+ unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
196
+ if point_coords is not None:
197
+ assert (
198
+ point_labels is not None
199
+ ), "point_labels must be supplied if point_coords is supplied."
200
+ unnorm_coords = self._transforms.transform_coords(
201
+ point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
202
+ )
203
+
204
+ if len(unnorm_coords.shape) == 2:
205
+ unnorm_coords, labels = unnorm_coords[np.newaxis, ...], point_labels[np.newaxis, ...]
206
+ if box is not None:
207
+ unnorm_box = self._transforms.transform_boxes(
208
+ box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
209
+ ) # Bx2x2
210
+ if mask_logits is not None:
211
+ if len(mask_logits.shape) == 3:
212
+ mask_logits = mask_logits[np.newaxis, :, :, :]
213
+
214
+ return mask_logits, unnorm_coords, labels, unnorm_box
215
+
216
+ def _predict(
217
+ self,
218
+ point_coords,
219
+ point_labels,
220
+ boxes = None,
221
+ mask_input = None,
222
+ multimask_output = True,
223
+ return_logits = False,
224
+ img_idx = -1,
225
+ ):
226
+ """
227
+ Predict masks for the given input prompts, using the currently set image.
228
+ Input prompts are batched torch tensors and are expected to already be
229
+ transformed to the input frame using SAM2Transforms.
230
+
231
+ Arguments:
232
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
233
+ model. Each point is in (X,Y) in pixels.
234
+ point_labels (torch.Tensor or None): A BxN array of labels for the
235
+ point prompts. 1 indicates a foreground point and 0 indicates a
236
+ background point.
237
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
238
+ model, in XYXY format.
239
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
240
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
241
+ for SAM, H=W=256. Masks returned by a previous iteration of the
242
+ predict method do not need further transformation.
243
+ multimask_output (bool): If true, the model will return three masks.
244
+ For ambiguous input prompts (such as a single click), this will often
245
+ produce better masks than a single prediction. If only a single
246
+ mask is needed, the model's predicted quality score can be used
247
+ to select the best mask. For non-ambiguous prompts, such as multiple
248
+ input prompts, multimask_output=False can give better results.
249
+ return_logits (bool): If true, returns un-thresholded masks logits
250
+ instead of a binary mask.
251
+
252
+ Returns:
253
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
254
+ number of masks, and (H, W) is the original image size.
255
+ (torch.Tensor): An array of shape BxC containing the model's
256
+ predictions for the quality of each mask.
257
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
258
+ of masks and H=W=256. These low res logits can be passed to
259
+ a subsequent iteration as mask input.
260
+ """
261
+ if not self._is_image_set:
262
+ raise RuntimeError(
263
+ "An image must be set with .set_image(...) before mask prediction."
264
+ )
265
+
266
+ if point_coords is not None:
267
+ concat_points = (point_coords, point_labels)
268
+ else:
269
+ concat_points = None
270
+
271
+ # Embed prompts
272
+ if boxes is not None:
273
+ box_coords = boxes.reshape(-1, 2, 2)
274
+ box_labels = np.array([[2, 3]], dtype=np.float32)
275
+ box_labels = box_labels.repeat(boxes.shape[0], 1)
276
+ # we merge "boxes" and "points" into a single "concat_points" input (where
277
+ # boxes are added at the beginning) to sam_prompt_encoder
278
+ if concat_points is not None:
279
+ concat_coords = np.concatenate([box_coords, concat_points[0]], axis=1)
280
+ concat_labels = np.concatenate([box_labels, concat_points[1]], axis=1)
281
+ concat_points = (concat_coords, concat_labels)
282
+ else:
283
+ print("Only box input provided")
284
+ concat_points = (box_coords, box_labels)
285
+
286
+ # assert concat_points[0].shape[1] > 4, "only support points < 4"
287
+
288
+ input_coords = np.tile(concat_points[0], (4, 1))[:, :4, :]
289
+ input_labels = np.tile(concat_points[1], (4))[:, :4]
290
+
291
+ # print("sparse_embeddings_tmp shape:", sparse_embeddings_tmp.shape)
292
+ if mask_input.all() == 0:
293
+ print("Get dense_embeddings_no_mask")
294
+ sparse_embeddings = self.prompt_encoder.run(
295
+ None,
296
+ {
297
+ "point_coords": input_coords if concat_points is not None else np.array([]),
298
+ "point_labels": input_labels if concat_points is not None else np.array([])
299
+ # "boxes": boxes if boxes is not None else np.zeros((1, 4), dtype=np.float32)
300
+ },
301
+ )[0]
302
+ dense_embeddings = np.load(f"{self.model_path}/dense_embeddings_no_mask.npy")
303
+ else:
304
+ print("Get dense_embeddings_mask")
305
+ sparse_embeddings = self.prompt_encoder.run(
306
+ None,
307
+ {
308
+ "point_coords": input_coords if concat_points is not None else np.array([]),
309
+ "point_labels": input_labels if concat_points is not None else np.array([])
310
+ # "boxes": boxes if boxes is not None else np.zeros((1, 4), dtype=np.float32)
311
+ },
312
+ )[0]
313
+ dense_embeddings = self.prompt_mask_encoder.run(
314
+ None,
315
+ {
316
+ "input.1": mask_input
317
+ },
318
+ )[0]
319
+
320
+ # Predict masks
321
+ batched_mode = (
322
+ concat_points is not None and concat_points[0].shape[0] > 1
323
+ ) # multi object prediction
324
+
325
+ high_res_features = [
326
+ feat_level[img_idx][np.newaxis, ...]
327
+ for feat_level in self._features["high_res_feats"]
328
+ ]
329
+
330
+ low_res_masks, iou_predictions = self.mask_decoder.run(
331
+ None,
332
+ {
333
+ "image_embeddings": self._features["image_embed"][img_idx][np.newaxis, ...],
334
+ # "image_pe": image_pe,
335
+ "sparse_prompt_embeddings": sparse_embeddings,
336
+ "dense_prompt_embeddings": dense_embeddings,
337
+ "high_res_feat_0": high_res_features[0],
338
+ "high_res_feat_1": high_res_features[1],
339
+ # "multimask_output": np.array([1 if multimask_output else 0], dtype=np.int32),
340
+ },
341
+ )
342
+
343
+ # Upscale the masks to the original image resolution
344
+ mask = low_res_masks[0].transpose(1, 2, 0) # HxWxC
345
+ resize_masks = cv2.resize(mask, (self._orig_hw[img_idx][1], self._orig_hw[img_idx][0]), interpolation=cv2.INTER_LINEAR)
346
+
347
+ resize_masks = resize_masks[np.newaxis, ...] # HxWx1xC
348
+ resize_masks = np.clip(resize_masks, -32.0, 32.0) # 1xCxHxW
349
+
350
+ if not return_logits:
351
+ resize_masks = resize_masks > self.mask_threshold
352
+
353
+ return resize_masks, iou_predictions, low_res_masks
354
+
355
+ def get_image_embedding(self):
356
+ """
357
+ Returns the image embeddings for the currently set image, with
358
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
359
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
360
+ """
361
+ if not self._is_image_set:
362
+ raise RuntimeError(
363
+ "An image must be set with .set_image(...) to generate an embedding."
364
+ )
365
+ assert (
366
+ self._features is not None
367
+ ), "Features must exist if an image has been set."
368
+ return self._features["image_embed"]
369
+
370
+ def reset_predictor(self) -> None:
371
+ """
372
+ Resets the image embeddings and other state variables.
373
+ """
374
+ self._is_image_set = False
375
+ self._features = None
376
+ self._orig_hw = None
377
+ self._is_batch = False
378
+
379
+ def _prepare_backbone_features(self, backbone_out):
380
+ """Prepare and flatten visual features."""
381
+ backbone_out = backbone_out.copy()
382
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
383
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
384
+
385
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
386
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
387
+
388
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
389
+ # flatten NxCxHxW to HWxNxC
390
+ vision_feats = [x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 0, 1) for x in feature_maps]
391
+
392
+ vision_pos_embeds = [x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 0, 1) for x in vision_pos_embeds]
393
+
394
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
utils/EdgeTAM_image_predictor_onnx.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ from PIL.Image import Image
13
+ from utils.transforms import SAM2Transforms, trunc_normal_
14
+ import onnxruntime as ort
15
+ # import axengine as ort
16
+ import cv2
17
+ import os
18
+
19
+ class ImagePredictor:
20
+ def __init__(
21
+ self,
22
+ model_path,
23
+ mask_threshold=0.0,
24
+ max_hole_area=0.0,
25
+ max_sprinkle_area=0.0,
26
+ resolution=1024,
27
+ **kwargs,
28
+ ) -> None:
29
+ """
30
+ Uses SAM-2 to calculate the image embedding for an image, and then
31
+ allow repeated, efficient mask prediction given prompts.
32
+
33
+ Arguments:
34
+ sam_model (Sam-2): The model to use for mask prediction.
35
+ mask_threshold (float): The threshold to use when converting mask logits
36
+ to binary masks. Masks are thresholded at 0 by default.
37
+ max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
38
+ the maximum area of max_hole_area in low_res_masks.
39
+ max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
40
+ the maximum area of max_sprinkle_area in low_res_masks.
41
+ """
42
+ super().__init__()
43
+
44
+ print("Loading EdgeTAM Onnx models...")
45
+ self.image_encoder = ort.InferenceSession(f"{model_path}/edgetam_image_encoder.onnx")
46
+ self.prompt_encoder = ort.InferenceSession(f"{model_path}/edgetam_prompt_encoder.onnx")
47
+ self.prompt_mask_encoder = ort.InferenceSession(f"{model_path}/edgetam_prompt_mask_encoder.onnx")
48
+ self.mask_decoder = ort.InferenceSession(f"{model_path}/edgetam_mask_decoder.onnx")
49
+
50
+ self.model_path = model_path
51
+
52
+ self._transforms = SAM2Transforms(
53
+ resolution=resolution,
54
+ mask_threshold=mask_threshold,
55
+ max_hole_area=max_hole_area,
56
+ max_sprinkle_area=max_sprinkle_area,
57
+ onnx=True
58
+ )
59
+ # Predictor state
60
+ self._is_image_set = False
61
+ self._features = None
62
+ self._orig_hw = None
63
+ # Whether the predictor is set for single image or a batch of images
64
+ self._is_batch = False
65
+
66
+ # Predictor config
67
+ self.mask_threshold = mask_threshold
68
+ self.num_feature_levels = 3
69
+ self.no_mem_embed = np.zeros((1, 1, 256))
70
+ trunc_normal_(self.no_mem_embed, std=0.02)
71
+
72
+ # Spatial dim for backbone feature maps
73
+ self._bb_feat_sizes = [
74
+ (256, 256),
75
+ (128, 128),
76
+ (64, 64),
77
+ ]
78
+
79
+ def set_image(
80
+ self,
81
+ image: Union[np.ndarray, Image],
82
+ ) -> None:
83
+ """
84
+ Calculates the image embeddings for the provided image, allowing
85
+ masks to be predicted with the 'predict' method.
86
+
87
+ Arguments:
88
+ image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
89
+ with pixel values in [0, 255].
90
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
91
+ """
92
+ self.reset_predictor()
93
+ # Transform the image to the form expected by the model
94
+ if isinstance(image, np.ndarray):
95
+ logging.info("For numpy array image, we assume (HxWxC) format")
96
+ self._orig_hw = [image.shape[:2]]
97
+
98
+ input_image = self._transforms(image).astype(np.float32) # return 3xHxW np.ndarray
99
+ input_image = input_image[None, ...]
100
+ # np.save(f"{self.path}/input_image.npy", input_image)
101
+
102
+ assert (
103
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
104
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
105
+ logging.info("Computing image embeddings for the provided image...")
106
+ vision_feats = self.image_encoder.run(None, {"input_image": input_image.astype(np.float32)})
107
+
108
+ feats = [
109
+ np.transpose(feat[:, 0, :].reshape(H, W, feat.shape[-1]), (2, 0, 1))[np.newaxis, :]
110
+ for feat, (H, W) in zip(reversed(vision_feats), reversed(self._bb_feat_sizes))
111
+ ][::-1]
112
+
113
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
114
+ self._is_image_set = True
115
+ logging.info("Image embeddings computed.")
116
+
117
+ def predict(
118
+ self,
119
+ point_coords: Optional[np.ndarray] = None,
120
+ point_labels: Optional[np.ndarray] = None,
121
+ box: Optional[np.ndarray] = None,
122
+ mask_input: Optional[np.ndarray] = None,
123
+ multimask_output: bool = True,
124
+ return_logits: bool = False,
125
+ normalize_coords=True,
126
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
127
+ """
128
+ Predict masks for the given input prompts, using the currently set image.
129
+
130
+ Arguments:
131
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
132
+ model. Each point is in (X,Y) in pixels.
133
+ point_labels (np.ndarray or None): A length N array of labels for the
134
+ point prompts. 1 indicates a foreground point and 0 indicates a
135
+ background point.
136
+ box (np.ndarray or None): A length 4 array given a box prompt to the
137
+ model, in XYXY format.
138
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
139
+ coming from a previous prediction iteration. Has form 1xHxW, where
140
+ for SAM, H=W=256.
141
+ multimask_output (bool): If true, the model will return three masks.
142
+ For ambiguous input prompts (such as a single click), this will often
143
+ produce better masks than a single prediction. If only a single
144
+ mask is needed, the model's predicted quality score can be used
145
+ to select the best mask. For non-ambiguous prompts, such as multiple
146
+ input prompts, multimask_output=False can give better results.
147
+ return_logits (bool): If true, returns un-thresholded masks logits
148
+ instead of a binary mask.
149
+ normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
150
+
151
+ Returns:
152
+ (np.ndarray): The output masks in CxHxW format, where C is the
153
+ number of masks, and (H, W) is the original image size.
154
+ (np.ndarray): An array of length C containing the model's
155
+ predictions for the quality of each mask.
156
+ (np.ndarray): An array of shape CxHxW, where C is the number
157
+ of masks and H=W=256. These low resolution logits can be passed to
158
+ a subsequent iteration as mask input.
159
+ """
160
+ if not self._is_image_set:
161
+ raise RuntimeError(
162
+ "An image must be set with .set_image(...) before mask prediction."
163
+ )
164
+
165
+ # Transform input prompts
166
+
167
+ #type check
168
+ point_coords = point_coords.astype(np.float32) if point_coords is not None else None
169
+ point_labels = point_labels.astype(np.float32) if point_labels is not None else None
170
+ box = box.astype(np.float32) if box is not None else None
171
+ mask_input = mask_input.astype(np.float32) if mask_input is not None else None
172
+
173
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
174
+ point_coords, point_labels, box, mask_input, normalize_coords
175
+ )
176
+
177
+ masks, iou_predictions, low_res_masks = self._predict(
178
+ unnorm_coords,
179
+ labels,
180
+ unnorm_box,
181
+ mask_input,
182
+ multimask_output,
183
+ return_logits=return_logits,
184
+ )
185
+
186
+ masks_np = masks
187
+
188
+ iou_predictions_np = iou_predictions[0]
189
+ low_res_masks_np = low_res_masks[0]
190
+ return masks_np, iou_predictions_np, low_res_masks_np
191
+
192
+ def _prep_prompts(
193
+ self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
194
+ ):
195
+
196
+ unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
197
+ if point_coords is not None:
198
+ assert (
199
+ point_labels is not None
200
+ ), "point_labels must be supplied if point_coords is supplied."
201
+ unnorm_coords = self._transforms.transform_coords(
202
+ point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
203
+ )
204
+
205
+ if len(unnorm_coords.shape) == 2:
206
+ unnorm_coords, labels = unnorm_coords[np.newaxis, ...], point_labels[np.newaxis, ...]
207
+ if box is not None:
208
+ unnorm_box = self._transforms.transform_boxes(
209
+ box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
210
+ ) # Bx2x2
211
+ if mask_logits is not None:
212
+ if len(mask_logits.shape) == 3:
213
+ mask_logits = mask_logits[np.newaxis, :, :, :]
214
+
215
+ return mask_logits, unnorm_coords, labels, unnorm_box
216
+
217
+ def _predict(
218
+ self,
219
+ point_coords,
220
+ point_labels,
221
+ boxes = None,
222
+ mask_input = None,
223
+ multimask_output = True,
224
+ return_logits = False,
225
+ img_idx = -1,
226
+ ):
227
+ """
228
+ Predict masks for the given input prompts, using the currently set image.
229
+ Input prompts are batched torch tensors and are expected to already be
230
+ transformed to the input frame using SAM2Transforms.
231
+
232
+ Arguments:
233
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
234
+ model. Each point is in (X,Y) in pixels.
235
+ point_labels (torch.Tensor or None): A BxN array of labels for the
236
+ point prompts. 1 indicates a foreground point and 0 indicates a
237
+ background point.
238
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
239
+ model, in XYXY format.
240
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
241
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
242
+ for SAM, H=W=256. Masks returned by a previous iteration of the
243
+ predict method do not need further transformation.
244
+ multimask_output (bool): If true, the model will return three masks.
245
+ For ambiguous input prompts (such as a single click), this will often
246
+ produce better masks than a single prediction. If only a single
247
+ mask is needed, the model's predicted quality score can be used
248
+ to select the best mask. For non-ambiguous prompts, such as multiple
249
+ input prompts, multimask_output=False can give better results.
250
+ return_logits (bool): If true, returns un-thresholded masks logits
251
+ instead of a binary mask.
252
+
253
+ Returns:
254
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
255
+ number of masks, and (H, W) is the original image size.
256
+ (torch.Tensor): An array of shape BxC containing the model's
257
+ predictions for the quality of each mask.
258
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
259
+ of masks and H=W=256. These low res logits can be passed to
260
+ a subsequent iteration as mask input.
261
+ """
262
+ if not self._is_image_set:
263
+ raise RuntimeError(
264
+ "An image must be set with .set_image(...) before mask prediction."
265
+ )
266
+
267
+ if point_coords is not None:
268
+ concat_points = (point_coords, point_labels)
269
+ else:
270
+ concat_points = None
271
+
272
+ # Embed prompts
273
+ if boxes is not None:
274
+ box_coords = boxes.reshape(-1, 2, 2)
275
+ box_labels = np.array([[2, 3]], dtype=np.float32)
276
+ box_labels = box_labels.repeat(boxes.shape[0], 1)
277
+ # we merge "boxes" and "points" into a single "concat_points" input (where
278
+ # boxes are added at the beginning) to sam_prompt_encoder
279
+ if concat_points is not None:
280
+ concat_coords = np.concatenate([box_coords, concat_points[0]], axis=1)
281
+ concat_labels = np.concatenate([box_labels, concat_points[1]], axis=1)
282
+ concat_points = (concat_coords, concat_labels)
283
+ else:
284
+ print("Only box input provided")
285
+ concat_points = (box_coords, box_labels)
286
+
287
+ # assert concat_points[0].shape[1] > 4, "only support points < 4"
288
+
289
+ input_coords = np.tile(concat_points[0], (4, 1))[:, :4, :]
290
+ input_labels = np.tile(concat_points[1], (4))[:, :4]
291
+
292
+
293
+ # print("sparse_embeddings_tmp shape:", sparse_embeddings_tmp.shape)
294
+ if mask_input.all() == 0:
295
+ print("Get dense_embeddings_no_mask")
296
+ sparse_embeddings = self.prompt_encoder.run(
297
+ None,
298
+ {
299
+ "point_coords": input_coords if concat_points is not None else np.array([]),
300
+ "point_labels": input_labels if concat_points is not None else np.array([])
301
+ # "boxes": boxes if boxes is not None else np.zeros((1, 4), dtype=np.float32)
302
+ },
303
+ )[0]
304
+ # np.save(f"{self.path}/dense_embeddings_no_mask.npy", dense_embeddings)
305
+ dense_embeddings = np.load(f"{self.model_path}/dense_embeddings_no_mask.npy")
306
+
307
+ np.save(f"{self.model_path}/point_coords.npy", input_coords)
308
+ np.save(f"{self.model_path}/point_labels.npy", input_labels)
309
+ else:
310
+ print("Get dense_embeddings_mask")
311
+ sparse_embeddings = self.prompt_encoder.run(
312
+ None,
313
+ {
314
+ "point_coords": input_coords if concat_points is not None else np.array([]),
315
+ "point_labels": input_labels if concat_points is not None else np.array([])
316
+ # "boxes": boxes if boxes is not None else np.zeros((1, 4), dtype=np.float32)
317
+ },
318
+ )[0]
319
+ dense_embeddings = self.prompt_mask_encoder.run(
320
+ None,
321
+ {
322
+ "input.1": mask_input
323
+ },
324
+ )[0]
325
+
326
+ # Predict masks
327
+ batched_mode = (
328
+ concat_points is not None and concat_points[0].shape[0] > 1
329
+ ) # multi object prediction
330
+
331
+ high_res_features = [
332
+ feat_level[img_idx][np.newaxis, ...]
333
+ for feat_level in self._features["high_res_feats"]
334
+ ]
335
+
336
+ low_res_masks, iou_predictions = self.mask_decoder.run(
337
+ None,
338
+ {
339
+ "image_embeddings": self._features["image_embed"][img_idx][np.newaxis, ...],
340
+ # "image_pe": image_pe,
341
+ "sparse_prompt_embeddings": sparse_embeddings,
342
+ "dense_prompt_embeddings": dense_embeddings,
343
+ "high_res_feat_0": high_res_features[0],
344
+ "high_res_feat_1": high_res_features[1],
345
+ # "multimask_output": np.array([1 if multimask_output else 0], dtype=np.int32),
346
+ },
347
+ )
348
+
349
+ # Upscale the masks to the original image resolution
350
+ mask = low_res_masks[0].transpose(1, 2, 0) # HxWxC
351
+ resize_masks = cv2.resize(mask, (self._orig_hw[img_idx][1], self._orig_hw[img_idx][0]), interpolation=cv2.INTER_LINEAR)
352
+
353
+ resize_masks = resize_masks[np.newaxis, ...] # HxWx1xC
354
+ resize_masks = np.clip(resize_masks, -32.0, 32.0) # 1xCxHxW
355
+
356
+ if not return_logits:
357
+ resize_masks = resize_masks > self.mask_threshold
358
+
359
+ return resize_masks, iou_predictions, low_res_masks
360
+
361
+ def get_image_embedding(self):
362
+ """
363
+ Returns the image embeddings for the currently set image, with
364
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
365
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
366
+ """
367
+ if not self._is_image_set:
368
+ raise RuntimeError(
369
+ "An image must be set with .set_image(...) to generate an embedding."
370
+ )
371
+ assert (
372
+ self._features is not None
373
+ ), "Features must exist if an image has been set."
374
+ return self._features["image_embed"]
375
+
376
+ def reset_predictor(self) -> None:
377
+ """
378
+ Resets the image embeddings and other state variables.
379
+ """
380
+ self._is_image_set = False
381
+ self._features = None
382
+ self._orig_hw = None
383
+ self._is_batch = False
384
+
385
+ def _prepare_backbone_features(self, backbone_out):
386
+ """Prepare and flatten visual features."""
387
+ backbone_out = backbone_out.copy()
388
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
389
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
390
+
391
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
392
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
393
+
394
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
395
+ # flatten NxCxHxW to HWxNxC
396
+ vision_feats = [x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 0, 1) for x in feature_maps]
397
+
398
+ vision_pos_embeds = [x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 0, 1) for x in vision_pos_embeds]
399
+
400
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
utils/__pycache__/EdgeTAM_image_predictor.cpython-311.pyc ADDED
Binary file (20.9 kB). View file
 
utils/__pycache__/transforms.cpython-311.pyc ADDED
Binary file (7.15 kB). View file
 
utils/transforms.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import warnings
8
+
9
+ import albumentations as A
10
+ import numpy as np
11
+ from scipy.stats import truncnorm
12
+ import cv2
13
+
14
+ class SAM2Transforms():
15
+ def __init__(
16
+ self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0, onnx=False
17
+ ):
18
+ """
19
+ Transforms for SAM2.
20
+ """
21
+ super().__init__()
22
+ self.resolution = resolution
23
+ self.mask_threshold = mask_threshold
24
+ self.max_hole_area = max_hole_area
25
+ self.max_sprinkle_area = max_sprinkle_area
26
+ self.transforms = A.Compose([
27
+ A.Resize(height=resolution, width=resolution), # 先 resize
28
+ A.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet RGB mean
29
+ std=[0.229, 0.224, 0.225], # ImageNet RGB std
30
+ max_pixel_value=255.0, # 因为输入是 0-255 的 uint8
31
+ p=1.0)
32
+ ])
33
+ self.onnx = onnx
34
+
35
+ def __call__(self, x):
36
+ #x: np.ndarray, HWC, uint8, RGB
37
+ # x_normal = cv2.resize(x, (self.resolution, self.resolution), interpolation=cv2.INTER_LINEAR)
38
+ if self.onnx:
39
+ x_normal = self.transforms(image=x)['image']
40
+ return x_normal.transpose(2, 0, 1)
41
+ else:
42
+ x_normal = cv2.resize(x, (self.resolution, self.resolution), interpolation=cv2.INTER_LINEAR)
43
+ return x_normal.transpose(2, 0, 1)
44
+
45
+
46
+ def forward_batch(self, img_list):
47
+ #img_list: list of np.ndarray, HWC, uint8, RGB
48
+ img_batch = [self.transforms(img) for img in img_list]
49
+ img_batch = np.concatenate([img[np.newaxis, :].transpose(0, 3, 1, 2) for img in img_batch], axis=0)
50
+ return img_batch
51
+
52
+ def transform_coords(
53
+ self, coords, normalize=False, orig_hw=None
54
+ ):
55
+ """
56
+ Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
57
+ If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
58
+
59
+ Returns
60
+ Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
61
+ """
62
+ if normalize:
63
+ assert orig_hw is not None
64
+ h, w = orig_hw
65
+ coords = coords.copy()
66
+ coords[..., 0] = coords[..., 0] / w
67
+ coords[..., 1] = coords[..., 1] / h
68
+ coords = coords * self.resolution
69
+ return coords
70
+
71
+ def transform_boxes(
72
+ self, boxes, normalize=False, orig_hw=None
73
+ ):
74
+ """
75
+ Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
76
+ if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
77
+ """
78
+ boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
79
+ return boxes
80
+
81
+ """
82
+ def postprocess_masks(self, masks, orig_hw):
83
+ # Perform PostProcessing on output masks.
84
+ from sam2.utils.misc import get_connected_components
85
+
86
+ masks = masks.float()
87
+ input_masks = masks
88
+ mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
89
+ try:
90
+ if self.max_hole_area > 0:
91
+ # Holes are those connected components in background with area <= self.fill_hole_area
92
+ # (background regions are those with mask scores <= self.mask_threshold)
93
+ labels, areas = get_connected_components(
94
+ mask_flat <= self.mask_threshold
95
+ )
96
+ is_hole = (labels > 0) & (areas <= self.max_hole_area)
97
+ is_hole = is_hole.reshape_as(masks)
98
+ # We fill holes with a small positive mask score (10.0) to change them to foreground.
99
+ masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
100
+
101
+ if self.max_sprinkle_area > 0:
102
+ labels, areas = get_connected_components(
103
+ mask_flat > self.mask_threshold
104
+ )
105
+ is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
106
+ is_hole = is_hole.reshape_as(masks)
107
+ # We fill holes with negative mask score (-10.0) to change them to background.
108
+ masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
109
+ except Exception as e:
110
+ # Skip the post-processing step if the CUDA kernel fails
111
+ warnings.warn(
112
+ f"{e}\n\nSkipping the post-processing step due to the error above. You can "
113
+ "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
114
+ "functionality may be limited (which doesn't affect the results in most cases; see "
115
+ "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
116
+ category=UserWarning,
117
+ stacklevel=2,
118
+ )
119
+ masks = input_masks
120
+
121
+ masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
122
+ return masks
123
+ """
124
+
125
+ def trunc_normal_(arr, std=0.02, mean=0.0):
126
+ """
127
+ 用截断正态分布原地初始化 numpy array
128
+
129
+ 截断范围: [mean - 2*std, mean + 2*std]
130
+ """
131
+ # 计算截断边界(以标准差为单位)
132
+ a = (mean - 2 * std - mean) / std # = -2
133
+ b = (mean + 2 * std - mean) / std # = +2
134
+
135
+ # 生成截断正态分布样本
136
+ samples = truncnorm.rvs(a, b, loc=mean, scale=std, size=arr.shape)
137
+
138
+ # 原地赋值
139
+ arr[:] = samples