import gradio as gr import torch from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation, pipeline import numpy as np from PIL import Image # Load models # sam_pipe = pipeline("mask-generation", model="facebook/sam-vit-base", device=-1) # sam_pipe = pipeline("mask-generation", model="syscv-community/sam-hq-vit-huge", device=-1) sam_pipe = pipeline("mask-generation", model="facebook/sam2-hiera-large", device=-1) processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") POINTS_PER_BATCH = 32 POINTS_PER_SIDE = 32 def add_point(points_state, labels_state, point_type, evt: gr.SelectData): new_point = list(evt.index) points_state.append(new_point) label = 1 if point_type == "Add Object (Positive)" else 0 labels_state.append(label) display_text = " | ".join([f"P{i+1}: {p} ({'Pos' if l==1 else 'Neg'})" for i, (p, l) in enumerate(zip(points_state, labels_state))]) return points_state, labels_state, display_text def handle_button(input_img, mode, text_query, points_state, labels_state): if not input_img: return None # --- MODE: POINT CLICK --- if mode == "Point Click": if not points_state: gr.Warning("Please click on the image to add points first!") return input_img outputs = sam_pipe(input_img, input_points=[points_state], input_labels=[labels_state]) w, h = input_img.size overlay = np.zeros((h, w, 3), dtype=np.uint8) for mask in outputs["masks"]: overlay[mask] = [0, 255, 0] # Green return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5) # --- MODE: AUTOMATIC --- elif mode == "Automatic (Segment Everything)": outputs = sam_pipe(input_img, points_per_batch=POINTS_PER_BATCH, points_per_side=POINTS_PER_SIDE) w, h = input_img.size overlay = np.zeros((h, w, 3), dtype=np.uint8) for mask in outputs["masks"]: overlay[mask] = np.random.randint(0, 255, (3,)) return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5) # --- MODE: TEXT PROMPT --- elif mode == "Text Prompt": if not text_query: return input_img prompts = [p.strip() for p in text_query.split(",")] inputs = processor(text=prompts, images=[input_img] * len(prompts), padding="max_length", return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) preds = torch.sigmoid(outputs.logits) if len(prompts) == 1: preds = preds.unsqueeze(0) w, h = input_img.size overlay = np.zeros((h, w, 3), dtype=np.uint8) for mask in preds: mask_np = (mask.numpy() > 0.1).astype(np.uint8) mask_resized = np.array(Image.fromarray(mask_np * 255).resize((w, h), resample=Image.NEAREST)) overlay[mask_resized > 0] = [255, 0, 0] # Red return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5) def reset_all(): # Matches the 7 outputs in btn_clear.click return None, None, [], [], "Automatic (Segment Everything)", "", "No points selected", "Add Object (Positive)" with gr.Blocks() as demo: gr.Markdown("# SAM Advanced: Points, Text, and Auto") points_state = gr.State([]) labels_state = gr.State([]) with gr.Row(): img_in = gr.Image(type="pil", label="Input (Click to add points)", interactive=True) img_out = gr.Image(type="pil", label="Output") coord_bar = gr.Textbox(label="Selected Coordinates [x, y]", value="No points selected", interactive=False) with gr.Row(): mode_select = gr.Radio( ["Automatic (Segment Everything)", "Text Prompt", "Point Click"], label="Mode", value="Automatic (Segment Everything)" ) text_box = gr.Textbox(label="Labels (comma separated)", visible=False) # FIXED RADIO LIST point_type = gr.Radio( choices=["Add Object (Positive)", "Exclude Area (Negative)"], label="Click Type", value="Add Object (Positive)", visible=False ) # Toggle visibility mode_select.change(lambda x: gr.update(visible=(x == "Text Prompt")), mode_select, text_box) mode_select.change(lambda x: gr.update(visible=(x == "Point Click")), mode_select, point_type) with gr.Row(): btn_run = gr.Button("Start Segmentation", variant="primary") btn_clear = gr.Button("Reset Everything") img_in.select(add_point, inputs=[points_state, labels_state, point_type], outputs=[points_state, labels_state, coord_bar]) btn_run.click(handle_button, inputs=[img_in, mode_select, text_box, points_state, labels_state], outputs=img_out) # Updated to handle all 7 output components correctly btn_clear.click(reset_all, outputs=[img_in, img_out, points_state, labels_state, mode_select, text_box, coord_bar, point_type]) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=7860)