vit-up / app.py
multimodalart's picture
multimodalart HF Staff
Upload app.py with huggingface_hub
8e0c227 verified
Raw
History Blame Contribute Delete
21 kB
"""ViT-Up: Faithful Feature Upsampling for Vision Transformers.
Interactive demo that loads the ViT-Up feature upsampler, extracts dense
features from an input image at a user-selected output resolution, and
visualises them via a 3-component PCA projection to RGB.
The DINOv3 backbone checkpoint on Hugging Face is gated, so this demo
loads the equivalent pretrained weights from the non-gated timm mirror
(`timm/vit_small_plus_patch16_dinov3.lvd1689m`) and maps them into the
same ``DINOv3ViT`` module structure the ViT-Up code expects. The ViT-Up
LoRA adapters and upsampler head are then loaded from ``Krispin/vit-up``.
"""
import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import spaces # MUST come before torch / any CUDA-touching import
import sys
import math
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image, ImageOps
import gradio as gr
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
import torchvision.transforms.v2 as T
# ---------------------------------------------------------------------------
# Config constants — DINOv3-S+ variant
# ---------------------------------------------------------------------------
BACKBONE_TIMM_REPO = "timm/vit_small_plus_patch16_dinov3.lvd1689m"
VITUP_WEIGHTS_REPO = "Krispin/vit-up"
VITUP_WEIGHTS_FILE = "vit_up_dinov3_splus.safetensors"
HIDDEN_SIZE = 384
NUM_LAYERS = 12
NUM_HEADS = 6
INTERMEDIATE_SIZE = 1536
PATCH_SIZE = 16
NUM_REGISTER_TOKENS = 4
IMAGE_SIZE = 448
LAYER_INDICES = [0, 2, 4, 6, 8, 10, 12]
RESNET_MEAN = torch.tensor([0.485, 0.456, 0.406])
RESNET_STD = torch.tensor([0.229, 0.224, 0.225])
# vit_up package is included in the Space repo root
from transformers import DINOv3ViTConfig
from vit_up.layers.backbones.dinov3_vit import DINOv3ViT
from vit_up.model.vit_up import ViTUp
from vit_up.utils.state_dict_migration import migrate_vit_up_state_dict_keys
from peft import LoraConfig, get_peft_model
# ---------------------------------------------------------------------------
# Weight mapping: timm -> DINOv3ViT
# ---------------------------------------------------------------------------
def _map_timm_to_dinov3(timm_sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert a timm ViT state-dict to the DINOv3ViT module key names."""
mapped: Dict[str, torch.Tensor] = {}
for key, val in timm_sd.items():
if key == "cls_token":
mapped["embeddings.cls_token"] = val
elif key == "reg_token":
mapped["embeddings.register_tokens"] = val
elif key == "patch_embed.proj.weight":
mapped["embeddings.patch_embeddings.weight"] = val
elif key == "patch_embed.proj.bias":
mapped["embeddings.patch_embeddings.bias"] = val
elif key.startswith("blocks.") and key.endswith(".attn.qkv.weight"):
idx = int(key.split(".")[1])
qkv = val # (3*hidden, hidden) but timm uses fused qkv
q, k, v = qkv.chunk(3, dim=0)
mapped[f"layer.{idx}.attention.q_proj.weight"] = q
mapped[f"layer.{idx}.attention.k_proj.weight"] = k
mapped[f"layer.{idx}.attention.v_proj.weight"] = v
elif key.startswith("blocks.") and key.endswith(".attn.qkv.bias"):
idx = int(key.split(".")[1])
qkv = val
if val is not None and val.numel() > 0:
q, k, v = qkv.chunk(3, dim=0)
mapped[f"layer.{idx}.attention.q_proj.bias"] = q
mapped[f"layer.{idx}.attention.k_proj.bias"] = k
mapped[f"layer.{idx}.attention.v_proj.bias"] = v
elif key.startswith("blocks.") and ".attn.proj." in key:
idx = int(key.split(".")[1])
suffix = key.split(".attn.proj.")[-1] # weight or bias
mapped[f"layer.{idx}.attention.o_proj.{suffix}"] = val
elif key.startswith("blocks.") and ".norm1." in key:
idx = int(key.split(".")[1])
suffix = key.split(".norm1.")[-1]
mapped[f"layer.{idx}.norm1.{suffix}"] = val
elif key.startswith("blocks.") and ".norm2." in key:
idx = int(key.split(".")[1])
suffix = key.split(".norm2.")[-1]
mapped[f"layer.{idx}.norm2.{suffix}"] = val
elif key.startswith("blocks.") and ".mlp.fc1_g." in key:
idx = int(key.split(".")[1])
suffix = key.split(".mlp.fc1_g.")[-1]
mapped[f"layer.{idx}.mlp.gate_proj.{suffix}"] = val
elif key.startswith("blocks.") and ".mlp.fc1_x." in key:
idx = int(key.split(".")[1])
suffix = key.split(".mlp.fc1_x.")[-1]
mapped[f"layer.{idx}.mlp.up_proj.{suffix}"] = val
elif key.startswith("blocks.") and ".mlp.fc2." in key:
idx = int(key.split(".")[1])
suffix = key.split(".mlp.fc2.")[-1]
mapped[f"layer.{idx}.mlp.down_proj.{suffix}"] = val
elif key.startswith("blocks.") and key.endswith(".gamma_1"):
idx = int(key.split(".")[1])
mapped[f"layer.{idx}.layer_scale1.lambda1"] = val
elif key.startswith("blocks.") and key.endswith(".gamma_2"):
idx = int(key.split(".")[1])
mapped[f"layer.{idx}.layer_scale2.lambda1"] = val
elif key == "norm.weight":
mapped["norm.weight"] = val
elif key == "norm.bias":
mapped["norm.bias"] = val
# pos_embed is handled by RoPE — skip
return mapped
# ---------------------------------------------------------------------------
# Build the backbone from config + timm weights + LoRA
# ---------------------------------------------------------------------------
def _build_backbone(device: str, dtype: torch.dtype) -> DINOv3ViT:
config = DINOv3ViTConfig(
hidden_size=HIDDEN_SIZE,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
intermediate_size=INTERMEDIATE_SIZE,
patch_size=PATCH_SIZE,
image_size=IMAGE_SIZE,
num_register_tokens=NUM_REGISTER_TOKENS,
use_gated_mlp=True,
layerscale_value=1e-5,
query_bias=False,
key_bias=False,
value_bias=False,
proj_bias=True,
mlp_bias=True,
)
config._attn_implementation = "eager"
backbone = DINOv3ViT(config)
# Load timm weights
timm_safetensors_path = hf_hub_download(
BACKBONE_TIMM_REPO, "model.safetensors"
)
timm_sd = load_safetensors(timm_safetensors_path, device="cpu")
mapped_sd = _map_timm_to_dinov3(timm_sd)
missing, unexpected = backbone.load_state_dict(mapped_sd, strict=False)
# embeddings.mask_token won't be in timm weights — that's fine
real_missing = [k for k in missing if "mask_token" not in k]
if real_missing:
print(f"[WARNING] Missing backbone keys after timm load: {real_missing[:10]}")
print(f"[INFO] Loaded backbone from timm: {len(mapped_sd)} tensors mapped")
# Apply LoRA
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=[
"patch_embeddings",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
)
backbone = get_peft_model(backbone, lora_config)
backbone = backbone.to(device=device, dtype=dtype).eval()
return backbone
# ---------------------------------------------------------------------------
# Build the ViT-Up model from config
# ---------------------------------------------------------------------------
def _build_vit_up(device: str, dtype: torch.dtype) -> ViTUp:
"""Instantiate the ViTUp upsampler from the config tree (same as the repo)."""
from vit_up.layers.query_encoder import QueryEncoder
from vit_up.layers.pos_enc import FourierPositionalEncoding
from vit_up.layers.continuous_rope import ContinuousRoPE2D
from vit_up.layers.smart_module_list import SmartModuleList
from vit_up.layers.mlp import SimpleMLP
from vit_up.layers.cross_attention import CrossAttention
from vit_up.layers.film import SimpleFiLMV2
dim = HIDDEN_SIZE
query_embedding = QueryEncoder(
layer_index=0,
img_in_size=3584,
window_size=0,
out_proj_module=None,
)
rel_pos_enc = FourierPositionalEncoding(num_bands=16, max_resolution=10.0)
q_rope_embeddings = ContinuousRoPE2D(dim=64, base=100.0, scale=2 * math.pi)
vit_up_blocks = SmartModuleList(
n_blocks=6,
block_class_path="vit_up.model.vit_up.ViTUpBlock",
block_init_args={
"dim": dim,
"dim_h": dim,
"transition_mlp": SimpleMLP(
dims=[dim, dim * 2, dim],
activation="gelu",
input_layernorm=True,
use_residual=True,
),
"cross_attention": CrossAttention(
dim=dim,
num_heads=NUM_HEADS,
cross_attn_window_size=32,
qkv_bias=True,
attn_dropout=0.0,
proj_dropout=0.0,
),
"featx": SimpleFiLMV2(
input_module=nn.LayerNorm(dim),
gamma_beta_mlp=SimpleMLP(
dims=[66, dim, dim * 2],
activation="gelu",
zero_init_last=True,
),
post_mlp=SimpleMLP(
dims=[dim, dim * 4, dim],
activation="gelu",
input_layernorm=True,
use_residual=False,
),
),
"mlp": SimpleMLP(
dims=[dim, dim * 4, dim],
activation="gelu",
),
},
)
decoder_mlp = SmartModuleList(
n_blocks=7,
block_class_path="vit_up.layers.mlp.SimpleMLP",
block_init_args={
"input_layernorm": True,
"dims": [dim, dim],
},
)
vit_up = ViTUp(
layer_indices=LAYER_INDICES,
query_embedding=query_embedding,
rel_pos_enc=rel_pos_enc,
vit_up_blocks=vit_up_blocks,
decoder_mlp=decoder_mlp,
q_rope_embeddings=q_rope_embeddings,
)
vit_up = vit_up.to(device=device, dtype=dtype).eval()
return vit_up
# ---------------------------------------------------------------------------
# Load ViT-Up + LoRA weights from the safetensors checkpoint
# ---------------------------------------------------------------------------
def _load_vit_up_weights(
backbone: nn.Module,
vit_up: ViTUp,
device: str,
) -> None:
"""Load the combined LoRA + ViT-Up weights from Krispin/vit-up."""
weights_path = hf_hub_download(VITUP_WEIGHTS_REPO, VITUP_WEIGHTS_FILE)
state_dict = load_safetensors(weights_path, device="cpu")
backbone_sd: Dict[str, torch.Tensor] = {}
vit_up_sd: Dict[str, torch.Tensor] = {}
for key, val in state_dict.items():
if key.startswith("backbone."):
backbone_sd[key.removeprefix("backbone.")] = val
else:
vit_up_sd[key] = val
# Load backbone LoRA weights
missing_b, unexpected_b = backbone.load_state_dict(backbone_sd, strict=False)
print(f"[INFO] Loaded backbone LoRA: {len(backbone_sd)} tensors, "
f"missing={len(missing_b)}, unexpected={len(unexpected_b)}")
# Load ViT-Up weights (with key migration)
migrated_vit_up_sd = migrate_vit_up_state_dict_keys(vit_up_sd)
missing_v, unexpected_v = vit_up.load_state_dict(migrated_vit_up_sd, strict=False)
print(f"[INFO] Loaded ViT-Up: {len(migrated_vit_up_sd)} tensors, "
f"missing={len(missing_v)}, unexpected={len(unexpected_v)}")
if missing_v:
print(f" Missing ViT-Up keys: {missing_v[:10]}")
# ---------------------------------------------------------------------------
# PCA utilities (from the repo's correspondence.py)
# ---------------------------------------------------------------------------
def _fit_pca(tokens_nc: torch.Tensor, k: int = 3) -> dict:
"""Fit a simple PCA on (N, C) feature tokens."""
tokens = tokens_nc.float()
mean = tokens.mean(dim=0)
centered = tokens - mean
_, singular_values, vh = torch.linalg.svd(centered, full_matrices=False)
components = vh[:k].T
projected = centered @ components
color_min = projected.amin(dim=0)
color_max = projected.amax(dim=0)
flat = torch.isclose(color_max, color_min)
color_max = torch.where(flat, color_min + 1.0, color_max)
return {
"pca_eig": components,
"pca_mean": mean,
"pca_color_min": color_min,
"pca_color_max": color_max,
}
def _apply_pca_rgb(feats_hwc: torch.Tensor, pca_data: dict) -> torch.Tensor:
h, w, c = feats_hwc.shape
tokens = feats_hwc.float().reshape(-1, c)
mean = pca_data["pca_mean"].to(device=tokens.device, dtype=tokens.dtype)
components = pca_data["pca_eig"].to(device=tokens.device, dtype=tokens.dtype)
color_min = pca_data["pca_color_min"].to(device=tokens.device, dtype=tokens.dtype)
color_max = pca_data["pca_color_max"].to(device=tokens.device, dtype=tokens.dtype)
projected = (tokens - mean) @ components
rgb = (projected - color_min.view(1, -1)) / (color_max - color_min).view(1, -1).add(1e-8)
rgb = rgb.clamp(0.0, 1.0).mul(255.0).to(torch.uint8)
return rgb.reshape(h, w, 3)
# ---------------------------------------------------------------------------
# Image utilities
# ---------------------------------------------------------------------------
def pad_image_to_square(img: Image.Image) -> Image.Image:
w, h = img.size
if w == h:
return img
max_side = max(w, h)
if w > h:
py = (w - h) // 2
return ImageOps.expand(img, border=(0, py), fill=0)
else:
px = (h - w) // 2
return ImageOps.expand(img, border=(px, 0), fill=0)
def crop_feature_square_to_image_aspect(
feat_img: Image.Image,
original_size: tuple,
) -> Image.Image:
width, height = original_size
max_size = max(width, height)
px, py = (0, 0)
if width > height:
py = (width - height) // 2
elif height > width:
px = (height - width) // 2
scale_x = feat_img.width / max_size
scale_y = feat_img.height / max_size
left = int(round(px * scale_x))
top = int(round(py * scale_y))
right = int(round((px + width) * scale_x))
bottom = int(round((py + height) * scale_y))
return feat_img.crop((left, top, right, bottom))
# ---------------------------------------------------------------------------
# Build the full model at module scope
# ---------------------------------------------------------------------------
print("[INFO] Building ViT-Up model...")
DEVICE = "cuda"
DTYPE = torch.bfloat16
backbone = _build_backbone(DEVICE, DTYPE)
vit_up = _build_vit_up(DEVICE, DTYPE)
_load_vit_up_weights(backbone, vit_up, DEVICE)
backbone = backbone.eval()
vit_up = vit_up.eval()
print("[INFO] Model ready.")
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def _prepare_image(img: Image.Image) -> torch.Tensor:
"""Pad to square, resize, normalise — return (1, 3, H, W) on device."""
img_square = pad_image_to_square(img.convert("RGB"))
transform = T.Compose([
T.ToImage(),
T.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=T.InterpolationMode.BILINEAR, antialias=True),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=RESNET_MEAN, std=RESNET_STD),
])
return transform(img_square).unsqueeze(0).to(DEVICE)
def _compute_query_coords(out_size: int) -> torch.Tensor:
coords = torch.linspace(0.5, out_size - 0.5, out_size) / out_size
grid_y, grid_x = torch.meshgrid(coords, coords, indexing="ij")
return torch.stack((grid_x, grid_y), dim=-1).reshape(1, -1, 2)
@spaces.GPU(duration=120)
def extract_and_visualize(
input_image: Image.Image,
output_resolution: int,
) -> tuple[Image.Image, Image.Image, str]:
"""Extract dense ViT-Up features and visualise them via PCA.
Args:
input_image: Input PIL image.
output_resolution: Output feature map resolution (pixels per side).
Returns:
Tuple of (pca_visualization, input_resized, info_text).
"""
if input_image is None:
return None, None, "Please provide an input image."
out_size = int(output_resolution)
orig_w, orig_h = input_image.size
# Prepare input
pixel_values = _prepare_image(input_image)
# Compute cache data (backbone hidden states)
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=DTYPE):
cache_data = vit_up.compute_cache_data(
pixel_values=pixel_values,
backbone=backbone,
hidden_layer_img_size=IMAGE_SIZE,
)
# Query coords for dense output
query_coords = _compute_query_coords(out_size).to(DEVICE, dtype=DTYPE)
# Extract features
chunk_size = 4096
q_chunks = []
for q_start in range(0, query_coords.shape[1], chunk_size):
q_end = min(q_start + chunk_size, query_coords.shape[1])
q_chunk = vit_up(
pixel_values=None,
q_xy_normalized=query_coords[:, q_start:q_end, :],
cache_data=cache_data,
)
q_chunks.append(q_chunk[-1]) # final layer
features = torch.cat(q_chunks, dim=1) # (1, out_size*out_size, D)
features_hwc = features[0].reshape(out_size, out_size, -1).float().cpu()
# PCA
pca_data = _fit_pca(features_hwc.reshape(-1, features_hwc.shape[-1]), k=3)
pca_rgb = _apply_pca_rgb(features_hwc, pca_data)
pca_img = Image.fromarray(pca_rgb.numpy().astype(np.uint8), mode="RGB")
# Crop to original aspect ratio
pca_img = crop_feature_square_to_image_aspect(pca_img, (orig_w, orig_h))
# Resize for display
display_w, display_h = orig_w, orig_h
max_display = 512
if max(display_w, display_h) > max_display:
scale = max_display / max(display_w, display_h)
display_w = int(display_w * scale)
display_h = int(display_h * scale)
pca_display = pca_img.resize((display_w, display_h), Image.Resampling.NEAREST)
# Also create a resized input for side-by-side comparison
input_display = input_image.convert("RGB").resize((display_w, display_h), Image.Resampling.LANCZOS)
info = (f"Feature dim: {features_hwc.shape[-1]} | "
f"Output resolution: {out_size}x{out_size} | "
f"Total query points: {out_size * out_size}")
return pca_display, input_display, info
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
CSS = """
#col-container { max-width: 1100px; margin: 0 auto; }
.dark .gradio-container { color: var(--body-text-color); }
"""
with gr.Blocks(theme=gr.themes.Citrus(), css=CSS) as demo:
gr.Markdown("# ViT-Up: Faithful Feature Upsampling for Vision Transformers")
gr.Markdown(
"Upload an image to extract dense DINOv3 features at arbitrary resolution "
"via the ViT-Up feature upsampler. The PCA visualization shows the "
"3 principal components of the upsampled feature map as RGB."
)
gr.Markdown(
"[Paper](https://huggingface.co/papers/2606.14024) | "
"[GitHub](https://github.com/krispinwandel/vit-up) | "
"[Model Weights](https://huggingface.co/Krispin/vit-up)"
)
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Image", type="pil")
with gr.Accordion("Advanced settings", open=False):
out_res = gr.Slider(
label="Output resolution (pixels per side)",
minimum=28,
maximum=224,
value=112,
step=28,
)
run_btn = gr.Button("Extract Features", variant="primary")
with gr.Column():
pca_out = gr.Image(label="PCA Feature Visualization")
input_display = gr.Image(label="Input (resized)")
info_text = gr.Textbox(label="Info", interactive=False)
run_btn.click(
fn=extract_and_visualize,
inputs=[input_img, out_res],
outputs=[pca_out, input_display, info_text],
api_name="extract_features",
)
gr.Examples(
examples=[
["city_with_cars.png", 112],
["fruit_store.png", 112],
],
inputs=[input_img, out_res],
outputs=[pca_out, input_display, info_text],
fn=extract_and_visualize,
cache_examples=True,
cache_mode="lazy",
)
demo.launch(mcp_server=True)