Update load.py
Browse files
load.py
CHANGED
|
@@ -125,29 +125,30 @@ def compiled_model(
|
|
| 125 |
def predict_large(
|
| 126 |
image: np.ndarray,
|
| 127 |
model: nn.Module,
|
| 128 |
-
chunk_size: int =
|
| 129 |
overlap: int = None,
|
| 130 |
batch_size: int = 1,
|
| 131 |
device: str = "cpu",
|
| 132 |
merge_clouds: bool = False,
|
| 133 |
apply_rules: bool = False,
|
| 134 |
-
max_direct_size: int = 4096, # Max size for direct inference
|
| 135 |
**kwargs
|
| 136 |
) -> np.ndarray:
|
| 137 |
"""
|
| 138 |
Predict on images of any size.
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
Args:
|
| 142 |
image: Input image (C, H, W) in reflectance [0, 1]
|
| 143 |
model: Loaded model from compiled_model()
|
| 144 |
-
chunk_size: Size of inference tiles (
|
| 145 |
-
overlap: Overlap between tiles (
|
| 146 |
batch_size: Tiles per batch (default: 1)
|
| 147 |
device: 'cpu' or 'cuda'
|
| 148 |
merge_clouds: If True, merge thin+thick into single cloud class
|
| 149 |
apply_rules: If True, apply physical rules for bright clouds
|
| 150 |
-
max_direct_size: Maximum dimension for direct inference (default: 4096)
|
| 151 |
|
| 152 |
Returns:
|
| 153 |
Predicted class labels (H, W)
|
|
@@ -162,24 +163,15 @@ def predict_large(
|
|
| 162 |
merge_clouds = model.merge_clouds
|
| 163 |
|
| 164 |
C, H, W = image.shape
|
| 165 |
-
max_dim = max(H, W)
|
| 166 |
-
|
| 167 |
-
# === AUTO CHUNK SIZE ===
|
| 168 |
-
if chunk_size is None:
|
| 169 |
-
if max_dim <= 1024:
|
| 170 |
-
chunk_size = max_dim # Process entire image
|
| 171 |
-
elif max_dim <= 2048:
|
| 172 |
-
chunk_size = 1024
|
| 173 |
-
else:
|
| 174 |
-
chunk_size = 512
|
| 175 |
|
| 176 |
# Set default overlap
|
| 177 |
if overlap is None:
|
| 178 |
overlap = chunk_size // 2
|
| 179 |
|
| 180 |
-
# ===
|
| 181 |
-
#
|
| 182 |
-
if
|
|
|
|
| 183 |
with torch.no_grad():
|
| 184 |
img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
|
| 185 |
logits = model(img_tensor)
|
|
@@ -199,8 +191,7 @@ def predict_large(
|
|
| 199 |
|
| 200 |
return pred
|
| 201 |
|
| 202 |
-
# === SLIDING WINDOW FOR
|
| 203 |
-
print(f" Using sliding window: {H}x{W} -> chunks={chunk_size}, overlap={overlap}")
|
| 204 |
|
| 205 |
step = chunk_size - overlap
|
| 206 |
|
|
@@ -230,7 +221,7 @@ def predict_large(
|
|
| 230 |
# Create blending window
|
| 231 |
window = get_spline_window(chunk_size, power=2)
|
| 232 |
|
| 233 |
-
# Generate tile coordinates
|
| 234 |
coords = []
|
| 235 |
for r in range(0, H_pad - chunk_size + 1, step):
|
| 236 |
for c in range(0, W_pad - chunk_size + 1, step):
|
|
|
|
| 125 |
def predict_large(
|
| 126 |
image: np.ndarray,
|
| 127 |
model: nn.Module,
|
| 128 |
+
chunk_size: int = 512,
|
| 129 |
overlap: int = None,
|
| 130 |
batch_size: int = 1,
|
| 131 |
device: str = "cpu",
|
| 132 |
merge_clouds: bool = False,
|
| 133 |
apply_rules: bool = False,
|
|
|
|
| 134 |
**kwargs
|
| 135 |
) -> np.ndarray:
|
| 136 |
"""
|
| 137 |
Predict on images of any size.
|
| 138 |
+
|
| 139 |
+
Strategy:
|
| 140 |
+
- Images ≤ 2048px in any dimension: direct inference (no tiling)
|
| 141 |
+
- Images > 2048px: sliding window with specified chunk_size
|
| 142 |
|
| 143 |
Args:
|
| 144 |
image: Input image (C, H, W) in reflectance [0, 1]
|
| 145 |
model: Loaded model from compiled_model()
|
| 146 |
+
chunk_size: Size of inference tiles for large images (default: 512)
|
| 147 |
+
overlap: Overlap between tiles (default: chunk_size // 2)
|
| 148 |
batch_size: Tiles per batch (default: 1)
|
| 149 |
device: 'cpu' or 'cuda'
|
| 150 |
merge_clouds: If True, merge thin+thick into single cloud class
|
| 151 |
apply_rules: If True, apply physical rules for bright clouds
|
|
|
|
| 152 |
|
| 153 |
Returns:
|
| 154 |
Predicted class labels (H, W)
|
|
|
|
| 163 |
merge_clouds = model.merge_clouds
|
| 164 |
|
| 165 |
C, H, W = image.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# Set default overlap
|
| 168 |
if overlap is None:
|
| 169 |
overlap = chunk_size // 2
|
| 170 |
|
| 171 |
+
# === STRATEGY: Use direct inference for images ≤ 2048px ===
|
| 172 |
+
# This avoids tiling artifacts on small/medium images
|
| 173 |
+
if max(H, W) <= 2048:
|
| 174 |
+
# Direct inference - no tiling
|
| 175 |
with torch.no_grad():
|
| 176 |
img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
|
| 177 |
logits = model(img_tensor)
|
|
|
|
| 191 |
|
| 192 |
return pred
|
| 193 |
|
| 194 |
+
# === SLIDING WINDOW FOR LARGE IMAGES (> 2048px) ===
|
|
|
|
| 195 |
|
| 196 |
step = chunk_size - overlap
|
| 197 |
|
|
|
|
| 221 |
# Create blending window
|
| 222 |
window = get_spline_window(chunk_size, power=2)
|
| 223 |
|
| 224 |
+
# Generate tile coordinates
|
| 225 |
coords = []
|
| 226 |
for r in range(0, H_pad - chunk_size + 1, step):
|
| 227 |
for c in range(0, W_pad - chunk_size + 1, step):
|