Spaces:
Build error
Build error
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| from segment_anything import sam_model_registry, SamPredictor | |
| from PIL import Image | |
| # ===== Load SAM2 model ===== | |
| CHECKPOINT_PATH = "sam2_huge.pth" # Place your checkpoint in the repo | |
| MODEL_TYPE = "vit_h" # Adjust based on checkpoint | |
| DEVICE = "cuda" # or "cpu" | |
| sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH) | |
| sam.to(DEVICE) | |
| predictor = SamPredictor(sam) | |
| # ===== State variables ===== | |
| points_state = [] | |
| labels_state = [] | |
| box_state = [] | |
| def set_image(image): | |
| global points_state, labels_state, box_state | |
| points_state, labels_state, box_state = [], [], [] | |
| predictor.set_image(image) | |
| return image | |
| def add_point(x, y, label): | |
| points_state.append([x, y]) | |
| labels_state.append(label) | |
| return run_prediction() | |
| def set_box(x1, y1, x2, y2): | |
| global box_state | |
| box_state = [x1, y1, x2, y2] | |
| return run_prediction() | |
| def run_prediction(): | |
| points_np = np.array(points_state) if points_state else None | |
| labels_np = np.array(labels_state) if labels_state else None | |
| box_np = np.array(box_state) if len(box_state) == 4 else None | |
| masks, _, _ = predictor.predict( | |
| point_coords=points_np, | |
| point_labels=labels_np, | |
| box=box_np[None, :] if box_np is not None else None, | |
| multimask_output=False | |
| ) | |
| mask = masks[0] | |
| overlay = overlay_mask(predictor.original_image, mask) | |
| return overlay | |
| def overlay_mask(image, mask, color=(0, 255, 0), alpha=0.5): | |
| overlay = image.copy() | |
| overlay[mask] = (overlay[mask] * (1 - alpha) + np.array(color) * alpha).astype(np.uint8) | |
| return overlay | |
| # ===== Gradio Interface ===== | |
| with gr.Blocks() as demo: | |
| img_input = gr.Image(label="Upload Image", type="numpy") | |
| img_output = gr.Image(label="Segmentation Output", type="numpy") | |
| img_input.change(set_image, inputs=img_input, outputs=img_output) | |
| # JS events will trigger these | |
| gr.Button("Add Positive Point").click( | |
| lambda: add_point(latest_click_x, latest_click_y, 1), outputs=img_output | |
| ) | |
| gr.Button("Add Negative Point").click( | |
| lambda: add_point(latest_click_x, latest_click_y, 0), outputs=img_output | |
| ) | |
| demo.launch() | |