pszemraj commited on
Commit
db8cd55
·
verified ·
1 Parent(s): 32a6188

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -40
app.py CHANGED
@@ -14,8 +14,8 @@ from transformers import AutoImageProcessor, AutoModel
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  MODEL_MAP = {
17
- "DINOv3 ViT-L/16 Satellite": "facebook/dinov3-vitl16-pretrain-sat493m",
18
- "DINOv3 ViT-L/16 LVD (web data)": "facebook/dinov3-vitl16-pretrain-lvd1689m",
19
  "⚠️ DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m",
20
  }
21
 
@@ -27,40 +27,54 @@ model = None
27
 
28
 
29
  def cleanup_memory():
 
 
 
 
 
 
 
 
 
 
30
 
31
  gc.collect()
32
 
33
  if torch.cuda.is_available():
34
  torch.cuda.empty_cache()
 
35
 
36
 
37
  def load_model(name):
38
  """Load model with proper memory management and dtype handling"""
39
  global processor, model
40
 
41
- # Clean up existing model
42
- del model
43
- cleanup_memory()
44
 
45
- model_id = MODEL_MAP[name]
46
 
47
- # Load processor
48
- processor = AutoImageProcessor.from_pretrained(model_id)
49
 
50
- model = (
51
- AutoModel.from_pretrained(
52
- model_id,
53
- torch_dtype="auto",
 
 
 
54
  )
55
- .to(DEVICE)
56
- .eval()
57
- )
58
 
59
- # Get model info
60
- param_count = sum(p.numel() for p in model.parameters()) / 1e9
61
 
62
- return f"Loaded: {name} | {param_count:.1f}B params | {DEVICE.upper()}"
63
 
 
 
 
64
 
65
 
66
  # Initialize default model
@@ -93,7 +107,8 @@ def _overlay(orig, heat01, alpha=0.55, box=None):
93
  """Create heatmap overlay"""
94
  H, W = orig.height, orig.width
95
  heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H))
96
- rgba = (cm.get_cmap("inferno")(np.asarray(heat) / 255.0) * 255).astype(np.uint8)
 
97
  ov = Image.fromarray(rgba, "RGBA")
98
  ov.putalpha(int(alpha * 255))
99
  base = orig.copy().convert("RGBA")
@@ -101,8 +116,13 @@ def _overlay(orig, heat01, alpha=0.55, box=None):
101
  if box:
102
  from PIL import ImageDraw
103
 
104
- ImageDraw.Draw(out, "RGBA").rectangle(
105
- box, outline=(255, 255, 255, 220), width=2
 
 
 
 
 
106
  )
107
  return out
108
 
@@ -153,34 +173,97 @@ def reset():
153
  return None, None
154
 
155
 
156
- with gr.Blocks() as demo:
157
- gr.Markdown("## DINOv3: patch similarity visualizer")
158
- gr.Markdown(
159
- "This is an app where you can upload an image, click on an object in the image and get the most similar patches to it according to DINOv3, revealing the way DINOv3 segments objects through features natively."
160
- )
161
- gr.Markdown("There's multiple model options you can pick from the dropdown.")
162
- gr.Markdown(
163
- "Please click Reset before you want to upload another image, as this app keeps features of the images."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  )
165
 
166
- with gr.Column():
167
- with gr.Row(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  model_choice = gr.Dropdown(
169
- choices=list(MODEL_MAP.keys()), value=DEFAULT_NAME, label="Model"
 
 
 
170
  )
171
  status = gr.Textbox(
172
- label="Status", value=f"Loaded: {DEFAULT_NAME}", interactive=False
 
 
 
173
  )
174
- opacity = gr.Slider(0.0, 1.0, 0.55, step=0.05, label="Opacity for the Map")
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- with gr.Row(scale=1):
177
  img = gr.Image(
178
- type="pil", label="Image", interactive=True, height=750, width=750
 
 
 
 
 
179
  )
180
 
181
  state = gr.State()
182
 
183
- model_choice.change(load_model, inputs=model_choice, outputs=status)
 
 
184
 
185
  img.upload(prepare, inputs=img, outputs=state)
186
 
@@ -191,11 +274,12 @@ with gr.Blocks() as demo:
191
  show_progress="minimal",
192
  )
