|
|
import spaces |
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import matplotlib |
|
|
from transformers import Sam3Processor, Sam3Model |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device) |
|
|
processor = Sam3Processor.from_pretrained("facebook/sam3") |
|
|
|
|
|
def overlay_masks(image: Image.Image, masks: torch.Tensor) -> Image.Image: |
|
|
""" |
|
|
Overlay segmentation masks on the input image using rainbow colormap. |
|
|
""" |
|
|
image = image.convert("RGBA") |
|
|
masks = 255 * masks.cpu().numpy().astype(np.uint8) |
|
|
|
|
|
n_masks = masks.shape[0] |
|
|
if n_masks == 0: |
|
|
return image.convert("RGB") |
|
|
|
|
|
cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks) |
|
|
colors = [ |
|
|
tuple(int(c * 255) for c in cmap(i)[:3]) |
|
|
for i in range(n_masks) |
|
|
] |
|
|
|
|
|
for mask, color in zip(masks, colors): |
|
|
mask_img = Image.fromarray(mask) |
|
|
overlay = Image.new("RGBA", image.size, color + (0,)) |
|
|
alpha = mask_img.point(lambda v: int(v * 0.5)) |
|
|
overlay.putalpha(alpha) |
|
|
image = Image.alpha_composite(image, overlay) |
|
|
return image |
|
|
|
|
|
spaces.GPU() |
|
|
|
|
|
def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float): |
|
|
""" |
|
|
Perform promptable concept segmentation using SAM3. |
|
|
""" |
|
|
if image is None: |
|
|
return None, "❌ Please upload an image." |
|
|
|
|
|
try: |
|
|
inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
results = processor.post_process_instance_segmentation( |
|
|
outputs, |
|
|
threshold=threshold, |
|
|
mask_threshold=mask_threshold, |
|
|
target_sizes=inputs.get("original_sizes").tolist() |
|
|
)[0] |
|
|
|
|
|
n_masks = len(results['masks']) |
|
|
if n_masks == 0: |
|
|
return image, f"❌ No objects found matching '{text}' (try adjusting thresholds or changing prompt)." |
|
|
|
|
|
overlaid_image = overlay_masks(image, results["masks"]) |
|
|
|
|
|
scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]]) |
|
|
info = f"✅ Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}" |
|
|
|
|
|
return overlaid_image, info |
|
|
|
|
|
except Exception as e: |
|
|
return image, f"❌ Error during segmentation: {str(e)}" |
|
|
|
|
|
def clear_all(): |
|
|
"""Clear all inputs and outputs""" |
|
|
return None, "", None, 0.5, 0.5 |
|
|
|
|
|
def segment_example(image_path: str, prompt: str): |
|
|
"""Handle example clicks""" |
|
|
image = Image.open(image_path) if image_path else None |
|
|
return segment(image, prompt, 0.5, 0.5) |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
theme=gr.themes.Soft(), |
|
|
title="SAM3 - Promptable Concept Segmentation", |
|
|
css=""" |
|
|
.gradio-container {max-width: 1400px !important;} |
|
|
""" |
|
|
) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# SAM3 - Promptable Concept Segmentation (PCS) |
|
|
|
|
|
**SAM3** performs zero-shot instance segmentation using natural language prompts on images. |
|
|
Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks for all matching objects. |
|
|
|
|
|
Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder) |
|
|
""" |
|
|
) |
|
|
|
|
|
gr.Markdown("### Inputs") |
|
|
with gr.Row(variant="panel"): |
|
|
image_input = gr.Image( |
|
|
label="Input Image", |
|
|
type="pil", |
|
|
height=400, |
|
|
) |
|
|
image_output = gr.Image( |
|
|
label="Output (Segmented Image)", |
|
|
height=400, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
text_input = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
placeholder="e.g., a person, ear, cat, bicycle...", |
|
|
scale=3 |
|
|
) |
|
|
clear_btn = gr.Button("🔍 Clear", size="sm", variant="secondary") |
|
|
|
|
|
with gr.Row(): |
|
|
thresh_slider = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.5, |
|
|
step=0.01, |
|
|
label="Detection Threshold", |
|
|
info="Higher values = fewer detections (objectness confidence)" |
|
|
) |
|
|
mask_thresh_slider = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.5, |
|
|
step=0.01, |
|
|
label="Mask Threshold", |
|
|
info="Higher values = sharper masks" |
|
|
) |
|
|
|
|
|
info_output = gr.Markdown( |
|
|
value="📝 Enter a prompt and click **Segment** to start.", |
|
|
label="Info / Results" |
|
|
) |
|
|
|
|
|
segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg") |
|
|
|
|
|
|
|
|
clear_btn.click( |
|
|
fn=clear_all, |
|
|
outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider] |
|
|
) |
|
|
|
|
|
|
|
|
segment_btn.click( |
|
|
fn=segment, |
|
|
inputs=[image_input, text_input, thresh_slider, mask_thresh_slider], |
|
|
outputs=[image_output, info_output] |
|
|
).then( |
|
|
fn=lambda: None, |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### Notes |
|
|
- **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3) |
|
|
- Supports natural language prompts like "a red car" or simple nouns. |
|
|
- GPU recommended for faster inference. |
|
|
- Thresholds control detection sensitivity and mask quality. |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True) |