UnSAMv2 / app.py
yjwnb6
Initial HF Space upload
7b25808
raw
history blame
13.2 kB
#!/usr/bin/env python3
"""Gradio demo for UnSAMv2 interactive image segmentation with Hugging Face ZeroGPU support."""
from __future__ import annotations
import logging
import os
import sys
import threading
from pathlib import Path
from typing import List, Optional, Sequence, Tuple
import cv2
import gradio as gr
import numpy as np
import torch
try:
import spaces # type: ignore
except ImportError: # pragma: no cover - optional dependency on Spaces runtime
spaces = None
REPO_ROOT = Path(__file__).resolve().parent
SAM2_REPO = REPO_ROOT / "sam2"
if SAM2_REPO.exists():
sys.path.insert(0, str(SAM2_REPO))
from sam2.build_sam import build_sam2 # noqa: E402
from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402
logging.basicConfig(level=os.getenv("UNSAMV2_LOGLEVEL", "INFO"))
LOGGER = logging.getLogger("unsamv2-gradio")
CONFIG_PATH = os.getenv("UNSAMV2_CONFIG", "configs/unsamv2_small.yaml")
CKPT_PATH = Path(
os.getenv("UNSAMV2_CKPT", SAM2_REPO / "checkpoints" / "unsamv2_plus_ckpt.pt")
).resolve()
if not CKPT_PATH.exists():
raise FileNotFoundError(
f"Checkpoint not found at {CKPT_PATH}. Set UNSAMV2_CKPT to a valid .pt file."
)
GRANULARITY_MIN = float(os.getenv("UNSAMV2_GRAN_MIN", 0.1))
GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0))
ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"}
ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60"))
POINT_MODE_TO_LABEL = {"Foreground (+)": 1, "Background (-)": 0}
POINT_COLORS_BGR = {
1: (72, 201, 127), # green-ish for positives
0: (64, 76, 225), # red-ish for negatives
}
MASK_COLOR_BGR = (0, 196, 255)
OUTLINE_COLOR_BGR = (0, 165, 255)
class ModelManager:
"""Keeps SAM2 models on each device and spawns lightweight predictors."""
def __init__(self) -> None:
self._models: dict[str, torch.nn.Module] = {}
self._lock = threading.Lock()
def _build(self, device: torch.device) -> torch.nn.Module:
LOGGER.info("Loading UnSAMv2 weights onto %s", device)
return build_sam2(
CONFIG_PATH,
ckpt_path=str(CKPT_PATH),
device=device,
mode="eval",
)
def get_model(self, device: torch.device) -> torch.nn.Module:
key = (
f"{device.type}:{device.index}"
if device.type == "cuda"
else device.type
)
with self._lock:
if key not in self._models:
self._models[key] = self._build(device)
return self._models[key]
def make_predictor(self, device: torch.device) -> SAM2ImagePredictor:
return SAM2ImagePredictor(self.get_model(device))
MODEL_MANAGER = ModelManager()
def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]:
if image is None:
return None
img = image[..., :3] # drop alpha if present
if img.dtype == np.float32 or img.dtype == np.float64:
if img.max() <= 1.0:
img = (img * 255).clip(0, 255).astype(np.uint8)
else:
img = img.clip(0, 255).astype(np.uint8)
elif img.dtype != np.uint8:
img = img.clip(0, 255).astype(np.uint8)
return img
def choose_device() -> torch.device:
preference = os.getenv("UNSAMV2_DEVICE", "auto").lower()
if preference == "cpu":
return torch.device("cpu")
if preference.startswith("cuda") or preference == "gpu":
if torch.cuda.is_available():
return torch.device(preference if preference.startswith("cuda") else "cuda")
LOGGER.warning("CUDA requested but not available; defaulting to CPU")
return torch.device("cpu")
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor:
tensor = torch.tensor([[[[value]]]], dtype=torch.float32, device=device)
return tensor
def draw_overlay(
image: np.ndarray,
mask: Optional[np.ndarray],
points: Sequence[Sequence[float]],
labels: Sequence[int],
alpha: float = 0.55,
) -> np.ndarray:
canvas_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
if mask is not None:
mask_bool = mask.astype(bool)
overlay = np.zeros_like(canvas_bgr, dtype=np.uint8)
overlay[mask_bool] = MASK_COLOR_BGR
canvas_bgr = np.where(
mask_bool[..., None],
(canvas_bgr * (1.0 - alpha) + overlay * alpha).astype(np.uint8),
canvas_bgr,
)
contours, _ = cv2.findContours(
mask_bool.astype(np.uint8),
mode=cv2.RETR_EXTERNAL,
method=cv2.CHAIN_APPROX_SIMPLE,
)
cv2.drawContours(canvas_bgr, contours, -1, OUTLINE_COLOR_BGR, 2)
for (x, y), lbl in zip(points, labels):
color = POINT_COLORS_BGR.get(lbl, (255, 255, 255))
center = (int(round(x)), int(round(y)))
cv2.circle(canvas_bgr, center, 7, color, thickness=-1, lineType=cv2.LINE_AA)
cv2.circle(canvas_bgr, center, 9, (255, 255, 255), thickness=2, lineType=cv2.LINE_AA)
return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB)
def points_table(points: Sequence[Sequence[float]], labels: Sequence[int]) -> List[List[str]]:
table = []
for idx, ((x, y), lbl) in enumerate(zip(points, labels), start=1):
table.append([
idx,
round(float(x), 1),
round(float(y), 1),
"fg" if lbl == 1 else "bg",
])
return table
def handle_image_upload(image: Optional[np.ndarray]):
img = ensure_uint8(image)
if img is None:
return (
None,
None,
None,
[],
[],
[],
"Upload an image to start adding clicks.",
)
return (
img,
None,
img,
[],
[],
[],
"Image loaded. Choose click type, then tap on the image.",
)
def handle_click(
point_mode: str,
pts: List[Sequence[float]],
lbls: List[int],
image: Optional[np.ndarray],
evt: gr.SelectData,
):
if image is None:
return (
gr.update(),
None,
pts,
lbls,
points_table(pts, lbls),
"Upload an image first.",
)
coord = evt.index # (x, y)
if coord is None:
return (
gr.update(),
None,
pts,
lbls,
points_table(pts, lbls),
"Couldn't read click position.",
)
x, y = coord
label = POINT_MODE_TO_LABEL.get(point_mode, 1)
pts = pts + [[float(x), float(y)]]
lbls = lbls + [label]
overlay = draw_overlay(image, None, pts, lbls)
status = f"Added {'positive' if label == 1 else 'negative'} click at ({int(x)}, {int(y)})."
return overlay, None, pts, lbls, points_table(pts, lbls), status
def undo_last_click(image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int]):
if not pts:
return (
gr.update(),
None,
pts,
lbls,
points_table(pts, lbls),
"No clicks to undo.",
)
pts = pts[:-1]
lbls = lbls[:-1]
overlay = draw_overlay(image, None, pts, lbls) if image is not None else None
status = "Removed the last click."
return overlay, None, pts, lbls, points_table(pts, lbls), status
def clear_clicks(image: Optional[np.ndarray]):
overlay = image if image is not None else None
return overlay, None, [], [], [], "Cleared all clicks."
def _run_segmentation(
image: Optional[np.ndarray],
pts: List[Sequence[float]],
lbls: List[int],
granularity: float,
):
img = ensure_uint8(image)
if img is None:
return None, None, "Upload an image to segment."
if not pts:
return draw_overlay(img, None, [], []), None, "Add at least one click before running segmentation."
device = choose_device()
predictor = MODEL_MANAGER.make_predictor(device)
predictor.set_image(img)
coords = np.asarray(pts, dtype=np.float32)
labels = np.asarray(lbls, dtype=np.int32)
gran_tensor = build_granularity_tensor(granularity, predictor.device)
masks, scores, _ = predictor.predict(
point_coords=coords,
point_labels=labels,
multimask_output=True,
gra=float(granularity),
granularity=gran_tensor,
)
best_idx = int(np.argmax(scores))
best_mask = masks[best_idx].astype(bool)
overlay = draw_overlay(img, best_mask, pts, lbls)
mask_vis = (best_mask.astype(np.uint8) * 255)
status = f"Best mask #{best_idx + 1} IoU score: {float(scores[best_idx]):.3f} | granularity={granularity:.2f}"
return overlay, mask_vis, status
if spaces is not None and ZERO_GPU_ENABLED:
segment_fn = spaces.GPU(duration=ZERO_GPU_DURATION)(_run_segmentation)
else:
segment_fn = _run_segmentation
def build_demo() -> gr.Blocks:
with gr.Blocks(title="UnSAMv2 Interactive Segmentation", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""## UnSAMv2 路 Interactive Granularity Control
Upload an image, add positive/negative clicks, tune granularity, and run segmentation.
ZeroGPU automatically pulls a GPU when available; otherwise the app falls back to CPU."""
)
image_state = gr.State()
points_state = gr.State([])
labels_state = gr.State([])
with gr.Row():
image_input = gr.Image(
label="1 路 Upload image & click to add prompts",
type="numpy",
height=480,
)
overlay_output = gr.Image(
label="Segmentation preview",
interactive=False,
height=480,
)
mask_output = gr.Image(
label="Binary mask",
interactive=False,
height=480,
)
with gr.Row():
point_mode = gr.Radio(
choices=list(POINT_MODE_TO_LABEL.keys()),
value="Foreground (+)",
label="Click type",
)
granularity_slider = gr.Slider(
minimum=GRANULARITY_MIN,
maximum=GRANULARITY_MAX,
value=0.2,
step=0.05,
label="Granularity",
info="Lower = finer details, Higher = coarser regions",
)
segment_button = gr.Button("3 路 Segment", variant="primary")
with gr.Row():
undo_button = gr.Button("Undo last click")
clear_button = gr.Button("Clear clicks")
points_table_output = gr.Dataframe(
headers=["#", "x", "y", "type"],
datatype=["number", "number", "number", "str"],
interactive=False,
label="2 路 Click history",
)
status_markdown = gr.Markdown(" Ready.")
image_input.upload(
handle_image_upload,
inputs=[image_input],
outputs=[
overlay_output,
mask_output,
image_state,
points_state,
labels_state,
points_table_output,
status_markdown,
],
)
image_input.clear(
handle_image_upload,
inputs=[image_input],
outputs=[
overlay_output,
mask_output,
image_state,
points_state,
labels_state,
points_table_output,
status_markdown,
],
)
image_input.select(
handle_click,
inputs=[
point_mode,
points_state,
labels_state,
image_state,
],
outputs=[
overlay_output,
mask_output,
points_state,
labels_state,
points_table_output,
status_markdown,
],
)
undo_button.click(
undo_last_click,
inputs=[image_state, points_state, labels_state],
outputs=[
overlay_output,
mask_output,
points_state,
labels_state,
points_table_output,
status_markdown,
],
)
clear_button.click(
clear_clicks,
inputs=[image_state],
outputs=[
overlay_output,
mask_output,
points_state,
labels_state,
points_table_output,
status_markdown,
],
)
segment_button.click(
segment_fn,
inputs=[image_state, points_state, labels_state, granularity_slider],
outputs=[overlay_output, mask_output, status_markdown],
)
demo.queue(max_size=8)
return demo
demo = build_demo()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), share=True)