Rausda6 commited on
Commit
0cfc510
·
verified ·
1 Parent(s): d55a3e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -208
app.py CHANGED
@@ -1,85 +1,74 @@
1
- # app.py
 
 
2
 
3
  import os
4
- import torch
5
- import torch.nn.functional as F
6
- import gradio as gr
7
  import numpy as np
8
  from PIL import Image, ImageDraw
9
- import torchvision.transforms.functional as TF
10
 
11
- # --- Robust colormap import (Matplotlib ≥3.5 and older versions) ---
12
- try:
13
- from matplotlib import colormaps as _mpl_colormaps
14
- def _get_cmap(name: str):
15
- return _mpl_colormaps[name]
16
- except Exception:
17
- import matplotlib.cm as _cm
18
- def _get_cmap(name: str):
19
- return _cm.get_cmap(name)
20
 
21
- from transformers import AutoModel # uses trust_remote_code for DINOv3
22
 
23
  # ----------------------------
24
- # Configuration
25
  # ----------------------------
26
- # Default to smaller/faster ViT-S/16+; offer ViT-H/16+ as alternative.
27
  DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
28
  ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
29
  AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
30
 
31
  PATCH_SIZE = 16
32
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
33
-
34
- # Normalization constants (standard for ImageNet)
35
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
36
- IMAGENET_STD = (0.229, 0.224, 0.225)
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # ----------------------------
39
- # Model Loading (Hugging Face Hub) with caching
40
  # ----------------------------
41
  _model_cache = {}
42
  _current_model_id = None
43
- model = None # global reference used by extract_image_features()
44
 
45
  def load_model_from_hub(model_id: str):
46
- """Loads a DINOv3 model from the Hugging Face Hub."""
47
- print(f"Loading model '{model_id}' from Hugging Face Hub...")
48
- try:
49
- token = os.environ.get("HF_TOKEN") # optional, for gated models
50
- mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
51
- mdl.to(DEVICE).eval()
52
- print(f"✅ Model '{model_id}' loaded successfully on device: {DEVICE}")
53
- return mdl
54
- except Exception as e:
55
- print(f"❌ Failed to load model '{model_id}': {e}")
56
- raise gr.Error(
57
- f"Could not load model '{model_id}'. "
58
- "If the model is gated, please accept the terms on its Hugging Face page "
59
- "and set HF_TOKEN in your environment. "
60
- f"Original error: {e}"
61
- )
62
 
63
  def get_model(model_id: str):
64
- """Return a cached model if available, otherwise load and cache it."""
65
  if model_id in _model_cache:
66
  return _model_cache[model_id]
67
  mdl = load_model_from_hub(model_id)
68
  _model_cache[model_id] = mdl
69
  return mdl
70
 
71
- # Load default model at startup
72
  model = get_model(DEFAULT_MODEL_ID)
73
  _current_model_id = DEFAULT_MODEL_ID
74
 
75
  # ----------------------------
76
- # Helper Functions (resize, viz)
77
  # ----------------------------
78
- def resize_to_grid(img: Image.Image, long_side: int, patch: int) -> torch.Tensor:
79
- """
80
- Resizes so max(h,w)=long_side (keeping aspect), then rounds each side UP to a multiple of 'patch'.
81
- Returns CHW float tensor in [0,1].
82
- """
83
  w, h = img.size
84
  scale = long_side / max(h, w)
85
  new_h = max(patch, int(round(h * scale)))
@@ -95,7 +84,6 @@ def colorize(sim_map_up: np.ndarray, cmap_name: str = "viridis") -> Image.Image:
95
  return Image.fromarray(rgb)
96
 
97
  def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
98
- # Put alpha on heatmap and composite for a crisp overlay
99
  base = base.convert("RGBA")
100
  heat = heat.convert("RGBA")
101
  a = Image.new("L", heat.size, int(255 * alpha))
@@ -111,80 +99,45 @@ def draw_crosshair(img: Image.Image, x: int, y: int, radius: int = None) -> Imag
111
  draw.line([(x, y - r), (x, y + r)], fill="red", width=3)
112
  return out
113
 
114
- def draw_boxes(img: Image.Image, boxes, outline="yellow", width=3, labels=True):
115
- out = img.copy()
116
- draw = ImageDraw.Draw(out)
117
- for i, (x0, y0, x1, y1) in enumerate(boxes, start=1):
118
- draw.rectangle([x0, y0, x1, y1], outline=outline, width=width)
119
- if labels:
120
- tx, ty = x0 + 2, y0 + 2
121
- draw.text((tx, ty), str(i), fill=outline)
122
- return out
123
-
124
- def patch_neighborhood_box(r: int, c: int, Hp: int, Wp: int, rad: int, patch: int = PATCH_SIZE):
125
- r0 = max(0, r - rad)
126
- r1 = min(Hp - 1, r + rad)
127
- c0 = max(0, c - rad)
128
- c1 = min(Wp - 1, c + rad)
129
- x0 = int(c0 * patch)
130
- y0 = int(r0 * patch)
131
- x1 = int((c1 + 1) * patch) - 1
132
- y1 = int((r1 + 1) * patch) - 1
133
- return (x0, y0, x1, y1)
134
-
135
  # ----------------------------
136
- # Feature Extraction (using transformers)
137
  # ----------------------------
138
  @torch.inference_mode()
139
- def extract_image_features(image_pil: Image.Image, target_long_side: int):
140
- """
141
- Extracts patch features from an image using the loaded Hugging Face model.
142
- """
143
  t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
144
  t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
145
  _, _, H, W = t_norm.shape
146
  Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
147
 
148
- # Models output: [CLS] + 4 register tokens + patches
149
- outputs = model(t_norm)
150
-
151
- # Skip the 5 special tokens to get only patch embeddings
152
- n_special_tokens = 5
153
- patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
154
-
155
- # L2-normalize features for cosine similarity
156
- X = F.normalize(patch_embeddings, p=2, dim=-1)
157
-
158
  img_resized = TF.to_pil_image(t)
 
159
  return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
160
 
161
  # ----------------------------
162
- # Similarity inside the same image
163
  # ----------------------------
164
- def click_to_similarity_in_same_image(
165
- state: dict,
166
- click_xy: tuple[int, int],
167
- exclude_radius_patches: int = 1,
168
- topk: int = 10,
169
- alpha: float = 0.55,
170
- cmap_name: str = "viridis",
171
- box_radius_patches: int = 4,
172
- ):
173
- if not state:
174
- return None, None, None, None
175
-
176
- X = state["X"]
177
- Hp, Wp = state["Hp"], state["Wp"]
178
- base_img = state["img"]
179
- img_w, img_h = base_img.size
180
-
181
- x_pix, y_pix = click_xy
182
  col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
183
  row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
184
- idx = row * Wp + col
185
 
186
- q = X[idx]
187
- sims = torch.matmul(X, q)
 
 
188
  sim_map = sims.view(Hp, Wp)
189
 
190
  if exclude_radius_patches > 0:
@@ -193,8 +146,8 @@ def click_to_similarity_in_same_image(
193
  torch.arange(Wp, device=sims.device),
194
  indexing="ij",
195
  )
196
- mask = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
197
- sim_map = sim_map.masked_fill(mask, float("-inf"))
198
 
199
  sim_up = F.interpolate(
200
  sim_map.unsqueeze(0).unsqueeze(0),
@@ -202,134 +155,147 @@ def click_to_similarity_in_same_image(
202
  mode="bicubic",
203
  align_corners=False,
204
  ).squeeze().detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- heatmap_pil = colorize(sim_up, cmap_name)
207
- overlay_pil = blend(base_img, heatmap_pil, alpha=alpha)
208
-
209
- overlay_boxes_pil = overlay_pil
210
- if topk and topk > 0:
211
- flat = sim_map.view(-1)
212
- valid = torch.isfinite(flat)
213
- if valid.any():
214
- vals = flat.clone()
215
- vals[~valid] = -1e9
216
- k = min(topk, int(valid.sum().item()))
217
- _, top_idx = torch.topk(vals, k=k, largest=True, sorted=True)
218
- boxes = [
219
- patch_neighborhood_box(
220
- r, c, Hp, Wp, rad=int(box_radius_patches), patch=PATCH_SIZE
221
- )
222
- for r, c in [divmod(j.item(), Wp) for j in top_idx]
223
- ]
224
- overlay_boxes_pil = draw_boxes(overlay_pil, boxes, outline="yellow", width=3, labels=True)
225
-
226
- marked_ref = draw_crosshair(base_img, x_pix, y_pix, radius=PATCH_SIZE // 2)
227
- return marked_ref, heatmap_pil, overlay_pil, overlay_boxes_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  # ----------------------------
230
- # Gradio UI (Manual-only processing)
231
  # ----------------------------
232
- with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Single-Image Patch Similarity") as demo:
233
- gr.Markdown("# 🦖 DINOv3 Single-Image Patch Similarity")
234
- gr.Markdown("Upload one image, adjust settings, then press **▶️ Start processing**. Click on the processed image to find similar regions.")
 
 
235
 
236
- app_state = gr.State()
 
237
 
238
  with gr.Row():
239
- with gr.Column(scale=1):
240
- input_image = gr.Image(
241
- label="Image (click anywhere after processing)",
242
- type="pil",
243
- value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg"
244
- )
245
- target_long_side = gr.Slider(
246
- minimum=224, maximum=1024, value=768, step=16,
247
- label="Processing Resolution",
248
- info="Higher values = more detail but slower processing",
249
- )
250
- with gr.Row():
251
- alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity")
252
- cmap = gr.Dropdown(
253
- ["viridis", "magma", "plasma", "inferno", "turbo", "cividis"],
254
- value="viridis", label="Colormap",
255
- )
256
- # Backbone selector (default = smaller/faster ViT-S/16+)
257
  model_choice = gr.Dropdown(
258
  choices=AVAILABLE_MODELS,
259
  value=DEFAULT_MODEL_ID,
260
  label="Backbone (DINOv3)",
261
- info="ViT-S/16+ is smaller & faster; ViT-H/16+ is larger.",
262
  )
263
- # Start processing button (manual trigger)
264
- with gr.Row():
265
- start_btn = gr.Button("▶️ Start processing", variant="primary")
266
-
267
- with gr.Column(scale=1):
268
- exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius (patches)")
269
- topk = gr.Slider(0, 200, value=20, step=1, label="Top-K boxes")
270
- box_radius = gr.Slider(0, 10, value=1, step=1, label="Box radius (patches)")
 
 
 
271
 
272
  with gr.Row():
273
- marked_image = gr.Image(label="Click marker / Preview", interactive=False)
274
- heatmap_output = gr.Image(label="Similarity heatmap", interactive=False)
275
- with gr.Row():
276
- overlay_output = gr.Image(label="Overlay (image heatmap)", interactive=False)
277
- overlay_boxes_output = gr.Image(label="Overlay + top-K similar patch boxes", interactive=False)
278
-
 
 
 
 
279
  def _ensure_model(model_id: str):
280
- """Ensure the global 'model' matches the dropdown selection."""
281
  global model, _current_model_id
282
  if model_id != _current_model_id:
283
  model = get_model(model_id)
284
  _current_model_id = model_id
285
 
286
- # Manual feature extraction (only runs on Start button)
287
- def _run_extraction(img: Image.Image, long_side: int, model_id: str, progress=gr.Progress(track_tqdm=True)):
288
- if img is None:
289
- return None, None
290
  _ensure_model(model_id)
291
- progress(0, desc="Extracting features...")
292
- st = extract_image_features(img, int(long_side))
293
- progress(1, desc="Done!")
294
- return st["img"], st
295
-
296
- # Clicking on processed image to compute similarities
297
- def _on_click(st, a: float, m: str, excl: int, k: int, box_rad: int, evt: gr.SelectData):
298
- if not st or evt is None:
299
- return None, None, None, None
300
- return click_to_similarity_in_same_image(
301
- st, click_xy=evt.index, exclude_radius_patches=int(excl),
302
- topk=int(k), alpha=float(a), cmap_name=m,
303
- box_radius_patches=int(box_rad),
304
- )
305
-
306
- # On image change: just preview and clear outputs/state (NO extraction)
307
- def _on_image_changed(img: Image.Image):
308
- if img is None:
309
- return None, None, None, None, None
310
- return img, None, None, None, None
311
 
312
- # ---------- Wiring (Manual mode) ----------
313
- # Do NOT auto-run on upload/slider/model change or on app load.
314
- # Only the Start button triggers extraction.
315
  start_btn.click(
316
- _run_extraction,
317
- inputs=[input_image, target_long_side, model_choice],
318
- outputs=[marked_image, app_state],
319
  )
320
 
321
- # When a new image is picked, show it as preview and clear old results.
322
- input_image.change(
323
- _on_image_changed,
324
- inputs=[input_image],
325
- outputs=[marked_image, app_state, heatmap_output, overlay_output, overlay_boxes_output],
326
- )
 
 
 
327
 
328
- # Keep click handler the same.
329
- marked_image.select(
330
  _on_click,
331
- inputs=[app_state, alpha, cmap, exclude_r, topk, box_radius],
332
- outputs=[marked_image, heatmap_output, overlay_output, overlay_boxes_output],
333
  )
334
 
335
  if __name__ == "__main__":
 
1
+ # app.py — DINOv3 two‑image patch‑similarity (click on image 1 → show similarities on both images)
2
+ # Derived from: https://huggingface.co/spaces/sayedM/DINOv3-features (single‑image version)
3
+ # Adds: second image input, dual feature extraction, and cross‑image similarity heatmaps/overlays.
4
 
5
  import os
 
 
 
6
  import numpy as np
7
  from PIL import Image, ImageDraw
 
8
 
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms.functional as TF
12
+ from transformers import AutoModel # trust_remote_code=True (DINOv3 on HF)
 
 
 
 
 
13
 
14
+ import gradio as gr
15
 
16
  # ----------------------------
17
+ # Config
18
  # ----------------------------
 
19
  DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
20
  ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
21
  AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
22
 
23
  PATCH_SIZE = 16
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
25
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
26
+ IMAGENET_STD = (0.229, 0.224, 0.225)
27
+ N_SPECIAL_TOKENS = 5 # [CLS] + 4 registers in dinov3 HF ports
28
+
29
+ # --- robust colormap import (matplotlib new/old)
30
+ try:
31
+ from matplotlib import colormaps as _mpl_colormaps
32
+ def _get_cmap(name: str):
33
+ return _mpl_colormaps[name]
34
+ except Exception:
35
+ import matplotlib.cm as _cm
36
+ def _get_cmap(name: str):
37
+ return _cm.get_cmap(name)
38
 
39
  # ----------------------------
40
+ # Model loading / cache
41
  # ----------------------------
42
  _model_cache = {}
43
  _current_model_id = None
44
+ model = None
45
 
46
  def load_model_from_hub(model_id: str):
47
+ print(f"Loading model '{model_id}' from HF Hub")
48
+ token = os.environ.get("HF_TOKEN")
49
+ mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
50
+ mdl.to(DEVICE).eval()
51
+ print(f"✅ Loaded '{model_id}' on {DEVICE}")
52
+ return mdl
 
 
 
 
 
 
 
 
 
 
53
 
54
  def get_model(model_id: str):
 
55
  if model_id in _model_cache:
56
  return _model_cache[model_id]
57
  mdl = load_model_from_hub(model_id)
58
  _model_cache[model_id] = mdl
59
  return mdl
60
 
61
+ # Load default at startup
62
  model = get_model(DEFAULT_MODEL_ID)
63
  _current_model_id = DEFAULT_MODEL_ID
64
 
65
  # ----------------------------
66
+ # Helpers
67
  # ----------------------------
68
+
69
+ def resize_to_grid(img: Image.Image, long_side: int, patch: int = PATCH_SIZE) -> torch.Tensor:
70
+ """Resize so max(h,w)=long_side with aspect kept; then pad up to multiples of patch.
71
+ Return CHW float tensor in [0,1]."""
 
72
  w, h = img.size
73
  scale = long_side / max(h, w)
74
  new_h = max(patch, int(round(h * scale)))
 
84
  return Image.fromarray(rgb)
85
 
86
  def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
 
87
  base = base.convert("RGBA")
88
  heat = heat.convert("RGBA")
89
  a = Image.new("L", heat.size, int(255 * alpha))
 
99
  draw.line([(x, y - r), (x, y + r)], fill="red", width=3)
100
  return out
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  # ----------------------------
103
+ # Feature extraction
104
  # ----------------------------
105
  @torch.inference_mode()
106
+ def extract_image_features(image_pil: Image.Image, target_long_side: int, mdl=None):
107
+ global model
108
+ mdl = mdl or model
 
109
  t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
110
  t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
111
  _, _, H, W = t_norm.shape
112
  Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
113
 
114
+ outputs = mdl(t_norm)
115
+ patch_emb = outputs.last_hidden_state.squeeze(0)[N_SPECIAL_TOKENS:, :] # skip special tokens
116
+ X = F.normalize(patch_emb, p=2, dim=-1) # (Hp*Wp, D), L2 norm for cosine
 
 
 
 
 
 
 
117
  img_resized = TF.to_pil_image(t)
118
+
119
  return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
120
 
121
  # ----------------------------
122
+ # Similarity utilities
123
  # ----------------------------
124
+
125
+ def index_from_xy(x_pix: int, y_pix: int, Wp: int) -> int:
126
+ col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
127
+ row = int(np.clip(y_pix // PATCH_SIZE, 0, (x_pix*0 + y_pix) // PATCH_SIZE)) # placeholder row calc replaced below
128
+ return row * Wp + col
129
+
130
+ # Corrected row/col computation helper
131
+
132
+ def row_col_from_xy(x_pix: int, y_pix: int, Hp: int, Wp: int):
 
 
 
 
 
 
 
 
 
133
  col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
134
  row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
135
+ return row, col
136
 
137
+ @torch.inference_mode()
138
+ def similarity_map(X: torch.Tensor, Hp: int, Wp: int, q_vec: torch.Tensor,
139
+ img_h: int, img_w: int, exclude_radius_patches: int = 1):
140
+ sims = torch.matmul(X, q_vec) # (Hp*Wp)
141
  sim_map = sims.view(Hp, Wp)
142
 
143
  if exclude_radius_patches > 0:
 
146
  torch.arange(Wp, device=sims.device),
147
  indexing="ij",
148
  )
149
+ # We'll mask later at the click location per-image if needed
150
+ mask_template = (rr * 0) # kept for API parity
151
 
152
  sim_up = F.interpolate(
153
  sim_map.unsqueeze(0).unsqueeze(0),
 
155
  mode="bicubic",
156
  align_corners=False,
157
  ).squeeze().detach().cpu().numpy()
158
+ return sim_map, sim_up
159
+
160
+ # ----------------------------
161
+ # Core: click on image 1 → heatmaps on image 1 and image 2
162
+ # ----------------------------
163
+
164
+ def click_two_image_similarity(state1: dict, state2: dict, click_xy: tuple[int, int],
165
+ exclude_radius_patches: int, alpha: float, cmap_name: str):
166
+ if not state1 or not state2:
167
+ return (None,)*6
168
+
169
+ X1, Hp1, Wp1, img1 = state1["X"], state1["Hp"], state1["Wp"], state1["img"]
170
+ X2, Hp2, Wp2, img2 = state2["X"], state2["Hp"], state2["Wp"], state2["img"]
171
 
172
+ img1_w, img1_h = img1.size
173
+ img2_w, img2_h = img2.size
174
+
175
+ # Build query vector from clicked patch on image 1
176
+ col = int(np.clip(click_xy[0] // PATCH_SIZE, 0, Wp1 - 1))
177
+ row = int(np.clip(click_xy[1] // PATCH_SIZE, 0, Hp1 - 1))
178
+ idx = row * Wp1 + col
179
+ q = X1[idx] # (D,)
180
+
181
+ # Similarity on image 1
182
+ sims1 = torch.matmul(X1, q)
183
+ sim_map1 = sims1.view(Hp1, Wp1)
184
+ if exclude_radius_patches > 0:
185
+ rr, cc = torch.meshgrid(
186
+ torch.arange(Hp1, device=sims1.device),
187
+ torch.arange(Wp1, device=sims1.device),
188
+ indexing="ij",
189
+ )
190
+ mask1 = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
191
+ sim_map1 = sim_map1.masked_fill(mask1, float("-inf"))
192
+
193
+ sim1_up = F.interpolate(
194
+ sim_map1.unsqueeze(0).unsqueeze(0),
195
+ size=(img1_h, img1_w),
196
+ mode="bicubic",
197
+ align_corners=False,
198
+ ).squeeze().detach().cpu().numpy()
199
+
200
+ heat1 = colorize(sim1_up, cmap_name)
201
+ overlay1 = blend(img1, heat1, alpha)
202
+ marked1 = draw_crosshair(img1, int(click_xy[0]), int(click_xy[1]), radius=PATCH_SIZE // 2)
203
+
204
+ # Similarity on image 2 (no exclusion mask, since click is on image 1)
205
+ sims2 = torch.matmul(X2, q)
206
+ sim_map2 = sims2.view(Hp2, Wp2)
207
+ sim2_up = F.interpolate(
208
+ sim_map2.unsqueeze(0).unsqueeze(0),
209
+ size=(img2_h, img2_w),
210
+ mode="bicubic",
211
+ align_corners=False,
212
+ ).squeeze().detach().cpu().numpy()
213
+
214
+ heat2 = colorize(sim2_up, cmap_name)
215
+ overlay2 = blend(img2, heat2, alpha)
216
+
217
+ return marked1, heat1, overlay1, heat2, overlay2, sim2_up.max().item()
218
 
219
  # ----------------------------
220
+ # Gradio UI
221
  # ----------------------------
222
+ with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarity") as demo:
223
+ gr.Markdown("# DINOv3 Two‑Image Patch Similarity")
224
+ gr.Markdown(
225
+ "Upload two images and press **▶️ Process**. Click a location on **Image 1** to see similar regions on **both** images."
226
+ )
227
 
228
+ state1 = gr.State()
229
+ state2 = gr.State()
230
 
231
  with gr.Row():
232
+ with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  model_choice = gr.Dropdown(
234
  choices=AVAILABLE_MODELS,
235
  value=DEFAULT_MODEL_ID,
236
  label="Backbone (DINOv3)",
237
+ info="ViTS/16+ is smaller & faster; ViTH/16+ is larger.",
238
  )
239
+ target_long_side = gr.Slider(224, 1024, value=768, step=16, label="Processing resolution")
240
+ alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity")
241
+ cmap = gr.Dropdown(["viridis", "magma", "plasma", "inferno", "turbo", "cividis"], value="viridis", label="Colormap")
242
+ exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius (patches, only for Image 1)")
243
+ start_btn = gr.Button("▶️ Process both", variant="primary")
244
+
245
+ with gr.Column():
246
+ img1 = gr.Image(label="Image 1 (clickable after processing)", type="pil",
247
+ value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg")
248
+ img2 = gr.Image(label="Image 2", type="pil",
249
+ value="https://upload.wikimedia.org/wikipedia/commons/9/99/Golden_retriever_eating_pigs_foot.jpg")
250
 
251
  with gr.Row():
252
+ with gr.Column():
253
+ marked1 = gr.Image(label="Image 1 — click marker / preview", interactive=False)
254
+ heat1 = gr.Image(label="Image 1 — similarity heatmap", interactive=False)
255
+ overlay1 = gr.Image(label="Image 1 overlay", interactive=False)
256
+ with gr.Column():
257
+ heat2 = gr.Image(label="Image 2 — similarity heatmap", interactive=False)
258
+ overlay2 = gr.Image(label="Image 2 — overlay", interactive=False)
259
+ score2 = gr.Number(label="Image 2 — max similarity score", precision=6)
260
+
261
+ # Utilities
262
  def _ensure_model(model_id: str):
 
263
  global model, _current_model_id
264
  if model_id != _current_model_id:
265
  model = get_model(model_id)
266
  _current_model_id = model_id
267
 
268
+ # Process button: extract features for both images
269
+ def _run_both(im1: Image.Image, im2: Image.Image, long_side: int, model_id: str, progress=gr.Progress(track_tqdm=False)):
270
+ if im1 is None or im2 is None:
271
+ raise gr.Error("Please provide both images.")
272
  _ensure_model(model_id)
273
+ progress(0, desc="Extracting features")
274
+ st1 = extract_image_features(im1, int(long_side), mdl=model)
275
+ st2 = extract_image_features(im2, int(long_side), mdl=model)
276
+ progress(1, desc="Done")
277
+ return st1["img"], st2["img"], st1, st2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
 
 
 
279
  start_btn.click(
280
+ _run_both,
281
+ inputs=[img1, img2, target_long_side, model_choice],
282
+ outputs=[marked1, overlay2, state1, state2], # show previews in two spots to confirm processing
283
  )
284
 
285
+ # Clicking on Image 1
286
+ def _on_click(st1, st2, a: float, m: str, excl: int, evt: gr.SelectData):
287
+ if not st1 or not st2 or evt is None:
288
+ return (None,)*6
289
+ return click_two_image_similarity(
290
+ st1, st2, click_xy=evt.index,
291
+ exclude_radius_patches=int(excl),
292
+ alpha=float(a), cmap_name=m,
293
+ )
294
 
295
+ marked1.select(
 
296
  _on_click,
297
+ inputs=[state1, state2, alpha, cmap, exclude_r],
298
+ outputs=[marked1, heat1, overlay1, heat2, overlay2, score2],
299
  )
300
 
301
  if __name__ == "__main__":