Spaces:
Sleeping
Sleeping
| """ | |
| HuggingFace Space for SAM / MedSAM Inference | |
| API-compatible with Dense-Captioning-Toolkit backend | |
| Deploy this to: https://huggingface.co/spaces/YOUR_USERNAME/medsam-inference | |
| """ | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import json | |
| import base64 | |
| import os | |
| import uuid | |
| from huggingface_hub import hf_hub_download | |
| # Import SAM components | |
| from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator | |
| # Initialize model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # ----------------------------------------------------------------------------- | |
| # Model configuration | |
| # ----------------------------------------------------------------------------- | |
| # 1) MedSAM (ViT-B) for interactive segmentation (points / boxes / multiple boxes) | |
| # We assume medsam_vit_b.pth is committed in this repo (small enough for Spaces). | |
| MEDSAM_CHECKPOINT = os.path.join(os.path.dirname(__file__), "medsam_vit_b.pth") | |
| print("Loading MedSAM model (vit_b) for interactive segmentation...") | |
| try: | |
| # MedSAM checkpoints are typically state_dicts; load and apply to a vit_b SAM backbone. | |
| state_dict = torch.load(MEDSAM_CHECKPOINT, map_location=device) | |
| medsam = sam_model_registry["vit_b"](checkpoint=None) | |
| medsam.load_state_dict(state_dict) | |
| medsam.to(device=device) | |
| medsam.eval() | |
| print("✓ MedSAM model (vit_b) loaded successfully") | |
| except Exception as e: | |
| print(f"✗ Failed to load MedSAM model from {MEDSAM_CHECKPOINT}: {e}") | |
| raise | |
| # SamPredictor for interactive segmentation (point/box prompts) using MedSAM | |
| predictor = SamPredictor(medsam) | |
| print("✓ SamPredictor (MedSAM) initialized for interactive segmentation") | |
| # 2) SAM ViT-H for automatic mask generation and embedding (encode_image) | |
| # We download this large checkpoint from a separate model repo using hf_hub_download. | |
| MODEL_REPO_ID = "Aniketg6/dense-captioning-models" | |
| MODEL_FILENAME = "sam_vit_h_4b8939.pth" # change if your filename is different | |
| MODEL_TYPE = "vit_h" # using SAM ViT-H (general-purpose SAM) | |
| print(f"Downloading SAM (vit_h) checkpoint `{MODEL_FILENAME}` from repo `{MODEL_REPO_ID}`...") | |
| SAM_CHECKPOINT = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=MODEL_FILENAME, | |
| ) | |
| print(f"✓ SAM (vit_h) checkpoint downloaded to: {SAM_CHECKPOINT}") | |
| print("Loading SAM model (vit_h) for auto masks and embeddings...") | |
| # Monkey-patch torch.load to use CPU mapping when needed | |
| original_torch_load = torch.load | |
| def patched_torch_load(f, *args, **kwargs): | |
| if "map_location" not in kwargs and device == "cpu": | |
| kwargs["map_location"] = "cpu" | |
| return original_torch_load(f, *args, **kwargs) | |
| torch.load = patched_torch_load | |
| try: | |
| # Ensure we always load onto CPU when no GPU is available | |
| torch.load = patched_torch_load | |
| sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT) | |
| finally: | |
| torch.load = original_torch_load | |
| sam.to(device=device) | |
| sam.eval() | |
| print("✓ SAM model (vit_h) loaded successfully") | |
| # SamAutomaticMaskGenerator for automatic mask generation (SAM ViT-H) | |
| mask_generator = SamAutomaticMaskGenerator( | |
| model=sam, | |
| points_per_side=16, # Lighter grid (16x16) for faster CPU + smaller responses | |
| pred_iou_thresh=0.7, # IoU threshold for filtering | |
| stability_score_thresh=0.7, # Stability threshold | |
| crop_n_layers=0, # Disable multi-scale crops to avoid IndexError | |
| crop_n_points_downscale_factor=2, | |
| min_mask_region_area=0 # Allow small masks (backend can filter if needed) | |
| ) | |
| print("✓ SamAutomaticMaskGenerator (SAM vit-h) initialized for automatic segmentation") | |
| # ============================================================================= | |
| # HELPER FUNCTIONS FOR EMBEDDINGS (STATELESS) | |
| # ============================================================================= | |
| def set_predictor_features_from_embedding(embedding_tensor: torch.Tensor, image_shape: tuple): | |
| """ | |
| Set SamPredictor's internal features using precomputed embedding | |
| Args: | |
| embedding_tensor: Precomputed embedding tensor [1, C, H, W] | |
| image_shape: Original image shape (height, width) | |
| """ | |
| # SamPredictor stores features in self.features | |
| # We need to set it directly (this is a bit of a hack but necessary) | |
| predictor.features = embedding_tensor | |
| predictor.original_image_size = image_shape | |
| predictor.input_size = (1024, 1024) # SAM default input size | |
| predictor.is_image_set = True | |
| # ============================================================================= | |
| # API FUNCTIONS - MATCHING BACKEND FORMAT (backend/app.py) | |
| # ============================================================================= | |
| def encode_image(image, request_json): | |
| """ | |
| Encode image using SAM image encoder and return embedding to the client. | |
| This is now a stateless API: it does NOT talk to Supabase. The caller | |
| (your backend) is responsible for storing the embedding if desired. | |
| Args: | |
| image: PIL Image | |
| request_json: JSON string with optional fields: | |
| { | |
| "image_id": "uuid-string" # Optional: image ID from your DB | |
| } | |
| Returns: | |
| JSON string: | |
| { | |
| "success": true/false, | |
| "image_id": "uuid-string" or null, | |
| "embedding_npy_base64": "...", # base64-encoded .npy of [C,H,W] | |
| "embedding_shape": [1, C, H, W] | |
| } | |
| """ | |
| try: | |
| # Parse input (image_id is optional and just echoed back) | |
| data = json.loads(request_json) if request_json else {} | |
| image_id = data.get("image_id") | |
| # Convert PIL to numpy | |
| image_array = np.array(image) | |
| H, W = image_array.shape[:2] | |
| # Resize image to SAM's expected input size (1024x1024) | |
| from skimage import transform | |
| img_resized = transform.resize( | |
| image_array, | |
| (1024, 1024), | |
| order=3, | |
| preserve_range=True, | |
| anti_aliasing=True, | |
| ).astype(np.uint8) | |
| # Normalize image (SAM expects normalized input) | |
| img_norm = (img_resized - img_resized.min()) / np.clip( | |
| img_resized.max() - img_resized.min(), 1e-8, None | |
| ) | |
| # Convert to tensor and add batch dimension | |
| tensor = ( | |
| torch.tensor(img_norm) | |
| .float() | |
| .permute(2, 0, 1) | |
| .unsqueeze(0) | |
| .to(device) | |
| ) | |
| # Encode image using SAM image encoder | |
| print(f"Encoding image (image_id={image_id}) original size: {W}x{H} -> 1024x1024") | |
| with torch.no_grad(): | |
| embedding = sam.image_encoder(tensor) | |
| # Convert embedding to numpy [C, Hf, Wf] | |
| arr = embedding.squeeze(0).cpu().numpy().astype(np.float32) | |
| # Serialize as .npy in memory and base64-encode it | |
| buf = io.BytesIO() | |
| np.save(buf, arr) | |
| buf.seek(0) | |
| embedding_b64 = base64.b64encode(buf.read()).decode("utf-8") | |
| return json.dumps( | |
| { | |
| "success": True, | |
| "image_id": image_id, | |
| "embedding_npy_base64": embedding_b64, | |
| "embedding_shape": list(embedding.shape), | |
| } | |
| ) | |
| except Exception as e: | |
| import traceback | |
| return json.dumps( | |
| { | |
| "success": False, | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| } | |
| ) | |
| def segment_points(image, request_json): | |
| """ | |
| Segment image with point prompts - MATCHES BACKEND /api/medsam/segment_points | |
| Each point gets its own small segment (converted to small bounding box). | |
| This matches the backend behavior where points are converted to small boxes. | |
| Args: | |
| image: PIL Image | |
| request_json: JSON string with format: | |
| { | |
| "points": [[x1, y1], [x2, y2], ...], | |
| "labels": [1, 0, ...] # 1=foreground, 0=background | |
| } | |
| Returns: | |
| JSON string matching backend response format: | |
| { | |
| "success": true, | |
| "masks": [{"mask": [[...]], "confidence": 0.95}, ...], | |
| "confidences": [0.95, ...], | |
| "method": "medsam_points_individual" | |
| } | |
| """ | |
| try: | |
| # Parse input | |
| data = json.loads(request_json) | |
| points = data.get("points", []) | |
| labels = data.get("labels", []) | |
| image_id = data.get("image_id") # Optional: if provided, use precomputed embedding | |
| if not points: | |
| return json.dumps({'success': False, 'error': 'At least one point is required'}) | |
| # Convert PIL to numpy | |
| image_array = np.array(image) | |
| H, W = image_array.shape[:2] | |
| # For now, always compute embedding from image (stateless API) | |
| predictor.set_image(image_array) | |
| # Process each point individually (like backend does) | |
| box_size = 20 # Small box size for point-based segmentation | |
| masks_list = [] | |
| confidences_list = [] | |
| for i, pt in enumerate(points): | |
| x, y = pt | |
| # Create a small bounding box centered on the point (matching backend behavior) | |
| x1 = max(0, x - box_size // 2) | |
| y1 = max(0, y - box_size // 2) | |
| x2 = min(W - 1, x + box_size // 2) | |
| y2 = min(H - 1, y + box_size // 2) | |
| bbox = np.array([x1, y1, x2, y2]) | |
| print(f"Processing point {i+1}/{len(points)}: ({x}, {y}) -> bbox: {bbox.tolist()}") | |
| # Run prediction with box | |
| masks, scores, logits = predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=bbox, | |
| multimask_output=False | |
| ) | |
| if len(masks) > 0: | |
| # Take the best mask | |
| best_idx = np.argmax(scores) | |
| mask = masks[best_idx] | |
| score = float(scores[best_idx]) | |
| masks_list.append({ | |
| 'mask': mask.astype(np.uint8).tolist(), | |
| 'confidence': score | |
| }) | |
| confidences_list.append(score) | |
| print(f"Point {i+1} segmentation successful, confidence: {score:.4f}") | |
| else: | |
| print(f"Point {i+1} segmentation failed") | |
| if masks_list: | |
| result = { | |
| 'success': True, | |
| 'masks': masks_list, | |
| 'confidences': confidences_list, | |
| 'method': 'medsam_points_individual' | |
| } | |
| else: | |
| result = {'success': False, 'error': 'All point segmentations failed'} | |
| return json.dumps(result) | |
| except Exception as e: | |
| import traceback | |
| return json.dumps({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'traceback': traceback.format_exc() | |
| }) | |
| def segment_box(image, request_json): | |
| """ | |
| Segment image with a single bounding box - MATCHES BACKEND /api/medsam/segment_box | |
| Args: | |
| image: PIL Image | |
| request_json: JSON string with format: | |
| { | |
| "bbox": [x1, y1, x2, y2] # Can be array or object with x1,y1,x2,y2 | |
| } | |
| Returns: | |
| JSON string matching backend response format: | |
| { | |
| "success": true, | |
| "mask": [[...]], | |
| "confidence": 0.95, | |
| "method": "medsam_box" | |
| } | |
| """ | |
| try: | |
| # Parse input | |
| data = json.loads(request_json) | |
| bbox = data.get("bbox", []) | |
| # Handle both array format [x1,y1,x2,y2] and object format {x1,y1,x2,y2} | |
| if isinstance(bbox, dict): | |
| bbox = [bbox.get('x1', 0), bbox.get('y1', 0), bbox.get('x2', 0), bbox.get('y2', 0)] | |
| if not bbox or len(bbox) != 4: | |
| return json.dumps({'success': False, 'error': 'Valid bounding box required [x1, y1, x2, y2]'}) | |
| box = np.array(bbox) | |
| # Convert PIL to numpy | |
| image_array = np.array(image) | |
| # Stateless: always compute embedding from image | |
| predictor.set_image(image_array) | |
| # Run prediction with box | |
| masks, scores, logits = predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=box, | |
| multimask_output=False | |
| ) | |
| if len(masks) > 0: | |
| best_idx = np.argmax(scores) | |
| mask = masks[best_idx] | |
| score = float(scores[best_idx]) | |
| result = { | |
| 'success': True, | |
| 'mask': mask.astype(np.uint8).tolist(), | |
| 'confidence': score, | |
| 'method': 'medsam_box' | |
| } | |
| else: | |
| result = {'success': False, 'error': 'Segmentation failed'} | |
| return json.dumps(result) | |
| except Exception as e: | |
| import traceback | |
| return json.dumps({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'traceback': traceback.format_exc() | |
| }) | |
| def segment_multiple_boxes(image, request_json): | |
| """ | |
| Segment image with multiple bounding boxes - MATCHES BACKEND /api/medsam/segment_multiple_boxes | |
| This is the main API endpoint used by the frontend for box-based segmentation. | |
| Args: | |
| image: PIL Image | |
| request_json: JSON string with format: | |
| { | |
| "bboxes": [ | |
| [x1, y1, x2, y2], # Array format | |
| {"x1": 10, "y1": 20, "x2": 100, "y2": 200} # Object format (also supported) | |
| ] | |
| } | |
| Returns: | |
| JSON string matching backend response format: | |
| { | |
| "success": true, | |
| "masks": [{"mask": [[...]], "confidence": 0.95}, ...], | |
| "confidences": [0.95, ...], | |
| "method": "medsam_multiple_boxes" | |
| } | |
| """ | |
| try: | |
| # Parse input | |
| data = json.loads(request_json) | |
| bboxes = data.get("bboxes", []) | |
| if not bboxes: | |
| return json.dumps({'success': False, 'error': 'At least one bounding box is required'}) | |
| # Convert PIL to numpy | |
| image_array = np.array(image) | |
| # Stateless: always compute embedding from image | |
| predictor.set_image(image_array) | |
| print(f"Processing {len(bboxes)} boxes for segmentation") | |
| masks_list = [] | |
| confidences_list = [] | |
| for i, bbox in enumerate(bboxes): | |
| # Handle both array format [x1,y1,x2,y2] and object format {x1,y1,x2,y2} | |
| if isinstance(bbox, dict): | |
| box = np.array([ | |
| bbox.get('x1', 0), | |
| bbox.get('y1', 0), | |
| bbox.get('x2', 0), | |
| bbox.get('y2', 0) | |
| ]) | |
| else: | |
| box = np.array(bbox) | |
| print(f"Processing box {i+1}/{len(bboxes)}: {box.tolist()}") | |
| # Run prediction with box | |
| masks, scores, logits = predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=box, | |
| multimask_output=False | |
| ) | |
| if len(masks) > 0: | |
| best_idx = np.argmax(scores) | |
| mask = masks[best_idx] | |
| score = float(scores[best_idx]) | |
| masks_list.append({ | |
| 'mask': mask.astype(np.uint8).tolist(), | |
| 'confidence': score | |
| }) | |
| confidences_list.append(score) | |
| print(f"Box {i+1} segmentation successful, confidence: {score:.4f}") | |
| else: | |
| print(f"Box {i+1} segmentation failed") | |
| if masks_list: | |
| result = { | |
| 'success': True, | |
| 'masks': masks_list, | |
| 'confidences': confidences_list, | |
| 'method': 'medsam_multiple_boxes' | |
| } | |
| else: | |
| result = {'success': False, 'error': 'All segmentations failed'} | |
| return json.dumps(result) | |
| except Exception as e: | |
| import traceback | |
| return json.dumps({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'traceback': traceback.format_exc() | |
| }) | |
| # ============================================================================= | |
| # AUTO MASK GENERATION API (replaces local mask_generator.generate()) | |
| # ============================================================================= | |
| def generate_auto_masks(image, request_json): | |
| """ | |
| Automatically generate all masks for an image using SAM-H model. | |
| This is equivalent to `mask_generator.generate(img_np)` in enhanced_preprocessing.py | |
| Args: | |
| image: PIL Image | |
| request_json: JSON string with optional parameters: | |
| { | |
| "points_per_side": 32, # Grid density (default: 32) | |
| "pred_iou_thresh": 0.88, # IoU threshold (default: 0.88) | |
| "stability_score_thresh": 0.95, # Stability threshold (default: 0.95) | |
| "min_mask_region_area": 0 # Minimum mask area (default: 0) | |
| } | |
| Returns: | |
| JSON string with format matching SamAutomaticMaskGenerator output: | |
| { | |
| "success": true, | |
| "masks": [ | |
| { | |
| "segmentation": [[...2D boolean array...]], | |
| "area": 12345, | |
| "bbox": [x, y, width, height], | |
| "predicted_iou": 0.95, | |
| "point_coords": [[x, y]], | |
| "stability_score": 0.98, | |
| "crop_box": [x, y, width, height] | |
| }, | |
| ... | |
| ], | |
| "num_masks": 42, | |
| "image_size": [height, width] | |
| } | |
| """ | |
| try: | |
| if mask_generator is None: | |
| return json.dumps({ | |
| 'success': False, | |
| 'error': 'MedSAM model not loaded. Please ensure medsam_vit_b.pth is available.', | |
| 'available': False | |
| }) | |
| # Parse optional parameters | |
| params = {} | |
| if request_json: | |
| try: | |
| params = json.loads(request_json) if request_json.strip() else {} | |
| except: | |
| params = {} | |
| # Convert PIL to numpy | |
| image_array = np.array(image) | |
| H, W = image_array.shape[:2] | |
| # Optional downscaling to keep masks smaller / faster | |
| resize_longest = int(params.get("resize_longest", 0) or 0) | |
| if resize_longest > 0 and max(H, W) > resize_longest: | |
| scale = resize_longest / float(max(H, W)) | |
| new_w = max(1, int(W * scale)) | |
| new_h = max(1, int(H * scale)) | |
| print(f"Resizing image from {W}x{H} to {new_w}x{new_h} for auto masks...") | |
| image_array = np.array(Image.fromarray(image_array).resize((new_w, new_h))) | |
| H, W = image_array.shape[:2] | |
| print(f"Generating automatic masks for image of size {W}x{H}...") | |
| # Generate masks using SAM automatic mask generator | |
| masks = mask_generator.generate(image_array) | |
| print(f"Generated {len(masks)} masks") | |
| if len(masks) > 0: | |
| # Log some stats about the masks | |
| areas = [m['area'] for m in masks] | |
| ious = [m['predicted_iou'] for m in masks] | |
| stabilities = [m['stability_score'] for m in masks] | |
| print(f" Area range: {min(areas)} - {max(areas)} pixels") | |
| print(f" IoU range: {min(ious):.3f} - {max(ious):.3f}") | |
| print(f" Stability range: {min(stabilities):.3f} - {max(stabilities):.3f}") | |
| else: | |
| print(" WARNING: No masks generated! This could mean:") | |
| print(" - Image is too uniform/simple") | |
| print(" - Thresholds are still too strict") | |
| print(" - Image size is too small or too large") | |
| # Optionally limit number of masks returned to keep JSON payload reasonable | |
| max_masks = int(params.get("max_masks", 10)) | |
| if max_masks > 0 and len(masks) > max_masks: | |
| # Sort by predicted IoU (descending) and keep top-K | |
| print(f"Limiting masks from {len(masks)} to top {max_masks} by predicted_iou") | |
| masks = sorted( | |
| masks, | |
| key=lambda m: float(m.get("predicted_iou", 0.0)), | |
| reverse=True, | |
| )[:max_masks] | |
| print(f"Preparing {len(masks)} masks to return to client...") | |
| # Convert masks to JSON-serializable format | |
| masks_output = [] | |
| for m in masks: | |
| mask_data = { | |
| "segmentation": m["segmentation"].astype(np.uint8).tolist(), | |
| "area": int(m["area"]), | |
| "bbox": [int(x) for x in m["bbox"]], # [x, y, width, height] | |
| "predicted_iou": float(m["predicted_iou"]), | |
| "point_coords": [ | |
| [int(p[0]), int(p[1])] for p in m["point_coords"] | |
| ] | |
| if m["point_coords"] is not None | |
| else [], | |
| "stability_score": float(m["stability_score"]), | |
| "crop_box": [int(x) for x in m["crop_box"]], # [x, y, width, height] | |
| } | |
| masks_output.append(mask_data) | |
| result = { | |
| 'success': True, | |
| 'masks': masks_output, | |
| 'num_masks': len(masks_output), | |
| 'image_size': [H, W] | |
| } | |
| print(f"Auto mask generation complete: {len(masks_output)} masks") | |
| return json.dumps(result) | |
| except Exception as e: | |
| import traceback | |
| return json.dumps({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'traceback': traceback.format_exc() | |
| }) | |
| def check_auto_mask_status(): | |
| """ | |
| Check if automatic mask generation is available | |
| """ | |
| return json.dumps({ | |
| 'available': mask_generator is not None, | |
| 'model': MODEL_FILENAME if mask_generator else None, | |
| 'model_type': MODEL_TYPE, | |
| 'device': str(device) | |
| }) | |
| # ============================================================================= | |
| # LEGACY API FUNCTIONS (kept for backwards compatibility with test scripts) | |
| # ============================================================================= | |
| def segment_with_points_legacy(image, points_json): | |
| """ | |
| Legacy API - Segment with point prompts using true point-based segmentation | |
| Args: | |
| points_json: JSON string with format: | |
| { | |
| "coords": [[x1, y1], [x2, y2], ...], | |
| "labels": [1, 0, ...], | |
| "multimask_output": true/false | |
| } | |
| """ | |
| try: | |
| points_data = json.loads(points_json) | |
| coords = np.array(points_data["coords"]) | |
| labels = np.array(points_data["labels"]) | |
| multimask_output = points_data.get("multimask_output", True) | |
| image_array = np.array(image) | |
| predictor.set_image(image_array) | |
| masks, scores, logits = predictor.predict( | |
| point_coords=coords, | |
| point_labels=labels, | |
| multimask_output=multimask_output | |
| ) | |
| masks_list = [] | |
| scores_list = [] | |
| for i, (mask, score) in enumerate(zip(masks, scores)): | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| mask_image = Image.fromarray(mask_uint8) | |
| buffer = io.BytesIO() | |
| mask_image.save(buffer, format='PNG') | |
| mask_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| masks_list.append({ | |
| 'mask_base64': mask_base64, | |
| 'mask_shape': mask.shape, | |
| 'mask_data': mask.tolist() | |
| }) | |
| scores_list.append(float(score)) | |
| return json.dumps({ | |
| 'success': True, | |
| 'masks': masks_list, | |
| 'scores': scores_list, | |
| 'num_masks': len(masks_list) | |
| }) | |
| except Exception as e: | |
| return json.dumps({'success': False, 'error': str(e)}) | |
| def segment_with_box_legacy(image, box_json): | |
| """ | |
| Legacy API - Segment with box prompt | |
| Args: | |
| box_json: JSON string with format: | |
| {"box": [x1, y1, x2, y2], "multimask_output": false} | |
| """ | |
| try: | |
| box_data = json.loads(box_json) | |
| box = np.array(box_data["box"]) | |
| multimask_output = box_data.get("multimask_output", False) | |
| image_array = np.array(image) | |
| predictor.set_image(image_array) | |
| masks, scores, logits = predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=box, | |
| multimask_output=multimask_output | |
| ) | |
| masks_list = [] | |
| scores_list = [] | |
| for i, (mask, score) in enumerate(zip(masks, scores)): | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| mask_image = Image.fromarray(mask_uint8) | |
| buffer = io.BytesIO() | |
| mask_image.save(buffer, format='PNG') | |
| mask_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| masks_list.append({ | |
| 'mask_base64': mask_base64, | |
| 'mask_shape': mask.shape, | |
| 'mask_data': mask.tolist() | |
| }) | |
| scores_list.append(float(score)) | |
| return json.dumps({ | |
| 'success': True, | |
| 'masks': masks_list, | |
| 'scores': scores_list, | |
| 'num_masks': len(masks_list), | |
| 'box': box.tolist() | |
| }) | |
| except Exception as e: | |
| import traceback | |
| return json.dumps({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'traceback': traceback.format_exc() | |
| }) | |
| def segment_simple(image, x, y, label=1, multimask=True): | |
| """Simple single-point segmentation for Gradio UI""" | |
| try: | |
| points_json = json.dumps({ | |
| "coords": [[int(x), int(y)]], | |
| "labels": [int(label)], | |
| "multimask_output": multimask | |
| }) | |
| result_json = segment_with_points_legacy(image, points_json) | |
| result = json.loads(result_json) | |
| if not result['success']: | |
| return None, f"Error: {result['error']}" | |
| best_idx = np.argmax(result['scores']) | |
| best_mask_base64 = result['masks'][best_idx]['mask_base64'] | |
| best_score = result['scores'][best_idx] | |
| mask_bytes = base64.b64decode(best_mask_base64) | |
| mask_image = Image.open(io.BytesIO(mask_bytes)) | |
| return mask_image, f"Score: {best_score:.4f}" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| # ============================================================================= | |
| # GRADIO INTERFACE | |
| # ============================================================================= | |
| with gr.Blocks(title="MedSAM Inference API") as demo: | |
| gr.Markdown("# 🏥 MedSAM Inference API") | |
| gr.Markdown("Point and box-based segmentation using Fine-Tuned MedSAM") | |
| gr.Markdown("**API-compatible with Dense-Captioning-Toolkit backend**") | |
| with gr.Tabs(): | |
| # Tab 1: Backend-Compatible API (Points) | |
| with gr.Tab("Segment Points (Backend API)"): | |
| gr.Markdown(""" | |
| ## Point-based Segmentation - Backend Compatible | |
| **Matches `/api/medsam/segment_points`** | |
| Each point is converted to a small bounding box for segmentation. | |
| **Input Format:** | |
| ```json | |
| { | |
| "points": [[x1, y1], [x2, y2], ...], | |
| "labels": [1, 0, ...] | |
| } | |
| ``` | |
| **Output Format (matches backend):** | |
| ```json | |
| { | |
| "success": true, | |
| "masks": [{"mask": [[...]], "confidence": 0.95}, ...], | |
| "confidences": [0.95, ...], | |
| "method": "medsam_points_individual" | |
| } | |
| ``` | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| points_image = gr.Image(type="pil", label="Input Image") | |
| points_json_input = gr.Textbox( | |
| label="Request JSON", | |
| placeholder='{"points": [[100, 150], [200, 200]], "labels": [1, 1]}', | |
| lines=3 | |
| ) | |
| points_button = gr.Button("Segment Points", variant="primary") | |
| with gr.Column(): | |
| points_output = gr.Textbox(label="Result JSON", lines=15) | |
| points_button.click( | |
| fn=segment_points, | |
| inputs=[points_image, points_json_input], | |
| outputs=points_output, | |
| api_name="segment_points" | |
| ) | |
| # Tab 2: Backend-Compatible API (Multiple Boxes) | |
| with gr.Tab("Segment Multiple Boxes (Backend API)"): | |
| gr.Markdown(""" | |
| ## Multiple Box Segmentation - Backend Compatible | |
| **Matches `/api/medsam/segment_multiple_boxes`** (main frontend API) | |
| **Input Format:** | |
| ```json | |
| { | |
| "bboxes": [ | |
| [x1, y1, x2, y2], | |
| {"x1": 10, "y1": 20, "x2": 100, "y2": 200} | |
| ] | |
| } | |
| ``` | |
| **Output Format (matches backend):** | |
| ```json | |
| { | |
| "success": true, | |
| "masks": [{"mask": [[...]], "confidence": 0.95}, ...], | |
| "confidences": [0.95, ...], | |
| "method": "medsam_multiple_boxes" | |
| } | |
| ``` | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| multi_box_image = gr.Image(type="pil", label="Input Image") | |
| multi_box_json = gr.Textbox( | |
| label="Request JSON", | |
| placeholder='{"bboxes": [[100, 100, 300, 300], [400, 400, 600, 600]]}', | |
| lines=3 | |
| ) | |
| multi_box_button = gr.Button("Segment Multiple Boxes", variant="primary") | |
| with gr.Column(): | |
| multi_box_output = gr.Textbox(label="Result JSON", lines=15) | |
| multi_box_button.click( | |
| fn=segment_multiple_boxes, | |
| inputs=[multi_box_image, multi_box_json], | |
| outputs=multi_box_output, | |
| api_name="segment_multiple_boxes" | |
| ) | |
| # Tab 3: Backend-Compatible API (Single Box) | |
| with gr.Tab("Segment Box (Backend API)"): | |
| gr.Markdown(""" | |
| ## Single Box Segmentation - Backend Compatible | |
| **Matches `/api/medsam/segment_box`** | |
| **Input Format:** | |
| ```json | |
| { | |
| "bbox": [x1, y1, x2, y2] | |
| } | |
| ``` | |
| **Output Format (matches backend):** | |
| ```json | |
| { | |
| "success": true, | |
| "mask": [[...]], | |
| "confidence": 0.95, | |
| "method": "medsam_box" | |
| } | |
| ``` | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| box_image = gr.Image(type="pil", label="Input Image") | |
| box_json_input = gr.Textbox( | |
| label="Request JSON", | |
| placeholder='{"bbox": [100, 100, 300, 300]}', | |
| lines=3 | |
| ) | |
| box_button = gr.Button("Segment Box", variant="primary") | |
| with gr.Column(): | |
| box_output = gr.Textbox(label="Result JSON", lines=15) | |
| box_button.click( | |
| fn=segment_box, | |
| inputs=[box_image, box_json_input], | |
| outputs=box_output, | |
| api_name="segment_box" | |
| ) | |
| # Tab 4: Legacy API (for test scripts) | |
| with gr.Tab("Legacy API"): | |
| gr.Markdown(""" | |
| ## Legacy API (for backwards compatibility) | |
| Original API format with `coords`, `mask_data`, `scores`, etc. | |
| Use if you have existing scripts using the old format. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| legacy_image = gr.Image(type="pil", label="Input Image") | |
| legacy_points = gr.Textbox( | |
| label="Points JSON (Legacy Format)", | |
| placeholder='{"coords": [[100, 150]], "labels": [1], "multimask_output": true}', | |
| lines=3 | |
| ) | |
| legacy_button = gr.Button("Run Segmentation (Legacy)", variant="secondary") | |
| with gr.Column(): | |
| legacy_output = gr.Textbox(label="Result JSON", lines=15) | |
| legacy_button.click( | |
| fn=segment_with_points_legacy, | |
| inputs=[legacy_image, legacy_points], | |
| outputs=legacy_output, | |
| api_name="segment_with_points" # Keep old API name for compatibility | |
| ) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(): | |
| legacy_box_image = gr.Image(type="pil", label="Input Image") | |
| legacy_box_json = gr.Textbox( | |
| label="Box JSON (Legacy Format)", | |
| placeholder='{"box": [100, 100, 300, 300], "multimask_output": false}', | |
| lines=3 | |
| ) | |
| legacy_box_button = gr.Button("Run Box Segmentation (Legacy)", variant="secondary") | |
| with gr.Column(): | |
| legacy_box_output = gr.Textbox(label="Result JSON", lines=15) | |
| legacy_box_button.click( | |
| fn=segment_with_box_legacy, | |
| inputs=[legacy_box_image, legacy_box_json], | |
| outputs=legacy_box_output, | |
| api_name="segment_with_box" # Keep old API name for compatibility | |
| ) | |
| # Tab 5: Auto Mask Generation (for preprocessing) | |
| with gr.Tab("Auto Mask Generation"): | |
| gr.Markdown(""" | |
| ## Automatic Mask Generation (MedSAM) | |
| **Replaces `mask_generator.generate(img_np)` in preprocessing pipeline** | |
| Uses MedSAM (ViT-B) model with `SamAutomaticMaskGenerator` to automatically | |
| segment all objects in an image. This is used for initial preprocessing | |
| of scientific/medical images. | |
| Uses the same `medsam_vit_b.pth` model as interactive segmentation. | |
| **Output Format:** | |
| ```json | |
| { | |
| "success": true, | |
| "masks": [ | |
| { | |
| "segmentation": [[...2D array...]], | |
| "area": 12345, | |
| "bbox": [x, y, width, height], | |
| "predicted_iou": 0.95, | |
| "point_coords": [[x, y]], | |
| "stability_score": 0.98, | |
| "crop_box": [x, y, width, height] | |
| } | |
| ], | |
| "num_masks": 42 | |
| } | |
| ``` | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| auto_image = gr.Image(type="pil", label="Input Image") | |
| auto_params = gr.Textbox( | |
| label="Parameters (optional)", | |
| placeholder='{"points_per_side": 32, "pred_iou_thresh": 0.88}', | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| auto_button = gr.Button("Generate All Masks", variant="primary") | |
| status_button = gr.Button("Check Status", variant="secondary") | |
| with gr.Column(): | |
| auto_output = gr.Textbox(label="Result JSON", lines=20) | |
| status_output = gr.Textbox(label="Status", lines=3) | |
| auto_button.click( | |
| fn=generate_auto_masks, | |
| inputs=[auto_image, auto_params], | |
| outputs=auto_output, | |
| api_name="generate_auto_masks" | |
| ) | |
| status_button.click( | |
| fn=check_auto_mask_status, | |
| inputs=[], | |
| outputs=status_output, | |
| api_name="check_auto_mask_status" | |
| ) | |
| # Tab 6: Encode Image (for embedding storage) | |
| with gr.Tab("Encode Image"): | |
| gr.Markdown(""" | |
| ## Image Encoding API | |
| **Encodes image using SAM image encoder and saves embedding to Supabase** | |
| This endpoint is used during preprocessing to compute and store image embeddings | |
| once per image. Later segmentation calls can use these precomputed embeddings | |
| for faster inference (no need to recompute embeddings on each API call). | |
| **Input Format:** | |
| ```json | |
| { | |
| "image_id": "uuid-string" # Required: image ID from database | |
| } | |
| ``` | |
| **Output Format:** | |
| ```json | |
| { | |
| "success": true, | |
| "message": "Embedding saved successfully for image_id=...", | |
| "image_id": "uuid-string", | |
| "embedding_shape": [1, 256, 64, 64] | |
| } | |
| ``` | |
| **Note:** Requires Supabase credentials (SUPABASE_URL and SUPABASE_KEY environment variables) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| encode_image_input = gr.Image(type="pil", label="Input Image") | |
| encode_json_input = gr.Textbox( | |
| label="Request JSON", | |
| placeholder='{"image_id": "123e4567-e89b-12d3-a456-426614174000"}', | |
| lines=2 | |
| ) | |
| encode_button = gr.Button("Encode Image", variant="primary") | |
| with gr.Column(): | |
| encode_output = gr.Textbox(label="Result JSON", lines=10) | |
| encode_button.click( | |
| fn=encode_image, | |
| inputs=[encode_image_input, encode_json_input], | |
| outputs=encode_output, | |
| api_name="encode_image" | |
| ) | |
| # Tab 7: Simple UI Interface | |
| with gr.Tab("Simple Interface"): | |
| gr.Markdown("## Click-based Segmentation") | |
| gr.Markdown("Enter X, Y coordinates to segment") | |
| with gr.Row(): | |
| with gr.Column(): | |
| simple_image = gr.Image(type="pil", label="Input Image") | |
| with gr.Row(): | |
| simple_x = gr.Number(label="X Coordinate", value=100) | |
| simple_y = gr.Number(label="Y Coordinate", value=100) | |
| with gr.Row(): | |
| simple_label = gr.Radio( | |
| choices=[1, 0], | |
| value=1, | |
| label="Point Label (1=foreground, 0=background)" | |
| ) | |
| simple_multimask = gr.Checkbox( | |
| label="Multiple Masks", | |
| value=True | |
| ) | |
| simple_button = gr.Button("Segment", variant="primary") | |
| with gr.Column(): | |
| simple_mask = gr.Image(label="Output Mask") | |
| simple_info = gr.Textbox(label="Info") | |
| simple_button.click( | |
| fn=segment_simple, | |
| inputs=[simple_image, simple_x, simple_y, simple_label, simple_multimask], | |
| outputs=[simple_mask, simple_info] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### 📡 API Usage from Python (Backend-Compatible) | |
| ```python | |
| from gradio_client import Client, handle_file | |
| import json | |
| client = Client("Aniketg6/medsam-inference") | |
| # Point-based segmentation (matches backend format) | |
| result = client.predict( | |
| image=handle_file("image.jpg"), | |
| request_json=json.dumps({ | |
| "points": [[150, 200], [300, 400]], | |
| "labels": [1, 1] | |
| }), | |
| api_name="/segment_points" | |
| ) | |
| # Multiple box segmentation (main frontend API) | |
| result = client.predict( | |
| image=handle_file("image.jpg"), | |
| request_json=json.dumps({ | |
| "bboxes": [[100, 100, 300, 300], [400, 400, 600, 600]] | |
| }), | |
| api_name="/segment_multiple_boxes" | |
| ) | |
| # Parse response | |
| data = json.loads(result) | |
| print(f"Success: {data['success']}") | |
| print(f"Masks: {len(data['masks'])}") | |
| print(f"Confidences: {data['confidences']}") | |
| print(f"Method: {data['method']}") | |
| ``` | |
| """) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) | |