sam3 / app.py
Quyetnguyen's picture
🎨 Redesign from AnyCoder
9bc11a4 verified
raw
history blame
10.6 kB
gradio
import spaces
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import Sam3Processor, Sam3Model
import requests
import warnings
warnings.filterwarnings("ignore")
# Global model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")
@spaces.GPU()
def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
"""
Perform promptable concept segmentation using SAM3.
Returns format compatible with gr.AnnotatedImage: (image, [(mask, label), ...])
"""
if image is None:
return None, "πŸ“· Please upload an image to begin."
if not text.strip():
return (image, []), "✏️ Enter a text prompt to find objects."
try:
inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
for key in inputs:
if inputs[key].dtype == torch.float32:
inputs[key] = inputs[key].to(model.dtype)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
n_masks = len(results['masks'])
if n_masks == 0:
return (image, []), f"πŸ” No objects found for \"{text}\". Try adjusting the thresholds."
# Format for AnnotatedImage: list of (mask, label) tuples
annotations = []
for i, (mask, score) in enumerate(zip(results['masks'], results['scores'])):
mask_np = mask.cpu().numpy().astype(np.float32)
label = f"{text} #{i+1}"
annotations.append((mask_np, label))
scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:3]])
info = f"**{n_masks}** object(s) found for \"{text}\" | Scores: {scores_text}"
return (image, annotations), info
except Exception as e:
return (image, []), f"⚠️ Segmentation error: {str(e)}"
def clear_all():
"""Clear all inputs and outputs"""
return None, "", None, 0.5, 0.5, "✏️ Enter a prompt and click **Segment** to start."
def segment_example(image_path: str, prompt: str):
"""Handle example clicks"""
if image_path.startswith("http"):
image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
else:
image = Image.open(image_path).convert("RGB")
return segment(image, prompt, 0.5, 0.5)
# Custom CSS for mobile-first responsive design
custom_css = """
/* Mobile-first responsive styles */
@media (max-width: 768px) {
.main-header { text-align: center; padding: 1rem !important; }
.control-panel { padding: 0.75rem !important; }
.slider-group { flex-direction: column; gap: 0.5rem !important; }
.example-grid { grid-template-columns: repeat(2, 1fr) !important; }
}
@media (min-width: 769px) {
.app-layout {
display: grid !important;
grid-template-columns: 1fr 380px !important;
gap: 1.5rem !important;
}
.control-panel {
position: sticky !important;
top: 1rem !important;
height: fit-content !important;
}
}
/* Smooth transitions */
.gradio-container { transition: all 0.3s ease !important; }
/* Modern slider styling */
.slider-label { font-weight: 500 !important; color: var(--body-text-color) !important; }
/* Card-like panels */
.control-panel {
background: var(--background-fill-secondary);
border-radius: var(--radius-lg);
padding: 1.25rem;
border: 1px solid var(--border-color-primary);
}
/* Button improvements */
.primary-btn {
font-weight: 600 !important;
letter-spacing: 0.02em !important;
}
/* Image container */
.image-container {
border-radius: var(--radius-lg);
overflow: hidden;
border: 1px solid var(--border-color-primary);
}
/* Info panel */
.info-panel {
background: var(--background-fill-primary);
border-radius: var(--radius-md);
padding: 1rem;
border-left: 3px solid var(--color-accent);
}
/* Example items */
.example-item {
cursor: pointer !important;
transition: transform 0.2s ease, box-shadow 0.2s ease !important;
}
.example-item:hover {
transform: translateY(-2px);
box-shadow: var(--shadow-drop-lg);
}
"""
# Gradio 6 - NO parameters in Blocks constructor!
with gr.Blocks() as demo:
# Header
gr.HTML("""
<div class="main-header" style="text-align: center; padding: 1.5rem; background: var(--background-fill-primary); border-bottom: 1px solid var(--border-color-primary); margin-bottom: 1rem;">
<h1 style="margin: 0; font-size: 1.75rem; font-weight: 700;">SAM3</h1>
<p style="margin: 0.5rem 0 0; opacity: 0.8; font-size: 0.95rem;">Promptable Concept Segmentation</p>
<div style="margin-top: 0.75rem;">
<a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank" style="color: var(--color-accent); text-decoration: none; font-size: 0.85rem;">Built with anycoder β†—</a>
</div>
</div>
""")
with gr.Row(elem_classes=["app-layout"]):
# Left: Image section
with gr.Column(scale=2):
with gr.Group(elem_classes=["image-container"]):
gr.Markdown("**πŸ“· Image**", elem_classes="slider-label")
image_input = gr.Image(
type="pil",
sources=["upload", "clipboard"],
height=400,
elem_id="input-image"
)
with gr.Group(elem_classes=["image-container", "mt-4"]):
gr.Markdown("**🎯 Segmentation Result**", elem_classes="slider-label")
image_output = gr.AnnotatedImage(
height=400,
show_legend=True,
elem_id="output-image"
)
# Right: Control panel
with gr.Column(scale=1, elem_classes=["control-panel"]):
gr.Markdown("### βš™οΈ Settings", elem_classes="slider-label")
# Text prompt
text_input = gr.Textbox(
label="Text Prompt",
placeholder="e.g., person, cat, car, cup...",
lines=2,
autoscroll=False
)
# Sliders in a row
with gr.Row(elem_classes=["slider-group"]):
thresh_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.01,
label="Detection",
info="Higher = fewer"
)
mask_thresh_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.01,
label="Mask",
info="Higher = sharper"
)
# Buttons
with gr.Row():
segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg", scale=2, elem_classes="primary-btn")
clear_btn = gr.Button("Clear", variant="secondary", size="lg", scale=1)
# Info output
info_output = gr.Markdown(
value="✏️ Enter a prompt and click **Segment** to start.",
elem_classes="info-panel"
)
# Examples
gr.Markdown("### πŸ’‘ Examples", elem_classes="slider-label")
gr.Examples(
examples=[
["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"],
["http://images.cocodataset.org/val2017/000000039769.jpg", "remote"],
["http://images.cocodataset.org/val2017/000000000285.jpg", "person"],
["http://images.cocodataset.org/val2017/000000003899.jpg", "dog"],
],
inputs=[image_input, text_input],
outputs=[image_output, info_output],
fn=segment_example,
cache_examples=False,
examples_per_page=4,
elem_classes="example-grid"
)
# Model info at bottom
gr.HTML("""
<div style="text-align: center; padding: 1rem; opacity: 0.7; font-size: 0.85rem; border-top: 1px solid var(--border-color-primary); margin-top: 1rem;">
Model: <a href="https://huggingface.co/facebook/sam3" target="_blank" style="color: var(--color-accent);">facebook/sam3</a>
β€’ Zero-shot segmentation with natural language prompts
</div>
""")
# Event handlers
clear_btn.click(
fn=clear_all,
outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider, info_output]
)
segment_btn.click(
fn=segment,
inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
outputs=[image_output, info_output]
)
# Also trigger on Enter key
text_input.submit(
fn=segment,
inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
outputs=[image_output, info_output]
)
# Gradio 6 - ALL app parameters go in launch()!
demo.launch(
theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="slate",
neutral_hue="gray",
font=gr.themes.GoogleFont("Inter"),
text_size="md",
spacing_size="md",
radius_size="lg"
).set(
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_700",
button_secondary_background_fill="*neutral_100",
button_secondary_background_fill_hover="*neutral_200",
block_background_fill="*neutral_50",
block_label_background_fill="*neutral_100",
),
css=custom_css,
css_paths=None,
js=None,
head=None,
title="SAM3 - Promptable Concept Segmentation",
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True,
footer_links=[
{"label": "anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"},
{"label": "Model", "url": "https://huggingface.co/facebook/sam3"}
]
)