import gradio as gr import torch from transformers import pipeline import numpy as np from PIL import Image # Load both models (Base version to keep it fast/stable on CPU) sam_pipe = pipeline("mask-generation", model="facebook/sam-vit-base", device=-1) text_pipe = pipeline("image-segmentation", model="CIDAS/clipseg-rd64-refined") def segment_logic(input_img, mode, text_query): if mode == "Automatic (Segment Everything)": # Standard SAM logic outputs = sam_pipe(input_img, points_per_side=10) masks = outputs["masks"] overlay = np.zeros((input_img.size[1], input_img.size[0], 3), dtype=np.uint8) for mask in masks: color = np.random.randint(0, 255, (3,)) overlay[mask] = color return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5) elif mode == "Text Prompt": # CLIPSeg logic: It understands "dog", "shirt", etc. if not text_query: return input_img result = text_pipe(input_img, prompt=text_query) # CLIPSeg returns a grayscale mask; we colorize it red mask = np.array(result["mask"]) overlay = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) overlay[mask > 100] = [255, 0, 0] # Red for the text match return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5) # Build the UI with gr.Blocks() as demo: gr.Markdown("# SAM + Text Segmentation") with gr.Row(): with gr.Column(): img_in = gr.Image(type="pil") mode_select = gr.Radio(["Automatic (Segment Everything)", "Text Prompt", "Point Click"], label="Select Mode", value="Automatic (Segment Everything)") text_box = gr.Textbox(label="Enter Object Name", visible=False) with gr.Column(): img_out = gr.Image(type="pil") # Show/Hide textbox based on mode mode_select.change(lambda x: gr.update(visible=(x == "Text Prompt")), mode_select, text_box) btn = gr.Button("Run Segmentation") btn.click(segment_logic, inputs=[img_in, mode_select, text_box], outputs=img_out) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=7860)