Update load.py
Browse files
load.py
CHANGED
|
@@ -134,7 +134,9 @@ def predict_large(
|
|
| 134 |
**kwargs
|
| 135 |
) -> np.ndarray:
|
| 136 |
"""
|
| 137 |
-
Predict on images of any size
|
|
|
|
|
|
|
| 138 |
|
| 139 |
Args:
|
| 140 |
image: Input image (C, H, W) in reflectance [0, 1]
|
|
@@ -166,8 +168,9 @@ def predict_large(
|
|
| 166 |
if overlap is None:
|
| 167 |
overlap = chunk_size // 2
|
| 168 |
|
| 169 |
-
#
|
| 170 |
-
|
|
|
|
| 171 |
with torch.no_grad():
|
| 172 |
img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
|
| 173 |
logits = model(img_tensor)
|
|
@@ -187,23 +190,16 @@ def predict_large(
|
|
| 187 |
|
| 188 |
return pred
|
| 189 |
|
| 190 |
-
# === SLIDING WINDOW FOR
|
| 191 |
|
| 192 |
-
# Calculate padding needed to make image divisible by step
|
| 193 |
step = chunk_size - overlap
|
| 194 |
|
| 195 |
-
#
|
| 196 |
-
|
| 197 |
-
pad_w = (step - (W % step)) % step
|
| 198 |
|
| 199 |
-
# Add extra overlap padding on all sides for smooth edges
|
| 200 |
-
pad_h += overlap
|
| 201 |
-
pad_w += overlap
|
| 202 |
-
|
| 203 |
-
# Pad image
|
| 204 |
image_padded = np.pad(
|
| 205 |
image,
|
| 206 |
-
((0, 0), (
|
| 207 |
mode="reflect"
|
| 208 |
)
|
| 209 |
|
|
@@ -219,9 +215,28 @@ def predict_large(
|
|
| 219 |
|
| 220 |
# Generate tile coordinates
|
| 221 |
coords = []
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
| 224 |
coords.append((r, c))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
# Process tiles in batches
|
| 227 |
with torch.no_grad():
|
|
@@ -248,8 +263,8 @@ def predict_large(
|
|
| 248 |
weight_sum = np.maximum(weight_sum, 1e-8)
|
| 249 |
probs_final = probs_sum / weight_sum
|
| 250 |
|
| 251 |
-
# Remove padding
|
| 252 |
-
probs_final = probs_final[:,
|
| 253 |
|
| 254 |
# Get final prediction
|
| 255 |
if merge_clouds:
|
|
|
|
| 134 |
**kwargs
|
| 135 |
) -> np.ndarray:
|
| 136 |
"""
|
| 137 |
+
Predict on images of any size.
|
| 138 |
+
- Small images (≤2048px): direct inference without tiling
|
| 139 |
+
- Large images (>2048px): sliding window with smooth blending
|
| 140 |
|
| 141 |
Args:
|
| 142 |
image: Input image (C, H, W) in reflectance [0, 1]
|
|
|
|
| 168 |
if overlap is None:
|
| 169 |
overlap = chunk_size // 2
|
| 170 |
|
| 171 |
+
# === DIRECT INFERENCE FOR SMALL/MEDIUM IMAGES ===
|
| 172 |
+
# Process directly without tiling to avoid artifacts
|
| 173 |
+
if max(H, W) <= 2048:
|
| 174 |
with torch.no_grad():
|
| 175 |
img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
|
| 176 |
logits = model(img_tensor)
|
|
|
|
| 190 |
|
| 191 |
return pred
|
| 192 |
|
| 193 |
+
# === SLIDING WINDOW FOR LARGE IMAGES (>2048px) ===
|
| 194 |
|
|
|
|
| 195 |
step = chunk_size - overlap
|
| 196 |
|
| 197 |
+
# Symmetric padding: overlap on each side
|
| 198 |
+
pad_size = overlap
|
|
|
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
image_padded = np.pad(
|
| 201 |
image,
|
| 202 |
+
((0, 0), (pad_size, pad_size), (pad_size, pad_size)),
|
| 203 |
mode="reflect"
|
| 204 |
)
|
| 205 |
|
|
|
|
| 215 |
|
| 216 |
# Generate tile coordinates
|
| 217 |
coords = []
|
| 218 |
+
r = 0
|
| 219 |
+
while r <= H_pad - chunk_size:
|
| 220 |
+
c = 0
|
| 221 |
+
while c <= W_pad - chunk_size:
|
| 222 |
coords.append((r, c))
|
| 223 |
+
c += step
|
| 224 |
+
# Ensure we cover the right edge
|
| 225 |
+
if c - step + chunk_size < W_pad:
|
| 226 |
+
coords.append((r, W_pad - chunk_size))
|
| 227 |
+
r += step
|
| 228 |
+
|
| 229 |
+
# Ensure we cover the bottom edge
|
| 230 |
+
if r - step + chunk_size < H_pad:
|
| 231 |
+
c = 0
|
| 232 |
+
while c <= W_pad - chunk_size:
|
| 233 |
+
coords.append((H_pad - chunk_size, c))
|
| 234 |
+
c += step
|
| 235 |
+
if c - step + chunk_size < W_pad:
|
| 236 |
+
coords.append((H_pad - chunk_size, W_pad - chunk_size))
|
| 237 |
+
|
| 238 |
+
# Remove duplicates
|
| 239 |
+
coords = list(set(coords))
|
| 240 |
|
| 241 |
# Process tiles in batches
|
| 242 |
with torch.no_grad():
|
|
|
|
| 263 |
weight_sum = np.maximum(weight_sum, 1e-8)
|
| 264 |
probs_final = probs_sum / weight_sum
|
| 265 |
|
| 266 |
+
# Remove symmetric padding
|
| 267 |
+
probs_final = probs_final[:, pad_size:pad_size + H, pad_size:pad_size + W]
|
| 268 |
|
| 269 |
# Get final prediction
|
| 270 |
if merge_clouds:
|