Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib | |
| from transformers import Sam3Processor, Sam3Model | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # Global model and processor | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device) | |
| processor = Sam3Processor.from_pretrained("facebook/sam3") | |
| def overlay_masks(image: Image.Image, masks: torch.Tensor) -> Image.Image: | |
| """ | |
| Overlay segmentation masks on the input image using rainbow colormap. | |
| """ | |
| image = image.convert("RGBA") | |
| masks = 255 * masks.cpu().numpy().astype(np.uint8) | |
| n_masks = masks.shape[0] | |
| if n_masks == 0: | |
| return image.convert("RGB") | |
| cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks) | |
| colors = [ | |
| tuple(int(c * 255) for c in cmap(i)[:3]) | |
| for i in range(n_masks) | |
| ] | |
| for mask, color in zip(masks, colors): | |
| mask_img = Image.fromarray(mask) | |
| overlay = Image.new("RGBA", image.size, color + (0,)) | |
| alpha = mask_img.point(lambda v: int(v * 0.5)) | |
| overlay.putalpha(alpha) | |
| image = Image.alpha_composite(image, overlay) | |
| return image | |
| def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float): | |
| """ | |
| Perform promptable concept segmentation using SAM3. | |
| """ | |
| if image is None: | |
| return None, "β Please upload an image." | |
| try: | |
| # Ensure inputs match model dtype | |
| inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device) | |
| # Convert inputs to match model dtype | |
| for key in inputs: | |
| if inputs[key].dtype == torch.float32: | |
| inputs[key] = inputs[key].to(model.dtype) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| results = processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=threshold, | |
| mask_threshold=mask_threshold, | |
| target_sizes=inputs.get("original_sizes").tolist() | |
| )[0] | |
| n_masks = len(results['masks']) | |
| if n_masks == 0: | |
| return image, f"β No objects found matching '{text}' (try adjusting thresholds or changing prompt)." | |
| overlaid_image = overlay_masks(image, results["masks"]) | |
| scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]]) # Top 5 scores | |
| info = f"β Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}" | |
| return overlaid_image, info | |
| except Exception as e: | |
| return image, f"β Error during segmentation: {str(e)}" | |
| def clear_all(): | |
| """Clear all inputs and outputs""" | |
| return None, "", None, 0.5, 0.5 | |
| def segment_example(image_path: str, prompt: str): | |
| """Handle example clicks""" | |
| image = Image.open(image_path) if image_path else None | |
| return segment(image, prompt, 0.5, 0.5) | |
| # Gradio Interface | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="SAM3 - Promptable Concept Segmentation", | |
| css=""" | |
| .gradio-container {max-width: 1400px !important;} | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # SAM3 - Promptable Concept Segmentation (PCS) | |
| **SAM3** performs zero-shot instance segmentation using natural language prompts on images. | |
| Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks for all matching objects. | |
| Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder) | |
| """ | |
| ) | |
| gr.Markdown("### Inputs") | |
| with gr.Row(variant="panel"): | |
| image_input = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| height=400, | |
| ) | |
| image_output = gr.Image( | |
| label="Output (Segmented Image)", | |
| height=400, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="e.g., a person, ear, cat, bicycle...", | |
| scale=3 | |
| ) | |
| clear_btn = gr.Button("π Clear", size="sm", variant="secondary") | |
| with gr.Row(): | |
| thresh_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.01, | |
| label="Detection Threshold", | |
| info="Higher values = fewer detections (objectness confidence)" | |
| ) | |
| mask_thresh_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.01, | |
| label="Mask Threshold", | |
| info="Higher values = sharper masks" | |
| ) | |
| info_output = gr.Markdown( | |
| value="π Enter a prompt and click **Segment** to start.", | |
| label="Info / Results" | |
| ) | |
| segment_btn = gr.Button("π― Segment", variant="primary", size="lg") | |
| # Add some example prompts | |
| gr.Examples( | |
| examples=[ | |
| ["examples/person.jpg", "person"], | |
| ["examples/car.jpg", "car"], | |
| ["examples/dog.jpg", "dog"], | |
| ["examples/building.jpg", "building"], | |
| ], | |
| inputs=[image_input, text_input], | |
| outputs=[image_output, info_output], | |
| fn=segment_example, | |
| cache_examples=True, | |
| ) | |
| # Clear button handler | |
| clear_btn.click( | |
| fn=clear_all, | |
| outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider] | |
| ) | |
| # Segment button handler | |
| segment_btn.click( | |
| fn=segment, | |
| inputs=[image_input, text_input, thresh_slider, mask_thresh_slider], | |
| outputs=[image_output, info_output] | |
| ).then( | |
| fn=lambda: None, | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Notes | |
| - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3) | |
| - Supports natural language prompts like "a red car" or simple nouns. | |
| - GPU recommended for faster inference. | |
| - Thresholds control detection sensitivity and mask quality. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True) |