JulioContrerasH commited on
Commit
38bf451
·
verified ·
1 Parent(s): 58ad7ff

Update load.py

Browse files
Files changed (1) hide show
  1. load.py +44 -31
load.py CHANGED
@@ -41,7 +41,7 @@ class MSSSegmentationModel(pl.LightningModule):
41
 
42
 
43
  def get_spline_window(size: int, power: int = 2) -> np.ndarray:
44
- """Ventana Hann 2D para blending suave."""
45
  intersection = np.hanning(size)
46
  window_2d = np.outer(intersection, intersection)
47
  return (window_2d ** power).astype(np.float32)
@@ -52,22 +52,26 @@ def apply_physical_rules(
52
  image: np.ndarray,
53
  merge_clouds: bool = False,
54
  ) -> np.ndarray:
55
- """Regla física para nubes gruesas saturadas."""
56
  saturation_threshold = 0.35
57
 
58
  pred = pred.copy()
59
 
 
60
  nodata_mask = np.all(image == 0, axis=0)
61
 
 
62
  bright_b0 = image[0] > saturation_threshold
63
  bright_b1 = image[1] > saturation_threshold * 0.80
64
  saturated_mask = bright_b0 & bright_b1
65
 
 
66
  if merge_clouds:
67
- pred[saturated_mask] = 1
68
  else:
69
- pred[saturated_mask] = 2
70
 
 
71
  pred[nodata_mask] = 0
72
 
73
  return pred
