JulioContrerasH commited on
Commit
ebf386f
·
verified ·
1 Parent(s): 62265ed

Update load.py

Browse files
Files changed (1) hide show
  1. load.py +20 -51
load.py CHANGED
@@ -1,8 +1,3 @@
1
- """
2
- Load and inference functions for MSS Cloud Detection Model
3
- Compatible with mlstac package
4
- """
5
-
6
  import torch
7
  import torch.nn as nn
8
  import numpy as np
@@ -130,56 +125,37 @@ def predict_large(
130
  device: str = "cpu",
131
  merge_clouds: bool = False,
132
  apply_rules: bool = False,
133
- max_direct_size: int = 1024, # Safe for 2GB GPU
134
  **kwargs
135
  ) -> np.ndarray:
136
  """
137
  Predict on images of any size.
138
 
139
- Strategy:
140
- - Small images (≤ max_direct_size): direct inference without tiling
141
- Examples: 256x256, 512x512, 1024x1024 (safe for 2GB GPU)
142
- - Large images (> max_direct_size): sliding window with overlapping tiles
143
- Examples: 2048x2048, 5000x5000, 22000x22000
144
-
145
- Args:
146
- image: Input image (C, H, W) in reflectance [0, 1]
147
- model: Loaded model from compiled_model()
148
- chunk_size: Tile size for large images (default: 512)
149
- overlap: Overlap between tiles (default: chunk_size // 2)
150
- batch_size: Tiles per batch (default: 1)
151
- device: 'cpu' or 'cuda'
152
- merge_clouds: If True, merge thin+thick into single cloud class
153
- apply_rules: If True, apply physical rules for bright clouds
154
- max_direct_size: Max dimension for direct inference (default: 1024)
155
- Set to 2048 for GPUs with ≥8GB VRAM
156
-
157
- Returns:
158
- Predicted class labels (H, W)
159
  """
160
  model.eval()
161
  model.to(device)
162
 
163
- # Get merge_clouds setting from model if available
164
- if not hasattr(model, 'merge_clouds'):
165
- model.merge_clouds = merge_clouds
166
- else:
167
- merge_clouds = model.merge_clouds
168
 
169
  C, H, W = image.shape
170
 
171
- # Set default overlap
172
  if overlap is None:
173
  overlap = chunk_size // 2
174
 
175
  # === DIRECT INFERENCE FOR SMALL IMAGES ===
176
- # Safe for GPUs with limited VRAM (2-4GB)
177
  if max(H, W) <= max_direct_size:
178
  with torch.no_grad():
179
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
180
  logits = model(img_tensor)
181
 
182
- if merge_clouds:
 
 
 
 
183
  probs = torch.softmax(logits, dim=1)
184
  probs_merged = torch.zeros(1, 3, H, W, device=device)
185
  probs_merged[:, 0] = probs[:, 0] # Clear
