File size: 2,291 Bytes
0564522
d8dc92a
0564522
 
 
 
d8dc92a
 
 
0564522
d8dc92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dfb3ec
 
 
d8dc92a
 
 
0564522
d8dc92a
 
8aa19f1
d8dc92a
 
 
0564522
 
8aa19f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
import torch
from transformers import pipeline
import numpy as np
from PIL import Image

# Load both models (Base version to keep it fast/stable on CPU)
sam_pipe = pipeline("mask-generation", model="facebook/sam-vit-base", device=-1)
text_pipe = pipeline("image-segmentation", model="CIDAS/clipseg-rd64-refined")

def segment_logic(input_img, mode, text_query):
    if mode == "Automatic (Segment Everything)":
        # Standard SAM logic
        outputs = sam_pipe(input_img, points_per_side=10)
        masks = outputs["masks"]
        overlay = np.zeros((input_img.size[1], input_img.size[0], 3), dtype=np.uint8)
        for mask in masks:
            color = np.random.randint(0, 255, (3,))
            overlay[mask] = color
        return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5)

    elif mode == "Text Prompt":
        # CLIPSeg logic: It understands "dog", "shirt", etc.
        if not text_query: return input_img
        result = text_pipe(input_img, prompt=text_query)
        # CLIPSeg returns a grayscale mask; we colorize it red
        mask = np.array(result["mask"])
        overlay = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
        overlay[mask > 100] = [255, 0, 0] # Red for the text match
        return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5)

# Build the UI
with gr.Blocks() as demo:
    gr.Markdown("# SAM + Text Segmentation")
    with gr.Row():
        with gr.Column():
            img_in = gr.Image(type="pil")
            mode_select = gr.Radio(["Automatic (Segment Everything)", "Text Prompt", "Point Click"],
                                    label="Select Mode",
                                    value="Automatic (Segment Everything)")
            text_box = gr.Textbox(label="Enter Object Name", visible=False)
        with gr.Column():
            img_out = gr.Image(type="pil")
    
    # Show/Hide textbox based on mode
    mode_select.change(lambda x: gr.update(visible=(x == "Text Prompt")), mode_select, text_box)
    
    btn = gr.Button("Run Segmentation")
    btn.click(segment_logic, inputs=[img_in, mode_select, text_box], outputs=img_out)


if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)