CodeJackR commited on
Commit
592adee
·
1 Parent(s): 804a4c8

Add custom handler for SAM Inference Endpoint

Browse files
Files changed (2) hide show
  1. handler.py +33 -0
  2. requirements.txt +4 -0
handler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ import io
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
7
+
8
+ # This will be called once on startup
9
+ def initialize():
10
+ global mask_generator
11
+ # HF will mount your model files under /mnt/models
12
+ checkpoint = "/mnt/models/pytorch_model.bin"
13
+ sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
14
+ mask_generator = SamAutomaticMaskGenerator(sam)
15
+
16
+ # This handles each incoming request
17
+ def inference(request):
18
+ # expect multipart/form-data with field name "image"
19
+ image_bytes = request.files["image"].read()
20
+ img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
21
+ img_np = np.array(img)
22
+ masks = mask_generator.generate(img_np)
23
+
24
+ # combine all masks into one binary mask
25
+ combined = np.zeros(img_np.shape[:2], dtype=np.uint8)
26
+ for m in masks:
27
+ combined[m["segmentation"]] = 255
28
+
29
+ # serialize to PNG
30
+ out = io.BytesIO()
31
+ Image.fromarray(combined).save(out, format="PNG")
32
+ out.seek(0)
33
+ return {"mask_png": out.read()}
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ Pillow
4
+ segment-anything