Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import pipeline | |
| import numpy as np | |
| from PIL import Image | |
| # Load both models (Base version to keep it fast/stable on CPU) | |
| sam_pipe = pipeline("mask-generation", model="facebook/sam-vit-base", device=-1) | |
| text_pipe = pipeline("image-segmentation", model="CIDAS/clipseg-rd64-refined") | |
| def segment_logic(input_img, mode, text_query): | |
| if mode == "Automatic (Segment Everything)": | |
| # Standard SAM logic | |
| outputs = sam_pipe(input_img, points_per_side=10) | |
| masks = outputs["masks"] | |
| overlay = np.zeros((input_img.size[1], input_img.size[0], 3), dtype=np.uint8) | |
| for mask in masks: | |
| color = np.random.randint(0, 255, (3,)) | |
| overlay[mask] = color | |
| return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5) | |
| elif mode == "Text Prompt": | |
| # CLIPSeg logic: It understands "dog", "shirt", etc. | |
| if not text_query: return input_img | |
| result = text_pipe(input_img, prompt=text_query) | |
| # CLIPSeg returns a grayscale mask; we colorize it red | |
| mask = np.array(result["mask"]) | |
| overlay = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) | |
| overlay[mask > 100] = [255, 0, 0] # Red for the text match | |
| return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5) | |
| # Build the UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# SAM + Text Segmentation") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_in = gr.Image(type="pil") | |
| mode_select = gr.Radio(["Automatic (Segment Everything)", "Text Prompt", "Point Click"], | |
| label="Select Mode", | |
| value="Automatic (Segment Everything)") | |
| text_box = gr.Textbox(label="Enter Object Name", visible=False) | |
| with gr.Column(): | |
| img_out = gr.Image(type="pil") | |
| # Show/Hide textbox based on mode | |
| mode_select.change(lambda x: gr.update(visible=(x == "Text Prompt")), mode_select, text_box) | |
| btn = gr.Button("Run Segmentation") | |
| btn.click(segment_logic, inputs=[img_in, mode_select, text_box], outputs=img_out) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |