JulioContrerasH commited on
Commit
f7c51b8
·
verified ·
1 Parent(s): 38bf451

Update load.py

Browse files
Files changed (1) hide show
  1. load.py +33 -18
load.py CHANGED
@@ -134,7 +134,9 @@ def predict_large(
134
  **kwargs
135
  ) -> np.ndarray:
136
  """
137
- Predict on images of any size using sliding window with smooth blending.
 
 
138
 
139
  Args:
140
  image: Input image (C, H, W) in reflectance [0, 1]
@@ -166,8 +168,9 @@ def predict_large(
166
  if overlap is None:
167
  overlap = chunk_size // 2
168
 
169
- # Direct inference for small images
170
- if H <= chunk_size and W <= chunk_size:
 
171
  with torch.no_grad():
172
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
173
  logits = model(img_tensor)
@@ -187,23 +190,16 @@ def predict_large(
187
 
188
  return pred
189
 
190
- # === SLIDING WINDOW FOR LARGER IMAGES ===
191
 
192
- # Calculate padding needed to make image divisible by step
193
  step = chunk_size - overlap
194
 
195
- # Padding to ensure tiles cover the entire image
196
- pad_h = (step - (H % step)) % step
197
- pad_w = (step - (W % step)) % step
198
 
199
- # Add extra overlap padding on all sides for smooth edges
200
- pad_h += overlap
201
- pad_w += overlap
202
-
203
- # Pad image
204
  image_padded = np.pad(
205
  image,
206
- ((0, 0), (overlap // 2, pad_h - overlap // 2), (overlap // 2, pad_w - overlap // 2)),
207
  mode="reflect"
208
  )
209
 
@@ -219,9 +215,28 @@ def predict_large(
219
 
220
  # Generate tile coordinates
221
  coords = []
222
- for r in range(0, H_pad - chunk_size + 1, step):
223
- for c in range(0, W_pad - chunk_size + 1, step):
 
 
224
  coords.append((r, c))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  # Process tiles in batches
227
  with torch.no_grad():
@@ -248,8 +263,8 @@ def predict_large(
248
  weight_sum = np.maximum(weight_sum, 1e-8)
249
  probs_final = probs_sum / weight_sum
250
 
251
- # Remove padding to get back to original size
252
- probs_final = probs_final[:, overlap // 2:overlap // 2 + H, overlap // 2:overlap // 2 + W]
253
 
254
  # Get final prediction
255
  if merge_clouds:
 
134
  **kwargs
135
  ) -> np.ndarray:
136
  """
137
+ Predict on images of any size.
138
+ - Small images (≤2048px): direct inference without tiling
139
+ - Large images (>2048px): sliding window with smooth blending
140
 
141
  Args:
142
  image: Input image (C, H, W) in reflectance [0, 1]
 
168
  if overlap is None:
169
  overlap = chunk_size // 2
170
 
171
+ # === DIRECT INFERENCE FOR SMALL/MEDIUM IMAGES ===
172
+ # Process directly without tiling to avoid artifacts
173
+ if max(H, W) <= 2048:
174
  with torch.no_grad():
175
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
176
  logits = model(img_tensor)
 
190
 
191
  return pred
192
 
193
+ # === SLIDING WINDOW FOR LARGE IMAGES (>2048px) ===
194
 
 
195
  step = chunk_size - overlap
196
 
197
+ # Symmetric padding: overlap on each side
198
+ pad_size = overlap
 
199
 
 
 
 
 
 
200
  image_padded = np.pad(
201
  image,
202
+ ((0, 0), (pad_size, pad_size), (pad_size, pad_size)),
203
  mode="reflect"
204
  )
205
 
 
215
 
216
  # Generate tile coordinates
217
  coords = []
218
+ r = 0
219
+ while r <= H_pad - chunk_size:
220
+ c = 0
221
+ while c <= W_pad - chunk_size:
222
  coords.append((r, c))
223
+ c += step
224
+ # Ensure we cover the right edge
225
+ if c - step + chunk_size < W_pad:
226
+ coords.append((r, W_pad - chunk_size))
227
+ r += step
228
+
229
+ # Ensure we cover the bottom edge
230
+ if r - step + chunk_size < H_pad:
231
+ c = 0
232
+ while c <= W_pad - chunk_size:
233
+ coords.append((H_pad - chunk_size, c))
234
+ c += step
235
+ if c - step + chunk_size < W_pad:
236
+ coords.append((H_pad - chunk_size, W_pad - chunk_size))
237
+
238
+ # Remove duplicates
239
+ coords = list(set(coords))
240
 
241
  # Process tiles in batches
242
  with torch.no_grad():
 
263
  weight_sum = np.maximum(weight_sum, 1e-8)
264
  probs_final = probs_sum / weight_sum
265
 
266
+ # Remove symmetric padding
267
+ probs_final = probs_final[:, pad_size:pad_size + H, pad_size:pad_size + W]
268
 
269
  # Get final prediction
270
  if merge_clouds: