Anigor66 commited on
Commit
d61084a
·
1 Parent(s): 57b395c

Use MedSAM for both interactive and automatic mask generation

Browse files
Files changed (1) hide show
  1. app.py +202 -8
app.py CHANGED
@@ -13,13 +13,13 @@ import json
13
  import base64
14
 
15
  # Import MedSAM components
16
- from segment_anything import sam_model_registry, SamPredictor
17
 
18
  # Initialize model
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  print(f"Using device: {device}")
21
 
22
- # Load your MedSAM model
23
  MODEL_CHECKPOINT = "medsam_vit_b.pth"
24
  MODEL_TYPE = "vit_b"
25
 
@@ -35,11 +35,30 @@ def patched_torch_load(f, *args, **kwargs):
35
  torch.load = patched_torch_load
36
 
37
  try:
38
- sam = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_CHECKPOINT)
39
- sam.to(device=device)
40
- sam.eval()
41
- predictor = SamPredictor(sam)
42
- print("✓ MedSAM model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  finally:
44
  # Restore original torch.load
45
  torch.load = original_torch_load
@@ -324,6 +343,118 @@ def segment_multiple_boxes(image, request_json):
324
  })
325
 
326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  # =============================================================================
328
  # LEGACY API FUNCTIONS (kept for backwards compatibility with test scripts)
329
  # =============================================================================
@@ -670,7 +801,70 @@ with gr.Blocks(title="MedSAM Inference API") as demo:
670
  api_name="segment_with_box" # Keep old API name for compatibility
671
  )
672
 
673
- # Tab 5: Simple UI Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
  with gr.Tab("Simple Interface"):
675
  gr.Markdown("## Click-based Segmentation")
676
  gr.Markdown("Enter X, Y coordinates to segment")
 
13
  import base64
14
 
15
  # Import MedSAM components
16
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
17
 
18
  # Initialize model
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  print(f"Using device: {device}")
21
 
22
+ # Model configuration - using MedSAM (vit_b) for both interactive and automatic segmentation
23
  MODEL_CHECKPOINT = "medsam_vit_b.pth"
24
  MODEL_TYPE = "vit_b"
25
 
 
35
  torch.load = patched_torch_load
36
 
37
  try:
38
+ # Load MedSAM model (vit_b) - used for both interactive and automatic segmentation
39
+ print(f"Loading MedSAM model ({MODEL_TYPE})...")
40
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_CHECKPOINT)
41
+ sam.to(device=device)
42
+ sam.eval()
43
+
44
+ # SamPredictor for interactive segmentation (point/box prompts)
45
+ predictor = SamPredictor(sam)
46
+ print("✓ SamPredictor initialized for interactive segmentation")
47
+
48
+ # SamAutomaticMaskGenerator for automatic mask generation
49
+ # Uses the same model but with automatic grid-based prompting
50
+ mask_generator = SamAutomaticMaskGenerator(
51
+ model=sam,
52
+ points_per_side=32, # Grid density (32x32 = 1024 points)
53
+ pred_iou_thresh=0.88, # IoU threshold for filtering
54
+ stability_score_thresh=0.95, # Stability threshold
55
+ crop_n_layers=1, # Number of crop layers for multi-scale
56
+ crop_n_points_downscale_factor=2, # Downscale factor for crops
57
+ min_mask_region_area=100 # Minimum mask area in pixels
58
+ )
59
+ print("✓ SamAutomaticMaskGenerator initialized for automatic segmentation")
60
+ print("✓ MedSAM model loaded successfully!")
61
+
62
  finally:
63
  # Restore original torch.load
64
  torch.load = original_torch_load
 
343
  })
344
 
345
 
