CodeJackR
Fix errors
d816a26
raw
history blame
3.04 kB
# handler.py
import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from typing import Dict, List, Any
class EndpointHandler():
def __init__(self, path=""):
"""
Called once at startup.
Load the SAM model using Hugging Face Transformers.
"""
try:
# Load the model and processor from the local path
self.model = SamModel.from_pretrained(path)
self.processor = SamProcessor.from_pretrained(path)
except Exception as e:
# Fallback to loading from a known SAM model if local loading fails
print(f"Failed to load from local path: {e}")
print("Attempting to load from facebook/sam-vit-base")
self.model = SamModel.from_pretrained("facebook/sam-vit-base")
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Called on every HTTP request.
Expecting base64 encoded image in the 'inputs' field or 'image' field.
"""
# Handle different input formats
if "inputs" in data:
if isinstance(data["inputs"], str):
# Base64 encoded image
image_bytes = base64.b64decode(data["inputs"])
elif isinstance(data["inputs"], dict) and "image" in data["inputs"]:
# Nested structure with image field
image_bytes = base64.b64decode(data["inputs"]["image"])
else:
raise ValueError("Invalid input format. Expected base64 encoded image string.")
elif "image" in data:
# Direct image field
image_bytes = base64.b64decode(data["image"])
else:
raise ValueError("No image found in request. Expected 'inputs' or 'image' field with base64 encoded image.")
# Process the image
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Prepare inputs for the model
inputs = self.processor(img, return_tensors="pt")
# Generate masks using the model
with torch.no_grad():
outputs = self.model(**inputs)
# Process the outputs to get masks
masks = self.processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0]
# Convert the first mask to a binary mask
mask = masks[0].squeeze().numpy()
mask_binary = (mask > 0.0).astype(np.uint8) * 255
# Convert result to base64
out = io.BytesIO()
Image.fromarray(mask_binary).save(out, format="PNG")
out.seek(0)
mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
# Return in the expected format
return [{"mask_png_base64": mask_base64, "num_masks": len(masks)}]