manu02 commited on
Commit
42d4446
·
verified ·
1 Parent(s): 946d5b3

Added token to app.py

Browse files
Files changed (1) hide show
  1. app.py +624 -622
app.py CHANGED
@@ -1,622 +1,624 @@
1
- # app.py
2
- # Gradio UI for interactive DINOv3 patch similarity (single or dual image)
3
- # - No AutoImageProcessor, no resize (only pad to multiple of patch size)
4
- # - Single image: click to show self-similarity; selected cell outlined in RED
5
- # - Two images: click on one side -> self overlay on source, cross overlay on target; best match on target outlined in YELLOW
6
- # - Red selection rectangle is hidden on the non-active image
7
- # - Patch size inferred from model (no override). Patch indices are not annotated.
8
- # - Dataset selector (LVD-1689M / SAT-493M); model dropdown shows only the short name between "dinov3-" and "-pretrain".
9
- # - Sample URL dropdowns switch between LVD (COCO/Picsum) and SAT (satellite imagery) and auto-fill / clear uploads.
10
-
11
- import io
12
- import math
13
- import urllib.request
14
- from functools import lru_cache
15
- from typing import Optional, Tuple, Dict, List
16
-
17
- import gradio as gr
18
- import numpy as np
19
- from PIL import Image, ImageDraw
20
- import torch
21
- from torchvision import transforms
22
- from transformers import AutoModel
23
- from matplotlib import colormaps as cm
24
-
25
- # ---------- Provided model IDs (ground truth list) ----------
26
- MODEL_ID_LIST = [
27
- "facebook/dinov3-vits16-pretrain-lvd1689m",
28
- "facebook/dinov3-vits16plus-pretrain-lvd1689m",
29
- "facebook/dinov3-vitb16-pretrain-lvd1689m",
30
- "facebook/dinov3-vitl16-pretrain-lvd1689m",
31
- "facebook/dinov3-vith16plus-pretrain-lvd1689m",
32
- "facebook/dinov3-vit7b16-pretrain-lvd1689m",
33
- "facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
34
- "facebook/dinov3-convnext-small-pretrain-lvd1689m",
35
- "facebook/dinov3-convnext-base-pretrain-lvd1689m",
36
- "facebook/dinov3-convnext-large-pretrain-lvd1689m",
37
- "facebook/dinov3-vitl16-pretrain-sat493m",
38
- "facebook/dinov3-vit7b16-pretrain-sat493m",
39
- ]
40
-
41
- DATASET_LABELS = {
42
- "LVD-1689M": "lvd1689m",
43
- "SAT-493M": "sat493m",
44
- }
45
-
46
- def build_model_maps(model_ids: List[str]):
47
- """
48
- Returns:
49
- valid_map[(dataset_key, short_name)] -> full_model_id
50
- options_by_dataset[dataset_key] -> [short_name,...] (display order preserved)
51
- """
52
- valid_map: Dict[Tuple[str, str], str] = {}
53
- options_by_dataset: Dict[str, List[str]] = {"lvd1689m": [], "sat493m": []}
54
-
55
- for mid in model_ids:
56
- # Expect pattern: "facebook/dinov3-<short>-pretrain-<dataset>"
57
- try:
58
- prefix = "facebook/dinov3-"
59
- start = mid.index(prefix) + len(prefix)
60
- pre_idx = mid.index("-pretrain", start)
61
- short = mid[start:pre_idx]
62
- dataset = mid.split("-pretrain-")[-1].strip()
63
- except Exception:
64
- # Skip anything that doesn't match the expected pattern
65
- continue
66
-
67
- key = (dataset, short)
68
- valid_map[key] = mid
69
- if dataset in options_by_dataset and short not in options_by_dataset[dataset]:
70
- options_by_dataset[dataset].append(short)
71
-
72
- return valid_map, options_by_dataset
73
-
74
- VALID_MODEL_MAP, MODEL_OPTIONS_BY_DATASET = build_model_maps(MODEL_ID_LIST)
75
-
76
- # ---------- Defaults / knobs ----------
77
- DEFAULT_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
78
- DEFAULT_DATASET_LABEL = "LVD-1689M" # initial radio
79
- DEFAULT_OVERLAY_ALPHA = 0.55
80
- DEFAULT_SHOW_GRID = True
81
-
82
- # ---------- Sample image URLs (dependent on dataset) ----------
83
- SAMPLE_URL_CHOICES: Dict[str, List[Tuple[str, str]]] = {
84
- # LVD: current ones
85
- "lvd1689m": [
86
- ("– choose a sample –", ""),
87
- ("COCO: 2 Cats on sofa (039769)", "http://images.cocodataset.org/val2017/000000039769.jpg"),
88
- ("COCO: Person skiing (000785)", "http://images.cocodataset.org/val2017/000000000785.jpg"),
89
- ("COCO: People running (000872)", "http://images.cocodataset.org/val2017/000000000872.jpg"),
90
- ("Picsum: Mountain (ID=1000)", "https://picsum.photos/id/1000/800/600"),
91
- ("Picsum: Kayak (ID=1011)", "https://picsum.photos/id/1011/800/600"),
92
- ("Picsum: Man and dog (ID=1012)", "https://picsum.photos/id/1012/800/600"),
93
- ],
94
- # SAT: satellite imagery examples
95
- "sat493m": [
96
- ("– choose a satellite sample –", ""),
97
- ("Blue Marble (NASA)", "https://upload.wikimedia.org/wikipedia/commons/9/9d/The_Blue_Marble_%28remastered%29.jpg"),
98
- ("GOES-16 Hurricane Florence (2018)", "https://upload.wikimedia.org/wikipedia/commons/5/5e/Hurricane_Florence_GOES-16_2018-09-12_1510Z.jpg"),
99
- ("NASA Earth Observatory: Philippines", "https://eoimages.gsfc.nasa.gov/images/imagerecords/151000/151639/philippines_tmo_2020118_lrg.jpg"),
100
- ],
101
- }
102
-
103
- def _sample_labels_for(dataset_label: str):
104
- key = DATASET_LABELS.get(dataset_label, "lvd1689m")
105
- return [label for label, _ in SAMPLE_URL_CHOICES.get(key, [])]
106
-
107
- def _apply_sample(dataset_label: str, sample_label: str):
108
- """Fill textbox with chosen sample URL and clear any uploaded image."""
109
- key = DATASET_LABELS.get(dataset_label, "lvd1689m")
110
- sample_map = dict(SAMPLE_URL_CHOICES.get(key, []))
111
- url = sample_map.get(sample_label, "")
112
- return gr.update(value=url), None # (textbox update, clear upload)
113
-
114
- # ---------- Utility ----------
115
- def load_image_from_any(src: Optional[Image.Image], url: Optional[str]) -> Optional[Image.Image]:
116
- # Prefer URL if present
117
- if url and str(url).strip().lower().startswith(("http://", "https://")):
118
- with urllib.request.urlopen(url) as resp:
119
- data = resp.read()
120
- return Image.open(io.BytesIO(data)).convert("RGB")
121
- if isinstance(src, Image.Image):
122
- return src.convert("RGB")
123
- return None
124
-
125
- def pad_to_multiple(pil_img: Image.Image, multiple: int = 16) -> Tuple[Image.Image, Tuple[int, int, int, int]]:
126
- W, H = pil_img.size
127
- H_pad = int(math.ceil(H / multiple) * multiple)
128
- W_pad = int(math.ceil(W / multiple) * multiple)
129
- if (H_pad, W_pad) == (H, W):
130
- return pil_img, (0, 0, 0, 0)
131
- canvas = Image.new("RGB", (W_pad, H_pad), (0, 0, 0))
132
- canvas.paste(pil_img, (0, 0))
133
- return canvas, (0, 0, W_pad - W, H_pad - H)
134
-
135
- def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16):
136
- img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
137
- transform = transforms.Compose([
138
- transforms.ToTensor(),
139
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
140
- std =[0.229, 0.224, 0.225]),
141
- ])
142
- pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
143
- disp_np = np.array(img_padded, dtype=np.uint8)
144
- return {"pixel_values": pixel_tensor}, disp_np, pad_box
145
-
146
- def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int) -> np.ndarray:
147
- if arr.ndim == 2:
148
- return arr.repeat(ps, 0).repeat(ps, 1)
149
- elif arr.ndim == 3:
150
- rows, cols, ch = arr.shape
151
- arr2 = arr.repeat(ps, 0).repeat(ps, 1)
152
- return arr2.reshape(rows * ps, cols * ps, ch)
153
- raise ValueError("upsample_nearest expects (rows,cols) or (rows,cols,channels)")
154
-
155
- def blend_overlay(base_uint8: np.ndarray, overlay_rgb_float: np.ndarray, alpha: float) -> np.ndarray:
156
- base = base_uint8.astype(np.float32)
157
- over = (overlay_rgb_float * 255.0).astype(np.float32)
158
- out = (1.0 - alpha) * base + alpha * over
159
- return np.clip(out, 0, 255).astype(np.uint8)
160
-
161
- def draw_grid(img: Image.Image, rows: int, cols: int, ps: int):
162
- d = ImageDraw.Draw(img)
163
- W, H = img.size
164
- for r in range(1, rows):
165
- y = r * ps
166
- d.line([(0, y), (W, y)], fill=(255, 255, 255), width=1)
167
- for c in range(1, cols):
168
- x = c * ps
169
- d.line([(x, 0), (x, H)], fill=(255, 255, 255), width=1)
170
-
171
- def rc_to_idx(r: int, c: int, cols: int) -> int:
172
- return int(r) * cols + int(c)
173
-
174
- def idx_to_rc(i: int, cols: int) -> Tuple[int, int]:
175
- return int(i) // cols, int(i) % cols
176
-
177
- # ---------- Model cache ----------
178
- @lru_cache(maxsize=3)
179
- def load_model_cached(full_model_id: str, device_str: str):
180
- device = torch.device(device_str)
181
- model = AutoModel.from_pretrained(full_model_id).to(device)
182
- model.eval()
183
- return model
184
-
185
- def infer_patch_size(model, default: int = 16) -> int:
186
- if hasattr(model, "config") and hasattr(model.config, "patch_size"):
187
- ps = model.config.patch_size
188
- if isinstance(ps, (tuple, list)): return int(ps[0])
189
- return int(ps)
190
- if hasattr(model, "patch_size"):
191
- ps = model.patch_size
192
- if isinstance(ps, (tuple, list)): return int(ps[0])
193
- return int(ps)
194
- return default
195
-
196
- # ---------- Per-image state ----------
197
- class PatchImageState:
198
- def __init__(self, pil_img: Image.Image, model, device_str: str, ps: int):
199
- self.pil = pil_img
200
- self.ps = ps
201
- inputs, disp_np, _ = preprocess_no_resize(pil_img, multiple=ps)
202
- self.disp = disp_np
203
- pv = inputs["pixel_values"].to(device_str) # (1,3,H,W)
204
- _, _, H, W = pv.shape
205
- self.H, self.W = int(H), int(W)
206
- self.rows, self.cols = self.H // ps, self.W // ps
207
-
208
- with torch.no_grad():
209
- out = model(pixel_values=pv)
210
- hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy() # (T,D)
211
-
212
- T, D = hs.shape
213
- n_patches = self.rows * self.cols
214
- n_special = T - n_patches # class + maybe registers
215
- if n_special < 1:
216
- raise RuntimeError(
217
- f"Token mismatch: T={T}, rows*cols={n_patches}, HxW={self.H}x{self.W}, ps={ps}"
218
- )
219
- self.D = D
220
- patches = hs[n_special:, :].reshape(self.rows, self.cols, D)
221
- self.X = patches.reshape(-1, D)
222
- self.Xn = self.X / (np.linalg.norm(self.X, axis=1, keepdims=True) + 1e-8)
223
-
224
- # ---------- Rendering / compute ----------
225
- def render_with_cosmap(
226
- st: PatchImageState,
227
- cos_map: Optional[np.ndarray],
228
- overlay_alpha: float,
229
- show_grid_flag: bool,
230
- select_idx: Optional[int] = None,
231
- best_idx: Optional[int] = None,
232
- ) -> Image.Image:
233
- H, W, ps = st.H, st.W, st.ps
234
- rows, cols = st.rows, st.cols
235
-
236
- if cos_map is None:
237
- disp = np.full((rows, cols), 0.5, dtype=np.float32)
238
- else:
239
- vmin, vmax = float(cos_map.min()), float(cos_map.max())
240
- rng = vmax - vmin if vmax > vmin else 1e-8
241
- disp = (cos_map - vmin) / rng
242
-
243
- cmap = cm.get_cmap("magma")
244
- rgba = cmap(disp)
245
- rgb = rgba[..., :3]
246
-
247
- if select_idx is not None:
248
- rs, cs = idx_to_rc(select_idx, cols)
249
- rgb[rs, cs, :] = np.array([1.0, 0.0, 0.0], dtype=np.float32)
250
-
251
- over_rgb_up = upsample_nearest(rgb, H, W, ps)
252
- blended = blend_overlay(st.disp, over_rgb_up, float(overlay_alpha))
253
- pil = Image.fromarray(blended)
254
-
255
- draw = ImageDraw.Draw(pil)
256
- if show_grid_flag:
257
- draw_grid(pil, rows, cols, ps)
258
-
259
- if select_idx is not None:
260
- r, c = idx_to_rc(select_idx, cols)
261
- x0, y0 = c * ps, r * ps
262
- x1, y1 = x0 + ps - 1, y0 + ps - 1
263
- draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 0, 0), width=2)
264
-
265
- if best_idx is not None:
266
- r, c = idx_to_rc(best_idx, cols)
267
- x0, y0 = c * ps, r * ps
268
- x1, y1 = x0 + ps - 1, y0 + ps - 1
269
- draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 255, 0), width=2)
270
-
271
- return pil
272
-
273
- def compute_self_and_cross(
274
- src: PatchImageState,
275
- tgt: Optional[PatchImageState],
276
- q_idx: int,
277
- ):
278
- q = src.X[q_idx]
279
- qn = q / (np.linalg.norm(q) + 1e-8)
280
-
281
- cos_self = src.Xn @ qn
282
- cos_map_self = cos_self.reshape(src.rows, src.cols)
283
- self_stats = (float(cos_map_self.min()), float(cos_map_self.max()))
284
-
285
- cross_result = None
286
- cos_map_cross = None
287
- if tgt is not None:
288
- cos_cross = tgt.Xn @ qn
289
- cos_map_cross = cos_cross.reshape(tgt.rows, tgt.cols)
290
- cross_min, cross_max = float(cos_map_cross.min()), float(cos_map_cross.max())
291
- best_idx = int(np.argmax(cos_cross))
292
- cross_result = (cross_min, cross_max, best_idx)
293
-
294
- return cos_map_self, cos_map_cross, self_stats, cross_result
295
-
296
- # ---------- Gradio helpers for model & samples ----------
297
- def dataset_label_to_key(label: str) -> str:
298
- return DATASET_LABELS.get(label, "lvd1689m")
299
-
300
- def update_model_dropdown(dataset_label: str):
301
- key = dataset_label_to_key(dataset_label)
302
- opts = MODEL_OPTIONS_BY_DATASET.get(key, [])
303
- default_val = opts[0] if opts else None
304
- return gr.update(choices=opts, value=default_val)
305
-
306
- def update_model_and_samples(dataset_label: str):
307
- # Update model dropdown
308
- model_update = update_model_dropdown(dataset_label)
309
- # Update both sample dropdowns to dataset-specific options
310
- labels = _sample_labels_for(dataset_label)
311
- sample_update = gr.update(choices=labels, value=(labels[0] if labels else None))
312
- return model_update, sample_update, sample_update
313
-
314
- def resolve_full_model_id(dataset_label: str, short_name: str) -> Optional[str]:
315
- key = (dataset_label_to_key(dataset_label), short_name)
316
- return VALID_MODEL_MAP.get(key)
317
-
318
- # ---------- Gradio callbacks ----------
319
- def init_states(
320
- left_img_in: Optional[Image.Image],
321
- left_url: str,
322
- right_img_in: Optional[Image.Image],
323
- right_url: str,
324
- dataset_label: str,
325
- short_model: str,
326
- show_grid_flag: bool,
327
- overlay_alpha: float,
328
- ):
329
- # Resolve images
330
- left_img = load_image_from_any(left_img_in, left_url)
331
- right_img = load_image_from_any(right_img_in, right_url)
332
- if left_img is None and right_img is None:
333
- left_img = load_image_from_any(None, DEFAULT_URL)
334
-
335
- # Resolve model
336
- full_model_id = resolve_full_model_id(dataset_label, short_model)
337
- if not full_model_id:
338
- return (gr.update(), gr.update(), None, None, 0, -1, -1, 16,
339
- f"❌ Model not available: {dataset_label} / {short_model}")
340
-
341
- device_str = "cuda" if torch.cuda.is_available() else "cpu"
342
- model = load_model_cached(full_model_id, device_str)
343
- ps = infer_patch_size(model, 16)
344
-
345
- left_state = PatchImageState(left_img, model, device_str, ps) if left_img is not None else None
346
- right_state = PatchImageState(right_img, model, device_str, ps) if right_img is not None else None
347
-
348
- active_side = 0 if left_state is not None else 1
349
-
350
- status = f"✔ Loaded: {full_model_id} | ps={ps}"
351
- out_left, out_right = None, None
352
-
353
- if left_state is not None and right_state is not None:
354
- q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
355
- cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
356
- best_idx = cross_info[2] if cross_info else None
357
- out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
358
- select_idx=q_idx, best_idx=None)
359
- out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
360
- select_idx=None, best_idx=best_idx)
361
- status += (f" | LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}] "
362
- f"| RIGHT cross best={best_idx}")
363
- left_idx, right_idx = q_idx, (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
364
- elif left_state is not None:
365
- q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
366
- cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
367
- out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
368
- select_idx=q_idx, best_idx=None)
369
- status += f" | Single LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}]"
370
- left_idx, right_idx = q_idx, -1
371
- else:
372
- q_idx = (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
373
- cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
374
- out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
375
- select_idx=q_idx, best_idx=None)
376
- status += f" | Single RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}]"
377
- left_idx, right_idx = -1, q_idx
378
-
379
- return (
380
- out_left, out_right,
381
- left_state, right_state,
382
- active_side,
383
- left_idx, right_idx,
384
- ps,
385
- status
386
- )
387
-
388
- def _coords_to_idx(x: int, y: int, st: PatchImageState) -> int:
389
- r = int(np.clip(y // st.ps, 0, st.rows - 1))
390
- c = int(np.clip(x // st.ps, 0, st.cols - 1))
391
- return rc_to_idx(r, c, st.cols)
392
-
393
- def on_select_left(
394
- evt: gr.SelectData,
395
- left_state: Optional[PatchImageState],
396
- right_state: Optional[PatchImageState],
397
- show_grid_flag: bool,
398
- overlay_alpha: float,
399
- ps: int,
400
- ):
401
- if left_state is None:
402
- return gr.update(), gr.update(), 0, -1, -1, "Upload/Load a LEFT image first."
403
-
404
- x, y = evt.index
405
- q_idx = _coords_to_idx(x, y, left_state)
406
-
407
- if right_state is not None:
408
- cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
409
- best_idx = cross_info[2]
410
- out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
411
- select_idx=q_idx, best_idx=None)
412
- out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
413
- select_idx=None, best_idx=best_idx)
414
- status = (f"LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}] | "
415
- f"RIGHT cross best idx={best_idx}")
416
- return out_left, out_right, 0, q_idx, -1, status
417
- else:
418
- cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
419
- out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
420
- select_idx=q_idx, best_idx=None)
421
- status = f"Single LEFT idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
422
- return out_left, gr.update(), 0, q_idx, -1, status
423
-
424
- def on_select_right(
425
- evt: gr.SelectData,
426
- left_state: Optional[PatchImageState],
427
- right_state: Optional[PatchImageState],
428
- show_grid_flag: bool,
429
- overlay_alpha: float,
430
- ps: int,
431
- ):
432
- if right_state is None:
433
- return gr.update(), gr.update(), 1, -1, -1, "Upload/Load a RIGHT image first."
434
-
435
- x, y = evt.index
436
- q_idx = _coords_to_idx(x, y, right_state)
437
-
438
- if left_state is not None:
439
- cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(right_state, left_state, q_idx)
440
- best_idx = cross_info[2]
441
- out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
442
- select_idx=q_idx, best_idx=None)
443
- out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
444
- select_idx=None, best_idx=best_idx)
445
- status = (f"RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}] | "
446
- f"LEFT cross best idx={best_idx}")
447
- return out_left, out_right, 1, -1, q_idx, status
448
- else:
449
- cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
450
- out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
451
- select_idx=q_idx, best_idx=None)
452
- status = f"Single RIGHT idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
453
- return gr.update(), out_right, 1, -1, q_idx, status
454
-
455
- def rebuild_with_settings(
456
- left_state: Optional[PatchImageState],
457
- right_state: Optional[PatchImageState],
458
- active_side: int,
459
- left_idx: int,
460
- right_idx: int,
461
- show_grid_flag: bool,
462
- overlay_alpha: float,
463
- ps: int,
464
- ):
465
- if left_state is None and right_state is None:
466
- return gr.update(), gr.update(), "Load an image first."
467
-
468
- if left_state is not None and right_state is not None:
469
- if active_side == 0:
470
- q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
471
- cos_self, cos_cross, _, cross_info = compute_self_and_cross(left_state, right_state, q_idx)
472
- best_idx = cross_info[2]
473
- out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
474
- select_idx=q_idx, best_idx=None)
475
- out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
476
- select_idx=None, best_idx=best_idx)
477
- else:
478
- q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
479
- cos_self, cos_cross, _, cross_info = compute_self_and_cross(right_state, left_state, q_idx)
480
- best_idx = cross_info[2]
481
- out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
482
- select_idx=q_idx, best_idx=None)
483
- out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
484
- select_idx=None, best_idx=best_idx)
485
- return out_left, out_right, "Updated overlays."
486
- elif left_state is not None:
487
- q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
488
- cos_self, _, _, _ = compute_self_and_cross(left_state, None, q_idx)
489
- out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
490
- select_idx=q_idx, best_idx=None)
491
- return out_left, gr.update(), "Updated overlays."
492
- else:
493
- q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
494
- cos_self, _, _, _ = compute_self_and_cross(right_state, None, q_idx)
495
- out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
496
- select_idx=q_idx, best_idx=None)
497
- return gr.update(), out_right, "Updated overlays."
498
-
499
- # ---------- Gradio UI ----------
500
- with gr.Blocks(title="DINOv3 Patch Similarity (Self & Cross)") as demo:
501
- gr.Markdown(
502
- """
503
- # DINOv3 Patch Similarity (Self & Cross)
504
- 1) Pick **Dataset** (LVD-1689M / SAT-493M).
505
- 2) Pick **Model**.
506
- 3) Upload one or two images (or paste URLs) and press **Initialize / Update**.
507
- - Click on a patch to update overlays.
508
- - In two-image mode, the non-active image hides the red selection and shows **yellow** best match.
509
- """
510
- )
511
-
512
- with gr.Row():
513
- dataset_radio = gr.Radio(
514
- label="Dataset",
515
- choices=list(DATASET_LABELS.keys()),
516
- value=DEFAULT_DATASET_LABEL,
517
- interactive=True
518
- )
519
- initial_key = DATASET_LABELS[DEFAULT_DATASET_LABEL]
520
- initial_models = MODEL_OPTIONS_BY_DATASET.get(initial_key, [])
521
- model_dropdown = gr.Dropdown(
522
- label="Model name",
523
- choices=initial_models,
524
- value=(initial_models[0] if initial_models else None),
525
- interactive=True
526
- )
527
-
528
- # initial sample labels based on default dataset
529
- initial_sample_labels = [label for label, _ in SAMPLE_URL_CHOICES.get(initial_key, [])]
530
-
531
- with gr.Row():
532
- with gr.Column():
533
- left_input = gr.Image(label="Left Image (upload)", type="pil",
534
- sources=["upload", "clipboard", "webcam"], interactive=True)
535
- left_url = gr.Textbox(label="Left Image URL (optional)", placeholder="https://...")
536
- left_sample = gr.Dropdown(label="Use a sample URL",
537
- choices=initial_sample_labels,
538
- value=(initial_sample_labels[0] if initial_sample_labels else None),
539
- interactive=True)
540
- with gr.Column():
541
- right_input = gr.Image(label="Right Image (upload)", type="pil",
542
- sources=["upload", "clipboard", "webcam"], interactive=True)
543
- right_url = gr.Textbox(label="Right Image URL (optional)", placeholder="https://...")
544
- right_sample = gr.Dropdown(label="Use a sample URL",
545
- choices=initial_sample_labels,
546
- value=(initial_sample_labels[0] if initial_sample_labels else None),
547
- interactive=True)
548
-
549
- with gr.Accordion("Overlay Settings", open=True):
550
- show_grid = gr.Checkbox(label="Show patch grid", value=DEFAULT_SHOW_GRID)
551
- overlay_alpha = gr.Slider(label="Overlay alpha", minimum=0.0, maximum=1.0,
552
- value=DEFAULT_OVERLAY_ALPHA, step=0.01)
553
-
554
- init_btn = gr.Button("Initialize / Update", variant="primary")
555
-
556
- with gr.Row():
557
- left_view = gr.Image(label="LEFT (click to select patch)", interactive=True)
558
- right_view = gr.Image(label="RIGHT (click to select patch)", interactive=True)
559
-
560
- status = gr.Markdown("")
561
-
562
- # Hidden states
563
- left_state = gr.State(None)
564
- right_state = gr.State(None)
565
- active_side = gr.State(0)
566
- left_idx = gr.State(-1)
567
- right_idx = gr.State(-1)
568
- ps_state = gr.State(16)
569
-
570
- # Update model dropdown and sample lists when dataset changes
571
- dataset_radio.change(
572
- fn=update_model_and_samples,
573
- inputs=[dataset_radio],
574
- outputs=[model_dropdown, left_sample, right_sample]
575
- )
576
-
577
- # When a sample is chosen, set URL and clear any uploaded image (prefer URL)
578
- left_sample.change(
579
- fn=_apply_sample,
580
- inputs=[dataset_radio, left_sample],
581
- outputs=[left_url, left_input]
582
- )
583
- right_sample.change(
584
- fn=_apply_sample,
585
- inputs=[dataset_radio, right_sample],
586
- outputs=[right_url, right_input]
587
- )
588
-
589
- # Initialize / reload model + overlays
590
- init_btn.click(
591
- fn=init_states,
592
- inputs=[left_input, left_url, right_input, right_url, dataset_radio, model_dropdown, show_grid, overlay_alpha],
593
- outputs=[left_view, right_view, left_state, right_state, active_side, left_idx, right_idx, ps_state, status],
594
- show_progress=True
595
- )
596
-
597
- # Click handlers
598
- left_view.select(
599
- fn=on_select_left,
600
- inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
601
- outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
602
- )
603
- right_view.select(
604
- fn=on_select_right,
605
- inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
606
- outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
607
- )
608
-
609
- # Live re-render on setting changes
610
- show_grid.change(
611
- fn=rebuild_with_settings,
612
- inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
613
- outputs=[left_view, right_view, status]
614
- )
615
- overlay_alpha.change(
616
- fn=rebuild_with_settings,
617
- inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
618
- outputs=[left_view, right_view, status]
619
- )
620
-
621
- if __name__ == "__main__":
622
- demo.queue().launch()
 
 
 
