File size: 2,053 Bytes
d33e75e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
SAM2 Base Segmenter
Adapted from MatAnyone demo
"""

import sys
sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2")

import torch
import numpy as np
from sam2.build_sam import build_sam2_video_predictor


class BaseSegmenter:
    def __init__(self, SAM_checkpoint, model_type, device):
        """
        Initialize SAM2 segmenter
        
        Args:
            SAM_checkpoint: Path to SAM2 checkpoint
            model_type: SAM2 model config file
            device: Device to run on
        """
        self.device = device
        self.model_type = model_type
        
        # Build SAM2 video predictor
        self.sam_predictor = build_sam2_video_predictor(
            config_file=model_type,
            ckpt_path=SAM_checkpoint,
            device=device
        )
        
        self.orignal_image = None
        self.inference_state = None
    
    def set_image(self, image: np.ndarray):
        """Set the current image for segmentation"""
        self.orignal_image = image
    
    def reset_image(self):
        """Reset the current image"""
        self.orignal_image = None
        self.inference_state = None
    
    def predict(self, prompts, prompt_type, multimask=True):
        """
        Predict mask from prompts
        
        Args:
            prompts: Dictionary with point_coords, point_labels, mask_input
            prompt_type: 'point' or 'both'
            multimask: Whether to return multiple masks
            
        Returns:
            masks, scores, logits
        """
        # For SAM2, we need to handle prompts differently
        # This is simplified - actual implementation will use video predictor
        
        # Placeholder - actual SAM2 prediction would go here
        # For now, return dummy values
        h, w = self.orignal_image.shape[:2]
        dummy_mask = np.zeros((h, w), dtype=bool)
        dummy_score = np.array([1.0])
        dummy_logit = np.zeros((h, w), dtype=np.float32)
        
        return np.array([dummy_mask]), dummy_score, np.array([dummy_logit])