pszemraj commited on
Commit
fe31b95
·
verified ·
1 Parent(s): bded4ff

Dynamic img sizing

Browse files

Key Changes:

Dynamic sizing function (compute_dynamic_size): Computes optimal dimensions up to 720px max while maintaining aspect ratio AND ensuring divisibility by patch_size (16).
Custom preprocessing (preprocess_image): Bypasses the rigid 224×224 resize. DINOv3 was trained with scale augmentation and its RoPE handles this perfectly.
Resolution feedback: Added resolution_info textbox showing actual processing resolution and patch grid dimensions.
Proper normalization: Still applies ImageNet normalization but at the dynamically computed resolution.

Why This Works:
DINOv3's 3D RoPE embeddings are computed dynamically in the forward pass based on actual image dimensions (see get_patches_center_coordinates in the modeling code). The model was trained with random scale augmentation specifically to handle variable sizes. You were unnecessarily constraining it.

Files changed (1) hide show
  1. app.py +116 -32
app.py CHANGED
@@ -20,6 +20,7 @@ MODEL_MAP = {
20
  }
21
 
22
  DEFAULT_NAME = list(MODEL_MAP.keys())[0]
 
23
 
24
  # Global model state
25
  processor = None
@@ -42,23 +43,44 @@ def cleanup_memory():
42
 
43
  if torch.cuda.is_available():
44
  torch.cuda.empty_cache()
45
- # torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  def load_model(name):
49
  """Load model with CORRECT dtype"""
50
  global processor, model
51
-
52
  cleanup_memory()
53
  model_id = MODEL_MAP[name]
54
-
55
  processor = AutoImageProcessor.from_pretrained(model_id)
56
-
57
  model = AutoModel.from_pretrained(
58
  model_id,
59
- torch_dtype="auto",
60
  ).eval()
61
-
62
  param_count = sum(p.numel() for p in model.parameters()) / 1e9
63
  return f"Loaded: {name} | {param_count:.1f}B params | Ready"
64
 
@@ -67,35 +89,77 @@ def load_model(name):
67
  load_model(DEFAULT_NAME)
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  @spaces.GPU(duration=60)
71
  def _extract_grid(img):
72
- """Extract feature grid from image"""
73
  global model
74
-
75
  with torch.inference_mode():
76
  # Move model to GPU for this call
77
- model = model.to('cuda')
78
-
79
- # Process image and move to GPU
80
- pv = processor(images=img, return_tensors="pt").pixel_values.to(model.device)
81
-
82
- # Run inference
 
83
  out = model(pixel_values=pv)
84
  last = out.last_hidden_state[0].to(torch.float32)
85
-
86
  # Extract features
87
  num_reg = getattr(model.config, "num_register_tokens", 0)
88
  p = model.config.patch_size
89
- _, _, Ht, Wt = pv.shape
90
- gh, gw = Ht // p, Wt // p
91
-
92
- feats = last[1 + num_reg:, :].reshape(gh, gw, -1).cpu()
93
-
 
94
  # Move model back to CPU before function exits
95
  model = model.cpu()
96
  torch.cuda.empty_cache()
97
-
98
- return feats, gh, gw
99
 
100
 
101
  def _overlay(orig, heat01, alpha=0.55, box=None):
@@ -123,14 +187,21 @@ def _overlay(orig, heat01, alpha=0.55, box=None):
123
 
124
 
125
  def prepare(img):
126
- """Prepare image and extract features"""
127
  if img is None:
128
  return None
129
 
130
  base = ImageOps.exif_transpose(img.convert("RGB"))
131
- feats, gh, gw = _extract_grid(base)
132
-
133
- return {"orig": base, "feats": feats, "gh": gh, "gw": gw}
 
 
 
 
 
 
 
134
 
135
 
136
  def click(state, opacity, img_value, evt: gr.SelectData):
@@ -160,12 +231,16 @@ def click(state, opacity, img_value, evt: gr.SelectData):
160
 
161
  box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y))
162
  overlay = _overlay(base, heat01, alpha=opacity, box=box)
163
- return overlay, state
 
 
 
 
164
 
165
 
166
  def reset():
