Spaces:
Sleeping
Sleeping
File size: 11,728 Bytes
bf826bc b90fe2f bf826bc 6788bd3 bf826bc b90fe2f 0cfc510 16ead7c a6e7953 16ead7c d73e700 4175ab9 bf826bc bf6787d 16ead7c bf6787d 16ead7c d73e700 4175ab9 bf826bc 5356d23 0cfc510 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 16ead7c b0a5be5 16ead7c bf6787d 16ead7c bf6787d 16ead7c b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc b0a5be5 bf826bc 4175ab9 bf826bc b0a5be5 bf826bc 0a1d7c0 0cfc510 bf826bc d73e700 0cfc510 d73e700 4175ab9 0cfc510 5356d23 bf826bc 0cfc510 bf826bc 0cfc510 bf826bc 4175ab9 bf826bc d287a97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 |
# app.py — DINOv3 two‑image patch similarity (click on Image 1 → show similarities on both images)
# Runs on CPU or CUDA. No external image URLs.
import os
from typing import Tuple
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
#from transformers import AutoModel # trust_remote_code=True
from transformers import AutoModel
# ============================
# Config
# ============================
DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
#DEFAULT_MODEL_ID = "onnx-community/dinov3-vits16-pretrain-lvd1689m-ONNX"
#ALT_MODEL_ID = "onnx-community/dinov3-vith16-pretrain-lvd1689m-ONNX"
AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
PATCH_SIZE = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
# Many DINOv3 HF ports expose 1 [CLS] + 4 registers at the front
N_SPECIAL_TOKENS = 5
# Robust colormap import (Matplotlib new/old)
try:
from matplotlib import colormaps as _mpl_colormaps
def _get_cmap(name: str):
return _mpl_colormaps[name]
except Exception:
import matplotlib.cm as _cm
def _get_cmap(name: str):
return _cm.get_cmap(name)
# ============================
# Model loading / cache
# ============================
_model_cache = {}
_current_model_id = None
model = None
def load_model_from_hubold(model_id: str):
print(f"Loading model '{model_id}' from HF Hub…")
token = os.environ.get("HF_TOKEN")
mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
mdl.to(DEVICE).eval()
print(f"✅ Loaded '{model_id}' on {DEVICE}")
return mdl
def load_model_from_hubold2(model_id: str):
print(f"Loading model '{model_id}' from HF Hub…")
token = os.environ.get("HF_TOKEN")
# Use pipeline instead of AutoModel
extractor = pipeline(
"image-feature-extraction",
model=model_id,
token=token,
trust_remote_code=True,
device=0 if DEVICE == "cuda" else -1,
)
print(f"✅ Loaded '{model_id}' on {DEVICE}")
return extractor
def load_model_from_hub(model_id: str):
print(f"Loading model '{model_id}' from HF Hub…")
token = os.environ.get("HF_TOKEN")
mdl = AutoModel.from_pretrained(
model_id,
token=token,
trust_remote_code=True,
)
mdl.to(DEVICE).eval()
print(f"✅ Loaded '{model_id}' on {DEVICE}")
return mdl
def get_model(model_id: str):
if model_id in _model_cache:
return _model_cache[model_id]
mdl = load_model_from_hub(model_id)
_model_cache[model_id] = mdl
return mdl
# Load default at startup
model = get_model(DEFAULT_MODEL_ID)
_current_model_id = DEFAULT_MODEL_ID
# ============================
# Helpers
# ============================
def resize_to_grid(img: Image.Image, long_side: int, patch: int = PATCH_SIZE) -> torch.Tensor:
"""Resize so max(h,w)=long_side with aspect kept; then pad to multiples of patch.
Return CHW float tensor in [0,1]."""
w, h = img.size
scale = long_side / max(h, w)
new_h = max(patch, int(round(h * scale)))
new_w = max(patch, int(round(w * scale)))
new_h = ((new_h + patch - 1) // patch) * patch
new_w = ((new_w + patch - 1) // patch) * patch
return TF.to_tensor(TF.resize(img.convert("RGB"), (new_h, new_w)))
def colorize(sim_map_up: np.ndarray, cmap_name: str = "viridis") -> Image.Image:
x = sim_map_up.astype(np.float32)
x = (x - x.min()) / (x.max() - x.min() + 1e-6)
rgb = (_get_cmap(cmap_name)(x)[..., :3] * 255).astype(np.uint8)
return Image.fromarray(rgb)
def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
base = base.convert("RGBA")
heat = heat.convert("RGBA")
a = Image.new("L", heat.size, int(255 * alpha))
heat.putalpha(a)
out = Image.alpha_composite(base, heat)
return out.convert("RGB")
def draw_crosshair(img: Image.Image, x: int, y: int, radius: int | None = None) -> Image.Image:
r = radius if radius is not None else max(2, PATCH_SIZE // 2)
out = img.copy()
draw = ImageDraw.Draw(out)
draw.line([(x - r, y), (x + r, y)], fill="red", width=3)
draw.line([(x, y - r), (x, y + r)], fill="red", width=3)
return out
# ============================
# Feature extraction
# ============================
@torch.inference_mode()
def extract_image_features(image_pil: Image.Image, target_long_side: int, mdl=None):
mdl = mdl or model
t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
_, _, H, W = t_norm.shape
Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
outputs = mdl(t_norm)
patch_emb = outputs.last_hidden_state.squeeze(0)[N_SPECIAL_TOKENS:, :] # skip special tokens
X = F.normalize(patch_emb, p=2, dim=-1) # (Hp*Wp, D), L2 norm for cosine
img_resized = TF.to_pil_image(t)
return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
# ============================
# Similarity utilities
# ============================
def row_col_from_xy(x_pix: int, y_pix: int, Hp: int, Wp: int):
col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
return row, col
@torch.inference_mode()
def similarity_map(X: torch.Tensor, Hp: int, Wp: int, q_vec: torch.Tensor,
img_h: int, img_w: int):
sims = torch.matmul(X, q_vec) # (Hp*Wp)
sim_map = sims.view(Hp, Wp)
sim_up = F.interpolate(
sim_map.unsqueeze(0).unsqueeze(0),
size=(img_h, img_w),
mode="bicubic",
align_corners=False,
).squeeze().detach().cpu().numpy()
return sim_map, sim_up
# ============================
# Core: click on image 1 → heatmaps on image 1 and image 2
# ============================
def click_two_image_similarity(state1: dict, state2: dict, click_xy: Tuple[int, int],
exclude_radius_patches: int, alpha: float, cmap_name: str):
if not state1 or not state2:
return (None,)*6
X1, Hp1, Wp1, img1 = state1["X"], state1["Hp"], state1["Wp"], state1["img"]
X2, Hp2, Wp2, img2 = state2["X"], state2["Hp"], state2["Wp"], state2["img"]
img1_w, img1_h = img1.size
img2_w, img2_h = img2.size
# Query vector from clicked patch on image 1
col = int(np.clip(click_xy[0] // PATCH_SIZE, 0, Wp1 - 1))
row = int(np.clip(click_xy[1] // PATCH_SIZE, 0, Hp1 - 1))
idx = row * Wp1 + col
q = X1[idx] # (D,)
# Similarity on image 1 (+ small exclusion mask around click if requested)
sims1 = torch.matmul(X1, q)
sim_map1 = sims1.view(Hp1, Wp1)
if exclude_radius_patches > 0:
rr, cc = torch.meshgrid(
torch.arange(Hp1, device=sims1.device),
torch.arange(Wp1, device=sims1.device),
indexing="ij",
)
mask1 = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
sim_map1 = sim_map1.masked_fill(mask1, float("-inf"))
sim1_up = F.interpolate(
sim_map1.unsqueeze(0).unsqueeze(0),
size=(img1_h, img1_w),
mode="bicubic",
align_corners=False,
).squeeze().detach().cpu().numpy()
heat1 = colorize(sim1_up, cmap_name)
overlay1 = blend(img1, heat1, alpha)
marked1 = draw_crosshair(img1, int(click_xy[0]), int(click_xy[1]), radius=PATCH_SIZE // 2)
# Similarity on image 2
sims2 = torch.matmul(X2, q)
sim_map2 = sims2.view(Hp2, Wp2)
sim2_up = F.interpolate(
sim_map2.unsqueeze(0).unsqueeze(0),
size=(img2_h, img2_w),
mode="bicubic",
align_corners=False,
).squeeze().detach().cpu().numpy()
heat2 = colorize(sim2_up, cmap_name)
overlay2 = blend(img2, heat2, alpha)
return marked1, heat1, overlay1, heat2, overlay2, float(sim2_up.max())
# ============================
# Gradio UI
# ============================
with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarity") as demo:
gr.Markdown("# DINOv3 Two‑Image Patch Similarity")
gr.Markdown("Upload two images and press **Process both**. Then click on **Image 1** to see similar regions on **both** images.")
state1 = gr.State()
state2 = gr.State()
with gr.Row():
with gr.Column():
model_choice = gr.Dropdown(choices=AVAILABLE_MODELS, value=DEFAULT_MODEL_ID, label="Backbone")
target_long_side = gr.Slider(224, 1024, value=768, step=16, label="Resolution (long side)")
alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity")
cmap = gr.Dropdown(["viridis", "magma", "plasma", "inferno", "turbo", "cividis"], value="viridis", label="Colormap")
exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius (patches) for Image 1")
start_btn = gr.Button("▶️ Process both", variant="primary")
with gr.Column():
img1 = gr.Image(label="Image 1 (clickable)", type="pil", sources=["upload", "clipboard"], value=None)
img2 = gr.Image(label="Image 2", type="pil", sources=["upload", "clipboard"], value=None)
with gr.Row():
with gr.Column():
marked1 = gr.Image(label="Image 1 — click marker / preview", interactive=False)
heat1 = gr.Image(label="Image 1 — similarity heatmap", interactive=False)
overlay1= gr.Image(label="Image 1 — overlay", interactive=False)
with gr.Column():
heat2 = gr.Image(label="Image 2 — similarity heatmap", interactive=False)
overlay2= gr.Image(label="Image 2 — overlay", interactive=False)
score2 = gr.Number(label="Image 2 — max similarity score", precision=6)
# Utilities
def _ensure_model(model_id: str):
global model, _current_model_id
if model_id != _current_model_id:
model = get_model(model_id)
_current_model_id = model_id
# Process button → extract features for both images and store in state
def _run_both(im1: Image.Image, im2: Image.Image, long_side: int, model_id: str, progress=gr.Progress(track_tqdm=False)):
if im1 is None or im2 is None:
raise gr.Error("Please provide both images before processing.")
_ensure_model(model_id)
progress(0, desc="Extracting features for Image 1…")
st1 = extract_image_features(im1, int(long_side), mdl=model)
progress(0.5, desc="Extracting features for Image 2…")
st2 = extract_image_features(im2, int(long_side), mdl=model)
progress(1, desc="Done")
# Show quick previews to confirm processing
return st1["img"], st2["img"], st1, st2
start_btn.click(
_run_both,
inputs=[img1, img2, target_long_side, model_choice],
outputs=[marked1, overlay2, state1, state2],
)
# Clicking on Image 1 → compute similarities on both images
def _on_click(st1, st2, a: float, m: str, excl: int, evt: gr.SelectData):
if not st1 or not st2 or evt is None:
return (None,)*6
return click_two_image_similarity(
st1, st2,
click_xy=evt.index,
exclude_radius_patches=int(excl),
alpha=float(a), cmap_name=m,
)
img1.select(
_on_click,
inputs=[state1, state2, alpha, cmap, exclude_r],
outputs=[marked1, heat1, overlay1, heat2, overlay2, score2],
)
if __name__ == "__main__":
demo.launch()
|