Thibaut commited on
Commit
a00a9a9
·
1 Parent(s): de95768

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +73 -0
handler.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, SamModel
6
+
7
+
8
+ class EndpointHandler:
9
+
10
+ def __init__(self, path="facebook/sam3"):
11
+ self.processor = AutoProcessor.from_pretrained(path)
12
+ self.model = SamModel.from_pretrained(
13
+ path,
14
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
15
+ )
16
+ self.model.eval()
17
+ if torch.cuda.is_available():
18
+ self.model = self.model.cuda()
19
+
20
+ def __call__(self, data):
21
+ """
22
+ Expected HF pipeline request:
23
+ {
24
+ "inputs": "<base64 or URL>",
25
+ "parameters": {
26
+ "classes": ["pothole", "marking"]
27
+ }
28
+ }
29
+ """
30
+
31
+ # Extract
32
+ image_b64 = data.get("inputs", None)
33
+ params = data.get("parameters", {})
34
+ classes = params.get("classes", None)
35
+
36
+ if image_b64 is None or classes is None:
37
+ return {"error": "Required fields: inputs (image base64), parameters.classes"}
38
+
39
+ # Decode image
40
+ image_bytes = base64.b64decode(image_b64)
41
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
42
+
43
+ inputs = self.processor(
44
+ images=pil_image,
45
+ text=classes,
46
+ return_tensors="pt"
47
+ )
48
+
49
+ if torch.cuda.is_available():
50
+ inputs = {k: v.cuda() for k, v in inputs.items()}
51
+
52
+ with torch.no_grad():
53
+ outputs = self.model(**inputs)
54
+
55
+ pred_masks = outputs.pred_masks.squeeze(1) # [N, H, W]
56
+
57
+ results = []
58
+ for i, cls in enumerate(classes):
59
+ mask = pred_masks[i].float().cpu()
60
+ binary_mask = (mask > 0.5).numpy().astype("uint8") * 255
61
+
62
+ pil_mask = Image.fromarray(binary_mask, mode="L")
63
+ buf = io.BytesIO()
64
+ pil_mask.save(buf, format="PNG")
65
+ mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
66
+
67
+ results.append({
68
+ "label": cls,
69
+ "mask": mask_b64,
70
+ "score": 1.0 # SAM3 does not output per-class confidence
71
+ })
72
+
73
+ return results