DmitryRu777's picture
Added pointing, it was forgotten
1dfb3ec
import gradio as gr
import torch
from transformers import pipeline
import numpy as np
from PIL import Image
# Load both models (Base version to keep it fast/stable on CPU)
sam_pipe = pipeline("mask-generation", model="facebook/sam-vit-base", device=-1)
text_pipe = pipeline("image-segmentation", model="CIDAS/clipseg-rd64-refined")
def segment_logic(input_img, mode, text_query):
if mode == "Automatic (Segment Everything)":
# Standard SAM logic
outputs = sam_pipe(input_img, points_per_side=10)
masks = outputs["masks"]
overlay = np.zeros((input_img.size[1], input_img.size[0], 3), dtype=np.uint8)
for mask in masks:
color = np.random.randint(0, 255, (3,))
overlay[mask] = color
return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5)
elif mode == "Text Prompt":
# CLIPSeg logic: It understands "dog", "shirt", etc.
if not text_query: return input_img
result = text_pipe(input_img, prompt=text_query)
# CLIPSeg returns a grayscale mask; we colorize it red
mask = np.array(result["mask"])
overlay = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
overlay[mask > 100] = [255, 0, 0] # Red for the text match
return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5)
# Build the UI
with gr.Blocks() as demo:
gr.Markdown("# SAM + Text Segmentation")
with gr.Row():
with gr.Column():
img_in = gr.Image(type="pil")
mode_select = gr.Radio(["Automatic (Segment Everything)", "Text Prompt", "Point Click"],
label="Select Mode",
value="Automatic (Segment Everything)")
text_box = gr.Textbox(label="Enter Object Name", visible=False)
with gr.Column():
img_out = gr.Image(type="pil")
# Show/Hide textbox based on mode
mode_select.change(lambda x: gr.update(visible=(x == "Text Prompt")), mode_select, text_box)
btn = gr.Button("Run Segmentation")
btn.click(segment_logic, inputs=[img_in, mode_select, text_box], outputs=img_out)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)