Rausda6 commited on
Commit
bf826bc
·
verified ·
1 Parent(s): b90fe2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -57
app.py CHANGED
@@ -1,30 +1,34 @@
 
 
 
1
  import os
 
 
2
  import numpy as np
3
  from PIL import Image, ImageDraw
4
 
5
-
6
  import torch
7
  import torch.nn.functional as F
8
  import torchvision.transforms.functional as TF
9
- from transformers import AutoModel # trust_remote_code=True
10
-
11
 
12
  import gradio as gr
13
 
14
-
15
- # --- config
 
16
  DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
17
  ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
18
  AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
19
 
20
-
21
  PATCH_SIZE = 16
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
24
- IMAGENET_STD = (0.229, 0.224, 0.225)
 
25
  N_SPECIAL_TOKENS = 5
26
 
27
- # --- robust colormap import (matplotlib new/old)
28
  try:
29
  from matplotlib import colormaps as _mpl_colormaps
30
  def _get_cmap(name: str):
@@ -34,9 +38,9 @@ except Exception:
34
  def _get_cmap(name: str):
35
  return _cm.get_cmap(name)
36
 
37
- # ----------------------------
38
  # Model loading / cache
39
- # ----------------------------
40
  _model_cache = {}
41
  _current_model_id = None
42
  model = None
@@ -60,12 +64,12 @@ def get_model(model_id: str):
60
  model = get_model(DEFAULT_MODEL_ID)
61
  _current_model_id = DEFAULT_MODEL_ID
62
 
63
- # ----------------------------
64
  # Helpers
65
- # ----------------------------
66
 
67
  def resize_to_grid(img: Image.Image, long_side: int, patch: int = PATCH_SIZE) -> torch.Tensor:
68
- """Resize so max(h,w)=long_side with aspect kept; then pad up to multiples of patch.
69
  Return CHW float tensor in [0,1]."""
70
  w, h = img.size
71
  scale = long_side / max(h, w)
@@ -89,7 +93,7 @@ def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Im
89
  out = Image.alpha_composite(base, heat)
90
  return out.convert("RGB")
91
 
92
- def draw_crosshair(img: Image.Image, x: int, y: int, radius: int = None) -> Image.Image:
93
  r = radius if radius is not None else max(2, PATCH_SIZE // 2)
94
  out = img.copy()
95
  draw = ImageDraw.Draw(out)
@@ -97,12 +101,11 @@ def draw_crosshair(img: Image.Image, x: int, y: int, radius: int = None) -> Imag
97
  draw.line([(x, y - r), (x, y + r)], fill="red", width=3)
98
  return out
99
 
100
- # ----------------------------
101
  # Feature extraction
102
- # ----------------------------
103
  @torch.inference_mode()
104
  def extract_image_features(image_pil: Image.Image, target_long_side: int, mdl=None):
105
- global model
106
  mdl = mdl or model
107
  t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
108
  t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
@@ -116,16 +119,9 @@ def extract_image_features(image_pil: Image.Image, target_long_side: int, mdl=No
116
 
117
  return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
118
 
119
- # ----------------------------
120
  # Similarity utilities
121
- # ----------------------------
122
-
123
- def index_from_xy(x_pix: int, y_pix: int, Wp: int) -> int:
124
- col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
125
- row = int(np.clip(y_pix // PATCH_SIZE, 0, (x_pix*0 + y_pix) // PATCH_SIZE)) # placeholder row calc replaced below
126
- return row * Wp + col
127
-
128
- # Corrected row/col computation helper
129
 
130
  def row_col_from_xy(x_pix: int, y_pix: int, Hp: int, Wp: int):
131
  col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
@@ -134,19 +130,9 @@ def row_col_from_xy(x_pix: int, y_pix: int, Hp: int, Wp: int):
134
 
135
  @torch.inference_mode()
136
  def similarity_map(X: torch.Tensor, Hp: int, Wp: int, q_vec: torch.Tensor,
137
- img_h: int, img_w: int, exclude_radius_patches: int = 1):
138
  sims = torch.matmul(X, q_vec) # (Hp*Wp)
139
  sim_map = sims.view(Hp, Wp)
140
-
141
- if exclude_radius_patches > 0:
142
- rr, cc = torch.meshgrid(
143
- torch.arange(Hp, device=sims.device),
144
- torch.arange(Wp, device=sims.device),
145
- indexing="ij",
146
- )
147
- # We'll mask later at the click location per-image if needed
148
- mask_template = (rr * 0) # kept for API parity
149
-
150
  sim_up = F.interpolate(
151
  sim_map.unsqueeze(0).unsqueeze(0),
152
  size=(img_h, img_w),
@@ -155,11 +141,11 @@ def similarity_map(X: torch.Tensor, Hp: int, Wp: int, q_vec: torch.Tensor,
155
  ).squeeze().detach().cpu().numpy()
156
  return sim_map, sim_up
157
 
158
- # ----------------------------
159
  # Core: click on image 1 → heatmaps on image 1 and image 2
160
- # ----------------------------
161
 
162
- def click_two_image_similarity(state1: dict, state2: dict, click_xy: tuple[int, int],
163
  exclude_radius_patches: int, alpha: float, cmap_name: str):
164
  if not state1 or not state2:
165
  return (None,)*6
@@ -170,13 +156,13 @@ def click_two_image_similarity(state1: dict, state2: dict, click_xy: tuple[int,
170
  img1_w, img1_h = img1.size
171
  img2_w, img2_h = img2.size
172
 
173
- # Build query vector from clicked patch on image 1
174
  col = int(np.clip(click_xy[0] // PATCH_SIZE, 0, Wp1 - 1))
175
  row = int(np.clip(click_xy[1] // PATCH_SIZE, 0, Hp1 - 1))
176
  idx = row * Wp1 + col
177
  q = X1[idx] # (D,)
178
 
179
- # Similarity on image 1
180
  sims1 = torch.matmul(X1, q)
181
  sim_map1 = sims1.view(Hp1, Wp1)
182
  if exclude_radius_patches > 0:
@@ -199,7 +185,7 @@ def click_two_image_similarity(state1: dict, state2: dict, click_xy: tuple[int,
199
  overlay1 = blend(img1, heat1, alpha)
200
  marked1 = draw_crosshair(img1, int(click_xy[0]), int(click_xy[1]), radius=PATCH_SIZE // 2)
201
 
202
- # Similarity on image 2 (no exclusion mask, since click is on image 1)
203
  sims2 = torch.matmul(X2, q)
204
  sim_map2 = sims2.view(Hp2, Wp2)
205
  sim2_up = F.interpolate(
@@ -212,16 +198,14 @@ def click_two_image_similarity(state1: dict, state2: dict, click_xy: tuple[int,
212
  heat2 = colorize(sim2_up, cmap_name)
213
  overlay2 = blend(img2, heat2, alpha)
214
 
215
- return marked1, heat1, overlay1, heat2, overlay2, sim2_up.max().item()
216
 
217
- # ----------------------------
218
  # Gradio UI
219
- # ----------------------------
220
-
221
-
222
  with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarity") as demo:
223
  gr.Markdown("# DINOv3 Two‑Image Patch Similarity")
224
- gr.Markdown("Upload two images, process, then click on image 1 to see similarities on both.")
225
 
226
  state1 = gr.State()
227
  state2 = gr.State()
@@ -229,17 +213,52 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarit
229
  with gr.Row():
230
  with gr.Column():
231
  model_choice = gr.Dropdown(choices=AVAILABLE_MODELS, value=DEFAULT_MODEL_ID, label="Backbone")
232
- target_long_side = gr.Slider(224, 1024, value=768, step=16, label="Resolution")
233
  alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity")
234
  cmap = gr.Dropdown(["viridis", "magma", "plasma", "inferno", "turbo", "cividis"], value="viridis", label="Colormap")
235
- exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius")
236
  start_btn = gr.Button("▶️ Process both", variant="primary")
237
 
238
  with gr.Column():
239
- img1 = gr.Image(label="Image 1 (clickable)", type="pil", value=None)
240
- img2 = gr.Image(label="Image 2", type="pil", value=None)
241
-
242
- # (rest of app: outputs, event wiring, functions, unchanged)
243
 
244
- if __name__ == "__main__":
245
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — DINOv3 two‑image patch similarity (click on Image 1 → show similarities on both images)
2
+ # Runs on CPU or CUDA. No external image URLs.
3
+
4
  import os
5
+ from typing import Tuple
6
+
7
  import numpy as np
8
  from PIL import Image, ImageDraw
9
 
 
10
  import torch
11
  import torch.nn.functional as F
12
  import torchvision.transforms.functional as TF
13
+ from transformers import AutoModel # trust_remote_code=True
 
14
 
15
  import gradio as gr
16
 
17
+ # ============================
18
+ # Config
19
+ # ============================
20
  DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
21
  ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
22
  AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
23
 
 
24
  PATCH_SIZE = 16
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
27
+ IMAGENET_STD = (0.229, 0.224, 0.225)
28
+ # Many DINOv3 HF ports expose 1 [CLS] + 4 registers at the front
29
  N_SPECIAL_TOKENS = 5
30
 
31
+ # Robust colormap import (Matplotlib new/old)
32
  try:
33
  from matplotlib import colormaps as _mpl_colormaps
34
  def _get_cmap(name: str):
 
38
  def _get_cmap(name: str):
39
  return _cm.get_cmap(name)
40
 
41
+ # ============================
42
  # Model loading / cache
43
+ # ============================
44
  _model_cache = {}
45
  _current_model_id = None
46
  model = None
 
64
  model = get_model(DEFAULT_MODEL_ID)
65
  _current_model_id = DEFAULT_MODEL_ID
66
 
67
+ # ============================
68
  # Helpers
69
+ # ============================
70
 
71
  def resize_to_grid(img: Image.Image, long_side: int, patch: int = PATCH_SIZE) -> torch.Tensor:
72
+ """Resize so max(h,w)=long_side with aspect kept; then pad to multiples of patch.
73
  Return CHW float tensor in [0,1]."""
74
  w, h = img.size
75
  scale = long_side / max(h, w)
 
93
  out = Image.alpha_composite(base, heat)
94
  return out.convert("RGB")
95
 
96
+ def draw_crosshair(img: Image.Image, x: int, y: int, radius: int | None = None) -> Image.Image:
97
  r = radius if radius is not None else max(2, PATCH_SIZE // 2)
98
  out = img.copy()
99
  draw = ImageDraw.Draw(out)
 
101
  draw.line([(x, y - r), (x, y + r)], fill="red", width=3)
102
  return out
103
 
104
+ # ============================
105
  # Feature extraction
106
+ # ============================
107
  @torch.inference_mode()
108
  def extract_image_features(image_pil: Image.Image, target_long_side: int, mdl=None):
 
109
  mdl = mdl or model
110
  t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
111
  t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
 
119
 
120
  return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
121
 
122
+ # ============================
123
  # Similarity utilities
124
+ # ============================
 
 
 
 
 
 
 
125
 
126
  def row_col_from_xy(x_pix: int, y_pix: int, Hp: int, Wp: int):
127
  col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
 
130
 
131
  @torch.inference_mode()
132
  def similarity_map(X: torch.Tensor, Hp: int, Wp: int, q_vec: torch.Tensor,
133
+ img_h: int, img_w: int):
134
  sims = torch.matmul(X, q_vec) # (Hp*Wp)
135
  sim_map = sims.view(Hp, Wp)
 
 
 
 
 
 
 
 
 
 
136
  sim_up = F.interpolate(
137
  sim_map.unsqueeze(0).unsqueeze(0),
138
  size=(img_h, img_w),
 
141
  ).squeeze().detach().cpu().numpy()
142
  return sim_map, sim_up
143
 
144
+ # ============================
145
  # Core: click on image 1 → heatmaps on image 1 and image 2
146
+ # ============================
147
 
148
+ def click_two_image_similarity(state1: dict, state2: dict, click_xy: Tuple[int, int],
149
  exclude_radius_patches: int, alpha: float, cmap_name: str):
150
  if not state1 or not state2:
151
  return (None,)*6
 
156
  img1_w, img1_h = img1.size
157
  img2_w, img2_h = img2.size
158
 
159
+ # Query vector from clicked patch on image 1
160
  col = int(np.clip(click_xy[0] // PATCH_SIZE, 0, Wp1 - 1))
161
  row = int(np.clip(click_xy[1] // PATCH_SIZE, 0, Hp1 - 1))
162
  idx = row * Wp1 + col
163
  q = X1[idx] # (D,)
164
 
165
+ # Similarity on image 1 (+ small exclusion mask around click if requested)
166
  sims1 = torch.matmul(X1, q)
167
  sim_map1 = sims1.view(Hp1, Wp1)
168
  if exclude_radius_patches > 0:
 
185
  overlay1 = blend(img1, heat1, alpha)
186
  marked1 = draw_crosshair(img1, int(click_xy[0]), int(click_xy[1]), radius=PATCH_SIZE // 2)
187
 
188
+ # Similarity on image 2
189
  sims2 = torch.matmul(X2, q)
190
  sim_map2 = sims2.view(Hp2, Wp2)
191
  sim2_up = F.interpolate(
 
198
  heat2 = colorize(sim2_up, cmap_name)
199
  overlay2 = blend(img2, heat2, alpha)
200
 
201
+ return marked1, heat1, overlay1, heat2, overlay2, float(sim2_up.max())
202
 
203
+ # ============================
204
  # Gradio UI
205
+ # ============================
 
 
206
  with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarity") as demo:
207
  gr.Markdown("# DINOv3 Two‑Image Patch Similarity")
208
+ gr.Markdown("Upload two images and press **Process both**. Then click on **Image 1** to see similar regions on **both** images.")
209
 
210
  state1 = gr.State()
211
  state2 = gr.State()
 
213
  with gr.Row():
214
  with gr.Column():
215
  model_choice = gr.Dropdown(choices=AVAILABLE_MODELS, value=DEFAULT_MODEL_ID, label="Backbone")
216
+ target_long_side = gr.Slider(224, 1024, value=768, step=16, label="Resolution (long side)")
217
  alpha = gr.Slider(0.0, 1.0, value=0.55, step=0.05, label="Overlay opacity")
218
  cmap = gr.Dropdown(["viridis", "magma", "plasma", "inferno", "turbo", "cividis"], value="viridis", label="Colormap")
219
+ exclude_r = gr.Slider(0, 10, value=0, step=1, label="Exclude radius (patches) for Image 1")
220
  start_btn = gr.Button("▶️ Process both", variant="primary")
221
 
222
  with gr.Column():
223
+ img1 = gr.Image(label="Image 1 (clickable)", type="pil", sources=["upload", "clipboard"], value=None)
224
+ img2 = gr.Image(label="Image 2", type="pil", sources=["upload", "clipboard"], value=None)
 
 
225
 
226
+ with gr.Row():
227
+ with gr.Column():
228
+ marked1 = gr.Image(label="Image 1 — click marker / preview", interactive=False)
229
+ heat1 = gr.Image(label="Image 1 — similarity heatmap", interactive=False)
230
+ overlay1= gr.Image(label="Image 1 — overlay", interactive=False)
231
+ with gr.Column():
232
+ heat2 = gr.Image(label="Image 2 — similarity heatmap", interactive=False)
233
+ overlay2= gr.Image(label="Image 2 — overlay", interactive=False)
234
+ score2 = gr.Number(label="Image 2 — max similarity score", precision=6)
235
+
236
+ # Utilities
237
+ def _ensure_model(model_id: str):
238
+ global model, _current_model_id
239
+ if model_id != _current_model_id:
240
+ model = get_model(model_id)
241
+ _current_model_id = model_id
242
+
243
+ # Process button → extract features for both images and store in state
244
+ def _run_both(im1: Image.Image, im2: Image.Image, long_side: int, model_id: str, progress=gr.Progress(track_tqdm=False)):
245
+ if im1 is None or im2 is None:
246
+ raise gr.Error("Please provide both images before processing.")
247
+ _ensure_model(model_id)
248
+ progress(0, desc="Extracting features for Image 1…")
249
+ st1 = extract_image_features(im1, int(long_side), mdl=model)
250
+ progress(0.5, desc="Extracting features for Image 2…")
251
+ st2 = extract_image_features(im2, int(long_side), mdl=model)
252
+ progress(1, desc="Done")
253
+ # Show quick previews to confirm processing
254
+ return st1["img"], st2["img"], st1, st2
255
+
256
+ start_btn.click(
257
+ _run_both,
258
+ inputs=[img1, img2, target_long_side, model_choice],
259
+ outputs=[marked1, overlay2, state1, state2],
260
+ )
261
+
262
+ # Clicking on Image 1 → compute similarities on both images
263
+ def _on_click(st1, st2, a: float, m: str, excl: int, evt: gr.SelectData):
264
+ if not st1 or not st2 or evt is