@@ -187,10 +163,11 @@ def predict_large(
187
  probs_merged[:, 2] = probs[:, 3] # Shadow
188
  pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
189
  else:
 
190
  pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
191
 
192
  if apply_rules:
193
- pred = apply_physical_rules(pred, image, merge_clouds)
194
 
195
  return pred
196
 
@@ -198,11 +175,9 @@ def predict_large(
198
 
199
  step = chunk_size - overlap
200
 
201
- # Calculate required padding
202
  pad_h = (step - (H - chunk_size) % step) % step
203
  pad_w = (step - (W - chunk_size) % step) % step
204
 
205
- # Symmetric padding
206
  pad_top = pad_h // 2
207
  pad_bottom = pad_h - pad_top
208
  pad_left = pad_w // 2
@@ -216,50 +191,45 @@ def predict_large(
216
 
217
  _, H_pad, W_pad = image_padded.shape
218
 
219
- # Initialize accumulation buffers
220
- num_classes = 4
221
  probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32)
222
  weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32)
223
 
224
- # Create blending window
225
  window = get_spline_window(chunk_size, power=2)
226
 
227
- # Generate tile coordinates
228
  coords = []
229
  for r in range(0, H_pad - chunk_size + 1, step):
230
  for c in range(0, W_pad - chunk_size + 1, step):
231
  coords.append((r, c))
232
 
233
- # Process tiles in batches
234
  with torch.no_grad():
235
  for i in range(0, len(coords), batch_size):
236
  batch_coords = coords[i:i + batch_size]
237
 
238
- # Extract tiles
239
  tiles = np.stack([
240
  image_padded[:, r:r + chunk_size, c:c + chunk_size]
241
  for r, c in batch_coords
242
  ])
243
 
244
- # Run inference
245
  tiles_tensor = torch.from_numpy(tiles).float().to(device)
246
  logits = model(tiles_tensor)
247
  probs = torch.softmax(logits, dim=1).cpu().numpy()
248
 
249
- # Accumulate weighted predictions
250
  for j, (r, c) in enumerate(batch_coords):
251
  probs_sum[:, r:r + chunk_size, c:c + chunk_size] += probs[j] * window
252
  weight_sum[r:r + chunk_size, c:c + chunk_size] += window
253
 
254
- # Normalize by accumulated weights
255
  weight_sum = np.maximum(weight_sum, 1e-8)
256
  probs_final = probs_sum / weight_sum
257
 
258
- # Remove padding to restore original size
259
  probs_final = probs_final[:, pad_top:pad_top + H, pad_left:pad_left + W]
260
 
261
- # Get final prediction
262
- if merge_clouds:
 
 
 
 
263
  probs_merged = np.zeros((3, H, W), dtype=np.float32)
264
  probs_merged[0] = probs_final[0]
265
  probs_merged[1] = probs_final[1] + probs_final[2]
@@ -268,9 +238,8 @@ def predict_large(
268
  else:
269
  pred = np.argmax(probs_final, axis=0).astype(np.uint8)
270
 
271
- # Apply physical rules if requested
272
  if apply_rules:
273
- pred = apply_physical_rules(pred, image, merge_clouds)
274
 
275
  return pred
276
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import numpy as np
 
125
  device: str = "cpu",
126
  merge_clouds: bool = False,
127
  apply_rules: bool = False,
128
+ max_direct_size: int = 1024,
129
  **kwargs
130
  ) -> np.ndarray:
131
  """
132
  Predict on images of any size.
133
 
134
+ Automatically detects if model has 3 or 4 classes.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  """
136
  model.eval()
137
  model.to(device)
138
 
139
+ # Detect number of classes in the model
140
+ num_classes = model.hparams.get('num_classes', 4)
141
+ is_3class_model = (num_classes == 3)
 
 
142
 
143
  C, H, W = image.shape
144
 
 
145
  if overlap is None:
146
  overlap = chunk_size // 2
147
 
148
  # === DIRECT INFERENCE FOR SMALL IMAGES ===
 
149
  if max(H, W) <= max_direct_size:
150
  with torch.no_grad():
151
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
152
  logits = model(img_tensor)
153
 
154
+ if is_3class_model:
155
+ # The model already has 3 classes: 0=clear, 1=cloud, 2=shadow
156
+ pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
157
+ elif merge_clouds:
158
+ # Model 4 classes → merge to 3
159
  probs = torch.softmax(logits, dim=1)
160
  probs_merged = torch.zeros(1, 3, H, W, device=device)
161
  probs_merged[:, 0] = probs[:, 0] # Clear
 
163
  probs_merged[:, 2] = probs[:, 3] # Shadow
164
  pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
165
  else:
166
+ # Model 4 classes without merge
167
  pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
168
 
169
  if apply_rules:
170
+ pred = apply_physical_rules(pred, image, merge_clouds=is_3class_model or merge_clouds)
171
 
172
  return pred
173
 
 
175
 
176
  step = chunk_size - overlap
177
 
 
178
  pad_h = (step - (H - chunk_size) % step) % step
179
  pad_w = (step - (W - chunk_size) % step) % step
180
 
 
181
  pad_top = pad_h // 2
182
  pad_bottom = pad_h - pad_top
183
  pad_left = pad_w // 2
 
191
 
192
  _, H_pad, W_pad = image_padded.shape
193
 
194
+ # Buffers according to number of classes
 
195
  probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32)
196
  weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32)
197
 
 
198
  window = get_spline_window(chunk_size, power=2)
199
 
 
200
  coords = []
201
  for r in range(0, H_pad - chunk_size + 1, step):
202
  for c in range(0, W_pad - chunk_size + 1, step):
203
  coords.append((r, c))
204
 
 
205
  with torch.no_grad():
206
  for i in range(0, len(coords), batch_size):
207
  batch_coords = coords[i:i + batch_size]
208
 
 
209
  tiles = np.stack([
210
  image_padded[:, r:r + chunk_size, c:c + chunk_size]
211
  for r, c in batch_coords
212
  ])
213
 
 
214
  tiles_tensor = torch.from_numpy(tiles).float().to(device)
215
  logits = model(tiles_tensor)
216
  probs = torch.softmax(logits, dim=1).cpu().numpy()
217
 
 
218
  for j, (r, c) in enumerate(batch_coords):
219
  probs_sum[:, r:r + chunk_size, c:c + chunk_size] += probs[j] * window
220
  weight_sum[r:r + chunk_size, c:c + chunk_size] += window
221
 
 
222
  weight_sum = np.maximum(weight_sum, 1e-8)
223
  probs_final = probs_sum / weight_sum
224
 
 
225
  probs_final = probs_final[:, pad_top:pad_top + H, pad_left:pad_left + W]
226
 
227
+ # Final forecast
228
+ if is_3class_model:
229
+ # It already has 3 classes
230
+ pred = np.argmax(probs_final, axis=0).astype(np.uint8)
231
+ elif merge_clouds:
232
+ # Merge 4 → 3
233
  probs_merged = np.zeros((3, H, W), dtype=np.float32)
234
  probs_merged[0] = probs_final[0]
235
  probs_merged[1] = probs_final[1] + probs_final[2]
 
238
  else:
239
  pred = np.argmax(probs_final, axis=0).astype(np.uint8)
240
 
 
241
  if apply_rules:
242
+ pred = apply_physical_rules(pred, image, merge_clouds=is_3class_model or merge_clouds)
243
 
244
  return pred
245