Anigor66 commited on
Commit
6b32938
·
1 Parent(s): f61a56b

Update API to match backend format - add segment_points, segment_box, segment_multiple_boxes

Browse files
Files changed (1) hide show
  1. app.py +483 -158
app.py CHANGED
@@ -1,5 +1,7 @@
1
  """
2
- HuggingFace Space for MedSAM Inference with Point Prompts
 
 
3
  Deploy this to: https://huggingface.co/spaces/YOUR_USERNAME/medsam-inference
4
  """
5
  import gradio as gr
@@ -33,61 +35,331 @@ def patched_torch_load(f, *args, **kwargs):
33
  torch.load = patched_torch_load
34
 
35
  try:
36
- sam = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_CHECKPOINT)
37
- sam.to(device=device)
38
- sam.eval()
39
- predictor = SamPredictor(sam)
40
- print("✓ MedSAM model loaded successfully!")
41
  finally:
42
  # Restore original torch.load
43
  torch.load = original_torch_load
44
 
45
 
46
- def segment_with_points(image, points_json):
 
 
 
 
47
  """
48
- Segment image with point prompts
 
 
 
49
 
50
  Args:
51
  image: PIL Image
52
- points_json: JSON string with format:
53
  {
54
- "coords": [[x1, y1], [x2, y2], ...],
55
- "labels": [1, 0, ...], # 1=foreground, 0=background
56
- "multimask_output": true/false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  }
58
 
59
  Returns:
60
- JSON string with masks and scores
 
 
 
 
 
 
61
  """
62
  try:
63
  # Parse input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  points_data = json.loads(points_json)
65
  coords = np.array(points_data["coords"])
66
  labels = np.array(points_data["labels"])
67
  multimask_output = points_data.get("multimask_output", True)
68
 
69
- # Convert PIL to numpy
70
  image_array = np.array(image)
71
-
72
- # Set image in predictor
73
  predictor.set_image(image_array)
74
 
75
- # Run prediction
76
  masks, scores, logits = predictor.predict(
77
  point_coords=coords,
78
  point_labels=labels,
79
  multimask_output=multimask_output
80
  )
81
 
82
- # Convert masks to lists (JSON serializable)
83
  masks_list = []
84
  scores_list = []
85
 
86
  for i, (mask, score) in enumerate(zip(masks, scores)):
87
- # Convert boolean mask to uint8
88
  mask_uint8 = (mask * 255).astype(np.uint8)
89
-
90
- # Encode mask as base64 PNG
91
  mask_image = Image.fromarray(mask_uint8)
92
  buffer = io.BytesIO()
93
  mask_image.save(buffer, format='PNG')
@@ -96,55 +368,37 @@ def segment_with_points(image, points_json):
96
  masks_list.append({
97
  'mask_base64': mask_base64,
98
  'mask_shape': mask.shape,
99
- 'mask_data': mask.tolist() # Also include raw data for processing
100
  })
101
  scores_list.append(float(score))
102
 
103
- result = {
104
  'success': True,
105
  'masks': masks_list,
106
  'scores': scores_list,
107
  'num_masks': len(masks_list)
108
- }
109
-
110
- return json.dumps(result)
111
 
112
  except Exception as e:
113
- error_result = {
114
- 'success': False,
115
- 'error': str(e)
116
- }
117
- return json.dumps(error_result)
118
 
119
 
120
- def segment_with_box(image, box_json):
121
  """
122
- Segment image with box prompt
123
 
124
  Args:
125
- image: PIL Image
126
  box_json: JSON string with format:
127
- {
128
- "box": [x1, y1, x2, y2], # Top-left and bottom-right corners
129
- "multimask_output": true/false
130
- }
131
-
132
- Returns:
133
- JSON string with masks and scores
134
  """
135
  try:
136
- # Parse input
137
  box_data = json.loads(box_json)
138
- box = np.array(box_data["box"]) # [x1, y1, x2, y2]
139
- multimask_output = box_data.get("multimask_output", False) # Usually False for box
140
 
141
- # Convert PIL to numpy
142
  image_array = np.array(image)
143
-
144
- # Set image in predictor
145
  predictor.set_image(image_array)
146
 
147
- # Run prediction with box
148
  masks, scores, logits = predictor.predict(
149
  point_coords=None,
150
  point_labels=None,
@@ -152,15 +406,11 @@ def segment_with_box(image, box_json):
152
  multimask_output=multimask_output
153
  )
154
 
155
- # Convert masks to lists (JSON serializable)
156
  masks_list = []
157
  scores_list = []
158
 
