import torch import numpy as np import mmcv from mmengine.visualization import Visualizer from third_parts.sam2.build_sam import build_sam2 from third_parts.sam2.sam2_image_predictor import SAM2ImagePredictor from mmdet.structures.mask import bitmap_to_polygon IMG_PATH = 'assets/view.jpg' MODEL_CKPT = "work_dirs/ckpt/sam2_hiera_large.pt" MODEL_CFG = "sam2_hiera_l.yaml" def prepare(): torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True if __name__ == '__main__': prepare() sam2_model = build_sam2(MODEL_CFG, MODEL_CKPT, device="cuda") predictor = SAM2ImagePredictor(sam2_model) image = mmcv.imread(IMG_PATH) predictor.set_image(image) input_point = np.array([[500, 475]]) input_label = np.array([1]) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, ) sorted_ind = np.argsort(scores)[::-1] masks = masks[sorted_ind] scores = scores[sorted_ind] logits = logits[sorted_ind] visualizer = Visualizer(image=image) masks = masks.astype(bool) masks = masks[0:1] polygons = [] for i, mask in enumerate(masks): contours, _ = bitmap_to_polygon(mask) polygons.extend(contours) visualizer.draw_polygons(polygons, edge_colors='w', alpha=0.8) visualizer.draw_binary_masks(masks, alphas=0.8) visualizer.draw_points(input_point, 'r', marker='*') result = visualizer.get_image()