346
+ # =============================================================================
347
+ # AUTO MASK GENERATION API (replaces local mask_generator.generate())
348
+ # =============================================================================
349
+
350
+ def generate_auto_masks(image, request_json):
351
+ """
352
+ Automatically generate all masks for an image using SAM-H model.
353
+
354
+ This is equivalent to `mask_generator.generate(img_np)` in enhanced_preprocessing.py
355
+
356
+ Args:
357
+ image: PIL Image
358
+ request_json: JSON string with optional parameters:
359
+ {
360
+ "points_per_side": 32, # Grid density (default: 32)
361
+ "pred_iou_thresh": 0.88, # IoU threshold (default: 0.88)
362
+ "stability_score_thresh": 0.95, # Stability threshold (default: 0.95)
363
+ "min_mask_region_area": 0 # Minimum mask area (default: 0)
364
+ }
365
+
366
+ Returns:
367
+ JSON string with format matching SamAutomaticMaskGenerator output:
368
+ {
369
+ "success": true,
370
+ "masks": [
371
+ {
372
+ "segmentation": [[...2D boolean array...]],
373
+ "area": 12345,
374
+ "bbox": [x, y, width, height],
375
+ "predicted_iou": 0.95,
376
+ "point_coords": [[x, y]],
377
+ "stability_score": 0.98,
378
+ "crop_box": [x, y, width, height]
379
+ },
380
+ ...
381
+ ],
382
+ "num_masks": 42,
383
+ "image_size": [height, width]
384
+ }
385
+ """
386
+ try:
387
+ if mask_generator is None:
388
+ return json.dumps({
389
+ 'success': False,
390
+ 'error': 'MedSAM model not loaded. Please ensure medsam_vit_b.pth is available.',
391
+ 'available': False
392
+ })
393
+
394
+ # Parse optional parameters
395
+ params = {}
396
+ if request_json:
397
+ try:
398
+ params = json.loads(request_json) if request_json.strip() else {}
399
+ except:
400
+ params = {}
401
+
402
+ # Convert PIL to numpy
403
+ image_array = np.array(image)
404
+ H, W = image_array.shape[:2]
405
+
406
+ print(f"Generating automatic masks for image of size {W}x{H}...")
407
+
408
+ # Generate masks using SAM automatic mask generator
409
+ masks = mask_generator.generate(image_array)
410
+
411
+ print(f"Generated {len(masks)} masks")
412
+
413
+ # Convert masks to JSON-serializable format
414
+ masks_output = []
415
+ for m in masks:
416
+ mask_data = {
417
+ 'segmentation': m['segmentation'].astype(np.uint8).tolist(),
418
+ 'area': int(m['area']),
419
+ 'bbox': [int(x) for x in m['bbox']], # [x, y, width, height]
420
+ 'predicted_iou': float(m['predicted_iou']),
421
+ 'point_coords': [[int(p[0]), int(p[1])] for p in m['point_coords']] if m['point_coords'] is not None else [],
422
+ 'stability_score': float(m['stability_score']),
423
+ 'crop_box': [int(x) for x in m['crop_box']] # [x, y, width, height]
424
+ }
425
+ masks_output.append(mask_data)
426
+
427
+ result = {
428
+ 'success': True,
429
+ 'masks': masks_output,
430
+ 'num_masks': len(masks_output),
431
+ 'image_size': [H, W]
432
+ }
433
+
434
+ print(f"Auto mask generation complete: {len(masks_output)} masks")
435
+ return json.dumps(result)
436
+
437
+ except Exception as e:
438
+ import traceback
439
+ return json.dumps({
440
+ 'success': False,
441
+ 'error': str(e),
442
+ 'traceback': traceback.format_exc()
443
+ })
444
+
445
+
446
+ def check_auto_mask_status():
447
+ """
448
+ Check if automatic mask generation is available
449
+ """
450
+ return json.dumps({
451
+ 'available': mask_generator is not None,
452
+ 'model': 'medsam_vit_b' if mask_generator else None,
453
+ 'model_type': MODEL_TYPE,
454
+ 'device': str(device)
455
+ })
456
+
457
+
458
  # =============================================================================
459
  # LEGACY API FUNCTIONS (kept for backwards compatibility with test scripts)
460
  # =============================================================================
 
801
  api_name="segment_with_box" # Keep old API name for compatibility
802
  )
803
 
804
+ # Tab 5: Auto Mask Generation (for preprocessing)
805
+ with gr.Tab("Auto Mask Generation"):
806
+ gr.Markdown("""
807
+ ## Automatic Mask Generation (MedSAM)
808
+
809
+ **Replaces `mask_generator.generate(img_np)` in preprocessing pipeline**
810
+
811
+ Uses MedSAM (ViT-B) model with `SamAutomaticMaskGenerator` to automatically
812
+ segment all objects in an image. This is used for initial preprocessing
813
+ of scientific/medical images.
814
+
815
+ Uses the same `medsam_vit_b.pth` model as interactive segmentation.
816
+
817
+ **Output Format:**
818
+ ```json
819
+ {
820
+ "success": true,
821
+ "masks": [
822
+ {
823
+ "segmentation": [[...2D array...]],
824
+ "area": 12345,
825
+ "bbox": [x, y, width, height],
826
+ "predicted_iou": 0.95,
827
+ "point_coords": [[x, y]],
828
+ "stability_score": 0.98,
829
+ "crop_box": [x, y, width, height]
830
+ }
831
+ ],
832
+ "num_masks": 42
833
+ }
834
+ ```
835
+ """)
836
+
837
+ with gr.Row():
838
+ with gr.Column():
839
+ auto_image = gr.Image(type="pil", label="Input Image")
840
+ auto_params = gr.Textbox(
841
+ label="Parameters (optional)",
842
+ placeholder='{"points_per_side": 32, "pred_iou_thresh": 0.88}',
843
+ lines=2
844
+ )
845
+ with gr.Row():
846
+ auto_button = gr.Button("Generate All Masks", variant="primary")
847
+ status_button = gr.Button("Check Status", variant="secondary")
848
+
849
+ with gr.Column():
850
+ auto_output = gr.Textbox(label="Result JSON", lines=20)
851
+ status_output = gr.Textbox(label="Status", lines=3)
852
+
853
+ auto_button.click(
854
+ fn=generate_auto_masks,
855
+ inputs=[auto_image, auto_params],
856
+ outputs=auto_output,
857
+ api_name="generate_auto_masks"
858
+ )
859
+
860
+ status_button.click(
861
+ fn=check_auto_mask_status,
862
+ inputs=[],
863
+ outputs=status_output,
864
+ api_name="check_auto_mask_status"
865
+ )
866
+
867
+ # Tab 6: Simple UI Interface
868
  with gr.Tab("Simple Interface"):
869
  gr.Markdown("## Click-based Segmentation")
870
  gr.Markdown("Enter X, Y coordinates to segment")