Spaces:
Paused
Paused
Zhen Ye
commited on
Commit
·
3de3df3
1
Parent(s):
3d32b4a
Fix SAM3 batch prediction shape mismatch and add InternVL2 to frontend
Browse files- LaserPerception/LaserPerception.html +3 -0
- models/segmenters/sam3.py +100 -47
LaserPerception/LaserPerception.html
CHANGED
|
@@ -81,6 +81,9 @@
|
|
| 81 |
<optgroup label="Drone Detection Models">
|
| 82 |
<option value="drone_yolo" data-kind="drone">Drone</option>
|
| 83 |
</optgroup>
|
|
|
|
|
|
|
|
|
|
| 84 |
</select>
|
| 85 |
</div>
|
| 86 |
<div>
|
|
|
|
| 81 |
<optgroup label="Drone Detection Models">
|
| 82 |
<option value="drone_yolo" data-kind="drone">Drone</option>
|
| 83 |
</optgroup>
|
| 84 |
+
<optgroup label="Vision-Language Models">
|
| 85 |
+
<option value="internvl2_military" data-kind="object">InternVL2 (Military)</option>
|
| 86 |
+
</optgroup>
|
| 87 |
</select>
|
| 88 |
</div>
|
| 89 |
<div>
|
models/segmenters/sam3.py
CHANGED
|
@@ -87,6 +87,95 @@ class SAM3Segmenter(Segmenter):
|
|
| 87 |
boxes=boxes_array,
|
| 88 |
)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
|
| 91 |
"""
|
| 92 |
Run SAM3 segmentation on a frame.
|
|
@@ -115,51 +204,8 @@ class SAM3Segmenter(Segmenter):
|
|
| 115 |
images=pil_image, text=text_prompts, return_tensors="pt"
|
| 116 |
).to(self.device)
|
| 117 |
|
| 118 |
-
# Handle batch
|
| 119 |
-
|
| 120 |
-
input_ids = inputs.get("input_ids")
|
| 121 |
-
|
| 122 |
-
if (
|
| 123 |
-
pixel_values is not None
|
| 124 |
-
and input_ids is not None
|
| 125 |
-
and pixel_values.shape[0] == 1
|
| 126 |
-
and input_ids.shape[0] > 1
|
| 127 |
-
):
|
| 128 |
-
target_batch_size = input_ids.shape[0]
|
| 129 |
-
logging.debug(f"Expanding SAM3 vision inputs from 1 to {target_batch_size} using embeddings reuse.")
|
| 130 |
-
|
| 131 |
-
# 1. Compute vision embeddings once
|
| 132 |
-
with torch.no_grad():
|
| 133 |
-
vision_outputs = self.model.get_vision_features(
|
| 134 |
-
pixel_values=pixel_values
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
# 2. Expand vision embeddings
|
| 138 |
-
# vision_outputs is a ModelOutput (dict-like)
|
| 139 |
-
for key, value in vision_outputs.items():
|
| 140 |
-
if isinstance(value, torch.Tensor):
|
| 141 |
-
if value.shape[0] == 1:
|
| 142 |
-
vision_outputs[key] = value.repeat(target_batch_size, *([1]*(value.dim()-1)))
|
| 143 |
-
elif isinstance(value, (list, tuple)):
|
| 144 |
-
new_list = []
|
| 145 |
-
for v in value:
|
| 146 |
-
if isinstance(v, torch.Tensor) and v.shape[0] == 1:
|
| 147 |
-
new_list.append(v.repeat(target_batch_size, *([1]*(v.dim()-1))))
|
| 148 |
-
else:
|
| 149 |
-
new_list.append(v)
|
| 150 |
-
# Preserve type (tuple vs list)
|
| 151 |
-
vision_outputs[key] = type(value)(new_list)
|
| 152 |
-
|
| 153 |
-
# 3. Update inputs for model call
|
| 154 |
-
inputs["vision_embeds"] = vision_outputs
|
| 155 |
-
del inputs["pixel_values"] # Mutually exclusive with vision_embeds
|
| 156 |
-
|
| 157 |
-
# 4. Expand other metadata
|
| 158 |
-
if "original_sizes" in inputs and inputs["original_sizes"].shape[0] == 1:
|
| 159 |
-
inputs["original_sizes"] = inputs["original_sizes"].repeat(target_batch_size, 1)
|
| 160 |
-
|
| 161 |
-
if "reshape_input_sizes" in inputs and inputs["reshape_input_sizes"].shape[0] == 1:
|
| 162 |
-
inputs["reshape_input_sizes"] = inputs["reshape_input_sizes"].repeat(target_batch_size, 1)
|
| 163 |
|
| 164 |
|
| 165 |
# Run inference
|
|
@@ -206,8 +252,15 @@ class SAM3Segmenter(Segmenter):
|
|
| 206 |
|
| 207 |
prompts = text_prompts or ["object"]
|
| 208 |
|
| 209 |
-
#
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
with torch.no_grad():
|
| 213 |
outputs = self.model(**inputs)
|
|
|
|
| 87 |
boxes=boxes_array,
|
| 88 |
)
|
| 89 |
|
| 90 |
+
def _expand_inputs_if_needed(self, inputs):
|
| 91 |
+
"""
|
| 92 |
+
Helper to expand vision inputs (pixel_values or vision_embeds) to match text prompts.
|
| 93 |
+
Handles:
|
| 94 |
+
1. 1 image, N texts (Expand 1 -> N)
|
| 95 |
+
2. N images, N*M texts (Expand N -> N*M)
|
| 96 |
+
"""
|
| 97 |
+
pixel_values = inputs.get("pixel_values")
|
| 98 |
+
input_ids = inputs.get("input_ids")
|
| 99 |
+
|
| 100 |
+
if (
|
| 101 |
+
pixel_values is not None
|
| 102 |
+
and input_ids is not None
|
| 103 |
+
):
|
| 104 |
+
img_batch = pixel_values.shape[0]
|
| 105 |
+
text_batch = input_ids.shape[0]
|
| 106 |
+
|
| 107 |
+
should_expand = False
|
| 108 |
+
expansion_factor = 1
|
| 109 |
+
|
| 110 |
+
if img_batch == 1 and text_batch > 1:
|
| 111 |
+
should_expand = True
|
| 112 |
+
expansion_factor = text_batch
|
| 113 |
+
elif img_batch > 1 and text_batch > img_batch and text_batch % img_batch == 0:
|
| 114 |
+
should_expand = True
|
| 115 |
+
expansion_factor = text_batch // img_batch
|
| 116 |
+
|
| 117 |
+
if should_expand:
|
| 118 |
+
logging.debug(f"Expanding SAM3 vision inputs from {img_batch} to {text_batch} (factor {expansion_factor}) using embeddings reuse.")
|
| 119 |
+
|
| 120 |
+
# 1. Compute vision embeddings once for original images
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
vision_outputs = self.model.get_vision_features(
|
| 123 |
+
pixel_values=pixel_values
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# Iterate over keys to expand
|
| 128 |
+
keys_to_expand = list(vision_outputs.keys())
|
| 129 |
+
for key in keys_to_expand:
|
| 130 |
+
value = getattr(vision_outputs, key, None)
|
| 131 |
+
if value is None:
|
| 132 |
+
# Try getItem
|
| 133 |
+
try:
|
| 134 |
+
value = vision_outputs[key]
|
| 135 |
+
except:
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
new_value = None
|
| 139 |
+
if isinstance(value, torch.Tensor):
|
| 140 |
+
# Ensure we only expand the batch dimension (dim 0)
|
| 141 |
+
if value.shape[0] == img_batch:
|
| 142 |
+
new_value = value.repeat_interleave(expansion_factor, dim=0)
|
| 143 |
+
elif isinstance(value, (list, tuple)):
|
| 144 |
+
new_list = []
|
| 145 |
+
valid_expansion = False
|
| 146 |
+
for i, v in enumerate(value):
|
| 147 |
+
if isinstance(v, torch.Tensor) and v.shape[0] == img_batch:
|
| 148 |
+
new_list.append(v.repeat_interleave(expansion_factor, dim=0))
|
| 149 |
+
valid_expansion = True
|
| 150 |
+
else:
|
| 151 |
+
new_list.append(v)
|
| 152 |
+
|
| 153 |
+
if valid_expansion:
|
| 154 |
+
# Preserve type
|
| 155 |
+
new_value = type(value)(new_list)
|
| 156 |
+
|
| 157 |
+
if new_value is not None:
|
| 158 |
+
# Update dict item if possible
|
| 159 |
+
try:
|
| 160 |
+
vision_outputs[key] = new_value
|
| 161 |
+
except:
|
| 162 |
+
pass
|
| 163 |
+
# Update attribute explicitly if it exists
|
| 164 |
+
if hasattr(vision_outputs, key):
|
| 165 |
+
setattr(vision_outputs, key, new_value)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# 3. Update inputs for model call
|
| 169 |
+
inputs["vision_embeds"] = vision_outputs
|
| 170 |
+
del inputs["pixel_values"] # Mutually exclusive with vision_embeds
|
| 171 |
+
|
| 172 |
+
# 4. Expand other metadata
|
| 173 |
+
if "original_sizes" in inputs and inputs["original_sizes"].shape[0] == img_batch:
|
| 174 |
+
inputs["original_sizes"] = inputs["original_sizes"].repeat_interleave(expansion_factor, dim=0)
|
| 175 |
+
|
| 176 |
+
if "reshape_input_sizes" in inputs and inputs["reshape_input_sizes"].shape[0] == img_batch:
|
| 177 |
+
inputs["reshape_input_sizes"] = inputs["reshape_input_sizes"].repeat_interleave(expansion_factor, dim=0)
|
| 178 |
+
|
| 179 |
def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
|
| 180 |
"""
|
| 181 |
Run SAM3 segmentation on a frame.
|
|
|
|
| 204 |
images=pil_image, text=text_prompts, return_tensors="pt"
|
| 205 |
).to(self.device)
|
| 206 |
|
| 207 |
+
# Handle batch expansion
|
| 208 |
+
self._expand_inputs_if_needed(inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
|
| 211 |
# Run inference
|
|
|
|
| 252 |
|
| 253 |
prompts = text_prompts or ["object"]
|
| 254 |
|
| 255 |
+
# Flatten prompts for all images: [img1_p1, img1_p2, img2_p1, img2_p2, ...]
|
| 256 |
+
flattened_prompts = []
|
| 257 |
+
for _ in frames:
|
| 258 |
+
flattened_prompts.extend(prompts)
|
| 259 |
+
|
| 260 |
+
inputs = self.processor(images=pil_images, text=flattened_prompts, return_tensors="pt").to(self.device)
|
| 261 |
+
|
| 262 |
+
# Handle batch expansion
|
| 263 |
+
self._expand_inputs_if_needed(inputs)
|
| 264 |
|
| 265 |
with torch.no_grad():
|
| 266 |
outputs = self.model(**inputs)
|