Anigor66 commited on
Commit
94ac4ef
·
1 Parent(s): 8cb9ee2

Added hf model

Browse files
Files changed (1) hide show
  1. app.py +36 -17
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- HuggingFace Space for MedSAM Inference
3
  API-compatible with Dense-Captioning-Toolkit backend
4
 
5
  Deploy this to: https://huggingface.co/spaces/YOUR_USERNAME/medsam-inference
@@ -12,30 +12,49 @@ import io
12
  import json
13
  import base64
14
 
15
- # Import MedSAM components
 
 
16
  from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
17
 
18
  # Initialize model
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  print(f"Using device: {device}")
21
 
22
- # Model configuration - using MedSAM (vit_b) for both interactive and automatic segmentation
23
- MODEL_CHECKPOINT = "medsam_vit_b.pth"
24
- MODEL_TYPE = "vit_b"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- print("Loading MedSAM model...")
27
 
28
  # Monkey-patch torch.load to use CPU mapping when needed
29
  original_torch_load = torch.load
30
  def patched_torch_load(f, *args, **kwargs):
31
- if 'map_location' not in kwargs and device == 'cpu':
32
- kwargs['map_location'] = 'cpu'
33
  return original_torch_load(f, *args, **kwargs)
34
 
35
  torch.load = patched_torch_load
36
 
37
- # Load MedSAM model (vit_b) - used for both interactive and automatic segmentation
38
- print(f"Loading MedSAM model ({MODEL_TYPE})...")
39
 
40
  try:
41
  torch.load = patched_torch_load
@@ -66,7 +85,7 @@ mask_generator = SamAutomaticMaskGenerator(
66
  min_mask_region_area=0 # Minimum mask area (lowered from 100 to allow small masks)
67
  )
68
  print("✓ SamAutomaticMaskGenerator initialized for automatic segmentation")
69
- print("✓ MedSAM model loaded successfully!")
70
 
71
 
72
  # =============================================================================
@@ -465,12 +484,12 @@ def check_auto_mask_status():
465
  """
466
  Check if automatic mask generation is available
467
  """
468
- return json.dumps({
469
- 'available': mask_generator is not None,
470
- 'model': 'medsam_vit_b' if mask_generator else None,
471
- 'model_type': MODEL_TYPE,
472
- 'device': str(device)
473
- })
474
 
475
 
476
  # =============================================================================
 
1
  """
2
+ HuggingFace Space for SAM / MedSAM Inference
3
  API-compatible with Dense-Captioning-Toolkit backend
4
 
5
  Deploy this to: https://huggingface.co/spaces/YOUR_USERNAME/medsam-inference
 
12
  import json
13
  import base64
14
 
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ # Import SAM components
18
  from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
19
 
20
  # Initialize model
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  print(f"Using device: {device}")
23
 
24
+ # -----------------------------------------------------------------------------
25
+ # Model configuration
26
+ # -----------------------------------------------------------------------------
27
+ # We CANNOT store the large SAM checkpoint directly in the Space repo due to
28
+ # size limits, so we fetch it from a separate model repo on HuggingFace Hub
29
+ # (e.g. `Aniketg6/dense-captioning-models`) and let Spaces cache it.
30
+ #
31
+ # You said you've uploaded your SAM model there; by default we assume the
32
+ # filename is `sam_vit_h_4b8939.pth`. If it's different, just change
33
+ # MODEL_FILENAME below.
34
+ MODEL_REPO_ID = "Aniketg6/dense-captioning-models"
35
+ MODEL_FILENAME = "sam_vit_h_4b8939.pth" # change if your filename is different
36
+ MODEL_TYPE = "vit_h" # using SAM ViT-H (general-purpose SAM)
37
+
38
+ print(f"Downloading SAM checkpoint `{MODEL_FILENAME}` from repo `{MODEL_REPO_ID}`...")
39
+ MODEL_CHECKPOINT = hf_hub_download(
40
+ repo_id=MODEL_REPO_ID,
41
+ filename=MODEL_FILENAME,
42
+ )
43
+ print(f"✓ Checkpoint downloaded to: {MODEL_CHECKPOINT}")
44
 
45
+ print("Loading SAM model...")
46
 
47
  # Monkey-patch torch.load to use CPU mapping when needed
48
  original_torch_load = torch.load
49
  def patched_torch_load(f, *args, **kwargs):
50
+ if "map_location" not in kwargs and device == "cpu":
51
+ kwargs["map_location"] = "cpu"
52
  return original_torch_load(f, *args, **kwargs)
53
 
54
  torch.load = patched_torch_load
55
 
56
+ # Load SAM model (vit_h) - used for both interactive and automatic segmentation
57
+ print(f"Loading SAM model ({MODEL_TYPE}) from downloaded checkpoint...")
58
 
59
  try:
60
  torch.load = patched_torch_load
 
85
  min_mask_region_area=0 # Minimum mask area (lowered from 100 to allow small masks)
86
  )
87
  print("✓ SamAutomaticMaskGenerator initialized for automatic segmentation")
88
+ print("✓ SAM model loaded successfully from HuggingFace Hub!")
89
 
90
 
91
  # =============================================================================
 
484
  """
485
  Check if automatic mask generation is available
486
  """
487
+ return json.dumps({
488
+ 'available': mask_generator is not None,
489
+ 'model': MODEL_FILENAME if mask_generator else None,
490
+ 'model_type': MODEL_TYPE,
491
+ 'device': str(device)
492
+ })
493
 
494
 
495
  # =============================================================================