193
 
194
- gr.Button("Reset").click(reset, outputs=[img, state])
 
195
 
196
  # Examples from current directory
197
  example_files = [
198
- str(f)
199
  for f in Path.cwd().iterdir()
200
  if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]
201
  ]
@@ -206,10 +290,22 @@ with gr.Blocks() as demo:
206
  inputs=img,
207
  fn=prepare,
208
  outputs=[state],
209
- label="Try an example image and then click on the objects.",
210
  examples_per_page=4,
211
  cache_examples=False,
212
  )
213
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  if __name__ == "__main__":
215
  demo.launch(share=False, debug=True)
 
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  MODEL_MAP = {
17
+ "DINOv3 ViT-L/16 Satellite (493M)": "facebook/dinov3-vitl16-pretrain-sat493m",
18
+ "DINOv3 ViT-L/16 LVD (1.7B web)": "facebook/dinov3-vitl16-pretrain-lvd1689m",
19
  "⚠️ DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m",
20
  }
21
 
 
27
 
28
 
29
  def cleanup_memory():
30
+ """Aggressive memory cleanup for model switching"""
31
+ global processor, model
32
+
33
+ if model is not None:
34
+ del model
35
+ model = None
36
+
37
+ if processor is not None:
38
+ del processor
39
+ processor = None
40
 
41
  gc.collect()
42
 
43
  if torch.cuda.is_available():
44
  torch.cuda.empty_cache()
45
+ torch.cuda.synchronize()
46
 
47
 
48
  def load_model(name):
49
  """Load model with proper memory management and dtype handling"""
50
  global processor, model
51
 
52
+ try:
53
+ # Clean up existing model
54
+ cleanup_memory()
55
 
56
+ model_id = MODEL_MAP[name]
57
 
58
+ # Load processor
59
+ processor = AutoImageProcessor.from_pretrained(model_id)
60
 
61
+ model = (
62
+ AutoModel.from_pretrained(
63
+ model_id,
64
+ torch_dtype="auto",
65
+ )
66
+ .to(DEVICE)
67
+ .eval()
68
  )
 
 
 
69
 
70
+ # Get model info
71
+ param_count = sum(p.numel() for p in model.parameters()) / 1e9
72
 
73
+ return f"Loaded: {name} | {param_count:.1f}B params | {DEVICE.upper()}"
74
 
75
+ except Exception as e:
76
+ cleanup_memory()
77
+ return f"❌ Failed to load {name}: {str(e)}"
78
 
79
 
80
  # Initialize default model
 
107
  """Create heatmap overlay"""
108
  H, W = orig.height, orig.width
109
  heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H))
110
+ # Use turbo colormap - better for satellite imagery
111
+ rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8)
112
  ov = Image.fromarray(rgba, "RGBA")
113
  ov.putalpha(int(alpha * 255))
114
  base = orig.copy().convert("RGBA")
 
116
  if box:
117
  from PIL import ImageDraw
118
 
119
+ draw = ImageDraw.Draw(out, "RGBA")
120
+ # Enhanced box visualization
121
+ draw.rectangle(box, outline=(255, 255, 255, 255), width=3)
122
+ draw.rectangle(
123
+ (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1),
124
+ outline=(0, 0, 0, 200),
125
+ width=1,
126
  )
127
  return out
128
 
 
173
  return None, None
174
 
175
 
176
+ with gr.Blocks(
177
+ theme=gr.themes.Citrus(),
178
+ css="""
179
+ .container {max-width: 1200px; margin: auto;}
180
+ .header {text-align: center; padding: 20px;}
181
+ .info-box {
182
+ background: rgba(0,0,0,0.03);
183
+ border-radius: 8px;
184
+ padding: 12px;
185
+ margin: 10px 0;
186
+ border-left: 4px solid #2563eb;
187
+ }
188
+ """,
189
+ ) as demo:
190
+ gr.HTML(
191
+ """
192
+ <div class="header">
193
+ <h1>🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity</h1>
194
+ <p style="font-size: 1.1em; color: #666;">
195
+ Click any region to visualize feature similarities across the image
196
+ </p>
197
+ </div>
198
+ """
199
  )
