Spaces:
Running on Zero
Running on Zero
| """TIPS Feature Explorer (GPU) β Hugging Face Space demo with ZeroGPU.""" | |
| import colorsys | |
| import os | |
| import gradio as gr | |
| import matplotlib.cm as cm | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageDraw, ImageFont | |
| from fast_pytorch_kmeans import KMeans as TorchKMeans | |
| from sklearn.decomposition import PCA | |
| from torchvision import transforms | |
| from transformers import AutoModel | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_IMAGE_SIZE = 896 | |
| PATCH_SIZE = 14 | |
| RESOLUTIONS = [224, 336, 448, 672, 896, 1120, 1372, 1792] | |
| ZEROSEG_IMAGE_SIZE = 1372 | |
| MAX_LEN = 64 | |
| VARIANTS = { | |
| "TIPS v2 β B/14": "google/tipsv2-b14-dpt", | |
| "TIPS v2 β L/14": "google/tipsv2-l14-dpt", | |
| "TIPS v2 β SO400m/14": "google/tipsv2-so400m14-dpt", | |
| "TIPS v2 β g/14": "google/tipsv2-g14-dpt", | |
| } | |
| DEFAULT_VARIANT = "TIPS v2 β L/14" | |
| def _device(): | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ββ Pascal Context (59 classes) βββββββββββββββββββββββββββββββββββββββββββββ | |
| TCL_PROMPTS = [ | |
| "itap of a {}.", | |
| "a bad photo of a {}.", | |
| "a origami {}.", | |
| "a photo of the large {}.", | |
| "a {} in a video game.", | |
| "art of the {}.", | |
| "a photo of the small {}.", | |
| "a photo of many {}.", | |
| "a photo of {}s.", | |
| ] | |
| PASCAL_CONTEXT_CLASSES = ( | |
| "aeroplane", | |
| "bag", | |
| "bed", | |
| "bedclothes", | |
| "bench", | |
| "bicycle", | |
| "bird", | |
| "boat", | |
| "book", | |
| "bottle", | |
| "building", | |
| "bus", | |
| "cabinet", | |
| "car", | |
| "cat", | |
| "ceiling", | |
| "chair", | |
| "cloth", | |
| "computer", | |
| "cow", | |
| "cup", | |
| "curtain", | |
| "dog", | |
| "door", | |
| "fence", | |
| "floor", | |
| "flower", | |
| "food", | |
| "grass", | |
| "ground", | |
| "horse", | |
| "keyboard", | |
| "light", | |
| "motorbike", | |
| "mountain", | |
| "mouse", | |
| "person", | |
| "plate", | |
| "platform", | |
| "pottedplant", | |
| "road", | |
| "rock", | |
| "sheep", | |
| "shelves", | |
| "sidewalk", | |
| "sign", | |
| "sky", | |
| "snow", | |
| "sofa", | |
| "table", | |
| "track", | |
| "train", | |
| "tree", | |
| "truck", | |
| "tvmonitor", | |
| "wall", | |
| "water", | |
| "window", | |
| "wood", | |
| ) | |
| ADE20K_CLASSES = ( | |
| "wall", | |
| "building", | |
| "sky", | |
| "floor", | |
| "tree", | |
| "ceiling", | |
| "road", | |
| "bed", | |
| "windowpane", | |
| "grass", | |
| "cabinet", | |
| "sidewalk", | |
| "person", | |
| "earth", | |
| "door", | |
| "table", | |
| "mountain", | |
| "plant", | |
| "curtain", | |
| "chair", | |
| "car", | |
| "water", | |
| "painting", | |
| "sofa", | |
| "shelf", | |
| "house", | |
| "sea", | |
| "mirror", | |
| "rug", | |
| "field", | |
| "armchair", | |
| "seat", | |
| "fence", | |
| "desk", | |
| "rock", | |
| "wardrobe", | |
| "lamp", | |
| "bathtub", | |
| "railing", | |
| "cushion", | |
| "base", | |
| "box", | |
| "column", | |
| "signboard", | |
| "chest_of_drawers", | |
| "counter", | |
| "sand", | |
| "sink", | |
| "skyscraper", | |
| "fireplace", | |
| "refrigerator", | |
| "grandstand", | |
| "path", | |
| "stairs", | |
| "runway", | |
| "case", | |
| "pool_table", | |
| "pillow", | |
| "screen_door", | |
| "stairway", | |
| "river", | |
| "bridge", | |
| "bookcase", | |
| "blind", | |
| "coffee_table", | |
| "toilet", | |
| "flower", | |
| "book", | |
| "hill", | |
| "bench", | |
| "countertop", | |
| "stove", | |
| "palm", | |
| "kitchen_island", | |
| "computer", | |
| "swivel_chair", | |
| "boat", | |
| "bar", | |
| "arcade_machine", | |
| "hovel", | |
| "bus", | |
| "towel", | |
| "light", | |
| "truck", | |
| "tower", | |
| "chandelier", | |
| "awning", | |
| "streetlight", | |
| "booth", | |
| "television", | |
| "airplane", | |
| "dirt_track", | |
| "apparel", | |
| "pole", | |
| "land", | |
| "bannister", | |
| "escalator", | |
| "ottoman", | |
| "bottle", | |
| "buffet", | |
| "poster", | |
| "stage", | |
| "van", | |
| "ship", | |
| "fountain", | |
| "conveyer_belt", | |
| "canopy", | |
| "washer", | |
| "plaything", | |
| "swimming_pool", | |
| "stool", | |
| "barrel", | |
| "basket", | |
| "waterfall", | |
| "tent", | |
| "bag", | |
| "minibike", | |
| "cradle", | |
| "oven", | |
| "ball", | |
| "food", | |
| "step", | |
| "tank", | |
| "trade_name", | |
| "microwave", | |
| "pot", | |
| "animal", | |
| "bicycle", | |
| "lake", | |
| "dishwasher", | |
| "screen", | |
| "blanket", | |
| "sculpture", | |
| "hood", | |
| "sconce", | |
| "vase", | |
| "traffic_light", | |
| "tray", | |
| "ashcan", | |
| "fan", | |
| "pier", | |
| "crt_screen", | |
| "plate", | |
| "monitor", | |
| "bulletin_board", | |
| "shower", | |
| "radiator", | |
| "glass", | |
| "clock", | |
| "flag", | |
| ) | |
| NUM_ADE20K_CLASSES = 150 | |
| ADE20K_PALETTE = np.zeros((NUM_ADE20K_CLASSES + 1, 3), dtype=np.uint8) | |
| for i in range(1, NUM_ADE20K_CLASSES + 1): | |
| hue = (i * 0.618033988749895) % 1.0 | |
| saturation = 0.65 + 0.35 * ((i * 7) % 5) / 4.0 | |
| value = 0.70 + 0.30 * ((i * 11) % 3) / 2.0 | |
| r, g, b = colorsys.hsv_to_rgb(hue, saturation, value) | |
| ADE20K_PALETTE[i] = [int(r * 255), int(g * 255), int(b * 255)] | |
| # ββ Model state (one model loaded at a time) βββββββββββββββββββββββββββββββ | |
| _model = { | |
| "name": None, | |
| "vision": None, | |
| "text": None, | |
| "tokenizer": None, | |
| "temperature": None, | |
| "ade20k_embs": None, | |
| "dpt": None, | |
| } | |
| def load_variant(name): | |
| """Load a DPT model variant from HuggingFace (includes the backbone).""" | |
| global _model | |
| if _model["name"] == name: | |
| return | |
| token = os.environ.get("HF_TIPSv2") or os.environ.get("HF_TOKEN") | |
| dpt = AutoModel.from_pretrained(VARIANTS[name], trust_remote_code=True, token=token) | |
| dpt.eval() | |
| dpt._get_backbone() # trigger backbone download | |
| backbone = dpt._backbone | |
| _model.update( | |
| name=name, | |
| dpt=dpt, | |
| vision=backbone.vision_encoder, | |
| text=backbone.text_encoder, | |
| tokenizer=backbone._load_tokenizer(), | |
| temperature=backbone.config.temperature, | |
| ade20k_embs=None, | |
| ) | |
| print(f"Loaded {name}") | |
| def _move_models_to_device(): | |
| """Move models to the current device (GPU inside @spaces.GPU, else CPU).""" | |
| dev = _device() | |
| if _model["vision"] is not None: | |
| _model["vision"].to(dev) | |
| if _model["text"] is not None: | |
| _model["text"].to(dev) | |
| if _model["dpt"] is not None: | |
| _model["dpt"].to(dev) | |
| def _ensure_ade20k_embs(): | |
| """Pre-compute Pascal Context text embeddings if not yet done (must run on GPU).""" | |
| if _model["ade20k_embs"] is not None: | |
| return | |
| dev = _device() | |
| model_t = _model["text"] | |
| tokenizer = _model["tokenizer"] | |
| all_embs = [] | |
| for template in TCL_PROMPTS: | |
| prompts = [template.format(c) for c in PASCAL_CONTEXT_CLASSES] | |
| ids, paddings = tokenizer.tokenize(prompts, max_len=MAX_LEN) | |
| with torch.no_grad(): | |
| embs = model_t( | |
| torch.from_numpy(ids).to(dev), | |
| torch.from_numpy(paddings).to(dev), | |
| ) | |
| all_embs.append(embs.cpu().numpy()) | |
| _model["ade20k_embs"] = l2_normalize(np.mean(all_embs, axis=0)) | |
| print("Pascal Context text embeddings computed.") | |
| def _init_model(): | |
| """Load model + move to GPU + compute text embeddings.""" | |
| load_variant(_model["name"] or DEFAULT_VARIANT) | |
| _move_models_to_device() | |
| _ensure_ade20k_embs() | |
| # ββ Preprocessing & helpers βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def preprocess(img, size=DEFAULT_IMAGE_SIZE): | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize((size, size)), | |
| transforms.ToTensor(), | |
| ] | |
| )(img) | |
| def l2_normalize(x, axis=-1): | |
| return x / np.linalg.norm(x, ord=2, axis=axis, keepdims=True).clip(min=1e-3) | |
| def upsample(arr, h, w, mode="bilinear"): | |
| """Upsample (H, W, C) or (H, W) numpy array to (h, w, ...).""" | |
| t = torch.from_numpy(arr).float() | |
| if t.ndim == 2: | |
| t = t.unsqueeze(-1) | |
| t = t.permute(2, 0, 1).unsqueeze(0) | |
| kwargs = dict(align_corners=False) if mode == "bilinear" else {} | |
| up = F.interpolate(t, size=(h, w), mode=mode, **kwargs) | |
| return up[0].permute(1, 2, 0).numpy() | |
| def to_uint8(x): | |
| return (x * 255).clip(0, 255).astype(np.uint8) | |
| # ββ Feature extraction (GPU-accelerated) ββββββββββββββββββββββββββββββββββββ | |
| def extract_features(image_np, resolution=DEFAULT_IMAGE_SIZE): | |
| """Return spatial features (sp, sp, D) as numpy. sp = resolution // 14.""" | |
| dev = _device() | |
| img = Image.fromarray(image_np).convert("RGB") | |
| tensor = preprocess(img, resolution).unsqueeze(0).to(dev) | |
| _, _, patch_tokens = _model["vision"](tensor) | |
| sp = resolution // PATCH_SIZE | |
| return patch_tokens.cpu().reshape(sp, sp, -1).numpy() | |
| def extract_features_value_attention(image_np, resolution=ZEROSEG_IMAGE_SIZE): | |
| """Return spatial features (sp, sp, D) using Value Attention on GPU. | |
| This follows the Colab reference implementation: run all blocks except the | |
| last normally, then for the last block extract V from QKV and manually | |
| apply out_proj, layer scale, residual, norm2, MLP + layer scale, second | |
| residual, and final norm. | |
| """ | |
| dev = _device() | |
| model_image = _model["vision"] | |
| img = Image.fromarray(image_np).convert("RGB") | |
| tensor = preprocess(img, resolution).unsqueeze(0).to(dev) | |
| x = model_image.prepare_tokens_with_masks(tensor) | |
| for blk in model_image.blocks[:-1]: | |
| x = blk(x) | |
| blk = model_image.blocks[-1] | |
| num_reg = getattr(model_image, "num_register_tokens", 1) | |
| b_dim, n_dim, c_dim = x.shape | |
| num_heads = blk.attn.num_heads | |
| qkv = blk.attn.qkv(blk.norm1(x)) | |
| qkv = qkv.reshape(b_dim, n_dim, 3, num_heads, c_dim // num_heads) | |
| qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, D_head) | |
| v = qkv[2] # (B, H, N, D_head) | |
| v_out = v.transpose(1, 2).reshape(b_dim, n_dim, c_dim) | |
| v_out = blk.attn.proj(v_out) | |
| v_out = blk.ls1(v_out) | |
| x_val = v_out + x | |
| y_val = blk.norm2(x_val) | |
| y_val = blk.ls2(blk.mlp(y_val)) | |
| x_val = x_val + y_val | |
| x_val = model_image.norm(x_val) | |
| patch_tokens = x_val[:, 1 + num_reg :, :] | |
| sp = resolution // PATCH_SIZE | |
| spatial = patch_tokens.cpu().reshape(sp, sp, -1).numpy() | |
| return spatial | |
| # ββ PCA Visualisations ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def vis_pca(spatial): | |
| """PCA of spatial features β RGB image.""" | |
| feat = spatial.reshape(-1, spatial.shape[-1]) | |
| pca = PCA(n_components=3, whiten=True) | |
| h, w = spatial.shape[0], spatial.shape[1] | |
| rgb = pca.fit_transform(feat).reshape(h, w, 3) | |
| rgb = 1 / (1 + np.exp(-2.0 * rgb)) | |
| return to_uint8(rgb) | |
| def vis_depth(spatial): | |
| """1st PCA component visualized with inferno colormap.""" | |
| feat = spatial.reshape(-1, spatial.shape[-1]) | |
| h, w = spatial.shape[0], spatial.shape[1] | |
| depth = PCA(n_components=1).fit_transform(feat).reshape(h, w) | |
| depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) | |
| colored = cm.get_cmap("inferno")(depth)[:, :, :3].astype(np.float32) | |
| return to_uint8(colored) | |
| def vis_kmeans(spatial, h, w, n_clusters=6): | |
| """K-means clustering of spatial features.""" | |
| sp_h, sp_w = spatial.shape[:2] | |
| feat = torch.from_numpy(spatial.reshape(-1, spatial.shape[-1])).to(_device()) | |
| km = TorchKMeans(n_clusters=n_clusters, max_iter=20) | |
| km.fit(feat) | |
| dists = -torch.cdist(feat, km.centroids) # (H*W, k) | |
| scores = dists.cpu().numpy().reshape(sp_h, sp_w, n_clusters) | |
| scores_up = upsample(scores, h, w, mode="bilinear") | |
| labels = scores_up.argmax(axis=-1) | |
| palette = plt.cm.tab20(np.linspace(0, 1, n_clusters))[:, :3] | |
| seg = palette[labels].astype(np.float32) | |
| return to_uint8(seg) | |
| # ββ Zero-shot Segmentation ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def vis_custom_semseg(spatial, orig_image, classes, class_embs): | |
| """Zero-shot semantic segmentation with user-defined classes.""" | |
| h, w = orig_image.shape[:2] | |
| sp_h, sp_w = spatial.shape[:2] | |
| n = len(classes) | |
| feat = l2_normalize(spatial.reshape(-1, spatial.shape[-1])) | |
| sim = feat @ class_embs.T | |
| sim_map = sim.reshape(sp_h, sp_w, n) | |
| sim_up = upsample(sim_map, h, w, mode="bilinear") | |
| labels = sim_up.argmax(axis=-1) | |
| palette = (plt.cm.tab20(np.linspace(0, 1, max(n, 2)))[:n, :3] * 255).astype( | |
| np.uint8 | |
| ) | |
| seg_rgb = palette[labels].astype(np.float32) / 255.0 | |
| mask_img = to_uint8(seg_rgb) | |
| blend = 0.1 * orig_image.astype(np.float32) / 255.0 + 0.9 * seg_rgb | |
| blend_img = Image.fromarray(to_uint8(blend)) | |
| unique_ids, counts = np.unique(labels, return_counts=True) | |
| order = np.argsort(-counts) | |
| unique_ids, counts = unique_ids[order], counts[order] | |
| total = counts.sum() | |
| try: | |
| font = ImageFont.truetype( | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", | |
| 60, | |
| ) | |
| except OSError: | |
| font = ImageFont.load_default() | |
| n_legend = min(len(unique_ids), 10) | |
| row_h = 80 | |
| swatch_w = 60 | |
| pad = 12 | |
| legend_w = 450 | |
| legend_h = max(h, n_legend * row_h + pad * 2) | |
| canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255)) | |
| canvas.paste(blend_img, (0, 0)) | |
| draw = ImageDraw.Draw(canvas) | |
| for i in range(n_legend): | |
| cid = unique_ids[i] | |
| color = tuple(palette[cid].tolist()) | |
| y_top = pad + i * row_h | |
| draw.rectangle( | |
| [w + pad, y_top, w + pad + swatch_w, y_top + swatch_w], | |
| fill=color, | |
| outline=(0, 0, 0), | |
| ) | |
| draw.text( | |
| (w + pad + swatch_w + 8, y_top + 6), | |
| classes[cid], | |
| fill="black", | |
| font=font, | |
| ) | |
| overlay_out = np.array(canvas) | |
| detected_parts, minor_parts = [], [] | |
| for i, cid in enumerate(unique_ids): | |
| pct = counts[i] / total * 100 | |
| if pct >= 2: | |
| detected_parts.append(f"{classes[cid]} ({pct:.1f}%)") | |
| else: | |
| minor_parts.append(f"{classes[cid]} ({pct:.1f}%)") | |
| absent = [ | |
| f"{classes[i]} (0.0%)" for i in range(n) if i not in set(unique_ids.tolist()) | |
| ] | |
| detected_str = ", ".join(detected_parts) | |
| undetected_str = ", ".join(minor_parts + absent) | |
| return overlay_out, mask_img, detected_str, undetected_str | |
| # ββ DPT Depth Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def vis_depth_dpt(depth_map, h, w): | |
| """Colour a depth map with the turbo colormap β PIL Image.""" | |
| d = depth_map.squeeze() | |
| d = (d - d.min()) / (d.max() - d.min() + 1e-8) | |
| colored = cm.get_cmap("turbo")(d)[:, :, :3].astype(np.float32) | |
| return to_uint8(upsample(colored, h, w)) | |
| def vis_normals_dpt(normals_map, h, w): | |
| """Map normals from [-1, 1] to [0, 1] and resize to original size.""" | |
| n = normals_map.cpu().numpy() | |
| n = (n + 1.0) / 2.0 | |
| n = np.transpose(n, (1, 2, 0)) # (H, W, 3) | |
| return to_uint8(upsample(n, h, w)) | |
| def vis_segmentation_dpt(seg_map, orig_image): | |
| """Colour a segmentation map with the ADE20K colormap + legend.""" | |
| h, w = orig_image.shape[:2] | |
| logits = seg_map.cpu().numpy().transpose(1, 2, 0) # (H, W, 150) | |
| logits_up = upsample(logits, h, w, mode="bilinear") | |
| pred = logits_up.argmax(axis=-1) # (h, w) | |
| seg_rgb = ADE20K_PALETTE[pred.astype(np.int32) + 1].astype(np.float32) / 255.0 | |
| blend = 0.15 * orig_image.astype(np.float32) / 255.0 + 0.85 * seg_rgb | |
| blend_img = Image.fromarray(to_uint8(blend)) | |
| unique_ids, counts = np.unique(pred, return_counts=True) | |
| total_pixels = counts.sum() | |
| order = np.argsort(-counts) | |
| unique_ids, counts = unique_ids[order], counts[order] | |
| pcts = counts / total_pixels * 100 | |
| mask = pcts >= 2.0 | |
| unique_ids, counts, pcts = unique_ids[mask], counts[mask], pcts[mask] | |
| try: | |
| font = ImageFont.truetype( | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", | |
| 36, | |
| ) | |
| except OSError: | |
| font = ImageFont.load_default() | |
| n_legend = min(len(unique_ids), 10) | |
| row_h, swatch_w, pad, legend_w = 50, 40, 10, 450 | |
| legend_h = max(h, n_legend * row_h + pad * 2) | |
| canvas = Image.new("RGB", (w + legend_w, legend_h), (255, 255, 255)) | |
| canvas.paste(blend_img, (0, 0)) | |
| draw = ImageDraw.Draw(canvas) | |
| for i in range(n_legend): | |
| cid = unique_ids[i] | |
| color = tuple(ADE20K_PALETTE[cid + 1].tolist()) | |
| name = ADE20K_CLASSES[cid] if cid < len(ADE20K_CLASSES) else f"class_{cid}" | |
| y_top = pad + i * row_h | |
| draw.rectangle( | |
| [w + pad, y_top, w + pad + swatch_w, y_top + swatch_w], | |
| fill=color, | |
| outline=(0, 0, 0), | |
| ) | |
| draw.text( | |
| (w + pad + swatch_w + 8, y_top + 4), | |
| name, | |
| fill="black", | |
| font=font, | |
| ) | |
| return np.array(canvas) | |
| # ββ Gradio callbacks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def on_variant_change(variant_name): | |
| load_variant(variant_name) | |
| _move_models_to_device() | |
| _ensure_ade20k_embs() | |
| return ( | |
| None, | |
| None, | |
| None, # pca_out, depth_out, kmeans_out | |
| None, # pca_state | |
| None, | |
| None, | |
| "", | |
| "", # custom outputs | |
| ) | |
| def on_pca_extract(image, resolution, _pca_state): | |
| if image is None: | |
| return None, None, None, None | |
| _init_model() | |
| resolution = int(resolution) | |
| spatial = extract_features(image, resolution) | |
| h, w = image.shape[:2] | |
| pca = vis_pca(spatial) | |
| depth = vis_depth(spatial) | |
| kmeans = vis_kmeans(spatial, h, w) | |
| state = { | |
| "spatial": spatial, | |
| "orig_image": image, | |
| "variant": _model["name"], | |
| "resolution": resolution, | |
| } | |
| return pca, depth, kmeans, state | |
| def on_recluster(image, resolution, n_clusters, pca_state): | |
| if image is None: | |
| gr.Warning("Upload an image first.") | |
| return None, pca_state | |
| _init_model() | |
| resolution = int(resolution) | |
| if ( | |
| pca_state is not None | |
| and pca_state.get("variant") == _model["name"] | |
| and pca_state.get("resolution") == resolution | |
| ): | |
| spatial = pca_state["spatial"] | |
| else: | |
| spatial = extract_features(image, resolution) | |
| pca_state = { | |
| "spatial": spatial, | |
| "orig_image": image, | |
| "variant": _model["name"], | |
| "resolution": resolution, | |
| } | |
| h, w = image.shape[:2] | |
| return vis_kmeans(spatial, h, w, int(n_clusters)), pca_state | |
| def on_zeroseg_custom(image, resolution, class_names_str): | |
| if image is None or not class_names_str or not class_names_str.strip(): | |
| gr.Warning("Upload an image and enter at least one class name.") | |
| return None, None, "", "" | |
| _init_model() | |
| resolution = int(resolution) | |
| classes = [c.strip() for c in class_names_str.split(",") if c.strip()] | |
| if not classes: | |
| return None, None, "", "" | |
| dev = _device() | |
| all_embs = [] | |
| for template in TCL_PROMPTS: | |
| prompts = [template.format(c) for c in classes] | |
| ids, paddings = _model["tokenizer"].tokenize(prompts, max_len=MAX_LEN) | |
| with torch.no_grad(): | |
| embs = _model["text"]( | |
| torch.from_numpy(ids).to(dev), | |
| torch.from_numpy(paddings).to(dev), | |
| ) | |
| all_embs.append(embs.cpu().numpy()) | |
| class_embs = l2_normalize(np.mean(all_embs, axis=0)) | |
| spatial = extract_features_value_attention(image, resolution) | |
| overlay, mask, detected, undetected = vis_custom_semseg( | |
| spatial, | |
| image, | |
| classes, | |
| class_embs, | |
| ) | |
| return overlay, mask, detected, undetected | |
| def on_depth_normals_predict(image, dpt_variant, resolution): # noqa: ARG001 | |
| """Run DPT depth and normals prediction.""" | |
| if image is None: | |
| return None, None | |
| _init_model() | |
| dev = _device() | |
| h, w = image.shape[:2] | |
| img = Image.fromarray(image).convert("RGB") | |
| tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev) | |
| depth_map = _model["dpt"].predict_depth(tensor) | |
| normals_map = _model["dpt"].predict_normals(tensor) | |
| return ( | |
| vis_depth_dpt(depth_map[0, 0].cpu().numpy(), h, w), | |
| vis_normals_dpt(normals_map[0], h, w), | |
| ) | |
| def on_segmentation_predict(image, dpt_variant, resolution): # noqa: ARG001 | |
| """Run DPT segmentation prediction.""" | |
| if image is None: | |
| return None | |
| _init_model() | |
| dev = _device() | |
| img = Image.fromarray(image).convert("RGB") | |
| tensor = preprocess(img, int(resolution)).unsqueeze(0).to(dev) | |
| seg_map = _model["dpt"].predict_segmentation(tensor) | |
| return vis_segmentation_dpt(seg_map[0], image) | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| custom_css = """ | |
| #pca_output_image img, #depth_output_image img { | |
| image-rendering: pixelated; | |
| object-fit: contain; | |
| } | |
| """ | |
| head = """ | |
| <!-- Google tag (gtag.js) --> | |
| <script async src="https://www.googletagmanager.com/gtag/js?id=G-P13E18K71N"></script> | |
| <script> | |
| window.dataLayer = window.dataLayer || []; | |
| function gtag(){dataLayer.push(arguments);} | |
| gtag('js', new Date()); | |
| gtag('config', 'G-P13E18K71N', { | |
| 'page_title': 'TIPSv2', | |
| 'page_location': 'https://huggingface.co/spaces/google/TIPSv2' | |
| }); | |
| </script> | |
| """ | |
| with gr.Blocks(head=head, title="TIPSv2 Feature Explorer", css=custom_css) as demo: | |
| gr.Markdown( | |
| "## TIPSv2 Feature Explorer\n" | |
| "Explore TIPSv2 representations here! For more information, see: " | |
| "https://gdm-tipsv2.github.io/", | |
| ) | |
| with gr.Row(): | |
| variant_dd = gr.Dropdown( | |
| choices=list(VARIANTS.keys()), | |
| value=DEFAULT_VARIANT, | |
| label="Model variant", | |
| ) | |
| resolution_dd = gr.Dropdown( | |
| choices=RESOLUTIONS, | |
| value=DEFAULT_IMAGE_SIZE, | |
| label="Resolution (higher = better quality, slower)", | |
| ) | |
| # ββ PCA / Feature Visualization Tab βββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π¨ PCA & Feature Visualization"): | |
| pca_state = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(): | |
| pca_input = gr.Image(type="numpy", label="Input image") | |
| pca_btn = gr.Button("Extract Features", variant="primary") | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.Tab("PCA"): | |
| pca_out = gr.Image( | |
| label="PCA (3 components β RGB)", | |
| height=448, | |
| elem_id="pca_output_image", | |
| ) | |
| with gr.Tab("PCA (1st component)"): | |
| depth_out = gr.Image( | |
| label="1st PCA component", | |
| height=448, | |
| elem_id="depth_output_image", | |
| ) | |
| with gr.Tab("K-means Clustering"): | |
| n_clusters = gr.Slider( | |
| 2, | |
| 20, | |
| value=6, | |
| step=1, | |
| label="Clusters", | |
| ) | |
| recluster_btn = gr.Button("Re-cluster") | |
| kmeans_out = gr.Image(label="K-means clusters") | |
| gr.Markdown("π **Click the examples below to explore!**") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/pca/hike.jpeg"], | |
| ["examples/pca/cph.jpeg"], | |
| ["examples/pca/angus.jpeg"], | |
| ["examples/pca/dadaocheng.jpeg"], | |
| ], | |
| inputs=[pca_input], | |
| ) | |
| # ββ Zero-shot Segmentation Tab ββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("βοΈ Zero-shot Segmentation"): | |
| gr.Markdown( | |
| "Define your own classes for zero-shot segmentation. " | |
| "Enter class names separated by commas.", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| custom_input = gr.Image(type="numpy", label="Input image", height=448) | |
| custom_classes = gr.Textbox( | |
| label="Class names (comma-separated)", | |
| value="class1, class2, class3", | |
| placeholder="e.g. cat, dog, sky, grass", | |
| ) | |
| custom_btn = gr.Button("Segment", variant="primary") | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.Tab("Overlay"): | |
| custom_overlay = gr.Image( | |
| label="Segmentation overlay", | |
| height=448, | |
| ) | |
| with gr.Tab("Mask"): | |
| custom_mask = gr.Image( | |
| label="Segmentation mask", | |
| height=448, | |
| ) | |
| custom_detected = gr.Textbox( | |
| label="Detected classes (sorted by area)", | |
| lines=2, | |
| ) | |
| custom_undetected = gr.Textbox(label="Not detected", lines=2) | |
| gr.Markdown("π **Click the examples below to explore!**") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/zeroseg/voc_2008_000891.jpg", "dog, cage, cloth, dog bowl"], | |
| [ | |
| "examples/zeroseg/pascal_context_00000_image.png", | |
| "bike, tree, fence, soccer, floor, chair, cushion", | |
| ], | |
| [ | |
| "examples/zeroseg/pascal_context_00007_image.png", | |
| "dog, table, chair, carpet, shoes", | |
| ], | |
| [ | |
| "examples/zeroseg/pascal_context_00049_image.png", | |
| "bus, snow, mountain, house, road", | |
| ], | |
| ], | |
| inputs=[custom_input, custom_classes], | |
| ) | |
| # ββ Depth/Normals Visualization Tab βββββββββββββββββββββββββββββββββ | |
| with gr.Tab("ποΈ Depth/Normals Visualization"): | |
| gr.Markdown( | |
| "Monocular depth and surface normals estimation using a **DPT " | |
| "(Dense Prediction Transformer)** head on top of a **frozen** " | |
| "TIPS v2 vision encoder. Trained on the **NYU Depth V2** dataset.", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| depth_input = gr.Image(type="numpy", label="Input image", height=448) | |
| depth_btn = gr.Button("Predict Depth & Normals", variant="primary") | |
| with gr.Column(): | |
| dpt_depth_out = gr.Image(label="DPT Depth Map", height=448) | |
| with gr.Column(): | |
| dpt_normals_out = gr.Image( | |
| label="DPT Surface Normals", | |
| height=448, | |
| ) | |
| gr.Markdown("π **Click the examples below to explore!**") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/nyuv2/bedroom_00280.jpg"], | |
| ["examples/nyuv2/kitchen_00249.jpg"], | |
| ["examples/nyuv2/living_room_01260.jpg"], | |
| ["examples/nyuv2/office_kitchen_00413.jpg"], | |
| ["examples/nyuv2/study_room_00272.jpg"], | |
| ], | |
| inputs=[depth_input], | |
| ) | |
| # ββ Supervised Segmentation Tab ββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Supervised Segmentation"): | |
| gr.Markdown( | |
| "Semantic segmentation using a **DPT (Dense Prediction " | |
| "Transformer)** head on top of a **frozen** TIPS v2 vision " | |
| "encoder. Trained on ADE20K (150 classes).", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| seg_input = gr.Image(type="numpy", label="Input image", height=448) | |
| seg_btn = gr.Button("Segment", variant="primary") | |
| with gr.Column(): | |
| seg_out = gr.Image(label="DPT Segmentation (ADE20K)", height=448) | |
| gr.Markdown("π **Click the examples below to explore!**") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/depth/ade20k_00003.png"], | |
| ["examples/depth/ade20k_00007.png"], | |
| ["examples/depth/ade20k_00014.png"], | |
| ["examples/depth/ade20k_00022.png"], | |
| ], | |
| inputs=[seg_input], | |
| ) | |
| # ββ Wiring ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| variant_dd.change( | |
| fn=on_variant_change, | |
| inputs=[variant_dd], | |
| outputs=[ | |
| pca_out, | |
| depth_out, | |
| kmeans_out, | |
| pca_state, | |
| custom_overlay, | |
| custom_mask, | |
| custom_detected, | |
| custom_undetected, | |
| ], | |
| ) | |
| pca_btn.click( | |
| fn=on_pca_extract, | |
| inputs=[pca_input, resolution_dd, pca_state], | |
| outputs=[pca_out, depth_out, kmeans_out, pca_state], | |
| ) | |
| recluster_btn.click( | |
| fn=on_recluster, | |
| inputs=[pca_input, resolution_dd, n_clusters, pca_state], | |
| outputs=[kmeans_out, pca_state], | |
| ) | |
| depth_btn.click( | |
| fn=on_depth_normals_predict, | |
| inputs=[depth_input, variant_dd, resolution_dd], | |
| outputs=[dpt_depth_out, dpt_normals_out], | |
| ) | |
| seg_btn.click( | |
| fn=on_segmentation_predict, | |
| inputs=[seg_input, variant_dd, resolution_dd], | |
| outputs=[seg_out], | |
| ) | |
| custom_btn.click( | |
| fn=on_zeroseg_custom, | |
| inputs=[custom_input, resolution_dd, custom_classes], | |
| outputs=[custom_overlay, custom_mask, custom_detected, custom_undetected], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |