Spaces:
Runtime error
Runtime error
tp53(ashish)
commited on
Commit
·
c94cf50
1
Parent(s):
49e4f07
Remove model folder - use fallback mode
Browse files- model/__init__.py +0 -0
- model/medsam3.py +0 -379
model/__init__.py
DELETED
|
File without changes
|
model/medsam3.py
DELETED
|
@@ -1,379 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
from typing import Dict, Optional, List, Any
|
| 5 |
-
|
| 6 |
-
try:
|
| 7 |
-
from sam3.model_builder import build_sam3_image_model as build_sam3_model
|
| 8 |
-
from sam3.model.data_misc import BatchedDatapoint, FindStage, BatchedFindTarget, BatchedInferenceMetadata
|
| 9 |
-
from sam3.model import decoder as sam3_decoder
|
| 10 |
-
SAM3_AVAILABLE = True
|
| 11 |
-
except ImportError:
|
| 12 |
-
build_sam3_model = None
|
| 13 |
-
BatchedDatapoint = None
|
| 14 |
-
FindStage = None
|
| 15 |
-
BatchedFindTarget = None
|
| 16 |
-
BatchedInferenceMetadata = None
|
| 17 |
-
sam3_decoder = None
|
| 18 |
-
SAM3_AVAILABLE = False
|
| 19 |
-
|
| 20 |
-
from peft import LoraConfig, get_peft_model
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def _patch_sam3_decoder_for_ddp():
|
| 24 |
-
"""
|
| 25 |
-
Monkey-patch SAM3's decoder to fix DDP device placement bug.
|
| 26 |
-
|
| 27 |
-
The bug: SAM3 caches coords_h/coords_w in compilable_cord_cache and coord_cache.
|
| 28 |
-
In DDP, these get created on cuda:0 first, then other ranks fail because
|
| 29 |
-
the cached coords are on the wrong device.
|
| 30 |
-
|
| 31 |
-
The fix: Patch _get_rpb_matrix to always move cached coords to the correct device.
|
| 32 |
-
"""
|
| 33 |
-
if not SAM3_AVAILABLE or sam3_decoder is None:
|
| 34 |
-
return
|
| 35 |
-
|
| 36 |
-
# Find the decoder class that has _get_rpb_matrix
|
| 37 |
-
decoder_cls = None
|
| 38 |
-
for name in dir(sam3_decoder):
|
| 39 |
-
cls = getattr(sam3_decoder, name)
|
| 40 |
-
if isinstance(cls, type) and hasattr(cls, '_get_rpb_matrix'):
|
| 41 |
-
decoder_cls = cls
|
| 42 |
-
break
|
| 43 |
-
|
| 44 |
-
if decoder_cls is None:
|
| 45 |
-
print("[MedSAM3] Warning: Could not find decoder class to patch")
|
| 46 |
-
return
|
| 47 |
-
|
| 48 |
-
# Store original method
|
| 49 |
-
original_get_rpb_matrix = decoder_cls._get_rpb_matrix
|
| 50 |
-
|
| 51 |
-
def patched_get_rpb_matrix(self, *args, **kwargs):
|
| 52 |
-
"""Patched version that ensures coords are on the correct device."""
|
| 53 |
-
# Get device from first tensor argument (reference_boxes)
|
| 54 |
-
target_device = None
|
| 55 |
-
for arg in args:
|
| 56 |
-
if torch.is_tensor(arg):
|
| 57 |
-
target_device = arg.device
|
| 58 |
-
break
|
| 59 |
-
if target_device is None:
|
| 60 |
-
for v in kwargs.values():
|
| 61 |
-
if torch.is_tensor(v):
|
| 62 |
-
target_device = v.device
|
| 63 |
-
break
|
| 64 |
-
|
| 65 |
-
if target_device is not None:
|
| 66 |
-
# Fix compilable_cord_cache if device mismatch
|
| 67 |
-
if hasattr(self, 'compilable_cord_cache') and self.compilable_cord_cache is not None:
|
| 68 |
-
cached_h, cached_w = self.compilable_cord_cache
|
| 69 |
-
if cached_h.device != target_device:
|
| 70 |
-
self.compilable_cord_cache = (
|
| 71 |
-
cached_h.to(target_device),
|
| 72 |
-
cached_w.to(target_device)
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
# Also fix coord_cache dict
|
| 76 |
-
if hasattr(self, 'coord_cache') and self.coord_cache:
|
| 77 |
-
for key in list(self.coord_cache.keys()):
|
| 78 |
-
cached_h, cached_w = self.coord_cache[key]
|
| 79 |
-
if cached_h.device != target_device:
|
| 80 |
-
self.coord_cache[key] = (
|
| 81 |
-
cached_h.to(target_device),
|
| 82 |
-
cached_w.to(target_device)
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
return original_get_rpb_matrix(self, *args, **kwargs)
|
| 86 |
-
|
| 87 |
-
# Apply patch
|
| 88 |
-
decoder_cls._get_rpb_matrix = patched_get_rpb_matrix
|
| 89 |
-
print("[MedSAM3] Successfully patched SAM3 decoder for DDP compatibility")
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
# Apply the patch at module load time
|
| 93 |
-
_patch_sam3_decoder_for_ddp()
|
| 94 |
-
|
| 95 |
-
class MedSAM3Model(nn.Module):
|
| 96 |
-
def __init__(self, model_id: str = "sam3_hiera_base", lora_rank: int = 16, image_size: int = 1024, checkpoint_path: Optional[str] = None):
|
| 97 |
-
super().__init__()
|
| 98 |
-
self._logged_shapes = False # For one-time debug logging
|
| 99 |
-
self._buffers_migrated = False # Track if we've done buffer device migration
|
| 100 |
-
self.image_size = image_size # Store for coordinate normalization
|
| 101 |
-
# --- 1. Initialize SAM 3 Architecture ---
|
| 102 |
-
if build_sam3_model:
|
| 103 |
-
# Initialize SAM3 architecture without downloading from HuggingFace
|
| 104 |
-
# (our checkpoint already contains full weights including base SAM3)
|
| 105 |
-
self.model = build_sam3_model(load_from_HF=False, eval_mode=False)
|
| 106 |
-
|
| 107 |
-
# --- 2. Load Weights ---
|
| 108 |
-
if checkpoint_path and os.path.exists(checkpoint_path):
|
| 109 |
-
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
| 110 |
-
if "model" in state_dict:
|
| 111 |
-
state_dict = state_dict["model"]
|
| 112 |
-
self.model.load_state_dict(state_dict, strict=False)
|
| 113 |
-
else:
|
| 114 |
-
raise ImportError(
|
| 115 |
-
"CRITICAL: SAM3 core libraries not found. "
|
| 116 |
-
"Ensure you have installed sam3 correctly (e.g. via pip install git+...sam3.git). "
|
| 117 |
-
"Check logs for previous import errors."
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
# --- 3. Freeze Backbone ---
|
| 121 |
-
for name, param in self.model.named_parameters():
|
| 122 |
-
if "perception_encoder" in name:
|
| 123 |
-
param.requires_grad = False
|
| 124 |
-
|
| 125 |
-
# --- 4. Apply LoRA ---
|
| 126 |
-
lora_config = LoraConfig(
|
| 127 |
-
r=lora_rank,
|
| 128 |
-
lora_alpha=lora_rank * 2,
|
| 129 |
-
target_modules=["qkv", "proj"],
|
| 130 |
-
lora_dropout=0.1,
|
| 131 |
-
bias="none",
|
| 132 |
-
task_type=None # Important: prevents peft from injecting 'input_ids'
|
| 133 |
-
)
|
| 134 |
-
self.model = get_peft_model(self.model, lora_config)
|
| 135 |
-
|
| 136 |
-
# --- 5. Foundation Specialist Fix: Dummy Matcher ---
|
| 137 |
-
# SAM3's forward_grounding path (which handles boxes/points) sometimes
|
| 138 |
-
# attempts to call self.matcher even if no targets are provided.
|
| 139 |
-
# We inject a dummy matcher that returns empty indices to prevent
|
| 140 |
-
# 'NoneType object is not callable' crashes.
|
| 141 |
-
base_model = self.model.get_base_model()
|
| 142 |
-
if hasattr(base_model, 'matcher') and base_model.matcher is None:
|
| 143 |
-
# Matcher expected signature: func(outputs, targets) -> list of matches
|
| 144 |
-
base_model.matcher = lambda outputs, targets: []
|
| 145 |
-
print("[MedSAM3] Injected dummy matcher for grounding stability")
|
| 146 |
-
|
| 147 |
-
def forward(self, pixel_values, input_boxes=None, input_points=None, point_labels=None, text_prompt=None):
|
| 148 |
-
# DDP Fix: Ensure all model buffers are on the same device as input
|
| 149 |
-
# SAM3 has some internal buffers that don't auto-migrate in DDP
|
| 150 |
-
# Only do this once per device to avoid overhead on every forward pass
|
| 151 |
-
target_device = pixel_values.device
|
| 152 |
-
if not self._buffers_migrated:
|
| 153 |
-
migrated_count = 0
|
| 154 |
-
for name, buf in self.model.named_buffers():
|
| 155 |
-
if buf.device != target_device:
|
| 156 |
-
buf.data = buf.data.to(target_device)
|
| 157 |
-
migrated_count += 1
|
| 158 |
-
if migrated_count > 0:
|
| 159 |
-
print(f"[MedSAM3] Migrated {migrated_count} buffers to {target_device}")
|
| 160 |
-
self._buffers_migrated = True
|
| 161 |
-
|
| 162 |
-
# Debug: Log shapes once on first forward pass
|
| 163 |
-
if not self._logged_shapes:
|
| 164 |
-
print(f"[MedSAM3] First forward - Input shapes:")
|
| 165 |
-
print(f" pixel_values: {pixel_values.shape}")
|
| 166 |
-
print(f" input_boxes: {input_boxes.shape if input_boxes is not None else None}")
|
| 167 |
-
print(f" input_points: {input_points.shape if input_points is not None else None}")
|
| 168 |
-
print(f" point_labels: {point_labels.shape if point_labels is not None else None}")
|
| 169 |
-
|
| 170 |
-
# --- 1. Handle 3D to 2D Flattening (Robust) ---
|
| 171 |
-
if pixel_values.dim() == 5:
|
| 172 |
-
# Input: (B, C, T, H, W) -> Goal: (B*T, C, H, W)
|
| 173 |
-
B_orig, C, T, H, W = pixel_values.shape
|
| 174 |
-
# Permute to (B, T, C, H, W) then flatten
|
| 175 |
-
pixel_values = pixel_values.permute(0, 2, 1, 3, 4).reshape(B_orig * T, C, H, W)
|
| 176 |
-
|
| 177 |
-
if input_boxes is not None:
|
| 178 |
-
# input_boxes is (B, T, 4) -> (B*T, 4)
|
| 179 |
-
input_boxes = input_boxes.view(B_orig * T, 4)
|
| 180 |
-
if input_points is not None:
|
| 181 |
-
# input_points is (B, T, 1, 2) -> (B*T, 1, 2)
|
| 182 |
-
input_points = input_points.view(B_orig * T, -1, 2)
|
| 183 |
-
if point_labels is not None:
|
| 184 |
-
# point_labels is (B, T, 1) -> (B*T, 1)
|
| 185 |
-
point_labels = point_labels.view(B_orig * T, -1)
|
| 186 |
-
|
| 187 |
-
# After reshaping, get the actual batch size
|
| 188 |
-
B = pixel_values.shape[0]
|
| 189 |
-
|
| 190 |
-
# --- 2. Channel Handling (Ensuring 3 channels for SAM3) ---
|
| 191 |
-
num_channels = pixel_values.shape[1]
|
| 192 |
-
if num_channels == 1:
|
| 193 |
-
# Single-channel (e.g., CT) -> replicate to 3 channels
|
| 194 |
-
pixel_values = pixel_values.repeat(1, 3, 1, 1)
|
| 195 |
-
elif num_channels == 3:
|
| 196 |
-
# Already 3 channels (e.g., multi-modal MRI after SelectMRIChannels)
|
| 197 |
-
pass
|
| 198 |
-
elif num_channels == 4:
|
| 199 |
-
# 4-channel MRI (BrainTumour) - use first 3 channels [FLAIR, T1w, T1gd]
|
| 200 |
-
# This is a fallback; ideally SelectMedicalChannels should handle this in transforms
|
| 201 |
-
if not self._logged_shapes:
|
| 202 |
-
print(f"[MedSAM3 WARNING] Received 4-channel input - using first 3 channels. "
|
| 203 |
-
f"Consider enabling SelectMedicalChannels transform.")
|
| 204 |
-
pixel_values = pixel_values[:, :3, :, :]
|
| 205 |
-
else:
|
| 206 |
-
# Unexpected channel count - average and replicate
|
| 207 |
-
if not self._logged_shapes:
|
| 208 |
-
print(f"[MedSAM3 WARNING] Unexpected {num_channels} channels - averaging to single then replicating to 3.")
|
| 209 |
-
pixel_values = pixel_values.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
|
| 210 |
-
|
| 211 |
-
# --- 3. Prompt Dimension Enforcement ---
|
| 212 |
-
# Boxes: (B_total, 1, 4)
|
| 213 |
-
if input_boxes is not None and input_boxes.dim() == 2:
|
| 214 |
-
input_boxes = input_boxes.unsqueeze(1)
|
| 215 |
-
|
| 216 |
-
# Points: (B_total, 1, N, 2), Labels: (B_total, 1, N)
|
| 217 |
-
if input_points is not None and input_points.dim() == 3:
|
| 218 |
-
input_points = input_points.unsqueeze(1)
|
| 219 |
-
if point_labels is not None and point_labels.dim() == 2:
|
| 220 |
-
point_labels = point_labels.unsqueeze(1)
|
| 221 |
-
|
| 222 |
-
# --- 3. Package into Official SAM 3 Structure ---
|
| 223 |
-
if BatchedDatapoint is not None and FindStage is not None:
|
| 224 |
-
# Get device from model parameters (critical for DDP multi-GPU)
|
| 225 |
-
device = pixel_values.device
|
| 226 |
-
|
| 227 |
-
# SAM 3 expects a SINGLE FindStage object that aggregates prompts for the entire batch.
|
| 228 |
-
# We must flatten the batch dimension of the prompts and create corresponding img_ids.
|
| 229 |
-
|
| 230 |
-
# Current shapes:
|
| 231 |
-
# input_boxes: (B, 1, 4)
|
| 232 |
-
# input_points: (B, 1, N, 2)
|
| 233 |
-
# point_labels: (B, 1, N)
|
| 234 |
-
|
| 235 |
-
# We treat each image in the batch as having 1 prompt (since we have 1 box/point set per slice)
|
| 236 |
-
# So we just flatten the first dimension.
|
| 237 |
-
|
| 238 |
-
# img_ids: [0, 1, 2, ... B-1] (since 1 prompt per image)
|
| 239 |
-
img_ids = torch.arange(B, device=device, dtype=torch.long)
|
| 240 |
-
|
| 241 |
-
# Text ids: all 0 (dummy)
|
| 242 |
-
text_ids = torch.zeros(B, device=device, dtype=torch.long)
|
| 243 |
-
|
| 244 |
-
# SAM3 expects SEQUENCE-FIRST format for embeddings, BATCH-FIRST for masks:
|
| 245 |
-
# input_boxes: [num_boxes, num_prompts, 4] - sequence first
|
| 246 |
-
# input_boxes_mask: [num_prompts, num_boxes] - batch first (1=padded/invalid)
|
| 247 |
-
# input_boxes_label: [num_boxes, num_prompts]
|
| 248 |
-
# input_points: [num_points, num_prompts, 2] - sequence first
|
| 249 |
-
# input_points_mask: [num_prompts, num_points] - batch first
|
| 250 |
-
#
|
| 251 |
-
# For our case: 1 box per image, B images → num_boxes=1, num_prompts=B
|
| 252 |
-
|
| 253 |
-
# Boxes: [B, 1, 4] → [1, B, 4] (sequence first)
|
| 254 |
-
# SAM3 expects boxes in CxCyWH format, normalized to [0, 1]
|
| 255 |
-
# Our input is xyxy in pixel coordinates
|
| 256 |
-
if input_boxes is not None:
|
| 257 |
-
boxes_xyxy = input_boxes.squeeze(1).float().to(device) # [B, 4] - x_min, y_min, x_max, y_max
|
| 258 |
-
|
| 259 |
-
# Use actual tensor dimensions for normalization (more robust than stored image_size)
|
| 260 |
-
actual_h, actual_w = pixel_values.shape[2], pixel_values.shape[3]
|
| 261 |
-
|
| 262 |
-
# Convert xyxy to cxcywh and normalize to [0, 1]
|
| 263 |
-
x_min, y_min, x_max, y_max = boxes_xyxy[:, 0], boxes_xyxy[:, 1], boxes_xyxy[:, 2], boxes_xyxy[:, 3]
|
| 264 |
-
cx = (x_min + x_max) / 2.0 / actual_w
|
| 265 |
-
cy = (y_min + y_max) / 2.0 / actual_h
|
| 266 |
-
w = (x_max - x_min) / actual_w
|
| 267 |
-
h = (y_max - y_min) / actual_h
|
| 268 |
-
|
| 269 |
-
# Clamp to ensure valid boxes (min size 1% of image to avoid ROI align issues)
|
| 270 |
-
min_size = 0.01
|
| 271 |
-
w = torch.clamp(w, min=min_size)
|
| 272 |
-
h = torch.clamp(h, min=min_size)
|
| 273 |
-
|
| 274 |
-
boxes_cxcywh = torch.stack([cx, cy, w, h], dim=1) # [B, 4]
|
| 275 |
-
flat_boxes = boxes_cxcywh.unsqueeze(0) # [1, B, 4]
|
| 276 |
-
flat_boxes_mask = torch.zeros(B, 1, device=device, dtype=torch.bool) # [B, 1] - 0=valid
|
| 277 |
-
flat_boxes_label = torch.zeros(1, B, device=device, dtype=torch.long) # [1, B]
|
| 278 |
-
else:
|
| 279 |
-
flat_boxes = torch.zeros(1, B, 4, device=device)
|
| 280 |
-
flat_boxes_mask = torch.ones(B, 1, device=device, dtype=torch.bool) # 1=invalid/padded
|
| 281 |
-
flat_boxes_label = torch.zeros(1, B, device=device, dtype=torch.long)
|
| 282 |
-
|
| 283 |
-
# Points: [B, 1, N, 2] → [N, B, 2] (sequence first)
|
| 284 |
-
# SAM3 expects points normalized to [0, 1]
|
| 285 |
-
n_points = input_points.shape[2] if input_points is not None else 1
|
| 286 |
-
if input_points is not None:
|
| 287 |
-
points_pixel = input_points.squeeze(1).float().to(device) # [B, N, 2] - x, y in pixel coords
|
| 288 |
-
# Normalize using actual tensor dimensions
|
| 289 |
-
actual_h, actual_w = pixel_values.shape[2], pixel_values.shape[3]
|
| 290 |
-
points_normalized = points_pixel.clone()
|
| 291 |
-
points_normalized[..., 0] = points_pixel[..., 0] / actual_w # x normalized
|
| 292 |
-
points_normalized[..., 1] = points_pixel[..., 1] / actual_h # y normalized
|
| 293 |
-
flat_points = points_normalized.permute(1, 0, 2) # [B, N, 2] → [N, B, 2]
|
| 294 |
-
flat_points_mask = torch.zeros(B, n_points, device=device, dtype=torch.bool) # 0=valid
|
| 295 |
-
else:
|
| 296 |
-
flat_points = torch.zeros(1, B, 2, device=device)
|
| 297 |
-
flat_points_mask = torch.ones(B, 1, device=device, dtype=torch.bool) # 1=invalid
|
| 298 |
-
|
| 299 |
-
stage = FindStage(
|
| 300 |
-
img_ids=img_ids,
|
| 301 |
-
text_ids=text_ids,
|
| 302 |
-
input_boxes=flat_boxes,
|
| 303 |
-
input_boxes_mask=flat_boxes_mask,
|
| 304 |
-
input_boxes_label=flat_boxes_label,
|
| 305 |
-
input_points=flat_points,
|
| 306 |
-
input_points_mask=flat_points_mask,
|
| 307 |
-
)
|
| 308 |
-
|
| 309 |
-
# Text batch for grounding head - use provided text_prompt or fallback
|
| 310 |
-
if text_prompt is not None:
|
| 311 |
-
find_text_batch = [text_prompt] * B
|
| 312 |
-
else:
|
| 313 |
-
find_text_batch = ["medical"] * B
|
| 314 |
-
|
| 315 |
-
# Create dummy target structure to satisfy SAM3's internal indexing [0]
|
| 316 |
-
# We use the dummy matcher injected in __init__ to ensure this doesn't
|
| 317 |
-
# actually trigger any real loss computation.
|
| 318 |
-
dummy_target = BatchedFindTarget(
|
| 319 |
-
num_boxes=torch.zeros(B, device=device, dtype=torch.long),
|
| 320 |
-
boxes=torch.zeros(B, 4, device=device),
|
| 321 |
-
boxes_padded=torch.zeros(B, 1, 4, device=device),
|
| 322 |
-
repeated_boxes=torch.zeros(B, 4, device=device),
|
| 323 |
-
segments=None,
|
| 324 |
-
semantic_segments=None,
|
| 325 |
-
is_valid_segment=None,
|
| 326 |
-
is_exhaustive=torch.zeros(B, device=device, dtype=torch.bool),
|
| 327 |
-
object_ids=torch.zeros(B, device=device, dtype=torch.long),
|
| 328 |
-
object_ids_padded=torch.zeros(B, 1, device=device, dtype=torch.long),
|
| 329 |
-
)
|
| 330 |
-
|
| 331 |
-
# Create proper metadata structure (required by SAM3's type hints)
|
| 332 |
-
# BatchedInferenceMetadata requires: coco_image_id, original_image_id, original_category_id,
|
| 333 |
-
# original_size, object_id, frame_index, is_conditioning_only
|
| 334 |
-
dummy_metadata = BatchedInferenceMetadata(
|
| 335 |
-
coco_image_id=torch.zeros(B, device=device, dtype=torch.long),
|
| 336 |
-
original_image_id=torch.zeros(B, device=device, dtype=torch.long),
|
| 337 |
-
original_category_id=torch.zeros(B, device=device, dtype=torch.int),
|
| 338 |
-
original_size=torch.tensor([[self.image_size, self.image_size]] * B, device=device, dtype=torch.long),
|
| 339 |
-
object_id=torch.zeros(B, device=device, dtype=torch.long),
|
| 340 |
-
frame_index=torch.zeros(B, device=device, dtype=torch.long),
|
| 341 |
-
is_conditioning_only=[None] * B,
|
| 342 |
-
) if BatchedInferenceMetadata is not None else {}
|
| 343 |
-
|
| 344 |
-
# Package into BatchedDatapoint
|
| 345 |
-
# find_targets=[dummy_target]: satisfy internal 'input.find_targets[0]' access
|
| 346 |
-
find_targets_list = [dummy_target] # Pre-create to verify it's not empty
|
| 347 |
-
find_metadatas_list = [dummy_metadata]
|
| 348 |
-
data = BatchedDatapoint(
|
| 349 |
-
img_batch=pixel_values,
|
| 350 |
-
find_text_batch=find_text_batch,
|
| 351 |
-
find_inputs=[stage],
|
| 352 |
-
find_targets=find_targets_list,
|
| 353 |
-
find_metadatas=find_metadatas_list
|
| 354 |
-
)
|
| 355 |
-
# Immediate verification
|
| 356 |
-
assert len(data.find_targets) == 1, f"find_targets should have 1 element, got {len(data.find_targets)}"
|
| 357 |
-
|
| 358 |
-
# Debug: Log processed shapes once
|
| 359 |
-
if not self._logged_shapes:
|
| 360 |
-
print(f"[MedSAM3] Processed shapes before SAM3 call:")
|
| 361 |
-
print(f" img_batch: {pixel_values.shape}")
|
| 362 |
-
print(f" flat_boxes: {flat_boxes.shape} (CxCyWH normalized)")
|
| 363 |
-
print(f" flat_boxes sample: {flat_boxes[0, 0, :] if flat_boxes.numel() > 0 else 'empty'}")
|
| 364 |
-
print(f" flat_boxes_mask: {flat_boxes_mask.shape}")
|
| 365 |
-
print(f" flat_points: {flat_points.shape} (normalized)")
|
| 366 |
-
print(f" flat_points_mask: {flat_points_mask.shape}")
|
| 367 |
-
print(f" img_ids: {img_ids.shape}")
|
| 368 |
-
print(f" find_targets: {data.find_targets}, len={len(data.find_targets)}")
|
| 369 |
-
print(f" find_inputs: {data.find_inputs}, len={len(data.find_inputs)}")
|
| 370 |
-
self._logged_shapes = True
|
| 371 |
-
|
| 372 |
-
# Validation safety: verify find_targets is not empty before passing to SAM3
|
| 373 |
-
if len(data.find_targets) == 0:
|
| 374 |
-
print(f"[MedSAM3 ERROR] find_targets is empty! dummy_target={dummy_target}")
|
| 375 |
-
raise ValueError("find_targets list is empty - this should never happen")
|
| 376 |
-
|
| 377 |
-
return self.model(data)
|
| 378 |
-
else:
|
| 379 |
-
return self.model(pixel_values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|