200
 
201
+ with gr.Row():
202
+ with gr.Column(scale=3):
203
+ gr.Markdown(
204
+ """
205
+ ### How it works
206
+ 1. **Select a model** - Satellite-pretrained models optimized for aerial/satellite imagery
207
+ 2. **Upload or select an image** - Works best with satellite, aerial, or outdoor scenes
208
+ 3. **Click any region** - See how similar other patches are to your selection
209
+ 4. **Adjust opacity** - Fine-tune visualization clarity
210
+ """
211
+ )
212
+
213
+ with gr.Column(scale=2):
214
+ gr.HTML(
215
+ """
216
+ <div class="info-box">
217
+ <b>💡 Model Info:</b><br>
218
+ • <b>Satellite (493M)</b>: Trained on 493M satellite images<br>
219
+ • <b>LVD (1.7B)</b>: Trained on 1.7B diverse web images<br>
220
+ • <b>7B Satellite</b>: Massive capacity, requires high VRAM
221
+ </div>
222
+ """
223
+ )
224
+
225
+ with gr.Row():
226
+ with gr.Column(scale=1):
227
  model_choice = gr.Dropdown(
228
+ choices=list(MODEL_MAP.keys()),
229
+ value=DEFAULT_NAME,
230
+ label="🤖 Model Selection",
231
+ info="Satellite models excel at geographic and structural patterns",
232
  )
233
  status = gr.Textbox(
234
+ label="📡 Model Status",
235
+ value=f"✅ Loaded: {DEFAULT_NAME}",
236
+ interactive=False,
237
+ lines=1,
238
  )
239
+ opacity = gr.Slider(
240
+ 0.0,
241
+ 1.0,
242
+ 0.55,
243
+ step=0.05,
244
+ label="🎨 Heatmap Opacity",
245
+ info="Balance between image and similarity map",
246
+ )
247
+
248
+ with gr.Row():
249
+ reset_btn = gr.Button("🔄 Reset", variant="secondary", scale=1)
250
+ clear_btn = gr.ClearButton(value="🗑️ Clear All", scale=1)
251
 
252
+ with gr.Column(scale=2):
253
  img = gr.Image(
254
+ type="pil",
255
+ label="Interactive Canvas (Click to explore)",
256
+ interactive=True,
257
+ height=600,
258
+ show_download_button=True,
259
+ show_share_button=False,
260
  )
261
 
262
  state = gr.State()
263
 
264
+ model_choice.change(
265
+ load_model, inputs=model_choice, outputs=status, show_progress="full"
266
+ )
267
 
268
  img.upload(prepare, inputs=img, outputs=state)
269
 
 
274
  show_progress="minimal",
275
  )
276
 
277
+ reset_btn.click(reset, outputs=[img, state])
278
+ clear_btn.add([img, state])
279
 
280
  # Examples from current directory
281
  example_files = [
282
+ f.name
283
  for f in Path.cwd().iterdir()
284
  if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"]
285
  ]
 
290
  inputs=img,
291
  fn=prepare,
292
  outputs=[state],
293
+ label="Example Images",
294
  examples_per_page=4,
295
  cache_examples=False,
296
  )
297
 
298
+ gr.Markdown(
299
+ """
300
+ ---
301
+ <div style="text-align: center; color: #666; font-size: 0.9em;">
302
+ <b>Performance Notes:</b> Satellite models are optimized for geographic patterns, land use classification,
303
+ and structural analysis. The 7B model provides exceptional detail but requires significant compute.
304
+ <br><br>
305
+ Built with DINOv3 | Optimized for satellite and aerial imagery analysis
306
+ </div>
307
+ """
308
+ )
309
+
310
  if __name__ == "__main__":
311
  demo.launch(share=False, debug=True)