File size: 11,728 Bytes
bf826bc
 
 
b90fe2f
bf826bc
6788bd3
bf826bc
b90fe2f
 
 
 
0cfc510
 
16ead7c
a6e7953
16ead7c
 
d73e700
4175ab9
bf826bc
 
 
bf6787d
 
16ead7c
bf6787d
 
16ead7c
d73e700
 
4175ab9
 
 
bf826bc
 
5356d23
0cfc510
bf826bc
b0a5be5
 
 
 
 
 
 
 
 
bf826bc
b0a5be5
bf826bc
b0a5be5
 
 
 
16ead7c
b0a5be5
 
 
 
 
 
 
16ead7c
bf6787d
16ead7c
 
 
 
 
 
 
 
 
 
 
 
 
bf6787d
 
 
 
 
 
 
 
 
 
 
16ead7c
 
 
b0a5be5
 
 
 
 
 
 
 
 
 
 
bf826bc
b0a5be5
bf826bc
b0a5be5
 
bf826bc
b0a5be5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf826bc
b0a5be5
 
 
 
 
 
 
bf826bc
b0a5be5
bf826bc
b0a5be5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf826bc
b0a5be5
bf826bc
b0a5be5
 
 
 
 
 
 
 
bf826bc
b0a5be5
 
 
 
 
 
 
 
 
 
bf826bc
b0a5be5
bf826bc
b0a5be5
bf826bc
b0a5be5
 
 
 
 
 
 
 
 
 
bf826bc
b0a5be5
 
 
 
 
bf826bc
b0a5be5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf826bc
b0a5be5
 
 
 
 
 
 
 
 
 
 
 
bf826bc
4175ab9
bf826bc
b0a5be5
bf826bc
0a1d7c0
0cfc510
bf826bc
d73e700
0cfc510
 
d73e700
4175ab9
0cfc510
5356d23
bf826bc
0cfc510
 
bf826bc
0cfc510
 
 
bf826bc
 
4175ab9
bf826bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d287a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# app.py — DINOv3 two‑image patch similarity (click on Image 1 → show similarities on both images)
# Runs on CPU or CUDA. No external image URLs.

import os
from typing import Tuple
import gradio as gr

import numpy as np
from PIL import Image, ImageDraw

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
#from transformers import AutoModel  # trust_remote_code=True
from transformers import AutoModel 
 



# ============================
# Config
# ============================
DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"

#DEFAULT_MODEL_ID = "onnx-community/dinov3-vits16-pretrain-lvd1689m-ONNX"
#ALT_MODEL_ID = "onnx-community/dinov3-vith16-pretrain-lvd1689m-ONNX"

AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]

PATCH_SIZE = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)
# Many DINOv3 HF ports expose 1 [CLS] + 4 registers at the front
N_SPECIAL_TOKENS = 5

# Robust colormap import (Matplotlib new/old)
try:
    from matplotlib import colormaps as _mpl_colormaps
    def _get_cmap(name: str):
        return _mpl_colormaps[name]
except Exception:
    import matplotlib.cm as _cm
    def _get_cmap(name: str):
        return _cm.get_cmap(name)

# ============================
# Model loading / cache
# ============================
_model_cache = {}
_current_model_id = None
model = None

def load_model_from_hubold(model_id: str):
    print(f"Loading model '{model_id}' from HF Hub…")
    token = os.environ.get("HF_TOKEN")
    mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
    mdl.to(DEVICE).eval()
    print(f"✅ Loaded '{model_id}' on {DEVICE}")
    return mdl

    
def load_model_from_hubold2(model_id: str):
    print(f"Loading model '{model_id}' from HF Hub…")
    token = os.environ.get("HF_TOKEN")
    # Use pipeline instead of AutoModel
    extractor = pipeline(
        "image-feature-extraction",
        model=model_id,
        token=token,
        trust_remote_code=True,
        device=0 if DEVICE == "cuda" else -1,
    )
    print(f"✅ Loaded '{model_id}' on {DEVICE}")
    return extractor

def load_model_from_hub(model_id: str):
    print(f"Loading model '{model_id}' from HF Hub…")
    token = os.environ.get("HF_TOKEN")
    mdl = AutoModel.from_pretrained(
        model_id,
        token=token,
        trust_remote_code=True,
    )
    mdl.to(DEVICE).eval()
    print(f"✅ Loaded '{model_id}' on {DEVICE}")
    return mdl



def get_model(model_id: str):
    if model_id in _model_cache:
        return _model_cache[model_id]
    mdl = load_model_from_hub(model_id)
    _model_cache[model_id] = mdl
    return mdl

