#!/usr/bin/env python3 """Gradio demo for UnSAMv2 interactive image segmentation with Hugging Face ZeroGPU support.""" from __future__ import annotations import logging import os import sys import threading from pathlib import Path from typing import List, Optional, Sequence import cv2 import gradio as gr import numpy as np import torch try: import spaces # type: ignore except ImportError: # pragma: no cover - optional dependency on Spaces runtime spaces = None REPO_ROOT = Path(__file__).resolve().parent SAM2_REPO = REPO_ROOT / "sam2" if SAM2_REPO.exists(): sys.path.insert(0, str(SAM2_REPO)) from sam2.build_sam import build_sam2 # noqa: E402 from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402 logging.basicConfig(level=logging.INFO) LOGGER = logging.getLogger("unsamv2-gradio") USE_M2M_REFINEMENT = True CONFIG_PATH = os.getenv("UNSAMV2_CONFIG", "configs/unsamv2_small.yaml") CKPT_PATH = Path( os.getenv("UNSAMV2_CKPT", SAM2_REPO / "checkpoints" / "unsamv2_plus_ckpt.pt") ).resolve() if not CKPT_PATH.exists(): raise FileNotFoundError( f"Checkpoint not found at {CKPT_PATH}. Set UNSAMV2_CKPT to a valid .pt file." ) GRANULARITY_MIN = float(os.getenv("UNSAMV2_GRAN_MIN", 0.1)) GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0)) ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"} ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60")) POINT_MODE_TO_LABEL = {"Foreground (+)": 1, "Background (-)": 0} POINT_COLORS_BGR = { 1: (72, 201, 127), # green-ish for positives 0: (64, 76, 225), # red-ish for negatives } MASK_COLOR_BGR = (0, 196, 255) OUTLINE_COLOR_BGR = (0, 165, 255) DEFAULT_IMAGE_PATH = REPO_ROOT / "demo" / "bird.webp" def _load_default_image() -> Optional[np.ndarray]: if not DEFAULT_IMAGE_PATH.exists(): LOGGER.warning("Default image missing at %s", DEFAULT_IMAGE_PATH) return None img_bgr = cv2.imread(str(DEFAULT_IMAGE_PATH), cv2.IMREAD_COLOR) if img_bgr is None: LOGGER.warning("Could not read default image at %s", DEFAULT_IMAGE_PATH) return None return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) DEFAULT_IMAGE = _load_default_image() class ModelManager: """Keeps SAM2 models on each device and spawns lightweight predictors.""" def __init__(self) -> None: self._models: dict[str, torch.nn.Module] = {} self._lock = threading.Lock() def _build(self, device: torch.device) -> torch.nn.Module: LOGGER.info("Loading UnSAMv2 weights onto %s", device) return build_sam2( CONFIG_PATH, ckpt_path=str(CKPT_PATH), device=device, mode="eval", ) def get_model(self, device: torch.device) -> torch.nn.Module: key = ( f"{device.type}:{device.index}" if device.type == "cuda" else device.type ) with self._lock: if key not in self._models: self._models[key] = self._build(device) return self._models[key] def make_predictor(self, device: torch.device) -> SAM2ImagePredictor: return SAM2ImagePredictor(self.get_model(device), mask_threshold=-1.0) MODEL_MANAGER = ModelManager() def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]: if image is None: return None img = image[..., :3] # drop alpha if present if img.dtype == np.float32 or img.dtype == np.float64: if img.max() <= 1.0: img = (img * 255).clip(0, 255).astype(np.uint8) else: img = img.clip(0, 255).astype(np.uint8) elif img.dtype != np.uint8: img = img.clip(0, 255).astype(np.uint8) return img def choose_device() -> torch.device: preference = os.getenv("UNSAMV2_DEVICE", "auto").lower() if preference == "cpu": return torch.device("cpu") if preference.startswith("cuda") or preference == "gpu": if torch.cuda.is_available(): return torch.device(preference if preference.startswith("cuda") else "cuda") LOGGER.warning("CUDA requested but not available; defaulting to CPU") return torch.device("cpu") return torch.device("cuda" if torch.cuda.is_available() else "cpu") def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor: tensor = torch.tensor([[[[value]]]], dtype=torch.float32, device=device) return tensor def apply_m2m_refinement( predictor, point_coords, point_labels, granularity, logits, best_mask_idx, use_m2m: bool = True, ): """Optionally run a second M2M pass using the best mask's logits.""" if not use_m2m: return None logging.info("Applying M2M refinement...") try: if logits is None: raise ValueError("logits must be provided for M2M refinement.") low_res_logits = logits[best_mask_idx : best_mask_idx + 1] refined_masks, refined_scores, _ = predictor.predict( point_coords=point_coords, point_labels=point_labels, multimask_output=False, gra=granularity, mask_input=low_res_logits, ) refined_mask = refined_masks[0] refined_score = float(refined_scores[0]) logging.info("M2M refinement completed with score: %.3f", refined_score) return refined_mask, refined_score except Exception as exc: # pragma: no cover - logging only logging.error("M2M refinement failed: %s, using original mask", exc) return None def draw_overlay( image: np.ndarray, mask: Optional[np.ndarray], points: Sequence[Sequence[float]], labels: Sequence[int], alpha: float = 0.55, ) -> np.ndarray: canvas_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) if mask is not None: mask_bool = mask.astype(bool) overlay = np.zeros_like(canvas_bgr, dtype=np.uint8) overlay[mask_bool] = MASK_COLOR_BGR canvas_bgr = np.where( mask_bool[..., None], (canvas_bgr * (1.0 - alpha) + overlay * alpha).astype(np.uint8), canvas_bgr, ) contours, _ = cv2.findContours( mask_bool.astype(np.uint8), mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE, ) cv2.drawContours(canvas_bgr, contours, -1, OUTLINE_COLOR_BGR, 2) for (x, y), lbl in zip(points, labels): color = POINT_COLORS_BGR.get(lbl, (255, 255, 255)) center = (int(round(x)), int(round(y))) cv2.circle(canvas_bgr, center, 7, color, thickness=-1, lineType=cv2.LINE_AA) cv2.circle(canvas_bgr, center, 9, (255, 255, 255), thickness=2, lineType=cv2.LINE_AA) return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB) def handle_image_upload(image: Optional[np.ndarray]): img = ensure_uint8(image) if img is None: return ( None, None, [], [], "Upload an image to start adding clicks.", ) return ( img, img, [], [], "Image loaded. Choose click type, then tap on the image.", ) def handle_click( point_mode: str, pts: List[Sequence[float]], lbls: List[int], image: Optional[np.ndarray], evt: gr.SelectData, ): if image is None: return ( gr.update(), pts, lbls, "Upload an image first.", ) coord = evt.index # (x, y) if coord is None: return ( gr.update(), pts, lbls, "Couldn't read click position.", ) x, y = coord label = POINT_MODE_TO_LABEL.get(point_mode, 1) pts = pts + [[float(x), float(y)]] lbls = lbls + [label] overlay = draw_overlay(image, None, pts, lbls) status = f"Added {'positive' if label == 1 else 'negative'} click at ({int(x)}, {int(y)})." return overlay, pts, lbls, status def undo_last_click(image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int]): if not pts: return ( gr.update(), pts, lbls, "No clicks to undo.", ) pts = pts[:-1] lbls = lbls[:-1] overlay = draw_overlay(image, None, pts, lbls) if image is not None else None status = "Removed the last click." return overlay, pts, lbls, status def clear_clicks(image: Optional[np.ndarray]): overlay = image if image is not None else None return overlay, [], [], "Cleared all clicks." def _run_segmentation( image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int], granularity: float, ): img = ensure_uint8(image) if img is None: return None, "Upload an image to segment." if not pts: return draw_overlay(img, None, [], []), "Add at least one click before running segmentation." device = choose_device() predictor = MODEL_MANAGER.make_predictor(device) predictor.set_image(img) coords = np.asarray(pts, dtype=np.float32) labels = np.asarray(lbls, dtype=np.int32) gran_tensor = build_granularity_tensor(granularity, predictor.device) masks, scores, logits = predictor.predict( point_coords=coords, point_labels=labels, multimask_output=True, gra=float(granularity), granularity=gran_tensor, ) best_idx = int(np.argmax(scores)) best_mask = masks[best_idx].astype(bool) status = ( f"Best mask #{best_idx + 1} IoU score: {float(scores[best_idx]):.3f} | " f"granularity={granularity:.2f}" ) refinement = apply_m2m_refinement( predictor=predictor, point_coords=coords, point_labels=labels, granularity=float(granularity), logits=logits, best_mask_idx=best_idx, use_m2m=USE_M2M_REFINEMENT, ) if refinement is not None: refined_mask, refined_score = refinement best_mask = refined_mask.astype(bool) status += f" | M2M IoU: {refined_score:.3f}" overlay = draw_overlay(img, best_mask, pts, lbls) return overlay, status if spaces is not None and ZERO_GPU_ENABLED: segment_fn = spaces.GPU(duration=ZERO_GPU_DURATION)(_run_segmentation) else: segment_fn = _run_segmentation def build_demo() -> gr.Blocks: with gr.Blocks(title="UnSAMv2 Interactive Segmentation", theme=gr.themes.Soft()) as demo: gr.Markdown( """## UnSAMv2 · Interactive Granularity Control Upload an image, add positive/negative clicks, tune granularity, and run segmentation.""" ) image_state = gr.State(DEFAULT_IMAGE) points_state = gr.State([]) labels_state = gr.State([]) image_input = gr.Image( label="Image · clicks & mask", type="numpy", height=480, value=DEFAULT_IMAGE, sources=["upload"], ) with gr.Row(): point_mode = gr.Radio( choices=list(POINT_MODE_TO_LABEL.keys()), value="Foreground (+)", label="Click type", ) granularity_slider = gr.Slider( minimum=GRANULARITY_MIN, maximum=GRANULARITY_MAX, value=0.2, step=0.01, label="Granularity", info="Lower = finer details, Higher = coarser regions", ) segment_button = gr.Button("Segment", variant="primary") with gr.Row(): undo_button = gr.Button("Undo last click") clear_button = gr.Button("Clear clicks") status_markdown = gr.Markdown(" Ready.") image_input.upload( handle_image_upload, inputs=[image_input], outputs=[ image_input, image_state, points_state, labels_state, status_markdown, ], ) image_input.clear( handle_image_upload, inputs=[image_input], outputs=[ image_input, image_state, points_state, labels_state, status_markdown, ], ) image_input.select( handle_click, inputs=[ point_mode, points_state, labels_state, image_state, ], outputs=[ image_input, points_state, labels_state, status_markdown, ], ) undo_button.click( undo_last_click, inputs=[image_state, points_state, labels_state], outputs=[ image_input, points_state, labels_state, status_markdown, ], ) clear_button.click( clear_clicks, inputs=[image_state], outputs=[ image_input, points_state, labels_state, status_markdown, ], ) segment_button.click( segment_fn, inputs=[image_state, points_state, labels_state, granularity_slider], outputs=[image_input, status_markdown], ) demo.queue(max_size=8) return demo demo = build_demo() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), share=True)