Upload 4 files
Browse files- .gitattributes +1 -0
- ensemble/ensemble_4.json +281 -0
- ensemble/ensemble_4.pt2 +3 -0
- ensemble/example_data.safetensor +3 -0
- ensemble/load.py +316 -0
.gitattributes
CHANGED
|
@@ -66,3 +66,4 @@ single/spot_1dpwunet.pt2 filter=lfs diff=lfs merge=lfs -text
|
|
| 66 |
single/spot_1dpwunetpp.pt2 filter=lfs diff=lfs merge=lfs -text
|
| 67 |
single/spot_segformer.pt2 filter=lfs diff=lfs merge=lfs -text
|
| 68 |
single/spot_unetpp.pt2 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 66 |
single/spot_1dpwunetpp.pt2 filter=lfs diff=lfs merge=lfs -text
|
| 67 |
single/spot_segformer.pt2 filter=lfs diff=lfs merge=lfs -text
|
| 68 |
single/spot_unetpp.pt2 filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
ensemble/ensemble_4.pt2 filter=lfs diff=lfs merge=lfs -text
|
ensemble/ensemble_4.json
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"type": "Feature",
|
| 3 |
+
"stac_version": "1.1.0",
|
| 4 |
+
"stac_extensions": [
|
| 5 |
+
"https://stac-extensions.github.io/mlm/v1.5.0/schema.json",
|
| 6 |
+
"https://stac-extensions.github.io/file/v2.1.0/schema.json"
|
| 7 |
+
],
|
| 8 |
+
"id": "ENSEMBLE_4MODELS_MEAN_UNCERTAINTY_2025-10-27",
|
| 9 |
+
"geometry": {
|
| 10 |
+
"type": "Polygon",
|
| 11 |
+
"coordinates": [
|
| 12 |
+
[
|
| 13 |
+
[
|
| 14 |
+
-180.0,
|
| 15 |
+
-90.0
|
| 16 |
+
],
|
| 17 |
+
[
|
| 18 |
+
-180.0,
|
| 19 |
+
90.0
|
| 20 |
+
],
|
| 21 |
+
[
|
| 22 |
+
180.0,
|
| 23 |
+
90.0
|
| 24 |
+
],
|
| 25 |
+
[
|
| 26 |
+
180.0,
|
| 27 |
+
-90.0
|
| 28 |
+
],
|
| 29 |
+
[
|
| 30 |
+
-180.0,
|
| 31 |
+
-90.0
|
| 32 |
+
]
|
| 33 |
+
]
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
"bbox": [
|
| 37 |
+
-180,
|
| 38 |
+
-90,
|
| 39 |
+
180,
|
| 40 |
+
90
|
| 41 |
+
],
|
| 42 |
+
"properties": {
|
| 43 |
+
"datetime": "2025-10-27T11:08:23Z",
|
| 44 |
+
"created": "2025-10-27T11:08:23Z",
|
| 45 |
+
"updated": "2025-12-01T10:57:16.283159Z",
|
| 46 |
+
"description": "Ensemble of 4 models (1dpwdeeplabv3, 1dpwunetpp, 1dpwseg, unet) with Mean aggregation and uncertainty quantification for cloud detection in VGT-1, VGT-2, and PROBA-V satellite imagery.",
|
| 47 |
+
"title": "Ensemble Cloud Detection Model (4 Models + Uncertainty) - VGT1/VGT2/Proba-V",
|
| 48 |
+
"mlm:name": "ensemble_4models_mean_uncertainty_fdr4vgt_cloudmask",
|
| 49 |
+
"mlm:architecture": "Ensemble (Mean+Uncertainty): DeepLabV3+PW, UNet+++PW, SegFormer+PW, UNet",
|
| 50 |
+
"mlm:tasks": [
|
| 51 |
+
"semantic-segmentation",
|
| 52 |
+
"uncertainty-quantification"
|
| 53 |
+
],
|
| 54 |
+
"mlm:framework": "pytorch",
|
| 55 |
+
"mlm:framework_version": "2.5.1+cu121",
|
| 56 |
+
"mlm:accelerator": "cuda",
|
| 57 |
+
"mlm:accelerator_constrained": false,
|
| 58 |
+
"mlm:accelerator_summary": "NVIDIA GPU with CUDA support (compute capability >= 7.0)",
|
| 59 |
+
"mlm:accelerator_count": 1,
|
| 60 |
+
"mlm:memory_size": 187574737,
|
| 61 |
+
"mlm:batch_size_suggestion": 4,
|
| 62 |
+
"mlm:total_parameters": 29030983,
|
| 63 |
+
"mlm:pretrained": true,
|
| 64 |
+
"mlm:pretrained_source": "Global VGT-1/VGT-2/PROBA-V cloud detection models (100k+ training samples)",
|
| 65 |
+
"mlm:input": [
|
| 66 |
+
{
|
| 67 |
+
"name": "VGT_PROBA_TOC_reflectance",
|
| 68 |
+
"bands": [
|
| 69 |
+
"Blue (B0, ~450nm)",
|
| 70 |
+
"Red (B2, ~645nm)",
|
| 71 |
+
"Near-Infrared (B3, ~835nm)",
|
| 72 |
+
"SWIR (MIR, ~1665nm)"
|
| 73 |
+
],
|
| 74 |
+
"input": {
|
| 75 |
+
"shape": [
|
| 76 |
+
-1,
|
| 77 |
+
4,
|
| 78 |
+
512,
|
| 79 |
+
512
|
| 80 |
+
],
|
| 81 |
+
"dim_order": [
|
| 82 |
+
"batch",
|
| 83 |
+
"channel",
|
| 84 |
+
"height",
|
| 85 |
+
"width"
|
| 86 |
+
],
|
| 87 |
+
"data_type": "float32"
|
| 88 |
+
},
|
| 89 |
+
"norm": {
|
| 90 |
+
"type": "raw_toc_reflectance",
|
| 91 |
+
"range": [
|
| 92 |
+
0,
|
| 93 |
+
10000
|
| 94 |
+
],
|
| 95 |
+
"description": "Raw Top-of-Canopy reflectance values scaled by 10000"
|
| 96 |
+
},
|
| 97 |
+
"pre_processing_function": null
|
| 98 |
+
}
|
| 99 |
+
],
|
| 100 |
+
"mlm:output": [
|
| 101 |
+
{
|
| 102 |
+
"name": "cloud_probability",
|
| 103 |
+
"tasks": [
|
| 104 |
+
"semantic-segmentation"
|
| 105 |
+
],
|
| 106 |
+
"result": {
|
| 107 |
+
"shape": [
|
| 108 |
+
-1,
|
| 109 |
+
1,
|
| 110 |
+
512,
|
| 111 |
+
512
|
| 112 |
+
],
|
| 113 |
+
"dim_order": [
|
| 114 |
+
"batch",
|
| 115 |
+
"channel",
|
| 116 |
+
"height",
|
| 117 |
+
"width"
|
| 118 |
+
],
|
| 119 |
+
"data_type": "float32"
|
| 120 |
+
},
|
| 121 |
+
"classification:classes": [
|
| 122 |
+
{
|
| 123 |
+
"value": 0.0,
|
| 124 |
+
"name": "clear",
|
| 125 |
+
"description": "Clear sky (may contain cloud shadows)",
|
| 126 |
+
"color_hint": "00000000"
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"value": 1.0,
|
| 130 |
+
"name": "cloud",
|
| 131 |
+
"description": "Cloud present",
|
| 132 |
+
"color_hint": "FFFF00"
|
| 133 |
+
}
|
| 134 |
+
],
|
| 135 |
+
"post_processing_function": "Apply threshold to get binary mask. Recommended threshold: 0.4. Returns tuple: (probabilities, uncertainty)",
|
| 136 |
+
"standard_threshold": 0.5,
|
| 137 |
+
"recommended_threshold": 0.4,
|
| 138 |
+
"value_range": [
|
| 139 |
+
0.0,
|
| 140 |
+
1.0
|
| 141 |
+
],
|
| 142 |
+
"description": "Per-pixel mean probability across ensemble models. Built-in sigmoid activation. Values close to 1.0 indicate high confidence of cloud."
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"name": "prediction_uncertainty",
|
| 146 |
+
"tasks": [
|
| 147 |
+
"uncertainty-quantification"
|
| 148 |
+
],
|
| 149 |
+
"result": {
|
| 150 |
+
"shape": [
|
| 151 |
+
-1,
|
| 152 |
+
1,
|
| 153 |
+
512,
|
| 154 |
+
512
|
| 155 |
+
],
|
| 156 |
+
"dim_order": [
|
| 157 |
+
"batch",
|
| 158 |
+
"channel",
|
| 159 |
+
"height",
|
| 160 |
+
"width"
|
| 161 |
+
],
|
| 162 |
+
"data_type": "float32"
|
| 163 |
+
},
|
| 164 |
+
"value_range": [
|
| 165 |
+
0.0,
|
| 166 |
+
1.0
|
| 167 |
+
],
|
| 168 |
+
"description": "Normalized standard deviation across 4 ensemble members. Values close to 1.0 indicate high disagreement between models (high uncertainty). Automatically returned as second element of output tuple."
|
| 169 |
+
}
|
| 170 |
+
],
|
| 171 |
+
"mlm:hyperparameters": {
|
| 172 |
+
"ensemble_size": 4,
|
| 173 |
+
"ensemble_members": [
|
| 174 |
+
"1dpwdeeplabv3",
|
| 175 |
+
"1dpwunetpp",
|
| 176 |
+
"1dpwseg",
|
| 177 |
+
"unet"
|
| 178 |
+
],
|
| 179 |
+
"aggregation_method": "mean",
|
| 180 |
+
"uncertainty_method": "normalized_std",
|
| 181 |
+
"avg_val_loss": 0.0616,
|
| 182 |
+
"member_details": [
|
| 183 |
+
{
|
| 184 |
+
"model": "1dpwdeeplabv3",
|
| 185 |
+
"epoch": 25,
|
| 186 |
+
"val_loss": 0.0611
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"model": "1dpwunetpp",
|
| 190 |
+
"epoch": 22,
|
| 191 |
+
"val_loss": 0.0625
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"model": "1dpwseg",
|
| 195 |
+
"epoch": 23,
|
| 196 |
+
"val_loss": 0.0622
|
| 197 |
+
},
|
| 198 |
+
{
|
| 199 |
+
"model": "unet",
|
| 200 |
+
"epoch": 20,
|
| 201 |
+
"val_loss": 0.0606
|
| 202 |
+
}
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
"file:size": 125049825,
|
| 206 |
+
"custom:export_format": "torch.export.pt2",
|
| 207 |
+
"custom:has_sigmoid": true,
|
| 208 |
+
"custom:sigmoid_location": "built-in per-model wrapper",
|
| 209 |
+
"custom:export_datetime": "2025-12-01T10:57:16.283159Z",
|
| 210 |
+
"custom:training_stage": "ensemble-mean-uncertainty",
|
| 211 |
+
"custom:project": "FDR4VGT",
|
| 212 |
+
"custom:project_url": "https://fdr4vgt.eu/",
|
| 213 |
+
"custom:sensors": [
|
| 214 |
+
"VGT-1",
|
| 215 |
+
"VGT-2",
|
| 216 |
+
"PROBA-V"
|
| 217 |
+
],
|
| 218 |
+
"custom:sensor_notes": "Model applicable to SPOT-VGT1, SPOT-VGT2, and PROBA-V imagery",
|
| 219 |
+
"custom:spatial_resolution": "1km",
|
| 220 |
+
"custom:tile_size": 512,
|
| 221 |
+
"custom:recommended_overlap": 64,
|
| 222 |
+
"custom:applicable_start": "1998-03-01T00:00:00Z",
|
| 223 |
+
"custom:applicable_end": null,
|
| 224 |
+
"custom:returns_tuple": true,
|
| 225 |
+
"custom:tuple_format": "(probabilities, uncertainty)",
|
| 226 |
+
"dependencies": [
|
| 227 |
+
"torch>=2.0.0",
|
| 228 |
+
"segmentation-models-pytorch>=0.3.0",
|
| 229 |
+
"pytorch-lightning>=2.0.0",
|
| 230 |
+
"numpy>=1.20.0"
|
| 231 |
+
]
|
| 232 |
+
},
|
| 233 |
+
"links": [
|
| 234 |
+
{
|
| 235 |
+
"rel": "about",
|
| 236 |
+
"href": "https://fdr4vgt.eu/",
|
| 237 |
+
"type": "text/html",
|
| 238 |
+
"title": "FDR4VGT Project - Harmonized VGT Data Record"
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"rel": "license",
|
| 242 |
+
"href": "https://creativecommons.org/licenses/by/4.0/",
|
| 243 |
+
"type": "text/html",
|
| 244 |
+
"title": "CC-BY-4.0 License"
|
| 245 |
+
}
|
| 246 |
+
],
|
| 247 |
+
"assets": {
|
| 248 |
+
"model": {
|
| 249 |
+
"href": "https://huggingface.co/isp-uv-es/FDR4VGT-CLOUD/resolve/main/ensemble/ensemble_4.pt2",
|
| 250 |
+
"type": "application/octet-stream; application=pytorch",
|
| 251 |
+
"title": "PyTorch ensemble model weights",
|
| 252 |
+
"description": "Ensemble of 4 models in torch.export .pt2 format. Returns tuple: (probabilities, uncertainty).",
|
| 253 |
+
"mlm:artifact_type": "torch.export.pt2",
|
| 254 |
+
"roles": [
|
| 255 |
+
"mlm:model",
|
| 256 |
+
"mlm:weights",
|
| 257 |
+
"data"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
"example_data": {
|
| 261 |
+
"href": "https://huggingface.co/isp-uv-es/FDR4VGT-CLOUD/resolve/main/ensemble/example_data.safetensor",
|
| 262 |
+
"type": "application/octet-stream; application=safetensors",
|
| 263 |
+
"title": "Example VGT/PROBA-V image",
|
| 264 |
+
"description": "Example VGT/PROBA-V Top-of-Canopy reflectance image for model inference.",
|
| 265 |
+
"roles": [
|
| 266 |
+
"mlm:example_data",
|
| 267 |
+
"data"
|
| 268 |
+
]
|
| 269 |
+
},
|
| 270 |
+
"load": {
|
| 271 |
+
"href": "https://huggingface.co/isp-uv-es/FDR4VGT-CLOUD/resolve/main/ensemble/load.py",
|
| 272 |
+
"type": "application/x-python-code",
|
| 273 |
+
"title": "PyTorch Ensemble Loader",
|
| 274 |
+
"description": "Python helper code to load the exported .pt2 ensemble model. Includes predict_large() function for large images.",
|
| 275 |
+
"roles": [
|
| 276 |
+
"code"
|
| 277 |
+
]
|
| 278 |
+
}
|
| 279 |
+
},
|
| 280 |
+
"collection": "ENSEMBLE_4MODELS_FDR4VGT_CloudMask_MeanUncertainty"
|
| 281 |
+
}
|
ensemble/ensemble_4.pt2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca47822614547e57b105edee92cda3f8fd080c0139523e7febd47fce809d69a2
|
| 3 |
+
size 125049825
|
ensemble/example_data.safetensor
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a66d52bb558f756d105b41ead9386cdd6f04b4ac9cdc0173b5632aa00f35b244
|
| 3 |
+
size 524504
|
ensemble/load.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn
|
| 3 |
+
import pathlib
|
| 4 |
+
import pystac
|
| 5 |
+
from typing import Literal, Tuple
|
| 6 |
+
import numpy as np
|
| 7 |
+
import itertools
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
# Ensemble model for combining multiple models' outputs
|
| 12 |
+
class EnsembleModel(torch.nn.Module):
|
| 13 |
+
def __init__(self, *models, mode="max"):
|
| 14 |
+
super(EnsembleModel, self).__init__()
|
| 15 |
+
self.models = torch.nn.ModuleList(models)
|
| 16 |
+
self.mode = mode
|
| 17 |
+
if mode not in ["min", "mean", "median", "max", "none"]:
|
| 18 |
+
raise ValueError("Mode must be 'none', 'min', 'mean', 'median', or 'max'.")
|
| 19 |
+
|
| 20 |
+
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 21 |
+
"""
|
| 22 |
+
Forward pass for ensemble.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Tuple of (probabilities, uncertainty):
|
| 26 |
+
- probabilities: (B, 1, H, W) - aggregated predictions
|
| 27 |
+
- uncertainty: (B, 1, H, W) - normalized std deviation
|
| 28 |
+
"""
|
| 29 |
+
outputs = []
|
| 30 |
+
for model in self.models:
|
| 31 |
+
output = model(x)
|
| 32 |
+
outputs.append(output)
|
| 33 |
+
|
| 34 |
+
if not outputs:
|
| 35 |
+
return None, None
|
| 36 |
+
|
| 37 |
+
# Stack all model outputs: (B, N, H, W) where N = number of models
|
| 38 |
+
stacked_outputs = torch.stack(outputs, dim=1) # (B, N, 1, H, W)
|
| 39 |
+
stacked_outputs = stacked_outputs.squeeze(2) # (B, N, H, W)
|
| 40 |
+
|
| 41 |
+
# Calculate aggregated probabilities
|
| 42 |
+
if self.mode == "max":
|
| 43 |
+
output_probs = torch.max(stacked_outputs, dim=1, keepdim=True)[0]
|
| 44 |
+
elif self.mode == "mean":
|
| 45 |
+
output_probs = torch.mean(stacked_outputs, dim=1, keepdim=True)
|
| 46 |
+
elif self.mode == "median":
|
| 47 |
+
output_probs = torch.median(stacked_outputs, dim=1, keepdim=True)[0]
|
| 48 |
+
elif self.mode == "min":
|
| 49 |
+
output_probs = torch.min(stacked_outputs, dim=1, keepdim=True)[0]
|
| 50 |
+
elif self.mode == "none":
|
| 51 |
+
# Return all predictions without aggregation
|
| 52 |
+
return stacked_outputs, None
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError("Mode must be 'min', 'mean', 'median', or 'max'.")
|
| 55 |
+
|
| 56 |
+
# Calculate uncertainty (normalized standard deviation)
|
| 57 |
+
N = len(outputs)
|
| 58 |
+
if N > 1:
|
| 59 |
+
# Calculate std across models (dim=1)
|
| 60 |
+
std_output = torch.std(stacked_outputs, dim=1, keepdim=True)
|
| 61 |
+
|
| 62 |
+
# Normalize the standard deviation [0 - 1]
|
| 63 |
+
# Formula: std_max = sqrt(0.25 * N / (N - 1))
|
| 64 |
+
std_max = math.sqrt(0.25 * N / (N - 1))
|
| 65 |
+
uncertainty = std_output / std_max
|
| 66 |
+
|
| 67 |
+
# Clamp to [0, 1] to avoid numerical issues
|
| 68 |
+
uncertainty = torch.clamp(uncertainty, 0.0, 1.0)
|
| 69 |
+
else:
|
| 70 |
+
# Single model: no uncertainty
|
| 71 |
+
uncertainty = torch.zeros_like(output_probs)
|
| 72 |
+
|
| 73 |
+
return output_probs, uncertainty # Both (B, 1, H, W)
|
| 74 |
+
|
| 75 |
+
def compiled_model(
|
| 76 |
+
path: pathlib.Path,
|
| 77 |
+
stac_item: pystac.Item,
|
| 78 |
+
mode: Literal["min", "mean", "median", "max"] = "max",
|
| 79 |
+
*args, **kwargs
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Loads model(s) dynamically based on STAC metadata.
|
| 83 |
+
|
| 84 |
+
- If single .pt2 → returns single model
|
| 85 |
+
- If multiple .pt2 → returns EnsembleModel
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
mode: Aggregation mode for ensembles (ignored for single models)
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Single model or EnsembleModel
|
| 92 |
+
"""
|
| 93 |
+
model_paths = []
|
| 94 |
+
for asset_key, asset in stac_item.assets.items():
|
| 95 |
+
if asset.href.endswith(".pt2"):
|
| 96 |
+
model_paths.append(asset.href)
|
| 97 |
+
|
| 98 |
+
if not model_paths:
|
| 99 |
+
raise ValueError("No .pt2 files found in STAC item assets.")
|
| 100 |
+
|
| 101 |
+
model_paths.sort()
|
| 102 |
+
|
| 103 |
+
if len(model_paths) == 1:
|
| 104 |
+
# Single model
|
| 105 |
+
return torch.export.load(model_paths[0]).module()
|
| 106 |
+
else:
|
| 107 |
+
# Ensemble model
|
| 108 |
+
models = [torch.export.load(p).module() for p in model_paths]
|
| 109 |
+
return EnsembleModel(*models, mode=mode)
|
| 110 |
+
|
| 111 |
+
def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0):
|
| 112 |
+
"""
|
| 113 |
+
Defines iteration strategy to traverse the image with overlap.
|
| 114 |
+
"""
|
| 115 |
+
dimy, dimx = dimension
|
| 116 |
+
if chunk_size > max(dimx, dimy):
|
| 117 |
+
return [(0, 0)]
|
| 118 |
+
y_step = chunk_size - overlap
|
| 119 |
+
x_step = chunk_size - overlap
|
| 120 |
+
iterchunks = list(itertools.product(range(0, dimy, y_step), range(0, dimx, x_step)))
|
| 121 |
+
iterchunks_fixed = fix_lastchunk(
|
| 122 |
+
iterchunks=iterchunks, s2dim=dimension, chunk_size=chunk_size
|
| 123 |
+
)
|
| 124 |
+
return iterchunks_fixed
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def fix_lastchunk(iterchunks, s2dim, chunk_size):
|
| 128 |
+
"""
|
| 129 |
+
Adjusts last chunks to prevent them from exceeding boundaries.
|
| 130 |
+
"""
|
| 131 |
+
itercontainer = []
|
| 132 |
+
for index_i, index_j in iterchunks:
|
| 133 |
+
if index_i + chunk_size > s2dim[0]:
|
| 134 |
+
index_i = max(s2dim[0] - chunk_size, 0)
|
| 135 |
+
if index_j + chunk_size > s2dim[1]:
|
| 136 |
+
index_j = max(s2dim[1] - chunk_size, 0)
|
| 137 |
+
itercontainer.append((index_i, index_j))
|
| 138 |
+
return list(set(itercontainer)) # Returns unique values just in case
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def predict_large(
|
| 142 |
+
image: np.ndarray,
|
| 143 |
+
model: torch.nn.Module,
|
| 144 |
+
chunk_size: int = 512,
|
| 145 |
+
overlap: int = 64,
|
| 146 |
+
device: str = "cpu",
|
| 147 |
+
nodata: float = 0.0
|
| 148 |
+
) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
|
| 149 |
+
"""
|
| 150 |
+
Predict a full 'image' (C, H, W) using overlapping patches.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
image: Input array (C, H, W)
|
| 154 |
+
model: Compiled PyTorch model
|
| 155 |
+
chunk_size: Tile size for inference
|
| 156 |
+
overlap: Overlap between tiles
|
| 157 |
+
device: 'cpu' or 'cuda'
|
| 158 |
+
nodata: No-data value
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
- For ensembles: Tuple of (probabilities, uncertainty), both (1, H, W)
|
| 162 |
+
- For single models: probabilities array (1, H, W)
|
| 163 |
+
|
| 164 |
+
Compatible with:
|
| 165 |
+
- Normal models (with .eval()) - returns probabilities only
|
| 166 |
+
- Exported models (.pt2) - returns probabilities only
|
| 167 |
+
- Ensembles (EnsembleModel) - returns (probabilities, uncertainty)
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
# Validate input array dimensions
|
| 171 |
+
if image.ndim != 3:
|
| 172 |
+
raise ValueError(f"Input array must be (C, H, W). Received {image.shape}")
|
| 173 |
+
|
| 174 |
+
bands, height, width = image.shape
|
| 175 |
+
|
| 176 |
+
# Prepare model (compatibility logic for .pt2 models)
|
| 177 |
+
try:
|
| 178 |
+
model.eval()
|
| 179 |
+
for p in model.parameters():
|
| 180 |
+
p.requires_grad = False
|
| 181 |
+
model = model.to(device)
|
| 182 |
+
except (NotImplementedError, AttributeError):
|
| 183 |
+
# Exported model (.pt2) or EnsembleModel
|
| 184 |
+
model = model.to(device)
|
| 185 |
+
|
| 186 |
+
test_input = torch.zeros(1, bands, chunk_size, chunk_size).to(device)
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
test_output = model(test_input)
|
| 189 |
+
|
| 190 |
+
is_ensemble = isinstance(test_output, tuple) and len(test_output) == 2
|
| 191 |
+
|
| 192 |
+
# Initialize output arrays
|
| 193 |
+
output_probs = np.full((1, height, width), nodata, dtype=np.float32)
|
| 194 |
+
|
| 195 |
+
if is_ensemble:
|
| 196 |
+
output_uncertainty = np.full((1, height, width), nodata, dtype=np.float32)
|
| 197 |
+
|
| 198 |
+
# Get the list of tile offsets
|
| 199 |
+
coords = define_iteration(
|
| 200 |
+
dimension=(height, width),
|
| 201 |
+
chunk_size=chunk_size,
|
| 202 |
+
overlap=overlap
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Iterate over tiles
|
| 206 |
+
for idx, (row_off, col_off) in enumerate(tqdm(coords, desc="Inference")):
|
| 207 |
+
|
| 208 |
+
# Read chunk (numpy slicing)
|
| 209 |
+
patch = image[
|
| 210 |
+
:,
|
| 211 |
+
row_off : row_off + chunk_size,
|
| 212 |
+
col_off : col_off + chunk_size
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
# Convert to tensor and handle padding if tile is smaller than chunk_size
|
| 216 |
+
patch_tensor = torch.from_numpy(patch).float().unsqueeze(0).to(device)
|
| 217 |
+
_, _, h_tile, w_tile = patch_tensor.shape
|
| 218 |
+
|
| 219 |
+
# Calculate padding needed
|
| 220 |
+
pad_h = chunk_size - h_tile
|
| 221 |
+
pad_w = chunk_size - w_tile
|
| 222 |
+
|
| 223 |
+
# Apply padding if necessary
|
| 224 |
+
if pad_h > 0 or pad_w > 0:
|
| 225 |
+
patch_tensor = torch.nn.functional.pad(
|
| 226 |
+
patch_tensor, (0, pad_w, 0, pad_h), "constant", nodata
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Create mask for nodata areas (all bands are nodata)
|
| 230 |
+
mask_all = (patch_tensor == nodata).all(dim=1, keepdim=True)
|
| 231 |
+
|
| 232 |
+
# Forward pass
|
| 233 |
+
with torch.no_grad():
|
| 234 |
+
model_output = model(patch_tensor)
|
| 235 |
+
|
| 236 |
+
if is_ensemble:
|
| 237 |
+
probs, uncertainty = model_output
|
| 238 |
+
probs = probs.masked_fill(mask_all, nodata)
|
| 239 |
+
uncertainty = uncertainty.masked_fill(mask_all, nodata)
|
| 240 |
+
else:
|
| 241 |
+
probs = model_output
|
| 242 |
+
probs = probs.masked_fill(mask_all, nodata)
|
| 243 |
+
|
| 244 |
+
# Remove batch dimension and ensure (1, H, W)
|
| 245 |
+
if probs.ndim == 4:
|
| 246 |
+
probs = probs.squeeze(0) # (1, H, W)
|
| 247 |
+
|
| 248 |
+
# Convert to numpy
|
| 249 |
+
result_probs = probs.cpu().numpy() # (1, H, W)
|
| 250 |
+
|
| 251 |
+
if is_ensemble:
|
| 252 |
+
if uncertainty.ndim == 4:
|
| 253 |
+
uncertainty = uncertainty.squeeze(0)
|
| 254 |
+
result_uncertainty = uncertainty.cpu().numpy()
|
| 255 |
+
|
| 256 |
+
# Logic for partial writing
|
| 257 |
+
if col_off == 0:
|
| 258 |
+
offset_x = 0
|
| 259 |
+
else:
|
| 260 |
+
offset_x = col_off + overlap // 2
|
| 261 |
+
|
| 262 |
+
if row_off == 0:
|
| 263 |
+
offset_y = 0
|
| 264 |
+
else:
|
| 265 |
+
offset_y = row_off + overlap // 2
|
| 266 |
+
|
| 267 |
+
if (offset_x + chunk_size) == width:
|
| 268 |
+
length_x = chunk_size
|
| 269 |
+
sub_x_start = 0
|
| 270 |
+
else:
|
| 271 |
+
length_x = chunk_size - (overlap // 2)
|
| 272 |
+
sub_x_start = overlap // 2 if col_off != 0 else 0
|
| 273 |
+
|
| 274 |
+
if (offset_y + chunk_size) == height:
|
| 275 |
+
length_y = chunk_size
|
| 276 |
+
sub_y_start = 0
|
| 277 |
+
else:
|
| 278 |
+
length_y = chunk_size - (overlap // 2)
|
| 279 |
+
sub_y_start = overlap // 2 if row_off != 0 else 0
|
| 280 |
+
|
| 281 |
+
# Ensure we don't exceed array bounds
|
| 282 |
+
if offset_y + length_y > height:
|
| 283 |
+
length_y = height - offset_y
|
| 284 |
+
if offset_x + length_x > width:
|
| 285 |
+
length_x = width - offset_x
|
| 286 |
+
|
| 287 |
+
# Extract the valid region from the result
|
| 288 |
+
to_write_probs = result_probs[
|
| 289 |
+
:,
|
| 290 |
+
sub_y_start : sub_y_start + length_y,
|
| 291 |
+
sub_x_start : sub_x_start + length_x
|
| 292 |
+
]
|
| 293 |
+
|
| 294 |
+
# Write to the output numpy array
|
| 295 |
+
output_probs[
|
| 296 |
+
:,
|
| 297 |
+
offset_y : offset_y + length_y,
|
| 298 |
+
offset_x : offset_x + length_x
|
| 299 |
+
] = to_write_probs
|
| 300 |
+
|
| 301 |
+
if is_ensemble:
|
| 302 |
+
to_write_uncertainty = result_uncertainty[
|
| 303 |
+
:,
|
| 304 |
+
sub_y_start : sub_y_start + length_y,
|
| 305 |
+
sub_x_start : sub_x_start + length_x
|
| 306 |
+
]
|
| 307 |
+
output_uncertainty[
|
| 308 |
+
:,
|
| 309 |
+
offset_y : offset_y + length_y,
|
| 310 |
+
offset_x : offset_x + length_x
|
| 311 |
+
] = to_write_uncertainty
|
| 312 |
+
|
| 313 |
+
if is_ensemble:
|
| 314 |
+
return output_probs, output_uncertainty
|
| 315 |
+
else:
|
| 316 |
+
return output_probs
|