""" HuggingFace Space for SAM / MedSAM Inference API-compatible with Dense-Captioning-Toolkit backend Deploy this to: https://huggingface.co/spaces/YOUR_USERNAME/medsam-inference """ import gradio as gr import torch import numpy as np from PIL import Image import io import json import base64 import os import uuid from huggingface_hub import hf_hub_download # Import SAM components from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator # Initialize model device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # ----------------------------------------------------------------------------- # Model configuration # ----------------------------------------------------------------------------- # 1) MedSAM (ViT-B) for interactive segmentation (points / boxes / multiple boxes) # We assume medsam_vit_b.pth is committed in this repo (small enough for Spaces). MEDSAM_CHECKPOINT = os.path.join(os.path.dirname(__file__), "medsam_vit_b.pth") print("Loading MedSAM model (vit_b) for interactive segmentation...") try: # MedSAM checkpoints are typically state_dicts; load and apply to a vit_b SAM backbone. state_dict = torch.load(MEDSAM_CHECKPOINT, map_location=device) medsam = sam_model_registry["vit_b"](checkpoint=None) medsam.load_state_dict(state_dict) medsam.to(device=device) medsam.eval() print("✓ MedSAM model (vit_b) loaded successfully") except Exception as e: print(f"✗ Failed to load MedSAM model from {MEDSAM_CHECKPOINT}: {e}") raise # SamPredictor for interactive segmentation (point/box prompts) using MedSAM predictor = SamPredictor(medsam) print("✓ SamPredictor (MedSAM) initialized for interactive segmentation") # 2) SAM ViT-H for automatic mask generation and embedding (encode_image) # We download this large checkpoint from a separate model repo using hf_hub_download. MODEL_REPO_ID = "Aniketg6/dense-captioning-models" MODEL_FILENAME = "sam_vit_h_4b8939.pth" # change if your filename is different MODEL_TYPE = "vit_h" # using SAM ViT-H (general-purpose SAM) print(f"Downloading SAM (vit_h) checkpoint `{MODEL_FILENAME}` from repo `{MODEL_REPO_ID}`...") SAM_CHECKPOINT = hf_hub_download( repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME, ) print(f"✓ SAM (vit_h) checkpoint downloaded to: {SAM_CHECKPOINT}") print("Loading SAM model (vit_h) for auto masks and embeddings...") # Monkey-patch torch.load to use CPU mapping when needed original_torch_load = torch.load def patched_torch_load(f, *args, **kwargs): if "map_location" not in kwargs and device == "cpu": kwargs["map_location"] = "cpu" return original_torch_load(f, *args, **kwargs) torch.load = patched_torch_load try: # Ensure we always load onto CPU when no GPU is available torch.load = patched_torch_load sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT) finally: torch.load = original_torch_load sam.to(device=device) sam.eval() print("✓ SAM model (vit_h) loaded successfully") # SamAutomaticMaskGenerator for automatic mask generation (SAM ViT-H) mask_generator = SamAutomaticMaskGenerator( model=sam, points_per_side=16, # Lighter grid (16x16) for faster CPU + smaller responses pred_iou_thresh=0.7, # IoU threshold for filtering stability_score_thresh=0.7, # Stability threshold crop_n_layers=0, # Disable multi-scale crops to avoid IndexError crop_n_points_downscale_factor=2, min_mask_region_area=0 # Allow small masks (backend can filter if needed) ) print("✓ SamAutomaticMaskGenerator (SAM vit-h) initialized for automatic segmentation") # ============================================================================= # HELPER FUNCTIONS FOR EMBEDDINGS (STATELESS) # ============================================================================= def set_predictor_features_from_embedding(embedding_tensor: torch.Tensor, image_shape: tuple): """ Set SamPredictor's internal features using precomputed embedding Args: embedding_tensor: Precomputed embedding tensor [1, C, H, W] image_shape: Original image shape (height, width) """ # SamPredictor stores features in self.features # We need to set it directly (this is a bit of a hack but necessary) predictor.features = embedding_tensor predictor.original_image_size = image_shape predictor.input_size = (1024, 1024) # SAM default input size predictor.is_image_set = True # ============================================================================= # API FUNCTIONS - MATCHING BACKEND FORMAT (backend/app.py) # ============================================================================= def encode_image(image, request_json): """ Encode image using SAM image encoder and return embedding to the client. This is now a stateless API: it does NOT talk to Supabase. The caller (your backend) is responsible for storing the embedding if desired. Args: image: PIL Image request_json: JSON string with optional fields: { "image_id": "uuid-string" # Optional: image ID from your DB } Returns: JSON string: { "success": true/false, "image_id": "uuid-string" or null, "embedding_npy_base64": "...", # base64-encoded .npy of [C,H,W] "embedding_shape": [1, C, H, W] } """ try: # Parse input (image_id is optional and just echoed back) data = json.loads(request_json) if request_json else {} image_id = data.get("image_id") # Convert PIL to numpy image_array = np.array(image) H, W = image_array.shape[:2] # Resize image to SAM's expected input size (1024x1024) from skimage import transform img_resized = transform.resize( image_array, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True, ).astype(np.uint8) # Normalize image (SAM expects normalized input) img_norm = (img_resized - img_resized.min()) / np.clip( img_resized.max() - img_resized.min(), 1e-8, None ) # Convert to tensor and add batch dimension tensor = ( torch.tensor(img_norm) .float() .permute(2, 0, 1) .unsqueeze(0) .to(device) ) # Encode image using SAM image encoder print(f"Encoding image (image_id={image_id}) original size: {W}x{H} -> 1024x1024") with torch.no_grad(): embedding = sam.image_encoder(tensor) # Convert embedding to numpy [C, Hf, Wf] arr = embedding.squeeze(0).cpu().numpy().astype(np.float32) # Serialize as .npy in memory and base64-encode it buf = io.BytesIO() np.save(buf, arr) buf.seek(0) embedding_b64 = base64.b64encode(buf.read()).decode("utf-8") return json.dumps( { "success": True, "image_id": image_id, "embedding_npy_base64": embedding_b64, "embedding_shape": list(embedding.shape), } ) except Exception as e: import traceback return json.dumps( { "success": False, "error": str(e), "traceback": traceback.format_exc(), } ) def segment_points(image, request_json): """ Segment image with point prompts - MATCHES BACKEND /api/medsam/segment_points Each point gets its own small segment (converted to small bounding box). This matches the backend behavior where points are converted to small boxes. Args: image: PIL Image request_json: JSON string with format: { "points": [[x1, y1], [x2, y2], ...], "labels": [1, 0, ...] # 1=foreground, 0=background } Returns: JSON string matching backend response format: { "success": true, "masks": [{"mask": [[...]], "confidence": 0.95}, ...], "confidences": [0.95, ...], "method": "medsam_points_individual" } """ try: # Parse input data = json.loads(request_json) points = data.get("points", []) labels = data.get("labels", []) image_id = data.get("image_id") # Optional: if provided, use precomputed embedding if not points: return json.dumps({'success': False, 'error': 'At least one point is required'}) # Convert PIL to numpy image_array = np.array(image) H, W = image_array.shape[:2] # For now, always compute embedding from image (stateless API) predictor.set_image(image_array) # Process each point individually (like backend does) box_size = 20 # Small box size for point-based segmentation masks_list = [] confidences_list = [] for i, pt in enumerate(points): x, y = pt # Create a small bounding box centered on the point (matching backend behavior) x1 = max(0, x - box_size // 2) y1 = max(0, y - box_size // 2) x2 = min(W - 1, x + box_size // 2) y2 = min(H - 1, y + box_size // 2) bbox = np.array([x1, y1, x2, y2]) print(f"Processing point {i+1}/{len(points)}: ({x}, {y}) -> bbox: {bbox.tolist()}") # Run prediction with box masks, scores, logits = predictor.predict( point_coords=None, point_labels=None, box=bbox, multimask_output=False ) if len(masks) > 0: # Take the best mask best_idx = np.argmax(scores) mask = masks[best_idx] score = float(scores[best_idx]) masks_list.append({ 'mask': mask.astype(np.uint8).tolist(), 'confidence': score }) confidences_list.append(score) print(f"Point {i+1} segmentation successful, confidence: {score:.4f}") else: print(f"Point {i+1} segmentation failed") if masks_list: result = { 'success': True, 'masks': masks_list, 'confidences': confidences_list, 'method': 'medsam_points_individual' } else: result = {'success': False, 'error': 'All point segmentations failed'} return json.dumps(result) except Exception as e: import traceback return json.dumps({ 'success': False, 'error': str(e), 'traceback': traceback.format_exc() }) def segment_box(image, request_json): """ Segment image with a single bounding box - MATCHES BACKEND /api/medsam/segment_box Args: image: PIL Image request_json: JSON string with format: { "bbox": [x1, y1, x2, y2] # Can be array or object with x1,y1,x2,y2 } Returns: JSON string matching backend response format: { "success": true, "mask": [[...]], "confidence": 0.95, "method": "medsam_box" } """ try: # Parse input data = json.loads(request_json) bbox = data.get("bbox", []) # Handle both array format [x1,y1,x2,y2] and object format {x1,y1,x2,y2} if isinstance(bbox, dict): bbox = [bbox.get('x1', 0), bbox.get('y1', 0), bbox.get('x2', 0), bbox.get('y2', 0)] if not bbox or len(bbox) != 4: return json.dumps({'success': False, 'error': 'Valid bounding box required [x1, y1, x2, y2]'}) box = np.array(bbox) # Convert PIL to numpy image_array = np.array(image) # Stateless: always compute embedding from image predictor.set_image(image_array) # Run prediction with box masks, scores, logits = predictor.predict( point_coords=None, point_labels=None, box=box, multimask_output=False ) if len(masks) > 0: best_idx = np.argmax(scores) mask = masks[best_idx] score = float(scores[best_idx]) result = { 'success': True, 'mask': mask.astype(np.uint8).tolist(), 'confidence': score, 'method': 'medsam_box' } else: result = {'success': False, 'error': 'Segmentation failed'} return json.dumps(result) except Exception as e: import traceback return json.dumps({ 'success': False, 'error': str(e), 'traceback': traceback.format_exc() }) def segment_multiple_boxes(image, request_json): """ Segment image with multiple bounding boxes - MATCHES BACKEND /api/medsam/segment_multiple_boxes This is the main API endpoint used by the frontend for box-based segmentation. Args: image: PIL Image request_json: JSON string with format: { "bboxes": [ [x1, y1, x2, y2], # Array format {"x1": 10, "y1": 20, "x2": 100, "y2": 200} # Object format (also supported) ] } Returns: JSON string matching backend response format: { "success": true, "masks": [{"mask": [[...]], "confidence": 0.95}, ...], "confidences": [0.95, ...], "method": "medsam_multiple_boxes" } """ try: # Parse input data = json.loads(request_json) bboxes = data.get("bboxes", []) if not bboxes: return json.dumps({'success': False, 'error': 'At least one bounding box is required'}) # Convert PIL to numpy image_array = np.array(image) # Stateless: always compute embedding from image predictor.set_image(image_array) print(f"Processing {len(bboxes)} boxes for segmentation") masks_list = [] confidences_list = [] for i, bbox in enumerate(bboxes): # Handle both array format [x1,y1,x2,y2] and object format {x1,y1,x2,y2} if isinstance(bbox, dict): box = np.array([ bbox.get('x1', 0), bbox.get('y1', 0), bbox.get('x2', 0), bbox.get('y2', 0) ]) else: box = np.array(bbox) print(f"Processing box {i+1}/{len(bboxes)}: {box.tolist()}") # Run prediction with box masks, scores, logits = predictor.predict( point_coords=None, point_labels=None, box=box, multimask_output=False ) if len(masks) > 0: best_idx = np.argmax(scores) mask = masks[best_idx] score = float(scores[best_idx]) masks_list.append({ 'mask': mask.astype(np.uint8).tolist(), 'confidence': score }) confidences_list.append(score) print(f"Box {i+1} segmentation successful, confidence: {score:.4f}") else: print(f"Box {i+1} segmentation failed") if masks_list: result = { 'success': True, 'masks': masks_list, 'confidences': confidences_list, 'method': 'medsam_multiple_boxes' } else: result = {'success': False, 'error': 'All segmentations failed'} return json.dumps(result) except Exception as e: import traceback return json.dumps({ 'success': False, 'error': str(e), 'traceback': traceback.format_exc() }) # ============================================================================= # AUTO MASK GENERATION API (replaces local mask_generator.generate()) # ============================================================================= def generate_auto_masks(image, request_json): """ Automatically generate all masks for an image using SAM-H model. This is equivalent to `mask_generator.generate(img_np)` in enhanced_preprocessing.py Args: image: PIL Image request_json: JSON string with optional parameters: { "points_per_side": 32, # Grid density (default: 32) "pred_iou_thresh": 0.88, # IoU threshold (default: 0.88) "stability_score_thresh": 0.95, # Stability threshold (default: 0.95) "min_mask_region_area": 0 # Minimum mask area (default: 0) } Returns: JSON string with format matching SamAutomaticMaskGenerator output: { "success": true, "masks": [ { "segmentation": [[...2D boolean array...]], "area": 12345, "bbox": [x, y, width, height], "predicted_iou": 0.95, "point_coords": [[x, y]], "stability_score": 0.98, "crop_box": [x, y, width, height] }, ... ], "num_masks": 42, "image_size": [height, width] } """ try: if mask_generator is None: return json.dumps({ 'success': False, 'error': 'MedSAM model not loaded. Please ensure medsam_vit_b.pth is available.', 'available': False }) # Parse optional parameters params = {} if request_json: try: params = json.loads(request_json) if request_json.strip() else {} except: params = {} # Convert PIL to numpy image_array = np.array(image) H, W = image_array.shape[:2] # Optional downscaling to keep masks smaller / faster resize_longest = int(params.get("resize_longest", 0) or 0) if resize_longest > 0 and max(H, W) > resize_longest: scale = resize_longest / float(max(H, W)) new_w = max(1, int(W * scale)) new_h = max(1, int(H * scale)) print(f"Resizing image from {W}x{H} to {new_w}x{new_h} for auto masks...") image_array = np.array(Image.fromarray(image_array).resize((new_w, new_h))) H, W = image_array.shape[:2] print(f"Generating automatic masks for image of size {W}x{H}...") # Generate masks using SAM automatic mask generator masks = mask_generator.generate(image_array) print(f"Generated {len(masks)} masks") if len(masks) > 0: # Log some stats about the masks areas = [m['area'] for m in masks] ious = [m['predicted_iou'] for m in masks] stabilities = [m['stability_score'] for m in masks] print(f" Area range: {min(areas)} - {max(areas)} pixels") print(f" IoU range: {min(ious):.3f} - {max(ious):.3f}") print(f" Stability range: {min(stabilities):.3f} - {max(stabilities):.3f}") else: print(" WARNING: No masks generated! This could mean:") print(" - Image is too uniform/simple") print(" - Thresholds are still too strict") print(" - Image size is too small or too large") # Optionally limit number of masks returned to keep JSON payload reasonable max_masks = int(params.get("max_masks", 10)) if max_masks > 0 and len(masks) > max_masks: # Sort by predicted IoU (descending) and keep top-K print(f"Limiting masks from {len(masks)} to top {max_masks} by predicted_iou") masks = sorted( masks, key=lambda m: float(m.get("predicted_iou", 0.0)), reverse=True, )[:max_masks] print(f"Preparing {len(masks)} masks to return to client...") # Convert masks to JSON-serializable format masks_output = [] for m in masks: mask_data = { "segmentation": m["segmentation"].astype(np.uint8).tolist(), "area": int(m["area"]), "bbox": [int(x) for x in m["bbox"]], # [x, y, width, height] "predicted_iou": float(m["predicted_iou"]), "point_coords": [ [int(p[0]), int(p[1])] for p in m["point_coords"] ] if m["point_coords"] is not None else [], "stability_score": float(m["stability_score"]), "crop_box": [int(x) for x in m["crop_box"]], # [x, y, width, height] } masks_output.append(mask_data) result = { 'success': True, 'masks': masks_output, 'num_masks': len(masks_output), 'image_size': [H, W] } print(f"Auto mask generation complete: {len(masks_output)} masks") return json.dumps(result) except Exception as e: import traceback return json.dumps({ 'success': False, 'error': str(e), 'traceback': traceback.format_exc() }) def check_auto_mask_status(): """ Check if automatic mask generation is available """ return json.dumps({ 'available': mask_generator is not None, 'model': MODEL_FILENAME if mask_generator else None, 'model_type': MODEL_TYPE, 'device': str(device) }) # ============================================================================= # LEGACY API FUNCTIONS (kept for backwards compatibility with test scripts) # ============================================================================= def segment_with_points_legacy(image, points_json): """ Legacy API - Segment with point prompts using true point-based segmentation Args: points_json: JSON string with format: { "coords": [[x1, y1], [x2, y2], ...], "labels": [1, 0, ...], "multimask_output": true/false } """ try: points_data = json.loads(points_json) coords = np.array(points_data["coords"]) labels = np.array(points_data["labels"]) multimask_output = points_data.get("multimask_output", True) image_array = np.array(image) predictor.set_image(image_array) masks, scores, logits = predictor.predict( point_coords=coords, point_labels=labels, multimask_output=multimask_output ) masks_list = [] scores_list = [] for i, (mask, score) in enumerate(zip(masks, scores)): mask_uint8 = (mask * 255).astype(np.uint8) mask_image = Image.fromarray(mask_uint8) buffer = io.BytesIO() mask_image.save(buffer, format='PNG') mask_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') masks_list.append({ 'mask_base64': mask_base64, 'mask_shape': mask.shape, 'mask_data': mask.tolist() }) scores_list.append(float(score)) return json.dumps({ 'success': True, 'masks': masks_list, 'scores': scores_list, 'num_masks': len(masks_list) }) except Exception as e: return json.dumps({'success': False, 'error': str(e)}) def segment_with_box_legacy(image, box_json): """ Legacy API - Segment with box prompt Args: box_json: JSON string with format: {"box": [x1, y1, x2, y2], "multimask_output": false} """ try: box_data = json.loads(box_json) box = np.array(box_data["box"]) multimask_output = box_data.get("multimask_output", False) image_array = np.array(image) predictor.set_image(image_array) masks, scores, logits = predictor.predict( point_coords=None, point_labels=None, box=box, multimask_output=multimask_output ) masks_list = [] scores_list = [] for i, (mask, score) in enumerate(zip(masks, scores)): mask_uint8 = (mask * 255).astype(np.uint8) mask_image = Image.fromarray(mask_uint8) buffer = io.BytesIO() mask_image.save(buffer, format='PNG') mask_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') masks_list.append({ 'mask_base64': mask_base64, 'mask_shape': mask.shape, 'mask_data': mask.tolist() }) scores_list.append(float(score)) return json.dumps({ 'success': True, 'masks': masks_list, 'scores': scores_list, 'num_masks': len(masks_list), 'box': box.tolist() }) except Exception as e: import traceback return json.dumps({ 'success': False, 'error': str(e), 'traceback': traceback.format_exc() }) def segment_simple(image, x, y, label=1, multimask=True): """Simple single-point segmentation for Gradio UI""" try: points_json = json.dumps({ "coords": [[int(x), int(y)]], "labels": [int(label)], "multimask_output": multimask }) result_json = segment_with_points_legacy(image, points_json) result = json.loads(result_json) if not result['success']: return None, f"Error: {result['error']}" best_idx = np.argmax(result['scores']) best_mask_base64 = result['masks'][best_idx]['mask_base64'] best_score = result['scores'][best_idx] mask_bytes = base64.b64decode(best_mask_base64) mask_image = Image.open(io.BytesIO(mask_bytes)) return mask_image, f"Score: {best_score:.4f}" except Exception as e: return None, f"Error: {str(e)}" # ============================================================================= # GRADIO INTERFACE # ============================================================================= with gr.Blocks(title="MedSAM Inference API") as demo: gr.Markdown("# 🏥 MedSAM Inference API") gr.Markdown("Point and box-based segmentation using Fine-Tuned MedSAM") gr.Markdown("**API-compatible with Dense-Captioning-Toolkit backend**") with gr.Tabs(): # Tab 1: Backend-Compatible API (Points) with gr.Tab("Segment Points (Backend API)"): gr.Markdown(""" ## Point-based Segmentation - Backend Compatible **Matches `/api/medsam/segment_points`** Each point is converted to a small bounding box for segmentation. **Input Format:** ```json { "points": [[x1, y1], [x2, y2], ...], "labels": [1, 0, ...] } ``` **Output Format (matches backend):** ```json { "success": true, "masks": [{"mask": [[...]], "confidence": 0.95}, ...], "confidences": [0.95, ...], "method": "medsam_points_individual" } ``` """) with gr.Row(): with gr.Column(): points_image = gr.Image(type="pil", label="Input Image") points_json_input = gr.Textbox( label="Request JSON", placeholder='{"points": [[100, 150], [200, 200]], "labels": [1, 1]}', lines=3 ) points_button = gr.Button("Segment Points", variant="primary") with gr.Column(): points_output = gr.Textbox(label="Result JSON", lines=15) points_button.click( fn=segment_points, inputs=[points_image, points_json_input], outputs=points_output, api_name="segment_points" ) # Tab 2: Backend-Compatible API (Multiple Boxes) with gr.Tab("Segment Multiple Boxes (Backend API)"): gr.Markdown(""" ## Multiple Box Segmentation - Backend Compatible **Matches `/api/medsam/segment_multiple_boxes`** (main frontend API) **Input Format:** ```json { "bboxes": [ [x1, y1, x2, y2], {"x1": 10, "y1": 20, "x2": 100, "y2": 200} ] } ``` **Output Format (matches backend):** ```json { "success": true, "masks": [{"mask": [[...]], "confidence": 0.95}, ...], "confidences": [0.95, ...], "method": "medsam_multiple_boxes" } ``` """) with gr.Row(): with gr.Column(): multi_box_image = gr.Image(type="pil", label="Input Image") multi_box_json = gr.Textbox( label="Request JSON", placeholder='{"bboxes": [[100, 100, 300, 300], [400, 400, 600, 600]]}', lines=3 ) multi_box_button = gr.Button("Segment Multiple Boxes", variant="primary") with gr.Column(): multi_box_output = gr.Textbox(label="Result JSON", lines=15) multi_box_button.click( fn=segment_multiple_boxes, inputs=[multi_box_image, multi_box_json], outputs=multi_box_output, api_name="segment_multiple_boxes" ) # Tab 3: Backend-Compatible API (Single Box) with gr.Tab("Segment Box (Backend API)"): gr.Markdown(""" ## Single Box Segmentation - Backend Compatible **Matches `/api/medsam/segment_box`** **Input Format:** ```json { "bbox": [x1, y1, x2, y2] } ``` **Output Format (matches backend):** ```json { "success": true, "mask": [[...]], "confidence": 0.95, "method": "medsam_box" } ``` """) with gr.Row(): with gr.Column(): box_image = gr.Image(type="pil", label="Input Image") box_json_input = gr.Textbox( label="Request JSON", placeholder='{"bbox": [100, 100, 300, 300]}', lines=3 ) box_button = gr.Button("Segment Box", variant="primary") with gr.Column(): box_output = gr.Textbox(label="Result JSON", lines=15) box_button.click( fn=segment_box, inputs=[box_image, box_json_input], outputs=box_output, api_name="segment_box" ) # Tab 4: Legacy API (for test scripts) with gr.Tab("Legacy API"): gr.Markdown(""" ## Legacy API (for backwards compatibility) Original API format with `coords`, `mask_data`, `scores`, etc. Use if you have existing scripts using the old format. """) with gr.Row(): with gr.Column(): legacy_image = gr.Image(type="pil", label="Input Image") legacy_points = gr.Textbox( label="Points JSON (Legacy Format)", placeholder='{"coords": [[100, 150]], "labels": [1], "multimask_output": true}', lines=3 ) legacy_button = gr.Button("Run Segmentation (Legacy)", variant="secondary") with gr.Column(): legacy_output = gr.Textbox(label="Result JSON", lines=15) legacy_button.click( fn=segment_with_points_legacy, inputs=[legacy_image, legacy_points], outputs=legacy_output, api_name="segment_with_points" # Keep old API name for compatibility ) gr.Markdown("---") with gr.Row(): with gr.Column(): legacy_box_image = gr.Image(type="pil", label="Input Image") legacy_box_json = gr.Textbox( label="Box JSON (Legacy Format)", placeholder='{"box": [100, 100, 300, 300], "multimask_output": false}', lines=3 ) legacy_box_button = gr.Button("Run Box Segmentation (Legacy)", variant="secondary") with gr.Column(): legacy_box_output = gr.Textbox(label="Result JSON", lines=15) legacy_box_button.click( fn=segment_with_box_legacy, inputs=[legacy_box_image, legacy_box_json], outputs=legacy_box_output, api_name="segment_with_box" # Keep old API name for compatibility ) # Tab 5: Auto Mask Generation (for preprocessing) with gr.Tab("Auto Mask Generation"): gr.Markdown(""" ## Automatic Mask Generation (MedSAM) **Replaces `mask_generator.generate(img_np)` in preprocessing pipeline** Uses MedSAM (ViT-B) model with `SamAutomaticMaskGenerator` to automatically segment all objects in an image. This is used for initial preprocessing of scientific/medical images. Uses the same `medsam_vit_b.pth` model as interactive segmentation. **Output Format:** ```json { "success": true, "masks": [ { "segmentation": [[...2D array...]], "area": 12345, "bbox": [x, y, width, height], "predicted_iou": 0.95, "point_coords": [[x, y]], "stability_score": 0.98, "crop_box": [x, y, width, height] } ], "num_masks": 42 } ``` """) with gr.Row(): with gr.Column(): auto_image = gr.Image(type="pil", label="Input Image") auto_params = gr.Textbox( label="Parameters (optional)", placeholder='{"points_per_side": 32, "pred_iou_thresh": 0.88}', lines=2 ) with gr.Row(): auto_button = gr.Button("Generate All Masks", variant="primary") status_button = gr.Button("Check Status", variant="secondary") with gr.Column(): auto_output = gr.Textbox(label="Result JSON", lines=20) status_output = gr.Textbox(label="Status", lines=3) auto_button.click( fn=generate_auto_masks, inputs=[auto_image, auto_params], outputs=auto_output, api_name="generate_auto_masks" ) status_button.click( fn=check_auto_mask_status, inputs=[], outputs=status_output, api_name="check_auto_mask_status" ) # Tab 6: Encode Image (for embedding storage) with gr.Tab("Encode Image"): gr.Markdown(""" ## Image Encoding API **Encodes image using SAM image encoder and saves embedding to Supabase** This endpoint is used during preprocessing to compute and store image embeddings once per image. Later segmentation calls can use these precomputed embeddings for faster inference (no need to recompute embeddings on each API call). **Input Format:** ```json { "image_id": "uuid-string" # Required: image ID from database } ``` **Output Format:** ```json { "success": true, "message": "Embedding saved successfully for image_id=...", "image_id": "uuid-string", "embedding_shape": [1, 256, 64, 64] } ``` **Note:** Requires Supabase credentials (SUPABASE_URL and SUPABASE_KEY environment variables) """) with gr.Row(): with gr.Column(): encode_image_input = gr.Image(type="pil", label="Input Image") encode_json_input = gr.Textbox( label="Request JSON", placeholder='{"image_id": "123e4567-e89b-12d3-a456-426614174000"}', lines=2 ) encode_button = gr.Button("Encode Image", variant="primary") with gr.Column(): encode_output = gr.Textbox(label="Result JSON", lines=10) encode_button.click( fn=encode_image, inputs=[encode_image_input, encode_json_input], outputs=encode_output, api_name="encode_image" ) # Tab 7: Simple UI Interface with gr.Tab("Simple Interface"): gr.Markdown("## Click-based Segmentation") gr.Markdown("Enter X, Y coordinates to segment") with gr.Row(): with gr.Column(): simple_image = gr.Image(type="pil", label="Input Image") with gr.Row(): simple_x = gr.Number(label="X Coordinate", value=100) simple_y = gr.Number(label="Y Coordinate", value=100) with gr.Row(): simple_label = gr.Radio( choices=[1, 0], value=1, label="Point Label (1=foreground, 0=background)" ) simple_multimask = gr.Checkbox( label="Multiple Masks", value=True ) simple_button = gr.Button("Segment", variant="primary") with gr.Column(): simple_mask = gr.Image(label="Output Mask") simple_info = gr.Textbox(label="Info") simple_button.click( fn=segment_simple, inputs=[simple_image, simple_x, simple_y, simple_label, simple_multimask], outputs=[simple_mask, simple_info] ) gr.Markdown(""" --- ### 📡 API Usage from Python (Backend-Compatible) ```python from gradio_client import Client, handle_file import json client = Client("Aniketg6/medsam-inference") # Point-based segmentation (matches backend format) result = client.predict( image=handle_file("image.jpg"), request_json=json.dumps({ "points": [[150, 200], [300, 400]], "labels": [1, 1] }), api_name="/segment_points" ) # Multiple box segmentation (main frontend API) result = client.predict( image=handle_file("image.jpg"), request_json=json.dumps({ "bboxes": [[100, 100, 300, 300], [400, 400, 600, 600]] }), api_name="/segment_multiple_boxes" ) # Parse response data = json.loads(result) print(f"Success: {data['success']}") print(f"Masks: {len(data['masks'])}") print(f"Confidences: {data['confidences']}") print(f"Method: {data['method']}") ``` """) # Launch if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )