sam3 / app.py
Quyetnguyen's picture
🎨 Redesign from AnyCoder
0370cdf verified
raw
history blame
9.12 kB
gradio
import spaces
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import Sam3Processor, Sam3Model
import requests
import warnings
warnings.filterwarnings("ignore")
# Global model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")
@spaces.GPU()
def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
"""
Perform promptable concept segmentation using SAM3.
Returns format compatible with gr.AnnotatedImage: (image, [(mask, label), ...])
"""
if image is None:
return None, "πŸ“· Please upload an image to begin."
if not text.strip():
return (image, []), "✏️ Enter a text prompt (e.g., 'person', 'cat', 'car')."
try:
inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
for key in inputs:
if inputs[key].dtype == torch.float32:
inputs[key] = inputs[key].to(model.dtype)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
n_masks = len(results['masks'])
if n_masks == 0:
return (image, []), f"πŸ” No objects found for **'{text}'** β€” try adjusting thresholds."
annotations = []
for i, (mask, score) in enumerate(zip(results['masks'], results['scores'])):
mask_np = mask.cpu().numpy().astype(np.float32)
label = f"#{i+1} ({score:.2f})"
annotations.append((mask_np, label))
scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]])
info = f"✨ **{n_masks}** objects found for **'{text}'**\n\nConfidence: {scores_text}{'...' if n_masks > 5 else ''}"
return (image, annotations), info
except Exception as e:
return (image, []), f"❌ Error: {str(e)}"
def clear_all():
"""Clear all inputs and outputs"""
return None, "", None, 0.5, 0.5, "πŸ’‘ Enter a prompt and click **Segment** to find objects."
def segment_example(image_path: str, prompt: str):
"""Handle example clicks"""
if image_path.startswith("http"):
image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
else:
image = Image.open(image_path).convert("RGB")
return segment(image, prompt, 0.5, 0.5)
# Custom modern theme
custom_theme = gr.themes.Glass(
primary_hue="slate",
secondary_hue="zinc",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
text_size="md",
spacing_size="lg",
radius_size="md"
).set(
button_primary_background_fill="*neutral_800",
button_primary_background_fill_hover="*neutral_700",
button_secondary_background_fill="*neutral_100",
button_secondary_background_fill_hover="*neutral_200",
block_background_fill="white",
block_secondary_background_fill="*neutral_50",
block_title_text_weight="600",
)
# Main application
with gr.Blocks() as demo:
# Header
gr.HTML("""
<div style="text-align: center; padding: 0.5rem 0; margin-bottom: 0.5rem;">
<h1 style="font-size: 1.75rem; font-weight: 700; margin: 0; color: var(--neutral-800);">
SAM3 <span style="font-weight: 400; color: var(--neutral-500);">Promptable Segmentation</span>
</h1>
<p style="margin: 0.25rem 0 0 0; color: var(--neutral-600); font-size: 0.875rem;">
Zero-shot instance segmentation with natural language
</p>
<a href="https://huggingface.co/spaces/akhaliq/anycoder"
style="color: var(--primary-600); text-decoration: none; font-size: 0.8rem;"
target="_blank">Built with anycoder</a>
</div>
""")
# Main content
with gr.Column(elem_classes=["main-content"]):
# Image section
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=280):
image_input = gr.Image(
label="πŸ“· Upload Image",
type="pil",
height=320,
sources=["upload", "clipboard"],
)
with gr.Column(scale=1, min_width=280):
image_output = gr.AnnotatedImage(
label="🎯 Segmentation Result",
height=320,
show_legend=True,
)
# Info output
info_output = gr.Markdown(
value="πŸ’‘ **Upload an image** and enter a prompt like 'person', 'cat', or 'car'",
elem_classes=["info-box"],
)
# Controls section
with gr.Group(elem_classes=["controls"]):
with gr.Row():
text_input = gr.Textbox(
label="What to find",
placeholder="e.g., person, cat, bicycle...",
scale=4,
)
segment_btn = gr.Button(
"πŸ” Segment",
variant="primary",
size="lg",
scale=1,
min_width=120,
)
with gr.Row():
thresh_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.01,
label="Detection",
info="Confidence threshold",
scale=1,
)
mask_thresh_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.01,
label="Mask",
info="Edge sharpness",
scale=1,
)
clear_btn = gr.Button(
"β†Ί Clear",
variant="secondary",
size="lg",
scale=0,
min_width=80,
)
# Examples
gr.Markdown("### Quick Examples")
gr.Examples(
examples=[
["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"],
["https://images.unsplash.com/photo-1535930483905-2c6d14342d7a", "dog"],
["https://images.unsplash.com/photo-1558618666-fcd25c85cd64", "car"],
],
inputs=[image_input, text_input],
outputs=[image_output, info_output],
fn=segment_example,
cache_examples=False,
examples_per_page=3,
)
# Footer info
gr.Accordion("ℹ️ About", open=False):
gr.Markdown("""
**SAM3** uses natural language prompts for zero-shot instance segmentation.
- **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
- GPU recommended for faster processing
- Works best with specific, clear object names
""")
# Event handlers
segment_btn.click(
fn=segment,
inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
outputs=[image_output, info_output],
api_visibility="public",
)
clear_btn.click(
fn=clear_all,
outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider, info_output],
api_visibility="private",
)
# Custom CSS for responsive design
custom_css = """
@media (max-width: 768px) {
.main-content {
gap: 0.75rem !important;
}
.controls {
gap: 0.75rem !important;
}
.info-box {
font-size: 0.875rem !important;
padding: 0.75rem !important;
}
}
@media (max-width: 480px) {
.gradio-group {
gap: 0.5rem !important;
}
}
.info-box {
background: var(--neutral-50);
border-radius: var(--radius-lg);
padding: 1rem;
border: 1px solid var(--neutral-200);
}
.controls {
background: var(--neutral-50);
border-radius: var(--radius-lg);
padding: 1.25rem;
border: 1px solid var(--neutral-200);
}
.gradio-annotatedimage {
border: 2px dashed var(--neutral-300);
border-radius: var(--radius-lg);
}
.gradio-group {
gap: 1rem !important;
}
"""
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True,
theme=custom_theme,
css=custom_css,
footer_links=[
{"label": "anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"},
{"label": "Model", "url": "https://huggingface.co/facebook/sam3"},
],
)