EdgeSAM / handler.py
JingShiang Yang
Return embedding bin file
976e547
from typing import Dict, List, Any
import onnxruntime as ort
import numpy as np
from PIL import Image
import io
import base64
import os
from huggingface_hub import hf_hub_download
class EndpointHandler:
def __init__(self, path=""):
# Download models from Hugging Face Hub
repo_id = "chongzhou/EdgeSAM"
encoder_path = hf_hub_download(
repo_id=repo_id,
filename="edge_sam_3x_encoder.onnx"
)
decoder_path = hf_hub_download(
repo_id=repo_id,
filename="edge_sam_3x_decoder.onnx"
)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self.encoder = ort.InferenceSession(encoder_path, providers=providers)
self.decoder = ort.InferenceSession(decoder_path, providers=providers)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
try:
# Parse input
inputs = data.get("inputs", data)
params = data.get("parameters", {})
# Load image
if isinstance(inputs, str):
image = Image.open(io.BytesIO(base64.b64decode(inputs)))
else:
image = inputs
# Preprocess
if image.mode != 'RGB':
image = image.convert('RGB')
image = image.resize((1024, 1024), Image.BILINEAR)
img_array = np.array(image).astype(np.float32) / 255.0
img_array = img_array.transpose(2, 0, 1)[np.newaxis, :]
# Encode
embeddings = self.encoder.run(None, {'image': img_array})[0]
# Check if only embeddings are requested
if params.get("return_embeddings_only", False):
# Convert embeddings to float32 binary and save to temp file
embeddings_float32 = embeddings.astype(np.float32)
temp_file = "/tmp/embeddings.bin"
embeddings_float32.tofile(temp_file)
return [{"file": temp_file, "shape": list(embeddings.shape), "dtype": "float32"}]
# Prepare prompts
coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32)
labels = np.array(params.get("point_labels", [1]), dtype=np.float32)
# Decode
decoder_outputs = self.decoder.run(None, {
'image_embeddings': embeddings,
'point_coords': coords.reshape(1, -1, 2),
'point_labels': labels.reshape(1, -1)
})
# decoder_outputs[0] is IoU scores (1, 4)
# decoder_outputs[1] is masks (1, 4, 256, 256)
masks = decoder_outputs[1]
# Take first mask and resize to 1024x1024
mask = masks[0, 0] # Shape: (256, 256)
mask = Image.fromarray(mask).resize((1024, 1024), Image.BILINEAR)
mask = np.array(mask)
mask = (mask > 0.0).astype(np.uint8) * 255
# Return result
result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)}
if params.get("return_mask_image", True):
buffer = io.BytesIO()
Image.fromarray(mask, mode='L').save(buffer, format='PNG')
result["mask"] = base64.b64encode(buffer.getvalue()).decode()
return [result]
except Exception as e:
import traceback
return [{
"error": str(e),
"type": type(e).__name__,
"traceback": traceback.format_exc()
}]