""" SAM2 Interaction Tools Handles SAM2 mask generation with user clicks """ import sys sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2") import numpy as np from PIL import Image from .base_segmenter import BaseSegmenter from .painter import mask_painter, point_painter mask_color = 3 mask_alpha = 0.7 contour_color = 1 contour_width = 5 point_color_ne = 8 # positive points point_color_ps = 50 # negative points point_alpha = 0.9 point_radius = 15 class SamControler: def __init__(self, SAM_checkpoint, model_type, device): """ Initialize SAM controller Args: SAM_checkpoint: Path to SAM2 checkpoint model_type: SAM2 model config file device: Device to run on """ self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) self.device = device def first_frame_click(self, image: np.ndarray, points: np.ndarray, labels: np.ndarray, multimask=True, mask_color=3): """ Generate mask from clicks on first frame Args: image: np.ndarray, (H, W, 3), RGB image points: np.ndarray, (N, 2), [x, y] coordinates labels: np.ndarray, (N,), 1 for positive, 0 for negative multimask: bool, whether to generate multiple masks mask_color: int, color ID for mask overlay Returns: mask: np.ndarray, (H, W), binary mask logit: np.ndarray, (H, W), mask logits painted_image: PIL.Image, visualization with mask and points """ # Check if we have positive clicks neg_flag = labels[-1] if neg_flag == 1: # Has positive click # First pass with points only prompts = { 'point_coords': points, 'point_labels': labels, } masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] # Refine with mask input prompts = { 'point_coords': points, 'point_labels': labels, 'mask_input': logit[None, :, :] } masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] else: # Only positive clicks prompts = { 'point_coords': points, 'point_labels': labels, } masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] # Paint mask on image painted_image = mask_painter( image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width ) # Paint positive points (label > 0) positive_points = np.squeeze(points[np.argwhere(labels > 0)], axis=1) if len(positive_points) > 0: painted_image = point_painter( painted_image, positive_points, point_color_ne, point_alpha, point_radius, contour_color, contour_width ) # Paint negative points (label < 1) negative_points = np.squeeze(points[np.argwhere(labels < 1)], axis=1) if len(negative_points) > 0: painted_image = point_painter( painted_image, negative_points, point_color_ps, point_alpha, point_radius, contour_color, contour_width ) painted_image = Image.fromarray(painted_image) return mask, logit, painted_image