CodeJackR
commited on
Commit
·
16a5f8c
1
Parent(s):
2ea60f3
handle boxes and points for SAM input
Browse files- handler.py +20 -5
handler.py
CHANGED
|
@@ -62,8 +62,15 @@ class EndpointHandler():
|
|
| 62 |
# Process the image
|
| 63 |
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 64 |
|
| 65 |
-
#
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# Generate masks using the model
|
| 69 |
with torch.no_grad():
|
|
@@ -76,9 +83,17 @@ class EndpointHandler():
|
|
| 76 |
inputs["reshaped_input_sizes"].cpu()
|
| 77 |
)[0]
|
| 78 |
|
| 79 |
-
# Convert the
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# Convert result to base64
|
| 84 |
out = io.BytesIO()
|
|
|
|
| 62 |
# Process the image
|
| 63 |
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 64 |
|
| 65 |
+
# SAM requires input prompts, so we'll generate a center point prompt
|
| 66 |
+
height, width = img.size[1], img.size[0] # PIL returns (width, height)
|
| 67 |
+
|
| 68 |
+
# Create a center point prompt for automatic segmentation
|
| 69 |
+
input_points = [[[width // 2, height // 2]]] # Center point
|
| 70 |
+
input_labels = [[1]] # Positive prompt
|
| 71 |
+
|
| 72 |
+
# Prepare inputs for the model with prompts
|
| 73 |
+
inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt")
|
| 74 |
|
| 75 |
# Generate masks using the model
|
| 76 |
with torch.no_grad():
|
|
|
|
| 83 |
inputs["reshaped_input_sizes"].cpu()
|
| 84 |
)[0]
|
| 85 |
|
| 86 |
+
# Convert the best mask to a binary mask
|
| 87 |
+
# SAM returns multiple masks, take the first one
|
| 88 |
+
if len(masks) > 0:
|
| 89 |
+
mask = masks[0].squeeze().numpy()
|
| 90 |
+
mask_binary = (mask > 0.5).astype(np.uint8) * 255
|
| 91 |
+
else:
|
| 92 |
+
# Fallback: create a simple center mask
|
| 93 |
+
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
| 94 |
+
center_x, center_y = width // 2, height // 2
|
| 95 |
+
size = min(width, height) // 8
|
| 96 |
+
mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
|
| 97 |
|
| 98 |
# Convert result to base64
|
| 99 |
out = io.BytesIO()
|