159
  for i, (mask, score) in enumerate(zip(masks, scores)):
160
- # Convert boolean mask to uint8
161
  mask_uint8 = (mask * 255).astype(np.uint8)
162
-
163
- # Encode mask as base64 PNG
164
  mask_image = Image.fromarray(mask_uint8)
165
  buffer = io.BytesIO()
166
  mask_image.save(buffer, format='PNG')
@@ -173,40 +423,25 @@ def segment_with_box(image, box_json):
173
  })
174
  scores_list.append(float(score))
175
 
176
- result = {
177
  'success': True,
178
  'masks': masks_list,
179
  'scores': scores_list,
180
  'num_masks': len(masks_list),
181
  'box': box.tolist()
182
- }
183
-
184
- return json.dumps(result)
185
 
186
  except Exception as e:
187
  import traceback
188
- error_result = {
189
  'success': False,
190
  'error': str(e),
191
  'traceback': traceback.format_exc()
192
- }
193
- return json.dumps(error_result)
194
 
195
 
196
  def segment_simple(image, x, y, label=1, multimask=True):
197
- """
198
- Simple single-point segmentation interface for Gradio UI
199
-
200
- Args:
201
- image: PIL Image
202
- x: X coordinate
203
- y: Y coordinate
204
- label: 1 for foreground, 0 for background
205
- multimask: Whether to output multiple masks
206
-
207
- Returns:
208
- Mask image and score
209
- """
210
  try:
211
  points_json = json.dumps({
212
  "coords": [[int(x), int(y)]],
@@ -214,18 +449,16 @@ def segment_simple(image, x, y, label=1, multimask=True):
214
  "multimask_output": multimask
215
  })
216
 
217
- result_json = segment_with_points(image, points_json)
218
  result = json.loads(result_json)
219
 
220
  if not result['success']:
221
  return None, f"Error: {result['error']}"
222
 
223
- # Get best mask (highest score)
224
  best_idx = np.argmax(result['scores'])
225
  best_mask_base64 = result['masks'][best_idx]['mask_base64']
226
  best_score = result['scores'][best_idx]
227
 
228
- # Decode mask
229
  mask_bytes = base64.b64decode(best_mask_base64)
230
  mask_image = Image.open(io.BytesIO(mask_bytes))
231
 
@@ -235,115 +468,209 @@ def segment_simple(image, x, y, label=1, multimask=True):
235
  return None, f"Error: {str(e)}"
236
 
237
 
238
- # Create Gradio interface with two tabs
 
 
 
239
  with gr.Blocks(title="MedSAM Inference API") as demo:
240
  gr.Markdown("# 🏥 MedSAM Inference API")
241
- gr.Markdown("Point-based segmentation using Fine-Tuned MedSAM")
 
242
 
243
  with gr.Tabs():
244
- # Tab 1: API Interface (for programmatic access)
245
- with gr.Tab("API Interface"):
246
  gr.Markdown("""
247
- ## JSON API for Programmatic Access
 
 
 
 
248
 
249
  **Input Format:**
250
  ```json
251
  {
252
- "coords": [[x1, y1], [x2, y2]],
253
- "labels": [1, 0],
254
- "multimask_output": true
255
  }
256
  ```
257
 
258
- **Output Format:**
259
  ```json
260
  {
261
  "success": true,
262
- "masks": [...],
263
- "scores": [0.95, 0.88, 0.76],
264
- "num_masks": 3
265
  }
266
  ```
267
  """)
268
 
269
  with gr.Row():
270
  with gr.Column():
271
- api_image = gr.Image(type="pil", label="Input Image")
272
- api_points = gr.Textbox(
273
- label="Points JSON",
274
- placeholder='{"coords": [[100, 150]], "labels": [1], "multimask_output": true}',
275
  lines=3
276
  )
277
- api_button = gr.Button("Run Segmentation", variant="primary")
278
 
279
  with gr.Column():
280
- api_output = gr.Textbox(label="Result JSON", lines=15)
281
 
