caball21 commited on
Commit
c75120d
·
verified ·
1 Parent(s): 0271854

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +64 -0
handler.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+
3
+ import torch
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+ import io
7
+ import json
8
+
9
+ # Define class labels (same order as training)
10
+ CLASS_LABELS = [
11
+ "glove_outline",
12
+ "webbing",
13
+ "thumb",
14
+ "palm_pocket",
15
+ "hand",
16
+ "glove_exterior"
17
+ ]
18
+
19
+ # Load model from disk
20
+ def load_model():
21
+ model = torch.load("pytorch_model.bin", map_location="cpu")
22
+ model.eval()
23
+ return model
24
+
25
+ model = load_model()
26
+
27
+ # Preprocessing transform
28
+ transform = T.Compose([
29
+ T.Resize((720, 1280)), # or whatever input size model expects
30
+ T.ToTensor()
31
+ ])
32
+
33
+ # Input: raw image bytes
34
+ def preprocess(input_bytes):
35
+ image = Image.open(io.BytesIO(input_bytes)).convert("RGB")
36
+ tensor = transform(image).unsqueeze(0) # [1, 3, H, W]
37
+ return tensor
38
+
39
+ # Postprocess output: convert logits to mask
40
+ def postprocess(output_tensor):
41
+ # Argmax over channel dimension (assumes shape [1, C, H, W])
42
+ pred = torch.argmax(output_tensor, dim=1)[0].cpu().numpy()
43
+ return pred.tolist() # List of H x W values from 0 to 5
44
+
45
+ # TorchServe/HF entrypoint
46
+ def infer(payload):
47
+ # If input is multipart/form-data, raw bytes
48
+ if isinstance(payload, bytes):
49
+ image_tensor = preprocess(payload)
50
+ elif isinstance(payload, dict) and "inputs" in payload:
51
+ # Hugging Face Inference API passes {"inputs": "base64 image data"}
52
+ from base64 import b64decode
53
+ image_tensor = preprocess(b64decode(payload["inputs"]))
54
+ else:
55
+ raise ValueError("Unsupported input format")
56
+
57
+ with torch.no_grad():
58
+ output = model(image_tensor)
59
+
60
+ mask = postprocess(output)
61
+ return {
62
+ "mask": mask,
63
+ "classes": CLASS_LABELS
64
+ }