JulioContrerasH commited on
Commit
6699c52
·
verified ·
1 Parent(s): f122baf

Update load.py

Browse files
Files changed (1) hide show
  1. load.py +15 -12
load.py CHANGED
@@ -121,7 +121,6 @@ def compiled_model(
121
 
122
  return model
123
 
124
-
125
  def predict_large(
126
  image: np.ndarray,
127
  model: nn.Module,
@@ -131,24 +130,29 @@ def predict_large(
131
  device: str = "cpu",
132
  merge_clouds: bool = False,
133
  apply_rules: bool = False,
 
134
  **kwargs
135
  ) -> np.ndarray:
136
  """
137
  Predict on images of any size.
138
 
139
  Strategy:
140
- - Images2048px in any dimension: direct inference (no tiling)
141
- - Images > 2048px: sliding window with specified chunk_size
 
 
142
 
143
  Args:
144
  image: Input image (C, H, W) in reflectance [0, 1]
145
  model: Loaded model from compiled_model()
146
- chunk_size: Size of inference tiles for large images (default: 512)
147
  overlap: Overlap between tiles (default: chunk_size // 2)
148
  batch_size: Tiles per batch (default: 1)
149
  device: 'cpu' or 'cuda'
150
  merge_clouds: If True, merge thin+thick into single cloud class
151
  apply_rules: If True, apply physical rules for bright clouds
 
 
152
 
153
  Returns:
154
  Predicted class labels (H, W)
@@ -168,10 +172,9 @@ def predict_large(
168
  if overlap is None:
169
  overlap = chunk_size // 2
170
 
171
- # === STRATEGY: Use direct inference for images ≤ 2048px ===
172
- # This avoids tiling artifacts on small/medium images
173
- if max(H, W) <= 2048:
174
- # Direct inference - no tiling
175
  with torch.no_grad():
176
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
177
  logits = model(img_tensor)
@@ -179,9 +182,9 @@ def predict_large(
179
  if merge_clouds:
180
  probs = torch.softmax(logits, dim=1)
181
  probs_merged = torch.zeros(1, 3, H, W, device=device)
182
- probs_merged[:, 0] = probs[:, 0]
183
- probs_merged[:, 1] = probs[:, 1] + probs[:, 2]
184
- probs_merged[:, 2] = probs[:, 3]
185
  pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
186
  else:
187
  pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
@@ -191,7 +194,7 @@ def predict_large(
191
 
192
  return pred
193
 
194
- # === SLIDING WINDOW FOR LARGE IMAGES (> 2048px) ===
195
 
196
  step = chunk_size - overlap
197
 
 
121
 
122
  return model
123
 
 
124
  def predict_large(
125
  image: np.ndarray,
126
  model: nn.Module,
 
130
  device: str = "cpu",
131
  merge_clouds: bool = False,
132
  apply_rules: bool = False,
133
+ max_direct_size: int = 1024, # Safe for 2GB GPU
134
  **kwargs
135
  ) -> np.ndarray:
136
  """
137
  Predict on images of any size.
138
 
139
  Strategy:
140
+ - Small images (max_direct_size): direct inference without tiling
141
+ Examples: 256x256, 512x512, 1024x1024 (safe for 2GB GPU)
142
+ - Large images (> max_direct_size): sliding window with overlapping tiles
143
+ Examples: 2048x2048, 5000x5000, 22000x22000
144
 
145
  Args:
146
  image: Input image (C, H, W) in reflectance [0, 1]
147
  model: Loaded model from compiled_model()
148
+ chunk_size: Tile size for large images (default: 512)
149
  overlap: Overlap between tiles (default: chunk_size // 2)
150
  batch_size: Tiles per batch (default: 1)
151
  device: 'cpu' or 'cuda'
152
  merge_clouds: If True, merge thin+thick into single cloud class
153
  apply_rules: If True, apply physical rules for bright clouds
154
+ max_direct_size: Max dimension for direct inference (default: 1024)
155
+ Set to 2048 for GPUs with ≥8GB VRAM
156
 
157
  Returns:
158
  Predicted class labels (H, W)
 
172
  if overlap is None:
173
  overlap = chunk_size // 2
174
 
175
+ # === DIRECT INFERENCE FOR SMALL IMAGES ===
176
+ # Safe for GPUs with limited VRAM (2-4GB)
177
+ if max(H, W) <= max_direct_size:
 
178
  with torch.no_grad():
179
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
180
  logits = model(img_tensor)
 
182
  if merge_clouds:
183
  probs = torch.softmax(logits, dim=1)
184
  probs_merged = torch.zeros(1, 3, H, W, device=device)
185
+ probs_merged[:, 0] = probs[:, 0] # Clear
186
+ probs_merged[:, 1] = probs[:, 1] + probs[:, 2] # Cloud
187
+ probs_merged[:, 2] = probs[:, 3] # Shadow
188
  pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
189
  else:
190
  pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
 
194
 
195
  return pred
196
 
197
+ # === SLIDING WINDOW FOR LARGE IMAGES ===
198
 
199
  step = chunk_size - overlap
200