1
+ # app.py
2
+ # Gradio UI for interactive DINOv3 patch similarity (single or dual image)
3
+ # - No AutoImageProcessor, no resize (only pad to multiple of patch size)
4
+ # - Single image: click to show self-similarity; selected cell outlined in RED
5
+ # - Two images: click on one side -> self overlay on source, cross overlay on target; best match on target outlined in YELLOW
6
+ # - Red selection rectangle is hidden on the non-active image
7
+ # - Patch size inferred from model (no override). Patch indices are not annotated.
8
+ # - Dataset selector (LVD-1689M / SAT-493M); model dropdown shows only the short name between "dinov3-" and "-pretrain".
9
+ # - Sample URL dropdowns switch between LVD (COCO/Picsum) and SAT (satellite imagery) and auto-fill / clear uploads.
10
+
11
+ import io
12
+ import math
13
+ import urllib.request
14
+ from functools import lru_cache
15
+ from typing import Optional, Tuple, Dict, List
16
+
17
+ import gradio as gr
18
+ import numpy as np
19
+ from PIL import Image, ImageDraw
20
+ import torch
21
+ from torchvision import transforms
22
+ from transformers import AutoModel
23
+ from matplotlib import colormaps as cm
24
+
25
+ token = os.environ.get("HF_TOKEN")
26
+
27
+ # ---------- Provided model IDs (ground truth list) ----------
28
+ MODEL_ID_LIST = [
29
+ "facebook/dinov3-vits16-pretrain-lvd1689m",
30
+ "facebook/dinov3-vits16plus-pretrain-lvd1689m",
31
+ "facebook/dinov3-vitb16-pretrain-lvd1689m",
32
+ "facebook/dinov3-vitl16-pretrain-lvd1689m",
33
+ "facebook/dinov3-vith16plus-pretrain-lvd1689m",
34
+ "facebook/dinov3-vit7b16-pretrain-lvd1689m",
35
+ "facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
36
+ "facebook/dinov3-convnext-small-pretrain-lvd1689m",
37
+ "facebook/dinov3-convnext-base-pretrain-lvd1689m",
38
+ "facebook/dinov3-convnext-large-pretrain-lvd1689m",
39
+ "facebook/dinov3-vitl16-pretrain-sat493m",
40
+ "facebook/dinov3-vit7b16-pretrain-sat493m",
41
+ ]
42
+
43
+ DATASET_LABELS = {
44
+ "LVD-1689M": "lvd1689m",
45
+ "SAT-493M": "sat493m",
46
+ }
47
+
48
+ def build_model_maps(model_ids: List[str]):
49
+ """
50
+ Returns:
51
+ valid_map[(dataset_key, short_name)] -> full_model_id
52
+ options_by_dataset[dataset_key] -> [short_name,...] (display order preserved)
53
+ """
54
+ valid_map: Dict[Tuple[str, str], str] = {}
55
+ options_by_dataset: Dict[str, List[str]] = {"lvd1689m": [], "sat493m": []}
56
+
57
+ for mid in model_ids:
58
+ # Expect pattern: "facebook/dinov3-<short>-pretrain-<dataset>"
59
+ try:
60
+ prefix = "facebook/dinov3-"
61
+ start = mid.index(prefix) + len(prefix)
62
+ pre_idx = mid.index("-pretrain", start)
63
+ short = mid[start:pre_idx]
64
+ dataset = mid.split("-pretrain-")[-1].strip()
65
+ except Exception:
66
+ # Skip anything that doesn't match the expected pattern
67
+ continue
68
+
69
+ key = (dataset, short)
70
+ valid_map[key] = mid
71
+ if dataset in options_by_dataset and short not in options_by_dataset[dataset]:
72
+ options_by_dataset[dataset].append(short)
73
+
74
+ return valid_map, options_by_dataset
75
+
76
+ VALID_MODEL_MAP, MODEL_OPTIONS_BY_DATASET = build_model_maps(MODEL_ID_LIST)
77
+
78
+ # ---------- Defaults / knobs ----------
79
+ DEFAULT_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
80
+ DEFAULT_DATASET_LABEL = "LVD-1689M" # initial radio
81
+ DEFAULT_OVERLAY_ALPHA = 0.55
82
+ DEFAULT_SHOW_GRID = True
83
+
84
+ # ---------- Sample image URLs (dependent on dataset) ----------
85
+ SAMPLE_URL_CHOICES: Dict[str, List[Tuple[str, str]]] = {
86
+ # LVD: current ones
87
+ "lvd1689m": [
88
+ (" choose a sample –", ""),
89
+ ("COCO: 2 Cats on sofa (039769)", "http://images.cocodataset.org/val2017/000000039769.jpg"),
90
+ ("COCO: Person skiing (000785)", "http://images.cocodataset.org/val2017/000000000785.jpg"),
91
+ ("COCO: People running (000872)", "http://images.cocodataset.org/val2017/000000000872.jpg"),
92
+ ("Picsum: Mountain (ID=1000)", "https://picsum.photos/id/1000/800/600"),
93
+ ("Picsum: Kayak (ID=1011)", "https://picsum.photos/id/1011/800/600"),
94
+ ("Picsum: Man and dog (ID=1012)", "https://picsum.photos/id/1012/800/600"),
95
+ ],
96
+ # SAT: satellite imagery examples
97
+ "sat493m": [
98
+ (" choose a satellite sample –", ""),
99
+ ("Blue Marble (NASA)", "https://upload.wikimedia.org/wikipedia/commons/9/9d/The_Blue_Marble_%28remastered%29.jpg"),
100
+ ("GOES-16 Hurricane Florence (2018)", "https://upload.wikimedia.org/wikipedia/commons/5/5e/Hurricane_Florence_GOES-16_2018-09-12_1510Z.jpg"),
101
+ ("NASA Earth Observatory: Philippines", "https://eoimages.gsfc.nasa.gov/images/imagerecords/151000/151639/philippines_tmo_2020118_lrg.jpg"),
102
+ ],
103
+ }
104
+
105
+ def _sample_labels_for(dataset_label: str):
106
+ key = DATASET_LABELS.get(dataset_label, "lvd1689m")
107
+ return [label for label, _ in SAMPLE_URL_CHOICES.get(key, [])]
108
+
109
+ def _apply_sample(dataset_label: str, sample_label: str):
110
+ """Fill textbox with chosen sample URL and clear any uploaded image."""
111
+ key = DATASET_LABELS.get(dataset_label, "lvd1689m")
112
+ sample_map = dict(SAMPLE_URL_CHOICES.get(key, []))
113
+ url = sample_map.get(sample_label, "")
114
+ return gr.update(value=url), None # (textbox update, clear upload)
115
+
116
+ # ---------- Utility ----------
117
+ def load_image_from_any(src: Optional[Image.Image], url: Optional[str]) -> Optional[Image.Image]:
118
+ # Prefer URL if present
119
+ if url and str(url).strip().lower().startswith(("http://", "https://")):
120
+ with urllib.request.urlopen(url) as resp:
121
+ data = resp.read()
122
+ return Image.open(io.BytesIO(data)).convert("RGB")
123
+ if isinstance(src, Image.Image):
124
+ return src.convert("RGB")
125
+ return None
126
+
127
+ def pad_to_multiple(pil_img: Image.Image, multiple: int = 16) -> Tuple[Image.Image, Tuple[int, int, int, int]]:
128
+ W, H = pil_img.size
129
+ H_pad = int(math.ceil(H / multiple) * multiple)
130
+ W_pad = int(math.ceil(W / multiple) * multiple)
131
+ if (H_pad, W_pad) == (H, W):
132
+ return pil_img, (0, 0, 0, 0)
133
+ canvas = Image.new("RGB", (W_pad, H_pad), (0, 0, 0))
134
+ canvas.paste(pil_img, (0, 0))
135
+ return canvas, (0, 0, W_pad - W, H_pad - H)
136
+
137
+ def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16):
138
+ img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
139
+ transform = transforms.Compose([
140
+ transforms.ToTensor(),
141
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
142
+ std =[0.229, 0.224, 0.225]),
143
+ ])
144
+ pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
145
+ disp_np = np.array(img_padded, dtype=np.uint8)
146
+ return {"pixel_values": pixel_tensor}, disp_np, pad_box
147
+
148
+ def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int) -> np.ndarray:
149
+ if arr.ndim == 2:
150
+ return arr.repeat(ps, 0).repeat(ps, 1)
151
+ elif arr.ndim == 3:
152
+ rows, cols, ch = arr.shape
153
+ arr2 = arr.repeat(ps, 0).repeat(ps, 1)
154
+ return arr2.reshape(rows * ps, cols * ps, ch)
155
+ raise ValueError("upsample_nearest expects (rows,cols) or (rows,cols,channels)")
156
+
157
+ def blend_overlay(base_uint8: np.ndarray, overlay_rgb_float: np.ndarray, alpha: float) -> np.ndarray:
158
+ base = base_uint8.astype(np.float32)
159
+ over = (overlay_rgb_float * 255.0).astype(np.float32)
160
+ out = (1.0 - alpha) * base + alpha * over
161
+ return np.clip(out, 0, 255).astype(np.uint8)
162
+
163
+ def draw_grid(img: Image.Image, rows: int, cols: int, ps: int):
164
+ d = ImageDraw.Draw(img)
165
+ W, H = img.size
166
+ for r in range(1, rows):
167
+ y = r * ps
168
+ d.line([(0, y), (W, y)], fill=(255, 255, 255), width=1)
169
+ for c in range(1, cols):
170
+ x = c * ps
171
+ d.line([(x, 0), (x, H)], fill=(255, 255, 255), width=1)
172
+
173
+ def rc_to_idx(r: int, c: int, cols: int) -> int:
174
+ return int(r) * cols + int(c)
175
+
176
+ def idx_to_rc(i: int, cols: int) -> Tuple[int, int]:
177
+ return int(i) // cols, int(i) % cols
178
+
179
+ # ---------- Model cache ----------
180
+ @lru_cache(maxsize=3)
181
+ def load_model_cached(full_model_id: str, device_str: str):
182
+ device = torch.device(device_str)
183
+ model = AutoModel.from_pretrained(full_model_id).to(device)
184
+ model.eval()
185
+ return model
186
+
187
+ def infer_patch_size(model, default: int = 16) -> int:
188
+ if hasattr(model, "config") and hasattr(model.config, "patch_size"):
189
+ ps = model.config.patch_size
190
+ if isinstance(ps, (tuple, list)): return int(ps[0])
191
+ return int(ps)
192
+ if hasattr(model, "patch_size"):
193
+ ps = model.patch_size
194
+ if isinstance(ps, (tuple, list)): return int(ps[0])
195
+ return int(ps)
196
+ return default
197
+
198
+ # ---------- Per-image state ----------
199
+ class PatchImageState:
200
+ def __init__(self, pil_img: Image.Image, model, device_str: str, ps: int):
201
+ self.pil = pil_img
202
+ self.ps = ps
203
+ inputs, disp_np, _ = preprocess_no_resize(pil_img, multiple=ps)
204
+ self.disp = disp_np
205
+ pv = inputs["pixel_values"].to(device_str) # (1,3,H,W)
206
+ _, _, H, W = pv.shape
207
+ self.H, self.W = int(H), int(W)
208
+ self.rows, self.cols = self.H // ps, self.W // ps
209
+
210
+ with torch.no_grad():
211
+ out = model(pixel_values=pv)
212
+ hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy() # (T,D)
213
+
214
+ T, D = hs.shape
215
+ n_patches = self.rows * self.cols
216
+ n_special = T - n_patches # class + maybe registers
217
+ if n_special < 1:
218
+ raise RuntimeError(
219
+ f"Token mismatch: T={T}, rows*cols={n_patches}, HxW={self.H}x{self.W}, ps={ps}"
220
+ )
221
+ self.D = D
222
+ patches = hs[n_special:, :].reshape(self.rows, self.cols, D)
223
+ self.X = patches.reshape(-1, D)
224
+ self.Xn = self.X / (np.linalg.norm(self.X, axis=1, keepdims=True) + 1e-8)
225
+
226
+ # ---------- Rendering / compute ----------
227
+ def render_with_cosmap(
228
+ st: PatchImageState,
229
+ cos_map: Optional[np.ndarray],
230
+ overlay_alpha: float,
231
+ show_grid_flag: bool,
232
+ select_idx: Optional[int] = None,
233
+ best_idx: Optional[int] = None,
234
+ ) -> Image.Image:
235
+ H, W, ps = st.H, st.W, st.ps
236
+ rows, cols = st.rows, st.cols
237
+
238
+ if cos_map is None:
239
+ disp = np.full((rows, cols), 0.5, dtype=np.float32)
240
+ else:
241
+ vmin, vmax = float(cos_map.min()), float(cos_map.max())
242
+ rng = vmax - vmin if vmax > vmin else 1e-8
243
+ disp = (cos_map - vmin) / rng
244
+
245
+ cmap = cm.get_cmap("magma")
246
+ rgba = cmap(disp)
247
+ rgb = rgba[..., :3]
248
+
249
+ if select_idx is not None:
250
+ rs, cs = idx_to_rc(select_idx, cols)
251
+ rgb[rs, cs, :] = np.array([1.0, 0.0, 0.0], dtype=np.float32)
252
+
253
+ over_rgb_up = upsample_nearest(rgb, H, W, ps)
254
+ blended = blend_overlay(st.disp, over_rgb_up, float(overlay_alpha))
255
+ pil = Image.fromarray(blended)
256
+
257
+ draw = ImageDraw.Draw(pil)
258
+ if show_grid_flag:
259
+ draw_grid(pil, rows, cols, ps)
260
+
261
+ if select_idx is not None:
262
+ r, c = idx_to_rc(select_idx, cols)
263
+ x0, y0 = c * ps, r * ps
264
+ x1, y1 = x0 + ps - 1, y0 + ps - 1
265
+ draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 0, 0), width=2)
266
+
267
+ if best_idx is not None:
268
+ r, c = idx_to_rc(best_idx, cols)
269
+ x0, y0 = c * ps, r * ps
270
+ x1, y1 = x0 + ps - 1, y0 + ps - 1
271
+ draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 255, 0), width=2)
272
+
273
+ return pil
274
+
275
+ def compute_self_and_cross(
276
+ src: PatchImageState,
277
+ tgt: Optional[PatchImageState],
278
+ q_idx: int,
279
+ ):
280
+ q = src.X[q_idx]
281
+ qn = q / (np.linalg.norm(q) + 1e-8)
282
+
283
+ cos_self = src.Xn @ qn
284
+ cos_map_self = cos_self.reshape(src.rows, src.cols)
285
+ self_stats = (float(cos_map_self.min()), float(cos_map_self.max()))
286
+
287
+ cross_result = None
288
+ cos_map_cross = None
289
+ if tgt is not None:
290
+ cos_cross = tgt.Xn @ qn
291
+ cos_map_cross = cos_cross.reshape(tgt.rows, tgt.cols)
292
+ cross_min, cross_max = float(cos_map_cross.min()), float(cos_map_cross.max())
293
+ best_idx = int(np.argmax(cos_cross))
294
+ cross_result = (cross_min, cross_max, best_idx)
295
+
296
+ return cos_map_self, cos_map_cross, self_stats, cross_result
297
+
298
+ # ---------- Gradio helpers for model & samples ----------
299
+ def dataset_label_to_key(label: str) -> str:
300
+ return DATASET_LABELS.get(label, "lvd1689m")
301
+
302
+ def update_model_dropdown(dataset_label: str):
303
+ key = dataset_label_to_key(dataset_label)
304
+ opts = MODEL_OPTIONS_BY_DATASET.get(key, [])
305
+ default_val = opts[0] if opts else None
306
+ return gr.update(choices=opts, value=default_val)
307
+
308
+ def update_model_and_samples(dataset_label: str):
309
+ # Update model dropdown
310
+ model_update = update_model_dropdown(dataset_label)
311
+ # Update both sample dropdowns to dataset-specific options
312
+ labels = _sample_labels_for(dataset_label)
313
+ sample_update = gr.update(choices=labels, value=(labels[0] if labels else None))
314
+ return model_update, sample_update, sample_update
315
+
316
+ def resolve_full_model_id(dataset_label: str, short_name: str) -> Optional[str]:
317
+ key = (dataset_label_to_key(dataset_label), short_name)
318
+ return VALID_MODEL_MAP.get(key)
319
+
320
+ # ---------- Gradio callbacks ----------
321
+ def init_states(
322
+ left_img_in: Optional[Image.Image],
323
+ left_url: str,
324
+ right_img_in: Optional[Image.Image],
325
+ right_url: str,
326
+ dataset_label: str,
327
+ short_model: str,
328
+ show_grid_flag: bool,
329
+ overlay_alpha: float,
330
+ ):
331
+ # Resolve images
332
+ left_img = load_image_from_any(left_img_in, left_url)
333
+ right_img = load_image_from_any(right_img_in, right_url)
334
+ if left_img is None and right_img is None:
335
+ left_img = load_image_from_any(None, DEFAULT_URL)
336
+
337
+ # Resolve model
338
+ full_model_id = resolve_full_model_id(dataset_label, short_model)
339
+ if not full_model_id:
340
+ return (gr.update(), gr.update(), None, None, 0, -1, -1, 16,
341
+ f"❌ Model not available: {dataset_label} / {short_model}")
342
+
343
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
344
+ model = load_model_cached(full_model_id, device_str)
345
+ ps = infer_patch_size(model, 16)
346
+
347
+ left_state = PatchImageState(left_img, model, device_str, ps) if left_img is not None else None
348
+ right_state = PatchImageState(right_img, model, device_str, ps) if right_img is not None else None
349
+
350
+ active_side = 0 if left_state is not None else 1
351
+
352
+ status = f"✔ Loaded: {full_model_id} | ps={ps}"
353
+ out_left, out_right = None, None
354
+
355
+ if left_state is not None and right_state is not None:
356
+ q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
357
+ cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
358
+ best_idx = cross_info[2] if cross_info else None
359
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
360
+ select_idx=q_idx, best_idx=None)
361
+ out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
362
+ select_idx=None, best_idx=best_idx)
363
+ status += (f" | LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}] "
364
+ f"| RIGHT cross best={best_idx}")
365
+ left_idx, right_idx = q_idx, (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
366
+ elif left_state is not None:
367
+ q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
368
+ cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
369
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
370
+ select_idx=q_idx, best_idx=None)
371
+ status += f" | Single LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}]"
372
+ left_idx, right_idx = q_idx, -1
373
+ else:
374
+ q_idx = (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
375
+ cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
376
+ out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
377
+ select_idx=q_idx, best_idx=None)
378
+ status += f" | Single RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}]"
379
+ left_idx, right_idx = -1, q_idx
380
+
381
+ return (
382
+ out_left, out_right,
383
+ left_state, right_state,
384
+ active_side,
385
+ left_idx, right_idx,
386
+ ps,
387
+ status
388
+ )
389
+
390
+ def _coords_to_idx(x: int, y: int, st: PatchImageState) -> int:
391
+ r = int(np.clip(y // st.ps, 0, st.rows - 1))
392
+ c = int(np.clip(x // st.ps, 0, st.cols - 1))
393
+ return rc_to_idx(r, c, st.cols)
394
+
395
+ def on_select_left(
396
+ evt: gr.SelectData,
397
+ left_state: Optional[PatchImageState],
398
+ right_state: Optional[PatchImageState],
399
+ show_grid_flag: bool,
400
+ overlay_alpha: float,
401
+ ps: int,
402
+ ):
403
+ if left_state is None:
404
+ return gr.update(), gr.update(), 0, -1, -1, "Upload/Load a LEFT image first."
405
+
406
+ x, y = evt.index
407
+ q_idx = _coords_to_idx(x, y, left_state)
408
+
409
+ if right_state is not None:
410
+ cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
411
+ best_idx = cross_info[2]
412
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
413
+ select_idx=q_idx, best_idx=None)
414
+ out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
415
+ select_idx=None, best_idx=best_idx)
416
+ status = (f"LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}] | "
417
+ f"RIGHT cross best idx={best_idx}")
418
+ return out_left, out_right, 0, q_idx, -1, status
419
+ else:
420
+ cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
421
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
422
+ select_idx=q_idx, best_idx=None)
423
+ status = f"Single LEFT • idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
424
+ return out_left, gr.update(), 0, q_idx, -1, status
425
+
426
+ def on_select_right(
427
+ evt: gr.SelectData,
428
+ left_state: Optional[PatchImageState],
429
+ right_state: Optional[PatchImageState],
430
+ show_grid_flag: bool,
431
+ overlay_alpha: float,
432
+ ps: int,
433
+ ):
434
+ if right_state is None:
435
+ return gr.update(), gr.update(), 1, -1, -1, "Upload/Load a RIGHT image first."
436
+
437
+ x, y = evt.index
438
+ q_idx = _coords_to_idx(x, y, right_state)
439
+
440
+ if left_state is not None:
441
+ cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(right_state, left_state, q_idx)
442
+ best_idx = cross_info[2]
443
+ out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
444
+ select_idx=q_idx, best_idx=None)
445
+ out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
446
+ select_idx=None, best_idx=best_idx)
447
+ status = (f"RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}] | "
448
+ f"LEFT cross best idx={best_idx}")
449
+ return out_left, out_right, 1, -1, q_idx, status
450
+ else:
451
+ cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
452
+ out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
453
+ select_idx=q_idx, best_idx=None)
454
+ status = f"Single RIGHT • idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
455
+ return gr.update(), out_right, 1, -1, q_idx, status
456
+
457
+ def rebuild_with_settings(
458
+ left_state: Optional[PatchImageState],
459
+ right_state: Optional[PatchImageState],
460
+ active_side: int,
461
+ left_idx: int,
462
+ right_idx: int,
463
+ show_grid_flag: bool,
464
+ overlay_alpha: float,
465
+ ps: int,
466
+ ):
467
+ if left_state is None and right_state is None:
468
+ return gr.update(), gr.update(), "Load an image first."
469
+
470
+ if left_state is not None and right_state is not None:
471
+ if active_side == 0:
472
+ q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
473
+ cos_self, cos_cross, _, cross_info = compute_self_and_cross(left_state, right_state, q_idx)
474
+ best_idx = cross_info[2]
475
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
476
+ select_idx=q_idx, best_idx=None)
477
+ out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
478
+ select_idx=None, best_idx=best_idx)
479
+ else:
480
+ q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
481
+ cos_self, cos_cross, _, cross_info = compute_self_and_cross(right_state, left_state, q_idx)
482
+ best_idx = cross_info[2]
483
+ out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
484
+ select_idx=q_idx, best_idx=None)
485
+ out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
486
+ select_idx=None, best_idx=best_idx)
487
+ return out_left, out_right, "Updated overlays."
488
+ elif left_state is not None:
489
+ q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
490
+ cos_self, _, _, _ = compute_self_and_cross(left_state, None, q_idx)
491
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
492
+ select_idx=q_idx, best_idx=None)
493
+ return out_left, gr.update(), "Updated overlays."
494
+ else:
495
+ q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
496
+ cos_self, _, _, _ = compute_self_and_cross(right_state, None, q_idx)
497
+ out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
498
+ select_idx=q_idx, best_idx=None)
499
+ return gr.update(), out_right, "Updated overlays."
500
+
501
+ # ---------- Gradio UI ----------
502
+ with gr.Blocks(title="DINOv3 Patch Similarity (Self & Cross)") as demo:
503
+ gr.Markdown(
504
+ """
505
+ # DINOv3 Patch Similarity (Self & Cross)
506
+ 1) Pick **Dataset** (LVD-1689M / SAT-493M).
507
+ 2) Pick **Model**.
508
+ 3) Upload one or two images (or paste URLs) and press **Initialize / Update**.
509
+ - Click on a patch to update overlays.
510
+ - In two-image mode, the non-active image hides the red selection and shows **yellow** best match.
511
+ """
512
+ )
513
+
514
+ with gr.Row():
515
+ dataset_radio = gr.Radio(
516
+ label="Dataset",
517
+ choices=list(DATASET_LABELS.keys()),
518
+ value=DEFAULT_DATASET_LABEL,
519
+ interactive=True
520
+ )
521
+ initial_key = DATASET_LABELS[DEFAULT_DATASET_LABEL]
522
+ initial_models = MODEL_OPTIONS_BY_DATASET.get(initial_key, [])
523
+ model_dropdown = gr.Dropdown(
524
+ label="Model name",
525
+ choices=initial_models,
526
+ value=(initial_models[0] if initial_models else None),
527
+ interactive=True
528
+ )
529
+
530
+ # initial sample labels based on default dataset
531
+ initial_sample_labels = [label for label, _ in SAMPLE_URL_CHOICES.get(initial_key, [])]
532
+
533
+ with gr.Row():
534
+ with gr.Column():
535
+ left_input = gr.Image(label="Left Image (upload)", type="pil",
536
+ sources=["upload", "clipboard", "webcam"], interactive=True)
537
+ left_url = gr.Textbox(label="Left Image URL (optional)", placeholder="https://...")
538
+ left_sample = gr.Dropdown(label="Use a sample URL",
539
+ choices=initial_sample_labels,
540
+ value=(initial_sample_labels[0] if initial_sample_labels else None),
541
+ interactive=True)
542
+ with gr.Column():
543
+ right_input = gr.Image(label="Right Image (upload)", type="pil",
544
+ sources=["upload", "clipboard", "webcam"], interactive=True)
545
+ right_url = gr.Textbox(label="Right Image URL (optional)", placeholder="https://...")
546
+ right_sample = gr.Dropdown(label="Use a sample URL",
547
+ choices=initial_sample_labels,
548
+ value=(initial_sample_labels[0] if initial_sample_labels else None),
549
+ interactive=True)
550
+
551
+ with gr.Accordion("Overlay Settings", open=True):
552
+ show_grid = gr.Checkbox(label="Show patch grid", value=DEFAULT_SHOW_GRID)
553
+ overlay_alpha = gr.Slider(label="Overlay alpha", minimum=0.0, maximum=1.0,
554
+ value=DEFAULT_OVERLAY_ALPHA, step=0.01)
555
+
556
+ init_btn = gr.Button("Initialize / Update", variant="primary")
557
+
558
+ with gr.Row():
559
+ left_view = gr.Image(label="LEFT (click to select patch)", interactive=True)
560
+ right_view = gr.Image(label="RIGHT (click to select patch)", interactive=True)
561
+
562
+ status = gr.Markdown("")
563
+
564
+ # Hidden states
565
+ left_state = gr.State(None)
566
+ right_state = gr.State(None)
567
+ active_side = gr.State(0)
568
+ left_idx = gr.State(-1)
569
+ right_idx = gr.State(-1)
570
+ ps_state = gr.State(16)
571
+
572
+ # Update model dropdown and sample lists when dataset changes
573
+ dataset_radio.change(
574
+ fn=update_model_and_samples,
575
+ inputs=[dataset_radio],
576
+ outputs=[model_dropdown, left_sample, right_sample]
577
+ )
578
+
579
+ # When a sample is chosen, set URL and clear any uploaded image (prefer URL)
580
+ left_sample.change(
581
+ fn=_apply_sample,
582
+ inputs=[dataset_radio, left_sample],
583
+ outputs=[left_url, left_input]
584
+ )
585
+ right_sample.change(
586
+ fn=_apply_sample,
587
+ inputs=[dataset_radio, right_sample],
588
+ outputs=[right_url, right_input]
589
+ )
590
+
591
+ # Initialize / reload model + overlays
592
+ init_btn.click(
593
+ fn=init_states,
594
+ inputs=[left_input, left_url, right_input, right_url, dataset_radio, model_dropdown, show_grid, overlay_alpha],
595
+ outputs=[left_view, right_view, left_state, right_state, active_side, left_idx, right_idx, ps_state, status],
596
+ show_progress=True
597
+ )
598
+
599
+ # Click handlers
600
+ left_view.select(
601
+ fn=on_select_left,
602
+ inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
603
+ outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
604
+ )
605
+ right_view.select(
606
+ fn=on_select_right,
607
+ inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
608
+ outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
609
+ )
610
+
611
+ # Live re-render on setting changes
612
+ show_grid.change(
613
+ fn=rebuild_with_settings,
614
+ inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
615
+ outputs=[left_view, right_view, status]
616
+ )
617
+ overlay_alpha.change(
618
+ fn=rebuild_with_settings,
619
+ inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
620
+ outputs=[left_view, right_view, status]
621
+ )
622
+
623
+ if __name__ == "__main__":
624
+ demo.queue().launch()