| import torch | |
| import numpy as np | |
| import mmcv | |
| from mmengine.visualization import Visualizer | |
| from third_parts.sam2.build_sam import build_sam2 | |
| from third_parts.sam2.sam2_image_predictor import SAM2ImagePredictor | |
| from mmdet.structures.mask import bitmap_to_polygon | |
| IMG_PATH = 'assets/view.jpg' | |
| MODEL_CKPT = "work_dirs/ckpt/sam2_hiera_large.pt" | |
| MODEL_CFG = "sam2_hiera_l.yaml" | |
| def prepare(): | |
| torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| if __name__ == '__main__': | |
| prepare() | |
| sam2_model = build_sam2(MODEL_CFG, MODEL_CKPT, device="cuda") | |
| predictor = SAM2ImagePredictor(sam2_model) | |
| image = mmcv.imread(IMG_PATH) | |
| predictor.set_image(image) | |
| input_point = np.array([[500, 475]]) | |
| input_label = np.array([1]) | |
| masks, scores, logits = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=True, | |
| ) | |
| sorted_ind = np.argsort(scores)[::-1] | |
| masks = masks[sorted_ind] | |
| scores = scores[sorted_ind] | |
| logits = logits[sorted_ind] | |
| visualizer = Visualizer(image=image) | |
| masks = masks.astype(bool) | |
| masks = masks[0:1] | |
| polygons = [] | |
| for i, mask in enumerate(masks): | |
| contours, _ = bitmap_to_polygon(mask) | |
| polygons.extend(contours) | |
| visualizer.draw_polygons(polygons, edge_colors='w', alpha=0.8) | |
| visualizer.draw_binary_masks(masks, alphas=0.8) | |
| visualizer.draw_points(input_point, 'r', marker='*') | |
| result = visualizer.get_image() | |