Anigor66
Medsam added
0ee959c
"""
HuggingFace Space for SAM / MedSAM Inference
API-compatible with Dense-Captioning-Toolkit backend
Deploy this to: https://huggingface.co/spaces/YOUR_USERNAME/medsam-inference
"""
import gradio as gr
import torch
import numpy as np
from PIL import Image
import io
import json
import base64
import os
import uuid
from huggingface_hub import hf_hub_download
# Import SAM components
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
# Initialize model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# -----------------------------------------------------------------------------
# Model configuration
# -----------------------------------------------------------------------------
# 1) MedSAM (ViT-B) for interactive segmentation (points / boxes / multiple boxes)
# We assume medsam_vit_b.pth is committed in this repo (small enough for Spaces).
MEDSAM_CHECKPOINT = os.path.join(os.path.dirname(__file__), "medsam_vit_b.pth")
print("Loading MedSAM model (vit_b) for interactive segmentation...")
try:
# MedSAM checkpoints are typically state_dicts; load and apply to a vit_b SAM backbone.
state_dict = torch.load(MEDSAM_CHECKPOINT, map_location=device)
medsam = sam_model_registry["vit_b"](checkpoint=None)
medsam.load_state_dict(state_dict)
medsam.to(device=device)
medsam.eval()
print("✓ MedSAM model (vit_b) loaded successfully")
except Exception as e:
print(f"✗ Failed to load MedSAM model from {MEDSAM_CHECKPOINT}: {e}")
raise
# SamPredictor for interactive segmentation (point/box prompts) using MedSAM
predictor = SamPredictor(medsam)
print("✓ SamPredictor (MedSAM) initialized for interactive segmentation")
# 2) SAM ViT-H for automatic mask generation and embedding (encode_image)
# We download this large checkpoint from a separate model repo using hf_hub_download.
MODEL_REPO_ID = "Aniketg6/dense-captioning-models"
MODEL_FILENAME = "sam_vit_h_4b8939.pth" # change if your filename is different
MODEL_TYPE = "vit_h" # using SAM ViT-H (general-purpose SAM)
print(f"Downloading SAM (vit_h) checkpoint `{MODEL_FILENAME}` from repo `{MODEL_REPO_ID}`...")
SAM_CHECKPOINT = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME,
)
print(f"✓ SAM (vit_h) checkpoint downloaded to: {SAM_CHECKPOINT}")
print("Loading SAM model (vit_h) for auto masks and embeddings...")
# Monkey-patch torch.load to use CPU mapping when needed
original_torch_load = torch.load
def patched_torch_load(f, *args, **kwargs):
if "map_location" not in kwargs and device == "cpu":
kwargs["map_location"] = "cpu"
return original_torch_load(f, *args, **kwargs)
torch.load = patched_torch_load
try:
# Ensure we always load onto CPU when no GPU is available
torch.load = patched_torch_load
sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
finally:
torch.load = original_torch_load
sam.to(device=device)
sam.eval()
print("✓ SAM model (vit_h) loaded successfully")
# SamAutomaticMaskGenerator for automatic mask generation (SAM ViT-H)
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # Lighter grid (16x16) for faster CPU + smaller responses
pred_iou_thresh=0.7, # IoU threshold for filtering
stability_score_thresh=0.7, # Stability threshold
crop_n_layers=0, # Disable multi-scale crops to avoid IndexError
crop_n_points_downscale_factor=2,
min_mask_region_area=0 # Allow small masks (backend can filter if needed)
)
print("✓ SamAutomaticMaskGenerator (SAM vit-h) initialized for automatic segmentation")
# =============================================================================
# HELPER FUNCTIONS FOR EMBEDDINGS (STATELESS)
# =============================================================================
def set_predictor_features_from_embedding(embedding_tensor: torch.Tensor, image_shape: tuple):
"""
Set SamPredictor's internal features using precomputed embedding
Args:
embedding_tensor: Precomputed embedding tensor [1, C, H, W]
image_shape: Original image shape (height, width)
"""
# SamPredictor stores features in self.features
# We need to set it directly (this is a bit of a hack but necessary)
predictor.features = embedding_tensor
predictor.original_image_size = image_shape
predictor.input_size = (1024, 1024) # SAM default input size
predictor.is_image_set = True
# =============================================================================
# API FUNCTIONS - MATCHING BACKEND FORMAT (backend/app.py)
# =============================================================================
def encode_image(image, request_json):
"""
Encode image using SAM image encoder and return embedding to the client.
This is now a stateless API: it does NOT talk to Supabase. The caller
(your backend) is responsible for storing the embedding if desired.
Args:
image: PIL Image
request_json: JSON string with optional fields:
{
"image_id": "uuid-string" # Optional: image ID from your DB
}
Returns:
JSON string:
{
"success": true/false,
"image_id": "uuid-string" or null,
"embedding_npy_base64": "...", # base64-encoded .npy of [C,H,W]
"embedding_shape": [1, C, H, W]
}
"""
try:
# Parse input (image_id is optional and just echoed back)
data = json.loads(request_json) if request_json else {}
image_id = data.get("image_id")
# Convert PIL to numpy
image_array = np.array(image)
H, W = image_array.shape[:2]
# Resize image to SAM's expected input size (1024x1024)
from skimage import transform
img_resized = transform.resize(
image_array,
(1024, 1024),
order=3,
preserve_range=True,
anti_aliasing=True,
).astype(np.uint8)
# Normalize image (SAM expects normalized input)
img_norm = (img_resized - img_resized.min()) / np.clip(
img_resized.max() - img_resized.min(), 1e-8, None
)
# Convert to tensor and add batch dimension
tensor = (
torch.tensor(img_norm)
.float()
.permute(2, 0, 1)
.unsqueeze(0)
.to(device)
)
# Encode image using SAM image encoder
print(f"Encoding image (image_id={image_id}) original size: {W}x{H} -> 1024x1024")
with torch.no_grad():
embedding = sam.image_encoder(tensor)
# Convert embedding to numpy [C, Hf, Wf]
arr = embedding.squeeze(0).cpu().numpy().astype(np.float32)
# Serialize as .npy in memory and base64-encode it
buf = io.BytesIO()
np.save(buf, arr)
buf.seek(0)
embedding_b64 = base64.b64encode(buf.read()).decode("utf-8")
return json.dumps(
{
"success": True,
"image_id": image_id,
"embedding_npy_base64": embedding_b64,
"embedding_shape": list(embedding.shape),
}
)
except Exception as e:
import traceback
return json.dumps(
{
"success": False,
"error": str(e),
"traceback": traceback.format_exc(),
}
)
def segment_points(image, request_json):
"""
Segment image with point prompts - MATCHES BACKEND /api/medsam/segment_points
Each point gets its own small segment (converted to small bounding box).
This matches the backend behavior where points are converted to small boxes.
Args:
image: PIL Image
request_json: JSON string with format:
{
"points": [[x1, y1], [x2, y2], ...],
"labels": [1, 0, ...] # 1=foreground, 0=background
}
Returns:
JSON string matching backend response format:
{
"success": true,
"masks": [{"mask": [[...]], "confidence": 0.95}, ...],
"confidences": [0.95, ...],
"method": "medsam_points_individual"
}
"""
try:
# Parse input
data = json.loads(request_json)
points = data.get("points", [])
labels = data.get("labels", [])
image_id = data.get("image_id") # Optional: if provided, use precomputed embedding
if not points:
return json.dumps({'success': False, 'error': 'At least one point is required'})
# Convert PIL to numpy
image_array = np.array(image)
H, W = image_array.shape[:2]
# For now, always compute embedding from image (stateless API)
predictor.set_image(image_array)
# Process each point individually (like backend does)
box_size = 20 # Small box size for point-based segmentation
masks_list = []
confidences_list = []
for i, pt in enumerate(points):
x, y = pt
# Create a small bounding box centered on the point (matching backend behavior)
x1 = max(0, x - box_size // 2)
y1 = max(0, y - box_size // 2)
x2 = min(W - 1, x + box_size // 2)
y2 = min(H - 1, y + box_size // 2)
bbox = np.array([x1, y1, x2, y2])
print(f"Processing point {i+1}/{len(points)}: ({x}, {y}) -> bbox: {bbox.tolist()}")
# Run prediction with box
masks, scores, logits = predictor.predict(
point_coords=None,
point_labels=None,
box=bbox,
multimask_output=False
)
if len(masks) > 0:
# Take the best mask
best_idx = np.argmax(scores)
mask = masks[best_idx]
score = float(scores[best_idx])
masks_list.append({
'mask': mask.astype(np.uint8).tolist(),
'confidence': score
})
confidences_list.append(score)
print(f"Point {i+1} segmentation successful, confidence: {score:.4f}")
else:
print(f"Point {i+1} segmentation failed")
if masks_list:
result = {
'success': True,
'masks': masks_list,
'confidences': confidences_list,
'method': 'medsam_points_individual'
}
else:
result = {'success': False, 'error': 'All point segmentations failed'}
return json.dumps(result)
except Exception as e:
import traceback
return json.dumps({
'success': False,
'error': str(e),
'traceback': traceback.format_exc()
})
def segment_box(image, request_json):
"""
Segment image with a single bounding box - MATCHES BACKEND /api/medsam/segment_box
Args:
image: PIL Image
request_json: JSON string with format:
{
"bbox": [x1, y1, x2, y2] # Can be array or object with x1,y1,x2,y2
}
Returns:
JSON string matching backend response format:
{
"success": true,
"mask": [[...]],
"confidence": 0.95,
"method": "medsam_box"
}
"""
try:
# Parse input
data = json.loads(request_json)
bbox = data.get("bbox", [])
# Handle both array format [x1,y1,x2,y2] and object format {x1,y1,x2,y2}
if isinstance(bbox, dict):
bbox = [bbox.get('x1', 0), bbox.get('y1', 0), bbox.get('x2', 0), bbox.get('y2', 0)]
if not bbox or len(bbox) != 4:
return json.dumps({'success': False, 'error': 'Valid bounding box required [x1, y1, x2, y2]'})
box = np.array(bbox)
# Convert PIL to numpy
image_array = np.array(image)
# Stateless: always compute embedding from image
predictor.set_image(image_array)
# Run prediction with box
masks, scores, logits = predictor.predict(
point_coords=None,
point_labels=None,
box=box,
multimask_output=False
)
if len(masks) > 0:
best_idx = np.argmax(scores)
mask = masks[best_idx]
score = float(scores[best_idx])
result = {
'success': True,
'mask': mask.astype(np.uint8).tolist(),
'confidence': score,
'method': 'medsam_box'
}
else:
result = {'success': False, 'error': 'Segmentation failed'}
return json.dumps(result)
except Exception as e:
import traceback
return json.dumps({
'success': False,
'error': str(e),
'traceback': traceback.format_exc()
})
def segment_multiple_boxes(image, request_json):
"""
Segment image with multiple bounding boxes - MATCHES BACKEND /api/medsam/segment_multiple_boxes
This is the main API endpoint used by the frontend for box-based segmentation.
Args:
image: PIL Image
request_json: JSON string with format:
{
"bboxes": [
[x1, y1, x2, y2], # Array format
{"x1": 10, "y1": 20, "x2": 100, "y2": 200} # Object format (also supported)
]
}
Returns:
JSON string matching backend response format:
{
"success": true,
"masks": [{"mask": [[...]], "confidence": 0.95}, ...],
"confidences": [0.95, ...],
"method": "medsam_multiple_boxes"
}
"""
try:
# Parse input
data = json.loads(request_json)
bboxes = data.get("bboxes", [])
if not bboxes:
return json.dumps({'success': False, 'error': 'At least one bounding box is required'})
# Convert PIL to numpy
image_array = np.array(image)
# Stateless: always compute embedding from image
predictor.set_image(image_array)
print(f"Processing {len(bboxes)} boxes for segmentation")
masks_list = []
confidences_list = []
for i, bbox in enumerate(bboxes):
# Handle both array format [x1,y1,x2,y2] and object format {x1,y1,x2,y2}
if isinstance(bbox, dict):
box = np.array([
bbox.get('x1', 0),
bbox.get('y1', 0),
bbox.get('x2', 0),
bbox.get('y2', 0)
])
else:
box = np.array(bbox)
print(f"Processing box {i+1}/{len(bboxes)}: {box.tolist()}")
# Run prediction with box
masks, scores, logits = predictor.predict(
point_coords=None,
point_labels=None,
box=box,
multimask_output=False
)
if len(masks) > 0:
best_idx = np.argmax(scores)
mask = masks[best_idx]
score = float(scores[best_idx])
masks_list.append({
'mask': mask.astype(np.uint8).tolist(),
'confidence': score
})
confidences_list.append(score)
print(f"Box {i+1} segmentation successful, confidence: {score:.4f}")
else:
print(f"Box {i+1} segmentation failed")
if masks_list:
result = {
'success': True,
'masks': masks_list,
'confidences': confidences_list,
'method': 'medsam_multiple_boxes'
}
else:
result = {'success': False, 'error': 'All segmentations failed'}
return json.dumps(result)
except Exception as e:
import traceback
return json.dumps({
'success': False,
'error': str(e),
'traceback': traceback.format_exc()
})
# =============================================================================
# AUTO MASK GENERATION API (replaces local mask_generator.generate())
# =============================================================================
def generate_auto_masks(image, request_json):
"""
Automatically generate all masks for an image using SAM-H model.
This is equivalent to `mask_generator.generate(img_np)` in enhanced_preprocessing.py
Args:
image: PIL Image
request_json: JSON string with optional parameters:
{
"points_per_side": 32, # Grid density (default: 32)
"pred_iou_thresh": 0.88, # IoU threshold (default: 0.88)
"stability_score_thresh": 0.95, # Stability threshold (default: 0.95)
"min_mask_region_area": 0 # Minimum mask area (default: 0)
}
Returns:
JSON string with format matching SamAutomaticMaskGenerator output:
{
"success": true,
"masks": [
{
"segmentation": [[...2D boolean array...]],
"area": 12345,
"bbox": [x, y, width, height],
"predicted_iou": 0.95,
"point_coords": [[x, y]],
"stability_score": 0.98,
"crop_box": [x, y, width, height]
},
...
],
"num_masks": 42,
"image_size": [height, width]
}
"""
try:
if mask_generator is None:
return json.dumps({
'success': False,
'error': 'MedSAM model not loaded. Please ensure medsam_vit_b.pth is available.',
'available': False
})
# Parse optional parameters
params = {}
if request_json:
try:
params = json.loads(request_json) if request_json.strip() else {}
except:
params = {}
# Convert PIL to numpy
image_array = np.array(image)
H, W = image_array.shape[:2]
# Optional downscaling to keep masks smaller / faster
resize_longest = int(params.get("resize_longest", 0) or 0)
if resize_longest > 0 and max(H, W) > resize_longest:
scale = resize_longest / float(max(H, W))
new_w = max(1, int(W * scale))
new_h = max(1, int(H * scale))
print(f"Resizing image from {W}x{H} to {new_w}x{new_h} for auto masks...")
image_array = np.array(Image.fromarray(image_array).resize((new_w, new_h)))
H, W = image_array.shape[:2]
print(f"Generating automatic masks for image of size {W}x{H}...")
# Generate masks using SAM automatic mask generator
masks = mask_generator.generate(image_array)
print(f"Generated {len(masks)} masks")
if len(masks) > 0:
# Log some stats about the masks
areas = [m['area'] for m in masks]
ious = [m['predicted_iou'] for m in masks]
stabilities = [m['stability_score'] for m in masks]
print(f" Area range: {min(areas)} - {max(areas)} pixels")
print(f" IoU range: {min(ious):.3f} - {max(ious):.3f}")
print(f" Stability range: {min(stabilities):.3f} - {max(stabilities):.3f}")
else:
print(" WARNING: No masks generated! This could mean:")
print(" - Image is too uniform/simple")
print(" - Thresholds are still too strict")
print(" - Image size is too small or too large")
# Optionally limit number of masks returned to keep JSON payload reasonable
max_masks = int(params.get("max_masks", 10))
if max_masks > 0 and len(masks) > max_masks:
# Sort by predicted IoU (descending) and keep top-K
print(f"Limiting masks from {len(masks)} to top {max_masks} by predicted_iou")
masks = sorted(
masks,
key=lambda m: float(m.get("predicted_iou", 0.0)),
reverse=True,
)[:max_masks]
print(f"Preparing {len(masks)} masks to return to client...")
# Convert masks to JSON-serializable format
masks_output = []
for m in masks:
mask_data = {
"segmentation": m["segmentation"].astype(np.uint8).tolist(),
"area": int(m["area"]),
"bbox": [int(x) for x in m["bbox"]], # [x, y, width, height]
"predicted_iou": float(m["predicted_iou"]),
"point_coords": [
[int(p[0]), int(p[1])] for p in m["point_coords"]
]
if m["point_coords"] is not None
else [],
"stability_score": float(m["stability_score"]),
"crop_box": [int(x) for x in m["crop_box"]], # [x, y, width, height]
}
masks_output.append(mask_data)
result = {
'success': True,
'masks': masks_output,
'num_masks': len(masks_output),
'image_size': [H, W]
}
print(f"Auto mask generation complete: {len(masks_output)} masks")
return json.dumps(result)
except Exception as e:
import traceback
return json.dumps({
'success': False,
'error': str(e),
'traceback': traceback.format_exc()
})
def check_auto_mask_status():
"""
Check if automatic mask generation is available
"""
return json.dumps({
'available': mask_generator is not None,
'model': MODEL_FILENAME if mask_generator else None,
'model_type': MODEL_TYPE,
'device': str(device)
})
# =============================================================================
# LEGACY API FUNCTIONS (kept for backwards compatibility with test scripts)
# =============================================================================
def segment_with_points_legacy(image, points_json):
"""
Legacy API - Segment with point prompts using true point-based segmentation
Args:
points_json: JSON string with format:
{
"coords": [[x1, y1], [x2, y2], ...],
"labels": [1, 0, ...],
"multimask_output": true/false
}
"""
try:
points_data = json.loads(points_json)
coords = np.array(points_data["coords"])
labels = np.array(points_data["labels"])
multimask_output = points_data.get("multimask_output", True)
image_array = np.array(image)
predictor.set_image(image_array)
masks, scores, logits = predictor.predict(
point_coords=coords,
point_labels=labels,
multimask_output=multimask_output
)
masks_list = []
scores_list = []
for i, (mask, score) in enumerate(zip(masks, scores)):
mask_uint8 = (mask * 255).astype(np.uint8)
mask_image = Image.fromarray(mask_uint8)
buffer = io.BytesIO()
mask_image.save(buffer, format='PNG')
mask_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
masks_list.append({
'mask_base64': mask_base64,
'mask_shape': mask.shape,
'mask_data': mask.tolist()
})
scores_list.append(float(score))
return json.dumps({
'success': True,
'masks': masks_list,
'scores': scores_list,
'num_masks': len(masks_list)
})
except Exception as e:
return json.dumps({'success': False, 'error': str(e)})
def segment_with_box_legacy(image, box_json):
"""
Legacy API - Segment with box prompt
Args:
box_json: JSON string with format:
{"box": [x1, y1, x2, y2], "multimask_output": false}
"""
try:
box_data = json.loads(box_json)
box = np.array(box_data["box"])
multimask_output = box_data.get("multimask_output", False)
image_array = np.array(image)
predictor.set_image(image_array)
masks, scores, logits = predictor.predict(
point_coords=None,
point_labels=None,
box=box,
multimask_output=multimask_output
)
masks_list = []
scores_list = []
for i, (mask, score) in enumerate(zip(masks, scores)):
mask_uint8 = (mask * 255).astype(np.uint8)
mask_image = Image.fromarray(mask_uint8)
buffer = io.BytesIO()
mask_image.save(buffer, format='PNG')
mask_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
masks_list.append({
'mask_base64': mask_base64,
'mask_shape': mask.shape,
'mask_data': mask.tolist()
})
scores_list.append(float(score))
return json.dumps({
'success': True,
'masks': masks_list,
'scores': scores_list,
'num_masks': len(masks_list),
'box': box.tolist()
})
except Exception as e:
import traceback
return json.dumps({
'success': False,
'error': str(e),
'traceback': traceback.format_exc()
})
def segment_simple(image, x, y, label=1, multimask=True):
"""Simple single-point segmentation for Gradio UI"""
try:
points_json = json.dumps({
"coords": [[int(x), int(y)]],
"labels": [int(label)],
"multimask_output": multimask
})
result_json = segment_with_points_legacy(image, points_json)
result = json.loads(result_json)
if not result['success']:
return None, f"Error: {result['error']}"
best_idx = np.argmax(result['scores'])
best_mask_base64 = result['masks'][best_idx]['mask_base64']
best_score = result['scores'][best_idx]
mask_bytes = base64.b64decode(best_mask_base64)
mask_image = Image.open(io.BytesIO(mask_bytes))
return mask_image, f"Score: {best_score:.4f}"
except Exception as e:
return None, f"Error: {str(e)}"
# =============================================================================
# GRADIO INTERFACE
# =============================================================================
with gr.Blocks(title="MedSAM Inference API") as demo:
gr.Markdown("# 🏥 MedSAM Inference API")
gr.Markdown("Point and box-based segmentation using Fine-Tuned MedSAM")
gr.Markdown("**API-compatible with Dense-Captioning-Toolkit backend**")
with gr.Tabs():
# Tab 1: Backend-Compatible API (Points)
with gr.Tab("Segment Points (Backend API)"):
gr.Markdown("""
## Point-based Segmentation - Backend Compatible
**Matches `/api/medsam/segment_points`**
Each point is converted to a small bounding box for segmentation.
**Input Format:**
```json
{
"points": [[x1, y1], [x2, y2], ...],
"labels": [1, 0, ...]
}
```
**Output Format (matches backend):**
```json
{
"success": true,
"masks": [{"mask": [[...]], "confidence": 0.95}, ...],
"confidences": [0.95, ...],
"method": "medsam_points_individual"
}
```
""")
with gr.Row():
with gr.Column():
points_image = gr.Image(type="pil", label="Input Image")
points_json_input = gr.Textbox(
label="Request JSON",
placeholder='{"points": [[100, 150], [200, 200]], "labels": [1, 1]}',
lines=3
)
points_button = gr.Button("Segment Points", variant="primary")
with gr.Column():
points_output = gr.Textbox(label="Result JSON", lines=15)
points_button.click(
fn=segment_points,
inputs=[points_image, points_json_input],
outputs=points_output,
api_name="segment_points"
)
# Tab 2: Backend-Compatible API (Multiple Boxes)
with gr.Tab("Segment Multiple Boxes (Backend API)"):
gr.Markdown("""
## Multiple Box Segmentation - Backend Compatible
**Matches `/api/medsam/segment_multiple_boxes`** (main frontend API)
**Input Format:**
```json
{
"bboxes": [
[x1, y1, x2, y2],
{"x1": 10, "y1": 20, "x2": 100, "y2": 200}
]
}
```
**Output Format (matches backend):**
```json
{
"success": true,
"masks": [{"mask": [[...]], "confidence": 0.95}, ...],
"confidences": [0.95, ...],
"method": "medsam_multiple_boxes"
}
```
""")
with gr.Row():
with gr.Column():
multi_box_image = gr.Image(type="pil", label="Input Image")
multi_box_json = gr.Textbox(
label="Request JSON",
placeholder='{"bboxes": [[100, 100, 300, 300], [400, 400, 600, 600]]}',
lines=3
)
multi_box_button = gr.Button("Segment Multiple Boxes", variant="primary")
with gr.Column():
multi_box_output = gr.Textbox(label="Result JSON", lines=15)
multi_box_button.click(
fn=segment_multiple_boxes,
inputs=[multi_box_image, multi_box_json],
outputs=multi_box_output,
api_name="segment_multiple_boxes"
)
# Tab 3: Backend-Compatible API (Single Box)
with gr.Tab("Segment Box (Backend API)"):
gr.Markdown("""
## Single Box Segmentation - Backend Compatible
**Matches `/api/medsam/segment_box`**
**Input Format:**
```json
{
"bbox": [x1, y1, x2, y2]
}
```
**Output Format (matches backend):**
```json
{
"success": true,
"mask": [[...]],
"confidence": 0.95,
"method": "medsam_box"
}
```
""")
with gr.Row():
with gr.Column():
box_image = gr.Image(type="pil", label="Input Image")
box_json_input = gr.Textbox(
label="Request JSON",
placeholder='{"bbox": [100, 100, 300, 300]}',
lines=3
)
box_button = gr.Button("Segment Box", variant="primary")
with gr.Column():
box_output = gr.Textbox(label="Result JSON", lines=15)
box_button.click(
fn=segment_box,
inputs=[box_image, box_json_input],
outputs=box_output,
api_name="segment_box"
)
# Tab 4: Legacy API (for test scripts)
with gr.Tab("Legacy API"):
gr.Markdown("""
## Legacy API (for backwards compatibility)
Original API format with `coords`, `mask_data`, `scores`, etc.
Use if you have existing scripts using the old format.
""")
with gr.Row():
with gr.Column():
legacy_image = gr.Image(type="pil", label="Input Image")
legacy_points = gr.Textbox(
label="Points JSON (Legacy Format)",
placeholder='{"coords": [[100, 150]], "labels": [1], "multimask_output": true}',
lines=3
)
legacy_button = gr.Button("Run Segmentation (Legacy)", variant="secondary")
with gr.Column():
legacy_output = gr.Textbox(label="Result JSON", lines=15)
legacy_button.click(
fn=segment_with_points_legacy,
inputs=[legacy_image, legacy_points],
outputs=legacy_output,
api_name="segment_with_points" # Keep old API name for compatibility
)
gr.Markdown("---")
with gr.Row():
with gr.Column():
legacy_box_image = gr.Image(type="pil", label="Input Image")
legacy_box_json = gr.Textbox(
label="Box JSON (Legacy Format)",
placeholder='{"box": [100, 100, 300, 300], "multimask_output": false}',
lines=3
)
legacy_box_button = gr.Button("Run Box Segmentation (Legacy)", variant="secondary")
with gr.Column():
legacy_box_output = gr.Textbox(label="Result JSON", lines=15)
legacy_box_button.click(
fn=segment_with_box_legacy,
inputs=[legacy_box_image, legacy_box_json],
outputs=legacy_box_output,
api_name="segment_with_box" # Keep old API name for compatibility
)
# Tab 5: Auto Mask Generation (for preprocessing)
with gr.Tab("Auto Mask Generation"):
gr.Markdown("""
## Automatic Mask Generation (MedSAM)
**Replaces `mask_generator.generate(img_np)` in preprocessing pipeline**
Uses MedSAM (ViT-B) model with `SamAutomaticMaskGenerator` to automatically
segment all objects in an image. This is used for initial preprocessing
of scientific/medical images.
Uses the same `medsam_vit_b.pth` model as interactive segmentation.
**Output Format:**
```json
{
"success": true,
"masks": [
{
"segmentation": [[...2D array...]],
"area": 12345,
"bbox": [x, y, width, height],
"predicted_iou": 0.95,
"point_coords": [[x, y]],
"stability_score": 0.98,
"crop_box": [x, y, width, height]
}
],
"num_masks": 42
}
```
""")
with gr.Row():
with gr.Column():
auto_image = gr.Image(type="pil", label="Input Image")
auto_params = gr.Textbox(
label="Parameters (optional)",
placeholder='{"points_per_side": 32, "pred_iou_thresh": 0.88}',
lines=2
)
with gr.Row():
auto_button = gr.Button("Generate All Masks", variant="primary")
status_button = gr.Button("Check Status", variant="secondary")
with gr.Column():
auto_output = gr.Textbox(label="Result JSON", lines=20)
status_output = gr.Textbox(label="Status", lines=3)
auto_button.click(
fn=generate_auto_masks,
inputs=[auto_image, auto_params],
outputs=auto_output,
api_name="generate_auto_masks"
)
status_button.click(
fn=check_auto_mask_status,
inputs=[],
outputs=status_output,
api_name="check_auto_mask_status"
)
# Tab 6: Encode Image (for embedding storage)
with gr.Tab("Encode Image"):
gr.Markdown("""
## Image Encoding API
**Encodes image using SAM image encoder and saves embedding to Supabase**
This endpoint is used during preprocessing to compute and store image embeddings
once per image. Later segmentation calls can use these precomputed embeddings
for faster inference (no need to recompute embeddings on each API call).
**Input Format:**
```json
{
"image_id": "uuid-string" # Required: image ID from database
}
```
**Output Format:**
```json
{
"success": true,
"message": "Embedding saved successfully for image_id=...",
"image_id": "uuid-string",
"embedding_shape": [1, 256, 64, 64]
}
```
**Note:** Requires Supabase credentials (SUPABASE_URL and SUPABASE_KEY environment variables)
""")
with gr.Row():
with gr.Column():
encode_image_input = gr.Image(type="pil", label="Input Image")
encode_json_input = gr.Textbox(
label="Request JSON",
placeholder='{"image_id": "123e4567-e89b-12d3-a456-426614174000"}',
lines=2
)
encode_button = gr.Button("Encode Image", variant="primary")
with gr.Column():
encode_output = gr.Textbox(label="Result JSON", lines=10)
encode_button.click(
fn=encode_image,
inputs=[encode_image_input, encode_json_input],
outputs=encode_output,
api_name="encode_image"
)
# Tab 7: Simple UI Interface
with gr.Tab("Simple Interface"):
gr.Markdown("## Click-based Segmentation")
gr.Markdown("Enter X, Y coordinates to segment")
with gr.Row():
with gr.Column():
simple_image = gr.Image(type="pil", label="Input Image")
with gr.Row():
simple_x = gr.Number(label="X Coordinate", value=100)
simple_y = gr.Number(label="Y Coordinate", value=100)
with gr.Row():
simple_label = gr.Radio(
choices=[1, 0],
value=1,
label="Point Label (1=foreground, 0=background)"
)
simple_multimask = gr.Checkbox(
label="Multiple Masks",
value=True
)
simple_button = gr.Button("Segment", variant="primary")
with gr.Column():
simple_mask = gr.Image(label="Output Mask")
simple_info = gr.Textbox(label="Info")
simple_button.click(
fn=segment_simple,
inputs=[simple_image, simple_x, simple_y, simple_label, simple_multimask],
outputs=[simple_mask, simple_info]
)
gr.Markdown("""
---
### 📡 API Usage from Python (Backend-Compatible)
```python
from gradio_client import Client, handle_file
import json
client = Client("Aniketg6/medsam-inference")
# Point-based segmentation (matches backend format)
result = client.predict(
image=handle_file("image.jpg"),
request_json=json.dumps({
"points": [[150, 200], [300, 400]],
"labels": [1, 1]
}),
api_name="/segment_points"
)
# Multiple box segmentation (main frontend API)
result = client.predict(
image=handle_file("image.jpg"),
request_json=json.dumps({
"bboxes": [[100, 100, 300, 300], [400, 400, 600, 600]]
}),
api_name="/segment_multiple_boxes"
)
# Parse response
data = json.loads(result)
print(f"Success: {data['success']}")
print(f"Masks: {len(data['masks'])}")
print(f"Confidences: {data['confidences']}")
print(f"Method: {data['method']}")
```
""")
# Launch
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)