Rausda6 commited on
Commit
b0a5be5
·
verified ·
1 Parent(s): 5356d23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -13
app.py CHANGED
@@ -1,31 +1,216 @@
1
- # The error comes from trying to set a remote image URL (`value=...`) in `gr.Image`, which Gradio tries to download and cache.
2
- # In Spaces with restricted networking, this fails with 404. Fix: use `value=None` or a local placeholder.
3
-
4
- import os
5
- import numpy as np
6
- from PIL import Image, ImageDraw
7
-
8
- import torch
9
  import torch.nn.functional as F
10
  import torchvision.transforms.functional as TF
11
- from transformers import AutoModel # trust_remote_code=True
 
12
 
13
  import gradio as gr
14
 
 
15
  # --- config
16
  DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
17
  ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
18
  AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
19
 
 
20
  PATCH_SIZE = 16
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
23
- IMAGENET_STD = (0.229, 0.224, 0.225)
24
  N_SPECIAL_TOKENS = 5
25
 
26
- # (rest of code identical to previous, omitted here for brevity)
27
- # ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
29
  with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarity") as demo:
30
  gr.Markdown("# DINOv3 Two‑Image Patch Similarity")
31
  gr.Markdown("Upload two images, process, then click on image 1 to see similarities on both.")
@@ -49,4 +234,4 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarit
49
  # (rest of app: outputs, event wiring, functions, unchanged)
50
 
51
  if __name__ == "__main__":
52
- demo.launch()
 
 
 
 
 
 
 
 
 
1
  import torch.nn.functional as F
2
  import torchvision.transforms.functional as TF
3
+ from transformers import AutoModel # trust_remote_code=True
4
+
5
 
6
  import gradio as gr
7
 
8
+
9
  # --- config
10
  DEFAULT_MODEL_ID = "facebook/dinov3-vits16plus-pretrain-lvd1689m"
11
  ALT_MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
12
  AVAILABLE_MODELS = [DEFAULT_MODEL_ID, ALT_MODEL_ID]
13
 
14
+
15
  PATCH_SIZE = 16
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
18
+ IMAGENET_STD = (0.229, 0.224, 0.225)
19
  N_SPECIAL_TOKENS = 5
20
 
