JulioContrerasH commited on
Commit
f5f9f9f
·
verified ·
1 Parent(s): 1f9f4d5

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -66,3 +66,4 @@ single/spot_1dpwunet.pt2 filter=lfs diff=lfs merge=lfs -text
66
  single/spot_1dpwunetpp.pt2 filter=lfs diff=lfs merge=lfs -text
67
  single/spot_segformer.pt2 filter=lfs diff=lfs merge=lfs -text
68
  single/spot_unetpp.pt2 filter=lfs diff=lfs merge=lfs -text
 
 
66
  single/spot_1dpwunetpp.pt2 filter=lfs diff=lfs merge=lfs -text
67
  single/spot_segformer.pt2 filter=lfs diff=lfs merge=lfs -text
68
  single/spot_unetpp.pt2 filter=lfs diff=lfs merge=lfs -text
69
+ ensemble/ensemble_4.pt2 filter=lfs diff=lfs merge=lfs -text
ensemble/ensemble_4.json ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "Feature",
3
+ "stac_version": "1.1.0",
4
+ "stac_extensions": [
5
+ "https://stac-extensions.github.io/mlm/v1.5.0/schema.json",
6
+ "https://stac-extensions.github.io/file/v2.1.0/schema.json"
7
+ ],
8
+ "id": "ENSEMBLE_4MODELS_MEAN_UNCERTAINTY_2025-10-27",
9
+ "geometry": {
10
+ "type": "Polygon",
11
+ "coordinates": [
12
+ [
13
+ [
14
+ -180.0,
15
+ -90.0
16
+ ],
17
+ [
18
+ -180.0,
19
+ 90.0
20
+ ],
21
+ [
22
+ 180.0,
23
+ 90.0
24
+ ],
25
+ [
26
+ 180.0,
27
+ -90.0
28
+ ],
29
+ [
30
+ -180.0,
31
+ -90.0
32
+ ]
33
+ ]
34
+ ]
35
+ },
36
+ "bbox": [
37
+ -180,
38
+ -90,
39
+ 180,
40
+ 90
41
+ ],
42
+ "properties": {
43
+ "datetime": "2025-10-27T11:08:23Z",
44
+ "created": "2025-10-27T11:08:23Z",
45
+ "updated": "2025-12-01T10:57:16.283159Z",
46
+ "description": "Ensemble of 4 models (1dpwdeeplabv3, 1dpwunetpp, 1dpwseg, unet) with Mean aggregation and uncertainty quantification for cloud detection in VGT-1, VGT-2, and PROBA-V satellite imagery.",
47
+ "title": "Ensemble Cloud Detection Model (4 Models + Uncertainty) - VGT1/VGT2/Proba-V",
48
+ "mlm:name": "ensemble_4models_mean_uncertainty_fdr4vgt_cloudmask",
49
+ "mlm:architecture": "Ensemble (Mean+Uncertainty): DeepLabV3+PW, UNet+++PW, SegFormer+PW, UNet",
50
+ "mlm:tasks": [
51
+ "semantic-segmentation",
52
+ "uncertainty-quantification"
53
+ ],
54
+ "mlm:framework": "pytorch",
55
+ "mlm:framework_version": "2.5.1+cu121",
56
+ "mlm:accelerator": "cuda",
57
+ "mlm:accelerator_constrained": false,
58
+ "mlm:accelerator_summary": "NVIDIA GPU with CUDA support (compute capability >= 7.0)",
59
+ "mlm:accelerator_count": 1,
60
+ "mlm:memory_size": 187574737,
61
+ "mlm:batch_size_suggestion": 4,
62
+ "mlm:total_parameters": 29030983,
63
+ "mlm:pretrained": true,
64
+ "mlm:pretrained_source": "Global VGT-1/VGT-2/PROBA-V cloud detection models (100k+ training samples)",
65
+ "mlm:input": [
66
+ {
67
+ "name": "VGT_PROBA_TOC_reflectance",
68
+ "bands": [
69
+ "Blue (B0, ~450nm)",
70
+ "Red (B2, ~645nm)",
71
+ "Near-Infrared (B3, ~835nm)",
72
+ "SWIR (MIR, ~1665nm)"
73
+ ],
74
+ "input": {
75
+ "shape": [
76
+ -1,
77
+ 4,
78
+ 512,
79
+ 512
80
+ ],
81
+ "dim_order": [
82
+ "batch",
83
+ "channel",
84
+ "height",
85
+ "width"
86
+ ],
87
+ "data_type": "float32"
88
+ },
89
+ "norm": {
90
+ "type": "raw_toc_reflectance",
91
+ "range": [
92
+ 0,
93
+ 10000
94
+ ],
95
+ "description": "Raw Top-of-Canopy reflectance values scaled by 10000"
96
+ },
97
+ "pre_processing_function": null
98
+ }
99
+ ],
100
+ "mlm:output": [
101
+ {
102
+ "name": "cloud_probability",
103
+ "tasks": [
104
+ "semantic-segmentation"
105
+ ],
106
+ "result": {
107
+ "shape": [
108
+ -1,
109
+ 1,
110
+ 512,
111
+ 512
112
+ ],
113
+ "dim_order": [
114
+ "batch",
115
+ "channel",
116
+ "height",
117
+ "width"
118
+ ],
119
+ "data_type": "float32"
120
+ },
121
+ "classification:classes": [
122
+ {
123
+ "value": 0.0,
124
+ "name": "clear",
125
+ "description": "Clear sky (may contain cloud shadows)",
126
+ "color_hint": "00000000"
127
+ },
128
+ {
129
+ "value": 1.0,
130
+ "name": "cloud",
131
+ "description": "Cloud present",
132
+ "color_hint": "FFFF00"
133
+ }
134
+ ],
135
+ "post_processing_function": "Apply threshold to get binary mask. Recommended threshold: 0.4. Returns tuple: (probabilities, uncertainty)",
136
+ "standard_threshold": 0.5,
137
+ "recommended_threshold": 0.4,
138
+ "value_range": [
139
+ 0.0,
140
+ 1.0
141
+ ],
142
+ "description": "Per-pixel mean probability across ensemble models. Built-in sigmoid activation. Values close to 1.0 indicate high confidence of cloud."
143
+ },
144
+ {
145
+ "name": "prediction_uncertainty",
146
+ "tasks": [
147
+ "uncertainty-quantification"
148
+ ],
149
+ "result": {
150
+ "shape": [
151
+ -1,
152
+ 1,
153
+ 512,
154
+ 512
155
+ ],
156
+ "dim_order": [
157
+ "batch",
158
+ "channel",
159
+ "height",
160
+ "width"
161
+ ],
162
+ "data_type": "float32"
163
+ },
164
+ "value_range": [
165
+ 0.0,
166
+ 1.0
167
+ ],
168
+ "description": "Normalized standard deviation across 4 ensemble members. Values close to 1.0 indicate high disagreement between models (high uncertainty). Automatically returned as second element of output tuple."
169
+ }
170
+ ],
171
+ "mlm:hyperparameters": {
172
+ "ensemble_size": 4,
173
+ "ensemble_members": [
174
+ "1dpwdeeplabv3",
175
+ "1dpwunetpp",
176
+ "1dpwseg",
177
+ "unet"
178
+ ],
179
+ "aggregation_method": "mean",
180
+ "uncertainty_method": "normalized_std",
181
+ "avg_val_loss": 0.0616,
182
+ "member_details": [
183
+ {
184
+ "model": "1dpwdeeplabv3",
185
+ "epoch": 25,
186
+ "val_loss": 0.0611
187
+ },
188
+ {
189
+ "model": "1dpwunetpp",
190
+ "epoch": 22,
191
+ "val_loss": 0.0625
192
+ },
193
+ {
194
+ "model": "1dpwseg",
195
+ "epoch": 23,
196
+ "val_loss": 0.0622
197
+ },
198
+ {
199
+ "model": "unet",
200
+ "epoch": 20,
201
+ "val_loss": 0.0606
202
+ }
203
+ ]
204
+ },
205
+ "file:size": 125049825,
206
+ "custom:export_format": "torch.export.pt2",
207
+ "custom:has_sigmoid": true,
208
+ "custom:sigmoid_location": "built-in per-model wrapper",
209
+ "custom:export_datetime": "2025-12-01T10:57:16.283159Z",
210
+ "custom:training_stage": "ensemble-mean-uncertainty",
211
+ "custom:project": "FDR4VGT",
212
+ "custom:project_url": "https://fdr4vgt.eu/",
213
+ "custom:sensors": [
214
+ "VGT-1",
215
+ "VGT-2",
216
+ "PROBA-V"
217
+ ],
218
+ "custom:sensor_notes": "Model applicable to SPOT-VGT1, SPOT-VGT2, and PROBA-V imagery",
219
+ "custom:spatial_resolution": "1km",
220
+ "custom:tile_size": 512,
221
+ "custom:recommended_overlap": 64,
222
+ "custom:applicable_start": "1998-03-01T00:00:00Z",
223
+ "custom:applicable_end": null,
224
+ "custom:returns_tuple": true,
225
+ "custom:tuple_format": "(probabilities, uncertainty)",
226
+ "dependencies": [
227
+ "torch>=2.0.0",
228
+ "segmentation-models-pytorch>=0.3.0",
229
+ "pytorch-lightning>=2.0.0",
230
+ "numpy>=1.20.0"
231
+ ]
232
+ },
233
+ "links": [
234
+ {
235
+ "rel": "about",
236
+ "href": "https://fdr4vgt.eu/",
237
+ "type": "text/html",
238
+ "title": "FDR4VGT Project - Harmonized VGT Data Record"
239
+ },
240
+ {
241
+ "rel": "license",
242
+ "href": "https://creativecommons.org/licenses/by/4.0/",
243
+ "type": "text/html",
244
+ "title": "CC-BY-4.0 License"
245
+ }
246
+ ],
247
+ "assets": {
248
+ "model": {
249
+ "href": "https://huggingface.co/isp-uv-es/FDR4VGT-CLOUD/resolve/main/ensemble/ensemble_4.pt2",
250
+ "type": "application/octet-stream; application=pytorch",
251
+ "title": "PyTorch ensemble model weights",
252
+ "description": "Ensemble of 4 models in torch.export .pt2 format. Returns tuple: (probabilities, uncertainty).",
253
+ "mlm:artifact_type": "torch.export.pt2",
254
+ "roles": [
255
+ "mlm:model",
256
+ "mlm:weights",
257
+ "data"
258
+ ]
259
+ },
260
+ "example_data": {
261
+ "href": "https://huggingface.co/isp-uv-es/FDR4VGT-CLOUD/resolve/main/ensemble/example_data.safetensor",
262
+ "type": "application/octet-stream; application=safetensors",
263
+ "title": "Example VGT/PROBA-V image",
264
+ "description": "Example VGT/PROBA-V Top-of-Canopy reflectance image for model inference.",
265
+ "roles": [
266
+ "mlm:example_data",
267
+ "data"
268
+ ]
269
+ },
270
+ "load": {
271
+ "href": "https://huggingface.co/isp-uv-es/FDR4VGT-CLOUD/resolve/main/ensemble/load.py",
272
+ "type": "application/x-python-code",
273
+ "title": "PyTorch Ensemble Loader",
274
+ "description": "Python helper code to load the exported .pt2 ensemble model. Includes predict_large() function for large images.",
275
+ "roles": [
276
+ "code"
277
+ ]
278
+ }
279
+ },
280
+ "collection": "ENSEMBLE_4MODELS_FDR4VGT_CloudMask_MeanUncertainty"
281
+ }
ensemble/ensemble_4.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca47822614547e57b105edee92cda3f8fd080c0139523e7febd47fce809d69a2
3
+ size 125049825
ensemble/example_data.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a66d52bb558f756d105b41ead9386cdd6f04b4ac9cdc0173b5632aa00f35b244
3
+ size 524504
ensemble/load.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn
3
+ import pathlib
4
+ import pystac
5
+ from typing import Literal, Tuple
6
+ import numpy as np
7
+ import itertools
8
+ from tqdm import tqdm
9
+ import math
10
+
11
+ # Ensemble model for combining multiple models' outputs
12
+ class EnsembleModel(torch.nn.Module):
13
+ def __init__(self, *models, mode="max"):
14
+ super(EnsembleModel, self).__init__()
15
+ self.models = torch.nn.ModuleList(models)
16
+ self.mode = mode
17
+ if mode not in ["min", "mean", "median", "max", "none"]:
18
+ raise ValueError("Mode must be 'none', 'min', 'mean', 'median', or 'max'.")
19
+
20
+ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
21
+ """
22
+ Forward pass for ensemble.
23
+
24
+ Returns:
25
+ Tuple of (probabilities, uncertainty):
26
+ - probabilities: (B, 1, H, W) - aggregated predictions
27
+ - uncertainty: (B, 1, H, W) - normalized std deviation
28
+ """
29
+ outputs = []
30
+ for model in self.models:
31
+ output = model(x)
32
+ outputs.append(output)
33
+
34
+ if not outputs:
35
+ return None, None
36
+
37
+ # Stack all model outputs: (B, N, H, W) where N = number of models
38
+ stacked_outputs = torch.stack(outputs, dim=1) # (B, N, 1, H, W)
39
+ stacked_outputs = stacked_outputs.squeeze(2) # (B, N, H, W)
40
+
41
+ # Calculate aggregated probabilities
42
+ if self.mode == "max":
43
+ output_probs = torch.max(stacked_outputs, dim=1, keepdim=True)[0]
44
+ elif self.mode == "mean":
45
+ output_probs = torch.mean(stacked_outputs, dim=1, keepdim=True)
46
+ elif self.mode == "median":
47
+ output_probs = torch.median(stacked_outputs, dim=1, keepdim=True)[0]
48
+ elif self.mode == "min":
49
+ output_probs = torch.min(stacked_outputs, dim=1, keepdim=True)[0]
50
+ elif self.mode == "none":
51
+ # Return all predictions without aggregation
52
+ return stacked_outputs, None
53
+ else:
54
+ raise ValueError("Mode must be 'min', 'mean', 'median', or 'max'.")
55
+
56
+ # Calculate uncertainty (normalized standard deviation)
57
+ N = len(outputs)
58
+ if N > 1:
59
+ # Calculate std across models (dim=1)
60
+ std_output = torch.std(stacked_outputs, dim=1, keepdim=True)
61
+
62
+ # Normalize the standard deviation [0 - 1]
63
+ # Formula: std_max = sqrt(0.25 * N / (N - 1))
64
+ std_max = math.sqrt(0.25 * N / (N - 1))
65
+ uncertainty = std_output / std_max
66
+
67
+ # Clamp to [0, 1] to avoid numerical issues
68
+ uncertainty = torch.clamp(uncertainty, 0.0, 1.0)
69
+ else:
70
+ # Single model: no uncertainty
71
+ uncertainty = torch.zeros_like(output_probs)
72
+
73
+ return output_probs, uncertainty # Both (B, 1, H, W)
74
+
75
+ def compiled_model(
76
+ path: pathlib.Path,
77
+ stac_item: pystac.Item,
78
+ mode: Literal["min", "mean", "median", "max"] = "max",
79
+ *args, **kwargs
80
+ ):
81
+ """
82
+ Loads model(s) dynamically based on STAC metadata.
83
+
84
+ - If single .pt2 → returns single model
85
+ - If multiple .pt2 → returns EnsembleModel
86
+
87
+ Args:
88
+ mode: Aggregation mode for ensembles (ignored for single models)
89
+
90
+ Returns:
91
+ Single model or EnsembleModel
92
+ """
93
+ model_paths = []
94
+ for asset_key, asset in stac_item.assets.items():
95
+ if asset.href.endswith(".pt2"):
96
+ model_paths.append(asset.href)
97
+
98
+ if not model_paths:
99
+ raise ValueError("No .pt2 files found in STAC item assets.")
100
+
101
+ model_paths.sort()
102
+
103
+ if len(model_paths) == 1:
104
+ # Single model
105
+ return torch.export.load(model_paths[0]).module()
106
+ else:
107
+ # Ensemble model
108
+ models = [torch.export.load(p).module() for p in model_paths]
109
+ return EnsembleModel(*models, mode=mode)
110
+
111
+ def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
112
+ """
113
+ Defines iteration strategy to traverse the image with overlap.
114
+ """
115
+ dimy, dimx = dimension
116
+ if chunk_size > max(dimx, dimy):
117
+ return [(0, 0)]
118
+ y_step = chunk_size - overlap
119
+ x_step = chunk_size - overlap
120
+ iterchunks = list(itertools.product(range(0, dimy, y_step), range(0, dimx, x_step)))
121
+ iterchunks_fixed = fix_lastchunk(
122
+ iterchunks=iterchunks, s2dim=dimension, chunk_size=chunk_size
123
+ )
124
+ return iterchunks_fixed
125
+
126
+
127
+ def fix_lastchunk(iterchunks, s2dim, chunk_size):
128
+ """
129
+ Adjusts last chunks to prevent them from exceeding boundaries.
130
+ """
131
+ itercontainer = []
132
+ for index_i, index_j in iterchunks:
133
+ if index_i + chunk_size > s2dim[0]:
134
+ index_i = max(s2dim[0] - chunk_size, 0)
135
+ if index_j + chunk_size > s2dim[1]:
136
+ index_j = max(s2dim[1] - chunk_size, 0)
137
+ itercontainer.append((index_i, index_j))
138
+ return list(set(itercontainer)) # Returns unique values just in case
139
+
140
+
141
+ def predict_large(
142
+ image: np.ndarray,
143
+ model: torch.nn.Module,
144
+ chunk_size: int = 512,
145
+ overlap: int = 64,
146
+ device: str = "cpu",
147
+ nodata: float = 0.0
148
+ ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
149
+ """
150
+ Predict a full 'image' (C, H, W) using overlapping patches.
151
+
152
+ Args:
153
+ image: Input array (C, H, W)
154
+ model: Compiled PyTorch model
155
+ chunk_size: Tile size for inference
156
+ overlap: Overlap between tiles
157
+ device: 'cpu' or 'cuda'
158
+ nodata: No-data value
159
+
160
+ Returns:
161
+ - For ensembles: Tuple of (probabilities, uncertainty), both (1, H, W)
162
+ - For single models: probabilities array (1, H, W)
163
+
164
+ Compatible with:
165
+ - Normal models (with .eval()) - returns probabilities only
166
+ - Exported models (.pt2) - returns probabilities only
167
+ - Ensembles (EnsembleModel) - returns (probabilities, uncertainty)
168
+ """
169
+
170
+ # Validate input array dimensions
171
+ if image.ndim != 3:
172
+ raise ValueError(f"Input array must be (C, H, W). Received {image.shape}")
173
+
174
+ bands, height, width = image.shape
175
+
176
+ # Prepare model (compatibility logic for .pt2 models)
177
+ try:
178
+ model.eval()
179
+ for p in model.parameters():
180
+ p.requires_grad = False
181
+ model = model.to(device)
182
+ except (NotImplementedError, AttributeError):
183
+ # Exported model (.pt2) or EnsembleModel
184
+ model = model.to(device)
185
+
186
+ test_input = torch.zeros(1, bands, chunk_size, chunk_size).to(device)
187
+ with torch.no_grad():
188
+ test_output = model(test_input)
189
+
190
+ is_ensemble = isinstance(test_output, tuple) and len(test_output) == 2
191
+
192
+ # Initialize output arrays
193
+ output_probs = np.full((1, height, width), nodata, dtype=np.float32)
194
+
195
+ if is_ensemble:
196
+ output_uncertainty = np.full((1, height, width), nodata, dtype=np.float32)
197
+
198
+ # Get the list of tile offsets
199
+ coords = define_iteration(
200
+ dimension=(height, width),
201
+ chunk_size=chunk_size,
202
+ overlap=overlap
203
+ )
204
+
205
+ # Iterate over tiles
206
+ for idx, (row_off, col_off) in enumerate(tqdm(coords, desc="Inference")):
207
+
208
+ # Read chunk (numpy slicing)
209
+ patch = image[
210
+ :,
211
+ row_off : row_off + chunk_size,
212
+ col_off : col_off + chunk_size
213
+ ]
214
+
215
+ # Convert to tensor and handle padding if tile is smaller than chunk_size
216
+ patch_tensor = torch.from_numpy(patch).float().unsqueeze(0).to(device)
217
+ _, _, h_tile, w_tile = patch_tensor.shape
218
+
219
+ # Calculate padding needed
220
+ pad_h = chunk_size - h_tile
221
+ pad_w = chunk_size - w_tile
222
+
223
+ # Apply padding if necessary
224
+ if pad_h > 0 or pad_w > 0:
225
+ patch_tensor = torch.nn.functional.pad(
226
+ patch_tensor, (0, pad_w, 0, pad_h), "constant", nodata
227
+ )
228
+
229
+ # Create mask for nodata areas (all bands are nodata)
230
+ mask_all = (patch_tensor == nodata).all(dim=1, keepdim=True)
231
+
232
+ # Forward pass
233
+ with torch.no_grad():
234
+ model_output = model(patch_tensor)
235
+
236
+ if is_ensemble:
237
+ probs, uncertainty = model_output
238
+ probs = probs.masked_fill(mask_all, nodata)
239
+ uncertainty = uncertainty.masked_fill(mask_all, nodata)
240
+ else:
241
+ probs = model_output
242
+ probs = probs.masked_fill(mask_all, nodata)
243
+
244
+ # Remove batch dimension and ensure (1, H, W)
245
+ if probs.ndim == 4:
246
+ probs = probs.squeeze(0) # (1, H, W)
247
+
248
+ # Convert to numpy
249
+ result_probs = probs.cpu().numpy() # (1, H, W)
250
+
251
+ if is_ensemble:
252
+ if uncertainty.ndim == 4:
253
+ uncertainty = uncertainty.squeeze(0)
254
+ result_uncertainty = uncertainty.cpu().numpy()
255
+
256
+ # Logic for partial writing
257
+ if col_off == 0:
258
+ offset_x = 0
259
+ else:
260
+ offset_x = col_off + overlap // 2
261
+
262
+ if row_off == 0:
263
+ offset_y = 0
264
+ else:
265
+ offset_y = row_off + overlap // 2
266
+
267
+ if (offset_x + chunk_size) == width:
268
+ length_x = chunk_size
269
+ sub_x_start = 0
270
+ else:
271
+ length_x = chunk_size - (overlap // 2)
272
+ sub_x_start = overlap // 2 if col_off != 0 else 0
273
+
274
+ if (offset_y + chunk_size) == height:
275
+ length_y = chunk_size
276
+ sub_y_start = 0
277
+ else:
278
+ length_y = chunk_size - (overlap // 2)
279
+ sub_y_start = overlap // 2 if row_off != 0 else 0
280
+
281
+ # Ensure we don't exceed array bounds
282
+ if offset_y + length_y > height:
283
+ length_y = height - offset_y
284
+ if offset_x + length_x > width:
285
+ length_x = width - offset_x
286
+
287
+ # Extract the valid region from the result
288
+ to_write_probs = result_probs[
289
+ :,
290
+ sub_y_start : sub_y_start + length_y,
291
+ sub_x_start : sub_x_start + length_x
292
+ ]
293
+
294
+ # Write to the output numpy array
295
+ output_probs[
296
+ :,
297
+ offset_y : offset_y + length_y,
298
+ offset_x : offset_x + length_x
299
+ ] = to_write_probs
300
+
301
+ if is_ensemble:
302
+ to_write_uncertainty = result_uncertainty[
303
+ :,
304
+ sub_y_start : sub_y_start + length_y,
305
+ sub_x_start : sub_x_start + length_x
306
+ ]
307
+ output_uncertainty[
308
+ :,
309
+ offset_y : offset_y + length_y,
310
+ offset_x : offset_x + length_x
311
+ ] = to_write_uncertainty
312
+
313
+ if is_ensemble:
314
+ return output_probs, output_uncertainty
315
+ else:
316
+ return output_probs