caball21 commited on
Commit
fc03d28
·
verified ·
1 Parent(s): 6a79cbc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +24 -11
handler.py CHANGED
@@ -6,6 +6,8 @@ from PIL import Image
6
  import io
7
  import json
8
 
 
 
9
  # Define class labels (same order as training)
10
  CLASS_LABELS = [
11
  "glove_outline",
@@ -16,39 +18,50 @@ CLASS_LABELS = [
16
  "glove_exterior"
17
  ]
18
 
19
- # Load model from disk
 
 
20
  def load_model():
21
- model = torch.load("pytorch_model.bin", map_location="cpu")
 
 
 
 
 
 
 
22
  model.eval()
23
  return model
24
 
25
  model = load_model()
26
 
27
- # Preprocessing transform
 
 
28
  transform = T.Compose([
29
- T.Resize((720, 1280)), # or whatever input size model expects
30
  T.ToTensor()
31
  ])
32
 
33
- # Input: raw image bytes
34
  def preprocess(input_bytes):
35
  image = Image.open(io.BytesIO(input_bytes)).convert("RGB")
36
  tensor = transform(image).unsqueeze(0) # [1, 3, H, W]
37
  return tensor
38
 
39
- # Postprocess output: convert logits to mask
 
 
40
  def postprocess(output_tensor):
41
- # Argmax over channel dimension (assumes shape [1, C, H, W])
42
  pred = torch.argmax(output_tensor, dim=1)[0].cpu().numpy()
43
- return pred.tolist() # List of H x W values from 0 to 5
44
 
45
- # TorchServe/HF entrypoint
 
 
46
  def infer(payload):
47
- # If input is multipart/form-data, raw bytes
48
  if isinstance(payload, bytes):
49
  image_tensor = preprocess(payload)
50
  elif isinstance(payload, dict) and "inputs" in payload:
51
- # Hugging Face Inference API passes {"inputs": "base64 image data"}
52
  from base64 import b64decode
53
  image_tensor = preprocess(b64decode(payload["inputs"]))
54
  else:
 
6
  import io
7
  import json
8
 
9
+ from sam2_model_stub import SAM2Hierarchical # 👈 stub class we define separately
10
+
11
  # Define class labels (same order as training)
12
  CLASS_LABELS = [
13
  "glove_outline",
 
18
  "glove_exterior"
19
  ]
20
 
21
+ # ----------------------------
22
+ # Load model weights + class
23
+ # ----------------------------
24
  def load_model():
25
+ model = SAM2Hierarchical(
26
+ num_classes=len(CLASS_LABELS),
27
+ in_channels=3,
28
+ backbone="vit_b", # <-- match your config.yaml
29
+ freeze_backbone=True,
30
+ use_cls_head=True
31
+ )
32
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
33
  model.eval()
34
  return model
35
 
36
  model = load_model()
37
 
38
+ # ----------------------------
39
+ # Preprocessing
40
+ # ----------------------------
41
  transform = T.Compose([
42
+ T.Resize((720, 1280)),
43
  T.ToTensor()
44
  ])
45
 
 
46
  def preprocess(input_bytes):
47
  image = Image.open(io.BytesIO(input_bytes)).convert("RGB")
48
  tensor = transform(image).unsqueeze(0) # [1, 3, H, W]
49
  return tensor
50
 
51
+ # ----------------------------
52
+ # Postprocessing
53
+ # ----------------------------
54
  def postprocess(output_tensor):
 
55
  pred = torch.argmax(output_tensor, dim=1)[0].cpu().numpy()
56
+ return pred.tolist()
57
 
58
+ # ----------------------------
59
+ # Inference Entry Point
60
+ # ----------------------------
61
  def infer(payload):
 
62
  if isinstance(payload, bytes):
63
  image_tensor = preprocess(payload)
64
  elif isinstance(payload, dict) and "inputs" in payload:
 
65
  from base64 import b64decode
66
  image_tensor = preprocess(b64decode(payload["inputs"]))
67
  else: