pszemraj commited on
Commit
252fe2c
·
verified ·
1 Parent(s): 43e8613

refactor to satellite-based dinoV3

Browse files

note this was done by claude, potentially overkill

Files changed (1) hide show
  1. app.py +286 -52
app.py CHANGED
@@ -1,74 +1,167 @@
 
 
 
 
 
1
  import numpy as np
 
2
  import torch
3
  import torch.nn.functional as F
4
  from PIL import Image, ImageOps
5
- import matplotlib.cm as cm
6
- import gradio as gr
7
  from transformers import AutoImageProcessor, AutoModel
8
- import spaces
9
 
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  MODEL_MAP = {
13
- "DINOv3 ViT+ Huge/16 (841M)": "facebook/dinov3-vith16plus-pretrain-lvd1689m",
14
- "DINOv2 ViT Large/16 (303M)": "facebook/dinov3-vitl16-pretrain-lvd1689m",
15
- "DINOv2 ConvNeXt Small (49.5M)": "facebook/dinov3-convnext-small-pretrain-lvd1689m",
16
  }
 
17
  DEFAULT_NAME = list(MODEL_MAP.keys())[0]
18
 
 
19
  processor = None
20
  model = None
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def load_model(name):
 
23
  global processor, model
24
- model_id = MODEL_MAP[name]
25
- processor = AutoImageProcessor.from_pretrained(model_id)
26
- model = AutoModel.from_pretrained(model_id, torch_dtype=torch.float32).to(DEVICE).eval()
27
- return f"Loaded: {name} → {model_id}"
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  load_model(DEFAULT_NAME)
30
 
31
- @spaces.GPU()
 
32
  def _extract_grid(img):
 
33
  with torch.inference_mode():
34
- pv = processor(images=img, return_tensors="pt").pixel_values.to(DEVICE)
 
 
 
 
35
  out = model(pixel_values=pv)
36
  last = out.last_hidden_state[0].to(torch.float32)
 
37
  num_reg = getattr(model.config, "num_register_tokens", 0)
38
  p = model.config.patch_size
39
  _, _, Ht, Wt = pv.shape
40
  gh, gw = Ht // p, Wt // p
41
- feats = last[1 + num_reg:, :].reshape(gh, gw, -1).cpu()
 
 
42
  return feats, gh, gw
43
 
 
44
  def _overlay(orig, heat01, alpha=0.55, box=None):
 
45
  H, W = orig.height, orig.width
46
- heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H))
47
- rgba = (cm.get_cmap("inferno")(np.asarray(heat) / 255.0) * 255).astype(np.uint8)
48
- ov = Image.fromarray(rgba, "RGBA"); ov.putalpha(int(alpha * 255))
 
 
 
 
 
 
49
  base = orig.copy().convert("RGBA")
50
  out = Image.alpha_composite(base, ov)
 
51
  if box:
52
  from PIL import ImageDraw
53
- ImageDraw.Draw(out, "RGBA").rectangle(box, outline=(255, 255, 255, 220), width=2)
 
 
 
 
 
 
 
 
 
54
  return out
55
 
 
56
  def prepare(img):
 
57
  if img is None:
58
  return None
 
59
  base = ImageOps.exif_transpose(img.convert("RGB"))
60
  feats, gh, gw = _extract_grid(base)
 
61
  return {"orig": base, "feats": feats, "gh": gh, "gw": gw}
62
 
63
 
64
- def click(state, opacity, img_value, evt: gr.SelectData):
65
- # If state wasn't prepared (e.g., Example selection), build it now
 
66
  if state is None and img_value is not None:
67
  state = prepare(img_value)
68
 
69
  if not state or evt.index is None:
70
- # Just show whatever is currently in the image component
71
- return img_value, state
72
 
73
  base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"]
74
 
@@ -85,52 +178,193 @@ def click(state, opacity, img_value, evt: gr.SelectData):
85
  smin, smax = float(sims.min()), float(sims.max())
86
  heat01 = (sims - smin) / (smax - smin + 1e-12)
87
 
 
 
 
 
 
 
 
88
  box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y))
89
- overlay = _overlay(base, heat01, alpha=opacity, box=box)
90
- return overlay, state
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def reset():
95
- return None, None
96
-
97
- with gr.Blocks() as demo:
98
- gr.Markdown("## DINOv3: patch similarity visualizer")
99
- gr.Markdown("This is an app where you can upload an image, click on an object in the image and get the most similar patches to it according to DINOv3, revealing the way DINOv3 segments objects through features natively.")
100
- gr.Markdown("There's multiple model options you can pick from the dropdown.")
101
- gr.Markdown("Please click Reset before you want to upload another image, as this app keeps features of the images.")
102
-
103
- with gr.Column():
104
- with gr.Row(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  model_choice = gr.Dropdown(
106
- choices=list(MODEL_MAP.keys()), value=DEFAULT_NAME, label="Model"
 
 
 
107
  )
108
- status = gr.Textbox(label="Status", value=f"Loaded: {DEFAULT_NAME}", interactive=False)
109
- opacity = gr.Slider(0.0, 1.0, 0.55, step=0.05, label="Opacity for the Map")
110
 
111
- with gr.Row(scale=1):
112
- img = gr.Image(type="pil", label="Image", interactive=True, height=750, width=750)
 
 
 
 
113
 
 
 
 
 
 
 
 
 
 
114
 
115
- state = gr.State()
 
 
 
 
 
116
 
117
- model_choice.change(load_model, inputs=model_choice, outputs=status)
118
 
119
- img.upload(prepare, inputs=img, outputs=state)
 
 
120
 
121
- img.select(click, inputs=[state, opacity, img], outputs=[img, state], show_progress="minimal")
 
 
 
 
 
 
 
 
122
 
 
123
 
124
- gr.Button("Reset").click(reset, outputs=[img, state])
125
  gr.Examples(
126
- examples=[["flowers.PNG"], ["kedis.JPG"], ["kyoto.jpg"], ["nara.JPG"]],
127
- inputs=img,
128
- fn=prepare,
129
- outputs=[img,state],
130
- label="Try an example image and then click on the objects.",
131
- examples_per_page=4,
132
- cache_examples=False,
133
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if __name__ == "__main__":
136
- demo.launch()
 
1
+ import gc
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+ import matplotlib.cm as cm
6
  import numpy as np
7
+ import spaces
8
  import torch
9
  import torch.nn.functional as F
10
  from PIL import Image, ImageOps
 
 
11
  from transformers import AutoImageProcessor, AutoModel
 
12
 
13
+ # Device configuration with memory management
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  MODEL_MAP = {
17
+ "DINOv3 ViT-L/16 Satellite": "facebook/dinov3-vitl16-pretrain-sat493m",
18
+ "DINOv3 ViT-L/16 LVD (General Web)": "facebook/dinov3-vitl16-pretrain-lvd1689m",
19
+ "⚠️ DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m",
20
  }
21
+
22
  DEFAULT_NAME = list(MODEL_MAP.keys())[0]
23
 
24
+ # Global model state
25
  processor = None
26
  model = None
27
 
28
+
29
+ def cleanup_memory():
30
+ """Aggressive memory cleanup for model switching"""
31
+ global processor, model
32
+
33
+ if model is not None:
34
+ del model
35
+ model = None
36
+
37
+ if processor is not None:
38
+ del processor
39
+ processor = None
40
+
41
+ gc.collect()
42
+
43
+ if torch.cuda.is_available():
44
+ torch.cuda.empty_cache()
45
+ torch.cuda.synchronize()
46
+
47
+
48
  def load_model(name):
49
+ """Load model with proper memory management and dtype handling"""
50
  global processor, model
 
 
 
 
51
 
52
+ try:
53
+ # Clean up existing model
54
+ cleanup_memory()
55
+
56
+ model_id = MODEL_MAP[name]
57
+
58
+ # Load with auto dtype for optimal performance
59
+ processor = AutoImageProcessor.from_pretrained(model_id)
60
+
61
+ # Determine optimal dtype based on model size and hardware
62
+ if "7b" in model_id.lower() and torch.cuda.is_available():
63
+ # For 7B model, use bfloat16 if available for memory efficiency
64
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
65
+ else:
66
+ dtype = torch.float32
67
+
68
+ model = AutoModel.from_pretrained(
69
+ model_id,
70
+ torch_dtype=dtype,
71
+ device_map="auto" if torch.cuda.is_available() else None,
72
+ )
73
+
74
+ if DEVICE == "cuda" and not hasattr(model, "device_map"):
75
+ model = model.to(DEVICE)
76
+
77
+ model.eval()
78
+
79
+ # Get model info
80
+ param_count = sum(p.numel() for p in model.parameters()) / 1e9
81
+ dtype_str = str(dtype).split(".")[-1]
82
+
83
+ return f"✅ Loaded: {name} | {param_count:.1f}B params | {dtype_str} | {DEVICE.upper()}"
84
+
85
+ except Exception as e:
86
+ cleanup_memory()
87
+ return f"❌ Failed to load {name}: {str(e)}"
88
+
89
+
90
+ # Initialize default model
91
  load_model(DEFAULT_NAME)
92
 
93
+
94
+ @spaces.GPU(duration=60)
95
  def _extract_grid(img):
96
+ """Extract feature grid from image"""
97
  with torch.inference_mode():
98
+ pv = processor(images=img, return_tensors="pt").pixel_values
99
+
100
+ if DEVICE == "cuda":
101
+ pv = pv.to(DEVICE)
102
+
103
  out = model(pixel_values=pv)
104
  last = out.last_hidden_state[0].to(torch.float32)
105
+
106
  num_reg = getattr(model.config, "num_register_tokens", 0)
107
  p = model.config.patch_size
108
  _, _, Ht, Wt = pv.shape
109
  gh, gw = Ht // p, Wt // p
110
+
111
+ feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu()
112
+
113
  return feats, gh, gw
114
 
115
+
116
  def _overlay(orig, heat01, alpha=0.55, box=None):
117
+ """Create heatmap overlay with improved visualization"""
118
  H, W = orig.height, orig.width
119
+ heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize(
120
+ (W, H), resample=Image.LANCZOS
121
+ )
122
+
123
+ # Use a better colormap for satellite imagery
124
+ rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8)
125
+ ov = Image.fromarray(rgba, "RGBA")
126
+ ov.putalpha(int(alpha * 255))
127
+
128
  base = orig.copy().convert("RGBA")
129
  out = Image.alpha_composite(base, ov)
130
+
131
  if box:
132
  from PIL import ImageDraw
133
+
134
+ draw = ImageDraw.Draw(out, "RGBA")
135
+ # Enhanced box visualization
136
+ draw.rectangle(box, outline=(255, 255, 255, 255), width=3)
137
+ draw.rectangle(
138
+ (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1),
139
+ outline=(0, 0, 0, 200),
140
+ width=1,
141
+ )
142
+
143
  return out
144
 
145
+
146
  def prepare(img):
147
+ """Prepare image and extract features"""
148
  if img is None:
149
  return None
150
+
151
  base = ImageOps.exif_transpose(img.convert("RGB"))
152
  feats, gh, gw = _extract_grid(base)
153
+
154
  return {"orig": base, "feats": feats, "gh": gh, "gw": gw}
155
 
156
 
157
+ def click(state, opacity, colormap, img_value, evt: gr.SelectData):
158
+ """Handle click events for similarity visualization"""
159
+ # If state wasn't prepared, build it now
160
  if state is None and img_value is not None:
161
  state = prepare(img_value)
162
 
163
  if not state or evt.index is None:
164
+ return img_value, state, None
 
165
 
166
  base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"]
167
 
 
178
  smin, smax = float(sims.min()), float(sims.max())
179
  heat01 = (sims - smin) / (smax - smin + 1e-12)
180
 
181
+ # Update colormap dynamically
182
+ cm_func = cm.get_cmap(colormap.lower())
183
+ rgba = (cm_func(heat01) * 255).astype(np.uint8)
184
+ ov = Image.fromarray(rgba, "RGBA")
185
+ ov.putalpha(int(opacity * 255))
186
+
187
+ base_rgba = base.copy().convert("RGBA")
188
  box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y))
 
 
