Update load.py
Browse files
load.py
CHANGED
|
@@ -121,7 +121,6 @@ def compiled_model(
|
|
| 121 |
|
| 122 |
return model
|
| 123 |
|
| 124 |
-
|
| 125 |
def predict_large(
|
| 126 |
image: np.ndarray,
|
| 127 |
model: nn.Module,
|
|
@@ -131,24 +130,29 @@ def predict_large(
|
|
| 131 |
device: str = "cpu",
|
| 132 |
merge_clouds: bool = False,
|
| 133 |
apply_rules: bool = False,
|
|
|
|
| 134 |
**kwargs
|
| 135 |
) -> np.ndarray:
|
| 136 |
"""
|
| 137 |
Predict on images of any size.
|
| 138 |
|
| 139 |
Strategy:
|
| 140 |
-
-
|
| 141 |
-
|
|
|
|
|
|
|
| 142 |
|
| 143 |
Args:
|
| 144 |
image: Input image (C, H, W) in reflectance [0, 1]
|
| 145 |
model: Loaded model from compiled_model()
|
| 146 |
-
chunk_size:
|
| 147 |
overlap: Overlap between tiles (default: chunk_size // 2)
|
| 148 |
batch_size: Tiles per batch (default: 1)
|
| 149 |
device: 'cpu' or 'cuda'
|
| 150 |
merge_clouds: If True, merge thin+thick into single cloud class
|
| 151 |
apply_rules: If True, apply physical rules for bright clouds
|
|
|
|
|
|
|
| 152 |
|
| 153 |
Returns:
|
| 154 |
Predicted class labels (H, W)
|
|
@@ -168,10 +172,9 @@ def predict_large(
|
|
| 168 |
if overlap is None:
|
| 169 |
overlap = chunk_size // 2
|
| 170 |
|
| 171 |
-
# ===
|
| 172 |
-
#
|
| 173 |
-
if max(H, W) <=
|
| 174 |
-
# Direct inference - no tiling
|
| 175 |
with torch.no_grad():
|
| 176 |
img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
|
| 177 |
logits = model(img_tensor)
|
|
@@ -179,9 +182,9 @@ def predict_large(
|
|
| 179 |
if merge_clouds:
|
| 180 |
probs = torch.softmax(logits, dim=1)
|
| 181 |
probs_merged = torch.zeros(1, 3, H, W, device=device)
|
| 182 |
-
probs_merged[:, 0] = probs[:, 0]
|
| 183 |
-
probs_merged[:, 1] = probs[:, 1] + probs[:, 2]
|
| 184 |
-
probs_merged[:, 2] = probs[:, 3]
|
| 185 |
pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
|
| 186 |
else:
|
| 187 |
pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
|
|
@@ -191,7 +194,7 @@ def predict_large(
|
|
| 191 |
|
| 192 |
return pred
|
| 193 |
|
| 194 |
-
# === SLIDING WINDOW FOR LARGE IMAGES
|
| 195 |
|
| 196 |
step = chunk_size - overlap
|
| 197 |
|
|
|
|
| 121 |
|
| 122 |
return model
|
| 123 |
|
|
|
|
| 124 |
def predict_large(
|
| 125 |
image: np.ndarray,
|
| 126 |
model: nn.Module,
|
|
|
|
| 130 |
device: str = "cpu",
|
| 131 |
merge_clouds: bool = False,
|
| 132 |
apply_rules: bool = False,
|
| 133 |
+
max_direct_size: int = 1024, # Safe for 2GB GPU
|
| 134 |
**kwargs
|
| 135 |
) -> np.ndarray:
|
| 136 |
"""
|
| 137 |
Predict on images of any size.
|
| 138 |
|
| 139 |
Strategy:
|
| 140 |
+
- Small images (≤ max_direct_size): direct inference without tiling
|
| 141 |
+
Examples: 256x256, 512x512, 1024x1024 (safe for 2GB GPU)
|
| 142 |
+
- Large images (> max_direct_size): sliding window with overlapping tiles
|
| 143 |
+
Examples: 2048x2048, 5000x5000, 22000x22000
|
| 144 |
|
| 145 |
Args:
|
| 146 |
image: Input image (C, H, W) in reflectance [0, 1]
|
| 147 |
model: Loaded model from compiled_model()
|
| 148 |
+
chunk_size: Tile size for large images (default: 512)
|
| 149 |
overlap: Overlap between tiles (default: chunk_size // 2)
|
| 150 |
batch_size: Tiles per batch (default: 1)
|
| 151 |
device: 'cpu' or 'cuda'
|
| 152 |
merge_clouds: If True, merge thin+thick into single cloud class
|
| 153 |
apply_rules: If True, apply physical rules for bright clouds
|
| 154 |
+
max_direct_size: Max dimension for direct inference (default: 1024)
|
| 155 |
+
Set to 2048 for GPUs with ≥8GB VRAM
|
| 156 |
|
| 157 |
Returns:
|
| 158 |
Predicted class labels (H, W)
|
|
|
|
| 172 |
if overlap is None:
|
| 173 |
overlap = chunk_size // 2
|
| 174 |
|
| 175 |
+
# === DIRECT INFERENCE FOR SMALL IMAGES ===
|
| 176 |
+
# Safe for GPUs with limited VRAM (2-4GB)
|
| 177 |
+
if max(H, W) <= max_direct_size:
|
|
|
|
| 178 |
with torch.no_grad():
|
| 179 |
img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
|
| 180 |
logits = model(img_tensor)
|
|
|
|
| 182 |
if merge_clouds:
|
| 183 |
probs = torch.softmax(logits, dim=1)
|
| 184 |
probs_merged = torch.zeros(1, 3, H, W, device=device)
|
| 185 |
+
probs_merged[:, 0] = probs[:, 0] # Clear
|
| 186 |
+
probs_merged[:, 1] = probs[:, 1] + probs[:, 2] # Cloud
|
| 187 |
+
probs_merged[:, 2] = probs[:, 3] # Shadow
|
| 188 |
pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
|
| 189 |
else:
|
| 190 |
pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
|
|
|
|
| 194 |
|
| 195 |
return pred
|
| 196 |
|
| 197 |
+
# === SLIDING WINDOW FOR LARGE IMAGES ===
|
| 198 |
|
| 199 |
step = chunk_size - overlap
|
| 200 |
|