mmrech commited on
Commit
3233b61
·
1 Parent(s): adc6eda

Fix SAM 3 implementation to match official akhaliq/sam3

Browse files

- Update imports: Use Sam3Processor and Sam3Model (not AutoImageProcessor/AutoModel)
- Update model loading: Use facebook/sam3 with proper torch_dtype (float16 for GPU)
- Create run_sam3_inference() helper matching official implementation
- Update all inference calls to use processor.post_process_instance_segmentation()
- Fix mask handling to work with SAM 3 output format (list of masks + scores)

Matches official implementation from: https://huggingface.co/spaces/akhaliq/sam3/blob/main/app.py

Files changed (1) hide show
  1. app.py +207 -370
app.py CHANGED
@@ -15,7 +15,7 @@ import torch
15
  import pydicom
16
  import numpy as np
17
  from PIL import Image, ImageEnhance, ImageDraw
18
- from transformers import AutoImageProcessor, AutoModel
19
  import matplotlib.pyplot as plt
20
  from matplotlib.patches import Rectangle
21
  from scipy import ndimage
@@ -46,33 +46,71 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
46
  model = None
47
  processor = None
48
 
49
- # SAM 3 model identifier - using AutoImageProcessor/AutoModel for SAM 3
50
- SAM_MODEL_ID = "facebook/sam3-hiera-large"
51
 
52
  try:
53
- processor = AutoImageProcessor.from_pretrained(SAM_MODEL_ID, token=hf_token)
54
- model = AutoModel.from_pretrained(SAM_MODEL_ID, token=hf_token)
55
- model = model.to(device)
 
 
 
 
56
  model.eval()
57
  print(f"✅ SAM 3 Model Loaded Successfully! ({SAM_MODEL_ID})")
58
  except Exception as e:
59
- print(f"⚠️ Model Load Warning: {e}")
60
- print("Trying alternative SAM 3 model identifier...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  try:
62
- # Fallback: try without hiera suffix
63
- SAM_MODEL_ID = "facebook/sam3"
64
- processor = AutoImageProcessor.from_pretrained(SAM_MODEL_ID, token=hf_token)
65
- model = AutoModel.from_pretrained(SAM_MODEL_ID, token=hf_token)
66
- model = model.to(device)
67
- model.eval()
68
- print(f"✅ SAM 3 Model Loaded Successfully! ({SAM_MODEL_ID})")
69
- except Exception as e2:
70
- print(f"❌ Failed to load SAM 3 model: {e2}")
71
- print("Ensure you have:")
72
- print(" 1. transformers>=4.45.0 for SAM 3 support")
73
- print(" 2. Valid Hugging Face token with access to SAM 3")
74
- print(" 3. Sufficient memory for the model")
75
- raise
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Create Sample DICOM File for Demo
78
  demo_dicom_path = "demo_brain_mri.dcm"
@@ -304,90 +342,40 @@ def process_medical_image(image_file, prompt_text, modality, window_type, return
304
 
305
  pil_image = Image.fromarray(img_uint8.astype(np.uint8))
306
 
307
- # Run SAM 3 Inference
308
- try:
309
- # Prepare inputs
310
- inputs = processor(images=pil_image, text=prompt_text, return_tensors="pt")
311
- # Move inputs to device
312
- inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
313
-
314
- with torch.no_grad():
315
- outputs = model(**inputs)
316
-
317
- # Extract masks from outputs - handle different output formats
318
- masks = None
319
- if hasattr(outputs, 'pred_masks'):
320
- masks = outputs.pred_masks
321
- elif isinstance(outputs, dict):
322
- # Try common mask keys
323
- masks = outputs.get('pred_masks') or outputs.get('masks') or outputs.get('segmentation_masks')
324
- if masks is None and len(outputs) > 0:
325
- # Get first tensor value if no standard key found
326
- first_value = list(outputs.values())[0]
327
- if isinstance(first_value, torch.Tensor) and len(first_value.shape) >= 2:
328
- masks = first_value
329
- elif isinstance(outputs, (list, tuple)) and len(outputs) > 0:
330
- masks = outputs[0]
331
- else:
332
- masks = outputs
333
-
334
- # Convert to numpy and process
335
- if masks is not None:
336
- if isinstance(masks, torch.Tensor):
337
- masks = masks.cpu().numpy()
338
-
339
- # Handle batch dimension if present
340
- if len(masks.shape) == 4: # [batch, num_masks, H, W]
341
- masks = masks[0] # Take first batch
342
- elif len(masks.shape) == 3: # [num_masks, H, W] or [H, W, channels]
343
- if masks.shape[0] < masks.shape[-1]: # Likely [num_masks, H, W]
344
- masks = masks # Keep as is
345
- else: # Likely [H, W, channels]
346
- masks = masks[..., 0] if masks.shape[-1] == 1 else masks
347
-
348
- # Ensure boolean mask - threshold if needed
349
- if masks.dtype != bool:
350
- if len(masks.shape) == 3: # Multiple masks
351
- masks = masks > 0.5
352
- # Combine all masks into one
353
- masks = np.any(masks, axis=0)
354
- else: # Single mask
355
- masks = masks > 0.5
356
-
357
- results = {'masks': masks}
358
- else:
359
- print("⚠️ Warning: No masks found in model output")
360
- results = {'masks': None}
361
-
362
- except Exception as e:
363
- print(f"❌ Error during model inference: {e}")
364
- import traceback
365
- traceback.print_exc()
366
  return None
367
 
368
- # Draw Masks on Image
369
  plt.figure(figsize=(10, 10))
370
  plt.imshow(pil_image)
371
 
372
  final_mask = None
373
  if 'masks' in results and results['masks'] is not None:
374
- masks = results['masks']
375
- # Handle different mask formats
376
- if isinstance(masks, np.ndarray):
377
- if len(masks.shape) == 3: # Multiple masks [num_masks, H, W]
378
- final_mask = np.any(masks, axis=0)
379
- elif len(masks.shape) == 2: # Single mask [H, W]
380
- final_mask = masks
381
- else:
382
- print(f"⚠️ Warning: Unexpected mask shape: {masks.shape}")
383
- final_mask = None
 
 
 
384
 
385
- if final_mask is not None:
 
 
386
  plt.imshow(final_mask, alpha=0.5, cmap='spring')
387
  else:
388
- print("⚠️ Warning: Could not process mask format.")
389
  else:
390
- print(f"⚠️ Warning: Masks is not a numpy array: {type(masks)}")
391
  else:
392
  print("⚠️ Warning: No masks in results.")
393
 
@@ -549,90 +537,39 @@ def process_medical_image_enhanced(image_file, prompt_text, modality, window_typ
549
  enhancer = ImageEnhance.Contrast(pil_image)
550
  pil_image = enhancer.enhance(contrast)
551
 
552
- # Run SAM 3 Inference
553
- try:
554
- # Prepare inputs
555
- inputs = processor(images=pil_image, text=prompt_text, return_tensors="pt")
556
- # Move inputs to device
557
- inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
558
-
559
- with torch.no_grad():
560
- outputs = model(**inputs)
561
-
562
- # Extract masks from outputs - handle different output formats
563
- masks = None
564
- if hasattr(outputs, 'pred_masks'):
565
- masks = outputs.pred_masks
566
- elif isinstance(outputs, dict):
567
- # Try common mask keys
568
- masks = outputs.get('pred_masks') or outputs.get('masks') or outputs.get('segmentation_masks')
569
- if masks is None and len(outputs) > 0:
570
- # Get first tensor value if no standard key found
571
- first_value = list(outputs.values())[0]
572
- if isinstance(first_value, torch.Tensor) and len(first_value.shape) >= 2:
573
- masks = first_value
574
- elif isinstance(outputs, (list, tuple)) and len(outputs) > 0:
575
- masks = outputs[0]
576
- else:
577
- masks = outputs
578
-
579
- # Convert to numpy and process
580
- if masks is not None:
581
- if isinstance(masks, torch.Tensor):
582
- masks = masks.cpu().numpy()
583
-
584
- # Handle batch dimension if present
585
- if len(masks.shape) == 4: # [batch, num_masks, H, W]
586
- masks = masks[0] # Take first batch
587
- elif len(masks.shape) == 3: # [num_masks, H, W] or [H, W, channels]
588
- if masks.shape[0] < masks.shape[-1]: # Likely [num_masks, H, W]
589
- masks = masks # Keep as is
590
- else: # Likely [H, W, channels]
591
- masks = masks[..., 0] if masks.shape[-1] == 1 else masks
592
-
593
- # Ensure boolean mask - threshold if needed
594
- if masks.dtype != bool:
595
- if len(masks.shape) == 3: # Multiple masks
596
- masks = masks > 0.5
597
- # Combine all masks into one
598
- masks = np.any(masks, axis=0)
599
- else: # Single mask
600
- masks = masks > 0.5
601
-
602
- results = {'masks': masks}
603
- else:
604
- print("⚠️ Warning: No masks found in model output")
605
- results = {'masks': None}
606
-
607
- except Exception as e:
608
- print(f"❌ Error during model inference: {e}")
609
- import traceback
610
- traceback.print_exc()
611
  return None
612
 
613
- # Draw Masks on Image with enhanced visualization
614
  plt.figure(figsize=(10, 10))
615
  plt.imshow(pil_image)
616
 
617
  final_mask = None
618
  if 'masks' in results and results['masks'] is not None:
619
- masks = results['masks']
620
- # Handle different mask formats
621
- if isinstance(masks, np.ndarray):
622
- if len(masks.shape) == 3: # Multiple masks [num_masks, H, W]
623
- final_mask = np.any(masks, axis=0)
624
- elif len(masks.shape) == 2: # Single mask [H, W]
625
- final_mask = masks
626
- else:
627
- print(f"⚠️ Warning: Unexpected mask shape: {masks.shape}")
628
- final_mask = None
 
 
629
 
630
- if final_mask is not None:
 
 
631
  plt.imshow(final_mask, alpha=transparency, cmap=colormap)
632
  else:
633
- print("⚠️ Warning: Could not process mask format.")
634
  else:
635
- print(f"⚠️ Warning: Masks is not a numpy array: {type(masks)}")
636
  else:
637
  print("⚠️ Warning: No masks in results.")
638
 
@@ -925,51 +862,35 @@ def process_with_point_prompt(image_file, point_x, point_y, modality, window_typ
925
  point_y = max(0, min(int(point_y), h - 1))
926
 
927
  # Create a prompt based on the point location
928
- # Use the point's neighborhood intensity as a hint for segmentation
929
  prompt_text = f"segment region at point"
930
 
931
- # Process with SAM
932
- inputs = processor(images=pil_image, text=prompt_text, return_tensors="pt")
933
- inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
934
-
935
- with torch.no_grad():
936
- outputs = model(**inputs)
937
-
938
- # Extract masks
939
- masks = None
940
- if hasattr(outputs, 'pred_masks'):
941
- masks = outputs.pred_masks
942
- elif isinstance(outputs, dict):
943
- masks = outputs.get('pred_masks') or outputs.get('masks')
944
 
945
- if masks is not None:
946
- if isinstance(masks, torch.Tensor):
947
- masks = masks.cpu().numpy()
948
-
949
- if len(masks.shape) == 4:
950
- masks = masks[0]
951
-
952
- if masks.dtype != bool:
953
- masks = masks > 0.5
954
-
955
- if len(masks.shape) == 3:
956
- # Select mask containing the point
957
- best_mask = None
958
- for i in range(masks.shape[0]):
959
- mask_resized = np.array(Image.fromarray(masks[i].astype(np.uint8) * 255).resize((w, h))) > 127
960
- if mask_resized[point_y, point_x]:
961
- best_mask = mask_resized
962
- break
963
-
964
- if best_mask is None:
965
- best_mask = np.any(masks, axis=0)
966
- best_mask = np.array(Image.fromarray(best_mask.astype(np.uint8) * 255).resize((w, h))) > 127
967
 
968
- final_mask = best_mask
969
- else:
970
- final_mask = np.array(Image.fromarray(masks.astype(np.uint8) * 255).resize((w, h))) > 127
971
- else:
972
- final_mask = None
 
 
 
 
 
 
 
 
 
973
 
974
  # Draw result with point marker
975
  plt.figure(figsize=(10, 10))
@@ -1039,43 +960,30 @@ def process_with_box_prompt(image_file, x1, y1, x2, y2, modality, window_type, c
1039
 
1040
  prompt_text = "segment region in bounding box"
1041
 
1042
- # Process with SAM
1043
- inputs = processor(images=pil_image, text=prompt_text, return_tensors="pt")
1044
- inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
1045
-
1046
- with torch.no_grad():
1047
- outputs = model(**inputs)
1048
-
1049
- # Extract and filter masks by box region
1050
- masks = None
1051
- if hasattr(outputs, 'pred_masks'):
1052
- masks = outputs.pred_masks
1053
- elif isinstance(outputs, dict):
1054
- masks = outputs.get('pred_masks') or outputs.get('masks')
1055
 
1056
  final_mask = None
1057
- if masks is not None:
1058
- if isinstance(masks, torch.Tensor):
1059
- masks = masks.cpu().numpy()
1060
-
1061
- if len(masks.shape) == 4:
1062
- masks = masks[0]
1063
-
1064
- if masks.dtype != bool:
1065
- masks = masks > 0.5
1066
-
1067
- if len(masks.shape) == 3:
1068
- combined = np.any(masks, axis=0)
1069
- else:
1070
- combined = masks
1071
-
1072
- # Resize to image size
1073
- combined_resized = np.array(Image.fromarray(combined.astype(np.uint8) * 255).resize((w, h))) > 127
1074
 
1075
- # Create box mask and intersect
1076
- box_mask = np.zeros((h, w), dtype=bool)
1077
- box_mask[y1:y2, x1:x2] = True
1078
- final_mask = combined_resized & box_mask
 
 
 
1079
 
1080
  # Draw result with box
1081
  plt.figure(figsize=(10, 10))
@@ -1136,79 +1044,40 @@ def process_multi_mask(image_file, prompt_text, modality, window_type, num_masks
1136
  if not prompt_text or not prompt_text.strip():
1137
  prompt_text = "brain"
1138
 
1139
- # Process with SAM
1140
- inputs = processor(images=pil_image, text=prompt_text, return_tensors="pt")
1141
- inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
1142
-
1143
- with torch.no_grad():
1144
- outputs = model(**inputs)
1145
-
1146
- # Extract masks
1147
- masks = None
1148
- scores = None
1149
-
1150
- if hasattr(outputs, 'pred_masks'):
1151
- masks = outputs.pred_masks
1152
- elif isinstance(outputs, dict):
1153
- masks = outputs.get('pred_masks') or outputs.get('masks')
1154
- scores = outputs.get('iou_scores') or outputs.get('scores')
1155
 
1156
  results = []
1157
  mask_info = []
1158
 
1159
- if masks is not None:
1160
- if isinstance(masks, torch.Tensor):
1161
- masks = masks.cpu().numpy()
1162
- if scores is not None and isinstance(scores, torch.Tensor):
1163
- scores = scores.cpu().numpy().flatten()
1164
 
1165
- if len(masks.shape) == 4:
1166
- masks = masks[0]
1167
 
1168
- if len(masks.shape) == 3:
1169
- num_available = masks.shape[0]
1170
- num_to_show = min(num_masks, num_available)
1171
-
1172
- # Generate confidence scores if not available
1173
- if scores is None:
1174
- scores = [1.0 / (i + 1) for i in range(num_available)] # Simulated scores
 
1175
 
1176
- colormaps = ['spring', 'cool', 'hot', 'viridis', 'plasma']
 
 
1177
 
1178
- for i in range(num_to_show):
1179
- mask = masks[i]
1180
- if mask.dtype != bool:
1181
- mask = mask > 0.5
1182
-
1183
- score = scores[i] if i < len(scores) else 0.5
1184
-
1185
- # Create visualization
1186
- plt.figure(figsize=(8, 8))
1187
- plt.imshow(pil_image)
1188
- plt.imshow(mask, alpha=0.5, cmap=colormaps[i % len(colormaps)])
1189
- plt.axis('off')
1190
- plt.title(f"Mask {i+1} - Confidence: {score:.2%}", fontsize=12)
1191
-
1192
- output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
1193
- output_path = output_file.name
1194
- output_file.close()
1195
-
1196
- plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
1197
- plt.close()
1198
-
1199
- results.append(output_path)
1200
- mask_info.append(f"Mask {i+1}: {score:.2%} confidence, {np.sum(mask):,} pixels")
1201
- else:
1202
- # Single mask case
1203
- mask = masks
1204
- if mask.dtype != bool:
1205
- mask = mask > 0.5
1206
 
 
1207
  plt.figure(figsize=(8, 8))
1208
  plt.imshow(pil_image)
1209
- plt.imshow(mask, alpha=0.5, cmap='spring')
1210
  plt.axis('off')
1211
- plt.title(f"Single Mask Output", fontsize=12)
1212
 
1213
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
1214
  output_path = output_file.name
@@ -1218,7 +1087,7 @@ def process_multi_mask(image_file, prompt_text, modality, window_type, num_masks
1218
  plt.close()
1219
 
1220
  results.append(output_path)
1221
- mask_info.append(f"Single mask: {np.sum(mask):,} pixels")
1222
 
1223
  status = f"✅ Generated {len(results)} mask candidate(s)"
1224
  info = "\n".join(mask_info) if mask_info else "No mask information available"
@@ -1582,48 +1451,28 @@ def automatic_mask_generator(image_file, modality, window_type,
1582
  progress(0.3 + 0.5 * (prompt_idx / len(prompts)), desc=f"Processing prompt: {prompt}...")
1583
 
1584
  try:
1585
- inputs = processor(images=pil_image, text=prompt, return_tensors="pt")
1586
- inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
1587
 
1588
- with torch.no_grad():
1589
- outputs = model(**inputs)
1590
-
1591
- masks = None
1592
- if hasattr(outputs, 'pred_masks'):
1593
- masks = outputs.pred_masks
1594
- elif isinstance(outputs, dict):
1595
- masks = outputs.get('pred_masks') or outputs.get('masks')
1596
-
1597
- if masks is not None:
1598
- if isinstance(masks, torch.Tensor):
1599
- masks = masks.cpu().numpy()
1600
-
1601
- if len(masks.shape) == 4:
1602
- masks = masks[0]
1603
-
1604
- if len(masks.shape) == 3:
1605
- for i in range(masks.shape[0]):
1606
- mask = masks[i]
1607
- if mask.dtype != bool:
1608
- mask = mask > 0.5
1609
-
1610
- # Filter by minimum area
1611
- mask_area = np.sum(mask)
1612
- if mask_area >= min_mask_area:
1613
- # Resize mask to image size
1614
- mask_resized = np.array(
1615
- Image.fromarray(mask.astype(np.uint8) * 255).resize((w, h))
1616
- ) > 127
1617
- all_masks.append(mask_resized)
1618
- all_scores.append(mask_area)
1619
- elif len(masks.shape) == 2:
1620
- mask = masks
1621
- if mask.dtype != bool:
1622
- mask = mask > 0.5
1623
- mask_area = np.sum(mask)
1624
  if mask_area >= min_mask_area:
 
1625
  mask_resized = np.array(
1626
- Image.fromarray(mask.astype(np.uint8) * 255).resize((w, h))
1627
  ) > 127
1628
  all_masks.append(mask_resized)
1629
  all_scores.append(mask_area)
@@ -1781,35 +1630,23 @@ def process_with_advanced_transforms(image_file, prompt_text, modality, window_t
1781
  if not prompt_text or not prompt_text.strip():
1782
  prompt_text = "brain"
1783
 
1784
- # Process with SAM
1785
- inputs = processor(images=pil_image, text=prompt_text, return_tensors="pt")
1786
- inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
1787
-
1788
- with torch.no_grad():
1789
- outputs = model(**inputs)
1790
-
1791
- # Extract masks
1792
- masks = None
1793
- if hasattr(outputs, 'pred_masks'):
1794
- masks = outputs.pred_masks
1795
- elif isinstance(outputs, dict):
1796
- masks = outputs.get('pred_masks') or outputs.get('masks')
1797
 
1798
  final_mask = None
1799
- if masks is not None:
1800
- if isinstance(masks, torch.Tensor):
1801
- masks = masks.cpu().numpy()
1802
-
1803
- if len(masks.shape) == 4:
1804
- masks = masks[0]
1805
-
1806
- if masks.dtype != bool:
1807
- masks = masks > 0.5
 
1808
 
1809
- if len(masks.shape) == 3:
1810
- final_mask = np.any(masks, axis=0)
1811
- else:
1812
- final_mask = masks
1813
 
1814
  # Visualize
1815
  plt.figure(figsize=(12, 6))
 
15
  import pydicom
16
  import numpy as np
17
  from PIL import Image, ImageEnhance, ImageDraw
18
+ from transformers import Sam3Processor, Sam3Model
19
  import matplotlib.pyplot as plt
20
  from matplotlib.patches import Rectangle
21
  from scipy import ndimage
 
46
  model = None
47
  processor = None
48
 
49
+ # SAM 3 model identifier - matching official implementation
50
+ SAM_MODEL_ID = "facebook/sam3"
51
 
52
  try:
53
+ # Load model with proper dtype (float16 for GPU, float32 for CPU) - matching official implementation
54
+ model = Sam3Model.from_pretrained(
55
+ SAM_MODEL_ID,
56
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
57
+ token=hf_token
58
+ ).to(device)
59
+ processor = Sam3Processor.from_pretrained(SAM_MODEL_ID, token=hf_token)
60
  model.eval()
61
  print(f"✅ SAM 3 Model Loaded Successfully! ({SAM_MODEL_ID})")
62
  except Exception as e:
63
+ print(f" Failed to load SAM 3 model: {e}")
64
+ print("Ensure you have:")
65
+ print(" 1. transformers>=4.45.0 for SAM 3 support")
66
+ print(" 2. Valid Hugging Face token with access to SAM 3")
67
+ print(" 3. Sufficient memory for the model")
68
+ raise
69
+
70
+ def run_sam3_inference(pil_image, prompt_text, threshold=0.5, mask_threshold=0.5):
71
+ """
72
+ Run SAM 3 inference - matching official implementation from akhaliq/sam3.
73
+
74
+ Args:
75
+ pil_image: PIL Image to segment
76
+ prompt_text: Text prompt for segmentation
77
+ threshold: Detection threshold (higher = fewer detections)
78
+ mask_threshold: Mask threshold (higher = sharper masks)
79
+
80
+ Returns:
81
+ results dict with 'masks' and 'scores' keys, or None if failed
82
+ """
83
+ if model is None or processor is None:
84
+ print("❌ Model not loaded")
85
+ return None
86
+
87
  try:
88
+ # Prepare inputs - matching official implementation
89
+ inputs = processor(images=pil_image, text=prompt_text.strip(), return_tensors="pt").to(device)
90
+
91
+ # Convert float32 inputs to model dtype (float16 for GPU) - matching official implementation
92
+ for key in inputs:
93
+ if isinstance(inputs[key], torch.Tensor) and inputs[key].dtype == torch.float32:
94
+ inputs[key] = inputs[key].to(model.dtype)
95
+
96
+ with torch.no_grad():
97
+ outputs = model(**inputs)
98
+
99
+ # Post-process using processor method - matching official implementation
100
+ results = processor.post_process_instance_segmentation(
101
+ outputs,
102
+ threshold=threshold,
103
+ mask_threshold=mask_threshold,
104
+ target_sizes=inputs.get("original_sizes").tolist() if "original_sizes" in inputs else [pil_image.size[::-1]]
105
+ )[0] # Get first batch result
106
+
107
+ return results
108
+
109
+ except Exception as e:
110
+ print(f"❌ Error during SAM 3 inference: {e}")
111
+ import traceback
112
+ traceback.print_exc()
113
+ return None
114
 
115
  # Create Sample DICOM File for Demo
116
  demo_dicom_path = "demo_brain_mri.dcm"
 
342
 
343
  pil_image = Image.fromarray(img_uint8.astype(np.uint8))
344
 
345
+ # Run SAM 3 Inference - using helper function matching official implementation
346
+ results = run_sam3_inference(pil_image, prompt_text, threshold=0.5, mask_threshold=0.5)
347
+
348
+ if results is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  return None
350
 
351
+ # Draw Masks on Image - matching official implementation format
352
  plt.figure(figsize=(10, 10))
353
  plt.imshow(pil_image)
354
 
355
  final_mask = None
356
  if 'masks' in results and results['masks'] is not None:
357
+ masks = results['masks'] # List of mask tensors from post_process_instance_segmentation
358
+ scores = results.get('scores', [])
359
+
360
+ if len(masks) > 0:
361
+ # Combine all masks into one (or use first mask)
362
+ # Convert tensors to numpy and combine
363
+ mask_arrays = []
364
+ for mask in masks:
365
+ if isinstance(mask, torch.Tensor):
366
+ mask_np = mask.cpu().numpy()
367
+ else:
368
+ mask_np = np.array(mask)
369
+ mask_arrays.append(mask_np)
370
 
371
+ # Combine all masks
372
+ if len(mask_arrays) > 0:
373
+ final_mask = np.any(mask_arrays, axis=0)
374
  plt.imshow(final_mask, alpha=0.5, cmap='spring')
375
  else:
376
+ print("⚠️ Warning: No valid masks found.")
377
  else:
378
+ print("⚠️ Warning: No masks in results.")
379
  else:
380
  print("⚠️ Warning: No masks in results.")
381
 
 
537
  enhancer = ImageEnhance.Contrast(pil_image)
538
  pil_image = enhancer.enhance(contrast)
539
 
540
+ # Run SAM 3 Inference - using helper function matching official implementation
541
+ results = run_sam3_inference(pil_image, prompt_text, threshold=0.5, mask_threshold=0.5)
542
+
543
+ if results is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  return None
545
 
546
+ # Draw Masks on Image with enhanced visualization - matching official implementation format
547
  plt.figure(figsize=(10, 10))
548
  plt.imshow(pil_image)
549
 
550
  final_mask = None
551
  if 'masks' in results and results['masks'] is not None:
552
+ masks = results['masks'] # List of mask tensors from post_process_instance_segmentation
553
+ scores = results.get('scores', [])
554
+
555
+ if len(masks) > 0:
556
+ # Combine all masks into one
557
+ mask_arrays = []
558
+ for mask in masks:
559
+ if isinstance(mask, torch.Tensor):
560
+ mask_np = mask.cpu().numpy()
561
+ else:
562
+ mask_np = np.array(mask)
563
+ mask_arrays.append(mask_np)
564
 
565
+ # Combine all masks
566
+ if len(mask_arrays) > 0:
567
+ final_mask = np.any(mask_arrays, axis=0)
568
  plt.imshow(final_mask, alpha=transparency, cmap=colormap)
569
  else:
570
+ print("⚠️ Warning: No valid masks found.")
571
  else:
572
+ print("⚠️ Warning: No masks in results.")
573
  else:
574
  print("⚠️ Warning: No masks in results.")
575
 
 
862
  point_y = max(0, min(int(point_y), h - 1))
863
 
864
  # Create a prompt based on the point location
 
865
  prompt_text = f"segment region at point"
866
 
867
+ # Process with SAM 3 - using helper function
868
+ results = run_sam3_inference(pil_image, prompt_text, threshold=0.5, mask_threshold=0.5)
 
 
 
 
 
 
 
 
 
 
 
869
 
870
+ final_mask = None
871
+ if results and 'masks' in results and results['masks'] is not None:
872
+ masks = results['masks']
873
+ # Select mask containing the point
874
+ for mask in masks:
875
+ if isinstance(mask, torch.Tensor):
876
+ mask_np = mask.cpu().numpy()
877
+ else:
878
+ mask_np = np.array(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
879
 
880
+ # Resize to image size
881
+ mask_resized = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127
882
+ if mask_resized[point_y, point_x]:
883
+ final_mask = mask_resized
884
+ break
885
+
886
+ # If no mask contains the point, use first mask
887
+ if final_mask is None and len(masks) > 0:
888
+ mask = masks[0]
889
+ if isinstance(mask, torch.Tensor):
890
+ mask_np = mask.cpu().numpy()
891
+ else:
892
+ mask_np = np.array(mask)
893
+ final_mask = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127
894
 
895
  # Draw result with point marker
896
  plt.figure(figsize=(10, 10))
 
960
 
961
  prompt_text = "segment region in bounding box"
962
 
963
+ # Process with SAM 3 - using helper function
964
+ results = run_sam3_inference(pil_image, prompt_text, threshold=0.5, mask_threshold=0.5)
 
 
 
 
 
 
 
 
 
 
 
965
 
966
  final_mask = None
967
+ if results and 'masks' in results and results['masks'] is not None:
968
+ masks = results['masks']
969
+ # Combine all masks
970
+ mask_arrays = []
971
+ for mask in masks:
972
+ if isinstance(mask, torch.Tensor):
973
+ mask_np = mask.cpu().numpy()
974
+ else:
975
+ mask_np = np.array(mask)
976
+ # Resize to image size
977
+ mask_resized = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127
978
+ mask_arrays.append(mask_resized)
 
 
 
 
 
979
 
980
+ if len(mask_arrays) > 0:
981
+ combined = np.any(mask_arrays, axis=0)
982
+
983
+ # Create box mask and intersect
984
+ box_mask = np.zeros((h, w), dtype=bool)
985
+ box_mask[y1:y2, x1:x2] = True
986
+ final_mask = combined & box_mask
987
 
988
  # Draw result with box
989
  plt.figure(figsize=(10, 10))
 
1044
  if not prompt_text or not prompt_text.strip():
1045
  prompt_text = "brain"
1046
 
1047
+ # Process with SAM 3 - using helper function
1048
+ sam_results = run_sam3_inference(pil_image, prompt_text, threshold=0.5, mask_threshold=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1049
 
1050
  results = []
1051
  mask_info = []
1052
 
1053
+ if sam_results and 'masks' in sam_results and sam_results['masks'] is not None:
1054
+ masks = sam_results['masks'] # List of mask tensors
1055
+ scores = sam_results.get('scores', []) # List of scores
 
 
1056
 
1057
+ num_available = len(masks)
1058
+ num_to_show = min(num_masks, num_available)
1059
 
1060
+ colormaps = ['spring', 'cool', 'hot', 'viridis', 'plasma']
1061
+
1062
+ for i in range(num_to_show):
1063
+ mask = masks[i]
1064
+ if isinstance(mask, torch.Tensor):
1065
+ mask_np = mask.cpu().numpy()
1066
+ else:
1067
+ mask_np = np.array(mask)
1068
 
1069
+ # Convert to boolean
1070
+ if mask_np.dtype != bool:
1071
+ mask_np = mask_np > 0.5
1072
 
1073
+ score = scores[i].item() if i < len(scores) and isinstance(scores[i], torch.Tensor) else (scores[i] if i < len(scores) else 0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1074
 
1075
+ # Create visualization
1076
  plt.figure(figsize=(8, 8))
1077
  plt.imshow(pil_image)
1078
+ plt.imshow(mask_np, alpha=0.5, cmap=colormaps[i % len(colormaps)])
1079
  plt.axis('off')
1080
+ plt.title(f"Mask {i+1} - Confidence: {score:.2%}", fontsize=12)
1081
 
1082
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
1083
  output_path = output_file.name
 
1087
  plt.close()
1088
 
1089
  results.append(output_path)
1090
+ mask_info.append(f"Mask {i+1}: {score:.2%} confidence, {np.sum(mask_np):,} pixels")
1091
 
1092
  status = f"✅ Generated {len(results)} mask candidate(s)"
1093
  info = "\n".join(mask_info) if mask_info else "No mask information available"
 
1451
  progress(0.3 + 0.5 * (prompt_idx / len(prompts)), desc=f"Processing prompt: {prompt}...")
1452
 
1453
  try:
1454
+ # Process with SAM 3 - using helper function
1455
+ sam_results = run_sam3_inference(pil_image, prompt, threshold=0.5, mask_threshold=0.5)
1456
 
1457
+ if sam_results and 'masks' in sam_results and sam_results['masks'] is not None:
1458
+ masks = sam_results['masks'] # List of mask tensors
1459
+
1460
+ for mask in masks:
1461
+ if isinstance(mask, torch.Tensor):
1462
+ mask_np = mask.cpu().numpy()
1463
+ else:
1464
+ mask_np = np.array(mask)
1465
+
1466
+ # Convert to boolean
1467
+ if mask_np.dtype != bool:
1468
+ mask_np = mask_np > 0.5
1469
+
1470
+ # Filter by minimum area
1471
+ mask_area = np.sum(mask_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1472
  if mask_area >= min_mask_area:
1473
+ # Resize mask to image size
1474
  mask_resized = np.array(
1475
+ Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))
1476
  ) > 127
1477
  all_masks.append(mask_resized)
1478
  all_scores.append(mask_area)
 
1630
  if not prompt_text or not prompt_text.strip():
1631
  prompt_text = "brain"
1632
 
1633
+ # Process with SAM 3 - using helper function
1634
+ results = run_sam3_inference(pil_image, prompt_text, threshold=0.5, mask_threshold=0.5)
 
 
 
 
 
 
 
 
 
 
 
1635
 
1636
  final_mask = None
1637
+ if results and 'masks' in results and results['masks'] is not None:
1638
+ masks = results['masks']
1639
+ # Combine all masks
1640
+ mask_arrays = []
1641
+ for mask in masks:
1642
+ if isinstance(mask, torch.Tensor):
1643
+ mask_np = mask.cpu().numpy()
1644
+ else:
1645
+ mask_np = np.array(mask)
1646
+ mask_arrays.append(mask_np)
1647
 
1648
+ if len(mask_arrays) > 0:
1649
+ final_mask = np.any(mask_arrays, axis=0)
 
 
1650
 
1651
  # Visualize
1652
  plt.figure(figsize=(12, 6))