@@ -121,8 +125,8 @@ def compiled_model(
121
  def predict_large(
122
  image: np.ndarray,
123
  model: nn.Module,
124
- chunk_size: int = 512,
125
- overlap: int = None,
126
  batch_size: int = 1,
127
  device: str = "cpu",
128
  merge_clouds: bool = False,
@@ -130,7 +134,7 @@ def predict_large(
130
  **kwargs
131
  ) -> np.ndarray:
132
  """
133
- Predict on large images using sliding window with overlap blending.
134
 
135
  Args:
136
  image: Input image (C, H, W) in reflectance [0, 1]
@@ -140,7 +144,7 @@ def predict_large(
140
  batch_size: Tiles per batch (default: 1)
141
  device: 'cpu' or 'cuda'
142
  merge_clouds: If True, merge thin+thick into single cloud class
143
- apply_rules: If True, apply physical rules for bright clouds (default: False)
144
 
145
  Returns:
146
  Predicted class labels (H, W)
@@ -150,7 +154,7 @@ def predict_large(
150
  model.eval()
151
  model.to(device)
152
 
153
- # Get merge_clouds from model if not specified
154
  if not hasattr(model, 'merge_clouds'):
155
  model.merge_clouds = merge_clouds
156
  else:
@@ -158,23 +162,22 @@ def predict_large(
158
 
159
  C, H, W = image.shape
160
 
161
- # Set default overlap if not specified
162
  if overlap is None:
163
  overlap = chunk_size // 2
164
 
165
- # Direct inference if image fits within chunk_size
166
  if H <= chunk_size and W <= chunk_size:
167
  with torch.no_grad():
168
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
169
  logits = model(img_tensor)
170
 
171
  if merge_clouds:
172
- # Merge thin+thick clouds into single cloud class
173
  probs = torch.softmax(logits, dim=1)
174
  probs_merged = torch.zeros(1, 3, H, W, device=device)
175
- probs_merged[:, 0] = probs[:, 0] # Clear
176
- probs_merged[:, 1] = probs[:, 1] + probs[:, 2] # Cloud (thin+thick)
177
- probs_merged[:, 2] = probs[:, 3] # Shadow
178
  pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
179
  else:
180
  pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
@@ -184,14 +187,23 @@ def predict_large(
184
 
185
  return pred
186
 
187
- # Sliding window inference for larger images
 
 
188
  step = chunk_size - overlap
189
- half_tile = chunk_size // 2
190
 
191
- # Pad image to ensure tiles cover entire image
 
 
 
 
 
 
 
 
192
  image_padded = np.pad(
193
  image,
194
- ((0, 0), (half_tile, half_tile + chunk_size), (half_tile, half_tile + chunk_size)),
195
  mode="reflect"
196
  )
197
 
@@ -206,15 +218,14 @@ def predict_large(
206
  window = get_spline_window(chunk_size, power=2)
207
 
208
  # Generate tile coordinates
209
- coords = [
210
- (r, c)
211
- for r in range(0, H_pad - chunk_size + 1, step)
212
- for c in range(0, W_pad - chunk_size + 1, step)
213
- ]
214
 
215
  # Process tiles in batches
216
  with torch.no_grad():
217
- for i in tqdm(range(0, len(coords), batch_size), desc=" Tiles", leave=False, disable=True):
218
  batch_coords = coords[i:i + batch_size]
219
 
220
  # Extract tiles
@@ -237,16 +248,15 @@ def predict_large(
237
  weight_sum = np.maximum(weight_sum, 1e-8)
238
  probs_final = probs_sum / weight_sum
239
 
240
- # Remove padding
241
- probs_final = probs_final[:, half_tile:half_tile + H, half_tile:half_tile + W]
242
 
243
  # Get final prediction
244
  if merge_clouds:
245
- # Merge thin+thick clouds into single cloud class
246
  probs_merged = np.zeros((3, H, W), dtype=np.float32)
247
- probs_merged[0] = probs_final[0] # Clear
248
- probs_merged[1] = probs_final[1] + probs_final[2] # Cloud (thin+thick)
249
- probs_merged[2] = probs_final[3] # Shadow
250
  pred = np.argmax(probs_merged, axis=0).astype(np.uint8)
251
  else:
252
  pred = np.argmax(probs_final, axis=0).astype(np.uint8)
@@ -297,16 +307,19 @@ def display_results(
297
 
298
  fig, axes = plt.subplots(1, 2, figsize=(12, 5))
299
 
 
300
  rgb = np.stack([image[1], image[0], image[2]], axis=-1)
301
  rgb = np.clip(rgb * 3, 0, 1)
302
  axes[0].imshow(rgb)
303
  axes[0].set_title("MSS RGB Composite")
304
  axes[0].axis('off')
305
 
 
306
  im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=len(labels)-1)
307
  axes[1].set_title("Cloud Detection")
308
  axes[1].axis('off')
309
 
 
310
  cbar = plt.colorbar(im, ax=axes[1], ticks=range(len(labels)))
311
  cbar.ax.set_yticklabels(labels)
312
 
 
41
 
42
 
43
  def get_spline_window(size: int, power: int = 2) -> np.ndarray:
44
+ """Hann window for smooth blending."""
45
  intersection = np.hanning(size)
46
  window_2d = np.outer(intersection, intersection)
47
  return (window_2d ** power).astype(np.float32)
 
52
  image: np.ndarray,
53
  merge_clouds: bool = False,
54
  ) -> np.ndarray:
55
+ """Apply physical rules for saturated thick clouds."""
56
  saturation_threshold = 0.35
57
 
58
  pred = pred.copy()
59
 
60
+ # Nodata mask
61
  nodata_mask = np.all(image == 0, axis=0)
62
 
63
+ # Saturated clouds (high values in visible bands)
64
  bright_b0 = image[0] > saturation_threshold
65
  bright_b1 = image[1] > saturation_threshold * 0.80
66
  saturated_mask = bright_b0 & bright_b1
67
 
68
+ # Assign thick cloud class
69
  if merge_clouds:
70
+ pred[saturated_mask] = 1 # Cloud (merged)
71
  else:
72
+ pred[saturated_mask] = 2 # Thick cloud
73
 
74
+ # Set nodata to clear
75
  pred[nodata_mask] = 0
76
 
77
  return pred
 
125
  def predict_large(
126
  image: np.ndarray,
127
  model: nn.Module,
128
+ chunk_size: int = 512,
129
+ overlap: int = None,
130
  batch_size: int = 1,
131
  device: str = "cpu",
132
  merge_clouds: bool = False,
 
134
  **kwargs
135
  ) -> np.ndarray:
136
  """
137
+ Predict on images of any size using sliding window with smooth blending.
138
 
139
  Args:
140
  image: Input image (C, H, W) in reflectance [0, 1]
 
144
  batch_size: Tiles per batch (default: 1)
145
  device: 'cpu' or 'cuda'
146
  merge_clouds: If True, merge thin+thick into single cloud class
147
+ apply_rules: If True, apply physical rules for bright clouds
148
 
149
  Returns:
150
  Predicted class labels (H, W)
 
154
  model.eval()
155
  model.to(device)
156
 
157
+ # Get merge_clouds setting from model if available
158
  if not hasattr(model, 'merge_clouds'):
159
  model.merge_clouds = merge_clouds
160
  else:
 
162
 
163
  C, H, W = image.shape
164
 
165
+ # Set default overlap
166
  if overlap is None:
167
  overlap = chunk_size // 2
168
 
169
+ # Direct inference for small images
170
  if H <= chunk_size and W <= chunk_size:
171
  with torch.no_grad():
172
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
173
  logits = model(img_tensor)
174
 
175
  if merge_clouds:
 
176
  probs = torch.softmax(logits, dim=1)
177
  probs_merged = torch.zeros(1, 3, H, W, device=device)
178
+ probs_merged[:, 0] = probs[:, 0]
179
+ probs_merged[:, 1] = probs[:, 1] + probs[:, 2]
180
+ probs_merged[:, 2] = probs[:, 3]
181
  pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
182
  else:
183
  pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
 
187
 
188
  return pred
189
 
190
+ # === SLIDING WINDOW FOR LARGER IMAGES ===
191
+
192
+ # Calculate padding needed to make image divisible by step
193
  step = chunk_size - overlap
 
194
 
195
+ # Padding to ensure tiles cover the entire image
196
+ pad_h = (step - (H % step)) % step
197
+ pad_w = (step - (W % step)) % step
198
+
199
+ # Add extra overlap padding on all sides for smooth edges
200
+ pad_h += overlap
201
+ pad_w += overlap
202
+
203
+ # Pad image
204
  image_padded = np.pad(
205
  image,
206
+ ((0, 0), (overlap // 2, pad_h - overlap // 2), (overlap // 2, pad_w - overlap // 2)),
207
  mode="reflect"
208
  )
209
 
 
218
  window = get_spline_window(chunk_size, power=2)
219
 
220
  # Generate tile coordinates
221
+ coords = []
222
+ for r in range(0, H_pad - chunk_size + 1, step):
223
+ for c in range(0, W_pad - chunk_size + 1, step):
224
+ coords.append((r, c))
 
225
 
226
  # Process tiles in batches
227
  with torch.no_grad():
228
+ for i in range(0, len(coords), batch_size):
229
  batch_coords = coords[i:i + batch_size]
230
 
231
  # Extract tiles
 
248
  weight_sum = np.maximum(weight_sum, 1e-8)
249
  probs_final = probs_sum / weight_sum
250
 
251
+ # Remove padding to get back to original size
252
+ probs_final = probs_final[:, overlap // 2:overlap // 2 + H, overlap // 2:overlap // 2 + W]
253
 
254
  # Get final prediction
255
  if merge_clouds:
 
256
  probs_merged = np.zeros((3, H, W), dtype=np.float32)
257
+ probs_merged[0] = probs_final[0]
258
+ probs_merged[1] = probs_final[1] + probs_final[2]
259
+ probs_merged[2] = probs_final[3]
260
  pred = np.argmax(probs_merged, axis=0).astype(np.uint8)
261
  else:
262
  pred = np.argmax(probs_final, axis=0).astype(np.uint8)
 
307
 
308
  fig, axes = plt.subplots(1, 2, figsize=(12, 5))
309
 
310
+ # RGB composite
311
  rgb = np.stack([image[1], image[0], image[2]], axis=-1)
312
  rgb = np.clip(rgb * 3, 0, 1)
313
  axes[0].imshow(rgb)
314
  axes[0].set_title("MSS RGB Composite")
315
  axes[0].axis('off')
316
 
317
+ # Prediction
318
  im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=len(labels)-1)
319
  axes[1].set_title("Cloud Detection")
320
  axes[1].axis('off')
321
 
322
+ # Colorbar
323
  cbar = plt.colorbar(im, ax=axes[1], ticks=range(len(labels)))
324
  cbar.ax.set_yticklabels(labels)
325