Anigor66 commited on
Commit
505d087
·
1 Parent(s): 530a504

Stabilize auto mask generator (disable multi-scale crops)

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -39,7 +39,7 @@ print(f"Loading MedSAM model ({MODEL_TYPE})...")
39
 
40
  try:
41
  torch.load = patched_torch_load
42
- sam = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_CHECKPOINT)
43
  finally:
44
  torch.load = original_torch_load
45
 
@@ -51,15 +51,19 @@ predictor = SamPredictor(sam)
51
  print("✓ SamPredictor initialized for interactive segmentation")
52
 
53
  # SamAutomaticMaskGenerator for automatic mask generation
54
- # Uses the same model but with automatic grid-based prompting
 
 
 
 
55
  mask_generator = SamAutomaticMaskGenerator(
56
  model=sam,
57
- points_per_side=32, # Grid density (32x32 = 1024 points)
58
- pred_iou_thresh=0.88, # IoU threshold for filtering
59
- stability_score_thresh=0.95, # Stability threshold
60
- crop_n_layers=1, # Number of crop layers for multi-scale
61
- crop_n_points_downscale_factor=2, # Downscale factor for crops
62
- min_mask_region_area=100 # Minimum mask area in pixels
63
  )
64
  print("✓ SamAutomaticMaskGenerator initialized for automatic segmentation")
65
  print("✓ MedSAM model loaded successfully!")
 
39
 
40
  try:
41
  torch.load = patched_torch_load
42
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_CHECKPOINT)
43
  finally:
44
  torch.load = original_torch_load
45
 
 
51
  print("✓ SamPredictor initialized for interactive segmentation")
52
 
53
  # SamAutomaticMaskGenerator for automatic mask generation
54
+ # Uses the same model but with automatic grid-based prompting.
55
+ # NOTE: We disable multi-scale cropping (crop_n_layers=0) because it can
56
+ # sometimes produce empty crop_boxes on certain images, which leads to
57
+ # IndexError inside torchvision.boxes.box_area. Using a single-scale grid
58
+ # is more stable for our use case.
59
  mask_generator = SamAutomaticMaskGenerator(
60
  model=sam,
61
+ points_per_side=32, # Grid density (32x32 = 1024 points)
62
+ pred_iou_thresh=0.88, # IoU threshold for filtering
63
+ stability_score_thresh=0.95, # Stability threshold
64
+ crop_n_layers=0, # Disable multi-scale crops to avoid IndexError
65
+ crop_n_points_downscale_factor=2,
66
+ min_mask_region_area=100 # Minimum mask area in pixels
67
  )
68
  print("✓ SamAutomaticMaskGenerator initialized for automatic segmentation")
69
  print("✓ MedSAM model loaded successfully!")