JulioContrerasH commited on
Commit
493ecbb
·
verified ·
1 Parent(s): 70e6d35

Upload 3 files

Browse files
Files changed (3) hide show
  1. load.py +358 -0
  2. unet.ckpt +3 -0
  3. unet.json +201 -0
load.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Load and inference functions for MSS Cloud Detection Model
3
+ Compatible with mlstac package
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ from pathlib import Path
10
+ from typing import Tuple, Optional
11
+ import pytorch_lightning as pl
12
+ import segmentation_models_pytorch as smp
13
+ from tqdm import tqdm
14
+
15
+
16
+ # ============================================================================
17
+ # MODEL DEFINITION (copied from your model.py)
18
+ # ============================================================================
19
+
20
+ class MSSSegmentationModel(pl.LightningModule):
21
+ """UNet para cloud segmentation en MSS."""
22
+
23
+ def __init__(
24
+ self,
25
+ in_channels: int = 4,
26
+ num_classes: int = 4,
27
+ encoder: str = "efficientnet-b3",
28
+ lr: float = 3e-4,
29
+ weight_decay: float = 1e-4,
30
+ ):
31
+ super().__init__()
32
+ self.save_hyperparameters()
33
+
34
+ self.model = smp.Unet(
35
+ encoder_name=encoder,
36
+ encoder_weights=None,
37
+ in_channels=in_channels,
38
+ classes=num_classes,
39
+ encoder_depth=5,
40
+ activation=None,
41
+ decoder_attention_type="scse",
42
+ )
43
+
44
+ def forward(self, x):
45
+ return self.model(x)
46
+
47
+
48
+ # ============================================================================
49
+ # INFERENCE UTILITIES
50
+ # ============================================================================
51
+
52
+ def get_spline_window(size: int, power: int = 2) -> np.ndarray:
53
+ """Generate Hann window for smooth blending."""
54
+ intersection = np.hanning(size)
55
+ window_2d = np.outer(intersection, intersection)
56
+ return (window_2d ** power).astype(np.float32)
57
+
58
+
59
+ def apply_physical_rules(
60
+ pred: np.ndarray,
61
+ image: np.ndarray,
62
+ merge_clouds: bool = False,
63
+ saturation_threshold: float = 0.35,
64
+ ) -> np.ndarray:
65
+ """
66
+ Apply physical rules for better cloud detection.
67
+
68
+ Args:
69
+ pred: Predicted classes (H, W)
70
+ image: Input image (4, H, W) in reflectance [0, 1]
71
+ merge_clouds: If True, merge thin+thick into single cloud class
72
+ saturation_threshold: Threshold for detecting saturated bright clouds
73
+ """
74
+ pred = pred.copy()
75
+
76
+ # Mask nodata pixels
77
+ nodata_mask = np.all(image == 0, axis=0)
78
+ pred[nodata_mask] = 0
79
+
80
+ # Detect very bright pixels (likely thick clouds)
81
+ bright_b0 = image[0] > saturation_threshold
82
+ bright_b1 = image[1] > saturation_threshold * 0.80
83
+ saturated_mask = bright_b0 & bright_b1
84
+
85
+ if merge_clouds:
86
+ # Set to cloud (1)
87
+ pred[saturated_mask] = 1
88
+ else:
89
+ # Set to thick cloud (2)
90
+ pred[saturated_mask] = 2
91
+
92
+ return pred
93
+
94
+
95
+ # ============================================================================
96
+ # MLSTAC-COMPATIBLE FUNCTIONS
97
+ # ============================================================================
98
+
99
+ def compiled_model(
100
+ model_dir: Path,
101
+ stac_item=None,
102
+ device: str = "cpu",
103
+ merge_clouds: bool = False,
104
+ **kwargs
105
+ ) -> nn.Module:
106
+ """
107
+ Load compiled model for inference.
108
+
109
+ Args:
110
+ model_dir: Directory containing the .ckpt file
111
+ stac_item: STAC item metadata (optional)
112
+ device: 'cpu' or 'cuda'
113
+ merge_clouds: If True, output will have 3 classes (clear, cloud, shadow)
114
+ If False, output will have 4 classes (clear, thin, thick, shadow)
115
+
116
+ Returns:
117
+ Loaded model in eval mode
118
+ """
119
+ # Find checkpoint file
120
+ ckpt_files = list(model_dir.glob("*.ckpt"))
121
+ if not ckpt_files:
122
+ raise FileNotFoundError(f"No .ckpt file found in {model_dir}")
123
+
124
+ ckpt_path = ckpt_files[0]
125
+
126
+ # Load model
127
+ model = MSSSegmentationModel.load_from_checkpoint(
128
+ ckpt_path,
129
+ map_location=device
130
+ )
131
+ model.eval()
132
+ model.to(device)
133
+
134
+ # Disable gradients
135
+ for param in model.parameters():
136
+ param.requires_grad = False
137
+
138
+ # Store merge_clouds flag for predict_large
139
+ model.merge_clouds = merge_clouds
140
+
141
+ print(f"✅ Model loaded from {ckpt_path.name}")
142
+ print(f" Device: {device}")
143
+ print(f" Classes: {'3 (merged)' if merge_clouds else '4 (original)'}")
144
+
145
+ return model
146
+
147
+
148
+ def predict_large(
149
+ image: np.ndarray,
150
+ model: nn.Module,
151
+ chunk_size: int = 512,
152
+ overlap: int = 256,
153
+ batch_size: int = 1,
154
+ device: str = "cpu",
155
+ nodata: float = 0.0,
156
+ apply_rules: bool = True,
157
+ saturation_threshold: float = 0.35,
158
+ **kwargs
159
+ ) -> np.ndarray:
160
+ """
161
+ Predict on large images using sliding window with overlap blending.
162
+
163
+ Args:
164
+ image: Input image (C, H, W) in reflectance [0, 1]
165
+ model: Loaded model from compiled_model()
166
+ chunk_size: Size of inference tiles (default: 1024)
167
+ overlap: Overlap between tiles for smooth blending (default: 256)
168
+ batch_size: Number of tiles to process in parallel (default: 1)
169
+ device: 'cpu' or 'cuda'
170
+ nodata: Value representing no-data pixels
171
+ apply_rules: Whether to apply physical rules post-processing
172
+ saturation_threshold: Threshold for detecting bright clouds
173
+
174
+ Returns:
175
+ Predicted class labels (H, W) with shape matching input
176
+ - If merge_clouds=False: 0=clear, 1=thin, 2=thick, 3=shadow
177
+ - If merge_clouds=True: 0=clear, 1=cloud, 2=shadow
178
+ """
179
+ model.eval()
180
+ model.to(device)
181
+
182
+ merge_clouds = getattr(model, 'merge_clouds', False)
183
+
184
+ C, H, W = image.shape
185
+
186
+ # Direct inference for small images
187
+ if H <= chunk_size and W <= chunk_size:
188
+ with torch.no_grad():
189
+ img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device)
190
+ logits = model(img_tensor)
191
+
192
+ if merge_clouds:
193
+ # Merge thin(1) + thick(2) probabilities
194
+ probs = torch.softmax(logits, dim=1)
195
+ probs_merged = torch.zeros(1, 3, H, W, device=device)
196
+ probs_merged[:, 0] = probs[:, 0] # clear
197
+ probs_merged[:, 1] = probs[:, 1] + probs[:, 2] # cloud
198
+ probs_merged[:, 2] = probs[:, 3] # shadow
199
+ pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
200
+ else:
201
+ pred = logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
202
+
203
+ if apply_rules:
204
+ pred = apply_physical_rules(pred, image, merge_clouds, saturation_threshold)
205
+
206
+ return pred
207
+
208
+ # Sliding window for large images
209
+ step = chunk_size - overlap
210
+ half_tile = chunk_size // 2
211
+
212
+ # Pad image
213
+ image_padded = np.pad(
214
+ image,
215
+ ((0, 0), (half_tile, half_tile + chunk_size), (half_tile, half_tile + chunk_size)),
216
+ mode="reflect"
217
+ )
218
+
219
+ _, H_pad, W_pad = image_padded.shape
220
+
221
+ # Initialize accumulators - ALWAYS 4 classes, merge at the end if needed
222
+ num_classes = 4
223
+ probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32)
224
+ weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32)
225
+
226
+ # Blending window
227
+ window = get_spline_window(chunk_size, power=2)
228
+
229
+ # Generate tile coordinates
230
+ coords = [
231
+ (r, c)
232
+ for r in range(0, H_pad - chunk_size + 1, step)
233
+ for c in range(0, W_pad - chunk_size + 1, step)
234
+ ]
235
+
236
+ # Process tiles in batches
237
+ with torch.no_grad():
238
+ for i in tqdm(range(0, len(coords), batch_size), desc=" Tiles", leave=False, disable=True):
239
+ batch_coords = coords[i:i + batch_size]
240
+
241
+ # Extract tiles
242
+ tiles = np.stack([
243
+ image_padded[:, r:r + chunk_size, c:c + chunk_size]
244
+ for r, c in batch_coords
245
+ ])
246
+
247
+ # Inference
248
+ tiles_tensor = torch.from_numpy(tiles).float().to(device)
249
+ logits = model(tiles_tensor)
250
+ probs = torch.softmax(logits, dim=1).cpu().numpy()
251
+
252
+ # Accumulate with blending - ALWAYS accumulate 4 classes
253
+ for j, (r, c) in enumerate(batch_coords):
254
+ probs_sum[:, r:r + chunk_size, c:c + chunk_size] += probs[j] * window
255
+ weight_sum[r:r + chunk_size, c:c + chunk_size] += window
256
+
257
+ # Normalize
258
+ weight_sum = np.maximum(weight_sum, 1e-8)
259
+ probs_final = probs_sum / weight_sum
260
+
261
+ # Crop to original size
262
+ probs_final = probs_final[:, half_tile:half_tile + H, half_tile:half_tile + W]
263
+
264
+ # Merge classes if requested - AFTER normalization
265
+ if merge_clouds:
266
+ probs_merged = np.zeros((3, H, W), dtype=np.float32)
267
+ probs_merged[0] = probs_final[0] # clear
268
+ probs_merged[1] = probs_final[1] + probs_final[2] # cloud = thin + thick
269
+ probs_merged[2] = probs_final[3] # shadow
270
+ pred = np.argmax(probs_merged, axis=0).astype(np.uint8)
271
+ else:
272
+ pred = np.argmax(probs_final, axis=0).astype(np.uint8)
273
+
274
+ # Apply physical rules
275
+ if apply_rules:
276
+ pred = apply_physical_rules(pred, image, merge_clouds, saturation_threshold)
277
+
278
+ return pred
279
+
280
+
281
+ # ============================================================================
282
+ # OPTIONAL: EXAMPLE DATA AND VISUALIZATION
283
+ # ============================================================================
284
+
285
+ def example_data(model_dir: Path, **kwargs):
286
+ """
287
+ Load example data for testing (optional function).
288
+
289
+ Returns:
290
+ Example MSS image as numpy array (4, H, W)
291
+ """
292
+ # This is optional - you can provide a small example .npy file
293
+ example_path = model_dir / "example_mss.npy"
294
+
295
+ if not example_path.exists():
296
+ # Return synthetic data if no example file
297
+ print("⚠️ No example data found, generating synthetic")
298
+ return np.random.rand(4, 512, 512).astype(np.float32) * 0.5
299
+
300
+ return np.load(example_path)
301
+
302
+
303
+ def display_results(
304
+ model_dir: Path,
305
+ image: np.ndarray,
306
+ prediction: np.ndarray,
307
+ stac_item=None,
308
+ **kwargs
309
+ ):
310
+ """
311
+ Display prediction results (optional visualization function).
312
+
313
+ Args:
314
+ model_dir: Model directory
315
+ image: Input image (4, H, W)
316
+ prediction: Predicted classes (H, W)
317
+ stac_item: STAC metadata
318
+ """
319
+ try:
320
+ import matplotlib.pyplot as plt
321
+ from matplotlib.colors import ListedColormap
322
+ except ImportError:
323
+ print("⚠️ matplotlib not installed, skipping visualization")
324
+ return
325
+
326
+ merge_clouds = prediction.max() <= 2
327
+
328
+ # Color maps
329
+ if merge_clouds:
330
+ colors = ['#2E7D32', '#FFFFFF', '#424242'] # clear, cloud, shadow
331
+ labels = ['Clear', 'Cloud', 'Shadow']
332
+ else:
333
+ colors = ['#2E7D32', '#B3E5FC', '#FFFFFF', '#424242']
334
+ labels = ['Clear', 'Thin Cloud', 'Thick Cloud', 'Shadow']
335
+
336
+ cmap = ListedColormap(colors)
337
+
338
+ # Plot
339
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
340
+
341
+ # RGB composite (use bands 1, 0, 2 as RGB approximation)
342
+ rgb = np.stack([image[1], image[0], image[2]], axis=-1)
343
+ rgb = np.clip(rgb * 3, 0, 1) # Brighten for visibility
344
+ axes[0].imshow(rgb)
345
+ axes[0].set_title("MSS RGB Composite")
346
+ axes[0].axis('off')
347
+
348
+ # Prediction
349
+ im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=len(labels)-1)
350
+ axes[1].set_title("Cloud Detection")
351
+ axes[1].axis('off')
352
+
353
+ # Colorbar
354
+ cbar = plt.colorbar(im, ax=axes[1], ticks=range(len(labels)))
355
+ cbar.ax.set_yticklabels(labels)
356
+
357
+ plt.tight_layout()
358
+ plt.show()
unet.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:920ac77982e059ba6300f757d5588284cb983f4a5430d05b8103f95101e3470a
3
+ size 154913825
unet.json ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "Feature",
3
+ "stac_version": "1.1.0",
4
+ "stac_extensions": [
5
+ "https://stac-extensions.github.io/mlm/v1.5.0/schema.json",
6
+ "https://stac-extensions.github.io/file/v2.1.0/schema.json"
7
+ ],
8
+ "id": "MSS_CLOUDMASK_UNET_EFFB3",
9
+ "geometry": {
10
+ "type": "Polygon",
11
+ "coordinates": [
12
+ [
13
+ [
14
+ -180,
15
+ -90
16
+ ],
17
+ [
18
+ -180,
19
+ 90
20
+ ],
21
+ [
22
+ 180,
23
+ 90
24
+ ],
25
+ [
26
+ 180,
27
+ -90
28
+ ],
29
+ [
30
+ -180,
31
+ -90
32
+ ]
33
+ ]
34
+ ]
35
+ },
36
+ "bbox": [
37
+ -180,
38
+ -90,
39
+ 180,
40
+ 90
41
+ ],
42
+ "properties": {
43
+ "datetime": "2026-01-18T22:42:31.441233Z",
44
+ "created": "2026-01-18T22:42:31.441233Z",
45
+ "updated": "2026-01-19T01:01:38.488397Z",
46
+ "title": "MSS Cloud Detection Model (UNet-EfficientNetB3)",
47
+ "description": "UNet architecture with EfficientNet-B3 encoder for cloud detection in Landsat MSS (Multispectral Scanner) imagery. Trained on CloudSEN12 data emulated to MSS spectral bands using satharmony package. Detects 4 classes: clear, thin cloud, thick cloud, and shadow.",
48
+ "mlm:name": "mss_cloudmask_unet_effb3",
49
+ "mlm:architecture": "UNet with EfficientNet-B3 encoder + SCSE attention",
50
+ "mlm:tasks": [
51
+ "semantic-segmentation",
52
+ "cloud-detection"
53
+ ],
54
+ "mlm:framework": "pytorch",
55
+ "mlm:framework_version": "2.5.1+cu121",
56
+ "mlm:accelerator": "cuda",
57
+ "mlm:memory_size": 309827650,
58
+ "mlm:batch_size_suggestion": 8,
59
+ "mlm:total_parameters": 13223490,
60
+ "mlm:input": [
61
+ {
62
+ "name": "mss_reflectance",
63
+ "bands": [
64
+ "Green (500-600nm)",
65
+ "Red (600-700nm)",
66
+ "NIR1 (700-800nm)",
67
+ "NIR2 (800-1100nm)"
68
+ ],
69
+ "input": {
70
+ "shape": [
71
+ -1,
72
+ 4,
73
+ "H",
74
+ "W"
75
+ ],
76
+ "dim_order": [
77
+ "batch",
78
+ "channel",
79
+ "height",
80
+ "width"
81
+ ],
82
+ "data_type": "float32"
83
+ },
84
+ "norm": {
85
+ "type": "reflectance",
86
+ "range": [
87
+ 0.0,
88
+ 1.0
89
+ ],
90
+ "description": "TOA reflectance normalized to [0, 1]. DN values should be divided by 10000."
91
+ },
92
+ "preprocessing": "Divide DN by 10000 to get reflectance in [0, 1]"
93
+ }
94
+ ],
95
+ "mlm:output": [
96
+ {
97
+ "name": "cloud_mask",
98
+ "classes": [
99
+ {
100
+ "id": 0,
101
+ "name": "clear",
102
+ "description": "Clear sky"
103
+ },
104
+ {
105
+ "id": 1,
106
+ "name": "thin_cloud",
107
+ "description": "Thin/cirrus clouds"
108
+ },
109
+ {
110
+ "id": 2,
111
+ "name": "thick_cloud",
112
+ "description": "Thick/opaque clouds"
113
+ },
114
+ {
115
+ "id": 3,
116
+ "name": "shadow",
117
+ "description": "Cloud shadow"
118
+ }
119
+ ],
120
+ "result": {
121
+ "shape": [
122
+ -1,
123
+ 4,
124
+ "H",
125
+ "W"
126
+ ],
127
+ "dim_order": [
128
+ "batch",
129
+ "class",
130
+ "height",
131
+ "width"
132
+ ],
133
+ "data_type": "float32"
134
+ },
135
+ "description": "Per-pixel logits for 4 classes. Use argmax to get class labels, or softmax for probabilities.",
136
+ "postprocessing": "Apply argmax(dim=1) to get class labels (0-3), or softmax(dim=1) for probabilities"
137
+ }
138
+ ],
139
+ "mlm:hyperparameters": {
140
+ "learning_rate": 0.0003,
141
+ "weight_decay": 0.0001,
142
+ "optimizer": "AdamW",
143
+ "scheduler": "CosineAnnealingWarmRestarts",
144
+ "batch_size": 256,
145
+ "training_epochs": 55,
146
+ "final_val_iou": 0.6164,
147
+ "loss_function": "CrossEntropyLoss",
148
+ "encoder_depth": 5,
149
+ "decoder_attention": "SCSE"
150
+ },
151
+ "custom:sensor": "Landsat MSS",
152
+ "custom:spatial_resolution": "60m",
153
+ "custom:temporal_coverage": "1972-2013",
154
+ "custom:training_data": "CloudSEN12 emulated to MSS bands",
155
+ "custom:emulator": "satharmony",
156
+ "custom:project": "QA4EO-2",
157
+ "custom:project_url": "https://github.com/IPL-UV/qa4eo",
158
+ "file:size": 154913825,
159
+ "dependencies": [
160
+ "torch>=2.0.0",
161
+ "pytorch-lightning>=2.0.0",
162
+ "segmentation-models-pytorch>=0.3.0",
163
+ "rasterio>=1.3.0",
164
+ "numpy>=1.21.0"
165
+ ]
166
+ },
167
+ "assets": {
168
+ "model": {
169
+ "href": "https://huggingface.co/isp-uv-es/QA4EO-2/resolve/main/unet.ckpt",
170
+ "type": "application/octet-stream",
171
+ "title": "PyTorch Lightning checkpoint",
172
+ "roles": [
173
+ "mlm:model",
174
+ "mlm:weights"
175
+ ],
176
+ "file:size": 154913825
177
+ },
178
+ "load": {
179
+ "href": "https://huggingface.co/isp-uv-es/QA4EO-2/resolve/main/load.py",
180
+ "type": "application/x-python-code",
181
+ "title": "Model loading and inference functions",
182
+ "roles": [
183
+ "mlm:inference-code"
184
+ ]
185
+ }
186
+ },
187
+ "links": [
188
+ {
189
+ "rel": "about",
190
+ "href": "https://github.com/IPL-UV/qa4eo",
191
+ "type": "text/html",
192
+ "title": "Project repository"
193
+ },
194
+ {
195
+ "rel": "license",
196
+ "href": "https://creativecommons.org/licenses/by/4.0/",
197
+ "type": "text/html",
198
+ "title": "CC-BY-4.0"
199
+ }
200
+ ]
201
+ }