189
 
190
+ out = Image.alpha_composite(base_rgba, ov)
191
+ if box:
192
+ from PIL import ImageDraw
193
+
194
+ draw = ImageDraw.Draw(out, "RGBA")
195
+ draw.rectangle(box, outline=(255, 255, 255, 255), width=3)
196
+ draw.rectangle(
197
+ (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1),
198
+ outline=(0, 0, 0, 200),
199
+ width=1,
200
+ )
201
+
202
+ # Stats for info panel
203
+ stats = f"""📊 **Similarity Statistics**
204
+ - Min: {smin:.3f}
205
+ - Max: {smax:.3f}
206
+ - Range: {smax - smin:.3f}
207
+ - Patch: ({i}, {j})
208
+ - Grid: {gw}×{gh}"""
209
+
210
+ return out, state, stats
211
 
212
 
213
  def reset():
214
+ """Reset the interface"""
215
+ return None, None, None
216
+
217
+
218
+ # Build the interface
219
+ with gr.Blocks(
220
+ theme=gr.themes.Soft(
221
+ primary_hue="blue",
222
+ secondary_hue="gray",
223
+ neutral_hue="gray",
224
+ font=gr.themes.GoogleFont("Inter"),
225
+ ),
226
+ css="""
227
+ .container {max-width: 1200px; margin: auto;}
228
+ .header {text-align: center; padding: 20px;}
229
+ .info-box {
230
+ background: rgba(0,0,0,0.03);
231
+ border-radius: 8px;
232
+ padding: 12px;
233
+ margin: 10px 0;
234
+ border-left: 4px solid #2563eb;
235
+ }
236
+ """,
237
+ ) as demo:
238
+ gr.HTML(
239
+ """
240
+ <div class="header">
241
+ <h1>🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity</h1>
242
+ <p style="font-size: 1.1em; color: #666;">
243
+ Explore how DINOv3 models trained on satellite imagery understand visual patterns
244
+ </p>
245
+ </div>
246
+ """
247
+ )
248
+
249
+ with gr.Row():
250
+ with gr.Column(scale=3):
251
+ gr.Markdown(
252
+ """
253
+ ### How it works
254
+ 1. **Select a model** - Satellite-pretrained models are optimized for aerial/satellite imagery
255
+ 2. **Upload or select an image** - Works best with satellite, aerial, or outdoor scenes
256
+ 3. **Click any region** - See how similar other patches are to your selection
257
+ 4. **Adjust visualization** - Fine-tune opacity and colormap for clarity
258
+ """
259
+ )
260
+
261
+ with gr.Column(scale=2):
262
+ gr.HTML(
263
+ """
264
+ <div class="info-box">
265
+ <b>💡 Model Info:</b><br>
266
+ • <b>Satellite models</b>: Trained on 493M satellite images<br>
267
+ • <b>LVD model</b>: Trained on 1.7B diverse images<br>
268
+ • <b>7B model</b>: Massive capacity, slower but more nuanced
269
+ </div>
270
+ """
271
+ )
272
+
273
+ with gr.Row():
274
+ with gr.Column(scale=1):
275
  model_choice = gr.Dropdown(
276
+ choices=list(MODEL_MAP.keys()),
277
+ value=DEFAULT_NAME,
278
+ label="🤖 Model Selection",
279
+ info="Satellite models excel at geographic and structural patterns",
280
  )
 
 
