Anigor66
Removed medsam
a2ebcd4
"""
HuggingFace Space for SAM ViT-H Automatic Mask Generation
Generates all segment masks for a given image.
"""
import warnings
import sys
import asyncio
# Suppress known warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.distributed.reduce_op.*")
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
# Fix Python 3.13 asyncio cleanup error (ValueError: Invalid file descriptor: -1)
# The error occurs when garbage-collected event loops try to close already-closed file descriptors.
if sys.version_info >= (3, 13):
_original_del = asyncio.BaseEventLoop.__del__ if hasattr(asyncio.BaseEventLoop, '__del__') else None
def _safe_del(self):
try:
if _original_del is not None:
_original_del(self)
except (ValueError, OSError):
pass # Suppress invalid file descriptor errors during cleanup
asyncio.BaseEventLoop.__del__ = _safe_del
import gradio as gr
import torch
import numpy as np
from PIL import Image
import json
import os
from huggingface_hub import hf_hub_download
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# -----------------------------------------------------------------------------
# SAM ViT-H configuration (lazy loaded)
# -----------------------------------------------------------------------------
MODEL_REPO_ID = "Aniketg6/dense-captioning-models"
MODEL_FILENAME = "sam_vit_h_4b8939.pth"
MODEL_TYPE = "vit_h"
sam = None
mask_generator = None
_sam_loaded = False
def ensure_sam_loaded():
"""Lazy load SAM ViT-H model and mask generator on first use."""
global sam, mask_generator, _sam_loaded
if _sam_loaded and sam is not None and mask_generator is not None:
return
print("=" * 60)
print("Loading SAM ViT-H model (lazy loading)...")
print("=" * 60)
try:
print(f"Downloading SAM (vit_h) checkpoint `{MODEL_FILENAME}` from repo `{MODEL_REPO_ID}`...")
SAM_CHECKPOINT = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME,
)
print(f" Checkpoint downloaded to: {SAM_CHECKPOINT}")
print("Loading SAM model (vit_h)...")
# Monkey-patch torch.load to force CPU mapping when needed
original_torch_load = torch.load
def patched_torch_load(f, *args, **kwargs):
if "map_location" not in kwargs and device == "cpu":
kwargs["map_location"] = "cpu"
return original_torch_load(f, *args, **kwargs)
torch.load = patched_torch_load
try:
sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
finally:
torch.load = original_torch_load
sam.to(device=device)
sam.eval()
print(" SAM model (vit_h) loaded successfully")
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16,
pred_iou_thresh=0.7,
stability_score_thresh=0.7,
crop_n_layers=0,
crop_n_points_downscale_factor=2,
min_mask_region_area=0,
)
print(" SamAutomaticMaskGenerator initialized")
_sam_loaded = True
print("=" * 60)
print("SAM ViT-H model loading complete!")
print("=" * 60)
except Exception as e:
print(f"Error loading SAM ViT-H model: {e}")
import traceback
traceback.print_exc()
raise RuntimeError(f"Failed to load SAM ViT-H model: {e}") from e
# =============================================================================
# API FUNCTION
# =============================================================================
def generate_auto_masks(image, request_json):
"""
Automatically generate all masks for an image using SAM ViT-H.
Args:
image: PIL Image
request_json: JSON string with optional parameters:
{
"points_per_side": 32,
"pred_iou_thresh": 0.88,
"stability_score_thresh": 0.95,
"min_mask_region_area": 0,
"resize_longest": 0,
"max_masks": 10
}
Returns:
JSON string:
{
"success": true,
"masks": [
{
"segmentation": [[...2D boolean array...]],
"area": 12345,
"bbox": [x, y, width, height],
"predicted_iou": 0.95,
"point_coords": [[x, y]],
"stability_score": 0.98,
"crop_box": [x, y, width, height]
}
],
"num_masks": 42,
"image_size": [height, width]
}
"""
ensure_sam_loaded()
try:
if mask_generator is None:
return json.dumps({
'success': False,
'error': 'SAM ViT-H model not loaded. Failed to initialize mask generator.',
'available': False
})
# Parse optional parameters
params = {}
if request_json:
try:
params = json.loads(request_json) if request_json.strip() else {}
except Exception:
params = {}
# Convert PIL to numpy
image_array = np.array(image)
H, W = image_array.shape[:2]
# Optional downscaling
resize_longest = int(params.get("resize_longest", 0) or 0)
if resize_longest > 0 and max(H, W) > resize_longest:
scale = resize_longest / float(max(H, W))
new_w = max(1, int(W * scale))
new_h = max(1, int(H * scale))
print(f"Resizing image from {W}x{H} to {new_w}x{new_h} for auto masks...")
image_array = np.array(Image.fromarray(image_array).resize((new_w, new_h)))
H, W = image_array.shape[:2]
print(f"Generating automatic masks for image of size {W}x{H}...")
masks = mask_generator.generate(image_array)
print(f"Generated {len(masks)} masks")
if len(masks) > 0:
areas = [m['area'] for m in masks]
ious = [m['predicted_iou'] for m in masks]
stabilities = [m['stability_score'] for m in masks]
print(f" Area range: {min(areas)} - {max(areas)} pixels")
print(f" IoU range: {min(ious):.3f} - {max(ious):.3f}")
print(f" Stability range: {min(stabilities):.3f} - {max(stabilities):.3f}")
else:
print(" WARNING: No masks generated!")
# Optionally limit number of masks
max_masks = int(params.get("max_masks", 10))
if max_masks > 0 and len(masks) > max_masks:
print(f"Limiting masks from {len(masks)} to top {max_masks} by predicted_iou")
masks = sorted(
masks,
key=lambda m: float(m.get("predicted_iou", 0.0)),
reverse=True,
)[:max_masks]
print(f"Preparing {len(masks)} masks to return to client...")
masks_output = []
for m in masks:
mask_data = {
"segmentation": m["segmentation"].astype(np.uint8).tolist(),
"area": int(m["area"]),
"bbox": [int(x) for x in m["bbox"]],
"predicted_iou": float(m["predicted_iou"]),
"point_coords": (
[[int(p[0]), int(p[1])] for p in m["point_coords"]]
if m["point_coords"] is not None
else []
),
"stability_score": float(m["stability_score"]),
"crop_box": [int(x) for x in m["crop_box"]],
}
masks_output.append(mask_data)
result = {
'success': True,
'masks': masks_output,
'num_masks': len(masks_output),
'image_size': [H, W]
}
print(f"Auto mask generation complete: {len(masks_output)} masks")
return json.dumps(result)
except Exception as e:
import traceback
return json.dumps({
'success': False,
'error': str(e),
'traceback': traceback.format_exc()
})
def check_auto_mask_status():
"""Check if automatic mask generation is available."""
return json.dumps({
'available': mask_generator is not None and _sam_loaded,
'model': MODEL_FILENAME if mask_generator else None,
'model_type': MODEL_TYPE,
'device': str(device),
'loaded': _sam_loaded
})
# =============================================================================
# GRADIO INTERFACE
# =============================================================================
with gr.Blocks(title="SAM Auto Mask Generation API") as demo:
gr.Markdown("# SAM Automatic Mask Generation API")
gr.Markdown("Generate all segment masks for an image using SAM ViT-H")
with gr.Row():
with gr.Column():
auto_image = gr.Image(type="pil", label="Input Image")
auto_params = gr.Textbox(
label="Parameters (optional)",
placeholder='{"max_masks": 10, "resize_longest": 1024}',
lines=2,
)
with gr.Row():
auto_button = gr.Button("Generate All Masks", variant="primary")
status_button = gr.Button("Check Status", variant="secondary")
with gr.Column():
auto_output = gr.Textbox(label="Result JSON", lines=20)
status_output = gr.Textbox(label="Status", lines=3)
auto_button.click(
fn=generate_auto_masks,
inputs=[auto_image, auto_params],
outputs=auto_output,
api_name="generate_auto_masks",
)
status_button.click(
fn=check_auto_mask_status,
inputs=[],
outputs=status_output,
api_name="check_auto_mask_status",
)
gr.Markdown("""
---
### API Usage from Python
```python
from gradio_client import Client, handle_file
import json
client = Client("Aniketg6/medsam-inference")
result = client.predict(
image=handle_file("image.jpg"),
request_json=json.dumps({"max_masks": 10}),
api_name="/generate_auto_masks"
)
data = json.loads(result)
print(f"Success: {data['success']}")
print(f"Masks: {data['num_masks']}")
```
""")
# Launch
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
ssr_mode=False,
)