Rausda6's picture
Update app.py
6788bd3 verified
# app.py — DINOv3 two‑image patch similarity (click on Image 1 → show similarities on both images)
# Runs on CPU or CUDA. No external image URLs.
import os
from typing import Tuple
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
#from transformers import AutoModel # trust_remote_code=True
from transformers import AutoModel
# ============================
# Config
# ============================
DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
#DEFAULT_MODEL_ID = "onnx-community/dinov3-vits16-pretrain-lvd1689m-ONNX"
#ALT_MODEL_ID = "onnx-community/dinov3-vith16-pretrain-lvd1689m-ONNX"
AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
PATCH_SIZE = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
# Many DINOv3 HF ports expose 1 [CLS] + 4 registers at the front
N_SPECIAL_TOKENS = 5
# Robust colormap import (Matplotlib new/old)
try:
from matplotlib import colormaps as _mpl_colormaps
def _get_cmap(name: str):
return _mpl_colormaps[name]
except Exception:
import matplotlib.cm as _cm
def _get_cmap(name: str):
return _cm.get_cmap(name)
# ============================
# Model loading / cache
# ============================
_model_cache = {}
_current_model_id = None
model = None
def load_model_from_hubold(model_id: str):
print(f"Loading model '{model_id}' from HF Hub…")
token = os.environ.get("HF_TOKEN")
mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
mdl.to(DEVICE).eval()
print(f"✅ Loaded '{model_id}' on {DEVICE}")
return mdl
def load_model_from_hubold2(model_id: str):
print(f"Loading model '{model_id}' from HF Hub…")
token = os.environ.get("HF_TOKEN")
# Use pipeline instead of AutoModel
extractor = pipeline(
"image-feature-extraction",
model=model_id,
token=token,
trust_remote_code=True,
device=0 if DEVICE == "cuda" else -1,
)
print(f"✅ Loaded '{model_id}' on {DEVICE}")
return extractor
def load_model_from_hub(model_id: str):
print(f"Loading model '{model_id}' from HF Hub…")
token = os.environ.get("HF_TOKEN")
mdl = AutoModel.from_pretrained(
model_id,
token=token,
trust_remote_code=True,
)
mdl.to(DEVICE).eval()
print(f"✅ Loaded '{model_id}' on {DEVICE}")
return mdl
def get_model(model_id: str):
if model_id in _model_cache:
return _model_cache[model_id]
mdl = load_model_from_hub(model_id)
_model_cache[model_id] = mdl
return mdl
# Load default at startup
model = get_model(DEFAULT_MODEL_ID)
_current_model_id = DEFAULT_MODEL_ID
# ============================
# Helpers
# ============================
def resize_to_grid(img: Image.Image, long_side: int, patch: int = PATCH_SIZE) -> torch.Tensor:
"""Resize so max(h,w)=long_side with aspect kept; then pad to multiples of patch.
Return CHW float tensor in [0,1]."""
w, h = img.size
scale = long_side / max(h, w)
new_h = max(patch, int(round(h * scale)))
new_w = max(patch, int(round(w * scale)))
new_h = ((new_h + patch - 1) // patch) * patch
new_w = ((new_w + patch - 1) // patch) * patch
return TF.to_tensor(TF.resize(img.convert("RGB"), (new_h, new_w)))
def colorize(sim_map_up: np.ndarray, cmap_name: str = "viridis") -> Image.Image:
x = sim_map_up.astype(np.float32)
x = (x - x.min()) / (x.max() - x.min() + 1e-6)
rgb = (_get_cmap(cmap_name)(x)[..., :3] * 255).astype(np.uint8)
return Image.fromarray(rgb)
def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
base = base.convert("RGBA")
heat = heat.convert("RGBA")
a = Image.new("L", heat.size, int(255 * alpha))
heat.putalpha(a)
out = Image.alpha_composite(base, heat)
return out.convert("RGB")
def draw_crosshair(img: Image.Image, x: int, y: int, radius: int | None = None) -> Image.Image:
r = radius if radius is not None else max(2, PATCH_SIZE // 2)
out = img.copy()
draw = ImageDraw.Draw(out)
draw.line([(x - r, y), (x + r, y)], fill="red", width=3)
draw.line([(x, y - r), (x, y + r)], fill="red", width=3)
return out
# ============================
# Feature extraction
# ============================
@torch.inference_mode()
def extract_image_features(image_pil: Image.Image, target_long_side: int, mdl=None):
mdl = mdl or model
t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
_, _, H, W = t_norm.shape
Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
outputs = mdl(t_norm)
patch_emb = outputs.last_hidden_state.squeeze(0)[N_SPECIAL_TOKENS:, :] # skip special tokens
X = F.normalize(patch_emb, p=2, dim=-1) # (Hp*Wp, D), L2 norm for cosine
img_resized = TF.to_pil_image(t)
return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
# ============================
# Similarity utilities
# ============================
def row_col_from_xy(x_pix: int, y_pix: int, Hp: int, Wp: int):
col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
return row, col
@torch.inference_mode()
def similarity_map(X: torch.Tensor, Hp: int, Wp: int, q_vec: torch.Tensor,
img_h: int, img_w: int):
sims = torch.matmul(X, q_vec) # (Hp*Wp)
sim_map = sims.view(Hp, Wp)
sim_up = F.interpolate(
sim_map.unsqueeze(0).unsqueeze(0),
size=(img_h, img_w),
mode="bicubic",
align_corners=False,
).squeeze().detach().cpu().numpy()
return sim_map, sim_up
# ============================
# Core: click on image 1 → heatmaps on image 1 and image 2
# ============================
def click_two_image_similarity(state1: dict, state2: dict, click_xy: Tuple[int, int],
exclude_radius_patches: int, alpha: float, cmap_name: str):
if not state1 or not state2:
return (None,)*6
X1, Hp1, Wp1, img1 = state1["X"], state1["Hp"], state1["Wp"], state1["img"]
X2, Hp2, Wp2, img2 = state2["X"], state2["Hp"], state2["Wp"], state2["img"]
img1_w, img1_h = img1.size
img2_w, img2_h = img2.size
# Query vector from clicked patch on image 1
col = int(np.clip(click_xy[0] // PATCH_SIZE, 0, Wp1 - 1))
row = int(np.clip(click_xy[1] // PATCH_SIZE, 0, Hp1 - 1))
idx = row * Wp1 + col
q = X1[idx] # (D,)
# Similarity on image 1 (+ small exclusion mask around click if requested)
sims1 = torch.matmul(X1, q)
sim_map1 = sims1.view(Hp1, Wp1)
if exclude_radius_patches > 0:
rr, cc = torch.meshgrid(
torch.arange(Hp1, device=sims1.device),
torch.arange(Wp1, device=sims1.device),
indexing="ij",
)
mask1 = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
sim_map1 = sim_map1.masked_fill(mask1, float("-inf"))
sim1_up = F.interpolate(
sim_map1.unsqueeze(0).unsqueeze(0),
size=(img1_h, img1_w),
mode="bicubic",
align_corners=False,
).squeeze().detach().cpu().numpy()
heat1 = colorize(sim1_up, cmap_name)
overlay1 = blend(img1, heat1, alpha)
marked1 = draw_crosshair(img1, int(click_xy[0]), int(click_xy[1]), radius=PATCH_SIZE // 2)
# Similarity on image 2
sims2 = torch.matmul(X2, q)
sim_map2 = sims2.view(Hp2, Wp2)
sim2_up = F.interpolate(
sim_map2.unsqueeze(0).unsqueeze(0),
size=(img2_h, img2_w),
mode="bicubic",
align_corners=False,
).squeeze().detach().cpu().numpy()
heat2 = colorize(sim2_up, cmap_name)
overlay2 = blend(img2, heat2, alpha)
return marked1, heat1, overlay1, heat2, overlay2, float(sim2_up.max())
# ============================
# Gradio UI
# ============================
with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarity") as demo:
gr.Markdown("# DINOv3 Two‑Image Patch Similarity")
gr.Markdown("Upload two images and press **Process both**. Then click on **Image 1** to see similar regions on **both** images.")
state1 = gr.State()
state2 = gr.State()
with gr.Row():
with gr.Column():
model_choice = gr.Dropdown(choices=AVAILABLE_MODELS, value=DEFAULT_MODEL_ID, label="Backbone")
target_long_side = gr.Slider(224, 1024, value=768, step=16, label="Resolution (long side)")
alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity")
cmap = gr.Dropdown(["viridis", "magma", "plasma", "inferno", "turbo", "cividis"], value="viridis", label="Colormap")
exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius (patches) for Image 1")
start_btn = gr.Button("▶️ Process both", variant="primary")
with gr.Column():
img1 = gr.Image(label="Image 1 (clickable)", type="pil", sources=["upload", "clipboard"], value=None)
img2 = gr.Image(label="Image 2", type="pil", sources=["upload", "clipboard"], value=None)
with gr.Row():
with gr.Column():
marked1 = gr.Image(label="Image 1 — click marker / preview", interactive=False)
heat1 = gr.Image(label="Image 1 — similarity heatmap", interactive=False)
overlay1= gr.Image(label="Image 1 — overlay", interactive=False)
with gr.Column():
heat2 = gr.Image(label="Image 2 — similarity heatmap", interactive=False)
overlay2= gr.Image(label="Image 2 — overlay", interactive=False)
score2 = gr.Number(label="Image 2 — max similarity score", precision=6)
# Utilities
def _ensure_model(model_id: str):
global model, _current_model_id
if model_id != _current_model_id:
model = get_model(model_id)
_current_model_id = model_id
# Process button → extract features for both images and store in state
def _run_both(im1: Image.Image, im2: Image.Image, long_side: int, model_id: str, progress=gr.Progress(track_tqdm=False)):
if im1 is None or im2 is None:
raise gr.Error("Please provide both images before processing.")
_ensure_model(model_id)
progress(0, desc="Extracting features for Image 1…")
st1 = extract_image_features(im1, int(long_side), mdl=model)
progress(0.5, desc="Extracting features for Image 2…")
st2 = extract_image_features(im2, int(long_side), mdl=model)
progress(1, desc="Done")
# Show quick previews to confirm processing
return st1["img"], st2["img"], st1, st2
start_btn.click(
_run_both,
inputs=[img1, img2, target_long_side, model_choice],
outputs=[marked1, overlay2, state1, state2],
)
# Clicking on Image 1 → compute similarities on both images
def _on_click(st1, st2, a: float, m: str, excl: int, evt: gr.SelectData):
if not st1 or not st2 or evt is None:
return (None,)*6
return click_two_image_similarity(
st1, st2,
click_xy=evt.index,
exclude_radius_patches=int(excl),
alpha=float(a), cmap_name=m,
)
img1.select(
_on_click,
inputs=[state1, state2, alpha, cmap, exclude_r],
outputs=[marked1, heat1, overlay1, heat2, overlay2, score2],
)
if __name__ == "__main__":
demo.launch()