Spaces:
Sleeping
Sleeping
| from typing import Dict, List, Any | |
| import onnxruntime as ort | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| # Download models from Hugging Face Hub | |
| repo_id = "chongzhou/EdgeSAM" | |
| encoder_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="edge_sam_3x_encoder.onnx" | |
| ) | |
| decoder_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="edge_sam_3x_decoder.onnx" | |
| ) | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| self.encoder = ort.InferenceSession(encoder_path, providers=providers) | |
| self.decoder = ort.InferenceSession(decoder_path, providers=providers) | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| try: | |
| # Parse input | |
| inputs = data.get("inputs", data) | |
| params = data.get("parameters", {}) | |
| # Load image | |
| if isinstance(inputs, str): | |
| image = Image.open(io.BytesIO(base64.b64decode(inputs))) | |
| else: | |
| image = inputs | |
| # Preprocess | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image = image.resize((1024, 1024), Image.BILINEAR) | |
| img_array = np.array(image).astype(np.float32) / 255.0 | |
| img_array = img_array.transpose(2, 0, 1)[np.newaxis, :] | |
| # Encode | |
| embeddings = self.encoder.run(None, {'image': img_array})[0] | |
| # Check if only embeddings are requested | |
| if params.get("return_embeddings_only", False): | |
| # Convert embeddings to float32 binary and save to temp file | |
| embeddings_float32 = embeddings.astype(np.float32) | |
| temp_file = "/tmp/embeddings.bin" | |
| embeddings_float32.tofile(temp_file) | |
| return [{"file": temp_file, "shape": list(embeddings.shape), "dtype": "float32"}] | |
| # Prepare prompts | |
| coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32) | |
| labels = np.array(params.get("point_labels", [1]), dtype=np.float32) | |
| # Decode | |
| decoder_outputs = self.decoder.run(None, { | |
| 'image_embeddings': embeddings, | |
| 'point_coords': coords.reshape(1, -1, 2), | |
| 'point_labels': labels.reshape(1, -1) | |
| }) | |
| # decoder_outputs[0] is IoU scores (1, 4) | |
| # decoder_outputs[1] is masks (1, 4, 256, 256) | |
| masks = decoder_outputs[1] | |
| # Take first mask and resize to 1024x1024 | |
| mask = masks[0, 0] # Shape: (256, 256) | |
| mask = Image.fromarray(mask).resize((1024, 1024), Image.BILINEAR) | |
| mask = np.array(mask) | |
| mask = (mask > 0.0).astype(np.uint8) * 255 | |
| # Return result | |
| result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)} | |
| if params.get("return_mask_image", True): | |
| buffer = io.BytesIO() | |
| Image.fromarray(mask, mode='L').save(buffer, format='PNG') | |
| result["mask"] = base64.b64encode(buffer.getvalue()).decode() | |
| return [result] | |
| except Exception as e: | |
| import traceback | |
| return [{ | |
| "error": str(e), | |
| "type": type(e).__name__, | |
| "traceback": traceback.format_exc() | |
| }] | |