Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation, pipeline | |
| import numpy as np | |
| from PIL import Image | |
| # Load models | |
| # sam_pipe = pipeline("mask-generation", model="facebook/sam-vit-base", device=-1) | |
| # sam_pipe = pipeline("mask-generation", model="syscv-community/sam-hq-vit-huge", device=-1) | |
| sam_pipe = pipeline("mask-generation", model="facebook/sam2-hiera-large", device=-1) | |
| processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| POINTS_PER_BATCH = 32 | |
| POINTS_PER_SIDE = 32 | |
| def add_point(points_state, labels_state, point_type, evt: gr.SelectData): | |
| new_point = list(evt.index) | |
| points_state.append(new_point) | |
| label = 1 if point_type == "Add Object (Positive)" else 0 | |
| labels_state.append(label) | |
| display_text = " | ".join([f"P{i+1}: {p} ({'Pos' if l==1 else 'Neg'})" | |
| for i, (p, l) in enumerate(zip(points_state, labels_state))]) | |
| return points_state, labels_state, display_text | |
| def handle_button(input_img, mode, text_query, points_state, labels_state): | |
| if not input_img: return None | |
| # --- MODE: POINT CLICK --- | |
| if mode == "Point Click": | |
| if not points_state: | |
| gr.Warning("Please click on the image to add points first!") | |
| return input_img | |
| outputs = sam_pipe(input_img, input_points=[points_state], input_labels=[labels_state]) | |
| w, h = input_img.size | |
| overlay = np.zeros((h, w, 3), dtype=np.uint8) | |
| for mask in outputs["masks"]: | |
| overlay[mask] = [0, 255, 0] # Green | |
| return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5) | |
| # --- MODE: AUTOMATIC --- | |
| elif mode == "Automatic (Segment Everything)": | |
| outputs = sam_pipe(input_img, points_per_batch=POINTS_PER_BATCH, points_per_side=POINTS_PER_SIDE) | |
| w, h = input_img.size | |
| overlay = np.zeros((h, w, 3), dtype=np.uint8) | |
| for mask in outputs["masks"]: | |
| overlay[mask] = np.random.randint(0, 255, (3,)) | |
| return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5) | |
| # --- MODE: TEXT PROMPT --- | |
| elif mode == "Text Prompt": | |
| if not text_query: return input_img | |
| prompts = [p.strip() for p in text_query.split(",")] | |
| inputs = processor(text=prompts, images=[input_img] * len(prompts), padding="max_length", return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| preds = torch.sigmoid(outputs.logits) | |
| if len(prompts) == 1: preds = preds.unsqueeze(0) | |
| w, h = input_img.size | |
| overlay = np.zeros((h, w, 3), dtype=np.uint8) | |
| for mask in preds: | |
| mask_np = (mask.numpy() > 0.1).astype(np.uint8) | |
| mask_resized = np.array(Image.fromarray(mask_np * 255).resize((w, h), resample=Image.NEAREST)) | |
| overlay[mask_resized > 0] = [255, 0, 0] # Red | |
| return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5) | |
| def reset_all(): | |
| # Matches the 7 outputs in btn_clear.click | |
| return None, None, [], [], "Automatic (Segment Everything)", "", "No points selected", "Add Object (Positive)" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# SAM Advanced: Points, Text, and Auto") | |
| points_state = gr.State([]) | |
| labels_state = gr.State([]) | |
| with gr.Row(): | |
| img_in = gr.Image(type="pil", label="Input (Click to add points)", interactive=True) | |
| img_out = gr.Image(type="pil", label="Output") | |
| coord_bar = gr.Textbox(label="Selected Coordinates [x, y]", value="No points selected", interactive=False) | |
| with gr.Row(): | |
| mode_select = gr.Radio( | |
| ["Automatic (Segment Everything)", "Text Prompt", "Point Click"], | |
| label="Mode", value="Automatic (Segment Everything)" | |
| ) | |
| text_box = gr.Textbox(label="Labels (comma separated)", visible=False) | |
| # FIXED RADIO LIST | |
| point_type = gr.Radio( | |
| choices=["Add Object (Positive)", "Exclude Area (Negative)"], | |
| label="Click Type", | |
| value="Add Object (Positive)", | |
| visible=False | |
| ) | |
| # Toggle visibility | |
| mode_select.change(lambda x: gr.update(visible=(x == "Text Prompt")), mode_select, text_box) | |
| mode_select.change(lambda x: gr.update(visible=(x == "Point Click")), mode_select, point_type) | |
| with gr.Row(): | |
| btn_run = gr.Button("Start Segmentation", variant="primary") | |
| btn_clear = gr.Button("Reset Everything") | |
| img_in.select(add_point, inputs=[points_state, labels_state, point_type], outputs=[points_state, labels_state, coord_bar]) | |
| btn_run.click(handle_button, inputs=[img_in, mode_select, text_box, points_state, labels_state], outputs=img_out) | |
| # Updated to handle all 7 output components correctly | |
| btn_clear.click(reset_all, outputs=[img_in, img_out, points_state, labels_state, mode_select, text_box, coord_bar, point_type]) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |