|
|
|
|
|
|
|
|
import io |
|
|
import base64 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import SamModel, SamProcessor |
|
|
from typing import Dict, List, Any |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Called once at startup. |
|
|
Load the SAM model using Hugging Face Transformers. |
|
|
""" |
|
|
try: |
|
|
|
|
|
self.model = SamModel.from_pretrained(path) |
|
|
self.processor = SamProcessor.from_pretrained(path) |
|
|
except Exception as e: |
|
|
|
|
|
print(f"Failed to load from local path: {e}") |
|
|
print("Attempting to load from facebook/sam-vit-base") |
|
|
self.model = SamModel.from_pretrained("facebook/sam-vit-base") |
|
|
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Called on every HTTP request. |
|
|
Expecting base64 encoded image in the 'inputs' field or 'image' field. |
|
|
""" |
|
|
|
|
|
if "inputs" in data: |
|
|
if isinstance(data["inputs"], str): |
|
|
|
|
|
image_bytes = base64.b64decode(data["inputs"]) |
|
|
elif isinstance(data["inputs"], dict) and "image" in data["inputs"]: |
|
|
|
|
|
image_bytes = base64.b64decode(data["inputs"]["image"]) |
|
|
else: |
|
|
raise ValueError("Invalid input format. Expected base64 encoded image string.") |
|
|
elif "image" in data: |
|
|
|
|
|
image_bytes = base64.b64decode(data["image"]) |
|
|
else: |
|
|
raise ValueError("No image found in request. Expected 'inputs' or 'image' field with base64 encoded image.") |
|
|
|
|
|
|
|
|
img = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
inputs = self.processor(img, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
|
|
|
|
|
|
masks = self.processor.image_processor.post_process_masks( |
|
|
outputs.pred_masks.cpu(), |
|
|
inputs["original_sizes"].cpu(), |
|
|
inputs["reshaped_input_sizes"].cpu() |
|
|
)[0] |
|
|
|
|
|
|
|
|
mask = masks[0].squeeze().numpy() |
|
|
mask_binary = (mask > 0.0).astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
out = io.BytesIO() |
|
|
Image.fromarray(mask_binary).save(out, format="PNG") |
|
|
out.seek(0) |
|
|
mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8') |
|
|
|
|
|
|
|
|
return [{"mask_png_base64": mask_base64, "num_masks": len(masks)}] |
|
|
|