| """ViT-Up: Faithful Feature Upsampling for Vision Transformers. |
| |
| Interactive demo that loads the ViT-Up feature upsampler, extracts dense |
| features from an input image at a user-selected output resolution, and |
| visualises them via a 3-component PCA projection to RGB. |
| |
| The DINOv3 backbone checkpoint on Hugging Face is gated, so this demo |
| loads the equivalent pretrained weights from the non-gated timm mirror |
| (`timm/vit_small_plus_patch16_dinov3.lvd1689m`) and maps them into the |
| same ``DINOv3ViT`` module structure the ViT-Up code expects. The ViT-Up |
| LoRA adapters and upsampler head are then loaded from ``Krispin/vit-up``. |
| """ |
|
|
| import os |
|
|
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
| import spaces |
| import sys |
| import math |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image, ImageOps |
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file as load_safetensors |
| import torchvision.transforms.v2 as T |
|
|
| |
| |
| |
| BACKBONE_TIMM_REPO = "timm/vit_small_plus_patch16_dinov3.lvd1689m" |
| VITUP_WEIGHTS_REPO = "Krispin/vit-up" |
| VITUP_WEIGHTS_FILE = "vit_up_dinov3_splus.safetensors" |
| HIDDEN_SIZE = 384 |
| NUM_LAYERS = 12 |
| NUM_HEADS = 6 |
| INTERMEDIATE_SIZE = 1536 |
| PATCH_SIZE = 16 |
| NUM_REGISTER_TOKENS = 4 |
| IMAGE_SIZE = 448 |
| LAYER_INDICES = [0, 2, 4, 6, 8, 10, 12] |
| RESNET_MEAN = torch.tensor([0.485, 0.456, 0.406]) |
| RESNET_STD = torch.tensor([0.229, 0.224, 0.225]) |
|
|
| |
| from transformers import DINOv3ViTConfig |
| from vit_up.layers.backbones.dinov3_vit import DINOv3ViT |
| from vit_up.model.vit_up import ViTUp |
| from vit_up.utils.state_dict_migration import migrate_vit_up_state_dict_keys |
| from peft import LoraConfig, get_peft_model |
|
|
|
|
| |
| |
| |
| def _map_timm_to_dinov3(timm_sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """Convert a timm ViT state-dict to the DINOv3ViT module key names.""" |
|
|
| mapped: Dict[str, torch.Tensor] = {} |
| for key, val in timm_sd.items(): |
| if key == "cls_token": |
| mapped["embeddings.cls_token"] = val |
| elif key == "reg_token": |
| mapped["embeddings.register_tokens"] = val |
| elif key == "patch_embed.proj.weight": |
| mapped["embeddings.patch_embeddings.weight"] = val |
| elif key == "patch_embed.proj.bias": |
| mapped["embeddings.patch_embeddings.bias"] = val |
| elif key.startswith("blocks.") and key.endswith(".attn.qkv.weight"): |
| idx = int(key.split(".")[1]) |
| qkv = val |
| q, k, v = qkv.chunk(3, dim=0) |
| mapped[f"layer.{idx}.attention.q_proj.weight"] = q |
| mapped[f"layer.{idx}.attention.k_proj.weight"] = k |
| mapped[f"layer.{idx}.attention.v_proj.weight"] = v |
| elif key.startswith("blocks.") and key.endswith(".attn.qkv.bias"): |
| idx = int(key.split(".")[1]) |
| qkv = val |
| if val is not None and val.numel() > 0: |
| q, k, v = qkv.chunk(3, dim=0) |
| mapped[f"layer.{idx}.attention.q_proj.bias"] = q |
| mapped[f"layer.{idx}.attention.k_proj.bias"] = k |
| mapped[f"layer.{idx}.attention.v_proj.bias"] = v |
| elif key.startswith("blocks.") and ".attn.proj." in key: |
| idx = int(key.split(".")[1]) |
| suffix = key.split(".attn.proj.")[-1] |
| mapped[f"layer.{idx}.attention.o_proj.{suffix}"] = val |
| elif key.startswith("blocks.") and ".norm1." in key: |
| idx = int(key.split(".")[1]) |
| suffix = key.split(".norm1.")[-1] |
| mapped[f"layer.{idx}.norm1.{suffix}"] = val |
| elif key.startswith("blocks.") and ".norm2." in key: |
| idx = int(key.split(".")[1]) |
| suffix = key.split(".norm2.")[-1] |
| mapped[f"layer.{idx}.norm2.{suffix}"] = val |
| elif key.startswith("blocks.") and ".mlp.fc1_g." in key: |
| idx = int(key.split(".")[1]) |
| suffix = key.split(".mlp.fc1_g.")[-1] |
| mapped[f"layer.{idx}.mlp.gate_proj.{suffix}"] = val |
| elif key.startswith("blocks.") and ".mlp.fc1_x." in key: |
| idx = int(key.split(".")[1]) |
| suffix = key.split(".mlp.fc1_x.")[-1] |
| mapped[f"layer.{idx}.mlp.up_proj.{suffix}"] = val |
| elif key.startswith("blocks.") and ".mlp.fc2." in key: |
| idx = int(key.split(".")[1]) |
| suffix = key.split(".mlp.fc2.")[-1] |
| mapped[f"layer.{idx}.mlp.down_proj.{suffix}"] = val |
| elif key.startswith("blocks.") and key.endswith(".gamma_1"): |
| idx = int(key.split(".")[1]) |
| mapped[f"layer.{idx}.layer_scale1.lambda1"] = val |
| elif key.startswith("blocks.") and key.endswith(".gamma_2"): |
| idx = int(key.split(".")[1]) |
| mapped[f"layer.{idx}.layer_scale2.lambda1"] = val |
| elif key == "norm.weight": |
| mapped["norm.weight"] = val |
| elif key == "norm.bias": |
| mapped["norm.bias"] = val |
| |
| return mapped |
|
|
|
|
| |
| |
| |
| def _build_backbone(device: str, dtype: torch.dtype) -> DINOv3ViT: |
| config = DINOv3ViTConfig( |
| hidden_size=HIDDEN_SIZE, |
| num_hidden_layers=NUM_LAYERS, |
| num_attention_heads=NUM_HEADS, |
| intermediate_size=INTERMEDIATE_SIZE, |
| patch_size=PATCH_SIZE, |
| image_size=IMAGE_SIZE, |
| num_register_tokens=NUM_REGISTER_TOKENS, |
| use_gated_mlp=True, |
| layerscale_value=1e-5, |
| query_bias=False, |
| key_bias=False, |
| value_bias=False, |
| proj_bias=True, |
| mlp_bias=True, |
| ) |
| config._attn_implementation = "eager" |
| backbone = DINOv3ViT(config) |
|
|
| |
| timm_safetensors_path = hf_hub_download( |
| BACKBONE_TIMM_REPO, "model.safetensors" |
| ) |
| timm_sd = load_safetensors(timm_safetensors_path, device="cpu") |
| mapped_sd = _map_timm_to_dinov3(timm_sd) |
| missing, unexpected = backbone.load_state_dict(mapped_sd, strict=False) |
| |
| real_missing = [k for k in missing if "mask_token" not in k] |
| if real_missing: |
| print(f"[WARNING] Missing backbone keys after timm load: {real_missing[:10]}") |
| print(f"[INFO] Loaded backbone from timm: {len(mapped_sd)} tensors mapped") |
|
|
| |
| lora_config = LoraConfig( |
| r=16, |
| lora_alpha=32, |
| lora_dropout=0.05, |
| bias="none", |
| target_modules=[ |
| "patch_embeddings", |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| ], |
| ) |
| backbone = get_peft_model(backbone, lora_config) |
| backbone = backbone.to(device=device, dtype=dtype).eval() |
| return backbone |
|
|
|
|
| |
| |
| |
| def _build_vit_up(device: str, dtype: torch.dtype) -> ViTUp: |
| """Instantiate the ViTUp upsampler from the config tree (same as the repo).""" |
| from vit_up.layers.query_encoder import QueryEncoder |
| from vit_up.layers.pos_enc import FourierPositionalEncoding |
| from vit_up.layers.continuous_rope import ContinuousRoPE2D |
| from vit_up.layers.smart_module_list import SmartModuleList |
| from vit_up.layers.mlp import SimpleMLP |
| from vit_up.layers.cross_attention import CrossAttention |
| from vit_up.layers.film import SimpleFiLMV2 |
|
|
| dim = HIDDEN_SIZE |
|
|
| query_embedding = QueryEncoder( |
| layer_index=0, |
| img_in_size=3584, |
| window_size=0, |
| out_proj_module=None, |
| ) |
|
|
| rel_pos_enc = FourierPositionalEncoding(num_bands=16, max_resolution=10.0) |
|
|
| q_rope_embeddings = ContinuousRoPE2D(dim=64, base=100.0, scale=2 * math.pi) |
|
|
| vit_up_blocks = SmartModuleList( |
| n_blocks=6, |
| block_class_path="vit_up.model.vit_up.ViTUpBlock", |
| block_init_args={ |
| "dim": dim, |
| "dim_h": dim, |
| "transition_mlp": SimpleMLP( |
| dims=[dim, dim * 2, dim], |
| activation="gelu", |
| input_layernorm=True, |
| use_residual=True, |
| ), |
| "cross_attention": CrossAttention( |
| dim=dim, |
| num_heads=NUM_HEADS, |
| cross_attn_window_size=32, |
| qkv_bias=True, |
| attn_dropout=0.0, |
| proj_dropout=0.0, |
| ), |
| "featx": SimpleFiLMV2( |
| input_module=nn.LayerNorm(dim), |
| gamma_beta_mlp=SimpleMLP( |
| dims=[66, dim, dim * 2], |
| activation="gelu", |
| zero_init_last=True, |
| ), |
| post_mlp=SimpleMLP( |
| dims=[dim, dim * 4, dim], |
| activation="gelu", |
| input_layernorm=True, |
| use_residual=False, |
| ), |
| ), |
| "mlp": SimpleMLP( |
| dims=[dim, dim * 4, dim], |
| activation="gelu", |
| ), |
| }, |
| ) |
|
|
| decoder_mlp = SmartModuleList( |
| n_blocks=7, |
| block_class_path="vit_up.layers.mlp.SimpleMLP", |
| block_init_args={ |
| "input_layernorm": True, |
| "dims": [dim, dim], |
| }, |
| ) |
|
|
| vit_up = ViTUp( |
| layer_indices=LAYER_INDICES, |
| query_embedding=query_embedding, |
| rel_pos_enc=rel_pos_enc, |
| vit_up_blocks=vit_up_blocks, |
| decoder_mlp=decoder_mlp, |
| q_rope_embeddings=q_rope_embeddings, |
| ) |
| vit_up = vit_up.to(device=device, dtype=dtype).eval() |
| return vit_up |
|
|
|
|
| |
| |
| |
| def _load_vit_up_weights( |
| backbone: nn.Module, |
| vit_up: ViTUp, |
| device: str, |
| ) -> None: |
| """Load the combined LoRA + ViT-Up weights from Krispin/vit-up.""" |
| weights_path = hf_hub_download(VITUP_WEIGHTS_REPO, VITUP_WEIGHTS_FILE) |
| state_dict = load_safetensors(weights_path, device="cpu") |
|
|
| backbone_sd: Dict[str, torch.Tensor] = {} |
| vit_up_sd: Dict[str, torch.Tensor] = {} |
| for key, val in state_dict.items(): |
| if key.startswith("backbone."): |
| backbone_sd[key.removeprefix("backbone.")] = val |
| else: |
| vit_up_sd[key] = val |
|
|
| |
| missing_b, unexpected_b = backbone.load_state_dict(backbone_sd, strict=False) |
| print(f"[INFO] Loaded backbone LoRA: {len(backbone_sd)} tensors, " |
| f"missing={len(missing_b)}, unexpected={len(unexpected_b)}") |
|
|
| |
| migrated_vit_up_sd = migrate_vit_up_state_dict_keys(vit_up_sd) |
| missing_v, unexpected_v = vit_up.load_state_dict(migrated_vit_up_sd, strict=False) |
| print(f"[INFO] Loaded ViT-Up: {len(migrated_vit_up_sd)} tensors, " |
| f"missing={len(missing_v)}, unexpected={len(unexpected_v)}") |
| if missing_v: |
| print(f" Missing ViT-Up keys: {missing_v[:10]}") |
|
|
|
|
| |
| |
| |
| def _fit_pca(tokens_nc: torch.Tensor, k: int = 3) -> dict: |
| """Fit a simple PCA on (N, C) feature tokens.""" |
| tokens = tokens_nc.float() |
| mean = tokens.mean(dim=0) |
| centered = tokens - mean |
| _, singular_values, vh = torch.linalg.svd(centered, full_matrices=False) |
| components = vh[:k].T |
| projected = centered @ components |
| color_min = projected.amin(dim=0) |
| color_max = projected.amax(dim=0) |
| flat = torch.isclose(color_max, color_min) |
| color_max = torch.where(flat, color_min + 1.0, color_max) |
| return { |
| "pca_eig": components, |
| "pca_mean": mean, |
| "pca_color_min": color_min, |
| "pca_color_max": color_max, |
| } |
|
|
|
|
| def _apply_pca_rgb(feats_hwc: torch.Tensor, pca_data: dict) -> torch.Tensor: |
| h, w, c = feats_hwc.shape |
| tokens = feats_hwc.float().reshape(-1, c) |
| mean = pca_data["pca_mean"].to(device=tokens.device, dtype=tokens.dtype) |
| components = pca_data["pca_eig"].to(device=tokens.device, dtype=tokens.dtype) |
| color_min = pca_data["pca_color_min"].to(device=tokens.device, dtype=tokens.dtype) |
| color_max = pca_data["pca_color_max"].to(device=tokens.device, dtype=tokens.dtype) |
| projected = (tokens - mean) @ components |
| rgb = (projected - color_min.view(1, -1)) / (color_max - color_min).view(1, -1).add(1e-8) |
| rgb = rgb.clamp(0.0, 1.0).mul(255.0).to(torch.uint8) |
| return rgb.reshape(h, w, 3) |
|
|
|
|
| |
| |
| |
| def pad_image_to_square(img: Image.Image) -> Image.Image: |
| w, h = img.size |
| if w == h: |
| return img |
| max_side = max(w, h) |
| if w > h: |
| py = (w - h) // 2 |
| return ImageOps.expand(img, border=(0, py), fill=0) |
| else: |
| px = (h - w) // 2 |
| return ImageOps.expand(img, border=(px, 0), fill=0) |
|
|
|
|
| def crop_feature_square_to_image_aspect( |
| feat_img: Image.Image, |
| original_size: tuple, |
| ) -> Image.Image: |
| width, height = original_size |
| max_size = max(width, height) |
| px, py = (0, 0) |
| if width > height: |
| py = (width - height) // 2 |
| elif height > width: |
| px = (height - width) // 2 |
| scale_x = feat_img.width / max_size |
| scale_y = feat_img.height / max_size |
| left = int(round(px * scale_x)) |
| top = int(round(py * scale_y)) |
| right = int(round((px + width) * scale_x)) |
| bottom = int(round((py + height) * scale_y)) |
| return feat_img.crop((left, top, right, bottom)) |
|
|
|
|
| |
| |
| |
| print("[INFO] Building ViT-Up model...") |
| DEVICE = "cuda" |
| DTYPE = torch.bfloat16 |
|
|
| backbone = _build_backbone(DEVICE, DTYPE) |
| vit_up = _build_vit_up(DEVICE, DTYPE) |
| _load_vit_up_weights(backbone, vit_up, DEVICE) |
| backbone = backbone.eval() |
| vit_up = vit_up.eval() |
| print("[INFO] Model ready.") |
|
|
|
|
| |
| |
| |
| def _prepare_image(img: Image.Image) -> torch.Tensor: |
| """Pad to square, resize, normalise — return (1, 3, H, W) on device.""" |
| img_square = pad_image_to_square(img.convert("RGB")) |
| transform = T.Compose([ |
| T.ToImage(), |
| T.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=T.InterpolationMode.BILINEAR, antialias=True), |
| T.ToDtype(torch.float32, scale=True), |
| T.Normalize(mean=RESNET_MEAN, std=RESNET_STD), |
| ]) |
| return transform(img_square).unsqueeze(0).to(DEVICE) |
|
|
|
|
| def _compute_query_coords(out_size: int) -> torch.Tensor: |
| coords = torch.linspace(0.5, out_size - 0.5, out_size) / out_size |
| grid_y, grid_x = torch.meshgrid(coords, coords, indexing="ij") |
| return torch.stack((grid_x, grid_y), dim=-1).reshape(1, -1, 2) |
|
|
|
|
| @spaces.GPU(duration=120) |
| def extract_and_visualize( |
| input_image: Image.Image, |
| output_resolution: int, |
| ) -> tuple[Image.Image, Image.Image, str]: |
| """Extract dense ViT-Up features and visualise them via PCA. |
| |
| Args: |
| input_image: Input PIL image. |
| output_resolution: Output feature map resolution (pixels per side). |
| |
| Returns: |
| Tuple of (pca_visualization, input_resized, info_text). |
| """ |
| if input_image is None: |
| return None, None, "Please provide an input image." |
|
|
| out_size = int(output_resolution) |
| orig_w, orig_h = input_image.size |
|
|
| |
| pixel_values = _prepare_image(input_image) |
|
|
| |
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=DTYPE): |
| cache_data = vit_up.compute_cache_data( |
| pixel_values=pixel_values, |
| backbone=backbone, |
| hidden_layer_img_size=IMAGE_SIZE, |
| ) |
|
|
| |
| query_coords = _compute_query_coords(out_size).to(DEVICE, dtype=DTYPE) |
|
|
| |
| chunk_size = 4096 |
| q_chunks = [] |
| for q_start in range(0, query_coords.shape[1], chunk_size): |
| q_end = min(q_start + chunk_size, query_coords.shape[1]) |
| q_chunk = vit_up( |
| pixel_values=None, |
| q_xy_normalized=query_coords[:, q_start:q_end, :], |
| cache_data=cache_data, |
| ) |
| q_chunks.append(q_chunk[-1]) |
|
|
| features = torch.cat(q_chunks, dim=1) |
| features_hwc = features[0].reshape(out_size, out_size, -1).float().cpu() |
|
|
| |
| pca_data = _fit_pca(features_hwc.reshape(-1, features_hwc.shape[-1]), k=3) |
| pca_rgb = _apply_pca_rgb(features_hwc, pca_data) |
| pca_img = Image.fromarray(pca_rgb.numpy().astype(np.uint8), mode="RGB") |
|
|
| |
| pca_img = crop_feature_square_to_image_aspect(pca_img, (orig_w, orig_h)) |
|
|
| |
| display_w, display_h = orig_w, orig_h |
| max_display = 512 |
| if max(display_w, display_h) > max_display: |
| scale = max_display / max(display_w, display_h) |
| display_w = int(display_w * scale) |
| display_h = int(display_h * scale) |
| pca_display = pca_img.resize((display_w, display_h), Image.Resampling.NEAREST) |
|
|
| |
| input_display = input_image.convert("RGB").resize((display_w, display_h), Image.Resampling.LANCZOS) |
|
|
| info = (f"Feature dim: {features_hwc.shape[-1]} | " |
| f"Output resolution: {out_size}x{out_size} | " |
| f"Total query points: {out_size * out_size}") |
|
|
| return pca_display, input_display, info |
|
|
|
|
| |
| |
| |
| CSS = """ |
| #col-container { max-width: 1100px; margin: 0 auto; } |
| .dark .gradio-container { color: var(--body-text-color); } |
| """ |
|
|
| with gr.Blocks(theme=gr.themes.Citrus(), css=CSS) as demo: |
| gr.Markdown("# ViT-Up: Faithful Feature Upsampling for Vision Transformers") |
| gr.Markdown( |
| "Upload an image to extract dense DINOv3 features at arbitrary resolution " |
| "via the ViT-Up feature upsampler. The PCA visualization shows the " |
| "3 principal components of the upsampled feature map as RGB." |
| ) |
| gr.Markdown( |
| "[Paper](https://huggingface.co/papers/2606.14024) | " |
| "[GitHub](https://github.com/krispinwandel/vit-up) | " |
| "[Model Weights](https://huggingface.co/Krispin/vit-up)" |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| input_img = gr.Image(label="Input Image", type="pil") |
| with gr.Accordion("Advanced settings", open=False): |
| out_res = gr.Slider( |
| label="Output resolution (pixels per side)", |
| minimum=28, |
| maximum=224, |
| value=112, |
| step=28, |
| ) |
| run_btn = gr.Button("Extract Features", variant="primary") |
| with gr.Column(): |
| pca_out = gr.Image(label="PCA Feature Visualization") |
| input_display = gr.Image(label="Input (resized)") |
|
|
| info_text = gr.Textbox(label="Info", interactive=False) |
|
|
| run_btn.click( |
| fn=extract_and_visualize, |
| inputs=[input_img, out_res], |
| outputs=[pca_out, input_display, info_text], |
| api_name="extract_features", |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| ["city_with_cars.png", 112], |
| ["fruit_store.png", 112], |
| ], |
| inputs=[input_img, out_res], |
| outputs=[pca_out, input_display, info_text], |
| fn=extract_and_visualize, |
| cache_examples=True, |
| cache_mode="lazy", |
| ) |
|
|
| demo.launch(mcp_server=True) |