Improve model card: Add library_name, GitHub link, and usage example

#1
by nielsr HF Staff - opened
Files changed (1) hide show
  1. README.md +41 -6
README.md CHANGED
@@ -2,12 +2,17 @@
2
  pipeline_tag: image-segmentation
3
  tags:
4
  - medical
 
 
5
  ---
 
6
  # MCP-MedSAM
7
 
8
  Pytorch Implementation of the paper:
9
  "[MCP-MedSAM: A Powerful Lightweight Medical Segment Anything Model Trained with a Single GPU in Just One Day](https://arxiv.org/abs/2412.05888)"
10
 
 
 
11
  ![MCP-MedSAM Architecture](docs/MCP-MedSAM.png)
12
 
13
  ## 📄 Overview
@@ -24,11 +29,6 @@ To further improve performance across imaging modalities, we introduce a **modal
24
 
25
  With these enhancements, our model achieves strong multi-modality segmentation performance, and can be trained in approximately **1 day on a single A100 (40GB)** GPU.
26
 
27
- <!--
28
- We are currently releasing the inference code along with the model weight. You can download from [here](https://drive.google.com/drive/folders/1NW4aSNhk-dtiK-dicTAUp0g0eR2fryNi?usp=sharing).
29
-
30
- The training code has been released and you can train your . -->
31
-
32
  ## Requirements
33
 
34
  * Python==3.10.14
@@ -38,7 +38,42 @@ The training code has been released and you can train your . -->
38
 
39
  ## Training and Inference
40
 
41
- Training and inference can be done by running train.py and infer.py. Model weights are stored in the pytorch_model.bin file, which can be loaded for inference.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  ## Citation
44
 
 
2
  pipeline_tag: image-segmentation
3
  tags:
4
  - medical
5
+ license: mit
6
+ library_name: transformers
7
  ---
8
+
9
  # MCP-MedSAM
10
 
11
  Pytorch Implementation of the paper:
12
  "[MCP-MedSAM: A Powerful Lightweight Medical Segment Anything Model Trained with a Single GPU in Just One Day](https://arxiv.org/abs/2412.05888)"
13
 
14
+ Code: [https://github.com/Leo-Lyu/MCP-MedSAM](https://github.com/Leo-Lyu/MCP-MedSAM)
15
+
16
  ![MCP-MedSAM Architecture](docs/MCP-MedSAM.png)
17
 
18
  ## 📄 Overview
 
29
 
30
  With these enhancements, our model achieves strong multi-modality segmentation performance, and can be trained in approximately **1 day on a single A100 (40GB)** GPU.
31
 
 
 
 
 
 
32
  ## Requirements
33
 
34
  * Python==3.10.14
 
38
 
39
  ## Training and Inference
40
 
41
+ Training and inference can be done by running `train.py` and `infer.py` from the [official repository](https://github.com/Leo-Lyu/MCP-MedSAM). Additionally, the model weight for inference can be downloaded from [here](https://drive.google.com/drive/folders/1NW4aSNhk-dtiK-dicTAUp0g0eR2fryNi?usp=sharing). MCP-MedSAM has also been uploaded to the [Hugging Face Hub](https://huggingface.co/Leo-Lyu/MCP-MedSAM), including pre-trained weights as well.
42
+
43
+ ## Usage
44
+
45
+ You can use the model with the `transformers` library for inference. Ensure you have the `transformers` library installed (`pip install transformers`).
46
+
47
+ ```python
48
+ from transformers import AutoModelForImageSegmentation, AutoProcessor
49
+ from PIL import Image
50
+ import torch
51
+
52
+ # Load model and processor from the Hugging Face Hub
53
+ model = AutoModelForImageSegmentation.from_pretrained("Leo-Lyu/MCP-MedSAM")
54
+ processor = AutoProcessor.from_pretrained("Leo-Lyu/MCP-MedSAM")
55
+
56
+ # Example: Load an image and define a bounding box prompt
57
+ # Replace "path/to/your/medical_image.jpg" with the actual path to your image file
58
+ image = Image.open("path/to/your/medical_image.jpg").convert("RGB")
59
+ input_boxes = [[100, 200, 300, 400]] # Example bounding box prompt [x_min, y_min, x_max, y_max]
60
+
61
+ # Prepare inputs for the model
62
+ inputs = processor(images=image, input_boxes=input_boxes, return_tensors="pt")
63
+
64
+ # Perform inference
65
+ with torch.no_grad():
66
+ outputs = model(**inputs)
67
+
68
+ # Access predicted masks (raw logits)
69
+ # The exact output structure might vary depending on the model implementation.
70
+ pred_masks = outputs.pred_masks
71
+
72
+ # To get a binary mask, apply sigmoid and threshold (example)
73
+ binary_mask = (torch.sigmoid(pred_masks) > 0.5).float()
74
+
75
+ print("Generated mask shape:", binary_mask.shape)
76
+ ```
77
 
78
  ## Citation
79