sam3 / app.py
akhaliq's picture
akhaliq HF Staff
Update app.py
30a638e verified
raw
history blame
6.4 kB
import spaces
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib
from transformers import Sam3Processor, Sam3Model
import warnings
warnings.filterwarnings("ignore")
# Global model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")
def overlay_masks(image: Image.Image, masks: torch.Tensor) -> Image.Image:
"""
Overlay segmentation masks on the input image using rainbow colormap.
"""
image = image.convert("RGBA")
masks = 255 * masks.cpu().numpy().astype(np.uint8)
n_masks = masks.shape[0]
if n_masks == 0:
return image.convert("RGB")
cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks)
colors = [
tuple(int(c * 255) for c in cmap(i)[:3])
for i in range(n_masks)
]
for mask, color in zip(masks, colors):
mask_img = Image.fromarray(mask)
overlay = Image.new("RGBA", image.size, color + (0,))
alpha = mask_img.point(lambda v: int(v * 0.5))
overlay.putalpha(alpha)
image = Image.alpha_composite(image, overlay)
return image
@spaces.GPU()
def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
"""
Perform promptable concept segmentation using SAM3.
"""
if image is None:
return None, "❌ Please upload an image."
try:
# Ensure inputs match model dtype
inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
# Convert inputs to match model dtype
for key in inputs:
if inputs[key].dtype == torch.float32:
inputs[key] = inputs[key].to(model.dtype)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
n_masks = len(results['masks'])
if n_masks == 0:
return image, f"❌ No objects found matching '{text}' (try adjusting thresholds or changing prompt)."
overlaid_image = overlay_masks(image, results["masks"])
scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]]) # Top 5 scores
info = f"βœ… Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}"
return overlaid_image, info
except Exception as e:
return image, f"❌ Error during segmentation: {str(e)}"
def clear_all():
"""Clear all inputs and outputs"""
return None, "", None, 0.5, 0.5
def segment_example(image_path: str, prompt: str):
"""Handle example clicks"""
image = Image.open(image_path) if image_path else None
return segment(image, prompt, 0.5, 0.5)
# Gradio Interface
with gr.Blocks(
theme=gr.themes.Soft(),
title="SAM3 - Promptable Concept Segmentation",
css="""
.gradio-container {max-width: 1400px !important;}
"""
) as demo:
gr.Markdown(
"""
# SAM3 - Promptable Concept Segmentation (PCS)
**SAM3** performs zero-shot instance segmentation using natural language prompts on images.
Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks for all matching objects.
Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
"""
)
gr.Markdown("### Inputs")
with gr.Row(variant="panel"):
image_input = gr.Image(
label="Input Image",
type="pil",
height=400,
)
image_output = gr.Image(
label="Output (Segmented Image)",
height=400,
interactive=False
)
with gr.Row():
text_input = gr.Textbox(
label="Text Prompt",
placeholder="e.g., a person, ear, cat, bicycle...",
scale=3
)
clear_btn = gr.Button("πŸ” Clear", size="sm", variant="secondary")
with gr.Row():
thresh_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.01,
label="Detection Threshold",
info="Higher values = fewer detections (objectness confidence)"
)
mask_thresh_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.01,
label="Mask Threshold",
info="Higher values = sharper masks"
)
info_output = gr.Markdown(
value="πŸ“ Enter a prompt and click **Segment** to start.",
label="Info / Results"
)
segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
# Add some example prompts
gr.Examples(
examples=[
["examples/person.jpg", "person"],
["examples/car.jpg", "car"],
["examples/dog.jpg", "dog"],
["examples/building.jpg", "building"],
],
inputs=[image_input, text_input],
outputs=[image_output, info_output],
fn=segment_example,
cache_examples=True,
)
# Clear button handler
clear_btn.click(
fn=clear_all,
outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider]
)
# Segment button handler
segment_btn.click(
fn=segment,
inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
outputs=[image_output, info_output]
).then(
fn=lambda: None,
)
gr.Markdown(
"""
### Notes
- **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
- Supports natural language prompts like "a red car" or simple nouns.
- GPU recommended for faster inference.
- Thresholds control detection sensitivity and mask quality.
"""
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)