JingShiang Yang commited on
Commit
976e547
·
1 Parent(s): 9dbafd4

Return embedding bin file

Browse files
Files changed (2) hide show
  1. app.py +11 -6
  2. handler.py +8 -0
app.py CHANGED
@@ -5,23 +5,28 @@ from handler import EndpointHandler
5
  # 初始化你的 ONNX 模型
6
  handler = EndpointHandler(path=".")
7
 
8
- def api_call(image, parameters):
9
  data = {
10
  "inputs": image, # PIL.Image
11
- "parameters": parameters or {}
12
  }
13
  result = handler(data)[0]
 
 
 
 
 
14
  return result # JSON only(含 base64 mask)
15
 
16
  demo = gr.Interface(
17
  fn=api_call,
18
  inputs=[
19
  gr.Image(type="pil", label="image"),
20
- gr.JSON(label="parameters (point_coords, point_labels, return_mask_image)")
21
  ],
22
- outputs=gr.JSON(label="result"),
23
- title="Edge SAM API",
24
- description="Pure API Space send image + parameters, get mask JSON."
25
  )
26
 
27
  if __name__ == "__main__":
 
5
  # 初始化你的 ONNX 模型
6
  handler = EndpointHandler(path=".")
7
 
8
+ def api_call(image, return_embeddings_only):
9
  data = {
10
  "inputs": image, # PIL.Image
11
+ "parameters": {"return_embeddings_only": return_embeddings_only}
12
  }
13
  result = handler(data)[0]
14
+
15
+ # If embeddings file is returned
16
+ if "file" in result:
17
+ return result["file"]
18
+
19
  return result # JSON only(含 base64 mask)
20
 
21
  demo = gr.Interface(
22
  fn=api_call,
23
  inputs=[
24
  gr.Image(type="pil", label="image"),
25
+ gr.Checkbox(label="Return embeddings only (as .bin file)", value=True)
26
  ],
27
+ outputs=[gr.File(label="embeddings.bin or result JSON")],
28
+ title="Edge SAM Encoder API",
29
+ description="Send image, get embeddings in float32 binary format (.bin file)."
30
  )
31
 
32
  if __name__ == "__main__":
handler.py CHANGED
@@ -49,6 +49,14 @@ class EndpointHandler:
49
  # Encode
50
  embeddings = self.encoder.run(None, {'image': img_array})[0]
51
 
 
 
 
 
 
 
 
 
52
  # Prepare prompts
53
  coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32)
54
  labels = np.array(params.get("point_labels", [1]), dtype=np.float32)
 
49
  # Encode
50
  embeddings = self.encoder.run(None, {'image': img_array})[0]
51
 
52
+ # Check if only embeddings are requested
53
+ if params.get("return_embeddings_only", False):
54
+ # Convert embeddings to float32 binary and save to temp file
55
+ embeddings_float32 = embeddings.astype(np.float32)
56
+ temp_file = "/tmp/embeddings.bin"
57
+ embeddings_float32.tofile(temp_file)
58
+ return [{"file": temp_file, "shape": list(embeddings.shape), "dtype": "float32"}]
59
+
60
  # Prepare prompts
61
  coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32)
62
  labels = np.array(params.get("point_labels", [1]), dtype=np.float32)