File size: 3,560 Bytes
30158dd
 
 
 
 
 
 
 
 
 
 
 
 
9dbafd4
30158dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976e547
 
 
 
 
 
 
 
30158dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import Dict, List, Any
import onnxruntime as ort
import numpy as np
from PIL import Image
import io
import base64
import os
from huggingface_hub import hf_hub_download


class EndpointHandler:
    def __init__(self, path=""):
        # Download models from Hugging Face Hub
        repo_id = "chongzhou/EdgeSAM"

        encoder_path = hf_hub_download(
            repo_id=repo_id,
            filename="edge_sam_3x_encoder.onnx"
        )
        decoder_path = hf_hub_download(
            repo_id=repo_id,
            filename="edge_sam_3x_decoder.onnx"
        )

        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

        self.encoder = ort.InferenceSession(encoder_path, providers=providers)
        self.decoder = ort.InferenceSession(decoder_path, providers=providers)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        try:
            # Parse input
            inputs = data.get("inputs", data)
            params = data.get("parameters", {})

            # Load image
            if isinstance(inputs, str):
                image = Image.open(io.BytesIO(base64.b64decode(inputs)))
            else:
                image = inputs

            # Preprocess
            if image.mode != 'RGB':
                image = image.convert('RGB')
            image = image.resize((1024, 1024), Image.BILINEAR)
            img_array = np.array(image).astype(np.float32) / 255.0
            img_array = img_array.transpose(2, 0, 1)[np.newaxis, :]

            # Encode
            embeddings = self.encoder.run(None, {'image': img_array})[0]

            # Check if only embeddings are requested
            if params.get("return_embeddings_only", False):
                # Convert embeddings to float32 binary and save to temp file
                embeddings_float32 = embeddings.astype(np.float32)
                temp_file = "/tmp/embeddings.bin"
                embeddings_float32.tofile(temp_file)
                return [{"file": temp_file, "shape": list(embeddings.shape), "dtype": "float32"}]

            # Prepare prompts
            coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32)
            labels = np.array(params.get("point_labels", [1]), dtype=np.float32)

            # Decode
            decoder_outputs = self.decoder.run(None, {
                'image_embeddings': embeddings,
                'point_coords': coords.reshape(1, -1, 2),
                'point_labels': labels.reshape(1, -1)
            })

            # decoder_outputs[0] is IoU scores (1, 4)
            # decoder_outputs[1] is masks (1, 4, 256, 256)
            masks = decoder_outputs[1]

            # Take first mask and resize to 1024x1024
            mask = masks[0, 0]  # Shape: (256, 256)
            mask = Image.fromarray(mask).resize((1024, 1024), Image.BILINEAR)
            mask = np.array(mask)
            mask = (mask > 0.0).astype(np.uint8) * 255

            # Return result
            result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)}

            if params.get("return_mask_image", True):
                buffer = io.BytesIO()
                Image.fromarray(mask, mode='L').save(buffer, format='PNG')
                result["mask"] = base64.b64encode(buffer.getvalue()).decode()

            return [result]

        except Exception as e:
            import traceback
            return [{
                "error": str(e),
                "type": type(e).__name__,
                "traceback": traceback.format_exc()
            }]