281
 
282
+ status = gr.Textbox(
283
+ label="📡 Model Status",
284
+ value=f"Ready: {DEFAULT_NAME}",
285
+ interactive=False,
286
+ lines=1,
287
+ )
288
 
289
+ with gr.Row():
290
+ opacity = gr.Slider(
291
+ 0.2,
292
+ 0.9,
293
+ 0.55,
294
+ step=0.05,
295
+ label="🎨 Heatmap Opacity",
296
+ info="Balance between image and similarity map",
297
+ )
298
 
299
+ colormap = gr.Dropdown(
300
+ choices=["Turbo", "Inferno", "Viridis", "Plasma", "Magma", "Jet"],
301
+ value="Turbo",
302
+ label="🌈 Colormap",
303
+ info="Different maps for different contrasts",
304
+ )
305
 
306
+ info_panel = gr.Markdown(value=None, label="Statistics", visible=True)
307
 
308
+ with gr.Row():
309
+ reset_btn = gr.Button("🔄 Reset", variant="secondary", scale=1)
310
+ clear_btn = gr.ClearButton(value="🗑️ Clear All", scale=1)
311
 
312
+ with gr.Column(scale=2):
313
+ img = gr.Image(
314
+ type="pil",
315
+ label="Interactive Canvas (Click to explore)",
316
+ interactive=True,
317
+ height=600,
318
+ show_download_button=True,
319
+ show_share_button=False,
320
+ )
321
 
322
+ state = gr.State()
323
 
324
+ # Examples focused on satellite-relevant imagery
325
  gr.Examples(
326
+ examples=[
327
+ [_filepath.name]
328
+ for _filepath in Path.cwd().iterdir()
329
+ if _filepath.suffix.lower() in [".jpg", ".png", ".webp"]
330
+ ],
331
+ inputs=img,
332
+ fn=prepare,
333
+ outputs=[state],
334
+ label="Example Images",
335
+ examples_per_page=6,
336
+ cache_examples=False,
337
+ )
338
+
339
+ # Event handlers
340
+ model_choice.change(
341
+ load_model, inputs=model_choice, outputs=status, show_progress="full"
342
+ )
343
+
344
+ img.upload(prepare, inputs=img, outputs=state, show_progress="minimal")
345
+
346
+ img.select(
347
+ click,
348
+ inputs=[state, opacity, colormap, img],
349
+ outputs=[img, state, info_panel],
350
+ show_progress="minimal",
351
+ )
352
+
353
+ reset_btn.click(reset, outputs=[img, state, info_panel], show_progress=False)
354
+
355
+ clear_btn.add([img, state, info_panel])
356
+
357
+ gr.Markdown(
358
+ """
359
+ ---
360
+ <div style="text-align: center; color: #666; font-size: 0.9em;">
361
+ <b>Performance Notes:</b> Satellite models are optimized for geographic patterns, land use classification,
362
+ and structural analysis. The 7B model provides exceptional detail but requires significant compute.
363
+ <br><br>
364
+ Built with DINOv3 | Optimized for satellite and aerial imagery analysis
365
+ </div>
366
+ """
367
+ )
368
 
369
  if __name__ == "__main__":
370
+ demo.launch(share=False, show_error=True)