282
- api_button.click(
283
- fn=segment_with_points,
284
- inputs=[api_image, api_points],
285
- outputs=api_output
 
286
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
- # Example
289
- gr.Examples(
290
- examples=[
291
- [
292
- "example_image.jpg",
293
- '{"coords": [[200, 200]], "labels": [1], "multimask_output": true}'
294
- ]
295
- ],
296
- inputs=[api_image, api_points]
297
  )
298
 
299
- # Tab 2: Box-based Segmentation
300
- with gr.Tab("Box Segmentation"):
301
  gr.Markdown("""
302
- ## Box-based Segmentation
303
 
304
- Segment using a bounding box (rectangle).
305
 
306
  **Input Format:**
307
  ```json
308
  {
309
- "box": [x1, y1, x2, y2],
310
- "multimask_output": false
 
 
 
 
 
 
 
 
 
311
  }
312
  ```
313
- Where (x1, y1) is top-left corner and (x2, y2) is bottom-right corner.
314
  """)
315
 
316
  with gr.Row():
317
  with gr.Column():
318
  box_image = gr.Image(type="pil", label="Input Image")
319
- box_json = gr.Textbox(
320
- label="Box JSON",
321
- placeholder='{"box": [100, 100, 300, 300], "multimask_output": false}',
322
  lines=3
323
  )
324
- box_button = gr.Button("Run Box Segmentation", variant="primary")
325
 
326
  with gr.Column():
327
  box_output = gr.Textbox(label="Result JSON", lines=15)
328
 
329
  box_button.click(
330
- fn=segment_with_box,
331
- inputs=[box_image, box_json],
332
- outputs=box_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  )
334
 
335
- # Example
336
- gr.Examples(
337
- examples=[
338
- [
339
- "example_image.jpg",
340
- '{"box": [150, 150, 350, 350], "multimask_output": false}'
341
- ]
342
- ],
343
- inputs=[box_image, box_json]
 
 
 
 
 
 
 
 
 
 
 
344
  )
345
 
346
- # Tab 3: Simple UI Interface
347
  with gr.Tab("Simple Interface"):
348
  gr.Markdown("## Click-based Segmentation")
349
  gr.Markdown("Enter X, Y coordinates to segment")
@@ -378,41 +705,39 @@ with gr.Blocks(title="MedSAM Inference API") as demo:
378
 
379
  gr.Markdown("""
380
  ---
381
- ### 📡 API Usage from Python
382
 
383
  ```python
384
- import requests
385
  import json
386
- import base64
387
- from PIL import Image
388
 
389
- # Your Space URL
390
- API_URL = "https://YOUR_USERNAME-medsam-inference.hf.space/api/predict"
391
 
392
- # Prepare image
393
- with open("image.jpg", "rb") as f:
394
- img_base64 = base64.b64encode(f.read()).decode()
395
-
396
- # Prepare points
397
- points_json = json.dumps({
398
- "coords": [[150, 200]],
399
- "labels": [1],
400
- "multimask_output": True
401
- })
402
 
403
- # Call API
404
- response = requests.post(
405
- API_URL,
406
- json={
407
- "data": [
408
- f"data:image/jpeg;base64,{img_base64}",
409
- points_json
410
- ]
411
- }
412
  )
413
 
414
- result = response.json()
415
- print(result)
 
 
 
 
416
  ```
417
  """)
418
 
@@ -422,5 +747,5 @@ if __name__ == "__main__":
422
  server_name="0.0.0.0",
423
  server_port=7860,
424
  share=False,
425
- show_error=True # Enable verbose error reporting
426
  )
 
1
  """
2
+ HuggingFace Space for MedSAM Inference
3
+ API-compatible with Dense-Captioning-Toolkit backend
4
+
5
  Deploy this to: https://huggingface.co/spaces/YOUR_USERNAME/medsam-inference
6
  """
7
  import gradio as gr
 
35
  torch.load = patched_torch_load
36
 
37
  try:
38
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_CHECKPOINT)
39
+ sam.to(device=device)
40
+ sam.eval()
41
+ predictor = SamPredictor(sam)
42
+ print("✓ MedSAM model loaded successfully!")
43
  finally:
44
  # Restore original torch.load
45
  torch.load = original_torch_load
46
 
47
 
