JulioContrerasH commited on
Commit
919f17a
·
verified ·
1 Parent(s): caa117e

Update load.py

Browse files
Files changed (1) hide show
  1. load.py +31 -96
load.py CHANGED
@@ -7,16 +7,11 @@ import torch
7
  import torch.nn as nn
8
  import numpy as np
9
  from pathlib import Path
10
- from typing import Tuple, Optional
11
  import pytorch_lightning as pl
12
  import segmentation_models_pytorch as smp
13
  from tqdm import tqdm
14
 
15
 
16
- # ============================================================================
17
- # MODEL DEFINITION (copied from your model.py)
18
- # ============================================================================
19
-
20
  class MSSSegmentationModel(pl.LightningModule):
21
  """UNet para cloud segmentation en MSS."""
22
 
@@ -45,12 +40,8 @@ class MSSSegmentationModel(pl.LightningModule):
45
  return self.model(x)
46
 
47
 
48
- # ============================================================================
49
- # INFERENCE UTILITIES
50
- # ============================================================================
51
-
52
  def get_spline_window(size: int, power: int = 2) -> np.ndarray:
53
- """Generate Hann window for smooth blending."""
54
  intersection = np.hanning(size)
55
  window_2d = np.outer(intersection, intersection)
56
  return (window_2d ** power).astype(np.float32)
@@ -60,42 +51,28 @@ def apply_physical_rules(
60
  pred: np.ndarray,
61
  image: np.ndarray,
62
  merge_clouds: bool = False,
63
- saturation_threshold: float = 0.35,
64
  ) -> np.ndarray:
65
- """
66
- Apply physical rules for better cloud detection.
67
 
68
- Args:
69
- pred: Predicted classes (H, W)
70
- image: Input image (4, H, W) in reflectance [0, 1]
71
- merge_clouds: If True, merge thin+thick into single cloud class
72
- saturation_threshold: Threshold for detecting saturated bright clouds
73
- """
74
  pred = pred.copy()
75
 
76
- # Mask nodata pixels
77
  nodata_mask = np.all(image == 0, axis=0)
78
- pred[nodata_mask] = 0
79
 
80
- # Detect very bright pixels (likely thick clouds)
81
  bright_b0 = image[0] > saturation_threshold
82
  bright_b1 = image[1] > saturation_threshold * 0.80
83
  saturated_mask = bright_b0 & bright_b1
84
 
85
  if merge_clouds:
86
- # Set to cloud (1)
87
  pred[saturated_mask] = 1
88
  else:
89
- # Set to thick cloud (2)
90
  pred[saturated_mask] = 2
91
 
 
 
92
  return pred
93
 
94
 
