|
|
""" |
|
|
SAM2 Interaction Tools |
|
|
Handles SAM2 mask generation with user clicks |
|
|
""" |
|
|
|
|
|
import sys |
|
|
sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2") |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from .base_segmenter import BaseSegmenter |
|
|
from .painter import mask_painter, point_painter |
|
|
|
|
|
|
|
|
mask_color = 3 |
|
|
mask_alpha = 0.7 |
|
|
contour_color = 1 |
|
|
contour_width = 5 |
|
|
point_color_ne = 8 |
|
|
point_color_ps = 50 |
|
|
point_alpha = 0.9 |
|
|
point_radius = 15 |
|
|
|
|
|
|
|
|
class SamControler: |
|
|
def __init__(self, SAM_checkpoint, model_type, device): |
|
|
""" |
|
|
Initialize SAM controller |
|
|
|
|
|
Args: |
|
|
SAM_checkpoint: Path to SAM2 checkpoint |
|
|
model_type: SAM2 model config file |
|
|
device: Device to run on |
|
|
""" |
|
|
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) |
|
|
self.device = device |
|
|
|
|
|
def first_frame_click(self, image: np.ndarray, points: np.ndarray, |
|
|
labels: np.ndarray, multimask=True, mask_color=3): |
|
|
""" |
|
|
Generate mask from clicks on first frame |
|
|
|
|
|
Args: |
|
|
image: np.ndarray, (H, W, 3), RGB image |
|
|
points: np.ndarray, (N, 2), [x, y] coordinates |
|
|
labels: np.ndarray, (N,), 1 for positive, 0 for negative |
|
|
multimask: bool, whether to generate multiple masks |
|
|
mask_color: int, color ID for mask overlay |
|
|
|
|
|
Returns: |
|
|
mask: np.ndarray, (H, W), binary mask |
|
|
logit: np.ndarray, (H, W), mask logits |
|
|
painted_image: PIL.Image, visualization with mask and points |
|
|
""" |
|
|
|
|
|
neg_flag = labels[-1] |
|
|
|
|
|
if neg_flag == 1: |
|
|
|
|
|
prompts = { |
|
|
'point_coords': points, |
|
|
'point_labels': labels, |
|
|
} |
|
|
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) |
|
|
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] |
|
|
|
|
|
|
|
|
prompts = { |
|
|
'point_coords': points, |
|
|
'point_labels': labels, |
|
|
'mask_input': logit[None, :, :] |
|
|
} |
|
|
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) |
|
|
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] |
|
|
else: |
|
|
prompts = { |
|
|
'point_coords': points, |
|
|
'point_labels': labels, |
|
|
} |
|
|
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) |
|
|
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] |
|
|
|
|
|
|
|
|
painted_image = mask_painter( |
|
|
image, |
|
|
mask.astype('uint8'), |
|
|
mask_color, |
|
|
mask_alpha, |
|
|
contour_color, |
|
|
contour_width |
|
|
) |
|
|
|
|
|
|
|
|
positive_points = np.squeeze(points[np.argwhere(labels > 0)], axis=1) |
|
|
if len(positive_points) > 0: |
|
|
painted_image = point_painter( |
|
|
painted_image, |
|
|
positive_points, |
|
|
point_color_ne, |
|
|
point_alpha, |
|
|
point_radius, |
|
|
contour_color, |
|
|
contour_width |
|
|
) |
|
|
|
|
|
|
|
|
negative_points = np.squeeze(points[np.argwhere(labels < 1)], axis=1) |
|
|
if len(negative_points) > 0: |
|
|
painted_image = point_painter( |
|
|
painted_image, |
|
|
negative_points, |
|
|
point_color_ps, |
|
|
point_alpha, |
|
|
point_radius, |
|
|
contour_color, |
|
|
contour_width |
|
|
) |
|
|
|
|
|
painted_image = Image.fromarray(painted_image) |
|
|
|
|
|
return mask, logit, painted_image |
|
|
|