167
  """Reset the interface"""
168
- return None, None
169
 
170
 
171
  with gr.Blocks(
@@ -207,6 +282,12 @@ with gr.Blocks(
207
  interactive=False,
208
  lines=1,
209
  )
 
 
 
 
 
 
210
  opacity = gr.Slider(
211
  0.0,
212
  1.0,
@@ -241,12 +322,12 @@ with gr.Blocks(
241
  img.select(
242
  click,
243
  inputs=[state, opacity, img],
244
- outputs=[img, state],
245
  show_progress="minimal",
246
  )
247
 
248
- reset_btn.click(reset, outputs=[img, state])
249
- clear_btn.add([img, state])
250
 
251
  # Examples from current directory
252
  example_files = [
@@ -267,9 +348,12 @@ with gr.Blocks(
267
  )
268
 
269
  gr.Markdown(
270
- """
271
  ---
272
  <div style="text-align: center; color: #666; font-size: 0.9em;">
 
 
 
273
  <b>Performance Notes:</b> Satellite models are optimized for geographic patterns, land use classification,
274
  and structural analysis. The 7B model provides exceptional detail but requires significant compute.
275
  <br><br>
 
20
  }
21
 
22
  DEFAULT_NAME = list(MODEL_MAP.keys())[0]
23
+ MAX_IMAGE_DIM = 720 # Maximum dimension for longer side
24
 
25
  # Global model state
26
  processor = None
 
43
 
44
  if torch.cuda.is_available():
45
  torch.cuda.empty_cache()
46
+
47
+
48
+ def compute_dynamic_size(height, width, max_dim=720, patch_size=16):
49
+ """
50
+ Compute new dimensions preserving aspect ratio with max_dim constraint.
51
+ Ensures dimensions are divisible by patch_size for clean patch extraction.
52
+ """
53
+ # Determine scaling factor
54
+ if height > width:
55
+ scale = min(1.0, max_dim / height)
56
+ else:
57
+ scale = min(1.0, max_dim / width)
58
+
59
+ # Compute new dimensions
60
+ new_height = int(height * scale)
61
+ new_width = int(width * scale)
62
+
63
+ # Round to nearest multiple of patch_size for clean patches
64
+ new_height = (new_height // patch_size) * patch_size
65
+ new_width = (new_width // patch_size) * patch_size
66
+
67
+ return new_height, new_width
68
 
69
 
70
  def load_model(name):
71
  """Load model with CORRECT dtype"""
72
  global processor, model
73
+
74
  cleanup_memory()
75
  model_id = MODEL_MAP[name]
76
+
77
  processor = AutoImageProcessor.from_pretrained(model_id)
78
+
79
  model = AutoModel.from_pretrained(
80
  model_id,
81
+ torch_dtype="auto",
82
  ).eval()
83
+
84
  param_count = sum(p.numel() for p in model.parameters()) / 1e9
85
  return f"Loaded: {name} | {param_count:.1f}B params | Ready"
86
 
 
89
  load_model(DEFAULT_NAME)
90
 
91
 
92
+ def preprocess_image(img):
93
+ """
94
+ Custom preprocessing that respects aspect ratio and uses dynamic sizing.
95
+ DINOv3's RoPE handles variable sizes beautifully - no need to constrain to 224x224!
96
+ """
97
+ # Convert to RGB if needed
98
+ if img.mode != "RGB":
99
+ img = img.convert("RGB")
100
+
101
+ # Compute dynamic size
102
+ orig_h, orig_w = img.height, img.width
103
+ patch_size = model.config.patch_size
104
+ new_h, new_w = compute_dynamic_size(orig_h, orig_w, MAX_IMAGE_DIM, patch_size)
105
+
106
+ # Resize image
107
+ img_resized = img.resize((new_w, new_h), Image.Resampling.BICUBIC)
108
+
109
+ # Convert to tensor and normalize using the processor's normalization params
110
+ img_array = np.array(img_resized).astype(np.float32) / 255.0
111
+
112
+ # Apply ImageNet normalization (from processor config)
113
+ mean = (
114
+ processor.image_mean
115
+ if hasattr(processor, "image_mean")
116
+ else [0.485, 0.456, 0.406]
117
+ )
118
+ std = (
119
+ processor.image_std
120
+ if hasattr(processor, "image_std")
121
+ else [0.229, 0.224, 0.225]
122
+ )
123
+
124
+ img_array = (img_array - mean) / std
125
+
126
+ # Convert to tensor with correct shape: [1, C, H, W]
127
+ img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float()
128
+
129
+ return img_tensor, new_h, new_w
130
+
131
+
132
  @spaces.GPU(duration=60)
133
  def _extract_grid(img):
134
+ """Extract feature grid from image - now with dynamic sizing!"""
135
  global model
136
+
137
  with torch.inference_mode():
138
  # Move model to GPU for this call
139
+ model = model.to("cuda")
140
+
141
+ # Preprocess with dynamic sizing
142
+ pv, img_h, img_w = preprocess_image(img)
143
+ pv = pv.to(model.device)
144
+
145
+ # Run inference - the model handles variable sizes perfectly!
146
  out = model(pixel_values=pv)
147
  last = out.last_hidden_state[0].to(torch.float32)
148
+
149
  # Extract features
150
  num_reg = getattr(model.config, "num_register_tokens", 0)
151
  p = model.config.patch_size
152
+
153
+ # Calculate grid dimensions based on actual image size
154
+ gh, gw = img_h // p, img_w // p
155
+
156
+ feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu()
157
+
158
  # Move model back to CPU before function exits
159
  model = model.cpu()
160
  torch.cuda.empty_cache()
161
+
162
+ return feats, gh, gw, img_h, img_w
163
 
164
 
165
  def _overlay(orig, heat01, alpha=0.55, box=None):
 
187
 
188
 
189
  def prepare(img):
190
+ """Prepare image and extract features with dynamic sizing"""
191
  if img is None:
192
  return None
193
 
194
  base = ImageOps.exif_transpose(img.convert("RGB"))
195
+ feats, gh, gw, img_h, img_w = _extract_grid(base)
196
+
197
+ return {
198
+ "orig": base,
199
+ "feats": feats,
200
+ "gh": gh,
201
+ "gw": gw,
202
+ "processed_h": img_h,
203
+ "processed_w": img_w,
204
+ }
205
 
206
 
207
  def click(state, opacity, img_value, evt: gr.SelectData):
 
231
 
232
  box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y))
233
  overlay = _overlay(base, heat01, alpha=opacity, box=box)
234
+
235
+ # Add info about resolution being processed
236
+ info_text = f"Processing at: {state['processed_w']}×{state['processed_h']} ({gh}×{gw} patches)"
237
+
238
+ return overlay, state, info_text
239
 
240
 
241
  def reset():
242
  """Reset the interface"""
243
+ return None, None, ""
244
 
245
 
246
  with gr.Blocks(
 
282
  interactive=False,
283
  lines=1,
284
  )
285
+ resolution_info = gr.Textbox(
286
+ label="Processing Resolution",
287
+ value="",
288
+ interactive=False,
289
+ lines=1,
290
+ )
291
  opacity = gr.Slider(
292
  0.0,
293
  1.0,
 
322
  img.select(
323
  click,
324
  inputs=[state, opacity, img],
325
+ outputs=[img, state, resolution_info],
326
  show_progress="minimal",
327
  )
328
 
329
+ reset_btn.click(reset, outputs=[img, state, resolution_info])
330
+ clear_btn.add([img, state, resolution_info])
331
 
332
  # Examples from current directory
333
  example_files = [
 
348
  )
349
 
350
  gr.Markdown(
351
+ f"""
352
  ---
353
  <div style="text-align: center; color: #666; font-size: 0.9em;">
354
+ <b>Dynamic Resolution:</b> Images are processed at up to {MAX_IMAGE_DIM}px (longer side) while preserving aspect ratio.
355
+ DINOv3's RoPE embeddings handle variable sizes perfectly - no need to squash to 224×224!
356
+ <br><br>
357
  <b>Performance Notes:</b> Satellite models are optimized for geographic patterns, land use classification,
358
  and structural analysis. The 7B model provides exceptional detail but requires significant compute.
359
  <br><br>