sajabdoli commited on
Commit
39a28f1
·
verified ·
1 Parent(s): 870465f

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +10 -11
  2. main.py +38 -0
  3. requirements.txt +8 -0
  4. sam_vit_b.pth +3 -0
README.md CHANGED
@@ -1,11 +1,10 @@
1
- ---
2
- title: SAM
3
- emoji: 📉
4
- colorFrom: red
5
- colorTo: pink
6
- sdk: docker
7
- pinned: false
8
- short_description: SAM
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # SAM FastAPI for Hugging Face Space
2
+
3
+ This is a FastAPI wrapper for Meta's Segment Anything Model (SAM), ready to deploy on Hugging Face Spaces.
4
+
5
+ ## Setup
6
+
7
+ 1. Upload the `sam_vit_b.pth` model checkpoint to the root of your Space manually.
8
+ 2. Hugging Face will automatically install dependencies and run the app.
9
+ 3. Use the `/segment` endpoint to perform segmentation.
10
+
 
main.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from segment_anything import sam_model_registry, SamPredictor
3
+ from PIL import Image
4
+ import numpy as np
5
+ import torch
6
+ import io
7
+
8
+ app = FastAPI()
9
+
10
+ # Load SAM Model
11
+ sam_checkpoint = "sam_vit_b.pth" # Add the weights file manually in the Space
12
+ model_type = "vit_b"
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
16
+ predictor = SamPredictor(sam)
17
+
18
+ @app.post("/segment")
19
+ async def segment_image(file: UploadFile = File(...)):
20
+ image_bytes = await file.read()
21
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
22
+ image_np = np.array(image)
23
+
24
+ predictor.set_image(image_np)
25
+
26
+ input_point = np.array([[100, 100]])
27
+ input_label = np.array([1])
28
+
29
+ masks, scores, _ = predictor.predict(
30
+ point_coords=input_point,
31
+ point_labels=input_label,
32
+ multimask_output=False
33
+ )
34
+
35
+ return {
36
+ "score": float(scores[0]),
37
+ "mask": masks[0].tolist()
38
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ opencv-python
4
+ numpy
5
+ torch
6
+ torchvision
7
+ git+https://github.com/facebookresearch/segment-anything.git
8
+ pillow
sam_vit_b.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383