File size: 5,085 Bytes
592adee 2f7cfdc 592adee d05bd8d 592adee d816a26 d05bd8d f9b3f94 592adee c78d04e d05bd8d 2f7cfdc d816a26 2f7cfdc d816a26 c78d04e d816a26 c78d04e d816a26 592adee c78d04e 2f7cfdc c78d04e 2f7cfdc c78d04e 064d94e c78d04e 064d94e 2f7cfdc 16a5f8c d816a26 f9b3f94 d816a26 f9b3f94 16a5f8c d816a26 d05bd8d 2f7cfdc d816a26 2f7cfdc d05bd8d 592adee c78d04e d05bd8d c78d04e f9b3f94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# handler.py
import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from typing import Dict, List, Any
import torch.nn.functional as F
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
"""
Called once at startup.
Load the SAM model using Hugging Face Transformers.
"""
try:
# Load the model and processor from the local path
self.model = SamModel.from_pretrained(path).to(device)
self.processor = SamProcessor.from_pretrained(path)
except Exception as e:
# Fallback to loading from a known SAM model if local loading fails
print(f"Failed to load from local path: {e}")
print("Attempting to load from facebook/sam-vit-base")
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def __call__(self, data: Any) -> Any:
"""
Called on every HTTP request.
Args:
data (:obj:):
includes the input data and the parameters for the inference.
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
img = Image.open(io.BytesIO(inputs))
# img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# img = raw_images[0]
# SAM requires input prompts, so we'll generate a center point prompt
height, width = img.size[1], img.size[0] # PIL returns (width, height)
# Create a center point prompt for automatic segmentation
input_points = [[[width // 2, height // 2]]] # Center point
input_labels = [[1]] # Positive prompt
# Prepare inputs for the model with prompts
inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt")
# Generate masks using the model
with torch.no_grad():
outputs = self.model(**inputs)
try:
# Get original image size
original_height, original_width = inputs["original_sizes"][0].tolist()
# Get predicted masks and scores
pred_masks = outputs.pred_masks.cpu() # (batch, num_masks, H, W)
iou_scores = outputs.iou_scores.cpu()[0] # (num_masks,)
# The model might return 4D or 5D tensors. Squeeze if 5D.
if pred_masks.ndim == 5:
pred_masks = pred_masks.squeeze(1)
# Select the best mask
best_mask_idx = torch.argmax(iou_scores)
best_mask_tensor = pred_masks[0, best_mask_idx, :, :] # (H, W)
# Upscale the mask to original image size
# Add batch and channel dims for interpolate
upscaled_mask = F.interpolate(
best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
size=(original_height, original_width),
mode='bilinear',
align_corners=False
).squeeze() # remove batch/channel dims
# Convert to binary mask
mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
except Exception as e:
print(f"Error processing masks: {e}")
# Fallback
height, width = img.size[1], img.size[0]
mask_binary = np.zeros((height, width), dtype=np.uint8)
center_x, center_y = width // 2, height // 2
size = min(width, height) // 8
mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
# Convert result to base64
out = io.BytesIO()
Image.fromarray(mask_binary).save(out, format="PNG")
out.seek(0)
mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
# Decode the returned mask and save
mask_bytes = base64.b64decode(mask_base64)
mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
# mask_img.save(output_path, format="JPEG")
# print(f"Wrote mask to {output_path}")
# Return in the expected format
return mask_img
def main():
# Hardcoded input and output paths
input_path = "/Users/rp7/Downloads/test.jpeg"
output_path = "output.jpg"
# Read and base64-encode the input image
with open(input_path, "rb") as f:
img_bytes = f.read()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
data_url = f"data:image/jpeg;base64,{img_b64}"
handler = EndpointHandler(path=".")
result = handler({"inputs": data_url})[0]
# Decode the returned mask and save
mask_bytes = base64.b64decode(result["mask_png_base64"])
mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
mask_img.save(output_path, format="JPEG")
print(f"Wrote mask to {output_path}")
if __name__ == "__main__":
main()
|