File size: 3,042 Bytes
592adee 2f7cfdc 592adee d05bd8d 592adee d816a26 d05bd8d 592adee d05bd8d 2f7cfdc d816a26 2f7cfdc d816a26 592adee d05bd8d 2f7cfdc d05bd8d 2f7cfdc d05bd8d 2f7cfdc d05bd8d 2f7cfdc d816a26 d05bd8d 2f7cfdc d816a26 2f7cfdc d05bd8d 592adee d05bd8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# handler.py
import io
import base64
import numpy as np
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from typing import Dict, List, Any
class EndpointHandler():
def __init__(self, path=""):
"""
Called once at startup.
Load the SAM model using Hugging Face Transformers.
"""
try:
# Load the model and processor from the local path
self.model = SamModel.from_pretrained(path)
self.processor = SamProcessor.from_pretrained(path)
except Exception as e:
# Fallback to loading from a known SAM model if local loading fails
print(f"Failed to load from local path: {e}")
print("Attempting to load from facebook/sam-vit-base")
self.model = SamModel.from_pretrained("facebook/sam-vit-base")
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Called on every HTTP request.
Expecting base64 encoded image in the 'inputs' field or 'image' field.
"""
# Handle different input formats
if "inputs" in data:
if isinstance(data["inputs"], str):
# Base64 encoded image
image_bytes = base64.b64decode(data["inputs"])
elif isinstance(data["inputs"], dict) and "image" in data["inputs"]:
# Nested structure with image field
image_bytes = base64.b64decode(data["inputs"]["image"])
else:
raise ValueError("Invalid input format. Expected base64 encoded image string.")
elif "image" in data:
# Direct image field
image_bytes = base64.b64decode(data["image"])
else:
raise ValueError("No image found in request. Expected 'inputs' or 'image' field with base64 encoded image.")
# Process the image
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Prepare inputs for the model
inputs = self.processor(img, return_tensors="pt")
# Generate masks using the model
with torch.no_grad():
outputs = self.model(**inputs)
# Process the outputs to get masks
masks = self.processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0]
# Convert the first mask to a binary mask
mask = masks[0].squeeze().numpy()
mask_binary = (mask > 0.0).astype(np.uint8) * 255
# Convert result to base64
out = io.BytesIO()
Image.fromarray(mask_binary).save(out, format="PNG")
out.seek(0)
mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
# Return in the expected format
return [{"mask_png_base64": mask_base64, "num_masks": len(masks)}]
|