"""TIPS Feature Explorer (GPU) — Hugging Face Space demo with ZeroGPU.""" import colorsys import os import gradio as gr import matplotlib.cm as cm import matplotlib.pyplot as plt import numpy as np import spaces import torch import torch.nn.functional as F from PIL import Image, ImageDraw, ImageFont from fast_pytorch_kmeans import KMeans as TorchKMeans from sklearn.decomposition import PCA from torchvision import transforms from transformers import AutoModel # ── Constants ─────────────────────────────────────────────────────────────── DEFAULT_IMAGE_SIZE = 896 PATCH_SIZE = 14 RESOLUTIONS = [224, 336, 448, 672, 896, 1120, 1372, 1792] ZEROSEG_IMAGE_SIZE = 1372 MAX_LEN = 64 VARIANTS = { "TIPS v2 — B/14": "google/tipsv2-b14-dpt", "TIPS v2 — L/14": "google/tipsv2-l14-dpt", "TIPS v2 — SO400m/14": "google/tipsv2-so400m14-dpt", "TIPS v2 — g/14": "google/tipsv2-g14-dpt", } DEFAULT_VARIANT = "TIPS v2 — L/14" def _device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu") # ── Pascal Context (59 classes) ───────────────────────────────────────────── TCL_PROMPTS = [ "itap of a {}.", "a bad photo of a {}.", "a origami {}.", "a photo of the large {}.", "a {} in a video game.", "art of the {}.", "a photo of the small {}.", "a photo of many {}.", "a photo of {}s.", ] PASCAL_CONTEXT_CLASSES = ( "aeroplane", "bag", "bed", "bedclothes", "bench", "bicycle", "bird", "boat", "book", "bottle", "building", "bus", "cabinet", "car", "cat", "ceiling", "chair", "cloth", "computer", "cow", "cup", "curtain", "dog", "door", "fence", "floor", "flower", "food", "grass", "ground", "horse", "keyboard", "light", "motorbike", "mountain", "mouse", "person", "plate", "platform", "pottedplant", "road", "rock", "sheep", "shelves", "sidewalk", "sign", "sky", "snow", "sofa", "table", "track", "train", "tree", "truck", "tvmonitor", "wall", "water", "window", "wood", ) ADE20K_CLASSES = ( "wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed", "windowpane", "grass", "cabinet", "sidewalk", "person", "earth", "door", "table", "mountain", "plant", "curtain", "chair", "car", "water", "painting", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock", "wardrobe", "lamp", "bathtub", "railing", "cushion", "base", "box", "column", "signboard", "chest_of_drawers", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator", "grandstand", "path", "stairs", "runway", "case", "pool_table", "pillow", "screen_door", "stairway", "river", "bridge", "bookcase", "blind", "coffee_table", "toilet", "flower", "book", "hill", "bench", "countertop", "stove", "palm", "kitchen_island", "computer", "swivel_chair", "boat", "bar", "arcade_machine", "hovel", "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight", "booth", "television", "airplane", "dirt_track", "apparel", "pole", "land", "bannister", "escalator", "ottoman", "bottle", "buffet", "poster", "stage", "van", "ship", "fountain", "conveyer_belt", "canopy", "washer", "plaything", "swimming_pool", "stool", "barrel", "basket", "waterfall", "tent", "bag", "minibike", "cradle", "oven", "ball", "food", "step", "tank", "trade_name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic_light", "tray", "ashcan", "fan", "pier", "crt_screen", "plate", "monitor", "bulletin_board", "shower", "radiator", "glass", "clock", "flag", ) NUM_ADE20K_CLASSES = 150 ADE20K_PALETTE = np.zeros((NUM_ADE20K_CLASSES + 1, 3), dtype=np.uint8) for i in range(1, NUM_ADE20K_CLASSES + 1): hue = (i * 0.618033988749895) % 1.0 saturation = 0.65 + 0.35 * ((i * 7) % 5) / 4.0 value = 0.70 + 0.30 * ((i * 11) % 3) / 2.0 r, g, b = colorsys.hsv_to_rgb(hue, saturation, value) ADE20K_PALETTE[i] = [int(r * 255), int(g * 255), int(b * 255)] # ── Model state (one model loaded at a time) ─────────────────────────────── _model = { "name": None, "vision": None, "text": None, "tokenizer": None, "temperature": None, "ade20k_embs": None, "dpt": None, } def load_variant(name): """Load a DPT model variant from HuggingFace (includes the backbone).""" global _model if _model["name"] == name: return token = os.environ.get("HF_TIPSv2") or os.environ.get("HF_TOKEN") dpt = AutoModel.from_pretrained(VARIANTS[name], trust_remote_code=True, token=token) dpt.eval() dpt._get_backbone() # trigger backbone download backbone = dpt._backbone _model.update( name=name, dpt=dpt, vision=backbone.vision_encoder, text=backbone.text_encoder, tokenizer=backbone._load_tokenizer(), temperature=backbone.config.temperature, ade20k_embs=None, ) print(f"Loaded {name}") def _move_models_to_device(): """Move models to the current device (GPU inside @spaces.GPU, else CPU).""" dev = _device() if _model["vision"] is not None: _model["vision"].to(dev) if _model["text"] is not None: _model["text"].to(dev) if _model["dpt"] is not None: _model["dpt"].to(dev) def _ensure_ade20k_embs(): """Pre-compute Pascal Context text embeddings if not yet done (must run on GPU).""" if _model["ade20k_embs"] is not None: return dev = _device() model_t = _model["text"] tokenizer = _model["tokenizer"] all_embs = [] for template in TCL_PROMPTS: prompts = [template.format(c) for c in PASCAL_CONTEXT_CLASSES] ids, paddings = tokenizer.tokenize(prompts, max_len=MAX_LEN) with torch.no_grad(): embs = model_t( torch.from_numpy(ids).to(dev), torch.from_numpy(paddings).to(dev), ) all_embs.append(embs.cpu().numpy()) _model["ade20k_embs"] = l2_normalize(np.mean(all_embs, axis=0)) print("Pascal Context text embeddings computed.") def _init_model(): """Load model + move to GPU + compute text embeddings.""" load_variant(_model["name"] or DEFAULT_VARIANT) _move_models_to_device() _ensure_ade20k_embs() # ── Preprocessing & helpers ───────────────────────────────────────────────── def preprocess(img, size=DEFAULT_IMAGE_SIZE): return transforms.Compose( [ transforms.Resize((size, size)), transforms.ToTensor(), ] )(img) def l2_normalize(x, axis=-1): return x / np.linalg.norm(x, ord=2, axis=axis, keepdims=True).clip(min=1e-3) def upsample(arr, h, w, mode="bilinear"): """Upsample (H, W, C) or (H, W) numpy array to (h, w, ...).""" t = torch.from_numpy(arr).float() if t.ndim == 2: t = t.unsqueeze(-1) t = t.permute(2, 0, 1).unsqueeze(0) kwargs = dict(align_corners=False) if mode == "bilinear" else {} up = F.interpolate(t, size=(h, w), mode=mode, **kwargs) return up[0].permute(1, 2, 0).numpy() def to_uint8(x): return (x * 255).clip(0, 255).astype(np.uint8) # ── Feature extraction (GPU-accelerated) ──────────────────────────────────── @torch.no_grad() def extract_features(image_np, resolution=DEFAULT_IMAGE_SIZE): """Return spatial features (sp, sp, D) as numpy. sp = resolution // 14.""" dev = _device() img = Image.fromarray(image_np).convert("RGB") tensor = preprocess(img, resolution).unsqueeze(0).to(dev) _, _, patch_tokens = _model["vision"](tensor) sp = resolution // PATCH_SIZE return patch_tokens.cpu().reshape(sp, sp, -1).numpy() @torch.no_grad() def extract_features_value_attention(image_np, resolution=ZEROSEG_IMAGE_SIZE): """Return spatial features (sp, sp, D) using Value Attention on GPU. This follows the Colab reference implementation: run all blocks except the last normally, then for the last block extract V from QKV and manually apply out_proj, layer scale, residual, norm2, MLP + layer scale, second residual, and final norm. """ dev = _device() model_image = _model["vision"] img = Image.fromarray(image_np).convert("RGB") tensor = preprocess(img, resolution).unsqueeze(0).to(dev) x = model_image.prepare_tokens_with_masks(tensor) for blk in model_image.blocks[:-1]: x = blk(x) blk = model_image.blocks[-1] num_reg = getattr(model_image, "num_register_tokens", 1) b_dim, n_dim, c_dim = x.shape num_heads = blk.attn.num_heads qkv = blk.attn.qkv(blk.norm1(x)) qkv = qkv.reshape(b_dim, n_dim, 3, num_heads, c_dim // num_heads) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, D_head) v = qkv[2] # (B, H, N, D_head) v_out = v.transpose(1, 2).reshape(b_dim, n_dim, c_dim) v_out = blk.attn.proj(v_out) v_out = blk.ls1(v_out) x_val = v_out + x y_val = blk.norm2(x_val) y_val = blk.ls2(blk.mlp(y_val)) x_val = x_val + y_val x_val = model_image.norm(x_val) patch_tokens = x_val[:, 1 + num_reg :, :] sp = resolution // PATCH_SIZE spatial = patch_tokens.cpu().reshape(sp, sp, -1).numpy() return spatial # ── PCA Visualisations ────────────────────────────────────────────────────── def vis_pca(spatial): """PCA of spatial features → RGB image.""" feat = spatial.reshape(-1, spatial.shape[-1]) pca = PCA(n_components=3, whiten=True) h, w = spatial.shape[0], spatial.shape[1] rgb = pca.fit_transform(feat).reshape(h, w, 3) rgb = 1 / (1 + np.exp(-2.0 * rgb)) return to_uint8(rgb) def vis_depth(spatial): """1st PCA component visualized with inferno colormap.""" feat = spatial.reshape(-1, spatial.shape[-1]) h, w = spatial.shape[0], spatial.shape[1] depth = PCA(n_components=1).fit_transform(feat).reshape(h, w) depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) colored = cm.get_cmap("inferno")(depth)[:, :, :3].astype(np.float32) return to_uint8(colored) def vis_kmeans(spatial, h, w, n_clusters=6): """K-means clustering of spatial features.""" sp_h, sp_w = spatial.shape[:2] feat = torch.from_numpy(spatial.reshape(-1, spatial.shape[-1])).to(_device()) km = TorchKMeans(n_clusters=n_clusters, max_iter=20) km.fit(feat) dists = -torch.cdist(feat, km.centroids) # (H*W, k) scores = dists.cpu().numpy().reshape(sp_h, sp_w, n_clusters) scores_up = upsample(scores, h, w, mode="bilinear") labels = scores_up.argmax(axis=-1) palette = plt.cm.tab20(np.linspace(0, 1, n_clusters))[:, :3] seg = palette[labels].astype(np.float32) return to_uint8(seg) # ── Zero-shot Segmentation ────────────────────────────────────────────────── def vis_custom_semseg(spatial, orig_image, classes, class_embs): """Zero-shot semantic segmentation with user-defined classes.""" h, w = orig_image.shape[:2] sp_h, sp_w = spatial.shape[:2] n = len(classes) feat = l2_normalize(spatial.reshape(-1, spatial.shape[-1])) sim = feat @ class_embs.T sim_map = sim.reshape(sp_h, sp_w, n) sim_up = upsample(sim_map, h, w, mode="bilinear") labels = sim_up.argmax(axis=-1) palette = (plt.cm.tab20(np.linspace(0, 1, max(n, 2)))[:n, :3] * 255).astype( np.uint8 ) seg_rgb = palette[labels].astype(np.float32) / 255.0 mask_img = to_uint8(seg_rgb) blend = 0.1 * orig_image.astype(np.float32) / 255.0 + 0.9 * seg_rgb blend_img = Image.fromarray(to_uint8(blend)) unique_ids, counts = np.unique(labels, return_counts=True) order = np.argsort(-counts) unique_ids, counts = unique_ids[order], counts[order] total = counts.sum() try: font = ImageFont.truetype( "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 60, ) except OSError: font = ImageFont.load_default() n_legend = min(len(unique_ids), 10) row_h = 80 swatch_w = 60 pad = 12 legend_w = 450 legend_h = max(h, n_legend * row_h + pad * 2) canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255)) canvas.paste(blend_img, (0, 0)) draw = ImageDraw.Draw(canvas) for i in range(n_legend): cid = unique_ids[i] color = tuple(palette[cid].tolist()) y_top = pad + i * row_h draw.rectangle( [w + pad, y_top, w + pad + swatch_w, y_top + swatch_w], fill=color, outline=(0, 0, 0), ) draw.text( (w + pad + swatch_w + 8, y_top + 6), classes[cid], fill="black", font=font, ) overlay_out = np.array(canvas) detected_parts, minor_parts = [], [] for i, cid in enumerate(unique_ids): pct = counts[i] / total * 100 if pct >= 2: detected_parts.append(f"{classes[cid]} ({pct:.1f}%)") else: minor_parts.append(f"{classes[cid]} ({pct:.1f}%)") absent = [ f"{classes[i]} (0.0%)" for i in range(n) if i not in set(unique_ids.tolist()) ] detected_str = ", ".join(detected_parts) undetected_str = ", ".join(minor_parts + absent) return overlay_out, mask_img, detected_str, undetected_str # ── DPT Depth Inference ───────────────────────────────────────────────────── def vis_depth_dpt(depth_map, h, w): """Colour a depth map with the turbo colormap → PIL Image.""" d = depth_map.squeeze() d = (d - d.min()) / (d.max() - d.min() + 1e-8) colored = cm.get_cmap("turbo")(d)[:, :, :3].astype(np.float32) return to_uint8(upsample(colored, h, w)) def vis_normals_dpt(normals_map, h, w): """Map normals from [-1, 1] to [0, 1] and resize to original size.""" n = normals_map.cpu().numpy() n = (n + 1.0) / 2.0 n = np.transpose(n, (1, 2, 0)) # (H, W, 3) return to_uint8(upsample(n, h, w)) def vis_segmentation_dpt(seg_map, orig_image): """Colour a segmentation map with the ADE20K colormap + legend.""" h, w = orig_image.shape[:2] logits = seg_map.cpu().numpy().transpose(1, 2, 0) # (H, W, 150) logits_up = upsample(logits, h, w, mode="bilinear") pred = logits_up.argmax(axis=-1) # (h, w) seg_rgb = ADE20K_PALETTE[pred.astype(np.int32) + 1].astype(np.float32) / 255.0 blend = 0.15 * orig_image.astype(np.float32) / 255.0 + 0.85 * seg_rgb blend_img = Image.fromarray(to_uint8(blend)) unique_ids, counts = np.unique(pred, return_counts=True) total_pixels = counts.sum() order = np.argsort(-counts) unique_ids, counts = unique_ids[order], counts[order] pcts = counts / total_pixels * 100 mask = pcts >= 2.0 unique_ids, counts, pcts = unique_ids[mask], counts[mask], pcts[mask] try: font = ImageFont.truetype( "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 36, ) except OSError: font = ImageFont.load_default() n_legend = min(len(unique_ids), 10) row_h, swatch_w, pad, legend_w = 50, 40, 10, 450 legend_h = max(h, n_legend * row_h + pad * 2) canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255)) canvas.paste(blend_img, (0, 0)) draw = ImageDraw.Draw(canvas) for i in range(n_legend): cid = unique_ids[i] color = tuple(ADE20K_PALETTE[cid + 1].tolist()) name = ADE20K_CLASSES[cid] if cid < len(ADE20K_CLASSES) else f"class_{cid}" y_top = pad + i * row_h draw.rectangle( [w + pad, y_top, w + pad + swatch_w, y_top + swatch_w], fill=color, outline=(0, 0, 0), ) draw.text( (w + pad + swatch_w + 8, y_top + 4), name, fill="black", font=font, ) return np.array(canvas) # ── Gradio callbacks ──────────────────────────────────────────────────────── @spaces.GPU def on_variant_change(variant_name): load_variant(variant_name) _move_models_to_device() _ensure_ade20k_embs() return ( None, None, None, # pca_out, depth_out, kmeans_out None, # pca_state None, None, "", "", # custom outputs ) @spaces.GPU def on_pca_extract(image, resolution, _pca_state): if image is None: return None, None, None, None _init_model() resolution = int(resolution) spatial = extract_features(image, resolution) h, w = image.shape[:2] pca = vis_pca(spatial) depth = vis_depth(spatial) kmeans = vis_kmeans(spatial, h, w) state = { "spatial": spatial, "orig_image": image, "variant": _model["name"], "resolution": resolution, } return pca, depth, kmeans, state @spaces.GPU def on_recluster(image, resolution, n_clusters, pca_state): if image is None: gr.Warning("Upload an image first.") return None, pca_state _init_model() resolution = int(resolution) if ( pca_state is not None and pca_state.get("variant") == _model["name"] and pca_state.get("resolution") == resolution ): spatial = pca_state["spatial"] else: spatial = extract_features(image, resolution) pca_state = { "spatial": spatial, "orig_image": image, "variant": _model["name"], "resolution": resolution, } h, w = image.shape[:2] return vis_kmeans(spatial, h, w, int(n_clusters)), pca_state @spaces.GPU def on_zeroseg_custom(image, resolution, class_names_str): if image is None or not class_names_str or not class_names_str.strip(): gr.Warning("Upload an image and enter at least one class name.") return None, None, "", "" _init_model() resolution = int(resolution) classes = [c.strip() for c in class_names_str.split(",") if c.strip()] if not classes: return None, None, "", "" dev = _device() all_embs = [] for template in TCL_PROMPTS: prompts = [template.format(c) for c in classes] ids, paddings = _model["tokenizer"].tokenize(prompts, max_len=MAX_LEN) with torch.no_grad(): embs = _model["text"]( torch.from_numpy(ids).to(dev), torch.from_numpy(paddings).to(dev), ) all_embs.append(embs.cpu().numpy()) class_embs = l2_normalize(np.mean(all_embs, axis=0)) spatial = extract_features_value_attention(image, resolution) overlay, mask, detected, undetected = vis_custom_semseg( spatial, image, classes, class_embs, ) return overlay, mask, detected, undetected @spaces.GPU def on_depth_normals_predict(image, dpt_variant, resolution): # noqa: ARG001 """Run DPT depth and normals prediction.""" if image is None: return None, None _init_model() dev = _device() h, w = image.shape[:2] img = Image.fromarray(image).convert("RGB") tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev) depth_map = _model["dpt"].predict_depth(tensor) normals_map = _model["dpt"].predict_normals(tensor) return ( vis_depth_dpt(depth_map[0, 0].cpu().numpy(), h, w), vis_normals_dpt(normals_map[0], h, w), ) @spaces.GPU def on_segmentation_predict(image, dpt_variant, resolution): # noqa: ARG001 """Run DPT segmentation prediction.""" if image is None: return None _init_model() dev = _device() img = Image.fromarray(image).convert("RGB") tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev) seg_map = _model["dpt"].predict_segmentation(tensor) return vis_segmentation_dpt(seg_map[0], image) # ── UI ────────────────────────────────────────────────────────────────────── custom_css = """ #pca_output_image img, #depth_output_image img { image-rendering: pixelated; object-fit: contain; } """ head = """ """ with gr.Blocks(head=head, title="TIPSv2 Feature Explorer", css=custom_css) as demo: gr.Markdown( "## TIPSv2 Feature Explorer\n" "Explore TIPSv2 representations here! For more information, see: " "https://gdm-tipsv2.github.io/", ) with gr.Row(): variant_dd = gr.Dropdown( choices=list(VARIANTS.keys()), value=DEFAULT_VARIANT, label="Model variant", ) resolution_dd = gr.Dropdown( choices=RESOLUTIONS, value=DEFAULT_IMAGE_SIZE, label="Resolution (higher = better quality, slower)", ) # ── PCA / Feature Visualization Tab ───────────────────────────────── with gr.Tab("🎨 PCA & Feature Visualization"): pca_state = gr.State(None) with gr.Row(): with gr.Column(): pca_input = gr.Image(type="numpy", label="Input image") pca_btn = gr.Button("Extract Features", variant="primary") with gr.Column(): with gr.Tabs(): with gr.Tab("PCA"): pca_out = gr.Image( label="PCA (3 components → RGB)", height=448, elem_id="pca_output_image", ) with gr.Tab("PCA (1st component)"): depth_out = gr.Image( label="1st PCA component", height=448, elem_id="depth_output_image", ) with gr.Tab("K-means Clustering"): n_clusters = gr.Slider( 2, 20, value=6, step=1, label="Clusters", ) recluster_btn = gr.Button("Re-cluster") kmeans_out = gr.Image(label="K-means clusters") gr.Markdown("👇 **Click the examples below to explore!**") gr.Examples( examples=[ ["examples/pca/hike.jpeg"], ["examples/pca/cph.jpeg"], ["examples/pca/angus.jpeg"], ["examples/pca/dadaocheng.jpeg"], ], inputs=[pca_input], ) # ── Zero-shot Segmentation Tab ────────────────────────────────────── with gr.Tab("✏️ Zero-shot Segmentation"): gr.Markdown( "Define your own classes for zero-shot segmentation. " "Enter class names separated by commas.", ) with gr.Row(): with gr.Column(): custom_input = gr.Image(type="numpy", label="Input image", height=448) custom_classes = gr.Textbox( label="Class names (comma-separated)", value="class1, class2, class3", placeholder="e.g. cat, dog, sky, grass", ) custom_btn = gr.Button("Segment", variant="primary") with gr.Column(): with gr.Tabs(): with gr.Tab("Overlay"): custom_overlay = gr.Image( label="Segmentation overlay", height=448, ) with gr.Tab("Mask"): custom_mask = gr.Image( label="Segmentation mask", height=448, ) custom_detected = gr.Textbox( label="Detected classes (sorted by area)", lines=2, ) custom_undetected = gr.Textbox(label="Not detected", lines=2) gr.Markdown("👇 **Click the examples below to explore!**") gr.Examples( examples=[ ["examples/zeroseg/voc_2008_000891.jpg", "dog, cage, cloth, dog bowl"], [ "examples/zeroseg/pascal_context_00000_image.png", "bike, tree, fence, soccer, floor, chair, cushion", ], [ "examples/zeroseg/pascal_context_00007_image.png", "dog, table, chair, carpet, shoes", ], [ "examples/zeroseg/pascal_context_00049_image.png", "bus, snow, mountain, house, road", ], ], inputs=[custom_input, custom_classes], ) # ── Depth/Normals Visualization Tab ───────────────────────────────── with gr.Tab("🏔️ Depth/Normals Visualization"): gr.Markdown( "Monocular depth and surface normals estimation using a **DPT " "(Dense Prediction Transformer)** head on top of a **frozen** " "TIPS v2 vision encoder. Trained on the **NYU Depth V2** dataset.", ) with gr.Row(): with gr.Column(): depth_input = gr.Image(type="numpy", label="Input image", height=448) depth_btn = gr.Button("Predict Depth & Normals", variant="primary") with gr.Column(): dpt_depth_out = gr.Image(label="DPT Depth Map", height=448) with gr.Column(): dpt_normals_out = gr.Image( label="DPT Surface Normals", height=448, ) gr.Markdown("👇 **Click the examples below to explore!**") gr.Examples( examples=[ ["examples/nyuv2/bedroom_00280.jpg"], ["examples/nyuv2/kitchen_00249.jpg"], ["examples/nyuv2/living_room_01260.jpg"], ["examples/nyuv2/office_kitchen_00413.jpg"], ["examples/nyuv2/study_room_00272.jpg"], ], inputs=[depth_input], ) # ── Supervised Segmentation Tab ────────────────────────────────────── with gr.Tab("🎭 Supervised Segmentation"): gr.Markdown( "Semantic segmentation using a **DPT (Dense Prediction " "Transformer)** head on top of a **frozen** TIPS v2 vision " "encoder. Trained on ADE20K (150 classes).", ) with gr.Row(): with gr.Column(): seg_input = gr.Image(type="numpy", label="Input image", height=448) seg_btn = gr.Button("Segment", variant="primary") with gr.Column(): seg_out = gr.Image(label="DPT Segmentation (ADE20K)", height=448) gr.Markdown("👇 **Click the examples below to explore!**") gr.Examples( examples=[ ["examples/depth/ade20k_00003.png"], ["examples/depth/ade20k_00007.png"], ["examples/depth/ade20k_00014.png"], ["examples/depth/ade20k_00022.png"], ], inputs=[seg_input], ) # ── Wiring ────────────────────────────────────────────────────────── variant_dd.change( fn=on_variant_change, inputs=[variant_dd], outputs=[ pca_out, depth_out, kmeans_out, pca_state, custom_overlay, custom_mask, custom_detected, custom_undetected, ], ) pca_btn.click( fn=on_pca_extract, inputs=[pca_input, resolution_dd, pca_state], outputs=[pca_out, depth_out, kmeans_out, pca_state], ) recluster_btn.click( fn=on_recluster, inputs=[pca_input, resolution_dd, n_clusters, pca_state], outputs=[kmeans_out, pca_state], ) depth_btn.click( fn=on_depth_normals_predict, inputs=[depth_input, variant_dd, resolution_dd], outputs=[dpt_depth_out, dpt_normals_out], ) seg_btn.click( fn=on_segmentation_predict, inputs=[seg_input, variant_dd, resolution_dd], outputs=[seg_out], ) custom_btn.click( fn=on_zeroseg_custom, inputs=[custom_input, resolution_dd, custom_classes], outputs=[custom_overlay, custom_mask, custom_detected, custom_undetected], ) if __name__ == "__main__": demo.launch()