Zhen Ye commited on
Commit
3de3df3
·
1 Parent(s): 3d32b4a

Fix SAM3 batch prediction shape mismatch and add InternVL2 to frontend

Browse files
LaserPerception/LaserPerception.html CHANGED
@@ -81,6 +81,9 @@
81
  <optgroup label="Drone Detection Models">
82
  <option value="drone_yolo" data-kind="drone">Drone</option>
83
  </optgroup>
 
 
 
84
  </select>
85
  </div>
86
  <div>
 
81
  <optgroup label="Drone Detection Models">
82
  <option value="drone_yolo" data-kind="drone">Drone</option>
83
  </optgroup>
84
+ <optgroup label="Vision-Language Models">
85
+ <option value="internvl2_military" data-kind="object">InternVL2 (Military)</option>
86
+ </optgroup>
87
  </select>
88
  </div>
89
  <div>
models/segmenters/sam3.py CHANGED
@@ -87,6 +87,95 @@ class SAM3Segmenter(Segmenter):
87
  boxes=boxes_array,
88
  )
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
91
  """
92
  Run SAM3 segmentation on a frame.
@@ -115,51 +204,8 @@ class SAM3Segmenter(Segmenter):
115
  images=pil_image, text=text_prompts, return_tensors="pt"
116
  ).to(self.device)
117
 
118
- # Handle batch size mismatch between image (1) and prompts (N) structure
119
- pixel_values = inputs.get("pixel_values")
120
- input_ids = inputs.get("input_ids")
121
-
122
- if (
123
- pixel_values is not None
124
- and input_ids is not None
125
- and pixel_values.shape[0] == 1
126
- and input_ids.shape[0] > 1
127
- ):
128
- target_batch_size = input_ids.shape[0]
129
- logging.debug(f"Expanding SAM3 vision inputs from 1 to {target_batch_size} using embeddings reuse.")
130
-
131
- # 1. Compute vision embeddings once
132
- with torch.no_grad():
133
- vision_outputs = self.model.get_vision_features(
134
- pixel_values=pixel_values
135
- )
136
-
137
- # 2. Expand vision embeddings
138
- # vision_outputs is a ModelOutput (dict-like)
139
- for key, value in vision_outputs.items():
140
- if isinstance(value, torch.Tensor):
141
- if value.shape[0] == 1:
142
- vision_outputs[key] = value.repeat(target_batch_size, *([1]*(value.dim()-1)))
143
- elif isinstance(value, (list, tuple)):
144
- new_list = []
145
- for v in value:
146
- if isinstance(v, torch.Tensor) and v.shape[0] == 1:
147
- new_list.append(v.repeat(target_batch_size, *([1]*(v.dim()-1))))
148
- else:
149
- new_list.append(v)
150
- # Preserve type (tuple vs list)
151
- vision_outputs[key] = type(value)(new_list)
152
-
153
- # 3. Update inputs for model call
154
- inputs["vision_embeds"] = vision_outputs
155
- del inputs["pixel_values"] # Mutually exclusive with vision_embeds
156
-
157
- # 4. Expand other metadata
158
- if "original_sizes" in inputs and inputs["original_sizes"].shape[0] == 1:
159
- inputs["original_sizes"] = inputs["original_sizes"].repeat(target_batch_size, 1)
160
-
161
- if "reshape_input_sizes" in inputs and inputs["reshape_input_sizes"].shape[0] == 1:
162
- inputs["reshape_input_sizes"] = inputs["reshape_input_sizes"].repeat(target_batch_size, 1)
163
 
164
 
165
  # Run inference
@@ -206,8 +252,15 @@ class SAM3Segmenter(Segmenter):
206
 
207
  prompts = text_prompts or ["object"]
208
 
209
- # Same prompts for all images
210
- inputs = self.processor(images=pil_images, text=[prompts]*len(frames), return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
211
 
212
  with torch.no_grad():
213
  outputs = self.model(**inputs)
 
87
  boxes=boxes_array,
88
  )
89
 
90
+ def _expand_inputs_if_needed(self, inputs):
91
+ """
92
+ Helper to expand vision inputs (pixel_values or vision_embeds) to match text prompts.
93
+ Handles:
94
+ 1. 1 image, N texts (Expand 1 -> N)
95
+ 2. N images, N*M texts (Expand N -> N*M)
96
+ """
97
+ pixel_values = inputs.get("pixel_values")
98
+ input_ids = inputs.get("input_ids")
99
+
100
+ if (
101
+ pixel_values is not None
102
+ and input_ids is not None
103
+ ):
104
+ img_batch = pixel_values.shape[0]
105
+ text_batch = input_ids.shape[0]
106
+
107
+ should_expand = False
108
+ expansion_factor = 1
109
+
110
+ if img_batch == 1 and text_batch > 1:
111
+ should_expand = True
112
+ expansion_factor = text_batch
113
+ elif img_batch > 1 and text_batch > img_batch and text_batch % img_batch == 0:
114
+ should_expand = True
115
+ expansion_factor = text_batch // img_batch
116
+
117
+ if should_expand:
118
+ logging.debug(f"Expanding SAM3 vision inputs from {img_batch} to {text_batch} (factor {expansion_factor}) using embeddings reuse.")
119
+
120
+ # 1. Compute vision embeddings once for original images
121
+ with torch.no_grad():
122
+ vision_outputs = self.model.get_vision_features(
123
+ pixel_values=pixel_values
124
+ )
125
+
126
+
127
+ # Iterate over keys to expand
128
+ keys_to_expand = list(vision_outputs.keys())
129
+ for key in keys_to_expand:
130
+ value = getattr(vision_outputs, key, None)
131
+ if value is None:
132
+ # Try getItem
133
+ try:
134
+ value = vision_outputs[key]
135
+ except:
136
+ continue
137
+
138
+ new_value = None
139
+ if isinstance(value, torch.Tensor):
140
+ # Ensure we only expand the batch dimension (dim 0)
141
+ if value.shape[0] == img_batch:
142
+ new_value = value.repeat_interleave(expansion_factor, dim=0)
143
+ elif isinstance(value, (list, tuple)):
144
+ new_list = []
145
+ valid_expansion = False
146
+ for i, v in enumerate(value):
147
+ if isinstance(v, torch.Tensor) and v.shape[0] == img_batch:
148
+ new_list.append(v.repeat_interleave(expansion_factor, dim=0))
149
+ valid_expansion = True
150
+ else:
151
+ new_list.append(v)
152
+
153
+ if valid_expansion:
154
+ # Preserve type
155
+ new_value = type(value)(new_list)
156
+
157
+ if new_value is not None:
158
+ # Update dict item if possible
159
+ try:
160
+ vision_outputs[key] = new_value
161
+ except:
162
+ pass
163
+ # Update attribute explicitly if it exists
164
+ if hasattr(vision_outputs, key):
165
+ setattr(vision_outputs, key, new_value)
166
+
167
+
168
+ # 3. Update inputs for model call
169
+ inputs["vision_embeds"] = vision_outputs
170
+ del inputs["pixel_values"] # Mutually exclusive with vision_embeds
171
+
172
+ # 4. Expand other metadata
173
+ if "original_sizes" in inputs and inputs["original_sizes"].shape[0] == img_batch:
174
+ inputs["original_sizes"] = inputs["original_sizes"].repeat_interleave(expansion_factor, dim=0)
175
+
176
+ if "reshape_input_sizes" in inputs and inputs["reshape_input_sizes"].shape[0] == img_batch:
177
+ inputs["reshape_input_sizes"] = inputs["reshape_input_sizes"].repeat_interleave(expansion_factor, dim=0)
178
+
179
  def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
180
  """
181
  Run SAM3 segmentation on a frame.
 
204
  images=pil_image, text=text_prompts, return_tensors="pt"
205
  ).to(self.device)
206
 
207
+ # Handle batch expansion
208
+ self._expand_inputs_if_needed(inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
 
211
  # Run inference
 
252
 
253
  prompts = text_prompts or ["object"]
254
 
255
+ # Flatten prompts for all images: [img1_p1, img1_p2, img2_p1, img2_p2, ...]
256
+ flattened_prompts = []
257
+ for _ in frames:
258
+ flattened_prompts.extend(prompts)
259
+
260
+ inputs = self.processor(images=pil_images, text=flattened_prompts, return_tensors="pt").to(self.device)
261
+
262
+ # Handle batch expansion
263
+ self._expand_inputs_if_needed(inputs)
264
 
265
  with torch.no_grad():
266
  outputs = self.model(**inputs)