File size: 10,759 Bytes
b30e7a3
5c36daa
b30e7a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c36daa
a5f8d15
5c36daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3de3df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b30e7a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3de3df3
 
3d32b4a
 
b30e7a3
3d32b4a
 
 
 
 
 
 
 
 
 
 
 
b30e7a3
 
 
 
 
 
 
 
 
5c36daa
b30e7a3
 
 
 
 
 
 
 
 
5c36daa
 
 
 
 
 
 
 
 
 
 
 
3de3df3
 
 
 
 
 
 
 
 
5c36daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import logging
from typing import Optional, Sequence

import numpy as np
import torch
from PIL import Image
from transformers import Sam3Model, Sam3Processor

from .base import Segmenter, SegmentationResult


class SAM3Segmenter(Segmenter):
    """
    SAM3 (Segment Anything Model 3) segmenter.

    Performs automatic instance segmentation on images without prompts.
    Uses facebook/sam3 model from HuggingFace.
    """

    name = "sam3"

    def __init__(
        self,
        model_id: str = "facebook/sam3",
        device: Optional[str] = None,
        threshold: float = 0.5,
        mask_threshold: float = 0.5,
    ):
        """
        Initialize SAM3 segmenter.

        Args:
            model_id: HuggingFace model ID
            device: Device to run on (cuda/cpu), auto-detected if None
            threshold: Confidence threshold for filtering instances
            mask_threshold: Threshold for binarizing masks
        """
        self.device = device or (
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self.threshold = threshold
        self.mask_threshold = mask_threshold

        logging.info(
            "Loading SAM3 model %s on device %s", model_id, self.device
        )

        try:
            self.model = Sam3Model.from_pretrained(model_id).to(self.device)
            self.processor = Sam3Processor.from_pretrained(model_id)
            self.model.eval()
        except Exception:
            logging.exception("Failed to load SAM3 model")
            raise

        logging.info("SAM3 model loaded successfully")

    supports_batch = True
    max_batch_size = 8

    def _parse_single_result(self, results, frame_shape) -> SegmentationResult:
        # Extract results
        masks = results.get("masks", [])
        scores = results.get("scores", None)
        boxes = results.get("boxes", None)

        # Convert to numpy arrays
        if len(masks) > 0:
            # Stack masks: list of (H, W) -> (N, H, W)
            masks_array = np.stack([m.cpu().numpy() for m in masks])
        else:
            # No objects detected
            masks_array = np.zeros(
                (0, frame_shape[0], frame_shape[1]), dtype=bool
            )

        scores_array = (
            scores.cpu().numpy() if scores is not None else None
        )
        boxes_array = (
            boxes.cpu().numpy() if boxes is not None else None
        )

        return SegmentationResult(
            masks=masks_array,
            scores=scores_array,
            boxes=boxes_array,
        )

    def _expand_inputs_if_needed(self, inputs):
        """
        Helper to expand vision inputs (pixel_values or vision_embeds) to match text prompts.
        Handles:
        1. 1 image, N texts (Expand 1 -> N)
        2. N images, N*M texts (Expand N -> N*M)
        """
        pixel_values = inputs.get("pixel_values")
        input_ids = inputs.get("input_ids")
        
        if (
            pixel_values is not None 
            and input_ids is not None
        ):
            img_batch = pixel_values.shape[0]
            text_batch = input_ids.shape[0]
            
            should_expand = False
            expansion_factor = 1
            
            if img_batch == 1 and text_batch > 1:
                should_expand = True
                expansion_factor = text_batch
            elif img_batch > 1 and text_batch > img_batch and text_batch % img_batch == 0:
                should_expand = True
                expansion_factor = text_batch // img_batch

            if should_expand:
                logging.debug(f"Expanding SAM3 vision inputs from {img_batch} to {text_batch} (factor {expansion_factor}) using embeddings reuse.")
                
                # 1. Compute vision embeddings once for original images
                with torch.no_grad():
                    vision_outputs = self.model.get_vision_features(
                        pixel_values=pixel_values
                    )
                
                
                # Iterate over keys to expand
                keys_to_expand = list(vision_outputs.keys())
                for key in keys_to_expand:
                    value = getattr(vision_outputs, key, None)
                    if value is None:
                        # Try getItem
                        try:
                            value = vision_outputs[key]
                        except:
                            continue
                            
                    new_value = None
                    if isinstance(value, torch.Tensor):
                        # Ensure we only expand the batch dimension (dim 0)
                        if value.shape[0] == img_batch:
                             new_value = value.repeat_interleave(expansion_factor, dim=0)
                    elif isinstance(value, (list, tuple)):
                        new_list = []
                        valid_expansion = False
                        for i, v in enumerate(value):
                            if isinstance(v, torch.Tensor) and v.shape[0] == img_batch:
                                new_list.append(v.repeat_interleave(expansion_factor, dim=0))
                                valid_expansion = True
                            else:
                                new_list.append(v)
                        
                        if valid_expansion:
                            # Preserve type
                            new_value = type(value)(new_list)
                    
                    if new_value is not None:
                         # Update dict item if possible
                         try:
                            vision_outputs[key] = new_value
                         except:
                            pass
                         # Update attribute explicitly if it exists
                         if hasattr(vision_outputs, key):
                             setattr(vision_outputs, key, new_value)
                         
                
                # 3. Update inputs for model call
                inputs["vision_embeds"] = vision_outputs
                del inputs["pixel_values"] # Mutually exclusive with vision_embeds
                
                # 4. Expand other metadata
                if "original_sizes" in inputs and inputs["original_sizes"].shape[0] == img_batch:
                    inputs["original_sizes"] = inputs["original_sizes"].repeat_interleave(expansion_factor, dim=0)
                
                if "reshape_input_sizes" in inputs and inputs["reshape_input_sizes"].shape[0] == img_batch:
                    inputs["reshape_input_sizes"] = inputs["reshape_input_sizes"].repeat_interleave(expansion_factor, dim=0)

    def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
        """
        Run SAM3 segmentation on a frame.

        Args:
            frame: Input image (HxWx3 numpy array in RGB)
            text_prompts: List of text prompts for segmentation

        Returns:
            SegmentationResult with instance masks
        """
        # Convert numpy array to PIL Image
        if frame.dtype == np.uint8:
            pil_image = Image.fromarray(frame)
        else:
            # Normalize to 0-255 if needed
            frame_uint8 = (frame * 255).astype(np.uint8)
            pil_image = Image.fromarray(frame_uint8)

        # Use default prompts if none provided
        if not text_prompts:
            text_prompts = ["object"]

        # Process image with text prompts
        inputs = self.processor(
            images=pil_image, text=text_prompts, return_tensors="pt"
        ).to(self.device)

        # Handle batch expansion
        self._expand_inputs_if_needed(inputs)


        # Run inference
        try:
            if "pixel_values" in inputs:
                logging.debug(f"SAM3 Input pixel_values shape: {inputs['pixel_values'].shape}")
            with torch.no_grad():
                outputs = self.model(**inputs)
        except RuntimeError as e:
            logging.error(f"RuntimeError during SAM3 inference: {e}")
            logging.error(f"Input keys: {inputs.keys()}")
            if 'pixel_values' in inputs:
                logging.error(f"Pixel values shape: {inputs['pixel_values'].shape}")
            # Re-raise to let user know
            raise

        # Post-process to get instance masks
        try:
            results = self.processor.post_process_instance_segmentation(
                outputs,
                threshold=self.threshold,
                mask_threshold=self.mask_threshold,
                target_sizes=inputs.get("original_sizes").tolist(),
            )[0]
            return self._parse_single_result(results, frame.shape)

        except Exception:
            logging.exception("SAM3 post-processing failed")
            # Return empty result
            return SegmentationResult(
                masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
                scores=None,
                boxes=None,
            )

    def predict_batch(self, frames: Sequence[np.ndarray], text_prompts: Optional[list] = None) -> Sequence[SegmentationResult]:
        pil_images = []
        for f in frames:
            if f.dtype == np.uint8:
                pil_images.append(Image.fromarray(f))
            else:
                f_uint8 = (f * 255).astype(np.uint8)
                pil_images.append(Image.fromarray(f_uint8))
        
        prompts = text_prompts or ["object"]
        
        # Flatten prompts for all images: [img1_p1, img1_p2, img2_p1, img2_p2, ...]
        flattened_prompts = []
        for _ in frames:
            flattened_prompts.extend(prompts)
            
        inputs = self.processor(images=pil_images, text=flattened_prompts, return_tensors="pt").to(self.device)
        
        # Handle batch expansion
        self._expand_inputs_if_needed(inputs)
        
        with torch.no_grad():
            outputs = self.model(**inputs)

        try:
             results_list = self.processor.post_process_instance_segmentation(
                outputs,
                threshold=self.threshold,
                mask_threshold=self.mask_threshold,
                target_sizes=inputs.get("original_sizes").tolist(),
            )
             return [self._parse_single_result(r, f.shape) for r, f in zip(results_list, frames)]
        except Exception:
            logging.exception("SAM3 batch post-processing failed")
            return [
                SegmentationResult(
                    masks=np.zeros((0, f.shape[0], f.shape[1]), dtype=bool),
                    scores=None,
                    boxes=None
                ) for f in frames
            ]