|
|
""" |
|
|
SAM2 Base Segmenter |
|
|
Adapted from MatAnyone demo |
|
|
""" |
|
|
|
|
|
import sys |
|
|
sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2") |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
|
|
|
|
class BaseSegmenter: |
|
|
def __init__(self, SAM_checkpoint, model_type, device): |
|
|
""" |
|
|
Initialize SAM2 segmenter |
|
|
|
|
|
Args: |
|
|
SAM_checkpoint: Path to SAM2 checkpoint |
|
|
model_type: SAM2 model config file |
|
|
device: Device to run on |
|
|
""" |
|
|
self.device = device |
|
|
self.model_type = model_type |
|
|
|
|
|
|
|
|
self.sam_predictor = build_sam2_video_predictor( |
|
|
config_file=model_type, |
|
|
ckpt_path=SAM_checkpoint, |
|
|
device=device |
|
|
) |
|
|
|
|
|
self.orignal_image = None |
|
|
self.inference_state = None |
|
|
|
|
|
def set_image(self, image: np.ndarray): |
|
|
"""Set the current image for segmentation""" |
|
|
self.orignal_image = image |
|
|
|
|
|
def reset_image(self): |
|
|
"""Reset the current image""" |
|
|
self.orignal_image = None |
|
|
self.inference_state = None |
|
|
|
|
|
def predict(self, prompts, prompt_type, multimask=True): |
|
|
""" |
|
|
Predict mask from prompts |
|
|
|
|
|
Args: |
|
|
prompts: Dictionary with point_coords, point_labels, mask_input |
|
|
prompt_type: 'point' or 'both' |
|
|
multimask: Whether to return multiple masks |
|
|
|
|
|
Returns: |
|
|
masks, scores, logits |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h, w = self.orignal_image.shape[:2] |
|
|
dummy_mask = np.zeros((h, w), dtype=bool) |
|
|
dummy_score = np.array([1.0]) |
|
|
dummy_logit = np.zeros((h, w), dtype=np.float32) |
|
|
|
|
|
return np.array([dummy_mask]), dummy_score, np.array([dummy_logit]) |
|
|
|