tp53(ashish) commited on
Commit
c94cf50
·
1 Parent(s): 49e4f07

Remove model folder - use fallback mode

Browse files
Files changed (2) hide show
  1. model/__init__.py +0 -0
  2. model/medsam3.py +0 -379
model/__init__.py DELETED
File without changes
model/medsam3.py DELETED
@@ -1,379 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- from typing import Dict, Optional, List, Any
5
-
6
- try:
7
- from sam3.model_builder import build_sam3_image_model as build_sam3_model
8
- from sam3.model.data_misc import BatchedDatapoint, FindStage, BatchedFindTarget, BatchedInferenceMetadata
9
- from sam3.model import decoder as sam3_decoder
10
- SAM3_AVAILABLE = True
11
- except ImportError:
12
- build_sam3_model = None
13
- BatchedDatapoint = None
14
- FindStage = None
15
- BatchedFindTarget = None
16
- BatchedInferenceMetadata = None
17
- sam3_decoder = None
18
- SAM3_AVAILABLE = False
19
-
20
- from peft import LoraConfig, get_peft_model
21
-
22
-
23
- def _patch_sam3_decoder_for_ddp():
24
- """
25
- Monkey-patch SAM3's decoder to fix DDP device placement bug.
26
-
27
- The bug: SAM3 caches coords_h/coords_w in compilable_cord_cache and coord_cache.
28
- In DDP, these get created on cuda:0 first, then other ranks fail because
29
- the cached coords are on the wrong device.
30
-
31
- The fix: Patch _get_rpb_matrix to always move cached coords to the correct device.
32
- """
33
- if not SAM3_AVAILABLE or sam3_decoder is None:
34
- return
35
-
36
- # Find the decoder class that has _get_rpb_matrix
37
- decoder_cls = None
38
- for name in dir(sam3_decoder):
39
- cls = getattr(sam3_decoder, name)
40
- if isinstance(cls, type) and hasattr(cls, '_get_rpb_matrix'):
41
- decoder_cls = cls
42
- break
43
-
44
- if decoder_cls is None:
45
- print("[MedSAM3] Warning: Could not find decoder class to patch")
46
- return
47
-
48
- # Store original method
49
- original_get_rpb_matrix = decoder_cls._get_rpb_matrix
50
-
51
- def patched_get_rpb_matrix(self, *args, **kwargs):
52
- """Patched version that ensures coords are on the correct device."""
53
- # Get device from first tensor argument (reference_boxes)
54
- target_device = None
55
- for arg in args:
56
- if torch.is_tensor(arg):
57
- target_device = arg.device
58
- break
59
- if target_device is None:
60
- for v in kwargs.values():
61
- if torch.is_tensor(v):
62
- target_device = v.device
63
- break
64
-
65
- if target_device is not None:
66
- # Fix compilable_cord_cache if device mismatch
67
- if hasattr(self, 'compilable_cord_cache') and self.compilable_cord_cache is not None:
68
- cached_h, cached_w = self.compilable_cord_cache
69
- if cached_h.device != target_device:
70
- self.compilable_cord_cache = (
71
- cached_h.to(target_device),
72
- cached_w.to(target_device)
73
- )
74
-
75
- # Also fix coord_cache dict
76
- if hasattr(self, 'coord_cache') and self.coord_cache:
77
- for key in list(self.coord_cache.keys()):
78
- cached_h, cached_w = self.coord_cache[key]
79
- if cached_h.device != target_device:
80
- self.coord_cache[key] = (
81
- cached_h.to(target_device),
82
- cached_w.to(target_device)
83
- )
84
-
85
- return original_get_rpb_matrix(self, *args, **kwargs)
86
-
87
- # Apply patch
88
- decoder_cls._get_rpb_matrix = patched_get_rpb_matrix
89
- print("[MedSAM3] Successfully patched SAM3 decoder for DDP compatibility")
90
-
91
-
92
- # Apply the patch at module load time
93
- _patch_sam3_decoder_for_ddp()
94
-
95
- class MedSAM3Model(nn.Module):
96
- def __init__(self, model_id: str = "sam3_hiera_base", lora_rank: int = 16, image_size: int = 1024, checkpoint_path: Optional[str] = None):
97
- super().__init__()
98
- self._logged_shapes = False # For one-time debug logging
99
- self._buffers_migrated = False # Track if we've done buffer device migration
100
- self.image_size = image_size # Store for coordinate normalization
101
- # --- 1. Initialize SAM 3 Architecture ---
102
- if build_sam3_model:
103
- # Initialize SAM3 architecture without downloading from HuggingFace
104
- # (our checkpoint already contains full weights including base SAM3)
105
- self.model = build_sam3_model(load_from_HF=False, eval_mode=False)
106
-
107
- # --- 2. Load Weights ---
108
- if checkpoint_path and os.path.exists(checkpoint_path):
109
- state_dict = torch.load(checkpoint_path, map_location="cpu")
110
- if "model" in state_dict:
111
- state_dict = state_dict["model"]
112
- self.model.load_state_dict(state_dict, strict=False)
113
- else:
114
- raise ImportError(
115
- "CRITICAL: SAM3 core libraries not found. "
116
- "Ensure you have installed sam3 correctly (e.g. via pip install git+...sam3.git). "
117
- "Check logs for previous import errors."
118
- )
119
-
120
- # --- 3. Freeze Backbone ---
121
- for name, param in self.model.named_parameters():
122
- if "perception_encoder" in name:
123
- param.requires_grad = False
124
-
125
- # --- 4. Apply LoRA ---
126
- lora_config = LoraConfig(
127
- r=lora_rank,
128
- lora_alpha=lora_rank * 2,
129
- target_modules=["qkv", "proj"],
130
- lora_dropout=0.1,
131
- bias="none",
132
- task_type=None # Important: prevents peft from injecting 'input_ids'
133
- )
134
- self.model = get_peft_model(self.model, lora_config)
135
-
136
- # --- 5. Foundation Specialist Fix: Dummy Matcher ---
137
- # SAM3's forward_grounding path (which handles boxes/points) sometimes
138
- # attempts to call self.matcher even if no targets are provided.
139
- # We inject a dummy matcher that returns empty indices to prevent
140
- # 'NoneType object is not callable' crashes.
141
- base_model = self.model.get_base_model()
142
- if hasattr(base_model, 'matcher') and base_model.matcher is None:
143
- # Matcher expected signature: func(outputs, targets) -> list of matches
144
- base_model.matcher = lambda outputs, targets: []
145
- print("[MedSAM3] Injected dummy matcher for grounding stability")
146
-
147
- def forward(self, pixel_values, input_boxes=None, input_points=None, point_labels=None, text_prompt=None):
148
- # DDP Fix: Ensure all model buffers are on the same device as input
149
- # SAM3 has some internal buffers that don't auto-migrate in DDP
150
- # Only do this once per device to avoid overhead on every forward pass
151
- target_device = pixel_values.device
152
- if not self._buffers_migrated:
153
- migrated_count = 0
154
- for name, buf in self.model.named_buffers():
155
- if buf.device != target_device:
156
- buf.data = buf.data.to(target_device)
157
- migrated_count += 1
158
- if migrated_count > 0:
159
- print(f"[MedSAM3] Migrated {migrated_count} buffers to {target_device}")
160
- self._buffers_migrated = True
161
-
162
- # Debug: Log shapes once on first forward pass
163
- if not self._logged_shapes:
164
- print(f"[MedSAM3] First forward - Input shapes:")
165
- print(f" pixel_values: {pixel_values.shape}")
166
- print(f" input_boxes: {input_boxes.shape if input_boxes is not None else None}")
167
- print(f" input_points: {input_points.shape if input_points is not None else None}")
168
- print(f" point_labels: {point_labels.shape if point_labels is not None else None}")
169
-
170
- # --- 1. Handle 3D to 2D Flattening (Robust) ---
171
- if pixel_values.dim() == 5:
172
- # Input: (B, C, T, H, W) -> Goal: (B*T, C, H, W)
173
- B_orig, C, T, H, W = pixel_values.shape
174
- # Permute to (B, T, C, H, W) then flatten
175
- pixel_values = pixel_values.permute(0, 2, 1, 3, 4).reshape(B_orig * T, C, H, W)
176
-
177
- if input_boxes is not None:
178
- # input_boxes is (B, T, 4) -> (B*T, 4)
179
- input_boxes = input_boxes.view(B_orig * T, 4)
180
- if input_points is not None:
181
- # input_points is (B, T, 1, 2) -> (B*T, 1, 2)
182
- input_points = input_points.view(B_orig * T, -1, 2)
183
- if point_labels is not None:
184
- # point_labels is (B, T, 1) -> (B*T, 1)
185
- point_labels = point_labels.view(B_orig * T, -1)
186
-
187
- # After reshaping, get the actual batch size
188
- B = pixel_values.shape[0]
189
-
190
- # --- 2. Channel Handling (Ensuring 3 channels for SAM3) ---
191
- num_channels = pixel_values.shape[1]
192
- if num_channels == 1:
193
- # Single-channel (e.g., CT) -> replicate to 3 channels
194
- pixel_values = pixel_values.repeat(1, 3, 1, 1)
195
- elif num_channels == 3:
196
- # Already 3 channels (e.g., multi-modal MRI after SelectMRIChannels)
197
- pass
198
- elif num_channels == 4:
199
- # 4-channel MRI (BrainTumour) - use first 3 channels [FLAIR, T1w, T1gd]
200
- # This is a fallback; ideally SelectMedicalChannels should handle this in transforms
201
- if not self._logged_shapes:
202
- print(f"[MedSAM3 WARNING] Received 4-channel input - using first 3 channels. "
203
- f"Consider enabling SelectMedicalChannels transform.")
204
- pixel_values = pixel_values[:, :3, :, :]
205
- else:
206
- # Unexpected channel count - average and replicate
207
- if not self._logged_shapes:
208
- print(f"[MedSAM3 WARNING] Unexpected {num_channels} channels - averaging to single then replicating to 3.")
209
- pixel_values = pixel_values.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
210
-
211
- # --- 3. Prompt Dimension Enforcement ---
212
- # Boxes: (B_total, 1, 4)
213
- if input_boxes is not None and input_boxes.dim() == 2:
214
- input_boxes = input_boxes.unsqueeze(1)
215
-
216
- # Points: (B_total, 1, N, 2), Labels: (B_total, 1, N)
217
- if input_points is not None and input_points.dim() == 3:
218
- input_points = input_points.unsqueeze(1)
219
- if point_labels is not None and point_labels.dim() == 2:
220
- point_labels = point_labels.unsqueeze(1)
221
-
222
- # --- 3. Package into Official SAM 3 Structure ---
223
- if BatchedDatapoint is not None and FindStage is not None:
224
- # Get device from model parameters (critical for DDP multi-GPU)
225
- device = pixel_values.device
226
-
227
- # SAM 3 expects a SINGLE FindStage object that aggregates prompts for the entire batch.
228
- # We must flatten the batch dimension of the prompts and create corresponding img_ids.
229
-
230
- # Current shapes:
231
- # input_boxes: (B, 1, 4)
232
- # input_points: (B, 1, N, 2)
233
- # point_labels: (B, 1, N)
234
-
235
- # We treat each image in the batch as having 1 prompt (since we have 1 box/point set per slice)
236
- # So we just flatten the first dimension.
237
-
238
- # img_ids: [0, 1, 2, ... B-1] (since 1 prompt per image)
239
- img_ids = torch.arange(B, device=device, dtype=torch.long)
240
-
241
- # Text ids: all 0 (dummy)
242
- text_ids = torch.zeros(B, device=device, dtype=torch.long)
243
-
244
- # SAM3 expects SEQUENCE-FIRST format for embeddings, BATCH-FIRST for masks:
245
- # input_boxes: [num_boxes, num_prompts, 4] - sequence first
246
- # input_boxes_mask: [num_prompts, num_boxes] - batch first (1=padded/invalid)
247
- # input_boxes_label: [num_boxes, num_prompts]
248
- # input_points: [num_points, num_prompts, 2] - sequence first
249
- # input_points_mask: [num_prompts, num_points] - batch first
250
- #
251
- # For our case: 1 box per image, B images → num_boxes=1, num_prompts=B
252
-
253
- # Boxes: [B, 1, 4] → [1, B, 4] (sequence first)
254
- # SAM3 expects boxes in CxCyWH format, normalized to [0, 1]
255
- # Our input is xyxy in pixel coordinates
256
- if input_boxes is not None:
257
- boxes_xyxy = input_boxes.squeeze(1).float().to(device) # [B, 4] - x_min, y_min, x_max, y_max
258
-
259
- # Use actual tensor dimensions for normalization (more robust than stored image_size)
260
- actual_h, actual_w = pixel_values.shape[2], pixel_values.shape[3]
261
-
262
- # Convert xyxy to cxcywh and normalize to [0, 1]
263
- x_min, y_min, x_max, y_max = boxes_xyxy[:, 0], boxes_xyxy[:, 1], boxes_xyxy[:, 2], boxes_xyxy[:, 3]
264
- cx = (x_min + x_max) / 2.0 / actual_w
265
- cy = (y_min + y_max) / 2.0 / actual_h
266
- w = (x_max - x_min) / actual_w
267
- h = (y_max - y_min) / actual_h
268
-
269
- # Clamp to ensure valid boxes (min size 1% of image to avoid ROI align issues)
270
- min_size = 0.01
271
- w = torch.clamp(w, min=min_size)
272
- h = torch.clamp(h, min=min_size)
273
-
274
- boxes_cxcywh = torch.stack([cx, cy, w, h], dim=1) # [B, 4]
275
- flat_boxes = boxes_cxcywh.unsqueeze(0) # [1, B, 4]
276
- flat_boxes_mask = torch.zeros(B, 1, device=device, dtype=torch.bool) # [B, 1] - 0=valid
277
- flat_boxes_label = torch.zeros(1, B, device=device, dtype=torch.long) # [1, B]
278
- else:
279
- flat_boxes = torch.zeros(1, B, 4, device=device)
280
- flat_boxes_mask = torch.ones(B, 1, device=device, dtype=torch.bool) # 1=invalid/padded
281
- flat_boxes_label = torch.zeros(1, B, device=device, dtype=torch.long)
282
-
283
- # Points: [B, 1, N, 2] → [N, B, 2] (sequence first)
284
- # SAM3 expects points normalized to [0, 1]
285
- n_points = input_points.shape[2] if input_points is not None else 1
286
- if input_points is not None:
287
- points_pixel = input_points.squeeze(1).float().to(device) # [B, N, 2] - x, y in pixel coords
288
- # Normalize using actual tensor dimensions
289
- actual_h, actual_w = pixel_values.shape[2], pixel_values.shape[3]
290
- points_normalized = points_pixel.clone()
291
- points_normalized[..., 0] = points_pixel[..., 0] / actual_w # x normalized
292
- points_normalized[..., 1] = points_pixel[..., 1] / actual_h # y normalized
293
- flat_points = points_normalized.permute(1, 0, 2) # [B, N, 2] → [N, B, 2]
294
- flat_points_mask = torch.zeros(B, n_points, device=device, dtype=torch.bool) # 0=valid
295
- else:
296
- flat_points = torch.zeros(1, B, 2, device=device)
297
- flat_points_mask = torch.ones(B, 1, device=device, dtype=torch.bool) # 1=invalid
298
-
299
- stage = FindStage(
300
- img_ids=img_ids,
301
- text_ids=text_ids,
302
- input_boxes=flat_boxes,
303
- input_boxes_mask=flat_boxes_mask,
304
- input_boxes_label=flat_boxes_label,
305
- input_points=flat_points,
306
- input_points_mask=flat_points_mask,
307
- )
308
-
309
- # Text batch for grounding head - use provided text_prompt or fallback
310
- if text_prompt is not None:
311
- find_text_batch = [text_prompt] * B
312
- else:
313
- find_text_batch = ["medical"] * B
314
-
315
- # Create dummy target structure to satisfy SAM3's internal indexing [0]
316
- # We use the dummy matcher injected in __init__ to ensure this doesn't
317
- # actually trigger any real loss computation.
318
- dummy_target = BatchedFindTarget(
319
- num_boxes=torch.zeros(B, device=device, dtype=torch.long),
320
- boxes=torch.zeros(B, 4, device=device),
321
- boxes_padded=torch.zeros(B, 1, 4, device=device),
322
- repeated_boxes=torch.zeros(B, 4, device=device),
323
- segments=None,
324
- semantic_segments=None,
325
- is_valid_segment=None,
326
- is_exhaustive=torch.zeros(B, device=device, dtype=torch.bool),
327
- object_ids=torch.zeros(B, device=device, dtype=torch.long),
328
- object_ids_padded=torch.zeros(B, 1, device=device, dtype=torch.long),
329
- )
330
-
331
- # Create proper metadata structure (required by SAM3's type hints)
332
- # BatchedInferenceMetadata requires: coco_image_id, original_image_id, original_category_id,
333
- # original_size, object_id, frame_index, is_conditioning_only
334
- dummy_metadata = BatchedInferenceMetadata(
335
- coco_image_id=torch.zeros(B, device=device, dtype=torch.long),
336
- original_image_id=torch.zeros(B, device=device, dtype=torch.long),
337
- original_category_id=torch.zeros(B, device=device, dtype=torch.int),
338
- original_size=torch.tensor([[self.image_size, self.image_size]] * B, device=device, dtype=torch.long),
339
- object_id=torch.zeros(B, device=device, dtype=torch.long),
340
- frame_index=torch.zeros(B, device=device, dtype=torch.long),
341
- is_conditioning_only=[None] * B,
342
- ) if BatchedInferenceMetadata is not None else {}
343
-
344
- # Package into BatchedDatapoint
345
- # find_targets=[dummy_target]: satisfy internal 'input.find_targets[0]' access
346
- find_targets_list = [dummy_target] # Pre-create to verify it's not empty
347
- find_metadatas_list = [dummy_metadata]
348
- data = BatchedDatapoint(
349
- img_batch=pixel_values,
350
- find_text_batch=find_text_batch,
351
- find_inputs=[stage],
352
- find_targets=find_targets_list,
353
- find_metadatas=find_metadatas_list
354
- )
355
- # Immediate verification
356
- assert len(data.find_targets) == 1, f"find_targets should have 1 element, got {len(data.find_targets)}"
357
-
358
- # Debug: Log processed shapes once
359
- if not self._logged_shapes:
360
- print(f"[MedSAM3] Processed shapes before SAM3 call:")
361
- print(f" img_batch: {pixel_values.shape}")
362
- print(f" flat_boxes: {flat_boxes.shape} (CxCyWH normalized)")
363
- print(f" flat_boxes sample: {flat_boxes[0, 0, :] if flat_boxes.numel() > 0 else 'empty'}")
364
- print(f" flat_boxes_mask: {flat_boxes_mask.shape}")
365
- print(f" flat_points: {flat_points.shape} (normalized)")
366
- print(f" flat_points_mask: {flat_points_mask.shape}")
367
- print(f" img_ids: {img_ids.shape}")
368
- print(f" find_targets: {data.find_targets}, len={len(data.find_targets)}")
369
- print(f" find_inputs: {data.find_inputs}, len={len(data.find_inputs)}")
370
- self._logged_shapes = True
371
-
372
- # Validation safety: verify find_targets is not empty before passing to SAM3
373
- if len(data.find_targets) == 0:
374
- print(f"[MedSAM3 ERROR] find_targets is empty! dummy_target={dummy_target}")
375
- raise ValueError("find_targets list is empty - this should never happen")
376
-
377
- return self.model(data)
378
- else:
379
- return self.model(pixel_values)