#!/usr/bin/env python3 """Gradio demo for UnSAMv2 interactive image segmentation with Hugging Face ZeroGPU support.""" from __future__ import annotations import logging import os import sys import threading from pathlib import Path from typing import List, Optional, Sequence, Tuple import cv2 import gradio as gr import numpy as np import torch try: import spaces # type: ignore except ImportError: # pragma: no cover - optional dependency on Spaces runtime spaces = None REPO_ROOT = Path(__file__).resolve().parent SAM2_REPO = REPO_ROOT / "sam2" if SAM2_REPO.exists(): sys.path.insert(0, str(SAM2_REPO)) from sam2.build_sam import build_sam2 # noqa: E402 from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402 logging.basicConfig(level=os.getenv("UNSAMV2_LOGLEVEL", "INFO")) LOGGER = logging.getLogger("unsamv2-gradio") CONFIG_PATH = os.getenv("UNSAMV2_CONFIG", "configs/unsamv2_small.yaml") CKPT_PATH = Path( os.getenv("UNSAMV2_CKPT", SAM2_REPO / "checkpoints" / "unsamv2_plus_ckpt.pt") ).resolve() if not CKPT_PATH.exists(): raise FileNotFoundError( f"Checkpoint not found at {CKPT_PATH}. Set UNSAMV2_CKPT to a valid .pt file." ) GRANULARITY_MIN = float(os.getenv("UNSAMV2_GRAN_MIN", 0.1)) GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0)) ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"} ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60")) POINT_MODE_TO_LABEL = {"Foreground (+)": 1, "Background (-)": 0} POINT_COLORS_BGR = { 1: (72, 201, 127), # green-ish for positives 0: (64, 76, 225), # red-ish for negatives } MASK_COLOR_BGR = (0, 196, 255) OUTLINE_COLOR_BGR = (0, 165, 255) class ModelManager: """Keeps SAM2 models on each device and spawns lightweight predictors.""" def __init__(self) -> None: self._models: dict[str, torch.nn.Module] = {} self._lock = threading.Lock() def _build(self, device: torch.device) -> torch.nn.Module: LOGGER.info("Loading UnSAMv2 weights onto %s", device) return build_sam2( CONFIG_PATH, ckpt_path=str(CKPT_PATH), device=device, mode="eval", ) def get_model(self, device: torch.device) -> torch.nn.Module: key = ( f"{device.type}:{device.index}" if device.type == "cuda" else device.type ) with self._lock: if key not in self._models: self._models[key] = self._build(device) return self._models[key] def make_predictor(self, device: torch.device) -> SAM2ImagePredictor: return SAM2ImagePredictor(self.get_model(device)) MODEL_MANAGER = ModelManager() def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]: if image is None: return None img = image[..., :3] # drop alpha if present if img.dtype == np.float32 or img.dtype == np.float64: if img.max() <= 1.0: img = (img * 255).clip(0, 255).astype(np.uint8) else: img = img.clip(0, 255).astype(np.uint8) elif img.dtype != np.uint8: img = img.clip(0, 255).astype(np.uint8) return img def choose_device() -> torch.device: preference = os.getenv("UNSAMV2_DEVICE", "auto").lower() if preference == "cpu": return torch.device("cpu") if preference.startswith("cuda") or preference == "gpu": if torch.cuda.is_available(): return torch.device(preference if preference.startswith("cuda") else "cuda") LOGGER.warning("CUDA requested but not available; defaulting to CPU") return torch.device("cpu") return torch.device("cuda" if torch.cuda.is_available() else "cpu") def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor: tensor = torch.tensor([[[[value]]]], dtype=torch.float32, device=device) return tensor def draw_overlay( image: np.ndarray, mask: Optional[np.ndarray], points: Sequence[Sequence[float]], labels: Sequence[int], alpha: float = 0.55, ) -> np.ndarray: canvas_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) if mask is not None: mask_bool = mask.astype(bool) overlay = np.zeros_like(canvas_bgr, dtype=np.uint8) overlay[mask_bool] = MASK_COLOR_BGR canvas_bgr = np.where( mask_bool[..., None], (canvas_bgr * (1.0 - alpha) + overlay * alpha).astype(np.uint8), canvas_bgr, ) contours, _ = cv2.findContours( mask_bool.astype(np.uint8), mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE, ) cv2.drawContours(canvas_bgr, contours, -1, OUTLINE_COLOR_BGR, 2) for (x, y), lbl in zip(points, labels): color = POINT_COLORS_BGR.get(lbl, (255, 255, 255)) center = (int(round(x)), int(round(y))) cv2.circle(canvas_bgr, center, 7, color, thickness=-1, lineType=cv2.LINE_AA) cv2.circle(canvas_bgr, center, 9, (255, 255, 255), thickness=2, lineType=cv2.LINE_AA) return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB) def points_table(points: Sequence[Sequence[float]], labels: Sequence[int]) -> List[List[str]]: table = [] for idx, ((x, y), lbl) in enumerate(zip(points, labels), start=1): table.append([ idx, round(float(x), 1), round(float(y), 1), "fg" if lbl == 1 else "bg", ]) return table def handle_image_upload(image: Optional[np.ndarray]): img = ensure_uint8(image) if img is None: return ( None, None, None, [], [], [], "Upload an image to start adding clicks.", ) return ( img, None, img, [], [], [], "Image loaded. Choose click type, then tap on the image.", ) def handle_click( point_mode: str, pts: List[Sequence[float]], lbls: List[int], image: Optional[np.ndarray], evt: gr.SelectData, ): if image is None: return ( gr.update(), None, pts, lbls, points_table(pts, lbls), "Upload an image first.", ) coord = evt.index # (x, y) if coord is None: return ( gr.update(), None, pts, lbls, points_table(pts, lbls), "Couldn't read click position.", ) x, y = coord label = POINT_MODE_TO_LABEL.get(point_mode, 1) pts = pts + [[float(x), float(y)]] lbls = lbls + [label] overlay = draw_overlay(image, None, pts, lbls) status = f"Added {'positive' if label == 1 else 'negative'} click at ({int(x)}, {int(y)})." return overlay, None, pts, lbls, points_table(pts, lbls), status def undo_last_click(image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int]): if not pts: return ( gr.update(), None, pts, lbls, points_table(pts, lbls), "No clicks to undo.", ) pts = pts[:-1] lbls = lbls[:-1] overlay = draw_overlay(image, None, pts, lbls) if image is not None else None status = "Removed the last click." return overlay, None, pts, lbls, points_table(pts, lbls), status def clear_clicks(image: Optional[np.ndarray]): overlay = image if image is not None else None return overlay, None, [], [], [], "Cleared all clicks." def _run_segmentation( image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int], granularity: float, ): img = ensure_uint8(image) if img is None: return None, None, "Upload an image to segment." if not pts: return draw_overlay(img, None, [], []), None, "Add at least one click before running segmentation." device = choose_device() predictor = MODEL_MANAGER.make_predictor(device) predictor.set_image(img) coords = np.asarray(pts, dtype=np.float32) labels = np.asarray(lbls, dtype=np.int32) gran_tensor = build_granularity_tensor(granularity, predictor.device) masks, scores, _ = predictor.predict( point_coords=coords, point_labels=labels, multimask_output=True, gra=float(granularity), granularity=gran_tensor, ) best_idx = int(np.argmax(scores)) best_mask = masks[best_idx].astype(bool) overlay = draw_overlay(img, best_mask, pts, lbls) mask_vis = (best_mask.astype(np.uint8) * 255) status = f"Best mask #{best_idx + 1} IoU score: {float(scores[best_idx]):.3f} | granularity={granularity:.2f}" return overlay, mask_vis, status if spaces is not None and ZERO_GPU_ENABLED: segment_fn = spaces.GPU(duration=ZERO_GPU_DURATION)(_run_segmentation) else: segment_fn = _run_segmentation def build_demo() -> gr.Blocks: with gr.Blocks(title="UnSAMv2 Interactive Segmentation", theme=gr.themes.Soft()) as demo: gr.Markdown( """## UnSAMv2 · Interactive Granularity Control Upload an image, add positive/negative clicks, tune granularity, and run segmentation. ZeroGPU automatically pulls a GPU when available; otherwise the app falls back to CPU.""" ) image_state = gr.State() points_state = gr.State([]) labels_state = gr.State([]) with gr.Row(): image_input = gr.Image( label="1 · Upload image & click to add prompts", type="numpy", height=480, ) overlay_output = gr.Image( label="Segmentation preview", interactive=False, height=480, ) mask_output = gr.Image( label="Binary mask", interactive=False, height=480, ) with gr.Row(): point_mode = gr.Radio( choices=list(POINT_MODE_TO_LABEL.keys()), value="Foreground (+)", label="Click type", ) granularity_slider = gr.Slider( minimum=GRANULARITY_MIN, maximum=GRANULARITY_MAX, value=0.2, step=0.05, label="Granularity", info="Lower = finer details, Higher = coarser regions", ) segment_button = gr.Button("3 · Segment", variant="primary") with gr.Row(): undo_button = gr.Button("Undo last click") clear_button = gr.Button("Clear clicks") points_table_output = gr.Dataframe( headers=["#", "x", "y", "type"], datatype=["number", "number", "number", "str"], interactive=False, label="2 · Click history", ) status_markdown = gr.Markdown(" Ready.") image_input.upload( handle_image_upload, inputs=[image_input], outputs=[ overlay_output, mask_output, image_state, points_state, labels_state, points_table_output, status_markdown, ], ) image_input.clear( handle_image_upload, inputs=[image_input], outputs=[ overlay_output, mask_output, image_state, points_state, labels_state, points_table_output, status_markdown, ], ) image_input.select( handle_click, inputs=[ point_mode, points_state, labels_state, image_state, ], outputs=[ overlay_output, mask_output, points_state, labels_state, points_table_output, status_markdown, ], ) undo_button.click( undo_last_click, inputs=[image_state, points_state, labels_state], outputs=[ overlay_output, mask_output, points_state, labels_state, points_table_output, status_markdown, ], ) clear_button.click( clear_clicks, inputs=[image_state], outputs=[ overlay_output, mask_output, points_state, labels_state, points_table_output, status_markdown, ], ) segment_button.click( segment_fn, inputs=[image_state, points_state, labels_state, granularity_slider], outputs=[overlay_output, mask_output, status_markdown], ) demo.queue(max_size=8) return demo demo = build_demo() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), share=True)