DmitryRu777's picture
Trying SAM2
4191ebf
import gradio as gr
import torch
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation, pipeline
import numpy as np
from PIL import Image
# Load models
# sam_pipe = pipeline("mask-generation", model="facebook/sam-vit-base", device=-1)
# sam_pipe = pipeline("mask-generation", model="syscv-community/sam-hq-vit-huge", device=-1)
sam_pipe = pipeline("mask-generation", model="facebook/sam2-hiera-large", device=-1)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
POINTS_PER_BATCH = 32
POINTS_PER_SIDE = 32
def add_point(points_state, labels_state, point_type, evt: gr.SelectData):
new_point = list(evt.index)
points_state.append(new_point)
label = 1 if point_type == "Add Object (Positive)" else 0
labels_state.append(label)
display_text = " | ".join([f"P{i+1}: {p} ({'Pos' if l==1 else 'Neg'})"
for i, (p, l) in enumerate(zip(points_state, labels_state))])
return points_state, labels_state, display_text
def handle_button(input_img, mode, text_query, points_state, labels_state):
if not input_img: return None
# --- MODE: POINT CLICK ---
if mode == "Point Click":
if not points_state:
gr.Warning("Please click on the image to add points first!")
return input_img
outputs = sam_pipe(input_img, input_points=[points_state], input_labels=[labels_state])
w, h = input_img.size
overlay = np.zeros((h, w, 3), dtype=np.uint8)
for mask in outputs["masks"]:
overlay[mask] = [0, 255, 0] # Green
return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5)
# --- MODE: AUTOMATIC ---
elif mode == "Automatic (Segment Everything)":
outputs = sam_pipe(input_img, points_per_batch=POINTS_PER_BATCH, points_per_side=POINTS_PER_SIDE)
w, h = input_img.size
overlay = np.zeros((h, w, 3), dtype=np.uint8)
for mask in outputs["masks"]:
overlay[mask] = np.random.randint(0, 255, (3,))
return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5)
# --- MODE: TEXT PROMPT ---
elif mode == "Text Prompt":
if not text_query: return input_img
prompts = [p.strip() for p in text_query.split(",")]
inputs = processor(text=prompts, images=[input_img] * len(prompts), padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
preds = torch.sigmoid(outputs.logits)
if len(prompts) == 1: preds = preds.unsqueeze(0)
w, h = input_img.size
overlay = np.zeros((h, w, 3), dtype=np.uint8)
for mask in preds:
mask_np = (mask.numpy() > 0.1).astype(np.uint8)
mask_resized = np.array(Image.fromarray(mask_np * 255).resize((w, h), resample=Image.NEAREST))
overlay[mask_resized > 0] = [255, 0, 0] # Red
return Image.blend(input_img.convert("RGB"), Image.fromarray(overlay), alpha=0.5)
def reset_all():
# Matches the 7 outputs in btn_clear.click
return None, None, [], [], "Automatic (Segment Everything)", "", "No points selected", "Add Object (Positive)"
with gr.Blocks() as demo:
gr.Markdown("# SAM Advanced: Points, Text, and Auto")
points_state = gr.State([])
labels_state = gr.State([])
with gr.Row():
img_in = gr.Image(type="pil", label="Input (Click to add points)", interactive=True)
img_out = gr.Image(type="pil", label="Output")
coord_bar = gr.Textbox(label="Selected Coordinates [x, y]", value="No points selected", interactive=False)
with gr.Row():
mode_select = gr.Radio(
["Automatic (Segment Everything)", "Text Prompt", "Point Click"],
label="Mode", value="Automatic (Segment Everything)"
)
text_box = gr.Textbox(label="Labels (comma separated)", visible=False)
# FIXED RADIO LIST
point_type = gr.Radio(
choices=["Add Object (Positive)", "Exclude Area (Negative)"],
label="Click Type",
value="Add Object (Positive)",
visible=False
)
# Toggle visibility
mode_select.change(lambda x: gr.update(visible=(x == "Text Prompt")), mode_select, text_box)
mode_select.change(lambda x: gr.update(visible=(x == "Point Click")), mode_select, point_type)
with gr.Row():
btn_run = gr.Button("Start Segmentation", variant="primary")
btn_clear = gr.Button("Reset Everything")
img_in.select(add_point, inputs=[points_state, labels_state, point_type], outputs=[points_state, labels_state, coord_bar])
btn_run.click(handle_button, inputs=[img_in, mode_select, text_box, points_state, labels_state], outputs=img_out)
# Updated to handle all 7 output components correctly
btn_clear.click(reset_all, outputs=[img_in, img_out, points_state, labels_state, mode_select, text_box, coord_bar, point_type])
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)