21
+ # --- robust colormap import (matplotlib new/old)
22
+ try:
23
+ from matplotlib import colormaps as _mpl_colormaps
24
+ def _get_cmap(name: str):
25
+ return _mpl_colormaps[name]
26
+ except Exception:
27
+ import matplotlib.cm as _cm
28
+ def _get_cmap(name: str):
29
+ return _cm.get_cmap(name)
30
+
31
+ # ----------------------------
32
+ # Model loading / cache
33
+ # ----------------------------
34
+ _model_cache = {}
35
+ _current_model_id = None
36
+ model = None
37
+
38
+ def load_model_from_hub(model_id: str):
39
+ print(f"Loading model '{model_id}' from HF Hub…")
40
+ token = os.environ.get("HF_TOKEN")
41
+ mdl = AutoModel.from_pretrained(model_id, token=token, trust_remote_code=True)
42
+ mdl.to(DEVICE).eval()
43
+ print(f"✅ Loaded '{model_id}' on {DEVICE}")
44
+ return mdl
45
+
46
+ def get_model(model_id: str):
47
+ if model_id in _model_cache:
48
+ return _model_cache[model_id]
49
+ mdl = load_model_from_hub(model_id)
50
+ _model_cache[model_id] = mdl
51
+ return mdl
52
+
53
+ # Load default at startup
54
+ model = get_model(DEFAULT_MODEL_ID)
55
+ _current_model_id = DEFAULT_MODEL_ID
56
+
57
+ # ----------------------------
58
+ # Helpers
59
+ # ----------------------------
60
+
61
+ def resize_to_grid(img: Image.Image, long_side: int, patch: int = PATCH_SIZE) -> torch.Tensor:
62
+ """Resize so max(h,w)=long_side with aspect kept; then pad up to multiples of patch.
63
+ Return CHW float tensor in [0,1]."""
64
+ w, h = img.size
65
+ scale = long_side / max(h, w)
66
+ new_h = max(patch, int(round(h * scale)))
67
+ new_w = max(patch, int(round(w * scale)))
68
+ new_h = ((new_h + patch - 1) // patch) * patch
69
+ new_w = ((new_w + patch - 1) // patch) * patch
70
+ return TF.to_tensor(TF.resize(img.convert("RGB"), (new_h, new_w)))
71
+
72
+ def colorize(sim_map_up: np.ndarray, cmap_name: str = "viridis") -> Image.Image:
73
+ x = sim_map_up.astype(np.float32)
74
+ x = (x - x.min()) / (x.max() - x.min() + 1e-6)
75
+ rgb = (_get_cmap(cmap_name)(x)[..., :3] * 255).astype(np.uint8)
76
+ return Image.fromarray(rgb)
77
+
78
+ def blend(base: Image.Image, heat: Image.Image, alpha: float = 0.55) -> Image.Image:
79
+ base = base.convert("RGBA")
80
+ heat = heat.convert("RGBA")
81
+ a = Image.new("L", heat.size, int(255 * alpha))
82
+ heat.putalpha(a)
83
+ out = Image.alpha_composite(base, heat)
84
+ return out.convert("RGB")
85
+
86
+ def draw_crosshair(img: Image.Image, x: int, y: int, radius: int = None) -> Image.Image:
87
+ r = radius if radius is not None else max(2, PATCH_SIZE // 2)
88
+ out = img.copy()
89
+ draw = ImageDraw.Draw(out)
90
+ draw.line([(x - r, y), (x + r, y)], fill="red", width=3)
91
+ draw.line([(x, y - r), (x, y + r)], fill="red", width=3)
92
+ return out
93
+
94
+ # ----------------------------
95
+ # Feature extraction
96
+ # ----------------------------
97
+ @torch.inference_mode()
98
+ def extract_image_features(image_pil: Image.Image, target_long_side: int, mdl=None):
99
+ global model
100
+ mdl = mdl or model
101
+ t = resize_to_grid(image_pil, target_long_side, PATCH_SIZE)
102
+ t_norm = TF.normalize(t, IMAGENET_MEAN, IMAGENET_STD).unsqueeze(0).to(DEVICE)
103
+ _, _, H, W = t_norm.shape
104
+ Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
105
+
106
+ outputs = mdl(t_norm)
107
+ patch_emb = outputs.last_hidden_state.squeeze(0)[N_SPECIAL_TOKENS:, :] # skip special tokens
108
+ X = F.normalize(patch_emb, p=2, dim=-1) # (Hp*Wp, D), L2 norm for cosine
109
+ img_resized = TF.to_pil_image(t)
110
+
111
+ return {"X": X, "Hp": Hp, "Wp": Wp, "img": img_resized}
112
+
113
+ # ----------------------------
114
+ # Similarity utilities
115
+ # ----------------------------
116
+
117
+ def index_from_xy(x_pix: int, y_pix: int, Wp: int) -> int:
118
+ col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
119
+ row = int(np.clip(y_pix // PATCH_SIZE, 0, (x_pix*0 + y_pix) // PATCH_SIZE)) # placeholder row calc replaced below
120
+ return row * Wp + col
121
+
122
+ # Corrected row/col computation helper
123
+
124
+ def row_col_from_xy(x_pix: int, y_pix: int, Hp: int, Wp: int):
125
+ col = int(np.clip(x_pix // PATCH_SIZE, 0, Wp - 1))
126
+ row = int(np.clip(y_pix // PATCH_SIZE, 0, Hp - 1))
127
+ return row, col
128
+
129
+ @torch.inference_mode()
130
+ def similarity_map(X: torch.Tensor, Hp: int, Wp: int, q_vec: torch.Tensor,
131
+ img_h: int, img_w: int, exclude_radius_patches: int = 1):
132
+ sims = torch.matmul(X, q_vec) # (Hp*Wp)
133
+ sim_map = sims.view(Hp, Wp)
134
+
135
+ if exclude_radius_patches > 0:
136
+ rr, cc = torch.meshgrid(
137
+ torch.arange(Hp, device=sims.device),
138
+ torch.arange(Wp, device=sims.device),
139
+ indexing="ij",
140
+ )
141
+ # We'll mask later at the click location per-image if needed
142
+ mask_template = (rr * 0) # kept for API parity
143
+
144
+ sim_up = F.interpolate(
145
+ sim_map.unsqueeze(0).unsqueeze(0),
146
+ size=(img_h, img_w),
147
+ mode="bicubic",
148
+ align_corners=False,
149
+ ).squeeze().detach().cpu().numpy()
150
+ return sim_map, sim_up
151
+
152
+ # ----------------------------
153
+ # Core: click on image 1 → heatmaps on image 1 and image 2
154
+ # ----------------------------
155
+
156
+ def click_two_image_similarity(state1: dict, state2: dict, click_xy: tuple[int, int],
157
+ exclude_radius_patches: int, alpha: float, cmap_name: str):
158
+ if not state1 or not state2:
159
+ return (None,)*6
160
+
161
+ X1, Hp1, Wp1, img1 = state1["X"], state1["Hp"], state1["Wp"], state1["img"]
162
+ X2, Hp2, Wp2, img2 = state2["X"], state2["Hp"], state2["Wp"], state2["img"]
163
+
164
+ img1_w, img1_h = img1.size
165
+ img2_w, img2_h = img2.size
166
+
167
+ # Build query vector from clicked patch on image 1
168
+ col = int(np.clip(click_xy[0] // PATCH_SIZE, 0, Wp1 - 1))
169
+ row = int(np.clip(click_xy[1] // PATCH_SIZE, 0, Hp1 - 1))
170
+ idx = row * Wp1 + col
171
+ q = X1[idx] # (D,)
172
+
173
+ # Similarity on image 1
174
+ sims1 = torch.matmul(X1, q)
175
+ sim_map1 = sims1.view(Hp1, Wp1)
176
+ if exclude_radius_patches > 0:
177
+ rr, cc = torch.meshgrid(
178
+ torch.arange(Hp1, device=sims1.device),
179
+ torch.arange(Wp1, device=sims1.device),
180
+ indexing="ij",
181
+ )
182
+ mask1 = (torch.abs(rr - row) <= exclude_radius_patches) & (torch.abs(cc - col) <= exclude_radius_patches)
183
+ sim_map1 = sim_map1.masked_fill(mask1, float("-inf"))
184
+
185
+ sim1_up = F.interpolate(
186
+ sim_map1.unsqueeze(0).unsqueeze(0),
187
+ size=(img1_h, img1_w),
188
+ mode="bicubic",
189
+ align_corners=False,
190
+ ).squeeze().detach().cpu().numpy()
191
+
192
+ heat1 = colorize(sim1_up, cmap_name)
193
+ overlay1 = blend(img1, heat1, alpha)
194
+ marked1 = draw_crosshair(img1, int(click_xy[0]), int(click_xy[1]), radius=PATCH_SIZE // 2)
195
+
196
+ # Similarity on image 2 (no exclusion mask, since click is on image 1)
197
+ sims2 = torch.matmul(X2, q)
198
+ sim_map2 = sims2.view(Hp2, Wp2)
199
+ sim2_up = F.interpolate(
200
+ sim_map2.unsqueeze(0).unsqueeze(0),
201
+ size=(img2_h, img2_w),
202
+ mode="bicubic",
203
+ align_corners=False,
204
+ ).squeeze().detach().cpu().numpy()
205
+
206
+ heat2 = colorize(sim2_up, cmap_name)
207
+ overlay2 = blend(img2, heat2, alpha)
208
+
209
+ return marked1, heat1, overlay1, heat2, overlay2, sim2_up.max().item()
210
 
211
+ # ----------------------------
212
+ # Gradio UI
213
+ # ----------------------------
214
  with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 Two‑Image Patch Similarity") as demo:
215
  gr.Markdown("# DINOv3 Two‑Image Patch Similarity")
216
  gr.Markdown("Upload two images, process, then click on image 1 to see similarities on both.")
 
234
  # (rest of app: outputs, event wiring, functions, unchanged)
235
 
236
  if __name__ == "__main__":
237
+ demo.launch()