import spaces import gradio as gr import torch import numpy as np from PIL import Image import matplotlib from transformers import Sam3Processor, Sam3Model import warnings warnings.filterwarnings("ignore") # Global model and processor device = "cuda" if torch.cuda.is_available() else "cpu" model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device) processor = Sam3Processor.from_pretrained("facebook/sam3") def overlay_masks(image: Image.Image, masks: torch.Tensor) -> Image.Image: """ Overlay segmentation masks on the input image using rainbow colormap. """ image = image.convert("RGBA") masks = 255 * masks.cpu().numpy().astype(np.uint8) n_masks = masks.shape[0] if n_masks == 0: return image.convert("RGB") cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks) colors = [ tuple(int(c * 255) for c in cmap(i)[:3]) for i in range(n_masks) ] for mask, color in zip(masks, colors): mask_img = Image.fromarray(mask) overlay = Image.new("RGBA", image.size, color + (0,)) alpha = mask_img.point(lambda v: int(v * 0.5)) overlay.putalpha(alpha) image = Image.alpha_composite(image, overlay) return image spaces.GPU() def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float): """ Perform promptable concept segmentation using SAM3. """ if image is None: return None, "❌ Please upload an image." try: inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) results = processor.post_process_instance_segmentation( outputs, threshold=threshold, mask_threshold=mask_threshold, target_sizes=inputs.get("original_sizes").tolist() )[0] n_masks = len(results['masks']) if n_masks == 0: return image, f"❌ No objects found matching '{text}' (try adjusting thresholds or changing prompt)." overlaid_image = overlay_masks(image, results["masks"]) scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]]) # Top 5 scores info = f"✅ Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}" return overlaid_image, info except Exception as e: return image, f"❌ Error during segmentation: {str(e)}" def clear_all(): """Clear all inputs and outputs""" return None, "", None, 0.5, 0.5 def segment_example(image_path: str, prompt: str): """Handle example clicks""" image = Image.open(image_path) if image_path else None return segment(image, prompt, 0.5, 0.5) # Gradio Interface with gr.Blocks( theme=gr.themes.Soft(), title="SAM3 - Promptable Concept Segmentation", css=""" .gradio-container {max-width: 1400px !important;} """ ) as demo: gr.Markdown( """ # SAM3 - Promptable Concept Segmentation (PCS) **SAM3** performs zero-shot instance segmentation using natural language prompts on images. Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks for all matching objects. Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder) """ ) gr.Markdown("### Inputs") with gr.Row(variant="panel"): image_input = gr.Image( label="Input Image", type="pil", height=400, ) image_output = gr.Image( label="Output (Segmented Image)", height=400, interactive=False ) with gr.Row(): text_input = gr.Textbox( label="Text Prompt", placeholder="e.g., a person, ear, cat, bicycle...", scale=3 ) clear_btn = gr.Button("🔍 Clear", size="sm", variant="secondary") with gr.Row(): thresh_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Detection Threshold", info="Higher values = fewer detections (objectness confidence)" ) mask_thresh_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Mask Threshold", info="Higher values = sharper masks" ) info_output = gr.Markdown( value="📝 Enter a prompt and click **Segment** to start.", label="Info / Results" ) segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg") # Clear button handler clear_btn.click( fn=clear_all, outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider] ) # Segment button handler segment_btn.click( fn=segment, inputs=[image_input, text_input, thresh_slider, mask_thresh_slider], outputs=[image_output, info_output] ).then( fn=lambda: None, ) gr.Markdown( """ ### Notes - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3) - Supports natural language prompts like "a red car" or simple nouns. - GPU recommended for faster inference. - Thresholds control detection sensitivity and mask quality. """ ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)