95
- # ============================================================================
96
- # MLSTAC-COMPATIBLE FUNCTIONS
97
- # ============================================================================
98
-
99
  def compiled_model(
100
  model_dir: Path,
101
  stac_item=None,
@@ -110,20 +87,18 @@ def compiled_model(
110
  model_dir: Directory containing the .ckpt file
111
  stac_item: STAC item metadata (optional)
112
  device: 'cpu' or 'cuda'
113
- merge_clouds: If True, output will have 3 classes (clear, cloud, shadow)
114
- If False, output will have 4 classes (clear, thin, thick, shadow)
115
 
116
  Returns:
117
  Loaded model in eval mode
118
  """
119
- # Find checkpoint file
120
  ckpt_files = list(model_dir.glob("*.ckpt"))
121
  if not ckpt_files:
122
  raise FileNotFoundError(f"No .ckpt file found in {model_dir}")
123
 
124
  ckpt_path = ckpt_files[0]
125
 
126
- # Load model
127
  model = MSSSegmentationModel.load_from_checkpoint(
128
  ckpt_path,
129
  map_location=device
@@ -131,11 +106,9 @@ def compiled_model(
131
  model.eval()
132
  model.to(device)
133
 
134
- # Disable gradients
135
  for param in model.parameters():
136
  param.requires_grad = False
137
 
138
- # Store merge_clouds flag for predict_large
139
  model.merge_clouds = merge_clouds
140
 
141
  print(f"✅ Model loaded from {ckpt_path.name}")
@@ -152,9 +125,8 @@ def predict_large(
152
  overlap: int = 256,
153
  batch_size: int = 1,
154
  device: str = "cpu",
155
- nodata: float = 0.0,
156
- apply_rules: bool = True,
157
- saturation_threshold: float = 0.35,
158
  **kwargs
159
  ) -> np.ndarray:
160
  """
@@ -163,53 +135,51 @@ def predict_large(
163
  Args:
164
  image: Input image (C, H, W) in reflectance [0, 1]
165
  model: Loaded model from compiled_model()
166
- chunk_size: Size of inference tiles (default: 1024)
167
- overlap: Overlap between tiles for smooth blending (default: 256)
168
- batch_size: Number of tiles to process in parallel (default: 1)
169
  device: 'cpu' or 'cuda'
170
- nodata: Value representing no-data pixels
171
- apply_rules: Whether to apply physical rules post-processing
172
- saturation_threshold: Threshold for detecting bright clouds
173
 
174
  Returns:
175
- Predicted class labels (H, W) with shape matching input
176
  - If merge_clouds=False: 0=clear, 1=thin, 2=thick, 3=shadow
177
  - If merge_clouds=True: 0=clear, 1=cloud, 2=shadow
178
  """
179
  model.eval()
180
  model.to(device)
181
 
182
- merge_clouds = getattr(model, 'merge_clouds', False)
 
 
 
183
 
184
  C, H, W = image.shape
185
 
186
- # Direct inference for small images
187
  if H <= chunk_size and W <= chunk_size:
188
  with torch.no_grad():
189
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
190
  logits = model(img_tensor)
191
 
192
  if merge_clouds:
193
- # Merge thin(1) + thick(2) probabilities
194
  probs = torch.softmax(logits, dim=1)
195
  probs_merged = torch.zeros(1, 3, H, W, device=device)
196
- probs_merged[:, 0] = probs[:, 0] # clear
197
- probs_merged[:, 1] = probs[:, 1] + probs[:, 2] # cloud
198
- probs_merged[:, 2] = probs[:, 3] # shadow
199
  pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
200
  else:
201
  pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
202
 
203
  if apply_rules:
204
- pred = apply_physical_rules(pred, image, merge_clouds, saturation_threshold)
205
 
206
  return pred
207
 
208
- # Sliding window for large images
209
  step = chunk_size - overlap
210
  half_tile = chunk_size // 2
211
 
212
- # Pad image
213
  image_padded = np.pad(
214
  image,
215
  ((0, 0), (half_tile, half_tile + chunk_size), (half_tile, half_tile + chunk_size)),
@@ -218,82 +188,60 @@ def predict_large(
218
 
219
  _, H_pad, W_pad = image_padded.shape
220
 
221
- # Initialize accumulators - ALWAYS 4 classes, merge at the end if needed
222
  num_classes = 4
223
  probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32)
224
  weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32)
225
 
226
- # Blending window
227
  window = get_spline_window(chunk_size, power=2)
228
 
229
- # Generate tile coordinates
230
  coords = [
231
  (r, c)
232
  for r in range(0, H_pad - chunk_size + 1, step)
233
  for c in range(0, W_pad - chunk_size + 1, step)
234
  ]
235
 
236
- # Process tiles in batches
237
  with torch.no_grad():
238
  for i in tqdm(range(0, len(coords), batch_size), desc=" Tiles", leave=False, disable=True):
239
  batch_coords = coords[i:i + batch_size]
240
 
241
- # Extract tiles
242
  tiles = np.stack([
243
  image_padded[:, r:r + chunk_size, c:c + chunk_size]
244
  for r, c in batch_coords
245
  ])
246
 
247
- # Inference
248
  tiles_tensor = torch.from_numpy(tiles).float().to(device)
249
  logits = model(tiles_tensor)
250
  probs = torch.softmax(logits, dim=1).cpu().numpy()
251
 
252
- # Accumulate with blending - ALWAYS accumulate 4 classes
253
  for j, (r, c) in enumerate(batch_coords):
254
  probs_sum[:, r:r + chunk_size, c:c + chunk_size] += probs[j] * window
255
  weight_sum[r:r + chunk_size, c:c + chunk_size] += window
256
 
257
- # Normalize
258
  weight_sum = np.maximum(weight_sum, 1e-8)
259
  probs_final = probs_sum / weight_sum
260
 
261
- # Crop to original size
262
  probs_final = probs_final[:, half_tile:half_tile + H, half_tile:half_tile + W]
263
 
264
- # Merge classes if requested - AFTER normalization
265
  if merge_clouds:
266
  probs_merged = np.zeros((3, H, W), dtype=np.float32)
267
- probs_merged[0] = probs_final[0] # clear
268
- probs_merged[1] = probs_final[1] + probs_final[2] # cloud = thin + thick
269
- probs_merged[2] = probs_final[3] # shadow
270
  pred = np.argmax(probs_merged, axis=0).astype(np.uint8)
271
  else:
272
  pred = np.argmax(probs_final, axis=0).astype(np.uint8)
273
 
274
- # Apply physical rules
275
  if apply_rules:
276
- pred = apply_physical_rules(pred, image, merge_clouds, saturation_threshold)
277
 
278
  return pred
279
 
280
 
281
- # ============================================================================
282
- # OPTIONAL: EXAMPLE DATA AND VISUALIZATION
283
- # ============================================================================
284
-
285
  def example_data(model_dir: Path, **kwargs):
286
- """
287
- Load example data for testing (optional function).
288
-
289
- Returns:
290
- Example MSS image as numpy array (4, H, W)
291
- """
292
- # This is optional - you can provide a small example .npy file
293
  example_path = model_dir / "example_mss.npy"
294
 
295
  if not example_path.exists():
296
- # Return synthetic data if no example file
297
  print("⚠️ No example data found, generating synthetic")
298
  return np.random.rand(4, 512, 512).astype(np.float32) * 0.5
299
 
@@ -307,15 +255,7 @@ def display_results(
307
  stac_item=None,
308
  **kwargs
309
  ):
310
- """
311
- Display prediction results (optional visualization function).
312
-
313
- Args:
314
- model_dir: Model directory
315
- image: Input image (4, H, W)
316
- prediction: Predicted classes (H, W)
317
- stac_item: STAC metadata
318
- """
319
  try:
320
  import matplotlib.pyplot as plt
321
  from matplotlib.colors import ListedColormap
@@ -325,9 +265,8 @@ def display_results(
325
 
326
  merge_clouds = prediction.max() <= 2
327
 
328
- # Color maps
329
  if merge_clouds:
330
- colors = ['#2E7D32', '#FFFFFF', '#424242'] # clear, cloud, shadow
331
  labels = ['Clear', 'Cloud', 'Shadow']
332
  else:
333
  colors = ['#2E7D32', '#B3E5FC', '#FFFFFF', '#424242']
@@ -335,22 +274,18 @@ def display_results(
335
 
336
  cmap = ListedColormap(colors)
337
 
338
- # Plot
339
  fig, axes = plt.subplots(1, 2, figsize=(12, 5))
340
 
341
- # RGB composite (use bands 1, 0, 2 as RGB approximation)
342
  rgb = np.stack([image[1], image[0], image[2]], axis=-1)
343
- rgb = np.clip(rgb * 3, 0, 1) # Brighten for visibility
344
  axes[0].imshow(rgb)
345
  axes[0].set_title("MSS RGB Composite")
346
  axes[0].axis('off')
347
 
348
- # Prediction
349
  im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=len(labels)-1)
350
  axes[1].set_title("Cloud Detection")
351
  axes[1].axis('off')
352
 
353
- # Colorbar
354
  cbar = plt.colorbar(im, ax=axes[1], ticks=range(len(labels)))
355
  cbar.ax.set_yticklabels(labels)
356
 
 
7
  import torch.nn as nn
8
  import numpy as np
9
  from pathlib import Path
 
10
  import pytorch_lightning as pl
11
  import segmentation_models_pytorch as smp
12
  from tqdm import tqdm
13
 
14
 
 
 
 
 
15
  class MSSSegmentationModel(pl.LightningModule):
16
  """UNet para cloud segmentation en MSS."""
17
 
 
40
  return self.model(x)
41
 
42
 
 
 
 
 
43
  def get_spline_window(size: int, power: int = 2) -> np.ndarray:
44
+ """Ventana Hann 2D para blending suave."""
45
  intersection = np.hanning(size)
46
  window_2d = np.outer(intersection, intersection)
47
  return (window_2d ** power).astype(np.float32)
 
51
  pred: np.ndarray,
52
  image: np.ndarray,
53
  merge_clouds: bool = False,
 
54
  ) -> np.ndarray:
55
+ """Regla física para nubes gruesas saturadas."""
56
+ saturation_threshold = 0.35
57
 
 
 
 
 
 
 
58
  pred = pred.copy()
59
 
 
60
  nodata_mask = np.all(image == 0, axis=0)
 
61
 
 
62
  bright_b0 = image[0] > saturation_threshold
63
  bright_b1 = image[1] > saturation_threshold * 0.80
64
  saturated_mask = bright_b0 & bright_b1
65
 
66
  if merge_clouds:
 
67
  pred[saturated_mask] = 1
68
  else:
 
69
  pred[saturated_mask] = 2
70
 
71
+ pred[nodata_mask] = 0
72
+
73
  return pred
74
 
75
 
 
 
 
 
76
  def compiled_model(
77
  model_dir: Path,
78
  stac_item=None,
 
87
  model_dir: Directory containing the .ckpt file
88
  stac_item: STAC item metadata (optional)
89
  device: 'cpu' or 'cuda'
90
+ merge_clouds: If True, output 3 classes (clear, cloud, shadow)
91
+ If False, output 4 classes (clear, thin, thick, shadow)
92
 
93
  Returns:
94
  Loaded model in eval mode
95
  """
 
96
  ckpt_files = list(model_dir.glob("*.ckpt"))
97
  if not ckpt_files:
98
  raise FileNotFoundError(f"No .ckpt file found in {model_dir}")
99
 
100
  ckpt_path = ckpt_files[0]
101
 
 
102
  model = MSSSegmentationModel.load_from_checkpoint(
103
  ckpt_path,
104
  map_location=device
 
106
  model.eval()
107
  model.to(device)
108
 
 
109
  for param in model.parameters():
110
  param.requires_grad = False
111
 
 
112
  model.merge_clouds = merge_clouds
113
 
114
  print(f"✅ Model loaded from {ckpt_path.name}")
 
125
  overlap: int = 256,
126
  batch_size: int = 1,
127
  device: str = "cpu",
128
+ merge_clouds: bool = False,
129
+ apply_rules: bool = False,
 
130
  **kwargs
131
  ) -> np.ndarray:
132
  """
 
135
  Args:
136
  image: Input image (C, H, W) in reflectance [0, 1]
137
  model: Loaded model from compiled_model()
138
+ chunk_size: Size of inference tiles (default: 512)
139
+ overlap: Overlap between tiles (default: 256)
140
+ batch_size: Tiles per batch (default: 1)
141
  device: 'cpu' or 'cuda'
142
+ merge_clouds: If True, merge thin+thick into single cloud class
143
+ apply_rules: If True, apply physical rules for bright clouds (default: False)
 
144
 
145
  Returns:
146
+ Predicted class labels (H, W)
147
  - If merge_clouds=False: 0=clear, 1=thin, 2=thick, 3=shadow
148
  - If merge_clouds=True: 0=clear, 1=cloud, 2=shadow
149
  """
150
  model.eval()
151
  model.to(device)
152
 
153
+ if not hasattr(model, 'merge_clouds'):
154
+ model.merge_clouds = merge_clouds
155
+ else:
156
+ merge_clouds = model.merge_clouds
157
 
158
  C, H, W = image.shape
159
 
 
160
  if H <= chunk_size and W <= chunk_size:
161
  with torch.no_grad():
162
  img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
163
  logits = model(img_tensor)
164
 
165
  if merge_clouds:
 
166
  probs = torch.softmax(logits, dim=1)
167
  probs_merged = torch.zeros(1, 3, H, W, device=device)
168
+ probs_merged[:, 0] = probs[:, 0]
169
+ probs_merged[:, 1] = probs[:, 1] + probs[:, 2]
170
+ probs_merged[:, 2] = probs[:, 3]
171
  pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
172
  else:
173
  pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
174
 
175
  if apply_rules:
176
+ pred = apply_physical_rules(pred, image, merge_clouds)
177
 
178
  return pred
179
 
 
180
  step = chunk_size - overlap
181
  half_tile = chunk_size // 2
182
 
 
183
  image_padded = np.pad(
184
  image,
185
  ((0, 0), (half_tile, half_tile + chunk_size), (half_tile, half_tile + chunk_size)),
 
188
 
189
  _, H_pad, W_pad = image_padded.shape
190
 
 
191
  num_classes = 4
192
  probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32)
193
  weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32)
194
 
 
195
  window = get_spline_window(chunk_size, power=2)
196
 
 
197
  coords = [
198
  (r, c)
199
  for r in range(0, H_pad - chunk_size + 1, step)
200
  for c in range(0, W_pad - chunk_size + 1, step)
201
  ]
202
 
 
203
  with torch.no_grad():
204
  for i in tqdm(range(0, len(coords), batch_size), desc=" Tiles", leave=False, disable=True):
205
  batch_coords = coords[i:i + batch_size]
206
 
 
207
  tiles = np.stack([
208
  image_padded[:, r:r + chunk_size, c:c + chunk_size]
209
  for r, c in batch_coords
210
  ])
211
 
 
212
  tiles_tensor = torch.from_numpy(tiles).float().to(device)
213
  logits = model(tiles_tensor)
214
  probs = torch.softmax(logits, dim=1).cpu().numpy()
215
 
 
216
  for j, (r, c) in enumerate(batch_coords):
217
  probs_sum[:, r:r + chunk_size, c:c + chunk_size] += probs[j] * window
218
  weight_sum[r:r + chunk_size, c:c + chunk_size] += window
219
 
 
220
  weight_sum = np.maximum(weight_sum, 1e-8)
221
  probs_final = probs_sum / weight_sum
222
 
 
223
  probs_final = probs_final[:, half_tile:half_tile + H, half_tile:half_tile + W]
224
 
 
225
  if merge_clouds:
226
  probs_merged = np.zeros((3, H, W), dtype=np.float32)
227
+ probs_merged[0] = probs_final[0]
228
+ probs_merged[1] = probs_final[1] + probs_final[2]
229
+ probs_merged[2] = probs_final[3]
230
  pred = np.argmax(probs_merged, axis=0).astype(np.uint8)
231
  else:
232
  pred = np.argmax(probs_final, axis=0).astype(np.uint8)
233
 
 
234
  if apply_rules:
235
+ pred = apply_physical_rules(pred, image, merge_clouds)
236
 
237
  return pred
238
 
239
 
 
 
 
 
240
  def example_data(model_dir: Path, **kwargs):
241
+ """Load example data for testing."""
 
 
 
 
 
 
242
  example_path = model_dir / "example_mss.npy"
243
 
244
  if not example_path.exists():
 
245
  print("⚠️ No example data found, generating synthetic")
246
  return np.random.rand(4, 512, 512).astype(np.float32) * 0.5
247
 
 
255
  stac_item=None,
256
  **kwargs
257
  ):
258
+ """Display prediction results."""
 
 
 
 
 
 
 
 
259
  try:
260
  import matplotlib.pyplot as plt
261
  from matplotlib.colors import ListedColormap
 
265
 
266
  merge_clouds = prediction.max() <= 2
267
 
 
268
  if merge_clouds:
269
+ colors = ['#2E7D32', '#FFFFFF', '#424242']
270
  labels = ['Clear', 'Cloud', 'Shadow']
271
  else:
272
  colors = ['#2E7D32', '#B3E5FC', '#FFFFFF', '#424242']
 
274
 
275
  cmap = ListedColormap(colors)
276
 
 
277
  fig, axes = plt.subplots(1, 2, figsize=(12, 5))
278
 
 
279
  rgb = np.stack([image[1], image[0], image[2]], axis=-1)
280
+ rgb = np.clip(rgb * 3, 0, 1)
281
  axes[0].imshow(rgb)
282
  axes[0].set_title("MSS RGB Composite")
283
  axes[0].axis('off')
284
 
 
285
  im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=len(labels)-1)
286
  axes[1].set_title("Cloud Detection")
287
  axes[1].axis('off')
288
 
 
289
  cbar = plt.colorbar(im, ax=axes[1], ticks=range(len(labels)))
290
  cbar.ax.set_yticklabels(labels)
291