CodeJackR
commited on
Commit
·
e0fb0e6
1
Parent(s):
064d94e
Input image as image
Browse files- handler.py +38 -52
handler.py
CHANGED
|
@@ -24,111 +24,97 @@ class EndpointHandler():
|
|
| 24 |
self.processor = SamProcessor.from_pretrained(path)
|
| 25 |
except Exception as e:
|
| 26 |
# Fallback to loading from a known SAM model if local loading fails
|
| 27 |
-
print(
|
| 28 |
print("Attempting to load from facebook/sam-vit-base")
|
| 29 |
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
| 30 |
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
| 31 |
|
| 32 |
-
def __call__(self, data
|
| 33 |
"""
|
| 34 |
Called on every HTTP request.
|
| 35 |
-
|
| 36 |
-
data (:obj:):
|
| 37 |
-
includes the input data and the parameters for the inference.
|
| 38 |
"""
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
# img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 45 |
-
# img = raw_images[0]
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
|
| 50 |
-
#
|
|
|
|
| 51 |
input_points = [[[width // 2, height // 2]]] # Center point
|
| 52 |
input_labels = [[1]] # Positive prompt
|
| 53 |
|
| 54 |
-
|
| 55 |
-
inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt")
|
| 56 |
|
| 57 |
-
# Generate masks
|
| 58 |
with torch.no_grad():
|
| 59 |
outputs = self.model(**inputs)
|
| 60 |
|
|
|
|
| 61 |
try:
|
| 62 |
-
# Get original image size
|
| 63 |
original_height, original_width = inputs["original_sizes"][0].tolist()
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
# Get predicted masks and scores
|
| 66 |
-
pred_masks = outputs.pred_masks.cpu() # (batch, num_masks, H, W)
|
| 67 |
-
iou_scores = outputs.iou_scores.cpu()[0] # (num_masks,)
|
| 68 |
-
|
| 69 |
-
# The model might return 4D or 5D tensors. Squeeze if 5D.
|
| 70 |
if pred_masks.ndim == 5:
|
| 71 |
pred_masks = pred_masks.squeeze(1)
|
| 72 |
|
| 73 |
-
# Select the best mask
|
| 74 |
best_mask_idx = torch.argmax(iou_scores)
|
| 75 |
-
best_mask_tensor = pred_masks[0, best_mask_idx, :, :]
|
| 76 |
|
| 77 |
-
# Upscale the mask to original image size
|
| 78 |
-
# Add batch and channel dims for interpolate
|
| 79 |
upscaled_mask = F.interpolate(
|
| 80 |
best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
|
| 81 |
size=(original_height, original_width),
|
| 82 |
mode='bilinear',
|
| 83 |
align_corners=False
|
| 84 |
-
).squeeze()
|
| 85 |
|
| 86 |
-
# Convert to binary mask
|
| 87 |
mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
|
| 88 |
|
| 89 |
except Exception as e:
|
| 90 |
-
print(
|
| 91 |
-
# Fallback
|
| 92 |
-
height, width = img.size[1], img.size[0]
|
| 93 |
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
| 94 |
center_x, center_y = width // 2, height // 2
|
| 95 |
size = min(width, height) // 8
|
| 96 |
mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
|
| 97 |
|
| 98 |
-
#
|
| 99 |
out = io.BytesIO()
|
| 100 |
Image.fromarray(mask_binary).save(out, format="PNG")
|
| 101 |
out.seek(0)
|
| 102 |
mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
|
| 103 |
-
|
| 104 |
-
# Decode the returned mask and save
|
| 105 |
-
mask_bytes = base64.b64decode(mask_base64)
|
| 106 |
-
mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
|
| 107 |
-
# mask_img.save(output_path, format="JPEG")
|
| 108 |
-
# print(f"Wrote mask to {output_path}")
|
| 109 |
|
| 110 |
-
# Return
|
| 111 |
-
return
|
| 112 |
|
| 113 |
def main():
|
| 114 |
-
#
|
| 115 |
input_path = "/Users/rp7/Downloads/test.jpeg"
|
| 116 |
-
output_path = "output.
|
| 117 |
|
| 118 |
-
#
|
| 119 |
with open(input_path, "rb") as f:
|
| 120 |
img_bytes = f.read()
|
| 121 |
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
| 122 |
-
|
| 123 |
|
|
|
|
| 124 |
handler = EndpointHandler(path=".")
|
| 125 |
-
result = handler(
|
| 126 |
-
|
| 127 |
-
#
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
|
| 133 |
if __name__ == "__main__":
|
| 134 |
main()
|
|
|
|
| 24 |
self.processor = SamProcessor.from_pretrained(path)
|
| 25 |
except Exception as e:
|
| 26 |
# Fallback to loading from a known SAM model if local loading fails
|
| 27 |
+
print("Failed to load from local path: {}".format(e))
|
| 28 |
print("Attempting to load from facebook/sam-vit-base")
|
| 29 |
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
| 30 |
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
| 31 |
|
| 32 |
+
def __call__(self, data):
|
| 33 |
"""
|
| 34 |
Called on every HTTP request.
|
| 35 |
+
Expecting base64 encoded image in the 'inputs' field.
|
|
|
|
|
|
|
| 36 |
"""
|
| 37 |
+
# 1. Parse and decode the input image
|
| 38 |
+
image_data = data.pop("inputs", None)
|
| 39 |
+
if not image_data:
|
| 40 |
+
raise ValueError("Missing 'inputs' key with a base64 image string.")
|
| 41 |
|
| 42 |
+
if isinstance(image_data, str) and image_data.startswith("data:"):
|
| 43 |
+
image_data = image_data.split(",", 1)[1]
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
image_bytes = base64.b64decode(image_data)
|
| 46 |
+
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 47 |
|
| 48 |
+
# 2. Prepare prompts and process the image
|
| 49 |
+
height, width = img.size[1], img.size[0]
|
| 50 |
input_points = [[[width // 2, height // 2]]] # Center point
|
| 51 |
input_labels = [[1]] # Positive prompt
|
| 52 |
|
| 53 |
+
inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
|
|
|
|
| 54 |
|
| 55 |
+
# 3. Generate masks
|
| 56 |
with torch.no_grad():
|
| 57 |
outputs = self.model(**inputs)
|
| 58 |
|
| 59 |
+
# 4. Process and select the best mask
|
| 60 |
try:
|
|
|
|
| 61 |
original_height, original_width = inputs["original_sizes"][0].tolist()
|
| 62 |
+
pred_masks = outputs.pred_masks.cpu()
|
| 63 |
+
iou_scores = outputs.iou_scores.cpu()[0]
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
if pred_masks.ndim == 5:
|
| 66 |
pred_masks = pred_masks.squeeze(1)
|
| 67 |
|
|
|
|
| 68 |
best_mask_idx = torch.argmax(iou_scores)
|
| 69 |
+
best_mask_tensor = pred_masks[0, best_mask_idx, :, :]
|
| 70 |
|
|
|
|
|
|
|
| 71 |
upscaled_mask = F.interpolate(
|
| 72 |
best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
|
| 73 |
size=(original_height, original_width),
|
| 74 |
mode='bilinear',
|
| 75 |
align_corners=False
|
| 76 |
+
).squeeze()
|
| 77 |
|
|
|
|
| 78 |
mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
|
| 79 |
|
| 80 |
except Exception as e:
|
| 81 |
+
print("Error processing masks: {}".format(e))
|
|
|
|
|
|
|
| 82 |
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
| 83 |
center_x, center_y = width // 2, height // 2
|
| 84 |
size = min(width, height) // 8
|
| 85 |
mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
|
| 86 |
|
| 87 |
+
# 5. Encode the output image to base64
|
| 88 |
out = io.BytesIO()
|
| 89 |
Image.fromarray(mask_binary).save(out, format="PNG")
|
| 90 |
out.seek(0)
|
| 91 |
mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
# 6. Return the response
|
| 94 |
+
return [{"mask_png_base64": mask_base64}]
|
| 95 |
|
| 96 |
def main():
|
| 97 |
+
# This main function shows how a client would call the endpoint.
|
| 98 |
input_path = "/Users/rp7/Downloads/test.jpeg"
|
| 99 |
+
output_path = "output.png"
|
| 100 |
|
| 101 |
+
# 1. Prepare the payload
|
| 102 |
with open(input_path, "rb") as f:
|
| 103 |
img_bytes = f.read()
|
| 104 |
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
| 105 |
+
payload = {"inputs": "data:image/jpeg;base64,{}".format(img_b64)}
|
| 106 |
|
| 107 |
+
# 2. Instantiate handler and call it
|
| 108 |
handler = EndpointHandler(path=".")
|
| 109 |
+
result = handler(payload)
|
| 110 |
+
|
| 111 |
+
# 3. Process the response
|
| 112 |
+
mask_b64 = result[0]["mask_png_base64"]
|
| 113 |
+
mask_bytes = base64.b64decode(mask_b64)
|
| 114 |
+
|
| 115 |
+
mask_img = Image.open(io.BytesIO(mask_bytes))
|
| 116 |
+
mask_img.save(output_path)
|
| 117 |
+
print("Wrote mask to {}".format(output_path))
|
| 118 |
|
| 119 |
if __name__ == "__main__":
|
| 120 |
main()
|