Spaces:
Sleeping
Sleeping
| """ | |
| HuggingFace Space for SAM ViT-H Automatic Mask Generation | |
| Generates all segment masks for a given image. | |
| """ | |
| import warnings | |
| import sys | |
| import asyncio | |
| # Suppress known warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.distributed.reduce_op.*") | |
| warnings.filterwarnings("ignore", category=UserWarning, module="torch") | |
| # Fix Python 3.13 asyncio cleanup error (ValueError: Invalid file descriptor: -1) | |
| # The error occurs when garbage-collected event loops try to close already-closed file descriptors. | |
| if sys.version_info >= (3, 13): | |
| _original_del = asyncio.BaseEventLoop.__del__ if hasattr(asyncio.BaseEventLoop, '__del__') else None | |
| def _safe_del(self): | |
| try: | |
| if _original_del is not None: | |
| _original_del(self) | |
| except (ValueError, OSError): | |
| pass # Suppress invalid file descriptor errors during cleanup | |
| asyncio.BaseEventLoop.__del__ = _safe_del | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import json | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
| # Device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # ----------------------------------------------------------------------------- | |
| # SAM ViT-H configuration (lazy loaded) | |
| # ----------------------------------------------------------------------------- | |
| MODEL_REPO_ID = "Aniketg6/dense-captioning-models" | |
| MODEL_FILENAME = "sam_vit_h_4b8939.pth" | |
| MODEL_TYPE = "vit_h" | |
| sam = None | |
| mask_generator = None | |
| _sam_loaded = False | |
| def ensure_sam_loaded(): | |
| """Lazy load SAM ViT-H model and mask generator on first use.""" | |
| global sam, mask_generator, _sam_loaded | |
| if _sam_loaded and sam is not None and mask_generator is not None: | |
| return | |
| print("=" * 60) | |
| print("Loading SAM ViT-H model (lazy loading)...") | |
| print("=" * 60) | |
| try: | |
| print(f"Downloading SAM (vit_h) checkpoint `{MODEL_FILENAME}` from repo `{MODEL_REPO_ID}`...") | |
| SAM_CHECKPOINT = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=MODEL_FILENAME, | |
| ) | |
| print(f" Checkpoint downloaded to: {SAM_CHECKPOINT}") | |
| print("Loading SAM model (vit_h)...") | |
| # Monkey-patch torch.load to force CPU mapping when needed | |
| original_torch_load = torch.load | |
| def patched_torch_load(f, *args, **kwargs): | |
| if "map_location" not in kwargs and device == "cpu": | |
| kwargs["map_location"] = "cpu" | |
| return original_torch_load(f, *args, **kwargs) | |
| torch.load = patched_torch_load | |
| try: | |
| sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT) | |
| finally: | |
| torch.load = original_torch_load | |
| sam.to(device=device) | |
| sam.eval() | |
| print(" SAM model (vit_h) loaded successfully") | |
| mask_generator = SamAutomaticMaskGenerator( | |
| model=sam, | |
| points_per_side=16, | |
| pred_iou_thresh=0.7, | |
| stability_score_thresh=0.7, | |
| crop_n_layers=0, | |
| crop_n_points_downscale_factor=2, | |
| min_mask_region_area=0, | |
| ) | |
| print(" SamAutomaticMaskGenerator initialized") | |
| _sam_loaded = True | |
| print("=" * 60) | |
| print("SAM ViT-H model loading complete!") | |
| print("=" * 60) | |
| except Exception as e: | |
| print(f"Error loading SAM ViT-H model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise RuntimeError(f"Failed to load SAM ViT-H model: {e}") from e | |
| # ============================================================================= | |
| # API FUNCTION | |
| # ============================================================================= | |
| def generate_auto_masks(image, request_json): | |
| """ | |
| Automatically generate all masks for an image using SAM ViT-H. | |
| Args: | |
| image: PIL Image | |
| request_json: JSON string with optional parameters: | |
| { | |
| "points_per_side": 32, | |
| "pred_iou_thresh": 0.88, | |
| "stability_score_thresh": 0.95, | |
| "min_mask_region_area": 0, | |
| "resize_longest": 0, | |
| "max_masks": 10 | |
| } | |
| Returns: | |
| JSON string: | |
| { | |
| "success": true, | |
| "masks": [ | |
| { | |
| "segmentation": [[...2D boolean array...]], | |
| "area": 12345, | |
| "bbox": [x, y, width, height], | |
| "predicted_iou": 0.95, | |
| "point_coords": [[x, y]], | |
| "stability_score": 0.98, | |
| "crop_box": [x, y, width, height] | |
| } | |
| ], | |
| "num_masks": 42, | |
| "image_size": [height, width] | |
| } | |
| """ | |
| ensure_sam_loaded() | |
| try: | |
| if mask_generator is None: | |
| return json.dumps({ | |
| 'success': False, | |
| 'error': 'SAM ViT-H model not loaded. Failed to initialize mask generator.', | |
| 'available': False | |
| }) | |
| # Parse optional parameters | |
| params = {} | |
| if request_json: | |
| try: | |
| params = json.loads(request_json) if request_json.strip() else {} | |
| except Exception: | |
| params = {} | |
| # Convert PIL to numpy | |
| image_array = np.array(image) | |
| H, W = image_array.shape[:2] | |
| # Optional downscaling | |
| resize_longest = int(params.get("resize_longest", 0) or 0) | |
| if resize_longest > 0 and max(H, W) > resize_longest: | |
| scale = resize_longest / float(max(H, W)) | |
| new_w = max(1, int(W * scale)) | |
| new_h = max(1, int(H * scale)) | |
| print(f"Resizing image from {W}x{H} to {new_w}x{new_h} for auto masks...") | |
| image_array = np.array(Image.fromarray(image_array).resize((new_w, new_h))) | |
| H, W = image_array.shape[:2] | |
| print(f"Generating automatic masks for image of size {W}x{H}...") | |
| masks = mask_generator.generate(image_array) | |
| print(f"Generated {len(masks)} masks") | |
| if len(masks) > 0: | |
| areas = [m['area'] for m in masks] | |
| ious = [m['predicted_iou'] for m in masks] | |
| stabilities = [m['stability_score'] for m in masks] | |
| print(f" Area range: {min(areas)} - {max(areas)} pixels") | |
| print(f" IoU range: {min(ious):.3f} - {max(ious):.3f}") | |
| print(f" Stability range: {min(stabilities):.3f} - {max(stabilities):.3f}") | |
| else: | |
| print(" WARNING: No masks generated!") | |
| # Optionally limit number of masks | |
| max_masks = int(params.get("max_masks", 10)) | |
| if max_masks > 0 and len(masks) > max_masks: | |
| print(f"Limiting masks from {len(masks)} to top {max_masks} by predicted_iou") | |
| masks = sorted( | |
| masks, | |
| key=lambda m: float(m.get("predicted_iou", 0.0)), | |
| reverse=True, | |
| )[:max_masks] | |
| print(f"Preparing {len(masks)} masks to return to client...") | |
| masks_output = [] | |
| for m in masks: | |
| mask_data = { | |
| "segmentation": m["segmentation"].astype(np.uint8).tolist(), | |
| "area": int(m["area"]), | |
| "bbox": [int(x) for x in m["bbox"]], | |
| "predicted_iou": float(m["predicted_iou"]), | |
| "point_coords": ( | |
| [[int(p[0]), int(p[1])] for p in m["point_coords"]] | |
| if m["point_coords"] is not None | |
| else [] | |
| ), | |
| "stability_score": float(m["stability_score"]), | |
| "crop_box": [int(x) for x in m["crop_box"]], | |
| } | |
| masks_output.append(mask_data) | |
| result = { | |
| 'success': True, | |
| 'masks': masks_output, | |
| 'num_masks': len(masks_output), | |
| 'image_size': [H, W] | |
| } | |
| print(f"Auto mask generation complete: {len(masks_output)} masks") | |
| return json.dumps(result) | |
| except Exception as e: | |
| import traceback | |
| return json.dumps({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'traceback': traceback.format_exc() | |
| }) | |
| def check_auto_mask_status(): | |
| """Check if automatic mask generation is available.""" | |
| return json.dumps({ | |
| 'available': mask_generator is not None and _sam_loaded, | |
| 'model': MODEL_FILENAME if mask_generator else None, | |
| 'model_type': MODEL_TYPE, | |
| 'device': str(device), | |
| 'loaded': _sam_loaded | |
| }) | |
| # ============================================================================= | |
| # GRADIO INTERFACE | |
| # ============================================================================= | |
| with gr.Blocks(title="SAM Auto Mask Generation API") as demo: | |
| gr.Markdown("# SAM Automatic Mask Generation API") | |
| gr.Markdown("Generate all segment masks for an image using SAM ViT-H") | |
| with gr.Row(): | |
| with gr.Column(): | |
| auto_image = gr.Image(type="pil", label="Input Image") | |
| auto_params = gr.Textbox( | |
| label="Parameters (optional)", | |
| placeholder='{"max_masks": 10, "resize_longest": 1024}', | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| auto_button = gr.Button("Generate All Masks", variant="primary") | |
| status_button = gr.Button("Check Status", variant="secondary") | |
| with gr.Column(): | |
| auto_output = gr.Textbox(label="Result JSON", lines=20) | |
| status_output = gr.Textbox(label="Status", lines=3) | |
| auto_button.click( | |
| fn=generate_auto_masks, | |
| inputs=[auto_image, auto_params], | |
| outputs=auto_output, | |
| api_name="generate_auto_masks", | |
| ) | |
| status_button.click( | |
| fn=check_auto_mask_status, | |
| inputs=[], | |
| outputs=status_output, | |
| api_name="check_auto_mask_status", | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### API Usage from Python | |
| ```python | |
| from gradio_client import Client, handle_file | |
| import json | |
| client = Client("Aniketg6/medsam-inference") | |
| result = client.predict( | |
| image=handle_file("image.jpg"), | |
| request_json=json.dumps({"max_masks": 10}), | |
| api_name="/generate_auto_masks" | |
| ) | |
| data = json.loads(result) | |
| print(f"Success: {data['success']}") | |
| print(f"Masks: {data['num_masks']}") | |
| ``` | |
| """) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| ssr_mode=False, | |
| ) | |