Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """Gradio demo for UnSAMv2 interactive image segmentation with Hugging Face ZeroGPU support.""" | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import sys | |
| import threading | |
| from pathlib import Path | |
| from typing import List, Optional, Sequence, Tuple | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| try: | |
| import spaces # type: ignore | |
| except ImportError: # pragma: no cover - optional dependency on Spaces runtime | |
| spaces = None | |
| REPO_ROOT = Path(__file__).resolve().parent | |
| SAM2_REPO = REPO_ROOT / "sam2" | |
| if SAM2_REPO.exists(): | |
| sys.path.insert(0, str(SAM2_REPO)) | |
| from sam2.build_sam import build_sam2 # noqa: E402 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402 | |
| logging.basicConfig(level=os.getenv("UNSAMV2_LOGLEVEL", "INFO")) | |
| LOGGER = logging.getLogger("unsamv2-gradio") | |
| CONFIG_PATH = os.getenv("UNSAMV2_CONFIG", "configs/unsamv2_small.yaml") | |
| CKPT_PATH = Path( | |
| os.getenv("UNSAMV2_CKPT", SAM2_REPO / "checkpoints" / "unsamv2_plus_ckpt.pt") | |
| ).resolve() | |
| if not CKPT_PATH.exists(): | |
| raise FileNotFoundError( | |
| f"Checkpoint not found at {CKPT_PATH}. Set UNSAMV2_CKPT to a valid .pt file." | |
| ) | |
| GRANULARITY_MIN = float(os.getenv("UNSAMV2_GRAN_MIN", 0.1)) | |
| GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0)) | |
| ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"} | |
| ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60")) | |
| POINT_MODE_TO_LABEL = {"Foreground (+)": 1, "Background (-)": 0} | |
| POINT_COLORS_BGR = { | |
| 1: (72, 201, 127), # green-ish for positives | |
| 0: (64, 76, 225), # red-ish for negatives | |
| } | |
| MASK_COLOR_BGR = (0, 196, 255) | |
| OUTLINE_COLOR_BGR = (0, 165, 255) | |
| class ModelManager: | |
| """Keeps SAM2 models on each device and spawns lightweight predictors.""" | |
| def __init__(self) -> None: | |
| self._models: dict[str, torch.nn.Module] = {} | |
| self._lock = threading.Lock() | |
| def _build(self, device: torch.device) -> torch.nn.Module: | |
| LOGGER.info("Loading UnSAMv2 weights onto %s", device) | |
| return build_sam2( | |
| CONFIG_PATH, | |
| ckpt_path=str(CKPT_PATH), | |
| device=device, | |
| mode="eval", | |
| ) | |
| def get_model(self, device: torch.device) -> torch.nn.Module: | |
| key = ( | |
| f"{device.type}:{device.index}" | |
| if device.type == "cuda" | |
| else device.type | |
| ) | |
| with self._lock: | |
| if key not in self._models: | |
| self._models[key] = self._build(device) | |
| return self._models[key] | |
| def make_predictor(self, device: torch.device) -> SAM2ImagePredictor: | |
| return SAM2ImagePredictor(self.get_model(device)) | |
| MODEL_MANAGER = ModelManager() | |
| def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]: | |
| if image is None: | |
| return None | |
| img = image[..., :3] # drop alpha if present | |
| if img.dtype == np.float32 or img.dtype == np.float64: | |
| if img.max() <= 1.0: | |
| img = (img * 255).clip(0, 255).astype(np.uint8) | |
| else: | |
| img = img.clip(0, 255).astype(np.uint8) | |
| elif img.dtype != np.uint8: | |
| img = img.clip(0, 255).astype(np.uint8) | |
| return img | |
| def choose_device() -> torch.device: | |
| preference = os.getenv("UNSAMV2_DEVICE", "auto").lower() | |
| if preference == "cpu": | |
| return torch.device("cpu") | |
| if preference.startswith("cuda") or preference == "gpu": | |
| if torch.cuda.is_available(): | |
| return torch.device(preference if preference.startswith("cuda") else "cuda") | |
| LOGGER.warning("CUDA requested but not available; defaulting to CPU") | |
| return torch.device("cpu") | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor: | |
| tensor = torch.tensor([[[[value]]]], dtype=torch.float32, device=device) | |
| return tensor | |
| def draw_overlay( | |
| image: np.ndarray, | |
| mask: Optional[np.ndarray], | |
| points: Sequence[Sequence[float]], | |
| labels: Sequence[int], | |
| alpha: float = 0.55, | |
| ) -> np.ndarray: | |
| canvas_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| if mask is not None: | |
| mask_bool = mask.astype(bool) | |
| overlay = np.zeros_like(canvas_bgr, dtype=np.uint8) | |
| overlay[mask_bool] = MASK_COLOR_BGR | |
| canvas_bgr = np.where( | |
| mask_bool[..., None], | |
| (canvas_bgr * (1.0 - alpha) + overlay * alpha).astype(np.uint8), | |
| canvas_bgr, | |
| ) | |
| contours, _ = cv2.findContours( | |
| mask_bool.astype(np.uint8), | |
| mode=cv2.RETR_EXTERNAL, | |
| method=cv2.CHAIN_APPROX_SIMPLE, | |
| ) | |
| cv2.drawContours(canvas_bgr, contours, -1, OUTLINE_COLOR_BGR, 2) | |
| for (x, y), lbl in zip(points, labels): | |
| color = POINT_COLORS_BGR.get(lbl, (255, 255, 255)) | |
| center = (int(round(x)), int(round(y))) | |
| cv2.circle(canvas_bgr, center, 7, color, thickness=-1, lineType=cv2.LINE_AA) | |
| cv2.circle(canvas_bgr, center, 9, (255, 255, 255), thickness=2, lineType=cv2.LINE_AA) | |
| return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB) | |
| def points_table(points: Sequence[Sequence[float]], labels: Sequence[int]) -> List[List[str]]: | |
| table = [] | |
| for idx, ((x, y), lbl) in enumerate(zip(points, labels), start=1): | |
| table.append([ | |
| idx, | |
| round(float(x), 1), | |
| round(float(y), 1), | |
| "fg" if lbl == 1 else "bg", | |
| ]) | |
| return table | |
| def handle_image_upload(image: Optional[np.ndarray]): | |
| img = ensure_uint8(image) | |
| if img is None: | |
| return ( | |
| None, | |
| None, | |
| None, | |
| [], | |
| [], | |
| [], | |
| "Upload an image to start adding clicks.", | |
| ) | |
| return ( | |
| img, | |
| None, | |
| img, | |
| [], | |
| [], | |
| [], | |
| "Image loaded. Choose click type, then tap on the image.", | |
| ) | |
| def handle_click( | |
| point_mode: str, | |
| pts: List[Sequence[float]], | |
| lbls: List[int], | |
| image: Optional[np.ndarray], | |
| evt: gr.SelectData, | |
| ): | |
| if image is None: | |
| return ( | |
| gr.update(), | |
| None, | |
| pts, | |
| lbls, | |
| points_table(pts, lbls), | |
| "Upload an image first.", | |
| ) | |
| coord = evt.index # (x, y) | |
| if coord is None: | |
| return ( | |
| gr.update(), | |
| None, | |
| pts, | |
| lbls, | |
| points_table(pts, lbls), | |
| "Couldn't read click position.", | |
| ) | |
| x, y = coord | |
| label = POINT_MODE_TO_LABEL.get(point_mode, 1) | |
| pts = pts + [[float(x), float(y)]] | |
| lbls = lbls + [label] | |
| overlay = draw_overlay(image, None, pts, lbls) | |
| status = f"Added {'positive' if label == 1 else 'negative'} click at ({int(x)}, {int(y)})." | |
| return overlay, None, pts, lbls, points_table(pts, lbls), status | |
| def undo_last_click(image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int]): | |
| if not pts: | |
| return ( | |
| gr.update(), | |
| None, | |
| pts, | |
| lbls, | |
| points_table(pts, lbls), | |
| "No clicks to undo.", | |
| ) | |
| pts = pts[:-1] | |
| lbls = lbls[:-1] | |
| overlay = draw_overlay(image, None, pts, lbls) if image is not None else None | |
| status = "Removed the last click." | |
| return overlay, None, pts, lbls, points_table(pts, lbls), status | |
| def clear_clicks(image: Optional[np.ndarray]): | |
| overlay = image if image is not None else None | |
| return overlay, None, [], [], [], "Cleared all clicks." | |
| def _run_segmentation( | |
| image: Optional[np.ndarray], | |
| pts: List[Sequence[float]], | |
| lbls: List[int], | |
| granularity: float, | |
| ): | |
| img = ensure_uint8(image) | |
| if img is None: | |
| return None, None, "Upload an image to segment." | |
| if not pts: | |
| return draw_overlay(img, None, [], []), None, "Add at least one click before running segmentation." | |
| device = choose_device() | |
| predictor = MODEL_MANAGER.make_predictor(device) | |
| predictor.set_image(img) | |
| coords = np.asarray(pts, dtype=np.float32) | |
| labels = np.asarray(lbls, dtype=np.int32) | |
| gran_tensor = build_granularity_tensor(granularity, predictor.device) | |
| masks, scores, _ = predictor.predict( | |
| point_coords=coords, | |
| point_labels=labels, | |
| multimask_output=True, | |
| gra=float(granularity), | |
| granularity=gran_tensor, | |
| ) | |
| best_idx = int(np.argmax(scores)) | |
| best_mask = masks[best_idx].astype(bool) | |
| overlay = draw_overlay(img, best_mask, pts, lbls) | |
| mask_vis = (best_mask.astype(np.uint8) * 255) | |
| status = f"Best mask #{best_idx + 1} IoU score: {float(scores[best_idx]):.3f} | granularity={granularity:.2f}" | |
| return overlay, mask_vis, status | |
| if spaces is not None and ZERO_GPU_ENABLED: | |
| segment_fn = spaces.GPU(duration=ZERO_GPU_DURATION)(_run_segmentation) | |
| else: | |
| segment_fn = _run_segmentation | |
| def build_demo() -> gr.Blocks: | |
| with gr.Blocks(title="UnSAMv2 Interactive Segmentation", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """## UnSAMv2 路 Interactive Granularity Control | |
| Upload an image, add positive/negative clicks, tune granularity, and run segmentation. | |
| ZeroGPU automatically pulls a GPU when available; otherwise the app falls back to CPU.""" | |
| ) | |
| image_state = gr.State() | |
| points_state = gr.State([]) | |
| labels_state = gr.State([]) | |
| with gr.Row(): | |
| image_input = gr.Image( | |
| label="1 路 Upload image & click to add prompts", | |
| type="numpy", | |
| height=480, | |
| ) | |
| overlay_output = gr.Image( | |
| label="Segmentation preview", | |
| interactive=False, | |
| height=480, | |
| ) | |
| mask_output = gr.Image( | |
| label="Binary mask", | |
| interactive=False, | |
| height=480, | |
| ) | |
| with gr.Row(): | |
| point_mode = gr.Radio( | |
| choices=list(POINT_MODE_TO_LABEL.keys()), | |
| value="Foreground (+)", | |
| label="Click type", | |
| ) | |
| granularity_slider = gr.Slider( | |
| minimum=GRANULARITY_MIN, | |
| maximum=GRANULARITY_MAX, | |
| value=0.2, | |
| step=0.05, | |
| label="Granularity", | |
| info="Lower = finer details, Higher = coarser regions", | |
| ) | |
| segment_button = gr.Button("3 路 Segment", variant="primary") | |
| with gr.Row(): | |
| undo_button = gr.Button("Undo last click") | |
| clear_button = gr.Button("Clear clicks") | |
| points_table_output = gr.Dataframe( | |
| headers=["#", "x", "y", "type"], | |
| datatype=["number", "number", "number", "str"], | |
| interactive=False, | |
| label="2 路 Click history", | |
| ) | |
| status_markdown = gr.Markdown(" Ready.") | |
| image_input.upload( | |
| handle_image_upload, | |
| inputs=[image_input], | |
| outputs=[ | |
| overlay_output, | |
| mask_output, | |
| image_state, | |
| points_state, | |
| labels_state, | |
| points_table_output, | |
| status_markdown, | |
| ], | |
| ) | |
| image_input.clear( | |
| handle_image_upload, | |
| inputs=[image_input], | |
| outputs=[ | |
| overlay_output, | |
| mask_output, | |
| image_state, | |
| points_state, | |
| labels_state, | |
| points_table_output, | |
| status_markdown, | |
| ], | |
| ) | |
| image_input.select( | |
| handle_click, | |
| inputs=[ | |
| point_mode, | |
| points_state, | |
| labels_state, | |
| image_state, | |
| ], | |
| outputs=[ | |
| overlay_output, | |
| mask_output, | |
| points_state, | |
| labels_state, | |
| points_table_output, | |
| status_markdown, | |
| ], | |
| ) | |
| undo_button.click( | |
| undo_last_click, | |
| inputs=[image_state, points_state, labels_state], | |
| outputs=[ | |
| overlay_output, | |
| mask_output, | |
| points_state, | |
| labels_state, | |
| points_table_output, | |
| status_markdown, | |
| ], | |
| ) | |
| clear_button.click( | |
| clear_clicks, | |
| inputs=[image_state], | |
| outputs=[ | |
| overlay_output, | |
| mask_output, | |
| points_state, | |
| labels_state, | |
| points_table_output, | |
| status_markdown, | |
| ], | |
| ) | |
| segment_button.click( | |
| segment_fn, | |
| inputs=[image_state, points_state, labels_state, granularity_slider], | |
| outputs=[overlay_output, mask_output, status_markdown], | |
| ) | |
| demo.queue(max_size=8) | |
| return demo | |
| demo = build_demo() | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), share=True) | |