JingShiang Yang commited on
Commit
594e70f
·
1 Parent(s): 4194e02

add handler.py

Browse files
Files changed (3) hide show
  1. README.md +84 -0
  2. handler.py +71 -0
  3. requirements.txt +3 -0
README.md CHANGED
@@ -1,3 +1,87 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ # EdgeSAM - Efficient Segment Anything Model
6
+
7
+ EdgeSAM is an accelerated variant of the Segment Anything Model (SAM) optimized for edge devices using ONNX Runtime.
8
+
9
+ ## Model Files
10
+
11
+ - `edge_sam_3x_encoder.onnx` - Image encoder (1024x1024 input)
12
+ - `edge_sam_3x_decoder.onnx` - Mask decoder with prompt support
13
+
14
+ ## Usage
15
+
16
+ ### API Request Format
17
+
18
+ ```python
19
+ import requests
20
+ import base64
21
+
22
+ # Encode your image
23
+ with open("image.jpg", "rb") as f:
24
+ image_b64 = base64.b64encode(f.read()).decode()
25
+
26
+ # Make request
27
+ response = requests.post(
28
+ "https://YOUR-ENDPOINT-URL",
29
+ json={
30
+ "inputs": image_b64,
31
+ "parameters": {
32
+ "point_coords": [[512, 512]], # Click point in 1024x1024 space
33
+ "point_labels": [1], # 1 = foreground, 0 = background
34
+ "return_mask_image": True
35
+ }
36
+ }
37
+ )
38
+
39
+ result = response.json()
40
+ ```
41
+
42
+ ### Response Format
43
+
44
+ ```json
45
+ [
46
+ {
47
+ "mask_shape": [1024, 1024],
48
+ "has_object": true,
49
+ "mask": "<base64_encoded_png>"
50
+ }
51
+ ]
52
+ ```
53
+
54
+ ### Parameters
55
+
56
+ - **point_coords**: Array of `[x, y]` coordinates in 1024x1024 space (optional)
57
+ - **point_labels**: Array of labels (1=foreground, 0=background) corresponding to points (optional)
58
+ - **box_coords**: Bounding box `[x1, y1, x2, y2]` (optional, not yet implemented)
59
+ - **return_mask_image**: Return base64-encoded PNG mask (default: `true`)
60
+
61
+ ### Coordinate System
62
+
63
+ All coordinates should be in **1024x1024** space, regardless of original image size. The handler automatically resizes input images to 1024x1024 before processing.
64
+
65
+ Example: For a click at the center of any image, use `[512, 512]`.
66
+
67
+ ## Local Testing
68
+
69
+ ```bash
70
+ # Install dependencies
71
+ pip install -r requirements.txt
72
+
73
+ # Run test script
74
+ python test_handler.py
75
+ ```
76
+
77
+ This will create:
78
+ - `test_input.png` - Test image with red circle
79
+ - `test_output_mask.png` - Generated segmentation mask
80
+ - `test_output_overlay.png` - Overlay visualization
81
+
82
+ ## Technical Details
83
+
84
+ - **Input**: RGB images (auto-resized to 1024x1024)
85
+ - **Preprocessing**: Normalized to [0, 1] range (`/ 255.0`)
86
+ - **Hardware**: Supports CUDA GPU with automatic CPU fallback
87
+ - **Framework**: ONNX Runtime Web compatible
handler.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import onnxruntime as ort
3
+ import numpy as np
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+ import os
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ model_path = path if path else "."
13
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
14
+
15
+ self.encoder = ort.InferenceSession(
16
+ os.path.join(model_path, "edge_sam_3x_encoder.onnx"),
17
+ providers=providers
18
+ )
19
+ self.decoder = ort.InferenceSession(
20
+ os.path.join(model_path, "edge_sam_3x_decoder.onnx"),
21
+ providers=providers
22
+ )
23
+
24
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
25
+ try:
26
+ # Parse input
27
+ inputs = data.get("inputs", data)
28
+ params = data.get("parameters", {})
29
+
30
+ # Load image
31
+ if isinstance(inputs, str):
32
+ image = Image.open(io.BytesIO(base64.b64decode(inputs)))
33
+ else:
34
+ image = inputs
35
+
36
+ # Preprocess
37
+ if image.mode != 'RGB':
38
+ image = image.convert('RGB')
39
+ image = image.resize((1024, 1024), Image.BILINEAR)
40
+ img_array = np.array(image).astype(np.float32) / 255.0
41
+ img_array = img_array.transpose(2, 0, 1)[np.newaxis, :]
42
+
43
+ # Encode
44
+ embeddings = self.encoder.run(None, {'image': img_array})[0]
45
+
46
+ # Prepare prompts
47
+ coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32)
48
+ labels = np.array(params.get("point_labels", [1]), dtype=np.float32)
49
+
50
+ # Decode
51
+ masks = self.decoder.run(None, {
52
+ 'image_embeddings': embeddings,
53
+ 'point_coords': coords.reshape(1, -1, 2),
54
+ 'point_labels': labels.reshape(1, -1)
55
+ })[0]
56
+
57
+ # Postprocess
58
+ mask = (masks[0, 0] > 0.0).astype(np.uint8) * 255
59
+
60
+ # Return result
61
+ result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)}
62
+
63
+ if params.get("return_mask_image", True):
64
+ buffer = io.BytesIO()
65
+ Image.fromarray(mask, mode='L').save(buffer, format='PNG')
66
+ result["mask"] = base64.b64encode(buffer.getvalue()).decode()
67
+
68
+ return [result]
69
+
70
+ except Exception as e:
71
+ return [{"error": str(e)}]
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ onnxruntime>=1.16.0
2
+ numpy>=1.24.0
3
+ Pillow>=10.0.0