CodeJackR
Input image as image
064d94e
raw
history blame
5.09 kB
# handler.py
import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from typing import Dict, List, Any
import torch.nn.functional as F
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
"""
Called once at startup.
Load the SAM model using Hugging Face Transformers.
"""
try:
# Load the model and processor from the local path
self.model = SamModel.from_pretrained(path).to(device)
self.processor = SamProcessor.from_pretrained(path)
except Exception as e:
# Fallback to loading from a known SAM model if local loading fails
print(f"Failed to load from local path: {e}")
print("Attempting to load from facebook/sam-vit-base")
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def __call__(self, data: Any) -> Any:
"""
Called on every HTTP request.
Args:
data (:obj:):
includes the input data and the parameters for the inference.
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
img = Image.open(io.BytesIO(inputs))
# img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# img = raw_images[0]
# SAM requires input prompts, so we'll generate a center point prompt
height, width = img.size[1], img.size[0] # PIL returns (width, height)
# Create a center point prompt for automatic segmentation
input_points = [[[width // 2, height // 2]]] # Center point
input_labels = [[1]] # Positive prompt
# Prepare inputs for the model with prompts
inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt")
# Generate masks using the model
with torch.no_grad():
outputs = self.model(**inputs)
try:
# Get original image size
original_height, original_width = inputs["original_sizes"][0].tolist()
# Get predicted masks and scores
pred_masks = outputs.pred_masks.cpu() # (batch, num_masks, H, W)
iou_scores = outputs.iou_scores.cpu()[0] # (num_masks,)
# The model might return 4D or 5D tensors. Squeeze if 5D.
if pred_masks.ndim == 5:
pred_masks = pred_masks.squeeze(1)
# Select the best mask
best_mask_idx = torch.argmax(iou_scores)
best_mask_tensor = pred_masks[0, best_mask_idx, :, :] # (H, W)
# Upscale the mask to original image size
# Add batch and channel dims for interpolate
upscaled_mask = F.interpolate(
best_mask_tensor.unsqueeze(0).unsqueeze(0).float(),
size=(original_height, original_width),
mode='bilinear',
align_corners=False
).squeeze() # remove batch/channel dims
# Convert to binary mask
mask_binary = (upscaled_mask > 0.0).numpy().astype(np.uint8) * 255
except Exception as e:
print(f"Error processing masks: {e}")
# Fallback
height, width = img.size[1], img.size[0]
mask_binary = np.zeros((height, width), dtype=np.uint8)
center_x, center_y = width // 2, height // 2
size = min(width, height) // 8
mask_binary[center_y-size:center_y+size, center_x-size:center_x+size] = 255
# Convert result to base64
out = io.BytesIO()
Image.fromarray(mask_binary).save(out, format="PNG")
out.seek(0)
mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
# Decode the returned mask and save
mask_bytes = base64.b64decode(mask_base64)
mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
# mask_img.save(output_path, format="JPEG")
# print(f"Wrote mask to {output_path}")
# Return in the expected format
return mask_img
def main():
# Hardcoded input and output paths
input_path = "/Users/rp7/Downloads/test.jpeg"
output_path = "output.jpg"
# Read and base64-encode the input image
with open(input_path, "rb") as f:
img_bytes = f.read()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
data_url = f"data:image/jpeg;base64,{img_b64}"
handler = EndpointHandler(path=".")
result = handler({"inputs": data_url})[0]
# Decode the returned mask and save
mask_bytes = base64.b64decode(result["mask_png_base64"])
mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
mask_img.save(output_path, format="JPEG")
print(f"Wrote mask to {output_path}")
if __name__ == "__main__":
main()