48
+ # =============================================================================
49
+ # API FUNCTIONS - MATCHING BACKEND FORMAT (backend/app.py)
50
+ # =============================================================================
51
+
52
+ def segment_points(image, request_json):
53
  """
54
+ Segment image with point prompts - MATCHES BACKEND /api/medsam/segment_points
55
+
56
+ Each point gets its own small segment (converted to small bounding box).
57
+ This matches the backend behavior where points are converted to small boxes.
58
 
59
  Args:
60
  image: PIL Image
61
+ request_json: JSON string with format:
62
  {
63
+ "points": [[x1, y1], [x2, y2], ...],
64
+ "labels": [1, 0, ...] # 1=foreground, 0=background
65
+ }
66
+
67
+ Returns:
68
+ JSON string matching backend response format:
69
+ {
70
+ "success": true,
71
+ "masks": [{"mask": [[...]], "confidence": 0.95}, ...],
72
+ "confidences": [0.95, ...],
73
+ "method": "medsam_points_individual"
74
+ }
75
+ """
76
+ try:
77
+ # Parse input
78
+ data = json.loads(request_json)
79
+ points = data.get("points", [])
80
+ labels = data.get("labels", [])
81
+
82
+ if not points:
83
+ return json.dumps({'success': False, 'error': 'At least one point is required'})
84
+
85
+ # Convert PIL to numpy
86
+ image_array = np.array(image)
87
+ H, W = image_array.shape[:2]
88
+
89
+ # Set image in predictor
90
+ predictor.set_image(image_array)
91
+
92
+ # Process each point individually (like backend does)
93
+ box_size = 20 # Small box size for point-based segmentation
94
+ masks_list = []
95
+ confidences_list = []
96
+
97
+ for i, pt in enumerate(points):
98
+ x, y = pt
99
+
100
+ # Create a small bounding box centered on the point (matching backend behavior)
101
+ x1 = max(0, x - box_size // 2)
102
+ y1 = max(0, y - box_size // 2)
103
+ x2 = min(W - 1, x + box_size // 2)
104
+ y2 = min(H - 1, y + box_size // 2)
105
+ bbox = np.array([x1, y1, x2, y2])
106
+
107
+ print(f"Processing point {i+1}/{len(points)}: ({x}, {y}) -> bbox: {bbox.tolist()}")
108
+
109
+ # Run prediction with box
110
+ masks, scores, logits = predictor.predict(
111
+ point_coords=None,
112
+ point_labels=None,
113
+ box=bbox,
114
+ multimask_output=False
115
+ )
116
+
117
+ if len(masks) > 0:
118
+ # Take the best mask
119
+ best_idx = np.argmax(scores)
120
+ mask = masks[best_idx]
121
+ score = float(scores[best_idx])
122
+
123
+ masks_list.append({
124
+ 'mask': mask.astype(np.uint8).tolist(),
125
+ 'confidence': score
126
+ })
127
+ confidences_list.append(score)
128
+ print(f"Point {i+1} segmentation successful, confidence: {score:.4f}")
129
+ else:
130
+ print(f"Point {i+1} segmentation failed")
131
+
132
+ if masks_list:
133
+ result = {
134
+ 'success': True,
135
+ 'masks': masks_list,
136
+ 'confidences': confidences_list,
137
+ 'method': 'medsam_points_individual'
138
+ }
139
+ else:
140
+ result = {'success': False, 'error': 'All point segmentations failed'}
141
+
142
+ return json.dumps(result)
143
+
144
+ except Exception as e:
145
+ import traceback
146
+ return json.dumps({
147
+ 'success': False,
148
+ 'error': str(e),
149
+ 'traceback': traceback.format_exc()
150
+ })
151
+
152
+
153
+ def segment_box(image, request_json):
154
+ """
155
+ Segment image with a single bounding box - MATCHES BACKEND /api/medsam/segment_box
156
+
157
+ Args:
158
+ image: PIL Image
159
+ request_json: JSON string with format:
160
+ {
161
+ "bbox": [x1, y1, x2, y2] # Can be array or object with x1,y1,x2,y2
162
+ }
163
+
164
+ Returns:
165
+ JSON string matching backend response format:
166
+ {
167
+ "success": true,
168
+ "mask": [[...]],
169
+ "confidence": 0.95,
170
+ "method": "medsam_box"
171
+ }
172
+ """
173
+ try:
174
+ # Parse input
175
+ data = json.loads(request_json)
176
+ bbox = data.get("bbox", [])
177
+
178
+ # Handle both array format [x1,y1,x2,y2] and object format {x1,y1,x2,y2}
179
+ if isinstance(bbox, dict):
180
+ bbox = [bbox.get('x1', 0), bbox.get('y1', 0), bbox.get('x2', 0), bbox.get('y2', 0)]
181
+
182
+ if not bbox or len(bbox) != 4:
183
+ return json.dumps({'success': False, 'error': 'Valid bounding box required [x1, y1, x2, y2]'})
184
+
185
+ box = np.array(bbox)
186
+
187
+ # Convert PIL to numpy
188
+ image_array = np.array(image)
189
+
190
+ # Set image in predictor
191
+ predictor.set_image(image_array)
192
+
193
+ # Run prediction with box
194
+ masks, scores, logits = predictor.predict(
195
+ point_coords=None,
196
+ point_labels=None,
197
+ box=box,
198
+ multimask_output=False
199
+ )
200
+
201
+ if len(masks) > 0:
202
+ best_idx = np.argmax(scores)
203
+ mask = masks[best_idx]
204
+ score = float(scores[best_idx])
205
+
206
+ result = {
207
+ 'success': True,
208
+ 'mask': mask.astype(np.uint8).tolist(),
209
+ 'confidence': score,
210
+ 'method': 'medsam_box'
211
+ }
212
+ else:
213
+ result = {'success': False, 'error': 'Segmentation failed'}
214
+
215
+ return json.dumps(result)
216
+
217
+ except Exception as e:
218
+ import traceback
219
+ return json.dumps({
220
+ 'success': False,
221
+ 'error': str(e),
222
+ 'traceback': traceback.format_exc()
223
+ })
224
+
225
+
226
+ def segment_multiple_boxes(image, request_json):
227
+ """
228
+ Segment image with multiple bounding boxes - MATCHES BACKEND /api/medsam/segment_multiple_boxes
229
+
230
+ This is the main API endpoint used by the frontend for box-based segmentation.
231
+
232
+ Args:
233
+ image: PIL Image
234
+ request_json: JSON string with format:
235
+ {
236
+ "bboxes": [
237
+ [x1, y1, x2, y2], # Array format
238
+ {"x1": 10, "y1": 20, "x2": 100, "y2": 200} # Object format (also supported)
239
+ ]
240
  }
241
 
242
  Returns:
243
+ JSON string matching backend response format:
244
+ {
245
+ "success": true,
246
+ "masks": [{"mask": [[...]], "confidence": 0.95}, ...],
247
+ "confidences": [0.95, ...],
248
+ "method": "medsam_multiple_boxes"
249
+ }
250
  """
251
  try:
252
  # Parse input
253
+ data = json.loads(request_json)
254
+ bboxes = data.get("bboxes", [])
255
+
256
+ if not bboxes:
257
+ return json.dumps({'success': False, 'error': 'At least one bounding box is required'})
258
+
259
+ # Convert PIL to numpy
260
+ image_array = np.array(image)
261
+
262
+ # Set image in predictor
263
+ predictor.set_image(image_array)
264
+
265
+ print(f"Processing {len(bboxes)} boxes for segmentation")
266
+
267
+ masks_list = []
268
+ confidences_list = []
269
+
270
+ for i, bbox in enumerate(bboxes):
271
+ # Handle both array format [x1,y1,x2,y2] and object format {x1,y1,x2,y2}
272
+ if isinstance(bbox, dict):
273
+ box = np.array([
274
+ bbox.get('x1', 0),
275
+ bbox.get('y1', 0),
276
+ bbox.get('x2', 0),
277
+ bbox.get('y2', 0)
278
+ ])
279
+ else:
280
+ box = np.array(bbox)
281
+
282
+ print(f"Processing box {i+1}/{len(bboxes)}: {box.tolist()}")
283
+
284
+ # Run prediction with box
285
+ masks, scores, logits = predictor.predict(
286
+ point_coords=None,
287
+ point_labels=None,
288
+ box=box,
289
+ multimask_output=False
290
+ )
291
+
292
+ if len(masks) > 0:
293
+ best_idx = np.argmax(scores)
294
+ mask = masks[best_idx]
295
+ score = float(scores[best_idx])
296
+
297
+ masks_list.append({
298
+ 'mask': mask.astype(np.uint8).tolist(),
299
+ 'confidence': score
300
+ })
301
+ confidences_list.append(score)
302
+ print(f"Box {i+1} segmentation successful, confidence: {score:.4f}")
303
+ else:
304
+ print(f"Box {i+1} segmentation failed")
305
+
306
+ if masks_list:
307
+ result = {
308
+ 'success': True,
309
+ 'masks': masks_list,
310
+ 'confidences': confidences_list,
311
+ 'method': 'medsam_multiple_boxes'
312
+ }
313
+ else:
314
+ result = {'success': False, 'error': 'All segmentations failed'}
315
+
316
+ return json.dumps(result)
317
+
318
+ except Exception as e:
319
+ import traceback
320
+ return json.dumps({
321
+ 'success': False,
322
+ 'error': str(e),
323
+ 'traceback': traceback.format_exc()
324
+ })
325
+
326
+
327
+ # =============================================================================
328
+ # LEGACY API FUNCTIONS (kept for backwards compatibility with test scripts)
329
+ # =============================================================================
330
+
331
+ def segment_with_points_legacy(image, points_json):
332
+ """
333
+ Legacy API - Segment with point prompts using true point-based segmentation
334
+
335
+ Args:
336
+ points_json: JSON string with format:
337
+ {
338
+ "coords": [[x1, y1], [x2, y2], ...],
339
+ "labels": [1, 0, ...],
340
+ "multimask_output": true/false
341
+ }
342
+ """
343
+ try:
344
  points_data = json.loads(points_json)
345
  coords = np.array(points_data["coords"])
346
  labels = np.array(points_data["labels"])
347
  multimask_output = points_data.get("multimask_output", True)
348
 
 
349
  image_array = np.array(image)
 
 
350
  predictor.set_image(image_array)
351
 
 
352
  masks, scores, logits = predictor.predict(
353
  point_coords=coords,
354
  point_labels=labels,
355
  multimask_output=multimask_output
356
  )
357
 
 
358
  masks_list = []
359
  scores_list = []
360
 
361
  for i, (mask, score) in enumerate(zip(masks, scores)):
 
362
  mask_uint8 = (mask * 255).astype(np.uint8)
 
 
363
  mask_image = Image.fromarray(mask_uint8)
364
  buffer = io.BytesIO()
365
  mask_image.save(buffer, format='PNG')
 
368
  masks_list.append({
369
  'mask_base64': mask_base64,
370
  'mask_shape': mask.shape,
371
+ 'mask_data': mask.tolist()
372
  })
373
  scores_list.append(float(score))
374
 
375
+ return json.dumps({
376
  'success': True,
377
  'masks': masks_list,
378
  'scores': scores_list,
379
  'num_masks': len(masks_list)
380
+ })
 
 
381
 
382
  except Exception as e:
383
+ return json.dumps({'success': False, 'error': str(e)})
 
 
 
 
384
 
385
 
386
+ def segment_with_box_legacy(image, box_json):
387
  """
388
+ Legacy API - Segment with box prompt
389
 
390
  Args:
 
391
  box_json: JSON string with format:
392
+ {"box": [x1, y1, x2, y2], "multimask_output": false}
 
 
 
 
 
 
393
  """
394
  try:
 
395
  box_data = json.loads(box_json)
396
+ box = np.array(box_data["box"])
397
+ multimask_output = box_data.get("multimask_output", False)
398
 
 
399
  image_array = np.array(image)
 
 
400
  predictor.set_image(image_array)
401
 
 
402
  masks, scores, logits = predictor.predict(
403
  point_coords=None,
404
  point_labels=None,
 
406
  multimask_output=multimask_output
407
  )
408
 
 
409
  masks_list = []
410
  scores_list = []
411
 
412
  for i, (mask, score) in enumerate(zip(masks, scores)):
 
413
  mask_uint8 = (mask * 255).astype(np.uint8)
 
 
414
  mask_image = Image.fromarray(mask_uint8)
415
  buffer = io.BytesIO()
416
  mask_image.save(buffer, format='PNG')
 
423
  })
424
  scores_list.append(float(score))
425
 
426
+ return json.dumps({
427
  'success': True,
428
  'masks': masks_list,
429
  'scores': scores_list,
430
  'num_masks': len(masks_list),
431
  'box': box.tolist()
432
+ })
 
 
433
 
434
  except Exception as e:
435
  import traceback
436
+ return json.dumps({
437
  'success': False,
438
  'error': str(e),
439
  'traceback': traceback.format_exc()
440
+ })
 
441
 
442
 
443
  def segment_simple(image, x, y, label=1, multimask=True):
444
+ """Simple single-point segmentation for Gradio UI"""
 
 
 
 
 
 
 
 
 
 
 
 
445
  try:
446
  points_json = json.dumps({
447
  "coords": [[int(x), int(y)]],
 
449
  "multimask_output": multimask
450
  })
451
 
452
+ result_json = segment_with_points_legacy(image, points_json)
453
  result = json.loads(result_json)
454
 
455
  if not result['success']:
456
  return None, f"Error: {result['error']}"
457
 
 
458
  best_idx = np.argmax(result['scores'])
459
  best_mask_base64 = result['masks'][best_idx]['mask_base64']
460
  best_score = result['scores'][best_idx]
461
 
 
462
  mask_bytes = base64.b64decode(best_mask_base64)
463
  mask_image = Image.open(io.BytesIO(mask_bytes))
464
 
 
468
  return None, f"Error: {str(e)}"
469
 
470
 
471
+ # =============================================================================
472
+ # GRADIO INTERFACE
473
+ # =============================================================================
474
+
475
  with gr.Blocks(title="MedSAM Inference API") as demo:
476
  gr.Markdown("# 🏥 MedSAM Inference API")
477
+ gr.Markdown("Point and box-based segmentation using Fine-Tuned MedSAM")
478
+ gr.Markdown("**API-compatible with Dense-Captioning-Toolkit backend**")
479
 
480
  with gr.Tabs():
481
+ # Tab 1: Backend-Compatible API (Points)
482
+ with gr.Tab("Segment Points (Backend API)"):
483
  gr.Markdown("""
484
+ ## Point-based Segmentation - Backend Compatible
485
+
486
+ **Matches `/api/medsam/segment_points`**
487
+
488
+ Each point is converted to a small bounding box for segmentation.
489
 
490
  **Input Format:**
491
  ```json
492
  {
493
+ "points": [[x1, y1], [x2, y2], ...],
494
+ "labels": [1, 0, ...]
 
495
  }
496
  ```
497
 
498
+ **Output Format (matches backend):**
499
  ```json
500
  {
501
  "success": true,
502
+ "masks": [{"mask": [[...]], "confidence": 0.95}, ...],
503
+ "confidences": [0.95, ...],
504
+ "method": "medsam_points_individual"
505
  }
506
  ```
507
  """)
508
 
509
  with gr.Row():
510
  with gr.Column():
511
+ points_image = gr.Image(type="pil", label="Input Image")
512
+ points_json_input = gr.Textbox(
513
+ label="Request JSON",
514
+ placeholder='{"points": [[100, 150], [200, 200]], "labels": [1, 1]}',
515
  lines=3
516
  )
517
+ points_button = gr.Button("Segment Points", variant="primary")
518
 
519
  with gr.Column():
520
+ points_output = gr.Textbox(label="Result JSON", lines=15)
521
 
522
+ points_button.click(
523
+ fn=segment_points,
524
+ inputs=[points_image, points_json_input],
525
+ outputs=points_output,
526
+ api_name="segment_points"
527
  )
528
+
529
+ # Tab 2: Backend-Compatible API (Multiple Boxes)
530
+ with gr.Tab("Segment Multiple Boxes (Backend API)"):
531
+ gr.Markdown("""
532
+ ## Multiple Box Segmentation - Backend Compatible
533
+
534
+ **Matches `/api/medsam/segment_multiple_boxes`** (main frontend API)
535
+
536
+ **Input Format:**
537
+ ```json
538
+ {
539
+ "bboxes": [
540
+ [x1, y1, x2, y2],
541
+ {"x1": 10, "y1": 20, "x2": 100, "y2": 200}
542
+ ]
543
+ }
544
+ ```
545
+
546
+ **Output Format (matches backend):**
547
+ ```json
548
+ {
549
+ "success": true,
550
+ "masks": [{"mask": [[...]], "confidence": 0.95}, ...],
551
+ "confidences": [0.95, ...],
552
+ "method": "medsam_multiple_boxes"
553
+ }
554
+ ```
555
+ """)
556
+
557
+ with gr.Row():
558
+ with gr.Column():
559
+ multi_box_image = gr.Image(type="pil", label="Input Image")
560
+ multi_box_json = gr.Textbox(
561
+ label="Request JSON",
562
+ placeholder='{"bboxes": [[100, 100, 300, 300], [400, 400, 600, 600]]}',
563
+ lines=3
564
+ )
565
+ multi_box_button = gr.Button("Segment Multiple Boxes", variant="primary")
566
+
567
+ with gr.Column():
568
+ multi_box_output = gr.Textbox(label="Result JSON", lines=15)
569
 
570
+ multi_box_button.click(
571
+ fn=segment_multiple_boxes,
572
+ inputs=[multi_box_image, multi_box_json],
573
+ outputs=multi_box_output,
574
+ api_name="segment_multiple_boxes"
 
 
 
 
575
  )
576
 
577
+ # Tab 3: Backend-Compatible API (Single Box)
578
+ with gr.Tab("Segment Box (Backend API)"):
579
  gr.Markdown("""
580
+ ## Single Box Segmentation - Backend Compatible
581
 
582
+ **Matches `/api/medsam/segment_box`**
583
 
584
  **Input Format:**
585
  ```json
586
  {
587
+ "bbox": [x1, y1, x2, y2]
588
+ }
589
+ ```
590
+
591
+ **Output Format (matches backend):**
592
+ ```json
593
+ {
594
+ "success": true,
595
+ "mask": [[...]],
596
+ "confidence": 0.95,
597
+ "method": "medsam_box"
598
  }
599
  ```
 
600
  """)
601
 
602
  with gr.Row():
603
  with gr.Column():
604
  box_image = gr.Image(type="pil", label="Input Image")
605
+ box_json_input = gr.Textbox(
606
+ label="Request JSON",
607
+ placeholder='{"bbox": [100, 100, 300, 300]}',
608
  lines=3
609
  )
610
+ box_button = gr.Button("Segment Box", variant="primary")
611
 
612
  with gr.Column():
613
  box_output = gr.Textbox(label="Result JSON", lines=15)
614
 
615
  box_button.click(
616
+ fn=segment_box,
617
+ inputs=[box_image, box_json_input],
618
+ outputs=box_output,
619
+ api_name="segment_box"
620
+ )
621
+
622
+ # Tab 4: Legacy API (for test scripts)
623
+ with gr.Tab("Legacy API"):
624
+ gr.Markdown("""
625
+ ## Legacy API (for backwards compatibility)
626
+
627
+ Original API format with `coords`, `mask_data`, `scores`, etc.
628
+ Use if you have existing scripts using the old format.
629
+ """)
630
+
631
+ with gr.Row():
632
+ with gr.Column():
633
+ legacy_image = gr.Image(type="pil", label="Input Image")
634
+ legacy_points = gr.Textbox(
635
+ label="Points JSON (Legacy Format)",
636
+ placeholder='{"coords": [[100, 150]], "labels": [1], "multimask_output": true}',
637
+ lines=3
638
+ )
639
+ legacy_button = gr.Button("Run Segmentation (Legacy)", variant="secondary")
640
+
641
+ with gr.Column():
642
+ legacy_output = gr.Textbox(label="Result JSON", lines=15)
643
+
644
+ legacy_button.click(
645
+ fn=segment_with_points_legacy,
646
+ inputs=[legacy_image, legacy_points],
647
+ outputs=legacy_output,
648
+ api_name="segment_with_points" # Keep old API name for compatibility
649
  )
650
 
651
+ gr.Markdown("---")
652
+
653
+ with gr.Row():
654
+ with gr.Column():
655
+ legacy_box_image = gr.Image(type="pil", label="Input Image")
656
+ legacy_box_json = gr.Textbox(
657
+ label="Box JSON (Legacy Format)",
658
+ placeholder='{"box": [100, 100, 300, 300], "multimask_output": false}',
659
+ lines=3
660
+ )
661
+ legacy_box_button = gr.Button("Run Box Segmentation (Legacy)", variant="secondary")
662
+
663
+ with gr.Column():
664
+ legacy_box_output = gr.Textbox(label="Result JSON", lines=15)
665
+
666
+ legacy_box_button.click(
667
+ fn=segment_with_box_legacy,
668
+ inputs=[legacy_box_image, legacy_box_json],
669
+ outputs=legacy_box_output,
670
+ api_name="segment_with_box" # Keep old API name for compatibility
671
  )
672
 
673
+ # Tab 5: Simple UI Interface
674
  with gr.Tab("Simple Interface"):
675
  gr.Markdown("## Click-based Segmentation")
676
  gr.Markdown("Enter X, Y coordinates to segment")
 
705
 
706
  gr.Markdown("""
707
  ---
708
+ ### 📡 API Usage from Python (Backend-Compatible)
709
 
710
  ```python
711
+ from gradio_client import Client, handle_file
712
  import json
 
 
713
 
714
+ client = Client("Aniketg6/medsam-inference")
 
715
 
716
+ # Point-based segmentation (matches backend format)
717
+ result = client.predict(
718
+ image=handle_file("image.jpg"),
719
+ request_json=json.dumps({
720
+ "points": [[150, 200], [300, 400]],
721
+ "labels": [1, 1]
722
+ }),
723
+ api_name="/segment_points"
724
+ )
 
725
 
726
+ # Multiple box segmentation (main frontend API)
727
+ result = client.predict(
728
+ image=handle_file("image.jpg"),
729
+ request_json=json.dumps({
730
+ "bboxes": [[100, 100, 300, 300], [400, 400, 600, 600]]
731
+ }),
732
+ api_name="/segment_multiple_boxes"
 
 
733
  )
734
 
735
+ # Parse response
736
+ data = json.loads(result)
737
+ print(f"Success: {data['success']}")
738
+ print(f"Masks: {len(data['masks'])}")
739
+ print(f"Confidences: {data['confidences']}")
740
+ print(f"Method: {data['method']}")
741
  ```
742
  """)
743
 
 
747
  server_name="0.0.0.0",
748
  server_port=7860,
749
  share=False,
750
+ show_error=True
751
  )