CodeJackR commited on
Commit
d816a26
·
1 Parent(s): d05bd8d

Fix errors

Browse files
Files changed (2) hide show
  1. handler.py +32 -30
  2. requirements.txt +1 -1
handler.py CHANGED
@@ -4,35 +4,26 @@ import io
4
  import base64
5
  import numpy as np
6
  from PIL import Image
7
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
 
8
  from typing import Dict, List, Any
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
  """
13
  Called once at startup.
14
- The model files are mounted under /opt/ml/model by default in Inference Endpoints.
15
  """
16
- # Try different possible checkpoint paths
17
- import os
18
- possible_paths = [
19
- os.path.join(path, "pytorch_model.bin"),
20
- os.path.join(path, "model.safetensors"),
21
- "/opt/ml/model/pytorch_model.bin",
22
- "/opt/ml/model/model.safetensors"
23
- ]
24
-
25
- checkpoint = None
26
- for p in possible_paths:
27
- if os.path.exists(p):
28
- checkpoint = p
29
- break
30
-
31
- if checkpoint is None:
32
- raise FileNotFoundError("Could not find model checkpoint in any of the expected locations")
33
-
34
- sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
35
- self.mask_generator = SamAutomaticMaskGenerator(sam)
36
 
37
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
38
  """
@@ -57,17 +48,28 @@ class EndpointHandler():
57
 
58
  # Process the image
59
  img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
60
- img_np = np.array(img)
61
 
62
- # Generate masks
63
- masks = self.mask_generator.generate(img_np)
64
- combined = np.zeros(img_np.shape[:2], dtype=np.uint8)
65
- for m in masks:
66
- combined[m["segmentation"]] = 255
67
-
 
 
 
 
 
 
 
 
 
 
 
 
68
  # Convert result to base64
69
  out = io.BytesIO()
70
- Image.fromarray(combined).save(out, format="PNG")
71
  out.seek(0)
72
  mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
73
 
 
4
  import base64
5
  import numpy as np
6
  from PIL import Image
7
+ import torch
8
+ from transformers import SamModel, SamProcessor
9
  from typing import Dict, List, Any
10
 
11
  class EndpointHandler():
12
  def __init__(self, path=""):
13
  """
14
  Called once at startup.
15
+ Load the SAM model using Hugging Face Transformers.
16
  """
17
+ try:
18
+ # Load the model and processor from the local path
19
+ self.model = SamModel.from_pretrained(path)
20
+ self.processor = SamProcessor.from_pretrained(path)
21
+ except Exception as e:
22
+ # Fallback to loading from a known SAM model if local loading fails
23
+ print(f"Failed to load from local path: {e}")
24
+ print("Attempting to load from facebook/sam-vit-base")
25
+ self.model = SamModel.from_pretrained("facebook/sam-vit-base")
26
+ self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
 
 
 
 
 
 
 
 
 
 
27
 
28
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
29
  """
 
48
 
49
  # Process the image
50
  img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
51
 
52
+ # Prepare inputs for the model
53
+ inputs = self.processor(img, return_tensors="pt")
54
+
55
+ # Generate masks using the model
56
+ with torch.no_grad():
57
+ outputs = self.model(**inputs)
58
+
59
+ # Process the outputs to get masks
60
+ masks = self.processor.image_processor.post_process_masks(
61
+ outputs.pred_masks.cpu(),
62
+ inputs["original_sizes"].cpu(),
63
+ inputs["reshaped_input_sizes"].cpu()
64
+ )[0]
65
+
66
+ # Convert the first mask to a binary mask
67
+ mask = masks[0].squeeze().numpy()
68
+ mask_binary = (mask > 0.0).astype(np.uint8) * 255
69
+
70
  # Convert result to base64
71
  out = io.BytesIO()
72
+ Image.fromarray(mask_binary).save(out, format="PNG")
73
  out.seek(0)
74
  mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
75
 
requirements.txt CHANGED
@@ -2,4 +2,4 @@
2
  torch
3
  numpy
4
  Pillow
5
- segment-anything
 
2
  torch
3
  numpy
4
  Pillow
5
+ transformers[vision]