TIPSv2 / app.py
bingyic's picture
Revert "Fix extract_features_value_attention and add _get_all_blocks"
4b9c38c
"""TIPS Feature Explorer (GPU) β€” Hugging Face Space demo with ZeroGPU."""
import colorsys
import os
import gradio as gr
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import spaces
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
from fast_pytorch_kmeans import KMeans as TorchKMeans
from sklearn.decomposition import PCA
from torchvision import transforms
from transformers import AutoModel
# ── Constants ───────────────────────────────────────────────────────────────
DEFAULT_IMAGE_SIZE = 896
PATCH_SIZE = 14
RESOLUTIONS = [224, 336, 448, 672, 896, 1120, 1372, 1792]
ZEROSEG_IMAGE_SIZE = 1372
MAX_LEN = 64
VARIANTS = {
"TIPS v2 β€” B/14": "google/tipsv2-b14-dpt",
"TIPS v2 β€” L/14": "google/tipsv2-l14-dpt",
"TIPS v2 β€” SO400m/14": "google/tipsv2-so400m14-dpt",
"TIPS v2 β€” g/14": "google/tipsv2-g14-dpt",
}
DEFAULT_VARIANT = "TIPS v2 β€” L/14"
def _device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── Pascal Context (59 classes) ─────────────────────────────────────────────
TCL_PROMPTS = [
"itap of a {}.",
"a bad photo of a {}.",
"a origami {}.",
"a photo of the large {}.",
"a {} in a video game.",
"art of the {}.",
"a photo of the small {}.",
"a photo of many {}.",
"a photo of {}s.",
]
PASCAL_CONTEXT_CLASSES = (
"aeroplane",
"bag",
"bed",
"bedclothes",
"bench",
"bicycle",
"bird",
"boat",
"book",
"bottle",
"building",
"bus",
"cabinet",
"car",
"cat",
"ceiling",
"chair",
"cloth",
"computer",
"cow",
"cup",
"curtain",
"dog",
"door",
"fence",
"floor",
"flower",
"food",
"grass",
"ground",
"horse",
"keyboard",
"light",
"motorbike",
"mountain",
"mouse",
"person",
"plate",
"platform",
"pottedplant",
"road",
"rock",
"sheep",
"shelves",
"sidewalk",
"sign",
"sky",
"snow",
"sofa",
"table",
"track",
"train",
"tree",
"truck",
"tvmonitor",
"wall",
"water",
"window",
"wood",
)
ADE20K_CLASSES = (
"wall",
"building",
"sky",
"floor",
"tree",
"ceiling",
"road",
"bed",
"windowpane",
"grass",
"cabinet",
"sidewalk",
"person",
"earth",
"door",
"table",
"mountain",
"plant",
"curtain",
"chair",
"car",
"water",
"painting",
"sofa",
"shelf",
"house",
"sea",
"mirror",
"rug",
"field",
"armchair",
"seat",
"fence",
"desk",
"rock",
"wardrobe",
"lamp",
"bathtub",
"railing",
"cushion",
"base",
"box",
"column",
"signboard",
"chest_of_drawers",
"counter",
"sand",
"sink",
"skyscraper",
"fireplace",
"refrigerator",
"grandstand",
"path",
"stairs",
"runway",
"case",
"pool_table",
"pillow",
"screen_door",
"stairway",
"river",
"bridge",
"bookcase",
"blind",
"coffee_table",
"toilet",
"flower",
"book",
"hill",
"bench",
"countertop",
"stove",
"palm",
"kitchen_island",
"computer",
"swivel_chair",
"boat",
"bar",
"arcade_machine",
"hovel",
"bus",
"towel",
"light",
"truck",
"tower",
"chandelier",
"awning",
"streetlight",
"booth",
"television",
"airplane",
"dirt_track",
"apparel",
"pole",
"land",
"bannister",
"escalator",
"ottoman",
"bottle",
"buffet",
"poster",
"stage",
"van",
"ship",
"fountain",
"conveyer_belt",
"canopy",
"washer",
"plaything",
"swimming_pool",
"stool",
"barrel",
"basket",
"waterfall",
"tent",
"bag",
"minibike",
"cradle",
"oven",
"ball",
"food",
"step",
"tank",
"trade_name",
"microwave",
"pot",
"animal",
"bicycle",
"lake",
"dishwasher",
"screen",
"blanket",
"sculpture",
"hood",
"sconce",
"vase",
"traffic_light",
"tray",
"ashcan",
"fan",
"pier",
"crt_screen",
"plate",
"monitor",
"bulletin_board",
"shower",
"radiator",
"glass",
"clock",
"flag",
)
NUM_ADE20K_CLASSES = 150
ADE20K_PALETTE = np.zeros((NUM_ADE20K_CLASSES + 1, 3), dtype=np.uint8)
for i in range(1, NUM_ADE20K_CLASSES + 1):
hue = (i * 0.618033988749895) % 1.0
saturation = 0.65 + 0.35 * ((i * 7) % 5) / 4.0
value = 0.70 + 0.30 * ((i * 11) % 3) / 2.0
r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)
ADE20K_PALETTE[i] = [int(r * 255), int(g * 255), int(b * 255)]
# ── Model state (one model loaded at a time) ───────────────────────────────
_model = {
"name": None,
"vision": None,
"text": None,
"tokenizer": None,
"temperature": None,
"ade20k_embs": None,
"dpt": None,
}
def load_variant(name):
"""Load a DPT model variant from HuggingFace (includes the backbone)."""
global _model
if _model["name"] == name:
return
token = os.environ.get("HF_TIPSv2") or os.environ.get("HF_TOKEN")
dpt = AutoModel.from_pretrained(VARIANTS[name], trust_remote_code=True, token=token)
dpt.eval()
dpt._get_backbone() # trigger backbone download
backbone = dpt._backbone
_model.update(
name=name,
dpt=dpt,
vision=backbone.vision_encoder,
text=backbone.text_encoder,
tokenizer=backbone._load_tokenizer(),
temperature=backbone.config.temperature,
ade20k_embs=None,
)
print(f"Loaded {name}")
def _move_models_to_device():
"""Move models to the current device (GPU inside @spaces.GPU, else CPU)."""
dev = _device()
if _model["vision"] is not None:
_model["vision"].to(dev)
if _model["text"] is not None:
_model["text"].to(dev)
if _model["dpt"] is not None:
_model["dpt"].to(dev)
def _ensure_ade20k_embs():
"""Pre-compute Pascal Context text embeddings if not yet done (must run on GPU)."""
if _model["ade20k_embs"] is not None:
return
dev = _device()
model_t = _model["text"]
tokenizer = _model["tokenizer"]
all_embs = []
for template in TCL_PROMPTS:
prompts = [template.format(c) for c in PASCAL_CONTEXT_CLASSES]
ids, paddings = tokenizer.tokenize(prompts, max_len=MAX_LEN)
with torch.no_grad():
embs = model_t(
torch.from_numpy(ids).to(dev),
torch.from_numpy(paddings).to(dev),
)
all_embs.append(embs.cpu().numpy())
_model["ade20k_embs"] = l2_normalize(np.mean(all_embs, axis=0))
print("Pascal Context text embeddings computed.")
def _init_model():
"""Load model + move to GPU + compute text embeddings."""
load_variant(_model["name"] or DEFAULT_VARIANT)
_move_models_to_device()
_ensure_ade20k_embs()
# ── Preprocessing & helpers ─────────────────────────────────────────────────
def preprocess(img, size=DEFAULT_IMAGE_SIZE):
return transforms.Compose(
[
transforms.Resize((size, size)),
transforms.ToTensor(),
]
)(img)
def l2_normalize(x, axis=-1):
return x / np.linalg.norm(x, ord=2, axis=axis, keepdims=True).clip(min=1e-3)
def upsample(arr, h, w, mode="bilinear"):
"""Upsample (H, W, C) or (H, W) numpy array to (h, w, ...)."""
t = torch.from_numpy(arr).float()
if t.ndim == 2:
t = t.unsqueeze(-1)
t = t.permute(2, 0, 1).unsqueeze(0)
kwargs = dict(align_corners=False) if mode == "bilinear" else {}
up = F.interpolate(t, size=(h, w), mode=mode, **kwargs)
return up[0].permute(1, 2, 0).numpy()
def to_uint8(x):
return (x * 255).clip(0, 255).astype(np.uint8)
# ── Feature extraction (GPU-accelerated) ────────────────────────────────────
@torch.no_grad()
def extract_features(image_np, resolution=DEFAULT_IMAGE_SIZE):
"""Return spatial features (sp, sp, D) as numpy. sp = resolution // 14."""
dev = _device()
img = Image.fromarray(image_np).convert("RGB")
tensor = preprocess(img, resolution).unsqueeze(0).to(dev)
_, _, patch_tokens = _model["vision"](tensor)
sp = resolution // PATCH_SIZE
return patch_tokens.cpu().reshape(sp, sp, -1).numpy()
@torch.no_grad()
def extract_features_value_attention(image_np, resolution=ZEROSEG_IMAGE_SIZE):
"""Return spatial features (sp, sp, D) using Value Attention on GPU.
This follows the Colab reference implementation: run all blocks except the
last normally, then for the last block extract V from QKV and manually
apply out_proj, layer scale, residual, norm2, MLP + layer scale, second
residual, and final norm.
"""
dev = _device()
model_image = _model["vision"]
img = Image.fromarray(image_np).convert("RGB")
tensor = preprocess(img, resolution).unsqueeze(0).to(dev)
x = model_image.prepare_tokens_with_masks(tensor)
for blk in model_image.blocks[:-1]:
x = blk(x)
blk = model_image.blocks[-1]
num_reg = getattr(model_image, "num_register_tokens", 1)
b_dim, n_dim, c_dim = x.shape
num_heads = blk.attn.num_heads
qkv = blk.attn.qkv(blk.norm1(x))
qkv = qkv.reshape(b_dim, n_dim, 3, num_heads, c_dim // num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, D_head)
v = qkv[2] # (B, H, N, D_head)
v_out = v.transpose(1, 2).reshape(b_dim, n_dim, c_dim)
v_out = blk.attn.proj(v_out)
v_out = blk.ls1(v_out)
x_val = v_out + x
y_val = blk.norm2(x_val)
y_val = blk.ls2(blk.mlp(y_val))
x_val = x_val + y_val
x_val = model_image.norm(x_val)
patch_tokens = x_val[:, 1 + num_reg :, :]
sp = resolution // PATCH_SIZE
spatial = patch_tokens.cpu().reshape(sp, sp, -1).numpy()
return spatial
# ── PCA Visualisations ──────────────────────────────────────────────────────
def vis_pca(spatial):
"""PCA of spatial features β†’ RGB image."""
feat = spatial.reshape(-1, spatial.shape[-1])
pca = PCA(n_components=3, whiten=True)
h, w = spatial.shape[0], spatial.shape[1]
rgb = pca.fit_transform(feat).reshape(h, w, 3)
rgb = 1 / (1 + np.exp(-2.0 * rgb))
return to_uint8(rgb)
def vis_depth(spatial):
"""1st PCA component visualized with inferno colormap."""
feat = spatial.reshape(-1, spatial.shape[-1])
h, w = spatial.shape[0], spatial.shape[1]
depth = PCA(n_components=1).fit_transform(feat).reshape(h, w)
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
colored = cm.get_cmap("inferno")(depth)[:, :, :3].astype(np.float32)
return to_uint8(colored)
def vis_kmeans(spatial, h, w, n_clusters=6):
"""K-means clustering of spatial features."""
sp_h, sp_w = spatial.shape[:2]
feat = torch.from_numpy(spatial.reshape(-1, spatial.shape[-1])).to(_device())
km = TorchKMeans(n_clusters=n_clusters, max_iter=20)
km.fit(feat)
dists = -torch.cdist(feat, km.centroids) # (H*W, k)
scores = dists.cpu().numpy().reshape(sp_h, sp_w, n_clusters)
scores_up = upsample(scores, h, w, mode="bilinear")
labels = scores_up.argmax(axis=-1)
palette = plt.cm.tab20(np.linspace(0, 1, n_clusters))[:, :3]
seg = palette[labels].astype(np.float32)
return to_uint8(seg)
# ── Zero-shot Segmentation ──────────────────────────────────────────────────
def vis_custom_semseg(spatial, orig_image, classes, class_embs):
"""Zero-shot semantic segmentation with user-defined classes."""
h, w = orig_image.shape[:2]
sp_h, sp_w = spatial.shape[:2]
n = len(classes)
feat = l2_normalize(spatial.reshape(-1, spatial.shape[-1]))
sim = feat @ class_embs.T
sim_map = sim.reshape(sp_h, sp_w, n)
sim_up = upsample(sim_map, h, w, mode="bilinear")
labels = sim_up.argmax(axis=-1)
palette = (plt.cm.tab20(np.linspace(0, 1, max(n, 2)))[:n, :3] * 255).astype(
np.uint8
)
seg_rgb = palette[labels].astype(np.float32) / 255.0
mask_img = to_uint8(seg_rgb)
blend = 0.1 * orig_image.astype(np.float32) / 255.0 + 0.9 * seg_rgb
blend_img = Image.fromarray(to_uint8(blend))
unique_ids, counts = np.unique(labels, return_counts=True)
order = np.argsort(-counts)
unique_ids, counts = unique_ids[order], counts[order]
total = counts.sum()
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
60,
)
except OSError:
font = ImageFont.load_default()
n_legend = min(len(unique_ids), 10)
row_h = 80
swatch_w = 60
pad = 12
legend_w = 450
legend_h = max(h, n_legend * row_h + pad * 2)
canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255))
canvas.paste(blend_img, (0, 0))
draw = ImageDraw.Draw(canvas)
for i in range(n_legend):
cid = unique_ids[i]
color = tuple(palette[cid].tolist())
y_top = pad + i * row_h
draw.rectangle(
[w + pad, y_top, w + pad + swatch_w, y_top + swatch_w],
fill=color,
outline=(0, 0, 0),
)
draw.text(
(w + pad + swatch_w + 8, y_top + 6),
classes[cid],
fill="black",
font=font,
)
overlay_out = np.array(canvas)
detected_parts, minor_parts = [], []
for i, cid in enumerate(unique_ids):
pct = counts[i] / total * 100
if pct >= 2:
detected_parts.append(f"{classes[cid]} ({pct:.1f}%)")
else:
minor_parts.append(f"{classes[cid]} ({pct:.1f}%)")
absent = [
f"{classes[i]} (0.0%)" for i in range(n) if i not in set(unique_ids.tolist())
]
detected_str = ", ".join(detected_parts)
undetected_str = ", ".join(minor_parts + absent)
return overlay_out, mask_img, detected_str, undetected_str
# ── DPT Depth Inference ─────────────────────────────────────────────────────
def vis_depth_dpt(depth_map, h, w):
"""Colour a depth map with the turbo colormap β†’ PIL Image."""
d = depth_map.squeeze()
d = (d - d.min()) / (d.max() - d.min() + 1e-8)
colored = cm.get_cmap("turbo")(d)[:, :, :3].astype(np.float32)
return to_uint8(upsample(colored, h, w))
def vis_normals_dpt(normals_map, h, w):
"""Map normals from [-1, 1] to [0, 1] and resize to original size."""
n = normals_map.cpu().numpy()
n = (n + 1.0) / 2.0
n = np.transpose(n, (1, 2, 0)) # (H, W, 3)
return to_uint8(upsample(n, h, w))
def vis_segmentation_dpt(seg_map, orig_image):
"""Colour a segmentation map with the ADE20K colormap + legend."""
h, w = orig_image.shape[:2]
logits = seg_map.cpu().numpy().transpose(1, 2, 0) # (H, W, 150)
logits_up = upsample(logits, h, w, mode="bilinear")
pred = logits_up.argmax(axis=-1) # (h, w)
seg_rgb = ADE20K_PALETTE[pred.astype(np.int32) + 1].astype(np.float32) / 255.0
blend = 0.15 * orig_image.astype(np.float32) / 255.0 + 0.85 * seg_rgb
blend_img = Image.fromarray(to_uint8(blend))
unique_ids, counts = np.unique(pred, return_counts=True)
total_pixels = counts.sum()
order = np.argsort(-counts)
unique_ids, counts = unique_ids[order], counts[order]
pcts = counts / total_pixels * 100
mask = pcts >= 2.0
unique_ids, counts, pcts = unique_ids[mask], counts[mask], pcts[mask]
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
36,
)
except OSError:
font = ImageFont.load_default()
n_legend = min(len(unique_ids), 10)
row_h, swatch_w, pad, legend_w = 50, 40, 10, 450
legend_h = max(h, n_legend * row_h + pad * 2)
canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255))
canvas.paste(blend_img, (0, 0))
draw = ImageDraw.Draw(canvas)
for i in range(n_legend):
cid = unique_ids[i]
color = tuple(ADE20K_PALETTE[cid + 1].tolist())
name = ADE20K_CLASSES[cid] if cid < len(ADE20K_CLASSES) else f"class_{cid}"
y_top = pad + i * row_h
draw.rectangle(
[w + pad, y_top, w + pad + swatch_w, y_top + swatch_w],
fill=color,
outline=(0, 0, 0),
)
draw.text(
(w + pad + swatch_w + 8, y_top + 4),
name,
fill="black",
font=font,
)
return np.array(canvas)
# ── Gradio callbacks ────────────────────────────────────────────────────────
@spaces.GPU
def on_variant_change(variant_name):
load_variant(variant_name)
_move_models_to_device()
_ensure_ade20k_embs()
return (
None,
None,
None, # pca_out, depth_out, kmeans_out
None, # pca_state
None,
None,
"",
"", # custom outputs
)
@spaces.GPU
def on_pca_extract(image, resolution, _pca_state):
if image is None:
return None, None, None, None
_init_model()
resolution = int(resolution)
spatial = extract_features(image, resolution)
h, w = image.shape[:2]
pca = vis_pca(spatial)
depth = vis_depth(spatial)
kmeans = vis_kmeans(spatial, h, w)
state = {
"spatial": spatial,
"orig_image": image,
"variant": _model["name"],
"resolution": resolution,
}
return pca, depth, kmeans, state
@spaces.GPU
def on_recluster(image, resolution, n_clusters, pca_state):
if image is None:
gr.Warning("Upload an image first.")
return None, pca_state
_init_model()
resolution = int(resolution)
if (
pca_state is not None
and pca_state.get("variant") == _model["name"]
and pca_state.get("resolution") == resolution
):
spatial = pca_state["spatial"]
else:
spatial = extract_features(image, resolution)
pca_state = {
"spatial": spatial,
"orig_image": image,
"variant": _model["name"],
"resolution": resolution,
}
h, w = image.shape[:2]
return vis_kmeans(spatial, h, w, int(n_clusters)), pca_state
@spaces.GPU
def on_zeroseg_custom(image, resolution, class_names_str):
if image is None or not class_names_str or not class_names_str.strip():
gr.Warning("Upload an image and enter at least one class name.")
return None, None, "", ""
_init_model()
resolution = int(resolution)
classes = [c.strip() for c in class_names_str.split(",") if c.strip()]
if not classes:
return None, None, "", ""
dev = _device()
all_embs = []
for template in TCL_PROMPTS:
prompts = [template.format(c) for c in classes]
ids, paddings = _model["tokenizer"].tokenize(prompts, max_len=MAX_LEN)
with torch.no_grad():
embs = _model["text"](
torch.from_numpy(ids).to(dev),
torch.from_numpy(paddings).to(dev),
)
all_embs.append(embs.cpu().numpy())
class_embs = l2_normalize(np.mean(all_embs, axis=0))
spatial = extract_features_value_attention(image, resolution)
overlay, mask, detected, undetected = vis_custom_semseg(
spatial,
image,
classes,
class_embs,
)
return overlay, mask, detected, undetected
@spaces.GPU
def on_depth_normals_predict(image, dpt_variant, resolution): # noqa: ARG001
"""Run DPT depth and normals prediction."""
if image is None:
return None, None
_init_model()
dev = _device()
h, w = image.shape[:2]
img = Image.fromarray(image).convert("RGB")
tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev)
depth_map = _model["dpt"].predict_depth(tensor)
normals_map = _model["dpt"].predict_normals(tensor)
return (
vis_depth_dpt(depth_map[0, 0].cpu().numpy(), h, w),
vis_normals_dpt(normals_map[0], h, w),
)
@spaces.GPU
def on_segmentation_predict(image, dpt_variant, resolution): # noqa: ARG001
"""Run DPT segmentation prediction."""
if image is None:
return None
_init_model()
dev = _device()
img = Image.fromarray(image).convert("RGB")
tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev)
seg_map = _model["dpt"].predict_segmentation(tensor)
return vis_segmentation_dpt(seg_map[0], image)
# ── UI ──────────────────────────────────────────────────────────────────────
custom_css = """
#pca_output_image img, #depth_output_image img {
image-rendering: pixelated;
object-fit: contain;
}
"""
head = """
<!-- Google tag (gtag.js) -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-P13E18K71N"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'G-P13E18K71N', {
'page_title': 'TIPSv2',
'page_location': 'https://huggingface.co/spaces/google/TIPSv2'
});
</script>
"""
with gr.Blocks(head=head, title="TIPSv2 Feature Explorer", css=custom_css) as demo:
gr.Markdown(
"## TIPSv2 Feature Explorer\n"
"Explore TIPSv2 representations here! For more information, see: "
"https://gdm-tipsv2.github.io/",
)
with gr.Row():
variant_dd = gr.Dropdown(
choices=list(VARIANTS.keys()),
value=DEFAULT_VARIANT,
label="Model variant",
)
resolution_dd = gr.Dropdown(
choices=RESOLUTIONS,
value=DEFAULT_IMAGE_SIZE,
label="Resolution (higher = better quality, slower)",
)
# ── PCA / Feature Visualization Tab ─────────────────────────────────
with gr.Tab("🎨 PCA & Feature Visualization"):
pca_state = gr.State(None)
with gr.Row():
with gr.Column():
pca_input = gr.Image(type="numpy", label="Input image")
pca_btn = gr.Button("Extract Features", variant="primary")
with gr.Column():
with gr.Tabs():
with gr.Tab("PCA"):
pca_out = gr.Image(
label="PCA (3 components β†’ RGB)",
height=448,
elem_id="pca_output_image",
)
with gr.Tab("PCA (1st component)"):
depth_out = gr.Image(
label="1st PCA component",
height=448,
elem_id="depth_output_image",
)
with gr.Tab("K-means Clustering"):
n_clusters = gr.Slider(
2,
20,
value=6,
step=1,
label="Clusters",
)
recluster_btn = gr.Button("Re-cluster")
kmeans_out = gr.Image(label="K-means clusters")
gr.Markdown("πŸ‘‡ **Click the examples below to explore!**")
gr.Examples(
examples=[
["examples/pca/hike.jpeg"],
["examples/pca/cph.jpeg"],
["examples/pca/angus.jpeg"],
["examples/pca/dadaocheng.jpeg"],
],
inputs=[pca_input],
)
# ── Zero-shot Segmentation Tab ──────────────────────────────────────
with gr.Tab("✏️ Zero-shot Segmentation"):
gr.Markdown(
"Define your own classes for zero-shot segmentation. "
"Enter class names separated by commas.",
)
with gr.Row():
with gr.Column():
custom_input = gr.Image(type="numpy", label="Input image", height=448)
custom_classes = gr.Textbox(
label="Class names (comma-separated)",
value="class1, class2, class3",
placeholder="e.g. cat, dog, sky, grass",
)
custom_btn = gr.Button("Segment", variant="primary")
with gr.Column():
with gr.Tabs():
with gr.Tab("Overlay"):
custom_overlay = gr.Image(
label="Segmentation overlay",
height=448,
)
with gr.Tab("Mask"):
custom_mask = gr.Image(
label="Segmentation mask",
height=448,
)
custom_detected = gr.Textbox(
label="Detected classes (sorted by area)",
lines=2,
)
custom_undetected = gr.Textbox(label="Not detected", lines=2)
gr.Markdown("πŸ‘‡ **Click the examples below to explore!**")
gr.Examples(
examples=[
["examples/zeroseg/voc_2008_000891.jpg", "dog, cage, cloth, dog bowl"],
[
"examples/zeroseg/pascal_context_00000_image.png",
"bike, tree, fence, soccer, floor, chair, cushion",
],
[
"examples/zeroseg/pascal_context_00007_image.png",
"dog, table, chair, carpet, shoes",
],
[
"examples/zeroseg/pascal_context_00049_image.png",
"bus, snow, mountain, house, road",
],
],
inputs=[custom_input, custom_classes],
)
# ── Depth/Normals Visualization Tab ─────────────────────────────────
with gr.Tab("πŸ”οΈ Depth/Normals Visualization"):
gr.Markdown(
"Monocular depth and surface normals estimation using a **DPT "
"(Dense Prediction Transformer)** head on top of a **frozen** "
"TIPS v2 vision encoder. Trained on the **NYU Depth V2** dataset.",
)
with gr.Row():
with gr.Column():
depth_input = gr.Image(type="numpy", label="Input image", height=448)
depth_btn = gr.Button("Predict Depth & Normals", variant="primary")
with gr.Column():
dpt_depth_out = gr.Image(label="DPT Depth Map", height=448)
with gr.Column():
dpt_normals_out = gr.Image(
label="DPT Surface Normals",
height=448,
)
gr.Markdown("πŸ‘‡ **Click the examples below to explore!**")
gr.Examples(
examples=[
["examples/nyuv2/bedroom_00280.jpg"],
["examples/nyuv2/kitchen_00249.jpg"],
["examples/nyuv2/living_room_01260.jpg"],
["examples/nyuv2/office_kitchen_00413.jpg"],
["examples/nyuv2/study_room_00272.jpg"],
],
inputs=[depth_input],
)
# ── Supervised Segmentation Tab ──────────────────────────────────────
with gr.Tab("🎭 Supervised Segmentation"):
gr.Markdown(
"Semantic segmentation using a **DPT (Dense Prediction "
"Transformer)** head on top of a **frozen** TIPS v2 vision "
"encoder. Trained on ADE20K (150 classes).",
)
with gr.Row():
with gr.Column():
seg_input = gr.Image(type="numpy", label="Input image", height=448)
seg_btn = gr.Button("Segment", variant="primary")
with gr.Column():
seg_out = gr.Image(label="DPT Segmentation (ADE20K)", height=448)
gr.Markdown("πŸ‘‡ **Click the examples below to explore!**")
gr.Examples(
examples=[
["examples/depth/ade20k_00003.png"],
["examples/depth/ade20k_00007.png"],
["examples/depth/ade20k_00014.png"],
["examples/depth/ade20k_00022.png"],
],
inputs=[seg_input],
)
# ── Wiring ──────────────────────────────────────────────────────────
variant_dd.change(
fn=on_variant_change,
inputs=[variant_dd],
outputs=[
pca_out,
depth_out,
kmeans_out,
pca_state,
custom_overlay,
custom_mask,
custom_detected,
custom_undetected,
],
)
pca_btn.click(
fn=on_pca_extract,
inputs=[pca_input, resolution_dd, pca_state],
outputs=[pca_out, depth_out, kmeans_out, pca_state],
)
recluster_btn.click(
fn=on_recluster,
inputs=[pca_input, resolution_dd, n_clusters, pca_state],
outputs=[kmeans_out, pca_state],
)
depth_btn.click(
fn=on_depth_normals_predict,
inputs=[depth_input, variant_dd, resolution_dd],
outputs=[dpt_depth_out, dpt_normals_out],
)
seg_btn.click(
fn=on_segmentation_predict,
inputs=[seg_input, variant_dd, resolution_dd],
outputs=[seg_out],
)
custom_btn.click(
fn=on_zeroseg_custom,
inputs=[custom_input, resolution_dd, custom_classes],
outputs=[custom_overlay, custom_mask, custom_detected, custom_undetected],
)
if __name__ == "__main__":
demo.launch()