File size: 6,402 Bytes
75921b2
ff645cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3ba680
ff645cc
 
 
 
 
 
 
 
30a638e
ff645cc
 
30a638e
 
 
 
 
ff645cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed119eb
 
 
 
 
 
 
 
 
ff645cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed119eb
ff645cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30a638e
 
 
 
 
 
 
 
 
 
 
 
 
 
ed119eb
 
 
 
 
 
 
ff645cc
 
 
 
 
ed119eb
ff645cc
 
 
 
 
 
 
 
 
 
 
 
 
 
fd4d970
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import spaces
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib
from transformers import Sam3Processor, Sam3Model
import warnings
warnings.filterwarnings("ignore")

# Global model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")

def overlay_masks(image: Image.Image, masks: torch.Tensor) -> Image.Image:
    """
    Overlay segmentation masks on the input image using rainbow colormap.
    """
    image = image.convert("RGBA")
    masks = 255 * masks.cpu().numpy().astype(np.uint8)
    
    n_masks = masks.shape[0]
    if n_masks == 0:
        return image.convert("RGB")
    
    cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks)
    colors = [
        tuple(int(c * 255) for c in cmap(i)[:3])
        for i in range(n_masks)
    ]

    for mask, color in zip(masks, colors):
        mask_img = Image.fromarray(mask)
        overlay = Image.new("RGBA", image.size, color + (0,))
        alpha = mask_img.point(lambda v: int(v * 0.5))
        overlay.putalpha(alpha)
        image = Image.alpha_composite(image, overlay)
    return image

@spaces.GPU()
def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
    """
    Perform promptable concept segmentation using SAM3.
    """
    if image is None:
        return None, "❌ Please upload an image."
    
    try:
        # Ensure inputs match model dtype
        inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
        
        # Convert inputs to match model dtype
        for key in inputs:
            if inputs[key].dtype == torch.float32:
                inputs[key] = inputs[key].to(model.dtype)
        
        with torch.no_grad():
            outputs = model(**inputs)
        
        results = processor.post_process_instance_segmentation(
            outputs,
            threshold=threshold,
            mask_threshold=mask_threshold,
            target_sizes=inputs.get("original_sizes").tolist()
        )[0]
        
        n_masks = len(results['masks'])
        if n_masks == 0:
            return image, f"❌ No objects found matching '{text}' (try adjusting thresholds or changing prompt)."
        
        overlaid_image = overlay_masks(image, results["masks"])
        
        scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]])  # Top 5 scores
        info = f"βœ… Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}"
        
        return overlaid_image, info
        
    except Exception as e:
        return image, f"❌ Error during segmentation: {str(e)}"

def clear_all():
    """Clear all inputs and outputs"""
    return None, "", None, 0.5, 0.5

def segment_example(image_path: str, prompt: str):
    """Handle example clicks"""
    image = Image.open(image_path) if image_path else None
    return segment(image, prompt, 0.5, 0.5)

# Gradio Interface
with gr.Blocks(
    theme=gr.themes.Soft(),
    title="SAM3 - Promptable Concept Segmentation",
    css="""
    .gradio-container {max-width: 1400px !important;}
    """
) as demo:
    gr.Markdown(
        """
        # SAM3 - Promptable Concept Segmentation (PCS)
        
        **SAM3** performs zero-shot instance segmentation using natural language prompts on images.
        Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks for all matching objects.
        
        Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
        """
    )
    
    gr.Markdown("### Inputs")
    with gr.Row(variant="panel"):
        image_input = gr.Image(
            label="Input Image",
            type="pil",
            height=400,
        )
        image_output = gr.Image(
            label="Output (Segmented Image)",
            height=400,
            interactive=False
        )
    
    with gr.Row():
        text_input = gr.Textbox(
            label="Text Prompt",
            placeholder="e.g., a person, ear, cat, bicycle...",
            scale=3
        )
        clear_btn = gr.Button("πŸ” Clear", size="sm", variant="secondary")
    
    with gr.Row():
        thresh_slider = gr.Slider(
            minimum=0.0,
            maximum=1.0,
            value=0.5,
            step=0.01,
            label="Detection Threshold",
            info="Higher values = fewer detections (objectness confidence)"
        )
        mask_thresh_slider = gr.Slider(
            minimum=0.0,
            maximum=1.0,
            value=0.5,
            step=0.01,
            label="Mask Threshold",
            info="Higher values = sharper masks"
        )
    
    info_output = gr.Markdown(
        value="πŸ“ Enter a prompt and click **Segment** to start.",
        label="Info / Results"
    )
    
    segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
    
    # Add some example prompts
    gr.Examples(
        examples=[
            ["examples/person.jpg", "person"],
            ["examples/car.jpg", "car"],
            ["examples/dog.jpg", "dog"],
            ["examples/building.jpg", "building"],
        ],
        inputs=[image_input, text_input],
        outputs=[image_output, info_output],
        fn=segment_example,
        cache_examples=True,
    )
    
    # Clear button handler
    clear_btn.click(
        fn=clear_all,
        outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider]
    )
    
    # Segment button handler
    segment_btn.click(
        fn=segment,
        inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
        outputs=[image_output, info_output]
    ).then(
        fn=lambda: None,
    )
    
    
    gr.Markdown(
        """
        ### Notes
        - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
        - Supports natural language prompts like "a red car" or simple nouns.
        - GPU recommended for faster inference.
        - Thresholds control detection sensitivity and mask quality.
        """
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)