# Load default at startup
model = get_model(DEFAULT_MODEL_ID)
_current_model_id = DEFAULT_MODEL_ID

# ============================
# Helpers
# ============================

def resize_to_grid(img: Image.Image, long_side: int, patch: int = PATCH_SIZE) -> torch.Tensor:
    """Resize so max(h,w)=long_side with aspect kept; then pad to multiples of patch.
    Return CHW float tensor in [0,1]."""
    w, h = img.size
    scale = long_side / max(h, w)
    new_h = max(patch, int(round(h * scale)))
    new_w = max(patch, int(round(w * scale)))
    new_h = ((new_h + patch - 1) // patch) * patch
    new_w = ((new_w + patch - 1) // patch) * patch
    return TF.to_tensor(TF.resize(img.convert("RGB"), (new_h, new_w)))

def colorize(sim_map_up: np.ndarray, cmap_name: str = "viridis") -> Image.Image:
    x = sim_map_up.astype(np.float32)
    x = (x - x.min()) / (x.max() - x.min() + 1e-6)
    rgb = (_get_cmap(cmap_name)(x)[..., :3] * 255).astype(np.uint8)
    return Image.fromarray(rgb)

def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
    base = base.convert("RGBA")
    heat = heat.convert("RGBA")
    a = Image.new("L", heat.size, int(255 * alpha))
    heat.putalpha(a)
    out = Image.alpha_composite(base, heat)
    return out.convert("RGB")

def draw_crosshair(img: Image.Image, x: int, y: int, radius: int | None = None) -> Image.Image:
    r = radius if radius is not None else max(2, PATCH_SIZE // 2)
    out = img.copy()
    draw = ImageDraw.Draw(out)
    draw.line([(x - r, y), (x + r, y)], fill="red", width=3)
    draw.line([(x, y - r), (x, y + r)], fill="red", width=3)
    return out

# ============================
# Feature extraction
# ============================
@torch.inference_mode()
def extract_image_features(image_pil: Image.Image, target_long_side: int, mdl=None):
    mdl = mdl or model
    t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
    t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
    _, _, H, W = t_norm.shape
    Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE

    outputs = mdl(t_norm)
    patch_emb = outputs.last_hidden_state.squeeze(0)[N_SPECIAL_TOKENS:, :]  # skip special tokens
    X = F.normalize(patch_emb, p=2, dim=-1)  # (Hp*Wp, D), L2 norm for cosine
    img_resized = TF.to_pil_image(t)

    return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}

# ============================
# Similarity utilities
# ============================

def row_col_from_xy(x_pix: int, y_pix: int, Hp: int, Wp: int):
    col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
    row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
    return row, col

@torch.inference_mode()
def similarity_map(X: torch.Tensor, Hp: int, Wp: int, q_vec: torch.Tensor,
                   img_h: int, img_w: int):
    sims = torch.matmul(X, q_vec)  # (Hp*Wp)
    sim_map = sims.view(Hp, Wp)
    sim_up = F.interpolate(
        sim_map.unsqueeze(0).unsqueeze(0),
        size=(img_h, img_w),
        mode="bicubic",
        align_corners=False,
    ).squeeze().detach().cpu().numpy()
    return sim_map, sim_up

# ============================
# Core: click on image 1 → heatmaps on image 1 and image 2
# ============================

def click_two_image_similarity(state1: dict, state2: dict, click_xy: Tuple[int, int],
                               exclude_radius_patches: int, alpha: float, cmap_name: str):
    if not state1 or not state2:
        return (None,)*6

    X1, Hp1, Wp1, img1 = state1["X"], state1["Hp"], state1["Wp"], state1["img"]
    X2, Hp2, Wp2, img2 = state2["X"], state2["Hp"], state2["Wp"], state2["img"]

    img1_w, img1_h = img1.size
    img2_w, img2_h = img2.size

    # Query vector from clicked patch on image 1
    col = int(np.clip(click_xy[0] // PATCH_SIZE, 0, Wp1 - 1))
    row = int(np.clip(click_xy[1] // PATCH_SIZE, 0, Hp1 - 1))
    idx = row * Wp1 + col
    q = X1[idx]  # (D,)

    # Similarity on image 1 (+ small exclusion mask around click if requested)
    sims1 = torch.matmul(X1, q)
    sim_map1 = sims1.view(Hp1, Wp1)
    if exclude_radius_patches > 0:
        rr, cc = torch.meshgrid(
            torch.arange(Hp1, device=sims1.device),
            torch.arange(Wp1, device=sims1.device),
            indexing="ij",
        )
        mask1 = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
        sim_map1 = sim_map1.masked_fill(mask1, float("-inf"))

    sim1_up = F.interpolate(
        sim_map1.unsqueeze(0).unsqueeze(0),
        size=(img1_h, img1_w),
        mode="bicubic",
        align_corners=False,
    ).squeeze().detach().cpu().numpy()

    heat1 = colorize(sim1_up, cmap_name)
    overlay1 = blend(img1, heat1, alpha)
    marked1 = draw_crosshair(img1, int(click_xy[0]), int(click_xy[1]), radius=PATCH_SIZE // 2)

    # Similarity on image 2
    sims2 = torch.matmul(X2, q)
    sim_map2 = sims2.view(Hp2, Wp2)
    sim2_up = F.interpolate(
        sim_map2.unsqueeze(0).unsqueeze(0),
        size=(img2_h, img2_w),
        mode="bicubic",
        align_corners=False,
    ).squeeze().detach().cpu().numpy()

    heat2 = colorize(sim2_up, cmap_name)
    overlay2 = blend(img2, heat2, alpha)

    return marked1, heat1, overlay1, heat2, overlay2, float(sim2_up.max())

# ============================
# Gradio UI
# ============================
with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarity") as demo:
    gr.Markdown("# DINOv3 Two‑Image Patch Similarity")
    gr.Markdown("Upload two images and press **Process both**. Then click on **Image 1** to see similar regions on **both** images.")

    state1 = gr.State()
    state2 = gr.State()

    with gr.Row():
        with gr.Column():
            model_choice = gr.Dropdown(choices=AVAILABLE_MODELS, value=DEFAULT_MODEL_ID, label="Backbone")
            target_long_side = gr.Slider(224, 1024, value=768, step=16, label="Resolution (long side)")
            alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity")
            cmap = gr.Dropdown(["viridis", "magma", "plasma", "inferno", "turbo", "cividis"], value="viridis", label="Colormap")
            exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius (patches) for Image 1")
            start_btn = gr.Button("▶️ Process both", variant="primary")

        with gr.Column():
            img1 = gr.Image(label="Image 1 (clickable)", type="pil", sources=["upload", "clipboard"], value=None)
            img2 = gr.Image(label="Image 2", type="pil", sources=["upload", "clipboard"], value=None)

    with gr.Row():
        with gr.Column():
            marked1 = gr.Image(label="Image 1 — click marker / preview", interactive=False)
            heat1   = gr.Image(label="Image 1 — similarity heatmap", interactive=False)
            overlay1= gr.Image(label="Image 1 — overlay", interactive=False)
        with gr.Column():
            heat2   = gr.Image(label="Image 2 — similarity heatmap", interactive=False)
            overlay2= gr.Image(label="Image 2 — overlay", interactive=False)
            score2  = gr.Number(label="Image 2 — max similarity score", precision=6)

    # Utilities
    def _ensure_model(model_id: str):
        global model, _current_model_id
        if model_id != _current_model_id:
            model = get_model(model_id)
            _current_model_id = model_id

    # Process button → extract features for both images and store in state
    def _run_both(im1: Image.Image, im2: Image.Image, long_side: int, model_id: str, progress=gr.Progress(track_tqdm=False)):
        if im1 is None or im2 is None:
            raise gr.Error("Please provide both images before processing.")
        _ensure_model(model_id)
        progress(0, desc="Extracting features for Image 1…")
        st1 = extract_image_features(im1, int(long_side), mdl=model)
        progress(0.5, desc="Extracting features for Image 2…")
        st2 = extract_image_features(im2, int(long_side), mdl=model)
        progress(1, desc="Done")
        # Show quick previews to confirm processing
        return st1["img"], st2["img"], st1, st2

    start_btn.click(
        _run_both,
        inputs=[img1, img2, target_long_side, model_choice],
        outputs=[marked1, overlay2, state1, state2],
    )

    # Clicking on Image 1 → compute similarities on both images
    def _on_click(st1, st2, a: float, m: str, excl: int, evt: gr.SelectData):
        if not st1 or not st2 or evt is None:
            return (None,)*6
        return click_two_image_similarity(
            st1, st2,
            click_xy=evt.index,
            exclude_radius_patches=int(excl),
            alpha=float(a), cmap_name=m,
        )

    img1.select(
        _on_click,
        inputs=[state1, state2, alpha, cmap, exclude_r],
        outputs=[marked1, heat1, overlay1, heat2, overlay2, score2],
    )

if __name__ == "__main__":
    demo.launch()