remisek commited on
Commit
4c0fb51
·
1 Parent(s): f5d4753

Install dependencies

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .venv/
2
+ .idea/
3
+ __pycache__/
4
+ *.pyc
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
4
+
5
+ WORKDIR /app
6
+
7
+ ENV UV_COMPILE_BYTECODE=1
8
+ ENV UV_LINK_MODE=copy
9
+
10
+ COPY pyproject.toml uv.lock ./
11
+
12
+ RUN uv sync --frozen --no-cache --no-install-project
13
+
14
+ COPY . .
15
+
16
+ EXPOSE 8501
17
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
18
+
19
+ ENTRYPOINT ["steamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
README.md CHANGED
@@ -9,3 +9,6 @@ short_description: Application for Automatic Crop Type Mapping
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+
13
+ # License
14
+ This project is licensed under the terms of the MIT License. See the LICENSE file for details.
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import folium
3
+ from folium import plugins
4
+ from streamlit_folium import st_folium
5
+ import rasterio
6
+ from rasterio.warp import calculate_default_transform, reproject, Resampling
7
+ import joblib
8
+ import numpy as np
9
+ import pandas as pd
10
+ import geopandas as gpd
11
+ from pathlib import Path
12
+ from matplotlib import colors as colors
13
+ import time
14
+ from rasterio.crs import CRS
15
+
16
+ from worldcereal.job import INFERENCE_JOB_OPTIONS, create_embeddings_process_graph
17
+ from openeo_gfmap import TemporalContext, BoundingBoxExtent
18
+ from worldcereal.parameters import EmbeddingsParameters
19
+
20
+
21
+ crop_classes = {
22
+ "tuz": "#32cd32",
23
+ "burak": "#8b008b",
24
+ "jęczmień": "#ffd700",
25
+ "kukurydza": "#ffa500",
26
+ "lucerna": "#9acd32",
27
+ "mieszanka": "#daa520",
28
+ "owies": "#f0e68c",
29
+ "pszenica": "#f5deb3",
30
+ "pszenżyto": "#bdb76b",
31
+ "rzepak": "#ffff00",
32
+ "sad": "#228b22",
33
+ "słonecznik": "#ff4500",
34
+ "ziemniak": "#a0522d",
35
+ "łubin": "#9370db",
36
+ "żyto": "#cd853f",
37
+ "inne": "#808080"
38
+ }
39
+ class_to_id = {name: i for i, name in enumerate(crop_classes.keys())}
40
+ id_to_class = {i: name for name, i in class_to_id.items()}
41
+
42
+ st.set_page_config(page_title="Crop Map", layout="wide")
43
+
44
+ model_path = Path("app/crop_map_app/models/random_forest_crop_classifier_06.joblib")
45
+ demo_dir = Path("app/crop_map_app/embeddings/demo")
46
+ temp_dir = Path("embeddings/temp_analysis") # for new files
47
+ temp_dir.mkdir(parents=True, exist_ok=True)
48
+
49
+ def get_class_color_rgba(class_name, alpha=180):
50
+ hex_color = crop_classes.get(class_name, "#000000")
51
+ rgb = colors.hex2color(hex_color)
52
+ return (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255), alpha)
53
+
54
+
55
+ def create_legend_html(stats_legend):
56
+ html_parts = [
57
+ "<div style='background-color: rgba(255, 255, 255, 0.1); padding: 10px; border-radius: 5px; font-family: sans-serif;'>"
58
+ ]
59
+
60
+ for _, row in stats_legend.iterrows():
61
+ crop = row['Crop']
62
+ color = crop_classes.get(crop, "#000000")
63
+ percent = row['Percentage']
64
+
65
+ row_html = (
66
+ f"<div style='display: flex; align-items: center; margin-bottom: 4px;'>"
67
+ f"<div style='width: 15px; height: 15px; background-color: {color}; margin-right: 10px; border-radius: 3px;'></div>"
68
+ f"<span style='font-size: 14px; flex-grow: 1;'>{crop}</span>"
69
+ f"<span style='font-weight: bold; font-size: 14px;'>{percent:.1f}%</span>"
70
+ f"</div>"
71
+ )
72
+ html_parts.append(row_html)
73
+
74
+ html_parts.append("</div>")
75
+ return "".join(html_parts)
76
+
77
+ @st.cache_resource
78
+ def load_model():
79
+ if not model_path.exists(): return None
80
+ return joblib.load(model_path)
81
+
82
+ @st.cache_data
83
+ def run_prediction(tif_path, _model):
84
+ with rasterio.open(tif_path) as src:
85
+ embedding = src.read()
86
+ src_transform = src.transform
87
+ src_crs = src.crs
88
+ h, w = src.height, src.width
89
+
90
+ n_channels = embedding.shape[0]
91
+ reshaped = embedding.transpose(1, 2, 0).reshape(-1, n_channels)
92
+
93
+ # prediction
94
+ batch_size = 50000
95
+ preds = []
96
+ for i in range(0, reshaped.shape[0], batch_size):
97
+ batch = reshaped[i:i + batch_size]
98
+ batch = np.nan_to_num(batch)
99
+ preds.append(_model.predict(batch))
100
+
101
+ raw_class_map_str = np.concatenate(preds).reshape(h, w)
102
+
103
+ raw_class_map_int = np.zeros((h, w), dtype=np.uint8)
104
+ for class_name, class_id in class_to_id.items():
105
+ raw_class_map_int[raw_class_map_str == class_name] = class_id
106
+
107
+ src_crs_str = src_crs.to_string()
108
+ dst_crs = CRS.from_string('EPSG:4326')
109
+
110
+ left, bottom, right, top = rasterio.transform.array_bounds(h, w, src_transform)
111
+ transform, dst_width, dst_height = calculate_default_transform(
112
+ src_crs_str, dst_crs, w, h, left=left, bottom=bottom, right=right, top=top
113
+ )
114
+
115
+ destination = np.zeros((dst_height, dst_width), dtype=np.uint8)
116
+ reproject(
117
+ source=raw_class_map_int,
118
+ destination=destination,
119
+ src_transform=src_transform,
120
+ src_crs=src_crs_str,
121
+ dst_transform=transform,
122
+ dst_crs=dst_crs,
123
+ resampling=Resampling.nearest
124
+ )
125
+
126
+ bounds_orig = rasterio.transform.array_bounds(dst_height, dst_width, transform)
127
+ folium_bounds = [[bounds_orig[1], bounds_orig[0]], [bounds_orig[3], bounds_orig[2]]]
128
+
129
+ return destination, folium_bounds
130
+
131
+ def run_openeo_job(lat, lon, size_km=1.0):
132
+ """
133
+ Runs WorldCereal job for a small box around lat/lon.
134
+ Returns path to downloaded tif or None.
135
+ """
136
+ try:
137
+ offset = (size_km / 111) / 2
138
+ west, east = lon - offset, lon + offset
139
+ south, north = lat - offset, lat + offset
140
+
141
+ spatial_extent = BoundingBoxExtent(
142
+ west=west, south=south, east=east, north=north, epsg=4326
143
+ )
144
+
145
+ # changing time range
146
+ temporal_extent = TemporalContext("2025-01-01", "2025-12-31")
147
+
148
+ st.info("Building OpenEO Process Graph...")
149
+ embedding_params = EmbeddingsParameters()
150
+ inference_result = create_embeddings_process_graph(
151
+ spatial_extent=spatial_extent,
152
+ temporal_extent=temporal_extent,
153
+ embeddings_parameters=embedding_params,
154
+ scale_uint16=True
155
+ )
156
+
157
+ job_title = f"thesis_demo_{lat}_{lon}"
158
+ st.info(f"Submitting Job: {job_title}...")
159
+ job = inference_result.create_job(
160
+ title=job_title,
161
+ job_options=INFERENCE_JOB_OPTIONS,
162
+ )
163
+
164
+ job.start()
165
+ job_id = job.job_id
166
+ st.success(f"Job started. ID: {job_id}")
167
+
168
+ status_box = st.empty()
169
+ while True:
170
+ metadata = job.describe_job()
171
+ status = metadata.get("status")
172
+ status_box.markdown(f"**Status:** `{status}` (refreshing every 5s...)")
173
+
174
+ if status == "finished":
175
+ break
176
+ elif status in ["error", "canceled"]:
177
+ st.error(f"Job failed with status: {status}")
178
+ return None
179
+
180
+ time.sleep(5)
181
+
182
+ st.info("Downloading results...")
183
+ results = job.get_results()
184
+ output_path = temp_dir / f"embedding_{lat}_{lon}.tif"
185
+
186
+ found = False
187
+ for asset in results.get_assets():
188
+ if asset.metadata.get("type", "").startswith("image/tiff"):
189
+ asset.download(str(output_path))
190
+ found = True
191
+ break
192
+
193
+ if found:
194
+ return output_path
195
+ else:
196
+ st.error("No TIFF found in results.")
197
+ return None
198
+
199
+ except Exception as e:
200
+ st.error(f"OpenEO Error: {str(e)}")
201
+ return None
202
+
203
+
204
+ st.title("Crop Map")
205
+
206
+ with st.sidebar:
207
+ st.header("Control Panel")
208
+ tif_files = list(demo_dir.glob("*.tif"))
209
+ if not tif_files:
210
+ st.error(f"No .tif files in {demo_dir}")
211
+ st.stop()
212
+
213
+ selected_tif = st.selectbox("Select Region", tif_files, format_func=lambda x: x.name)
214
+
215
+ possible_name = selected_tif.stem.replace("_embedding", "") + ".geojson"
216
+ geojson_path = selected_tif.parent / possible_name
217
+ has_geojson = geojson_path.exists()
218
+
219
+ if has_geojson:
220
+ st.success(f"Linked: {geojson_path.name}")
221
+
222
+ run_btn = st.button("Run Analysis", type="primary")
223
+
224
+ if run_btn:
225
+ model = load_model()
226
+ if not model:
227
+ st.error("Model not found")
228
+ st.stop()
229
+
230
+ with st.spinner("Processing..."): # type: ignore[arg-type]
231
+ class_map, bounds = run_prediction(selected_tif, model)
232
+
233
+ h, w = class_map.shape
234
+ rgba_img = np.zeros((h, w, 4), dtype=np.uint8)
235
+ unique_ids = np.unique(class_map)
236
+
237
+ for uid in unique_ids:
238
+ if uid not in id_to_class: continue
239
+ crop = id_to_class[uid]
240
+ c = get_class_color_rgba(crop, alpha=255)
241
+ rgba_img[class_map == uid] = c
242
+
243
+ gdf = None
244
+ if has_geojson:
245
+ gdf = gpd.read_file(geojson_path)
246
+ if gdf.crs != "EPSG:4326":
247
+ gdf = gdf.to_crs("EPSG:4326")
248
+ gdf['geometry'] = gdf['geometry'].simplify(tolerance=0.0001)
249
+
250
+ total = class_map.size
251
+ counts = {id_to_class[uid]: np.sum(class_map == uid) for uid in unique_ids if uid in id_to_class}
252
+ stats_df = pd.DataFrame([
253
+ {"Crop": k, "Pixels": v, "Percentage": v / total * 100} for k, v in counts.items()
254
+ ]).sort_values("Percentage", ascending=False)
255
+
256
+ st.session_state['analysis_results'] = {
257
+ "bounds": bounds,
258
+ "rgba_img": rgba_img,
259
+ "gdf": gdf,
260
+ "stats_df": stats_df
261
+ }
262
+
263
+ tab1, tab2 = st.tabs(["Pre-loaded Regions", "Analyze New Area"])
264
+
265
+ with tab1:
266
+ if 'analysis_results' in st.session_state:
267
+ data = st.session_state['analysis_results']
268
+ bounds = data['bounds']
269
+ rgba_img = data['rgba_img']
270
+ gdf = data['gdf']
271
+ stats_df = data['stats_df']
272
+
273
+ c1, c2 = st.columns([3, 1])
274
+
275
+ with c1:
276
+ center_lat = (bounds[0][0] + bounds[1][0]) / 2
277
+ center_lon = (bounds[0][1] + bounds[1][1]) / 2
278
+
279
+ overlay_opacity = st.slider("Overlay Opacity", 0.0, 1.0, 0.7, 0.1, key="opacity_tab1")
280
+
281
+ m = folium.Map(location=[center_lat, center_lon], zoom_start=14, control_scale=True)
282
+
283
+ folium.TileLayer(
284
+ tiles='CartoDB positron',
285
+ name='Light Map',
286
+ overlay=False
287
+ ).add_to(m)
288
+
289
+ folium.TileLayer(
290
+ tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
291
+ attr='Esri',
292
+ name='Satellite',
293
+ overlay=False
294
+ ).add_to(m)
295
+
296
+ folium.raster_layers.ImageOverlay(
297
+ image=rgba_img,
298
+ bounds=bounds,
299
+ opacity=overlay_opacity,
300
+ name='Prediction',
301
+ pixelated=True
302
+ ).add_to(m)
303
+
304
+ if gdf is not None:
305
+ folium.GeoJson(
306
+ gdf,
307
+ name="Fields",
308
+ style_function=lambda x: {'color': 'white', 'weight': 1, 'fillOpacity': 0, 'dashArray': '5, 5'},
309
+ tooltip=folium.GeoJsonTooltip(fields=['roslina'], aliases=['Crop:'])
310
+ ).add_to(m)
311
+
312
+ folium.LayerControl().add_to(m)
313
+ plugins.Fullscreen().add_to(m)
314
+
315
+ st_folium(m, height=600, use_container_width=True)
316
+
317
+ with c2:
318
+ st.subheader("Legend")
319
+ st.markdown(create_legend_html(stats_df), unsafe_allow_html=True)
320
+ st.dataframe(stats_df[["Crop", "Percentage"]], hide_index=True)
321
+
322
+ with tab2:
323
+ c1, c2 = st.columns([1, 2])
324
+
325
+ if 'tab2_results' not in st.session_state:
326
+ st.session_state['tab2_results'] = None
327
+
328
+ with c1:
329
+ st.markdown("### 1. Select Area")
330
+ lat = st.number_input("Latitude", value=50.93131691432723, format="%.4f")
331
+ lon = st.number_input("Longitude", value=22.781513694631702, format="%.4f")
332
+
333
+ if st.button("Generate the embedding and classify"):
334
+ with st.spinner("Talking to Satellites... (This takes ~5 mins)"): # type: ignore[arg-type]
335
+ tif_path = run_openeo_job(lat, lon)
336
+
337
+ if tif_path:
338
+ st.success("Embedding Generated!")
339
+
340
+ model = load_model()
341
+ class_map, bounds = run_prediction(tif_path, model)
342
+
343
+ h, w = class_map.shape
344
+ rgba_img = np.zeros((h, w, 4), dtype=np.uint8)
345
+ unique_ids = np.unique(class_map)
346
+
347
+ for uid in unique_ids:
348
+ if uid not in id_to_class: continue
349
+ crop = id_to_class[uid]
350
+ c = get_class_color_rgba(crop, alpha=255)
351
+ rgba_img[class_map == uid] = c
352
+
353
+ total = class_map.size
354
+ counts = {id_to_class[uid]: np.sum(class_map == uid) for uid in unique_ids if uid in id_to_class}
355
+ stats_df = pd.DataFrame([
356
+ {"Crop": k, "Pixels": v, "Percentage": v / total * 100} for k, v in counts.items()
357
+ ]).sort_values("Percentage", ascending=False)
358
+
359
+ st.session_state['tab2_results'] = {
360
+ "bounds": bounds,
361
+ "rgba_img": rgba_img,
362
+ "stats_df": stats_df
363
+ }
364
+
365
+ st.success("Classification Complete")
366
+
367
+ with c2:
368
+ if st.session_state['tab2_results']:
369
+ data = st.session_state['tab2_results']
370
+ bounds = data['bounds']
371
+ rgba_img = data['rgba_img']
372
+ stats_df = data['stats_df']
373
+
374
+ st.markdown("### 2. Analysis Results")
375
+
376
+ center_lat = (bounds[0][0] + bounds[1][0]) / 2
377
+ center_lon = (bounds[0][1] + bounds[1][1]) / 2
378
+
379
+ overlay_opacity = st.slider("Overlay Opacity", 0.0, 1.0, 0.7, 0.1, key="opacity_tab2")
380
+
381
+ m = folium.Map(location=[center_lat, center_lon], zoom_start=14, control_scale=True)
382
+
383
+ folium.TileLayer(
384
+ tiles='CartoDB positron',
385
+ name='Light Map',
386
+ overlay=False
387
+ ).add_to(m)
388
+
389
+ folium.TileLayer(
390
+ tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
391
+ attr='Esri',
392
+ name='Satellite',
393
+ overlay=False
394
+ ).add_to(m)
395
+
396
+ folium.raster_layers.ImageOverlay(
397
+ image=rgba_img,
398
+ bounds=bounds,
399
+ opacity=overlay_opacity,
400
+ name='Prediction',
401
+ pixelated=True
402
+ ).add_to(m)
403
+
404
+ folium.LayerControl().add_to(m)
405
+ plugins.Fullscreen().add_to(m)
406
+
407
+ st_folium(m, height=500, use_container_width=True)
408
+
409
+ st.divider()
410
+ col_leg, col_df = st.columns(2)
411
+ with col_leg:
412
+ st.subheader("Legend")
413
+ st.markdown(create_legend_html(stats_df), unsafe_allow_html=True)
openeo_gfmap/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEO General Framework for Mapping.
2
+
3
+ Simplify the development of mapping applications through Remote Sensing data
4
+ by leveraging the power of OpenEO (http://openeo.org/).
5
+
6
+ More information available in the README.md file.
7
+ """
8
+
9
+ from .backend import Backend, BackendContext
10
+ from .fetching import FetchType
11
+ from .metadata import FakeMetadata
12
+ from .spatial import BoundingBoxExtent, SpatialContext
13
+ from .temporal import TemporalContext
14
+
15
+ __all__ = [
16
+ "Backend",
17
+ "BackendContext",
18
+ "SpatialContext",
19
+ "BoundingBoxExtent",
20
+ "TemporalContext",
21
+ "FakeMetadata",
22
+ "FetchType",
23
+ ]
openeo_gfmap/backend.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Backend Contct.
2
+
3
+ Defines on which backend the pipeline is being currently used.
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ from dataclasses import dataclass
9
+ from enum import Enum
10
+ from typing import Callable, Dict, Optional
11
+
12
+ import openeo
13
+
14
+ _log = logging.getLogger(__name__)
15
+
16
+
17
+ class Backend(Enum):
18
+ """Enumerating the backends supported by the Mapping Framework."""
19
+
20
+ TERRASCOPE = "terrascope"
21
+ EODC = "eodc" # Dask implementation. Do not test on this yet.
22
+ CDSE = "cdse" # Terrascope implementation (pyspark) #URL: openeo.dataspace.copernicus.eu (need to register)
23
+ CDSE_STAGING = "cdse-staging"
24
+ LOCAL = "local" # Based on the same components of EODc
25
+ FED = "fed" # Federation backend
26
+
27
+
28
+ @dataclass
29
+ class BackendContext:
30
+ """Backend context and information.
31
+
32
+ Containing backend related information useful for the framework to
33
+ adapt the process graph.
34
+ """
35
+
36
+ backend: Backend
37
+
38
+
39
+ def _create_connection(
40
+ url: str, *, env_var_suffix: str, connect_kwargs: Optional[dict] = None
41
+ ):
42
+ """
43
+ Generic helper to create an openEO connection
44
+ with support for multiple client credential configurations from environment variables
45
+ """
46
+ connection = openeo.connect(url, **(connect_kwargs or {}))
47
+
48
+ if (
49
+ os.environ.get("OPENEO_AUTH_METHOD") == "client_credentials"
50
+ and f"OPENEO_AUTH_CLIENT_ID_{env_var_suffix}" in os.environ
51
+ ):
52
+ # Support for multiple client credentials configs from env vars
53
+ client_id = os.environ[f"OPENEO_AUTH_CLIENT_ID_{env_var_suffix}"]
54
+ client_secret = os.environ[f"OPENEO_AUTH_CLIENT_SECRET_{env_var_suffix}"]
55
+ provider_id = os.environ.get(f"OPENEO_AUTH_PROVIDER_ID_{env_var_suffix}")
56
+ _log.info(
57
+ f"Doing client credentials from env var with {env_var_suffix=} {provider_id} {client_id=} {len(client_secret)=} "
58
+ )
59
+
60
+ connection.authenticate_oidc_client_credentials(
61
+ client_id=client_id, client_secret=client_secret, provider_id=provider_id
62
+ )
63
+ else:
64
+ # Standard authenticate_oidc procedure: refresh token, device code or default env var handling
65
+ # See https://open-eo.github.io/openeo-python-client/auth.html#oidc-authentication-dynamic-method-selection
66
+
67
+ # Use a shorter max poll time by default to alleviate the default impression that the test seem to hang
68
+ # because of the OIDC device code poll loop.
69
+ max_poll_time = int(
70
+ os.environ.get("OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME") or 30
71
+ )
72
+ connection.authenticate_oidc(max_poll_time=max_poll_time)
73
+ return connection
74
+
75
+
76
+ def vito_connection() -> openeo.Connection:
77
+ """Performs a connection to the VITO backend using the oidc authentication."""
78
+ return _create_connection(
79
+ url="openeo.vito.be",
80
+ env_var_suffix="VITO",
81
+ )
82
+
83
+
84
+ def cdse_connection() -> openeo.Connection:
85
+ """Performs a connection to the CDSE backend using oidc authentication."""
86
+ return _create_connection(
87
+ url="openeo.dataspace.copernicus.eu",
88
+ env_var_suffix="CDSE",
89
+ )
90
+
91
+
92
+ def cdse_staging_connection() -> openeo.Connection:
93
+ """Performs a connection to the CDSE backend using oidc authentication."""
94
+ return _create_connection(
95
+ url="openeo-staging.dataspace.copernicus.eu",
96
+ env_var_suffix="CDSE_STAGING",
97
+ )
98
+
99
+
100
+ def eodc_connection() -> openeo.Connection:
101
+ """Perfroms a connection to the EODC backend using the oidc authentication."""
102
+ return _create_connection(
103
+ url="https://openeo.eodc.eu/openeo/1.1.0",
104
+ env_var_suffix="EODC",
105
+ )
106
+
107
+
108
+ def fed_connection() -> openeo.Connection:
109
+ """Performs a connection to the OpenEO federated backend using the oidc
110
+ authentication."""
111
+ return _create_connection(
112
+ url="openeofed.dataspace.copernicus.eu/",
113
+ env_var_suffix="FED",
114
+ )
115
+
116
+
117
+ BACKEND_CONNECTIONS: Dict[Backend, Callable] = {
118
+ Backend.TERRASCOPE: vito_connection,
119
+ Backend.CDSE: cdse_connection,
120
+ Backend.CDSE_STAGING: cdse_staging_connection,
121
+ Backend.FED: fed_connection,
122
+ }
openeo_gfmap/fetching.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Main file for extractions and pre-processing of data through OpenEO
2
+ """
3
+
4
+ from enum import Enum
5
+ from typing import Callable
6
+
7
+ import openeo
8
+
9
+ from openeo_gfmap import BackendContext
10
+ from openeo_gfmap.spatial import SpatialContext
11
+ from openeo_gfmap.temporal import TemporalContext
12
+
13
+
14
+ class FetchType(Enum):
15
+ """Enumerates the different types of extraction. There are three types of
16
+ enumerations.
17
+
18
+ * TILE: Tile based extractions, getting the data for a dense part. The
19
+ output of such fetching process in a dense DataCube.
20
+ * POINT: Point based extractions. From a datasets of polygons, gets sparse
21
+ extractions and performs spatial aggregation on the selected polygons. The
22
+ output of such fetching process is a VectorCube, that can be used to get
23
+ a pandas.DataFrame
24
+ * POLYGON: Patch based extractions, returning a VectorCube of sparsed
25
+ patches. This can be retrieved as multiple NetCDF files from one job.
26
+ """
27
+
28
+ TILE = "tile"
29
+ POINT = "point"
30
+ POLYGON = "polygon"
31
+
32
+
33
+ class CollectionFetcher:
34
+ """Base class to fetch a particular collection.
35
+
36
+ Parameters
37
+ ----------
38
+ backend_context: BackendContext
39
+ Information about the backend in use, useful in certain cases.
40
+ bands: list
41
+ List of band names to load from that collection.
42
+ collection_fetch: Callable
43
+ Function defining how to fetch a collection for a specific backend,
44
+ the function accepts the following parameters: connection,
45
+ spatial extent, temporal extent, bands and additional parameters.
46
+ collection_preprocessing: Callable
47
+ Function defining how to harmonize the data of a collection in a
48
+ backend. For example, this function could rename the bands as they
49
+ can be different for every backend/collection (SENTINEL2_L2A or
50
+ SENTINEL2_L2A_SENTINELHUB). Accepts the following parameters:
51
+ datacube (of pre-fetched collection) and additional parameters.
52
+ colection_params: dict
53
+ Additional parameters encoded within a dictionnary that will be
54
+ passed in the fetch and preprocessing function.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ backend_context: BackendContext,
60
+ bands: list,
61
+ collection_fetch: Callable,
62
+ collection_preprocessing: Callable,
63
+ **collection_params,
64
+ ):
65
+ self.backend_contect = backend_context
66
+ self.bands = bands
67
+ self.fetcher = collection_fetch
68
+ self.processing = collection_preprocessing
69
+ self.params = collection_params
70
+
71
+ def get_cube(
72
+ self,
73
+ connection: openeo.Connection,
74
+ spatial_context: SpatialContext,
75
+ temporal_context: TemporalContext,
76
+ ) -> openeo.DataCube:
77
+ """Retrieve a data cube from the given spatial and temporal context.
78
+
79
+ Parameters
80
+ ----------
81
+ connection: openeo.Connection
82
+ A connection to an OpenEO backend. The backend provided must be the
83
+ same as the one this extractor class is configured for.
84
+ spatial_extent: SpatialContext
85
+ Either a GeoJSON collection on which spatial filtering will be
86
+ applied or a bounding box with an EPSG code. If a bounding box is
87
+ provided, no filtering is applied and the entirety of the data is
88
+ fetched for that region.
89
+ temporal_extent: TemporalContext
90
+ The begin and end date of the extraction.
91
+ """
92
+ collection_data = self.fetcher(
93
+ connection, spatial_context, temporal_context, self.bands, **self.params
94
+ )
95
+
96
+ preprocessed_data = self.processing(collection_data, **self.params)
97
+
98
+ return preprocessed_data
openeo_gfmap/metadata.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metadata utilities related to the usage of a DataCube. Used to interract
2
+ with the OpenEO backends and cover some shortcomings.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass
9
+ class FakeMetadata:
10
+ """Fake metdata object used for datacubes fetched from STAC catalogues.
11
+ This is used as a temporal fix for OpenEO backend shortcomings, but
12
+ will become unused with the time.
13
+ """
14
+
15
+ band_names: list
16
+
17
+ def rename_labels(self, _, target, source):
18
+ """Rename the labels of the band dimension."""
19
+ mapping = dict(zip(target, source))
20
+ band_names = self.band_names.copy()
21
+ for idx, name in enumerate(band_names):
22
+ if name in target:
23
+ self.band_names[idx] = mapping[name]
24
+ return self
openeo_gfmap/spatial.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Definitions of spatial context, either point-based or spatial"""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Union
5
+
6
+ from geojson import GeoJSON
7
+ from shapely.geometry import Polygon, box
8
+
9
+
10
+ @dataclass
11
+ class BoundingBoxExtent:
12
+ """Definition of a bounding box as accepted by OpenEO
13
+
14
+ Contains the minx, miny, maxx, maxy coordinates expressed as west, south
15
+ east, north. The EPSG is also defined.
16
+ """
17
+
18
+ west: float
19
+ south: float
20
+ east: float
21
+ north: float
22
+ epsg: int = 4326
23
+
24
+ def __dict__(self):
25
+ return {
26
+ "west": self.west,
27
+ "south": self.south,
28
+ "east": self.east,
29
+ "north": self.north,
30
+ "crs": f"EPSG:{self.epsg}",
31
+ "srs": f"EPSG:{self.epsg}",
32
+ }
33
+
34
+ def __iter__(self):
35
+ return iter(
36
+ [
37
+ ("west", self.west),
38
+ ("south", self.south),
39
+ ("east", self.east),
40
+ ("north", self.north),
41
+ ("crs", f"EPSG:{self.epsg}"),
42
+ ("srs", f"EPSG:{self.epsg}"),
43
+ ]
44
+ )
45
+
46
+ def to_geometry(self) -> Polygon:
47
+ return box(self.west, self.south, self.east, self.north)
48
+
49
+ def to_geojson(self) -> GeoJSON:
50
+ return self.to_geometry().__geo_interface__
51
+
52
+
53
+ SpatialContext = Union[GeoJSON, BoundingBoxExtent, str]
openeo_gfmap/temporal.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Definitions of temporal context"""
2
+
3
+ from dataclasses import dataclass
4
+ from datetime import datetime
5
+
6
+
7
+ @dataclass
8
+ class TemporalContext:
9
+ """Temporal context is defined by a `start_date` and `end_date` values.
10
+
11
+ The value must be encoded on a YYYY-mm-dd format, e.g. 2020-01-01
12
+ """
13
+
14
+ start_date: str
15
+ end_date: str
16
+
17
+ def to_datetime(self):
18
+ """Converts the temporal context to a tuple of datetime objects."""
19
+ return (
20
+ datetime.strptime(self.start_date, "%Y-%m-%d"),
21
+ datetime.strptime(self.end_date, "%Y-%m-%d"),
22
+ )
pyproject.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "crop-map"
3
+ version = "0.1.0"
4
+ description = "Application for Crop Type Mapping"
5
+ requires-python = ">=3.12"
6
+ dependencies = [
7
+ "folium>=0.20.0",
8
+ "geojson>=3.2.0",
9
+ "geopandas>=1.1.2",
10
+ "joblib>=1.5.3",
11
+ "matplotlib>=3.10.8",
12
+ "openeo>=0.47.0",
13
+ "rasterio>=1.5.0",
14
+ "streamlit>=1.53.1",
15
+ "streamlit-folium>=0.26.1",
16
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
worldcereal/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from ._version import __version__
4
+
5
+ __all__ = ["__version__"]
6
+
7
+ SUPPORTED_SEASONS = [
8
+ "tc-s1",
9
+ "tc-s2",
10
+ "tc-annual",
11
+ "custom",
12
+ ]
13
+
14
+ SEASONAL_MAPPING = {
15
+ "tc-s1": "S1",
16
+ "tc-s2": "S2",
17
+ "tc-annual": "ANNUAL",
18
+ "custom": "custom",
19
+ }
20
+
21
+
22
+ # Default buffer (days) prior to
23
+ # season start
24
+ SEASON_PRIOR_BUFFER = {
25
+ "tc-s1": 0,
26
+ "tc-s2": 0,
27
+ "tc-annual": 0,
28
+ "custom": 0,
29
+ }
30
+
31
+
32
+ # Default buffer (days) after
33
+ # season end
34
+ SEASON_POST_BUFFER = {
35
+ "tc-s1": 0,
36
+ "tc-s2": 0,
37
+ "tc-annual": 0,
38
+ "custom": 0,
39
+ }
worldcereal/_version.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ __version__ = "2.4.1"
worldcereal/job.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Executing inference jobs on the OpenEO backend.
2
+
3
+ Possible entry points for inference in this module:
4
+ - `generate_map`: This function is used to generate a map for a single patch.
5
+ It creates one OpenEO job and processes the inference for the specified
6
+ spatial and temporal extent.
7
+ - `collect_inputs`: This function is used to collect preprocessed inputs
8
+ without performing inference. It retrieves the required data for further
9
+ processing or analysis.
10
+ - `run_largescale_inference`: This function utilizes a job manager to
11
+ orchestrate and execute multiple inference jobs automatically, enabling
12
+ efficient large-scale processing.
13
+ - `setup_inference_job_manager`: This function prepares the job manager
14
+ and job database for large-scale inference jobs. It sets up the necessary
15
+ infrastructure to manage and track jobs in a notebook environment.
16
+ Used in the WorldCereal demo notebooks.
17
+
18
+ """
19
+
20
+ import json
21
+ import shutil
22
+ from copy import deepcopy
23
+ from functools import partial
24
+ from pathlib import Path
25
+ from typing import Callable, Dict, List, Literal, Optional, Union
26
+
27
+ import geopandas as gpd
28
+ import openeo
29
+ import pandas as pd
30
+ from loguru import logger
31
+ from openeo import BatchJob
32
+ from openeo.extra.job_management import CsvJobDatabase, MultiBackendJobManager
33
+ from openeo_gfmap import Backend, BackendContext, BoundingBoxExtent, TemporalContext
34
+ from openeo_gfmap.backend import BACKEND_CONNECTIONS
35
+ from pydantic import BaseModel
36
+ from typing_extensions import TypedDict
37
+
38
+ from worldcereal.openeo.mapping import _cropland_map, _croptype_map, _embeddings_map
39
+ from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs
40
+ from worldcereal.parameters import (
41
+ CropLandParameters,
42
+ CropTypeParameters,
43
+ EmbeddingsParameters,
44
+ WorldCerealProductType,
45
+ )
46
+ from worldcereal.utils.models import load_model_lut
47
+
48
+ ONNX_DEPS_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/onnx_deps_python311.zip"
49
+ FEATURE_DEPS_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/torch_deps_python311.zip"
50
+ INFERENCE_JOB_OPTIONS = {
51
+ "driver-memory": "4g",
52
+ "executor-memory": "2g",
53
+ "executor-memoryOverhead": "3g",
54
+ "max-executors": 20,
55
+ "python-memory": "disable",
56
+ "soft-errors": 0.1,
57
+ "image-name": "python311",
58
+ "udf-dependency-archives": [
59
+ f"{ONNX_DEPS_URL}#onnx_deps",
60
+ f"{FEATURE_DEPS_URL}#feature_deps",
61
+ ],
62
+ }
63
+
64
+
65
+ class WorldCerealProduct(TypedDict):
66
+ """Dataclass representing a WorldCereal inference product.
67
+
68
+ Attributes
69
+ ----------
70
+ url: str
71
+ URL to the product.
72
+ type: WorldCerealProductType
73
+ Type of the product. Either cropland or croptype.
74
+ temporal_extent: TemporalContext
75
+ Period of time for which the product has been generated.
76
+ path: Optional[Path]
77
+ Path to the downloaded product.
78
+ lut: Optional[Dict]
79
+ Look-up table for the product.
80
+
81
+ """
82
+
83
+ url: str
84
+ type: WorldCerealProductType
85
+ temporal_extent: TemporalContext
86
+ path: Optional[Path]
87
+ lut: Optional[Dict]
88
+
89
+
90
+ class InferenceResults(BaseModel):
91
+ """Dataclass to store the results of the WorldCereal job.
92
+
93
+ Attributes
94
+ ----------
95
+ job_id : str
96
+ Job ID of the finished OpenEO job.
97
+ products: Dict[str, WorldCerealProduct]
98
+ Dictionary with the different products.
99
+ metadata: Optional[Path]
100
+ Path to metadata file, if it was downloaded locally.
101
+ """
102
+
103
+ job_id: str
104
+ products: Dict[str, WorldCerealProduct]
105
+ metadata: Optional[Path]
106
+
107
+
108
+ class InferenceJobManager(MultiBackendJobManager):
109
+ """A job manager for executing large-scale WorldCereal inference jobs on the OpenEO backend.
110
+ Based on official MultiBackendJobManager with extension of how results are downloaded
111
+ and named.
112
+ """
113
+
114
+ @classmethod
115
+ def generate_output_path_inference(
116
+ cls,
117
+ root_folder: Path,
118
+ geometry_index: int,
119
+ row: pd.Series,
120
+ asset_id: Optional[str] = None,
121
+ ) -> Path:
122
+ """Method to generate the output path for inference jobs.
123
+
124
+ Parameters
125
+ ----------
126
+ root_folder : Path
127
+ root folder where the output parquet file will be saved
128
+ geometry_index : int
129
+ For point extractions, only one asset (a geoparquet file) is generated per job.
130
+ Therefore geometry_index is always 0. It has to be included in the function signature
131
+ to be compatible with the GFMapJobManager
132
+ row : pd.Series
133
+ the current job row from the GFMapJobManager
134
+ asset_id : str, optional
135
+ Needed for compatibility with GFMapJobManager but not used.
136
+
137
+ Returns
138
+ -------
139
+ Path
140
+ output path for the point extractions parquet file
141
+ """
142
+
143
+ tile_name = row.tile_name
144
+
145
+ # Create the subfolder to store the output
146
+ subfolder = root_folder / str(tile_name)
147
+ subfolder.mkdir(parents=True, exist_ok=True)
148
+
149
+ return subfolder
150
+
151
+ def on_job_done(self, job: BatchJob, row):
152
+ logger.info(f"Job {job.job_id} completed")
153
+ output_dir = self.generate_output_path_inference(self._root_dir, 0, row)
154
+
155
+ # Get job results
156
+ job_result = job.get_results()
157
+
158
+ # Get the products
159
+ assets = job_result.get_assets()
160
+ for asset in assets:
161
+ asset_name = asset.name.split(".")[0].split("_")[0]
162
+ asset_type = asset_name.split("-")[0]
163
+ asset_type = getattr(WorldCerealProductType, asset_type.upper())
164
+ filepath = asset.download(target=output_dir)
165
+
166
+ # We want to add the tile name to the filename
167
+ new_filepath = filepath.parent / f"{filepath.stem}_{row.tile_name}.tif"
168
+ shutil.move(filepath, new_filepath)
169
+
170
+ job_metadata = job.describe()
171
+ result_metadata = job_result.get_metadata()
172
+ job_metadata_path = output_dir / f"job_{job.job_id}.json"
173
+ result_metadata_path = output_dir / f"result_{job.job_id}.json"
174
+
175
+ with job_metadata_path.open("w", encoding="utf-8") as f:
176
+ json.dump(job_metadata, f, ensure_ascii=False)
177
+ with result_metadata_path.open("w", encoding="utf-8") as f:
178
+ json.dump(result_metadata, f, ensure_ascii=False)
179
+
180
+ # post_job_action(output_file)
181
+ logger.success("Job completed")
182
+
183
+
184
+ def create_inference_process_graph(
185
+ spatial_extent: BoundingBoxExtent,
186
+ temporal_extent: TemporalContext,
187
+ product_type: WorldCerealProductType = WorldCerealProductType.CROPLAND,
188
+ cropland_parameters: CropLandParameters = CropLandParameters(),
189
+ croptype_parameters: CropTypeParameters = CropTypeParameters(),
190
+ s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
191
+ out_format: str = "GTiff",
192
+ backend_context: BackendContext = BackendContext(Backend.CDSE),
193
+ tile_size: Optional[int] = 128,
194
+ target_epsg: Optional[int] = None,
195
+ connection: Optional[openeo.Connection] = None,
196
+ ) -> List[openeo.DataCube]:
197
+ """Wrapper function that creates the inference openEO process graph.
198
+
199
+ Parameters
200
+ ----------
201
+ spatial_extent : BoundingBoxExtent
202
+ spatial extent of the map
203
+ temporal_extent : TemporalContext
204
+ temporal range to consider
205
+ product_type : WorldCerealProductType, optional
206
+ product describer, by default WorldCerealProductType.CROPLAND
207
+ cropland_parameters: CropLandParameters
208
+ Parameters for the cropland product inference pipeline.
209
+ croptype_parameters: Optional[CropTypeParameters]
210
+ Parameters for the croptype product inference pipeline. Only required
211
+ whenever `product_type` is set to `WorldCerealProductType.CROPTYPE`,
212
+ will be ignored otherwise.
213
+ s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]]
214
+ Sentinel-1 orbit state to use for the inference. If not provided,
215
+ the orbit state will be dynamically determined based on the spatial extent.
216
+ out_format : str, optional
217
+ Output format, by default "GTiff"
218
+ backend_context : BackendContext
219
+ backend to run the job on, by default CDSE.
220
+ tile_size: int, optional
221
+ Tile size to use for the data loading in OpenEO, by default 128.
222
+ target_epsg: Optional[int] = None
223
+ EPSG code to use for the output products. If not provided, the
224
+ default EPSG will be used.
225
+ connection: Optional[openeo.Connection] = None,
226
+ Optional OpenEO connection to use. If not provided, a new connection
227
+ will be created based on the backend_context.
228
+
229
+ Returns
230
+ -------
231
+ List[openeo.DataCube]
232
+ A list with one or more result objects or a list of DataCube objects, representing the inference
233
+ process graph. This object can be used to execute the job on the OpenEO backend.
234
+ The result will be a DataCube with the classification results.
235
+
236
+ Raises
237
+ ------
238
+ ValueError
239
+ if the product is not supported
240
+ ValueError
241
+ if the out_format is not supported
242
+ """
243
+ if product_type not in WorldCerealProductType:
244
+ raise ValueError(f"Product {product_type.value} not supported.")
245
+
246
+ if out_format not in ["GTiff", "NetCDF"]:
247
+ raise ValueError(f"Format {format} not supported.")
248
+
249
+ # Make a connection to the OpenEO backend
250
+ if connection is None:
251
+ connection = BACKEND_CONNECTIONS[backend_context.backend]()
252
+
253
+ # Preparing the input cube for inference
254
+ inputs = worldcereal_preprocessed_inputs(
255
+ connection=connection,
256
+ backend_context=backend_context,
257
+ spatial_extent=spatial_extent,
258
+ temporal_extent=temporal_extent,
259
+ tile_size=tile_size,
260
+ s1_orbit_state=s1_orbit_state,
261
+ target_epsg=target_epsg,
262
+ # disable_meteo=True,
263
+ )
264
+
265
+ # Spatial filtering
266
+ inputs = inputs.filter_bbox(dict(spatial_extent))
267
+
268
+ # Construct the feature extraction and model inference pipeline
269
+ if product_type == WorldCerealProductType.CROPLAND:
270
+ results = _cropland_map(
271
+ inputs,
272
+ temporal_extent,
273
+ cropland_parameters=cropland_parameters,
274
+ )
275
+
276
+ elif product_type == WorldCerealProductType.CROPTYPE:
277
+ if not isinstance(croptype_parameters, CropTypeParameters):
278
+ raise ValueError(
279
+ f"Please provide a valid `croptype_parameters` parameter."
280
+ f" Received: {croptype_parameters}"
281
+ )
282
+
283
+ # Generate crop type map with optional cropland masking
284
+ results = _croptype_map(
285
+ inputs,
286
+ temporal_extent,
287
+ cropland_parameters=cropland_parameters,
288
+ croptype_parameters=croptype_parameters,
289
+ )
290
+
291
+ return results
292
+
293
+
294
+ def create_embeddings_process_graph(
295
+ spatial_extent: BoundingBoxExtent,
296
+ temporal_extent: TemporalContext,
297
+ embeddings_parameters: EmbeddingsParameters = EmbeddingsParameters(),
298
+ s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
299
+ out_format: str = "GTiff",
300
+ backend_context: BackendContext = BackendContext(Backend.CDSE),
301
+ tile_size: Optional[int] = 128,
302
+ target_epsg: Optional[int] = None,
303
+ scale_uint16: bool = True,
304
+ ) -> openeo.DataCube:
305
+ """Create an OpenEO process graph for generating embeddings.
306
+
307
+ Parameters
308
+ ----------
309
+ spatial_extent : BoundingBoxExtent
310
+ Spatial extent of the map.
311
+ temporal_extent : TemporalContext
312
+ Temporal range to consider.
313
+ embeddings_parameters : EmbeddingsParameters, optional
314
+ Parameters for the embeddings product inference pipeline, by default EmbeddingsParameters().
315
+ s1_orbit_state : Optional[Literal["ASCENDING", "DESCENDING"]], optional
316
+ Sentinel-1 orbit state to use for the inference. If not provided, the orbit state will be dynamically determined based on the spatial extent, by default None.
317
+ out_format : str, optional
318
+ Output format, by default "GTiff".
319
+ backend_context : BackendContext, optional
320
+ Backend to run the job on, by default BackendContext(Backend.CDSE).
321
+ tile_size : Optional[int], optional
322
+ Tile size to use for the data loading in OpenEO, by default 128.
323
+ target_epsg : Optional[int], optional
324
+ EPSG code to use for the output products. If not provided, the default EPSG will be used.
325
+ scale_uint16 : bool, optional
326
+ Whether to scale the embeddings to uint16 for memory optimization, by default True.
327
+
328
+ Returns
329
+ -------
330
+ openeo.DataCube
331
+ DataCube object representing the embeddings process graph. This object can be used to execute the job on the OpenEO backend. The result will be a DataCube with the embeddings.
332
+
333
+ Raises
334
+ ------
335
+ ValueError
336
+ If the output format is not supported.
337
+ """
338
+
339
+ if out_format not in ["GTiff", "NetCDF"]:
340
+ raise ValueError(f"Format {format} not supported.")
341
+
342
+ # Make a connection to the OpenEO backend
343
+ connection = BACKEND_CONNECTIONS[backend_context.backend]()
344
+
345
+ # Preparing the input cube for inference
346
+ inputs = worldcereal_preprocessed_inputs(
347
+ connection=connection,
348
+ backend_context=backend_context,
349
+ spatial_extent=spatial_extent,
350
+ temporal_extent=temporal_extent,
351
+ tile_size=tile_size,
352
+ s1_orbit_state=s1_orbit_state,
353
+ target_epsg=target_epsg,
354
+ # disable_meteo=True,
355
+ )
356
+
357
+ # Spatial filtering
358
+ inputs = inputs.filter_bbox(dict(spatial_extent))
359
+
360
+ embeddings = _embeddings_map(
361
+ inputs,
362
+ temporal_extent,
363
+ embeddings_parameters=embeddings_parameters,
364
+ scale_uint16=scale_uint16,
365
+ )
366
+
367
+ # Save the final result
368
+ embeddings = embeddings.save_result(
369
+ format=out_format,
370
+ options=dict(
371
+ filename_prefix=f"WorldCereal_Embeddings_{temporal_extent.start_date}_{temporal_extent.end_date}",
372
+ ),
373
+ )
374
+
375
+ return embeddings
376
+
377
+
378
+ def create_inputs_process_graph(
379
+ spatial_extent: BoundingBoxExtent,
380
+ temporal_extent: TemporalContext,
381
+ s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
382
+ out_format: str = "NetCDF",
383
+ backend_context: BackendContext = BackendContext(Backend.CDSE),
384
+ tile_size: Optional[int] = 128,
385
+ target_epsg: Optional[int] = None,
386
+ compositing_window: Literal["month", "dekad"] = "month",
387
+ ) -> openeo.DataCube:
388
+ """Wrapper function that creates the inputs openEO process graph.
389
+
390
+ Parameters
391
+ ----------
392
+ spatial_extent : BoundingBoxExtent
393
+ spatial extent of the map
394
+ temporal_extent : TemporalContext
395
+ temporal range to consider
396
+ s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]]
397
+ Sentinel-1 orbit state to use for the inference. If not provided,
398
+ the orbit state will be dynamically determined based on the spatial extent.
399
+ out_format : str, optional
400
+ Output format, by default "NetCDF"
401
+ backend_context : BackendContext
402
+ backend to run the job on, by default CDSE.
403
+ tile_size: int, optional
404
+ Tile size to use for the data loading in OpenEO, by default 128.
405
+ target_epsg: Optional[int] = None
406
+ EPSG code to use for the output products. If not provided, the
407
+ default EPSG will be used.
408
+ compositing_window: Literal["month", "dekad"]
409
+ Compositing window to use for the data loading in OpenEO, by default
410
+ "month".
411
+
412
+ Returns
413
+ -------
414
+ openeo.DataCube
415
+ DataCube object representing the inputs process graph.
416
+ This object can be used to execute the job on the OpenEO backend.
417
+ The result will be a DataCube with the preprocessed inputs.
418
+
419
+ Raises
420
+ ------
421
+ ValueError
422
+ if the out_format is not supported
423
+ """
424
+
425
+ if out_format not in ["GTiff", "NetCDF"]:
426
+ raise ValueError(f"Format {format} not supported.")
427
+
428
+ # Make a connection to the OpenEO backend
429
+ connection = BACKEND_CONNECTIONS[backend_context.backend]()
430
+
431
+ # Preparing the input cube for inference
432
+ inputs = worldcereal_preprocessed_inputs(
433
+ connection=connection,
434
+ backend_context=backend_context,
435
+ spatial_extent=spatial_extent,
436
+ temporal_extent=temporal_extent,
437
+ tile_size=tile_size,
438
+ s1_orbit_state=s1_orbit_state,
439
+ target_epsg=target_epsg,
440
+ compositing_window=compositing_window,
441
+ # disable_meteo=True,
442
+ )
443
+
444
+ # Spatial filtering
445
+ inputs = inputs.filter_bbox(dict(spatial_extent))
446
+
447
+ # Save the final result
448
+ inputs = inputs.save_result(
449
+ format=out_format,
450
+ options=dict(
451
+ filename_prefix=f"preprocessed-inputs_{temporal_extent.start_date}_{temporal_extent.end_date}",
452
+ ),
453
+ )
454
+
455
+ return inputs
456
+
457
+
458
+ def create_inference_job(
459
+ row: pd.Series,
460
+ connection: openeo.Connection,
461
+ provider: str,
462
+ connection_provider: str,
463
+ product_type: WorldCerealProductType = WorldCerealProductType.CROPTYPE,
464
+ cropland_parameters: CropLandParameters = CropLandParameters(),
465
+ croptype_parameters: CropTypeParameters = CropTypeParameters(),
466
+ s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
467
+ target_epsg: Optional[int] = None,
468
+ job_options: Optional[dict] = None,
469
+ ) -> BatchJob:
470
+ """Create an OpenEO batch job for WorldCereal inference.
471
+
472
+ Parameters
473
+ ----------
474
+ row : pd.Series
475
+ _description_
476
+ Contains at least the following fields:
477
+ - start_date: str, start date of the temporal extent
478
+ - end_date: str, end date of the temporal extent
479
+ - geometry: shapely.geometry, geometry of the spatial extent
480
+ - tile_name: str, name of the tile
481
+ - epsg: int, EPSG code of the spatial extent
482
+ - bounds_epsg: str representation of tuple,
483
+ bounds of the spatial extent in CRS as
484
+ specified by epsg attribute
485
+ connection : openeo.Connection
486
+ openEO connection to the backend
487
+ provider : str
488
+ unused but required for compatibility with MultiBackendJobManager
489
+ connection_provider : str
490
+ unused but required for compatibility with MultiBackendJobManager6
491
+ product_type : WorldCerealProductType, optional
492
+ Type of the WorldCereal product to generate, by default WorldCerealProductType.CROPTYPE
493
+ croptype_parameters : Optional[CropTypeParameters], optional
494
+ Parameters for the croptype product inference pipeline. Only required
495
+ whenever `product_type` is set to `WorldCerealProductType.CROPTYPE`,
496
+ will be ignored otherwise, by default None
497
+ cropland_parameters : Optional[CropLandParameters], optional
498
+ Parameters for the cropland product inference pipeline, by default None
499
+ s1_orbit_state : Optional[Literal["ASCENDING", "DESCENDING"]], optional
500
+ Sentinel-1 orbit state to use for the inference. If not provided, the
501
+ best orbit will be dynamically derived from the catalogue.
502
+ target_epsg : Optional[int], optional
503
+ EPSG code to reproject the data to. If not provided, the data will be
504
+ left in the original epsg as mentioned in the row.
505
+ job_options : Optional[dict], optional
506
+ Additional job options to pass to the OpenEO backend, by default None
507
+
508
+ Returns
509
+ -------
510
+ BatchJob
511
+ Batch job created on openEO backend.
512
+ """
513
+
514
+ # Get temporal and spatial extents from the row
515
+ temporal_extent = TemporalContext(start_date=row.start_date, end_date=row.end_date)
516
+ epsg = int(row.epsg)
517
+ bounds = eval(row.bounds_epsg)
518
+ spatial_extent = BoundingBoxExtent(
519
+ west=bounds[0], south=bounds[1], east=bounds[2], north=bounds[3], epsg=epsg
520
+ )
521
+
522
+ if target_epsg is None:
523
+ # If no target EPSG is provided, use the EPSG from the row
524
+ target_epsg = epsg
525
+
526
+ # Update default job options with the provided ones
527
+ inference_job_options = deepcopy(INFERENCE_JOB_OPTIONS)
528
+ if job_options is not None:
529
+ inference_job_options.update(job_options)
530
+
531
+ inference_result = create_inference_process_graph(
532
+ spatial_extent=spatial_extent,
533
+ temporal_extent=temporal_extent,
534
+ product_type=product_type,
535
+ croptype_parameters=croptype_parameters,
536
+ cropland_parameters=cropland_parameters,
537
+ s1_orbit_state=s1_orbit_state,
538
+ target_epsg=target_epsg,
539
+ connection=connection,
540
+ )
541
+
542
+ # Submit the job
543
+ return connection.create_job(
544
+ inference_result,
545
+ title=f"WorldCereal [{product_type.value}] job_{row.tile_name}",
546
+ description="Job that performs end-to-end WorldCereal inference",
547
+ additional=inference_job_options, # TODO: once openeo-python-client supports job_options, use that
548
+ )
549
+
550
+
551
+ def generate_map(
552
+ spatial_extent: BoundingBoxExtent,
553
+ temporal_extent: TemporalContext,
554
+ output_dir: Optional[Union[Path, str]] = None,
555
+ product_type: WorldCerealProductType = WorldCerealProductType.CROPLAND,
556
+ cropland_parameters: CropLandParameters = CropLandParameters(),
557
+ croptype_parameters: CropTypeParameters = CropTypeParameters(),
558
+ out_format: str = "GTiff",
559
+ backend_context: BackendContext = BackendContext(Backend.CDSE),
560
+ tile_size: Optional[int] = 128,
561
+ job_options: Optional[dict] = None,
562
+ target_epsg: Optional[int] = None,
563
+ ) -> InferenceResults:
564
+ """Main function to generate a WorldCereal product.
565
+
566
+ Parameters
567
+ ----------
568
+ spatial_extent : BoundingBoxExtent
569
+ spatial extent of the map
570
+ temporal_extent : TemporalContext
571
+ temporal range to consider
572
+ output_dir : Optional[Union[Path, str]]
573
+ path to directory where products should be downloaded to
574
+ product_type : WorldCerealProductType, optional
575
+ product describer, by default WorldCerealProductType.CROPLAND
576
+ cropland_parameters: CropLandParameters
577
+ Parameters for the cropland product inference pipeline.
578
+ croptype_parameters: Optional[CropTypeParameters]
579
+ Parameters for the croptype product inference pipeline. Only required
580
+ whenever `product_type` is set to `WorldCerealProductType.CROPTYPE`,
581
+ will be ignored otherwise.
582
+ out_format : str, optional
583
+ Output format, by default "GTiff"
584
+ backend_context : BackendContext
585
+ backend to run the job on, by default CDSE.
586
+ tile_size: int, optional
587
+ Tile size to use for the data loading in OpenEO, by default 128.
588
+ job_options: dict, optional
589
+ Additional job options to pass to the OpenEO backend, by default None
590
+ target_epsg: Optional[int] = None
591
+ EPSG code to use for the output products. If not provided, the
592
+ default EPSG will be used.
593
+
594
+ Returns
595
+ -------
596
+ InferenceResults
597
+ Results of the finished WorldCereal job.
598
+
599
+ Raises
600
+ ------
601
+ ValueError
602
+ if the product is not supported
603
+ ValueError
604
+ if the out_format is not supported
605
+ """
606
+
607
+ # Get a connection to the OpenEO backend
608
+ connection = BACKEND_CONNECTIONS[backend_context.backend]()
609
+
610
+ # Create the process graph
611
+ results = create_inference_process_graph(
612
+ spatial_extent=spatial_extent,
613
+ temporal_extent=temporal_extent,
614
+ product_type=product_type,
615
+ cropland_parameters=cropland_parameters,
616
+ croptype_parameters=croptype_parameters,
617
+ out_format=out_format,
618
+ backend_context=backend_context,
619
+ tile_size=tile_size,
620
+ target_epsg=target_epsg,
621
+ connection=connection,
622
+ )
623
+
624
+ if output_dir is not None:
625
+ output_dir = Path(output_dir)
626
+ output_dir.mkdir(parents=True, exist_ok=True)
627
+
628
+ # Submit the job
629
+ inference_job_options = deepcopy(INFERENCE_JOB_OPTIONS)
630
+ if job_options is not None:
631
+ inference_job_options.update(job_options)
632
+
633
+ # Execute the job
634
+ job = connection.create_job(
635
+ results,
636
+ additional=inference_job_options, # TODO: once openeo-python-client supports job_options, use that
637
+ title=f"WorldCereal [{product_type.value}] job",
638
+ description="Job that performs end-to-end WorldCereal inference",
639
+ ).start_and_wait()
640
+
641
+ # Get look-up tables
642
+ luts = {}
643
+ luts[WorldCerealProductType.CROPLAND.value] = load_model_lut(
644
+ cropland_parameters.classifier_parameters.classifier_url
645
+ )
646
+ if product_type == WorldCerealProductType.CROPTYPE:
647
+ luts[WorldCerealProductType.CROPTYPE.value] = load_model_lut(
648
+ croptype_parameters.classifier_parameters.classifier_url
649
+ )
650
+
651
+ # Get job results
652
+ job_result = job.get_results()
653
+
654
+ # Get the products
655
+ assets = job_result.get_assets()
656
+ products = {}
657
+ for asset in assets:
658
+ asset_name = asset.name.split(".")[0].split("_")[0]
659
+ asset_type = asset_name.split("-")[0]
660
+ asset_type = getattr(WorldCerealProductType, asset_type.upper())
661
+ if output_dir is not None:
662
+ filepath = asset.download(target=output_dir)
663
+ else:
664
+ filepath = None
665
+ products[asset_name] = {
666
+ "url": asset.href,
667
+ "type": asset_type,
668
+ "temporal_extent": temporal_extent,
669
+ "path": filepath,
670
+ "lut": luts[asset_type.value],
671
+ }
672
+
673
+ # Download job metadata if output path is provided
674
+ if output_dir is not None:
675
+ metadata_file = output_dir / "job-results.json"
676
+ metadata_file.write_text(json.dumps(job_result.get_metadata()))
677
+ else:
678
+ metadata_file = None
679
+
680
+ # Compile InferenceResults and return
681
+ return InferenceResults(
682
+ job_id=job.job_id, products=products, metadata=metadata_file
683
+ )
684
+
685
+
686
+ def collect_inputs(
687
+ spatial_extent: BoundingBoxExtent,
688
+ temporal_extent: TemporalContext,
689
+ output_path: Union[Path, str],
690
+ backend_context: BackendContext = BackendContext(Backend.CDSE),
691
+ tile_size: Optional[int] = 128,
692
+ job_options: Optional[dict] = None,
693
+ compositing_window: Literal["month", "dekad"] = "month",
694
+ ):
695
+ """Function to retrieve preprocessed inputs that are being
696
+ used in the generation of WorldCereal products.
697
+
698
+ Parameters
699
+ ----------
700
+ spatial_extent : BoundingBoxExtent
701
+ spatial extent of the map
702
+ temporal_extent : TemporalContext
703
+ temporal range to consider
704
+ output_path : Union[Path, str]
705
+ output path to download the product to
706
+ backend_context : BackendContext
707
+ backend to run the job on, by default CDSE
708
+ tile_size: int, optional
709
+ Tile size to use for the data loading in OpenEO, by default 128
710
+ so it uses the OpenEO default setting.
711
+ job_options: dict, optional
712
+ Additional job options to pass to the OpenEO backend, by default None
713
+ compositing_window: Literal["month", "dekad"]
714
+ Compositing window to use for the data loading in OpenEO, by default
715
+ "month".
716
+ """
717
+
718
+ # Make a connection to the OpenEO backend
719
+ connection = BACKEND_CONNECTIONS[backend_context.backend]()
720
+
721
+ # Preparing the input cube for the inference
722
+ inputs = worldcereal_preprocessed_inputs(
723
+ connection=connection,
724
+ backend_context=backend_context,
725
+ spatial_extent=spatial_extent,
726
+ temporal_extent=temporal_extent,
727
+ tile_size=tile_size,
728
+ validate_temporal_context=False,
729
+ compositing_window=compositing_window,
730
+ )
731
+
732
+ # Spatial filtering
733
+ inputs = inputs.filter_bbox(dict(spatial_extent))
734
+
735
+ JOB_OPTIONS = {
736
+ "driver-memory": "4g",
737
+ "executor-memory": "1g",
738
+ "executor-memoryOverhead": "1g",
739
+ "python-memory": "3g",
740
+ "soft-errors": 0.1,
741
+ }
742
+ if job_options is not None:
743
+ JOB_OPTIONS.update(job_options)
744
+
745
+ inputs.execute_batch(
746
+ outputfile=output_path,
747
+ out_format="NetCDF",
748
+ title="WorldCereal [collect_inputs] job",
749
+ description="Job that collects inputs for WorldCereal inference",
750
+ job_options=JOB_OPTIONS,
751
+ )
752
+
753
+
754
+ def run_largescale_inference(
755
+ production_grid: Union[Path, gpd.GeoDataFrame],
756
+ output_dir: Union[Path, str],
757
+ product_type: WorldCerealProductType = WorldCerealProductType.CROPLAND,
758
+ cropland_parameters: CropLandParameters = CropLandParameters(),
759
+ croptype_parameters: CropTypeParameters = CropTypeParameters(),
760
+ backend_context: BackendContext = BackendContext(Backend.CDSE),
761
+ target_epsg: Optional[int] = None,
762
+ s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
763
+ job_options: Optional[dict] = None,
764
+ parallel_jobs: int = 2,
765
+ ):
766
+ """
767
+ Run large-scale inference jobs on the OpenEO backend.
768
+ This function orchestrates the execution of large-scale inference jobs
769
+ using a production grid (either a Parquet file or a GeoDataFrame) and specified parameters.
770
+ It manages job creation, tracking, and execution on the OpenEO backend.
771
+
772
+ Parameters
773
+ ----------
774
+ production_grid : Union[Path, gpd.GeoDataFrame]
775
+ Path to the production grid file in Parquet format or a GeoDataFrame.
776
+ The grid must contain the required attributes: 'start_date', 'end_date',
777
+ 'geometry', 'tile_name', 'epsg' and 'bounds_epsg'.
778
+ output_dir : Union[Path, str]
779
+ Directory where output files and job tracking information will be stored.
780
+ product_type : WorldCerealProductType
781
+ Type of product to generate. Defaults to WorldCerealProductType.CROPLAND.
782
+ cropland_parameters : CropLandParameters
783
+ Parameters for cropland inference.
784
+ croptype_parameters : CropTypeParameters
785
+ Parameters for crop type inference.
786
+ backend_context : BackendContext
787
+ Context for the backend to use. Defaults to BackendContext(Backend.CDSE).
788
+ target_epsg : Optional[int]
789
+ EPSG code for the target coordinate reference system.
790
+ If None, no reprojection will be performed.
791
+ s1_orbit_state : Optional[Literal["ASCENDING", "DESCENDING"]]
792
+ Sentinel-1 orbit state to use ('ASCENDING' or 'DESCENDING')
793
+ If None, no specific orbit state is enforced.
794
+ job_options : Optional[dict]
795
+ Additional options for configuring the inference jobs. Defaults to None.
796
+ parallel_jobs : int
797
+ Number of parallel jobs to manage on the backend. Defaults to 2. Note that load
798
+ balancing does not guarantee that all jobs will run in parallel.
799
+
800
+ Returns
801
+ -------
802
+ None
803
+ """
804
+
805
+ job_manager, job_db, start_job = setup_inference_job_manager(
806
+ production_grid=production_grid,
807
+ output_dir=output_dir,
808
+ product_type=product_type,
809
+ cropland_parameters=cropland_parameters,
810
+ croptype_parameters=croptype_parameters,
811
+ backend_context=backend_context,
812
+ target_epsg=target_epsg,
813
+ s1_orbit_state=s1_orbit_state,
814
+ job_options=job_options,
815
+ parallel_jobs=parallel_jobs,
816
+ )
817
+
818
+ job_df = job_db.df
819
+ job_tracking_csv = job_db.path
820
+
821
+ # Run the jobs
822
+ job_manager.run_jobs(
823
+ df=job_df,
824
+ start_job=start_job,
825
+ job_db=job_tracking_csv,
826
+ )
827
+
828
+ logger.info("Job manager finished.")
829
+
830
+
831
+ def setup_inference_job_manager(
832
+ production_grid: Union[Path, gpd.GeoDataFrame],
833
+ output_dir: Union[Path, str],
834
+ product_type: WorldCerealProductType = WorldCerealProductType.CROPLAND,
835
+ cropland_parameters: CropLandParameters = CropLandParameters(),
836
+ croptype_parameters: CropTypeParameters = CropTypeParameters(),
837
+ backend_context: BackendContext = BackendContext(Backend.CDSE),
838
+ target_epsg: Optional[int] = None,
839
+ s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
840
+ job_options: Optional[dict] = None,
841
+ parallel_jobs: int = 2,
842
+ ) -> tuple[InferenceJobManager, CsvJobDatabase, Callable]:
843
+ """
844
+ Prepare large-scale inference jobs on the OpenEO backend.
845
+ This function sets up the job manager, creates job tracking information,
846
+ and defines the job creation function for WorldCereal inference jobs.
847
+
848
+ Parameters
849
+ ----------
850
+ production_grid : Union[Path, gpd.GeoDataFrame]
851
+ Path to the production grid file in Parquet format or a GeoDataFrame.
852
+ The grid must contain the required attributes: 'start_date', 'end_date',
853
+ 'geometry', 'tile_name', 'epsg' and 'bounds_epsg'.
854
+ output_dir : Union[Path, str]
855
+ Directory where output files and job tracking information will be stored.
856
+ product_type : WorldCerealProductType
857
+ Type of product to generate. Defaults to WorldCerealProductType.CROPLAND.
858
+ cropland_parameters : CropLandParameters
859
+ Parameters for cropland inference.
860
+ croptype_parameters : CropTypeParameters
861
+ Parameters for crop type inference.
862
+ backend_context : BackendContext
863
+ Context for the backend to use. Defaults to BackendContext(Backend.CDSE).
864
+ target_epsg : Optional[int]
865
+ EPSG code for the target coordinate reference system.
866
+ If None, no reprojection will be performed.
867
+ s1_orbit_state : Optional[Literal["ASCENDING", "DESCENDING"]]
868
+ Sentinel-1 orbit state to use ('ASCENDING' or 'DESCENDING')
869
+ If None, no specific orbit state is enforced.
870
+ job_options : Optional[dict]
871
+ Additional options for configuring the inference jobs. Defaults to None.
872
+ parallel_jobs : int
873
+ Number of parallel jobs to manage on the backend. Defaults to 2. Note that load
874
+ balancing does not guarantee that all jobs will run in parallel.
875
+
876
+ Returns
877
+ -------
878
+ tuple[InferenceJobManager, CsvJobDatabase, callable]
879
+ A tuple containing:
880
+ - InferenceJobManager: The job manager for handling inference jobs.
881
+ - CsvJobDatabase: The job database for tracking job information.
882
+ - callable: A function to create individual inference jobs.
883
+
884
+ Raises
885
+ -------
886
+ AssertionError:
887
+ If the production grid does not contain the required attributes.
888
+ """
889
+
890
+ # Setup output directory
891
+ output_dir = Path(output_dir)
892
+ output_dir.mkdir(parents=True, exist_ok=True)
893
+
894
+ # Make a connection to the OpenEO backend
895
+ backend = backend_context.backend
896
+ connection = BACKEND_CONNECTIONS[backend]()
897
+
898
+ # Setup the job manager
899
+ logger.info("Setting up the job manager.")
900
+ manager = InferenceJobManager(root_dir=output_dir)
901
+ manager.add_backend(
902
+ backend.value, connection=connection, parallel_jobs=parallel_jobs
903
+ )
904
+
905
+ # Configure job tracking CSV file
906
+ job_tracking_csv = output_dir / "job_tracking.csv"
907
+
908
+ job_db = CsvJobDatabase(path=job_tracking_csv)
909
+ if not job_db.exists():
910
+ logger.info("Job tracking file does not exist, creating new jobs.")
911
+
912
+ if isinstance(production_grid, Path):
913
+ production_gdf = gpd.read_parquet(production_grid)
914
+ elif isinstance(production_grid, gpd.GeoDataFrame):
915
+ production_gdf = production_grid
916
+ else:
917
+ raise ValueError("production_grid must be a Path or a GeoDataFrame.")
918
+
919
+ REQUIRED_ATTRIBUTES = [
920
+ "start_date",
921
+ "end_date",
922
+ "geometry",
923
+ "tile_name",
924
+ "epsg",
925
+ "bounds_epsg",
926
+ ]
927
+ for attr in REQUIRED_ATTRIBUTES:
928
+ assert (
929
+ attr in production_gdf.columns
930
+ ), f"The production grid must contain a '{attr}' column."
931
+
932
+ job_df = production_gdf[REQUIRED_ATTRIBUTES].copy()
933
+
934
+ df = manager._normalize_df(job_df)
935
+ # Save the job tracking DataFrame to the job database
936
+ job_db.persist(df)
937
+
938
+ else:
939
+ logger.info("Job tracking file already exists, skipping job creation.")
940
+
941
+ # Define the job creation function
942
+ start_job = partial(
943
+ create_inference_job,
944
+ product_type=product_type,
945
+ cropland_parameters=cropland_parameters,
946
+ croptype_parameters=croptype_parameters,
947
+ s1_orbit_state=s1_orbit_state,
948
+ job_options=job_options,
949
+ target_epsg=target_epsg,
950
+ )
951
+
952
+ # Check if there are jobs to run
953
+ if job_db.df.empty:
954
+ logger.warning("No jobs to run. The job tracking CSV is empty.")
955
+ raise ValueError(
956
+ "No jobs to run. The job tracking CSV is empty. "
957
+ "Please check the production grid and ensure it contains valid data."
958
+ )
959
+
960
+ return manager, job_db, start_job
worldcereal/openeo/__init__.py ADDED
File without changes
worldcereal/openeo/feature_extractor.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """openEO UDF to compute Presto/Prometheo features."""
2
+
3
+ import copy
4
+ import functools
5
+ import logging
6
+ import random
7
+ import sys
8
+ import urllib.request
9
+ import zipfile
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import numpy as np
14
+ import xarray as xr
15
+ from openeo.metadata import CollectionMetadata
16
+ from openeo.udf import XarrayDataCube
17
+ from openeo.udf.udf_data import UdfData
18
+ from pyproj import Transformer
19
+ from pyproj.crs import CRS
20
+ from scipy.ndimage import (
21
+ convolve,
22
+ zoom,
23
+ )
24
+ from shapely.geometry import Point
25
+ from shapely.ops import transform
26
+
27
+ sys.path.append("feature_deps")
28
+
29
+ import torch # noqa: E402
30
+
31
+ PROMETHEO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/prometheo-0.0.3-py3-none-any.whl"
32
+
33
+ GFMAP_BAND_MAPPING = {
34
+ "S2-L2A-B02": "B2",
35
+ "S2-L2A-B03": "B3",
36
+ "S2-L2A-B04": "B4",
37
+ "S2-L2A-B05": "B5",
38
+ "S2-L2A-B06": "B6",
39
+ "S2-L2A-B07": "B7",
40
+ "S2-L2A-B08": "B8",
41
+ "S2-L2A-B8A": "B8A",
42
+ "S2-L2A-B11": "B11",
43
+ "S2-L2A-B12": "B12",
44
+ "S1-SIGMA0-VH": "VH",
45
+ "S1-SIGMA0-VV": "VV",
46
+ "AGERA5-TMEAN": "temperature_2m",
47
+ "AGERA5-PRECIP": "total_precipitation",
48
+ }
49
+
50
+ LAT_HARMONIZED_NAME = "GEO-LAT"
51
+ LON_HARMONIZED_NAME = "GEO-LON"
52
+ EPSG_HARMONIZED_NAME = "GEO-EPSG"
53
+
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+
58
+ @functools.lru_cache(maxsize=1)
59
+ def unpack_prometheo_wheel(wheel_url: str):
60
+ destination_dir = Path.cwd() / "dependencies" / "prometheo"
61
+ destination_dir.mkdir(exist_ok=True, parents=True)
62
+
63
+ # Downloads the wheel file
64
+ modelfile, _ = urllib.request.urlretrieve(
65
+ wheel_url, filename=Path.cwd() / Path(wheel_url).name
66
+ )
67
+ with zipfile.ZipFile(modelfile, "r") as zip_ref:
68
+ zip_ref.extractall(destination_dir)
69
+ return destination_dir
70
+
71
+
72
+ @functools.lru_cache(maxsize=1)
73
+ def compile_encoder(presto_encoder):
74
+ """Helper function that compiles the encoder of a Presto model
75
+ and performs a warm-up on dummy data. The lru_cache decorator
76
+ ensures caching on compute nodes to be able to actually benefit
77
+ from the compilation process.
78
+
79
+ Parameters
80
+ ----------
81
+ presto_encoder : nn.Module
82
+ Encoder part of Presto model to compile
83
+
84
+ """
85
+
86
+ presto_encoder = torch.compile(presto_encoder) # type: ignore
87
+
88
+ for _ in range(3):
89
+ presto_encoder(
90
+ torch.rand((1, 12, 17)),
91
+ torch.ones((1, 12)).long(),
92
+ torch.rand(1, 2),
93
+ )
94
+
95
+ return presto_encoder
96
+
97
+
98
+ def evaluate_resolution(inarr: xr.DataArray, epsg: int) -> int:
99
+ """Helper function to get the resolution in meters for
100
+ the input array.
101
+
102
+ Parameters
103
+ ----------
104
+ inarr : xr.DataArray
105
+ input array to determine resolution for.
106
+
107
+ Returns
108
+ -------
109
+ int
110
+ resolution in meters.
111
+ """
112
+
113
+ if epsg == 4326:
114
+ logger.info(
115
+ "Converting WGS84 coordinates to EPSG:3857 to determine resolution."
116
+ )
117
+
118
+ transformer = Transformer.from_crs(epsg, 3857, always_xy=True)
119
+ points = [Point(x, y) for x, y in zip(inarr.x.values, inarr.y.values)]
120
+ points = [transform(transformer.transform, point) for point in points]
121
+
122
+ resolution = abs(points[1].x - points[0].x)
123
+
124
+ else:
125
+ resolution = abs(inarr.x[1].values - inarr.x[0].values)
126
+
127
+ logger.info(f"Resolution for computing slope: {resolution}")
128
+
129
+ return resolution
130
+
131
+
132
+ def compute_slope(inarr: xr.DataArray, resolution: int) -> xr.DataArray:
133
+ """Computes the slope using the scipy library. The input array should
134
+ have the following bands: 'elevation' And no time dimension. Returns a
135
+ new DataArray containing the new `slope` band.
136
+
137
+ Parameters
138
+ ----------
139
+ inarr : xr.DataArray
140
+ input array containing a band 'elevation'.
141
+ resolution : int
142
+ resolution of the input array in meters.
143
+
144
+ Returns
145
+ -------
146
+ xr.DataArray
147
+ output array containing 'slope' band in degrees.
148
+ """
149
+
150
+ def _rolling_fill(darr, max_iter=2):
151
+ """Helper function that also reflects values inside
152
+ a patch with NaNs."""
153
+ if max_iter == 0:
154
+ return darr
155
+ else:
156
+ max_iter -= 1
157
+ # arr of shape (rows, cols)
158
+ mask = np.isnan(darr)
159
+
160
+ if ~np.any(mask):
161
+ return darr
162
+
163
+ roll_params = [(0, 1), (0, -1), (1, 0), (-1, 0)]
164
+ random.shuffle(roll_params)
165
+
166
+ for roll_param in roll_params:
167
+ rolled = np.roll(darr, roll_param, axis=(0, 1))
168
+ darr[mask] = rolled[mask]
169
+
170
+ return _rolling_fill(darr, max_iter=max_iter)
171
+
172
+ def _downsample(arr: np.ndarray, factor: int) -> np.ndarray:
173
+ """Downsamples a 2D NumPy array by a given factor with average resampling and reflect padding.
174
+
175
+ Parameters
176
+ ----------
177
+ arr : np.ndarray
178
+ The 2D input array.
179
+ factor : int
180
+ The factor by which to downsample. For example, factor=2 downsamples by 2x.
181
+
182
+ Returns
183
+ -------
184
+ np.ndarray
185
+ Downsampled array.
186
+ """
187
+
188
+ # Get the original shape of the array
189
+ X, Y = arr.shape
190
+
191
+ # Calculate how much padding is needed for each dimension
192
+ pad_X = (
193
+ factor - (X % factor)
194
+ ) % factor # Ensures padding is only applied if needed
195
+ pad_Y = (
196
+ factor - (Y % factor)
197
+ ) % factor # Ensures padding is only applied if needed
198
+
199
+ # Pad the array using 'reflect' mode
200
+ padded = np.pad(arr, ((0, pad_X), (0, pad_Y)), mode="reflect")
201
+
202
+ # Reshape the array to form blocks of size 'factor' x 'factor'
203
+ reshaped = padded.reshape(
204
+ (X + pad_X) // factor, factor, (Y + pad_Y) // factor, factor
205
+ )
206
+
207
+ # Take the mean over the factor-sized blocks
208
+ downsampled = np.nanmean(reshaped, axis=(1, 3))
209
+
210
+ return downsampled
211
+
212
+ dem = inarr.sel(bands="elevation").values
213
+ dem_arr = dem.astype(np.float32)
214
+
215
+ # Invalid to NaN and keep track of these pixels
216
+ dem_arr[dem_arr == 65535] = np.nan
217
+ idx_invalid = np.isnan(dem_arr)
218
+
219
+ # Fill NaNs with rolling fill
220
+ dem_arr = _rolling_fill(dem_arr)
221
+
222
+ # We make sure DEM is at 20m for slope computation
223
+ # compatible with global slope collection
224
+ factor = int(20 / resolution)
225
+ if factor < 1 or factor % 2 != 0:
226
+ raise NotImplementedError(
227
+ f"Unsupported resolution for slope computation: {resolution}"
228
+ )
229
+ dem_arr_downsampled = _downsample(dem_arr, factor)
230
+ x_odd, y_odd = dem_arr.shape[0] % 2 != 0, dem_arr.shape[1] % 2 != 0
231
+
232
+ # Mask NaN values in the DEM data
233
+ dem_masked = np.ma.masked_invalid(dem_arr_downsampled)
234
+
235
+ # Define convolution kernels for x and y gradients (simple finite difference approximation)
236
+ kernel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) / (
237
+ 8.0 * 20 # array is now at 20m resolution
238
+ ) # x-derivative kernel
239
+
240
+ kernel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) / (
241
+ 8.0 * 20 # array is now at 20m resolution
242
+ ) # y-derivative kernel
243
+
244
+ # Apply convolution to compute gradients
245
+ dx = convolve(dem_masked, kernel_x) # Gradient in the x-direction
246
+ dy = convolve(dem_masked, kernel_y) # Gradient in the y-direction
247
+
248
+ # Reapply the mask to the gradients
249
+ dx = np.ma.masked_where(dem_masked.mask, dx)
250
+ dy = np.ma.masked_where(dem_masked.mask, dy)
251
+
252
+ # Calculate the magnitude of the gradient (rise/run)
253
+ gradient_magnitude = np.ma.sqrt(dx**2 + dy**2)
254
+
255
+ # Convert gradient magnitude to slope (in degrees)
256
+ slope = np.ma.arctan(gradient_magnitude) * (180 / np.pi)
257
+
258
+ # Upsample to original resolution with bilinear interpolation
259
+ mask = slope.mask
260
+ mask = zoom(mask, zoom=factor, order=0)
261
+ slope = zoom(slope, zoom=factor, order=1)
262
+ slope[mask] = 65535
263
+
264
+ # Strip one row or column if original array was odd in that dimension
265
+ if x_odd:
266
+ slope = slope[:-1, :]
267
+ if y_odd:
268
+ slope = slope[:, :-1]
269
+
270
+ # Fill slope values where the original DEM had NaNs
271
+ slope[idx_invalid] = 65535
272
+ slope[np.isnan(slope)] = 65535
273
+ slope = slope.astype(np.uint16)
274
+
275
+ return xr.DataArray(
276
+ slope[None, :, :],
277
+ dims=("bands", "y", "x"),
278
+ coords={
279
+ "bands": ["slope"],
280
+ "y": inarr.y,
281
+ "x": inarr.x,
282
+ },
283
+ )
284
+
285
+
286
+ def select_timestep_from_temporal_features(
287
+ features: xr.DataArray, target_date: Optional[str] = None
288
+ ) -> xr.DataArray:
289
+ """Select a specific timestep from temporal features based on target date.
290
+
291
+ Parameters
292
+ ----------
293
+ features : xr.DataArray
294
+ Temporal features with time dimension preserved.
295
+ target_date : str, optional
296
+ Target date in ISO format (YYYY-MM-DD). If None, selects middle timestep.
297
+
298
+ Returns
299
+ -------
300
+ xr.DataArray
301
+ Features for the selected timestep with time dimension removed.
302
+ """
303
+ if target_date is None:
304
+ # Select middle timestep
305
+ mid_idx = len(features.t) // 2
306
+ features = features.isel(t=mid_idx)
307
+ else:
308
+ # Parse target date and find closest timestep
309
+ target_datetime = np.datetime64(target_date)
310
+
311
+ # Check if target_datetime is within the temporal extent of features
312
+ min_time = features.t.min().values
313
+ max_time = features.t.max().values
314
+
315
+ if target_datetime < min_time or target_datetime > max_time:
316
+ raise ValueError(
317
+ f"Target date {target_date} is outside the temporal extent of features. "
318
+ f"Available time range: {min_time} to {max_time}"
319
+ )
320
+
321
+ # Find closest timestep
322
+ features = features.sel(t=target_datetime, method="nearest")
323
+
324
+ return features
325
+
326
+
327
+ def extract_presto_embeddings(
328
+ inarr: xr.DataArray, parameters: dict, epsg: int
329
+ ) -> xr.DataArray:
330
+ """Executes the feature extraction process on the input array."""
331
+
332
+ if epsg is None:
333
+ raise ValueError(
334
+ "EPSG code is required for Presto feature extraction, but was "
335
+ "not correctly initialized."
336
+ )
337
+ if "presto_model_url" not in parameters:
338
+ raise ValueError('Missing required parameter "presto_model_url"')
339
+
340
+ presto_model_url = parameters.get("presto_model_url")
341
+ logger.info(f'Loading Presto model from "{presto_model_url}"')
342
+ prometheo_wheel_url = parameters.get("prometheo_wheel_url", PROMETHEO_WHL_URL)
343
+ logger.info(f'Loading Prometheo wheel from "{prometheo_wheel_url}"')
344
+
345
+ ignore_dependencies = parameters.get("ignore_dependencies", False)
346
+ if ignore_dependencies:
347
+ logger.info(
348
+ "`ignore_dependencies` flag is set to True. Make sure that "
349
+ "Presto and its dependencies are available on the runtime "
350
+ "environment"
351
+ )
352
+
353
+ # The below is required to avoid flipping of the result
354
+ # when running on OpenEO backend!
355
+ inarr = inarr.transpose(
356
+ "bands", "t", "x", "y"
357
+ ) # Presto/Prometheo expects xy dimension order
358
+
359
+ # Change the band names
360
+ new_band_names = [GFMAP_BAND_MAPPING.get(b.item(), b.item()) for b in inarr.bands]
361
+ inarr = inarr.assign_coords(bands=new_band_names)
362
+
363
+ # Log pixel statistics
364
+ total_pixels = inarr.size
365
+ num_nan_pixels = np.isnan(inarr.values).sum()
366
+ num_zero_pixels = (inarr.values == 0).sum()
367
+ num_nodatavalue_pixels = (inarr.values == 65535).sum()
368
+ logger.info("Band names: " + ", ".join(inarr.bands.values))
369
+ logger.debug(
370
+ f"Array dtype: {inarr.dtype}, "
371
+ f"Array size: {inarr.shape}, total pixels: {total_pixels}, "
372
+ f"Pixel statistics: NaN pixels = {num_nan_pixels} "
373
+ f"({num_nan_pixels / total_pixels * 100:.2f}%), "
374
+ f"0 pixels = {num_zero_pixels} "
375
+ f"({num_zero_pixels / total_pixels * 100:.2f}%), "
376
+ f"NODATAVALUE pixels = {num_nodatavalue_pixels} "
377
+ f"({num_nodatavalue_pixels / total_pixels * 100:.2f}%)"
378
+ )
379
+
380
+ # Log mean value (ignoring NaNs) per band
381
+ for band in inarr.bands.values:
382
+ band_data = inarr.sel(bands=band).values
383
+ mean_value = np.nanmean(band_data)
384
+ logger.debug(f"Band '{band}': Mean value (ignoring NaNs) = {mean_value:.2f}")
385
+
386
+ # Handle NaN values in Presto compatible way
387
+ inarr = inarr.fillna(65535)
388
+
389
+ if not ignore_dependencies:
390
+ # Unzip the Presto dependencies on the backend
391
+ logger.info("Unpacking prometheo wheel")
392
+ deps_dir = unpack_prometheo_wheel(prometheo_wheel_url)
393
+
394
+ logger.info("Appending dependencies")
395
+ sys.path.append(str(deps_dir))
396
+
397
+ if "slope" not in inarr.bands:
398
+ # If 'slope' is not present we need to compute it here
399
+ logger.warning("`slope` not found in input array. Computing ...")
400
+ resolution = evaluate_resolution(inarr.isel(t=0), epsg)
401
+ slope = compute_slope(inarr.isel(t=0), resolution)
402
+ slope = slope.expand_dims({"t": inarr.t}, axis=0).astype("float32")
403
+
404
+ inarr = xr.concat([inarr.astype("float32"), slope], dim="bands")
405
+
406
+ batch_size = parameters.get("batch_size", 256)
407
+ temporal_prediction = parameters.get("temporal_prediction", False)
408
+ target_date = parameters.get("target_date", None)
409
+ logger.info(
410
+ (
411
+ f"Extracting Presto features with batch size {batch_size}, "
412
+ f"temporal_prediction={temporal_prediction}, "
413
+ f"target_date={target_date}"
414
+ )
415
+ )
416
+
417
+ # TODO: compile_presto not used for now?
418
+ # compile_presto = parameters.get("compile_presto", False)
419
+ # self.logger.info(f"Compile presto: {compile_presto}")
420
+
421
+ logger.info("Loading Presto model for inference")
422
+
423
+ # TODO: try to take run_model_inference from worldcereal
424
+ from prometheo.datasets.worldcereal import run_model_inference
425
+ from prometheo.models import Presto
426
+ from prometheo.models.pooling import PoolingMethods
427
+ from prometheo.models.presto.wrapper import load_presto_weights
428
+
429
+ presto_model = Presto()
430
+ presto_model = load_presto_weights(presto_model, presto_model_url)
431
+
432
+ logger.info("Extracting presto features")
433
+ # Check if we have the expected 12 timesteps
434
+ if len(inarr.t) != 12:
435
+ raise ValueError(f"Can only run Presto on 12 timesteps, got: {len(inarr.t)}")
436
+
437
+ # Determine pooling method based on temporal_prediction parameter
438
+ pooling_method = (
439
+ PoolingMethods.TIME if temporal_prediction else PoolingMethods.GLOBAL
440
+ )
441
+ logger.info(f"Using pooling method: {pooling_method}")
442
+
443
+ features = run_model_inference(
444
+ inarr,
445
+ presto_model,
446
+ epsg=epsg,
447
+ batch_size=batch_size,
448
+ pooling_method=pooling_method,
449
+ )
450
+
451
+ # If temporal prediction, select specific timestep based on target_date
452
+ if temporal_prediction:
453
+ features = select_timestep_from_temporal_features(features, target_date)
454
+
455
+ features = features.transpose(
456
+ "bands", "y", "x"
457
+ ) # openEO expects yx order after the UDF
458
+
459
+ return features
460
+
461
+
462
+ def get_latlons(inarr: xr.DataArray, epsg: int) -> xr.DataArray:
463
+ """Returns the latitude and longitude coordinates of the given array in
464
+ a dataarray. Returns a dataarray with the same width/height of the input
465
+ array, but with two bands, one for latitude and one for longitude. The
466
+ metadata coordinates of the output array are the same as the input
467
+ array, as the array wasn't reprojected but instead new features were
468
+ computed.
469
+
470
+ The latitude and longitude band names are standardized to the names
471
+ `LAT_HARMONIZED_NAME` and `LON_HARMONIZED_NAME` respectively.
472
+ """
473
+
474
+ lon = inarr.coords["x"]
475
+ lat = inarr.coords["y"]
476
+ lon, lat = np.meshgrid(lon, lat)
477
+
478
+ if epsg is None:
479
+ raise Exception(
480
+ "EPSG code was not defined, cannot extract lat/lon array "
481
+ "as the CRS is unknown."
482
+ )
483
+
484
+ # If the coordiantes are not in EPSG:4326, we need to reproject them
485
+ if epsg != 4326:
486
+ # Initializes a pyproj reprojection object
487
+ transformer = Transformer.from_crs(
488
+ crs_from=CRS.from_epsg(epsg),
489
+ crs_to=CRS.from_epsg(4326),
490
+ always_xy=True,
491
+ )
492
+ lon, lat = transformer.transform(xx=lon, yy=lat)
493
+
494
+ # Create a two channel numpy array of the lat and lons together by stacking
495
+ latlon = np.stack([lat, lon])
496
+
497
+ # Repack in a dataarray
498
+ return xr.DataArray(
499
+ latlon,
500
+ dims=["bands", "y", "x"],
501
+ coords={
502
+ "bands": [LAT_HARMONIZED_NAME, LON_HARMONIZED_NAME],
503
+ "y": inarr.coords["y"],
504
+ "x": inarr.coords["x"],
505
+ },
506
+ )
507
+
508
+
509
+ def rescale_s1_backscatter(arr: xr.DataArray) -> xr.DataArray:
510
+ """Rescales the input array from uint16 to float32 decibel values.
511
+ The input array should be in uint16 format, as this optimizes memory usage in Open-EO
512
+ processes. This function is called automatically on the bands of the input array, except
513
+ if the parameter `rescale_s1` is set to False.
514
+ """
515
+ s1_bands = ["S1-SIGMA0-VV", "S1-SIGMA0-VH", "S1-SIGMA0-HV", "S1-SIGMA0-HH"]
516
+ s1_bands_to_select = list(set(arr.bands.values) & set(s1_bands))
517
+
518
+ if len(s1_bands_to_select) == 0:
519
+ return arr
520
+
521
+ data_to_rescale = arr.sel(bands=s1_bands_to_select).astype(np.float32).data
522
+
523
+ # Assert that the values are set between 1 and 65535
524
+ if data_to_rescale.min().item() < 1 or data_to_rescale.max().item() > 65535:
525
+ raise ValueError(
526
+ "The input array should be in uint16 format, with values between 1 and 65535. "
527
+ "This restriction assures that the data was processed according to the S1 fetcher "
528
+ "preprocessor. The user can disable this scaling manually by setting the "
529
+ "`rescale_s1` parameter to False in the feature extractor."
530
+ )
531
+
532
+ # Converting back to power values
533
+ data_to_rescale = 20.0 * np.log10(data_to_rescale) - 83.0
534
+ data_to_rescale = np.power(10, data_to_rescale / 10.0)
535
+ data_to_rescale[~np.isfinite(data_to_rescale)] = np.nan
536
+
537
+ # Converting power values to decibels
538
+ data_to_rescale = 10.0 * np.log10(data_to_rescale)
539
+
540
+ # Change the bands within the array
541
+ arr.loc[dict(bands=s1_bands_to_select)] = data_to_rescale
542
+ return arr
543
+
544
+
545
+ # Below comes the actual UDF part
546
+
547
+
548
+ # Apply the Feature Extraction UDF
549
+ def apply_udf_data(udf_data: UdfData) -> UdfData:
550
+ """This is the actual openeo UDF that will be executed by the backend."""
551
+
552
+ cube = udf_data.datacube_list[0]
553
+ parameters = copy.deepcopy(udf_data.user_context)
554
+
555
+ proj = udf_data.proj
556
+ if proj is not None:
557
+ proj = proj["EPSG"]
558
+
559
+ parameters[EPSG_HARMONIZED_NAME] = proj
560
+
561
+ arr = cube.get_array().transpose("bands", "t", "y", "x")
562
+
563
+ epsg = parameters.pop(EPSG_HARMONIZED_NAME)
564
+ logger.info(f"EPSG code determined for feature extraction: {epsg}")
565
+
566
+ if parameters.get("rescale_s1", True):
567
+ arr = rescale_s1_backscatter(arr)
568
+
569
+ arr = extract_presto_embeddings(inarr=arr, parameters=parameters, epsg=epsg)
570
+
571
+ cube = XarrayDataCube(arr)
572
+
573
+ udf_data.datacube_list = [cube]
574
+
575
+ return udf_data
576
+
577
+
578
+ # Change band names
579
+ def apply_metadata(metadata: CollectionMetadata, context: dict) -> CollectionMetadata:
580
+ return metadata.rename_labels(
581
+ dimension="bands", target=[f"presto_ft_{i}" for i in range(128)]
582
+ )
worldcereal/openeo/inference.py ADDED
@@ -0,0 +1,1191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """openEO UDF to compute Presto/Prometheo features with clean code structure."""
2
+
3
+ import logging
4
+ import os
5
+ import random
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import requests
12
+ import xarray as xr
13
+ from openeo.udf import XarrayDataCube
14
+ from openeo.udf.udf_data import UdfData
15
+ from pyproj import Transformer
16
+ from scipy.ndimage import convolve, zoom
17
+ from shapely.geometry import Point
18
+ from shapely.ops import transform
19
+
20
+ try:
21
+ from loguru import logger
22
+
23
+ logger.remove()
24
+ logger.add(sys.stderr, level="INFO")
25
+
26
+ class InterceptHandler(logging.Handler):
27
+ def emit(self, record):
28
+ level = record.levelname
29
+ logger.opt(depth=6).log(level, record.getMessage())
30
+
31
+ # Replace existing handlers
32
+ for h in logging.root.handlers[:]:
33
+ logging.root.removeHandler(h)
34
+
35
+ logging.root.setLevel(logging.INFO)
36
+ logging.root.addHandler(InterceptHandler())
37
+
38
+ except ImportError:
39
+ # loguru not available, use standard logging
40
+ logger = logging.getLogger(__name__)
41
+
42
+ _MODULE_CACHE_KEY = f"__model_cache_{__name__}"
43
+
44
+ # Constants
45
+ PROMETHEO_WHL_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/prometheo-0.0.3-py3-none-any.whl"
46
+
47
+ GFMAP_BAND_MAPPING = {
48
+ "S2-L2A-B02": "B2",
49
+ "S2-L2A-B03": "B3",
50
+ "S2-L2A-B04": "B4",
51
+ "S2-L2A-B05": "B5",
52
+ "S2-L2A-B06": "B6",
53
+ "S2-L2A-B07": "B7",
54
+ "S2-L2A-B08": "B8",
55
+ "S2-L2A-B8A": "B8A",
56
+ "S2-L2A-B11": "B11",
57
+ "S2-L2A-B12": "B12",
58
+ "S1-SIGMA0-VH": "VH",
59
+ "S1-SIGMA0-VV": "VV",
60
+ "AGERA5-TMEAN": "temperature_2m",
61
+ "AGERA5-PRECIP": "total_precipitation",
62
+ }
63
+
64
+ LAT_HARMONIZED_NAME = "GEO-LAT"
65
+ LON_HARMONIZED_NAME = "GEO-LON"
66
+ EPSG_HARMONIZED_NAME = "GEO-EPSG"
67
+
68
+ S1_BANDS = ["S1-SIGMA0-VV", "S1-SIGMA0-VH", "S1-SIGMA0-HV", "S1-SIGMA0-HH"]
69
+ NODATA_VALUE = 65535
70
+
71
+ POSTPROCESSING_EXCLUDED_VALUES = [254, 255, 65535]
72
+ POSTPROCESSING_NODATA = 255
73
+
74
+ NUM_THREADS = 2
75
+
76
+ sys.path.append("feature_deps")
77
+ sys.path.append("onnx_deps")
78
+ import onnxruntime as ort # noqa: E402
79
+
80
+ _PROMETHEO_INSTALLED = False
81
+
82
+ # Global variables for Prometheo imports
83
+ Presto = None
84
+ load_presto_weights = None
85
+ run_model_inference = None
86
+ PoolingMethods = None
87
+
88
+
89
+ # =============================================================================
90
+ # STANDALONE FUNCTIONS (Work in both apply_udf_data and apply_metadata contexts)
91
+ # =============================================================================
92
+ def get_model_cache():
93
+ """Get or create module-specific cache."""
94
+ if not hasattr(sys, _MODULE_CACHE_KEY):
95
+ setattr(sys, _MODULE_CACHE_KEY, {})
96
+ return getattr(sys, _MODULE_CACHE_KEY)
97
+
98
+
99
+ def _ensure_prometheo_dependencies():
100
+ """Non-cached dependency check."""
101
+ global _PROMETHEO_INSTALLED, Presto, load_presto_weights, run_model_inference, PoolingMethods
102
+
103
+ if _PROMETHEO_INSTALLED:
104
+ return
105
+
106
+ try:
107
+ # Try to import first
108
+ from prometheo.datasets.worldcereal import run_model_inference
109
+ from prometheo.models import Presto
110
+ from prometheo.models.pooling import PoolingMethods
111
+ from prometheo.models.presto.wrapper import load_presto_weights
112
+
113
+ # They're now available in the global scope
114
+ _PROMETHEO_INSTALLED = True
115
+ return
116
+ except ImportError:
117
+ pass
118
+
119
+ # Installation required
120
+ logger.info("Prometheo not available, installing...")
121
+ _install_prometheo()
122
+
123
+ # Import immediately after installation - these will be available globally
124
+ from prometheo.datasets.worldcereal import run_model_inference
125
+ from prometheo.models import Presto
126
+ from prometheo.models.pooling import PoolingMethods
127
+ from prometheo.models.presto.wrapper import load_presto_weights
128
+
129
+ optimize_pytorch_cpu_performance(NUM_THREADS)
130
+ _PROMETHEO_INSTALLED = True
131
+
132
+
133
+ def _install_prometheo():
134
+ """Non-cached installation function."""
135
+ import shutil
136
+ import tempfile
137
+ import urllib.request
138
+ import zipfile
139
+
140
+ temp_dir = Path(tempfile.mkdtemp())
141
+ try:
142
+ # Download wheel
143
+ wheel_path, _ = urllib.request.urlretrieve(PROMETHEO_WHL_URL)
144
+
145
+ # Extract to temp directory
146
+ with zipfile.ZipFile(wheel_path, "r") as zip_ref:
147
+ zip_ref.extractall(temp_dir)
148
+
149
+ # Add to Python path
150
+ sys.path.append(str(temp_dir))
151
+ logger.info(f"Prometheo installed to {temp_dir}.")
152
+
153
+ except Exception as e:
154
+ if temp_dir.exists():
155
+ shutil.rmtree(temp_dir)
156
+ logger.error(f"Failed to install prometheo: {e}")
157
+ raise
158
+
159
+
160
+ def load_onnx_model_cached(model_url: str):
161
+ """ONNX loading is fine since it's pure (no side effects)."""
162
+
163
+ cache = get_model_cache()
164
+ if model_url in cache:
165
+ logger.debug(f"ONNX model cache hit for {model_url}.")
166
+ return cache[model_url]
167
+
168
+ logger.info(f"Loading ONNX model from {model_url}")
169
+ response = requests.get(model_url, timeout=120)
170
+
171
+ session_options, providers = optimize_onnx_cpu_performance(NUM_THREADS)
172
+
173
+ model = ort.InferenceSession(response.content, session_options, providers=providers)
174
+
175
+ metadata = model.get_modelmeta().custom_metadata_map
176
+ class_params = eval(metadata["class_params"], {"__builtins__": None}, {})
177
+
178
+ lut = dict(zip(class_params["class_names"], class_params["class_to_label"]))
179
+ sorted_lut = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])}
180
+
181
+ result = (model, sorted_lut)
182
+ cache[model_url] = result
183
+ return result
184
+
185
+
186
+ def load_presto_weights_cached(presto_model_url: str):
187
+ """Manual caching for Presto weights with dependency check."""
188
+ cache = get_model_cache()
189
+ if presto_model_url in cache:
190
+ logger.debug(f"Presto model cache hit for {presto_model_url}")
191
+ return cache[presto_model_url]
192
+
193
+ # Ensure dependencies are available (not cached)
194
+ _ensure_prometheo_dependencies()
195
+
196
+ logger.info(f"Loading Presto weights from: {presto_model_url}")
197
+
198
+ model = Presto() # type: ignore
199
+ result = load_presto_weights(model, presto_model_url) # type: ignore
200
+
201
+ cache[presto_model_url] = result
202
+ return result
203
+
204
+
205
+ def get_output_labels(lut_sorted: dict, postprocess_parameters: dict = {}) -> list:
206
+ """Generate output band names from LUT - works in both contexts.
207
+ Parameters
208
+ ----------
209
+ lut_sorted : dict
210
+ Sorted lookup table mapping class names to labels.
211
+ postprocess_parameters : dict
212
+ Postprocessing parameters to determine whether to keep per-class probability bands.
213
+ If not provided, we assume all probabilities are kept."""
214
+
215
+ # Determine whether to remove per-class probability bands
216
+ # based on postprocessing parameters
217
+ postprocessing_enabled = postprocess_parameters.get("enabled", True)
218
+ keep_class_probs = postprocess_parameters.get("keep_class_probs", True)
219
+ if postprocessing_enabled and (not keep_class_probs):
220
+ # Only classification and overall probability
221
+ return ["classification", "probability"]
222
+ else:
223
+ # Include per-class probabilities
224
+ class_names = lut_sorted.keys()
225
+ return ["classification", "probability"] + [
226
+ f"probability_{name}" for name in class_names
227
+ ]
228
+
229
+
230
+ def optimize_pytorch_cpu_performance(num_threads):
231
+ """CPU-specific optimizations for Prometheo."""
232
+ import torch
233
+
234
+ # Thread configuration
235
+
236
+ torch.set_num_threads(num_threads)
237
+ torch.set_num_interop_threads(
238
+ num_threads
239
+ ) # TODO test setting to 4 due to parallel slope cal ect
240
+ os.environ["OMP_NUM_THREADS"] = str(num_threads)
241
+ os.environ["MKL_NUM_THREADS"] = str(num_threads)
242
+ os.environ["OPENBLAS_NUM_THREADS"] = str(num_threads)
243
+
244
+ logger.info(f"PyTorch CPU: using {num_threads} threads")
245
+
246
+ # CPU-specific optimizations
247
+ if hasattr(torch.backends, "mkldnn"):
248
+ torch.backends.mkldnn.enabled = True
249
+
250
+ torch.set_grad_enabled(False) # Disable gradients for inference
251
+
252
+ return num_threads
253
+
254
+
255
+ def optimize_onnx_cpu_performance(num_threads):
256
+ """CPU-specific ONNX optimizations."""
257
+ session_options = ort.SessionOptions()
258
+
259
+ session_options.intra_op_num_threads = num_threads
260
+ session_options.inter_op_num_threads = (
261
+ num_threads # TODO test setting to 1 due to sequential nature
262
+ )
263
+
264
+ # CPU-specific optimizations
265
+ session_options.enable_cpu_mem_arena = True
266
+ session_options.enable_mem_pattern = True
267
+ session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
268
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
269
+
270
+ providers = ["CPUExecutionProvider"]
271
+
272
+ return session_options, providers
273
+
274
+
275
+ # =============================================================================
276
+ # POSTPROCESSING FUNCTIONS
277
+ # =============================================================================
278
+
279
+
280
+ def majority_vote(
281
+ base_labels: xr.DataArray,
282
+ max_probabilities: xr.DataArray,
283
+ kernel_size: int,
284
+ ) -> xr.DataArray:
285
+ """Majority vote is performed using a sliding local kernel.
286
+ For each pixel, the voting of a final class is done by counting
287
+ neighbours values.
288
+ Pixels that have one of the specified excluded values are
289
+ excluded in the voting process and are unchanged.
290
+
291
+ The prediction probabilities are reevaluated by taking, for each pixel,
292
+ the average of probabilities of the neighbors that belong to the winning class.
293
+ (For example, if a pixel was voted to class 2 and there are three
294
+ neighbors of that class, then the new probability is the sum of the
295
+ old probabilities of each pixels divided by 3)
296
+
297
+ Parameters
298
+ ----------
299
+ base_labels : xr.DataArray
300
+ The original predicted classification labels.
301
+ max_probabilities : xr.DataArray
302
+ The original probabilities of the winning class (ranging between 0 and 100).
303
+ kernel_size : int
304
+ The size of the kernel used for the neighbour around the pixel.
305
+
306
+ Returns
307
+ -------
308
+ xr.DataArray
309
+ The cleaned classification labels and associated probabilities.
310
+ """
311
+ from scipy.signal import convolve2d
312
+
313
+ prediction = base_labels.values
314
+ probability = max_probabilities.values
315
+
316
+ # As the probabilities are in integers between 0 and 100,
317
+ # we use uint16 matrices to store the vote scores
318
+ assert (
319
+ kernel_size <= 25
320
+ ), f"Kernel value cannot be larger than 25 (currently: {kernel_size}) because it might lead to scenarios where the 16-bit count matrix is overflown"
321
+
322
+ # Build a class mapping, so classes are converted to indexes and vice-versa
323
+ unique_values = set(np.unique(prediction))
324
+ unique_values = sorted(unique_values - set(POSTPROCESSING_EXCLUDED_VALUES)) # type: ignore
325
+ index_value_lut = [(k, v) for k, v in enumerate(unique_values)]
326
+
327
+ counts = np.zeros(shape=(*prediction.shape, len(unique_values)), dtype=np.uint16)
328
+ probabilities = np.zeros(
329
+ shape=(*probability.shape, len(unique_values)), dtype=np.uint16
330
+ )
331
+
332
+ # Iterates for each classes
333
+ for cls_idx, cls_value in index_value_lut:
334
+ # Take the binary mask of the interest class, and multiply by the probabilities
335
+ class_mask = ((prediction == cls_value) * probability).astype(np.uint16)
336
+
337
+ # Set to 0 the class scores where the label is excluded
338
+ for excluded_value in POSTPROCESSING_EXCLUDED_VALUES:
339
+ class_mask[prediction == excluded_value] = 0
340
+
341
+ # Binary class mask, used to count HOW MANY neighbours pixels are used for this class
342
+ binary_class_mask = (class_mask > 0).astype(np.uint16)
343
+
344
+ # Creates the kernel
345
+ kernel = np.ones(shape=(kernel_size, kernel_size), dtype=np.uint16)
346
+
347
+ # Counts around the window the sum of probabilities for that given class
348
+ counts[:, :, cls_idx] = convolve2d(class_mask, kernel, mode="same")
349
+
350
+ # Counts the number of neighbors pixels that voted for that given class
351
+ class_voters = convolve2d(binary_class_mask, kernel, mode="same")
352
+ # Remove the 0 values because might create divide by 0 issues
353
+ class_voters[class_voters == 0] = 1
354
+
355
+ probabilities[:, :, cls_idx] = np.divide(counts[:, :, cls_idx], class_voters)
356
+
357
+ # Initializes output array
358
+ aggregated_predictions = np.zeros(
359
+ shape=(counts.shape[0], counts.shape[1]), dtype=np.uint16
360
+ )
361
+ # Initializes probabilities output array
362
+ aggregated_probabilities = np.zeros(
363
+ shape=(counts.shape[0], counts.shape[1]), dtype=np.uint16
364
+ )
365
+
366
+ if len(unique_values) > 0:
367
+ # Takes the indices that have the biggest scores
368
+ aggregated_predictions_indices = np.argmax(counts, axis=2)
369
+
370
+ # Get the new probabilities of the predictions
371
+ aggregated_probabilities = np.take_along_axis(
372
+ probabilities,
373
+ aggregated_predictions_indices.reshape(
374
+ *aggregated_predictions_indices.shape, 1
375
+ ),
376
+ axis=2,
377
+ ).squeeze()
378
+
379
+ # Check which pixels have a counts value equal to 0
380
+ no_score_mask = np.sum(counts, axis=2) == 0
381
+
382
+ # convert back to values from indices
383
+ for cls_idx, cls_value in index_value_lut:
384
+ aggregated_predictions[aggregated_predictions_indices == cls_idx] = (
385
+ cls_value
386
+ )
387
+ aggregated_predictions = aggregated_predictions.astype(np.uint16)
388
+
389
+ aggregated_predictions[no_score_mask] = POSTPROCESSING_NODATA
390
+ aggregated_probabilities[no_score_mask] = POSTPROCESSING_NODATA
391
+
392
+ # Setting excluded values back to their original values
393
+ for excluded_value in POSTPROCESSING_EXCLUDED_VALUES:
394
+ aggregated_predictions[prediction == excluded_value] = excluded_value
395
+ aggregated_probabilities[prediction == excluded_value] = excluded_value
396
+
397
+ return xr.DataArray(
398
+ np.stack((aggregated_predictions, aggregated_probabilities)),
399
+ dims=["bands", "y", "x"],
400
+ coords={
401
+ "bands": ["classification", "probability"],
402
+ "y": base_labels.y,
403
+ "x": base_labels.x,
404
+ },
405
+ )
406
+
407
+
408
+ def smooth_probabilities(
409
+ base_labels: xr.DataArray, class_probabilities: xr.DataArray
410
+ ) -> xr.DataArray:
411
+ """Performs gaussian smoothing on the class probabilities. Requires the
412
+ base labels to keep the pixels that are excluded away from smoothing.
413
+ """
414
+ from scipy.signal import convolve2d
415
+
416
+ base_labels_vals = base_labels.values
417
+ probabilities_vals = class_probabilities.values
418
+
419
+ excluded_mask = np.in1d(
420
+ base_labels_vals.reshape(-1),
421
+ POSTPROCESSING_EXCLUDED_VALUES,
422
+ ).reshape(*base_labels_vals.shape)
423
+
424
+ conv_kernel = np.array([[1, 2, 1], [2, 3, 2], [1, 2, 1]], dtype=np.int16)
425
+
426
+ for class_idx in range(probabilities_vals.shape[0]):
427
+ probabilities_vals[class_idx] = (
428
+ convolve2d(
429
+ probabilities_vals[class_idx],
430
+ conv_kernel,
431
+ mode="same",
432
+ boundary="symm",
433
+ )
434
+ / conv_kernel.sum()
435
+ )
436
+ probabilities_vals[class_idx][excluded_mask] = 0
437
+
438
+ # Sum of probabilities should be 1, cast to uint16
439
+ probabilities_vals = np.round(
440
+ probabilities_vals / probabilities_vals.sum(axis=0) * 100.0
441
+ ).astype("uint16")
442
+
443
+ return xr.DataArray(
444
+ probabilities_vals,
445
+ coords=class_probabilities.coords,
446
+ dims=class_probabilities.dims,
447
+ )
448
+
449
+
450
+ def reclassify(
451
+ base_labels: xr.DataArray,
452
+ base_max_probs: xr.DataArray,
453
+ probabilities: xr.DataArray,
454
+ ) -> xr.DataArray:
455
+ base_labels_vals = base_labels.values
456
+ base_max_probs_vals = base_max_probs.values
457
+
458
+ excluded_mask = np.in1d(
459
+ base_labels_vals.reshape(-1),
460
+ POSTPROCESSING_EXCLUDED_VALUES,
461
+ ).reshape(*base_labels_vals.shape)
462
+
463
+ new_labels_vals = np.argmax(probabilities.values, axis=0)
464
+ new_max_probs_vals = np.max(probabilities.values, axis=0)
465
+
466
+ new_labels_vals[excluded_mask] = base_labels_vals[excluded_mask]
467
+ new_max_probs_vals[excluded_mask] = base_max_probs_vals[excluded_mask]
468
+
469
+ return xr.DataArray(
470
+ np.stack((new_labels_vals, new_max_probs_vals)),
471
+ dims=["bands", "y", "x"],
472
+ coords={
473
+ "bands": ["classification", "probability"],
474
+ "y": base_labels.y,
475
+ "x": base_labels.x,
476
+ },
477
+ )
478
+
479
+
480
+ # =============================================================================
481
+ # ERROR HANDLING - SIMPLE VERSION
482
+ # =============================================================================
483
+
484
+
485
+ def create_nan_output_array(
486
+ inarr: xr.DataArray, num_outputs: int, error_info: str = ""
487
+ ) -> xr.DataArray:
488
+ """Creates a NaN-filled output array with proper dimensions and coordinates.
489
+
490
+ Parameters
491
+ ----------
492
+ inarr : xr.DataArray
493
+ Input array to derive dimensions from
494
+ num_outputs : int
495
+ Number of output bands/classes
496
+ error_info : str
497
+ Error information to include in attributes for debugging
498
+
499
+ Returns
500
+ -------
501
+ xr.DataArray
502
+ NaN-filled array with proper structure
503
+ """
504
+ logger.error(f"Creating NaN output array due to error: {error_info}")
505
+ logger.error(f"Input array shape: {inarr.shape}, dims: {inarr.dims}")
506
+ logger.error(
507
+ f"Input array coords - bands: {inarr.bands.values}, t: {len(inarr.t)}, x: {len(inarr.x)}, y: {len(inarr.y)}"
508
+ )
509
+
510
+ # Create NaN array with same spatial dimensions
511
+ nan_array = np.full(
512
+ (num_outputs, len(inarr.y), len(inarr.x)), np.nan, dtype=np.float32
513
+ )
514
+
515
+ # Create output array with proper coordinates
516
+ output_array = xr.DataArray(
517
+ nan_array,
518
+ dims=["bands", "y", "x"],
519
+ coords={
520
+ "bands": list(range(num_outputs)),
521
+ "y": inarr.y,
522
+ "x": inarr.x,
523
+ },
524
+ attrs={"error": error_info},
525
+ )
526
+
527
+ return output_array
528
+
529
+
530
+ # =============================================================================
531
+ # CLASSES (Main logic for apply_udf_data)
532
+ # =============================================================================
533
+
534
+
535
+ class SlopeCalculator:
536
+ """Handles slope computation from elevation data."""
537
+
538
+ @staticmethod
539
+ def compute(resolution: float, elevation_data: np.ndarray) -> np.ndarray:
540
+ """Compute slope from elevation data."""
541
+ dem_arr = SlopeCalculator._prepare_dem_array(elevation_data)
542
+ dem_downsampled = SlopeCalculator._downsample_to_20m(dem_arr, resolution)
543
+ slope = SlopeCalculator._compute_slope_gradient(dem_downsampled)
544
+ result = SlopeCalculator._upsample_to_original(slope, dem_arr.shape, resolution)
545
+ return result
546
+
547
+ @staticmethod
548
+ def _prepare_dem_array(dem: np.ndarray) -> np.ndarray:
549
+ """Prepare DEM array by handling NaNs and invalid values."""
550
+ dem_arr = dem.astype(np.float32)
551
+ dem_arr[dem_arr == NODATA_VALUE] = np.nan
552
+ return SlopeCalculator._fill_nans(dem_arr)
553
+
554
+ @staticmethod
555
+ def _fill_nans(dem_arr: np.ndarray, max_iter: int = 2) -> np.ndarray:
556
+ """Fill NaN values using rolling fill approach."""
557
+ if max_iter == 0 or not np.any(np.isnan(dem_arr)):
558
+ return dem_arr
559
+
560
+ mask = np.isnan(dem_arr)
561
+ roll_params = [(0, 1), (0, -1), (1, 0), (-1, 0)]
562
+ random.shuffle(roll_params)
563
+
564
+ for roll_param in roll_params:
565
+ rolled = np.roll(dem_arr, roll_param, axis=(0, 1))
566
+ dem_arr[mask] = rolled[mask]
567
+
568
+ return SlopeCalculator._fill_nans(dem_arr, max_iter - 1)
569
+
570
+ @staticmethod
571
+ def _downsample_to_20m(dem_arr: np.ndarray, resolution: float) -> np.ndarray:
572
+ """Downsample DEM to 20m resolution for slope computation."""
573
+ factor = int(20 / resolution)
574
+ if factor < 1 or factor % 2 != 0:
575
+ raise ValueError(f"Unsupported resolution for slope: {resolution}")
576
+
577
+ X, Y = dem_arr.shape
578
+ pad_X, pad_Y = (
579
+ (factor - (X % factor)) % factor,
580
+ (factor - (Y % factor)) % factor,
581
+ )
582
+ padded = np.pad(dem_arr, ((0, pad_X), (0, pad_Y)), mode="reflect")
583
+
584
+ reshaped = padded.reshape(
585
+ (X + pad_X) // factor, factor, (Y + pad_Y) // factor, factor
586
+ )
587
+ return np.nanmean(reshaped, axis=(1, 3))
588
+
589
+ @staticmethod
590
+ def _compute_slope_gradient(dem: np.ndarray) -> np.ndarray:
591
+ """Compute slope gradient using Sobel operators."""
592
+ kernel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) / (8.0 * 20)
593
+ kernel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) / (8.0 * 20)
594
+
595
+ dx = convolve(dem, kernel_x)
596
+ dy = convolve(dem, kernel_y)
597
+ gradient_magnitude = np.sqrt(dx**2 + dy**2)
598
+
599
+ return np.arctan(gradient_magnitude) * (180 / np.pi)
600
+
601
+ @staticmethod
602
+ def _upsample_to_original(
603
+ slope: np.ndarray, original_shape: Tuple[int, ...], resolution: float
604
+ ) -> np.ndarray:
605
+ """Upsample slope back to original resolution."""
606
+ factor = int(20 / resolution)
607
+ slope_upsampled = zoom(slope, zoom=factor, order=1)
608
+
609
+ # Handle odd dimensions
610
+ if original_shape[0] % 2 != 0:
611
+ slope_upsampled = slope_upsampled[:-1, :]
612
+ if original_shape[1] % 2 != 0:
613
+ slope_upsampled = slope_upsampled[:, :-1]
614
+
615
+ return slope_upsampled.astype(np.uint16)
616
+
617
+
618
+ class CoordinateTransformer:
619
+ """Handles coordinate transformations and spatial operations."""
620
+
621
+ @staticmethod
622
+ def get_resolution(inarr: xr.DataArray, epsg: int) -> float:
623
+ """Calculate resolution in meters."""
624
+ if epsg == 4326:
625
+ return CoordinateTransformer._get_wgs84_resolution(inarr)
626
+ return abs(inarr.x[1].values - inarr.x[0].values)
627
+
628
+ @staticmethod
629
+ def _get_wgs84_resolution(inarr: xr.DataArray) -> float:
630
+ """Convert WGS84 coordinates to meters for resolution calculation."""
631
+ transformer = Transformer.from_crs(4326, 3857, always_xy=True)
632
+ points = [Point(x, y) for x, y in zip(inarr.x.values, inarr.y.values)]
633
+ points = [transform(transformer.transform, point) for point in points]
634
+ return abs(points[1].x - points[0].x)
635
+
636
+ @staticmethod
637
+ def get_lat_lon_array(inarr: xr.DataArray, epsg: int) -> xr.DataArray:
638
+ """Create latitude/longitude array from coordinates."""
639
+ lon, lat = np.meshgrid(inarr.x.values, inarr.y.values)
640
+
641
+ if epsg != 4326:
642
+ transformer = Transformer.from_crs(epsg, 4326, always_xy=True)
643
+ lon, lat = transformer.transform(lon, lat)
644
+
645
+ latlon = np.stack([lat, lon])
646
+ return xr.DataArray(
647
+ latlon,
648
+ dims=["bands", "y", "x"],
649
+ coords={
650
+ "bands": [LAT_HARMONIZED_NAME, LON_HARMONIZED_NAME],
651
+ "y": inarr.y,
652
+ "x": inarr.x,
653
+ },
654
+ )
655
+
656
+
657
+ class DataPreprocessor:
658
+ """Handles data preprocessing operations."""
659
+
660
+ @staticmethod
661
+ def rescale_s1_backscatter(arr: xr.DataArray) -> xr.DataArray:
662
+ """Rescale Sentinel-1 backscatter from uint16 to dB values."""
663
+ s1_bands_present = [b for b in S1_BANDS if b in arr.bands.values]
664
+ if not s1_bands_present:
665
+ return arr
666
+
667
+ s1_data = arr.sel(bands=s1_bands_present).astype(np.float32)
668
+ DataPreprocessor._validate_s1_data(s1_data.values)
669
+
670
+ # Convert to power values then to dB
671
+ power_values = 20.0 * np.log10(s1_data.values) - 83.0
672
+ power_values = np.power(10, power_values / 10.0)
673
+ power_values[~np.isfinite(power_values)] = np.nan
674
+
675
+ db_values = 10.0 * np.log10(power_values)
676
+ arr.loc[dict(bands=s1_bands_present)] = db_values
677
+
678
+ return arr
679
+
680
+ @staticmethod
681
+ def _validate_s1_data(data: np.ndarray) -> None:
682
+ """Validate S1 data meets preprocessing requirements."""
683
+ if data.min() < 1 or data.max() > NODATA_VALUE:
684
+ raise ValueError(
685
+ "S1 data should be uint16 format with values 1-65535. "
686
+ "Set 'rescale_s1' to False to disable scaling."
687
+ )
688
+
689
+
690
+ class PrestoFeatureExtractor:
691
+ """Handles Presto feature extraction pipeline."""
692
+
693
+ def __init__(self, parameters: Dict[str, Any]):
694
+ self.parameters = parameters
695
+
696
+ def extract(self, inarr: xr.DataArray, epsg: int) -> xr.DataArray:
697
+ """Extract Presto features from input array."""
698
+ if epsg is None:
699
+ raise ValueError("EPSG code required for Presto feature extraction")
700
+
701
+ # ONLY check top level - no nested lookup
702
+ presto_model_url = self.parameters.get("presto_model_url")
703
+ if not presto_model_url:
704
+ logger.error(
705
+ f"Missing presto_model_url. Available keys: {list(self.parameters.keys())}"
706
+ )
707
+ raise ValueError('Missing required parameter "presto_model_url"')
708
+
709
+ if len(inarr.t) != 12:
710
+ error_msg = (
711
+ f"Presto requires exactly 12 timesteps, but got {len(inarr.t)}. "
712
+ f"Available timesteps: {inarr.t.values}. "
713
+ f"Patch coordinates - x: {inarr.x.values.tolist()}, y: {inarr.y.values.tolist()}"
714
+ )
715
+ logger.error(error_msg)
716
+
717
+ # Return NaN array instead of crashing
718
+ return create_nan_output_array(
719
+ inarr, self.parameters["num_outputs"], error_msg
720
+ )
721
+
722
+ inarr = self._preprocess_input(inarr)
723
+
724
+ if "slope" not in inarr.bands:
725
+ inarr = self._add_slope_band(inarr, epsg)
726
+
727
+ return self._run_presto_inference(inarr, epsg)
728
+
729
+ def _preprocess_input(self, inarr: xr.DataArray) -> xr.DataArray:
730
+ """Preprocess input array for Presto."""
731
+ inarr = inarr.transpose("bands", "t", "x", "y")
732
+
733
+ # Harmonize band names
734
+ new_bands = [GFMAP_BAND_MAPPING.get(b.item(), b.item()) for b in inarr.bands]
735
+ inarr = inarr.assign_coords(bands=new_bands)
736
+
737
+ return inarr.fillna(NODATA_VALUE)
738
+
739
+ def _add_slope_band(self, inarr: xr.DataArray, epsg: int) -> xr.DataArray:
740
+ """Compute and add slope band to array."""
741
+ logger.warning("Slope band not found, computing...")
742
+ resolution = CoordinateTransformer.get_resolution(inarr.isel(t=0), epsg)
743
+ elevation_data = inarr.sel(bands="COP-DEM").isel(t=0).values
744
+
745
+ slope_array = SlopeCalculator.compute(resolution, elevation_data)
746
+ slope_da = (
747
+ xr.DataArray(
748
+ slope_array[None, :, :],
749
+ dims=("bands", "y", "x"),
750
+ coords={"bands": ["slope"], "y": inarr.y, "x": inarr.x},
751
+ )
752
+ .expand_dims({"t": inarr.t})
753
+ .astype("float32")
754
+ )
755
+
756
+ return xr.concat([inarr.astype("float32"), slope_da], dim="bands")
757
+
758
+ def _run_presto_inference(self, inarr: xr.DataArray, epsg: int) -> xr.DataArray:
759
+ """Run Presto model inference with safe dependency handling."""
760
+ # Dependencies are now handled by load_presto_weights_cached
761
+ import gc
762
+
763
+ import torch
764
+
765
+ _ensure_prometheo_dependencies()
766
+
767
+ presto_model_url = self.parameters["presto_model_url"]
768
+
769
+ model = load_presto_weights_cached(presto_model_url)
770
+
771
+ # Import here to ensure dependencies are available
772
+ pooling_method = (
773
+ PoolingMethods.TIME # type: ignore
774
+ if self.parameters.get("temporal_prediction")
775
+ else PoolingMethods.GLOBAL # type: ignore
776
+ )
777
+
778
+ logger.info("Running presto inference ...")
779
+ try:
780
+ with torch.inference_mode():
781
+ features = run_model_inference(
782
+ inarr,
783
+ model,
784
+ epsg=epsg,
785
+ batch_size=self.parameters.get("batch_size", 256), # TODO optimize?
786
+ pooling_method=pooling_method,
787
+ ) # type: ignore
788
+ logger.info("Inference completed.")
789
+
790
+ if self.parameters.get("temporal_prediction"):
791
+ features = self._select_temporal_features(features)
792
+ return features.transpose("bands", "y", "x")
793
+
794
+ finally:
795
+ gc.collect()
796
+
797
+ def _select_temporal_features(self, features: xr.DataArray) -> xr.DataArray:
798
+ """Select specific timestep from temporal features."""
799
+ target_date = self.parameters.get("target_date")
800
+
801
+ if target_date is None:
802
+ mid_idx = len(features.t) // 2
803
+ return features.isel(t=mid_idx)
804
+
805
+ target_dt = np.datetime64(target_date)
806
+ min_time, max_time = features.t.min().values, features.t.max().values
807
+
808
+ if target_dt < min_time or target_dt > max_time:
809
+ raise ValueError(
810
+ f"Target date {target_date} outside feature range: {min_time} to {max_time}"
811
+ )
812
+
813
+ return features.sel(t=target_dt, method="nearest")
814
+
815
+
816
+ class ONNXClassifier:
817
+ """Handles ONNX model inference for classification."""
818
+
819
+ def __init__(self, parameters: Dict[str, Any]):
820
+ self.parameters = parameters
821
+
822
+ def predict(self, features: xr.DataArray) -> xr.DataArray:
823
+ """Run classification prediction."""
824
+ classifier_url = self.parameters.get("classifier_url")
825
+ if not classifier_url:
826
+ logger.error(
827
+ f"Missing classifier_url. Available keys: {list(self.parameters.keys())}"
828
+ )
829
+ raise ValueError('Missing required parameter "classifier_url"')
830
+
831
+ session, lut = load_onnx_model_cached(classifier_url)
832
+ features_flat = self._prepare_features(features)
833
+
834
+ logger.info("Running ONNX model inference ...")
835
+ predictions = self._run_inference(session, lut, features_flat)
836
+ logger.info("ONNX inference completed.")
837
+
838
+ return self._reshape_predictions(predictions, features, lut)
839
+
840
+ def _prepare_features(self, features: xr.DataArray) -> np.ndarray:
841
+ """Prepare features for inference."""
842
+ return (
843
+ features.transpose("bands", "x", "y")
844
+ .stack(xy=["x", "y"])
845
+ .transpose()
846
+ .values
847
+ )
848
+
849
+ def _run_inference(
850
+ self, session: Any, lut: Dict, features: np.ndarray
851
+ ) -> np.ndarray:
852
+ """Run ONNX model inference."""
853
+ outputs = session.run(None, {"features": features})
854
+
855
+ labels = np.zeros(len(outputs[0]), dtype=np.uint16)
856
+ probabilities = np.zeros(len(outputs[0]), dtype=np.uint8)
857
+
858
+ for i, (label, prob) in enumerate(zip(outputs[0], outputs[1])):
859
+ labels[i] = lut[label]
860
+ probabilities[i] = int(round(prob[label] * 100))
861
+
862
+ class_probs = np.array(
863
+ [[prob[label] for label in lut.keys()] for prob in outputs[1]]
864
+ )
865
+ class_probs = (class_probs * 100).round().astype(np.uint8)
866
+
867
+ return np.hstack([labels[:, None], probabilities[:, None], class_probs]).T
868
+
869
+ def _reshape_predictions(
870
+ self, predictions: np.ndarray, original_features: xr.DataArray, lut: Dict
871
+ ) -> xr.DataArray:
872
+ """Reshape predictions to match original spatial dimensions."""
873
+ output_labels = get_output_labels(lut)
874
+ x_coords, y_coords = original_features.x.values, original_features.y.values
875
+
876
+ reshaped = predictions.reshape(
877
+ (len(output_labels), len(x_coords), len(y_coords))
878
+ )
879
+
880
+ return xr.DataArray(
881
+ reshaped,
882
+ dims=["bands", "x", "y"],
883
+ coords={"bands": output_labels, "x": x_coords, "y": y_coords},
884
+ ).transpose("bands", "y", "x")
885
+
886
+
887
+ class Postprocessor:
888
+ """Handles postprocessing of classification results."""
889
+
890
+ def __init__(self, parameters: Dict[str, Any], classifier_url: str):
891
+ self.parameters = parameters
892
+ self.classifier_url = classifier_url
893
+
894
+ def apply(self, inarr: xr.DataArray) -> xr.DataArray:
895
+ inarr = inarr.transpose(
896
+ "bands", "y", "x"
897
+ ) # Ensure correct dimension order for openEO backend
898
+
899
+ _, lookup_table = load_onnx_model_cached(self.classifier_url)
900
+
901
+ if self.parameters.get("method") == "smooth_probabilities":
902
+ # Cast to float for more accurate gaussian smoothing
903
+ class_probabilities = (
904
+ inarr.isel(bands=slice(2, None)).astype("float32") / 100.0
905
+ )
906
+
907
+ # Peform probability smoothing
908
+ class_probabilities = smooth_probabilities(
909
+ inarr.sel(bands="classification"), class_probabilities
910
+ )
911
+
912
+ # Reclassify
913
+ new_labels = reclassify(
914
+ inarr.sel(bands="classification"),
915
+ inarr.sel(bands="probability"),
916
+ class_probabilities,
917
+ )
918
+
919
+ # Re-apply labels
920
+ class_labels = list(lookup_table.values())
921
+
922
+ # Create a final labels array with same dimensions as new_labels
923
+ final_labels = xr.full_like(new_labels, fill_value=65535)
924
+ for idx, label in enumerate(class_labels):
925
+ final_labels.loc[{"bands": "classification"}] = xr.where(
926
+ new_labels.sel(bands="classification") == idx,
927
+ label,
928
+ final_labels.sel(bands="classification"),
929
+ )
930
+ new_labels.sel(bands="classification").values = final_labels.sel(
931
+ bands="classification"
932
+ ).values
933
+
934
+ # Append the per-class probabalities if required
935
+ if self.parameters.get("keep_class_probs", False):
936
+ new_labels = xr.concat([new_labels, class_probabilities], dim="bands")
937
+
938
+ elif self.parameters.get("method") == "majority_vote":
939
+ kernel_size = self.parameters.get("kernel_size", 5)
940
+
941
+ new_labels = majority_vote(
942
+ inarr.sel(bands="classification"),
943
+ inarr.sel(bands="probability"),
944
+ kernel_size=kernel_size,
945
+ )
946
+
947
+ # Append the per-class probabalities if required
948
+ if self.parameters.get("keep_class_probs", False):
949
+ class_probabilities = inarr.isel(bands=slice(2, None))
950
+ new_labels = xr.concat([new_labels, class_probabilities], dim="bands")
951
+
952
+ else:
953
+ raise ValueError(
954
+ f"Unknown post-processing method: {self.parameters.get('method')}"
955
+ )
956
+
957
+ new_labels = new_labels.transpose(
958
+ "bands", "y", "x"
959
+ ) # Ensure correct dimension order for openEO backend
960
+
961
+ return new_labels
962
+
963
+
964
+ # =============================================================================
965
+ # MAIN UDF FUNCTIONS
966
+ # =============================================================================
967
+
968
+
969
+ def run_single_workflow(
970
+ input_array: xr.DataArray,
971
+ epsg: int,
972
+ parameters: Dict[str, Any],
973
+ mask: Optional[xr.DataArray] = None,
974
+ ) -> xr.DataArray:
975
+ """Run a single classification workflow with optional masking."""
976
+
977
+ # Preprocess data
978
+ if parameters["feature_parameters"].get("rescale_s1", True):
979
+ logger.info("Rescale s1 ...")
980
+ input_array = DataPreprocessor.rescale_s1_backscatter(input_array)
981
+
982
+ # Extract features
983
+ logger.info("Extract Presto embeddings ...")
984
+ feature_extractor = PrestoFeatureExtractor(parameters["feature_parameters"])
985
+ features = feature_extractor.extract(input_array, epsg)
986
+ logger.info("Presto embedding extraction done.")
987
+
988
+ # Classify
989
+ logger.info("Onnx classification ...")
990
+ classifier = ONNXClassifier(parameters["classifier_parameters"])
991
+ classes = classifier.predict(features)
992
+ logger.info("Onnx classification done.")
993
+
994
+ # Postprocess
995
+ postprocess_parameters: Dict[str, Any] = parameters.get(
996
+ "postprocess_parameters", {}
997
+ )
998
+
999
+ if postprocess_parameters.get("enable"):
1000
+ logger.info("Postprocessing classification results ...")
1001
+ if postprocess_parameters.get("save_intermediate"):
1002
+ classes_raw = classes.assign_coords(
1003
+ bands=[f"raw_{b}" for b in list(classes.bands.values)]
1004
+ )
1005
+ postprocessor = Postprocessor(
1006
+ postprocess_parameters,
1007
+ classifier_url=parameters.get("classifier_parameters", {}).get(
1008
+ "classifier_url"
1009
+ ),
1010
+ )
1011
+
1012
+ classes = postprocessor.apply(classes)
1013
+ if postprocess_parameters.get("save_intermediate"):
1014
+ classes = xr.concat([classes, classes_raw], dim="bands")
1015
+ logger.info("Postprocessing done.")
1016
+
1017
+ # Set masked areas to specific value
1018
+ if mask is not None:
1019
+ logger.info("`mask` provided, applying to classification results ...")
1020
+ classes = classes.where(mask, 254) # 254 = non-cropland
1021
+
1022
+ return classes
1023
+
1024
+
1025
+ def combine_results(
1026
+ croptype_result: xr.DataArray, cropland_result: xr.DataArray
1027
+ ) -> xr.DataArray:
1028
+ """Combine crop type results with ALL cropland classification bands."""
1029
+
1030
+ # Rename cropland bands to avoid conflicts
1031
+ cropland_bands_renamed = [
1032
+ f"cropland_{band}" for band in cropland_result.bands.values
1033
+ ]
1034
+ cropland_result = cropland_result.assign_coords(bands=cropland_bands_renamed)
1035
+
1036
+ # Rename croptype bands for clarity
1037
+ croptype_bands_renamed = [
1038
+ f"croptype_{band}" for band in croptype_result.bands.values
1039
+ ]
1040
+ croptype_result = croptype_result.assign_coords(bands=croptype_bands_renamed)
1041
+
1042
+ # Combine all bands from both results
1043
+ combined_bands = list(croptype_bands_renamed) + list(cropland_bands_renamed)
1044
+ combined_data = np.concatenate(
1045
+ [croptype_result.values, cropland_result.values], axis=0
1046
+ )
1047
+
1048
+ result = xr.DataArray(
1049
+ combined_data,
1050
+ dims=["bands", "y", "x"],
1051
+ coords={
1052
+ "bands": combined_bands,
1053
+ "y": croptype_result.y,
1054
+ "x": croptype_result.x,
1055
+ },
1056
+ )
1057
+
1058
+ return result
1059
+
1060
+
1061
+ def apply_udf_data(udf_data: UdfData) -> UdfData:
1062
+ """Main UDF entry point - expects cropland_params and croptype_params in context."""
1063
+
1064
+ input_cube = udf_data.datacube_list[0]
1065
+ parameters = udf_data.user_context.copy()
1066
+
1067
+ epsg = udf_data.proj["EPSG"] if udf_data.proj else None
1068
+ if epsg is None:
1069
+ raise ValueError("EPSG code not found in projection information")
1070
+
1071
+ # Prepare input array
1072
+ input_array = input_cube.get_array().transpose("bands", "t", "y", "x")
1073
+
1074
+ # Extract both parameter sets directly from context
1075
+ cropland_params = parameters.get("cropland_params", {})
1076
+ croptype_params = parameters.get("croptype_params", {})
1077
+
1078
+ # Check if we have both parameter sets for dual workflow
1079
+ if cropland_params and croptype_params:
1080
+ logger.info(
1081
+ "Running combined workflow: cropland masking + croptype mapping ..."
1082
+ )
1083
+
1084
+ # Run cropland classification - pass the FLAT parameters
1085
+ logger.info("Running cropland classification ...")
1086
+ cropland_result = run_single_workflow(input_array, epsg, cropland_params)
1087
+ logger.info("Cropland classification done.")
1088
+
1089
+ # Extract cropland mask for masking the crop type classification
1090
+ cropland_mask = cropland_result.sel(bands="classification") > 0
1091
+
1092
+ # Run crop type classification with mask
1093
+ logger.info("Running crop type classification ...")
1094
+ croptype_result = run_single_workflow(
1095
+ input_array, epsg, croptype_params, cropland_mask
1096
+ )
1097
+ logger.info("Croptype classification done.")
1098
+
1099
+ # Combine ALL bands from both results
1100
+ result = combine_results(croptype_result, cropland_result)
1101
+ result_cube = XarrayDataCube(result)
1102
+
1103
+ else:
1104
+ # Single workflow (fallback to original behavior)
1105
+ logger.info("Running single workflow ...")
1106
+ result = run_single_workflow(input_array, epsg, parameters)
1107
+ result_cube = XarrayDataCube(result)
1108
+
1109
+ udf_data.datacube_list = [result_cube]
1110
+
1111
+ return udf_data
1112
+
1113
+
1114
+ def apply_metadata(metadata, context: Dict) -> Any:
1115
+ """Update collection metadata for combined output with ALL bands.
1116
+
1117
+ Band naming logic summary (kept for mapping module resilience):
1118
+ - Single workflow (either cropland OR croptype parameters only):
1119
+ Base bands: classification, probability, probability_<class>
1120
+ If save_intermediate: raw_<band> duplicates are appended.
1121
+ - Combined workflow (both croptype_params & cropland_params):
1122
+ Prefixed bands: croptype_<band> and cropland_<band>
1123
+ If save_intermediate: croptype_raw_<band> and cropland_raw_<band> duplicates appended.
1124
+
1125
+ No renaming occurs here beyond prefixing for the combined workflow; logic in
1126
+ mapping.py must therefore accept both prefixed and unprefixed forms.
1127
+ """
1128
+ try:
1129
+ # For dual workflow, combine band names from both models
1130
+ if "croptype_params" in context and "cropland_params" in context:
1131
+ # Get croptype band names
1132
+ croptype_classifier_url = context["croptype_params"][
1133
+ "classifier_parameters"
1134
+ ].get("classifier_url")
1135
+ if croptype_classifier_url:
1136
+ _, croptype_lut = load_onnx_model_cached(croptype_classifier_url)
1137
+ postprocess_parameters = context["croptype_params"].get(
1138
+ "postprocess_parameters", {}
1139
+ )
1140
+ croptype_bands = [
1141
+ f"croptype_{band}"
1142
+ for band in get_output_labels(croptype_lut, postprocess_parameters)
1143
+ ]
1144
+ if postprocess_parameters.get("save_intermediate", False):
1145
+ croptype_bands += [
1146
+ band.replace("croptype_", "croptype_raw_")
1147
+ for band in croptype_bands
1148
+ ]
1149
+ else:
1150
+ raise ValueError("No croptype LUT found")
1151
+
1152
+ # Get cropland band names
1153
+ cropland_classifier_url = context["cropland_params"][
1154
+ "classifier_parameters"
1155
+ ].get("classifier_url")
1156
+ if cropland_classifier_url:
1157
+ _, cropland_lut = load_onnx_model_cached(cropland_classifier_url)
1158
+ postprocess_parameters = context["cropland_params"].get(
1159
+ "postprocess_parameters", {}
1160
+ )
1161
+ cropland_bands = [
1162
+ f"cropland_{band}"
1163
+ for band in get_output_labels(cropland_lut, postprocess_parameters)
1164
+ ]
1165
+ if postprocess_parameters.get("save_intermediate", False):
1166
+ cropland_bands += [
1167
+ band.replace("cropland_", "cropland_raw_")
1168
+ for band in cropland_bands
1169
+ ]
1170
+ else:
1171
+ raise ValueError("No cropland LUT found")
1172
+
1173
+ output_labels = croptype_bands + cropland_bands
1174
+
1175
+ else:
1176
+ # Single workflow
1177
+ classifier_url = context["classifier_parameters"].get("classifier_url")
1178
+ if classifier_url:
1179
+ _, lut_sorted = load_onnx_model_cached(classifier_url)
1180
+ postprocess_parameters = context.get("postprocess_parameters", {})
1181
+ output_labels = get_output_labels(lut_sorted, postprocess_parameters)
1182
+ if postprocess_parameters.get("save_intermediate", False):
1183
+ output_labels += [f"raw_{band}" for band in output_labels]
1184
+ else:
1185
+ raise ValueError("No classifier URL found in context")
1186
+
1187
+ return metadata.rename_labels(dimension="bands", target=output_labels)
1188
+
1189
+ except Exception as e:
1190
+ logger.warning(f"Could not load model in metadata context: {e}")
1191
+ return metadata
worldcereal/openeo/mapping.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mapping helpers for cropland, croptype and embeddings products.
2
+
3
+ Band naming conventions produced by the UDF (`inference.py`):
4
+
5
+ Single workflow (only cropland OR only croptype parameters passed to UDF):
6
+ classification, probability, probability_<class>
7
+ If save_intermediate: raw_<band> duplicates (e.g. raw_classification)
8
+
9
+ Combined workflow (croptype with cropland masking: both `croptype_params` &
10
+ `cropland_params` passed):
11
+ croptype_<band>, cropland_<band>
12
+ If save_intermediate: croptype_raw_<band>, cropland_raw_<band>
13
+ Example: croptype_classification -> croptype_raw_classification
14
+
15
+ Important: Raw bands in the combined workflow do NOT duplicate the base prefix;
16
+ they simply replace the leading product prefix with <product>_raw_.
17
+
18
+ Simplification: We ignore any *save_intermediate* flags. If raw bands are
19
+ present we save them; the UDF only emits them when intermediate results were
20
+ requested upstream.
21
+ """
22
+
23
+ from pathlib import Path
24
+ from typing import List
25
+
26
+ import openeo
27
+ from openeo import DataCube
28
+ from openeo_gfmap import TemporalContext
29
+ from openeo_gfmap.preprocessing.scaling import compress_uint16
30
+
31
+ from worldcereal.openeo.inference import apply_metadata
32
+ from worldcereal.parameters import (
33
+ CropLandParameters,
34
+ CropTypeParameters,
35
+ EmbeddingsParameters,
36
+ WorldCerealProductType,
37
+ )
38
+
39
+ NEIGHBORHOOD_SPEC = dict(
40
+ size=[
41
+ {"dimension": "x", "unit": "px", "value": 128},
42
+ {"dimension": "y", "unit": "px", "value": 128},
43
+ ],
44
+ overlap=[
45
+ {"dimension": "x", "unit": "px", "value": 0},
46
+ {"dimension": "y", "unit": "px", "value": 0},
47
+ ],
48
+ )
49
+
50
+
51
+ def _run_udf(inputs: DataCube, udf: openeo.UDF) -> DataCube:
52
+ return inputs.apply_neighborhood(process=udf, **NEIGHBORHOOD_SPEC)
53
+
54
+
55
+ def _reduce_temporal_mean(cube: DataCube) -> DataCube:
56
+ return cube.reduce_dimension(dimension="t", reducer="mean")
57
+
58
+
59
+ def _filename_prefix(
60
+ product: WorldCerealProductType, temporal: TemporalContext, raw: bool = False
61
+ ) -> str:
62
+ suffix = "-raw" if raw else ""
63
+ return f"{product.value}{suffix}_{temporal.start_date}_{temporal.end_date}"
64
+
65
+
66
+ def _save_result(cube: DataCube, prefix: str) -> DataCube:
67
+ return cube.save_result(format="GTiff", options={"filename_prefix": prefix})
68
+
69
+
70
+ def _cropland_map(
71
+ inputs: DataCube,
72
+ temporal_extent: TemporalContext,
73
+ cropland_parameters: CropLandParameters,
74
+ ) -> List[DataCube]:
75
+ """Produce cropland product from preprocessed inputs (single workflow).
76
+
77
+ Saves final bands and any raw_* bands purely based on presence.
78
+ """
79
+ inference_udf = openeo.UDF.from_file(
80
+ path=Path(__file__).resolve().parent / "inference.py",
81
+ context=cropland_parameters.model_dump(),
82
+ )
83
+ classes = _run_udf(inputs, inference_udf)
84
+ classes.metadata = apply_metadata(
85
+ classes.metadata, cropland_parameters.model_dump()
86
+ )
87
+ classes = _reduce_temporal_mean(classes)
88
+ classes = compress_uint16(classes)
89
+
90
+ bands = classes.metadata.band_names
91
+ result_cubes: List[DataCube] = []
92
+
93
+ final_bands = [b for b in bands if not b.startswith("raw_")]
94
+ if final_bands:
95
+ final_cube = classes.filter_bands(final_bands)
96
+ result_cubes.append(
97
+ _save_result(
98
+ final_cube,
99
+ _filename_prefix(WorldCerealProductType.CROPLAND, temporal_extent),
100
+ )
101
+ )
102
+
103
+ raw_bands = [b for b in bands if b.startswith("raw_")]
104
+ if raw_bands:
105
+ raw_cube = classes.filter_bands(raw_bands)
106
+ result_cubes.append(
107
+ _save_result(
108
+ raw_cube,
109
+ _filename_prefix(
110
+ WorldCerealProductType.CROPLAND, temporal_extent, raw=True
111
+ ),
112
+ )
113
+ )
114
+
115
+ return result_cubes
116
+
117
+
118
+ def _croptype_map(
119
+ inputs: DataCube,
120
+ temporal_extent: TemporalContext,
121
+ croptype_parameters: CropTypeParameters,
122
+ cropland_parameters: CropLandParameters,
123
+ ) -> List[DataCube]:
124
+ """Produce crop type product. Optionally includes cropland masking.
125
+ Cropland mask final bands saved only if `croptype_parameters.save_mask` is True.
126
+ """
127
+ if croptype_parameters.mask_cropland:
128
+ parameters = {
129
+ "cropland_params": cropland_parameters.model_dump(),
130
+ "croptype_params": croptype_parameters.model_dump(),
131
+ }
132
+ else:
133
+ parameters = croptype_parameters.model_dump()
134
+
135
+ inference_udf = openeo.UDF.from_file(
136
+ path=Path(__file__).resolve().parent / "inference.py",
137
+ context=parameters,
138
+ )
139
+ classes = _run_udf(inputs, inference_udf)
140
+ classes.metadata = apply_metadata(classes.metadata, parameters)
141
+ classes = _reduce_temporal_mean(classes)
142
+ classes = compress_uint16(classes)
143
+
144
+ bands = classes.metadata.band_names
145
+ result_cubes: List[DataCube] = []
146
+
147
+ if croptype_parameters.mask_cropland:
148
+ # Prefixed croptype final and raw bands
149
+ croptype_final_bands = [
150
+ b for b in bands if b.startswith("croptype_") and "raw" not in b
151
+ ]
152
+ # Raw croptype bands (presence-based)
153
+ raw_croptype_bands = [b for b in bands if b.startswith("croptype_raw_")]
154
+ else:
155
+ # Single workflow: unprefixed croptype bands
156
+ croptype_final_bands = [b for b in bands if not b.startswith("raw_")]
157
+ raw_croptype_bands = [b for b in bands if b.startswith("raw_")]
158
+
159
+ # Final croptype
160
+ croptype_cube = classes.filter_bands(croptype_final_bands).rename_labels(
161
+ dimension="bands",
162
+ target=[
163
+ b.replace("croptype_", "") for b in croptype_final_bands
164
+ ], # Remove prefix
165
+ )
166
+ result_cubes.append(
167
+ _save_result(
168
+ croptype_cube,
169
+ _filename_prefix(WorldCerealProductType.CROPTYPE, temporal_extent),
170
+ )
171
+ )
172
+
173
+ # Raw croptype if present
174
+ if raw_croptype_bands:
175
+ raw_croptype_cube = classes.filter_bands(raw_croptype_bands).rename_labels(
176
+ dimension="bands",
177
+ target=[
178
+ b.replace("croptype_", "") for b in raw_croptype_bands
179
+ ], # Remove prefix
180
+ )
181
+ result_cubes.append(
182
+ _save_result(
183
+ raw_croptype_cube,
184
+ _filename_prefix(
185
+ WorldCerealProductType.CROPTYPE, temporal_extent, raw=True
186
+ ),
187
+ )
188
+ )
189
+
190
+ # Optional cropland mask & raw cropland bands
191
+ if croptype_parameters.save_mask:
192
+ cropland_final_bands = [
193
+ b
194
+ for b in bands
195
+ if b.startswith("cropland_") and not b.startswith("cropland_raw_")
196
+ ]
197
+ cropland_cube = classes.filter_bands(cropland_final_bands).rename_labels(
198
+ dimension="bands",
199
+ target=[
200
+ b.replace("cropland_", "") for b in cropland_final_bands
201
+ ], # Remove prefix
202
+ )
203
+ result_cubes.append(
204
+ _save_result(
205
+ cropland_cube,
206
+ _filename_prefix(WorldCerealProductType.CROPLAND, temporal_extent),
207
+ )
208
+ )
209
+ raw_cropland_bands = [b for b in bands if b.startswith("cropland_raw_")]
210
+ if raw_cropland_bands:
211
+ raw_cropland_cube = classes.filter_bands(raw_cropland_bands).rename_labels(
212
+ dimension="bands",
213
+ target=[
214
+ b.replace("cropland_", "") for b in raw_cropland_bands
215
+ ], # Remove prefix
216
+ )
217
+ result_cubes.append(
218
+ _save_result(
219
+ raw_cropland_cube,
220
+ _filename_prefix(
221
+ WorldCerealProductType.CROPLAND, temporal_extent, raw=True
222
+ ),
223
+ )
224
+ )
225
+
226
+ return result_cubes
227
+
228
+
229
+ def _embeddings_map(
230
+ inputs: DataCube,
231
+ temporal_extent: TemporalContext, # temporal extent unused but kept for signature consistency
232
+ embeddings_parameters: EmbeddingsParameters,
233
+ scale_uint16: bool = True,
234
+ ) -> DataCube:
235
+ """Produce embeddings map using Prometheo feature extractor."""
236
+
237
+ feature_udf = openeo.UDF.from_file(
238
+ path=Path(__file__).resolve().parent / "feature_extractor.py",
239
+ context=embeddings_parameters.feature_parameters.model_dump(),
240
+ )
241
+ embeddings = _run_udf(inputs, feature_udf)
242
+ embeddings = _reduce_temporal_mean(embeddings)
243
+
244
+ if scale_uint16:
245
+ OFFSET = -6
246
+ SCALE = 0.0002
247
+ embeddings = (embeddings - OFFSET) / SCALE
248
+ embeddings = embeddings.linear_scale_range(0, 65534, 0, 65534)
249
+
250
+ return embeddings
worldcereal/openeo/preprocessing.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any, Dict, List, Literal, Optional, Union
3
+
4
+ import pandas as pd
5
+ from geojson import GeoJSON
6
+ from openeo import UDF, Connection, DataCube
7
+ from openeo_gfmap import (
8
+ Backend,
9
+ BackendContext,
10
+ BoundingBoxExtent,
11
+ FetchType,
12
+ SpatialContext,
13
+ TemporalContext,
14
+ )
15
+ from openeo_gfmap.fetching.generic import build_generic_extractor
16
+ from openeo_gfmap.fetching.s1 import build_sentinel1_grd_extractor
17
+ from openeo_gfmap.fetching.s2 import build_sentinel2_l2a_extractor
18
+ from openeo_gfmap.preprocessing.compositing import mean_compositing, median_compositing
19
+ from openeo_gfmap.preprocessing.sar import compress_backscatter_uint16
20
+ from openeo_gfmap.utils.catalogue import UncoveredS1Exception, select_s1_orbitstate_vvvh
21
+
22
+ WORLDCEREAL_S2_BANDS = [
23
+ "S2-L2A-B02",
24
+ "S2-L2A-B03",
25
+ "S2-L2A-B04",
26
+ "S2-L2A-B05",
27
+ "S2-L2A-B06",
28
+ "S2-L2A-B07",
29
+ "S2-L2A-B08",
30
+ "S2-L2A-B8A",
31
+ "S2-L2A-B11",
32
+ "S2-L2A-B12",
33
+ ]
34
+
35
+ WORLDCEREAL_S1_BANDS = [
36
+ "S1-SIGMA0-VH",
37
+ "S1-SIGMA0-VV",
38
+ ]
39
+
40
+ WORLDCEREAL_DEM_BANDS = ["elevation", "slope"]
41
+
42
+ WORLDCEREAL_METEO_BANDS = ["AGERA5-PRECIP", "AGERA5-TMEAN"]
43
+
44
+ WORLDCEREAL_BANDS = {
45
+ "SENTINEL2": WORLDCEREAL_S2_BANDS,
46
+ "SENTINEL1": WORLDCEREAL_S1_BANDS,
47
+ "DEM": WORLDCEREAL_DEM_BANDS,
48
+ "METEO": WORLDCEREAL_METEO_BANDS,
49
+ }
50
+
51
+
52
+ class InvalidTemporalContextError(Exception):
53
+ pass
54
+
55
+
56
+ def spatially_filter_cube(
57
+ connection: Connection, cube: DataCube, spatial_extent: Optional[SpatialContext]
58
+ ) -> DataCube:
59
+ """
60
+ Apply spatial filtering to a data cube based on the given spatial extent.
61
+
62
+
63
+ Parameters
64
+ ----------
65
+ connection : Connection
66
+ The connection object used to interact with the openEO backend.
67
+ cube : DataCube
68
+ The input data cube to be spatially filtered.
69
+ spatial_extent : Optional[SpatialContext]
70
+ The spatial extent used for filtering the data cube. It can be a BoundingBoxExtent,
71
+ a GeoJSON object, or a URL to a GeoJSON or Parquet file. If set to `None`,
72
+ no spatial filtering will be applied.
73
+
74
+ Returns
75
+ -------
76
+ DataCube
77
+ The spatially filtered data cube.
78
+
79
+ Raises
80
+ ------
81
+ ValueError
82
+ If the spatial_extent parameter is not of type BoundingBoxExtent, GeoJSON, or str.
83
+
84
+ """
85
+ if isinstance(spatial_extent, BoundingBoxExtent):
86
+ cube = cube.filter_bbox(dict(spatial_extent))
87
+ elif isinstance(spatial_extent, GeoJSON):
88
+ cube = cube.filter_spatial(spatial_extent)
89
+ elif isinstance(spatial_extent, str):
90
+ geometry = connection.load_url(
91
+ spatial_extent,
92
+ format=(
93
+ "Parquet"
94
+ if ".parquet" in spatial_extent or ".geoparquet" in spatial_extent
95
+ else "GeoJSON"
96
+ ),
97
+ )
98
+ cube = cube.filter_spatial(geometry)
99
+
100
+ return cube
101
+
102
+
103
+ def select_best_s1_orbit_direction(
104
+ backend_context: BackendContext,
105
+ spatial_extent: SpatialContext,
106
+ temporal_extent: TemporalContext,
107
+ ) -> str:
108
+ """Selects the best Sentinel-1 orbit direction based on the given spatio-temporal context.
109
+
110
+ Parameters
111
+ ----------
112
+ backend_context : BackendContext
113
+ The backend context for accessing the data.
114
+ spatial_extent : SpatialContext
115
+ The spatial extent of the data.
116
+ temporal_extent : TemporalContext
117
+ The temporal extent of the data.
118
+
119
+ Returns
120
+ -------
121
+ str
122
+ The selected orbit direction (either "ASCENDING" or "DESCENDING").
123
+ """
124
+ try:
125
+ orbit_direction = select_s1_orbitstate_vvvh(
126
+ backend_context, spatial_extent, temporal_extent
127
+ )
128
+ except UncoveredS1Exception as exc:
129
+ orbit_direction = "ASCENDING"
130
+ print(
131
+ f"Could not find any Sentinel-1 data for the given spatio-temporal context. "
132
+ f"Using ASCENDING orbit direction as a last resort. Error: {exc}"
133
+ )
134
+
135
+ return orbit_direction
136
+
137
+
138
+ def raw_datacube_S2(
139
+ connection: Connection,
140
+ backend_context: BackendContext,
141
+ temporal_extent: TemporalContext,
142
+ bands: List[str],
143
+ fetch_type: FetchType,
144
+ spatial_extent: Optional[SpatialContext] = None,
145
+ filter_tile: Optional[str] = None,
146
+ distance_to_cloud_flag: Optional[bool] = True,
147
+ additional_masks_flag: Optional[bool] = True,
148
+ apply_mask_flag: Optional[bool] = False,
149
+ tile_size: Optional[int] = None,
150
+ target_epsg: Optional[int] = None,
151
+ ) -> DataCube:
152
+ """Extract Sentinel-2 datacube from OpenEO using GFMAP routines.
153
+ Raw data is extracted with no cloud masking applied by default (can be
154
+ enabled by setting `apply_mask=True`). In additional to the raw band values
155
+ a cloud-mask computed from the dilation of the SCL layer, as well as a
156
+ rank mask from the BAP compositing are added.
157
+
158
+ Parameters
159
+ ----------
160
+ connection : Connection
161
+ OpenEO connection instance.
162
+ backend_context : BackendContext
163
+ GFMAP Backend context to use for extraction.
164
+ temporal_extent : TemporalContext
165
+ Temporal context to extract data from.
166
+ bands : List[str]
167
+ List of Sentinel-2 bands to extract.
168
+ fetch_type : FetchType
169
+ GFMAP Fetch type to use for extraction.
170
+ spatial_extent : Optional[SpatialContext], optional
171
+ Spatial context to extract data from, can be a GFMAP BoundingBoxExtent,
172
+ a GeoJSON dict or an URL to a publicly accessible GeoParquet file.
173
+ filter_tile : Optional[str], optional
174
+ Filter by tile ID, by default disabled. This forces the process to only
175
+ one tile ID from the Sentinel-2 collection.
176
+ apply_mask : bool, optional
177
+ Apply cloud masking, by default False. Can be enabled for high
178
+ optimization of memory usage.
179
+ target_epsg : Optional[int], optional
180
+ Target EPSG to resample the data, by default None.
181
+ """
182
+ # Extract the SCL collection only
183
+ scl_cube_properties = {"eo:cloud_cover": lambda val: val <= 95.0}
184
+ if filter_tile:
185
+ scl_cube_properties["tileId"] = lambda val: val == filter_tile
186
+
187
+ # Create the job to extract S2
188
+ extraction_parameters: dict[str, Any] = {
189
+ "target_resolution": 10,
190
+ "target_crs": target_epsg,
191
+ "load_collection": {
192
+ "eo:cloud_cover": lambda val: val <= 95.0,
193
+ },
194
+ }
195
+
196
+ scl_cube = connection.load_collection(
197
+ collection_id="SENTINEL2_L2A",
198
+ bands=["SCL"],
199
+ temporal_extent=[temporal_extent.start_date, temporal_extent.end_date],
200
+ properties=scl_cube_properties,
201
+ )
202
+
203
+ # Resample to 10m resolution for the SCL layer, using optional target_epsg
204
+ scl_cube = scl_cube.resample_spatial(projection=target_epsg, resolution=10)
205
+
206
+ # Compute the SCL dilation mask
207
+ scl_dilated_mask = scl_cube.process(
208
+ "to_scl_dilation_mask",
209
+ data=scl_cube,
210
+ scl_band_name="SCL",
211
+ kernel1_size=17, # 17px dilation on a 10m layer
212
+ kernel2_size=77, # 77px dilation on a 10m layer
213
+ mask1_values=[2, 4, 5, 6, 7],
214
+ mask2_values=[3, 8, 9, 10, 11],
215
+ erosion_kernel_size=3,
216
+ ).rename_labels("bands", ["S2-L2A-SCL_DILATED_MASK"])
217
+
218
+ additional_masks = scl_dilated_mask
219
+
220
+ if distance_to_cloud_flag:
221
+ # Compute the distance to cloud and add it to the cube
222
+ distance_to_cloud = scl_cube.apply_neighborhood(
223
+ process=UDF.from_file(Path(__file__).parent / "udf_distance_to_cloud.py"),
224
+ size=[
225
+ {"dimension": "x", "unit": "px", "value": 256},
226
+ {"dimension": "y", "unit": "px", "value": 256},
227
+ {"dimension": "t", "unit": "null", "value": "P1D"},
228
+ ],
229
+ overlap=[
230
+ {"dimension": "x", "unit": "px", "value": 16},
231
+ {"dimension": "y", "unit": "px", "value": 16},
232
+ ],
233
+ ).rename_labels("bands", ["S2-L2A-DISTANCE-TO-CLOUD"])
234
+
235
+ additional_masks = scl_dilated_mask.merge_cubes(distance_to_cloud)
236
+
237
+ if additional_masks_flag:
238
+ extraction_parameters["pre_merge"] = additional_masks
239
+
240
+ if filter_tile:
241
+ extraction_parameters["load_collection"]["tileId"] = (
242
+ lambda val: val == filter_tile
243
+ )
244
+
245
+ if tile_size is not None:
246
+ extraction_parameters["update_arguments"] = {
247
+ "featureflags": {"tilesize": tile_size}
248
+ }
249
+
250
+ s2_cube = build_sentinel2_l2a_extractor(
251
+ backend_context,
252
+ bands=bands,
253
+ fetch_type=fetch_type,
254
+ **extraction_parameters,
255
+ ).get_cube(connection, None, temporal_extent)
256
+
257
+ if apply_mask_flag:
258
+ s2_cube = s2_cube.mask(scl_dilated_mask)
259
+
260
+ return s2_cube
261
+
262
+
263
+ def raw_datacube_S1(
264
+ connection: Connection,
265
+ backend_context: BackendContext,
266
+ temporal_extent: TemporalContext,
267
+ bands: List[str],
268
+ fetch_type: FetchType,
269
+ spatial_extent: Optional[SpatialContext] = None,
270
+ target_resolution: float = 20.0,
271
+ orbit_direction: Optional[str] = None,
272
+ tile_size: Optional[int] = None,
273
+ target_epsg: Optional[int] = None,
274
+ ) -> DataCube:
275
+ """Extract Sentinel-1 datacube from OpenEO using GFMAP routines.
276
+
277
+ Parameters
278
+ ----------
279
+ connection : Connection
280
+ OpenEO connection instance.
281
+ backend_context : BackendContext
282
+ GFMAP Backend context to use for extraction.
283
+ temporal_extent : TemporalContext
284
+ Temporal context to extract data from.
285
+ bands : List[str]
286
+ List of Sentinel-1 bands to extract.
287
+ fetch_type : FetchType
288
+ GFMAP Fetch type to use for extraction.
289
+ spatial_extent : Optional[SpatialContext], optional
290
+ Spatial context to extract data from, can be a GFMAP BoundingBoxExtent,
291
+ a GeoJSON dict or an URL to a publicly accessible GeoParquet file.
292
+ target_resolution : float, optional
293
+ Target resolution to resample the data to, by default 20.0.
294
+ orbit_direction : Optional[str], optional
295
+ Orbit direction to filter the data, by default None.
296
+ target_epsg : Optional[int], optional
297
+ Target EPSG to resample the data to, by default None.
298
+ """
299
+ extractor_parameters: Dict[str, Any] = {
300
+ "target_resolution": target_resolution,
301
+ "target_crs": target_epsg,
302
+ }
303
+
304
+ if orbit_direction is not None:
305
+ extractor_parameters["load_collection"] = {
306
+ "sat:orbit_state": lambda orbit: orbit == orbit_direction,
307
+ "polarisation": lambda pol: pol == "VV&VH",
308
+ }
309
+ else:
310
+ extractor_parameters["load_collection"] = {
311
+ "polarisation": lambda pol: pol == "VV&VH",
312
+ }
313
+
314
+ if tile_size is not None:
315
+ extractor_parameters["update_arguments"] = {
316
+ "featureflags": {"tilesize": tile_size}
317
+ }
318
+
319
+ s1_cube = build_sentinel1_grd_extractor(
320
+ backend_context, bands=bands, fetch_type=fetch_type, **extractor_parameters
321
+ ).get_cube(connection, None, temporal_extent)
322
+
323
+ return s1_cube
324
+
325
+
326
+ def raw_datacube_DEM(
327
+ connection: Connection,
328
+ backend_context: BackendContext,
329
+ fetch_type: FetchType,
330
+ spatial_extent: Optional[SpatialContext] = None,
331
+ ) -> DataCube:
332
+ """Method to get the DEM datacube from the backend.
333
+ If running on CDSE backend, the slope is also loaded from the global
334
+ slope collection and merged with the DEM cube.
335
+
336
+ Returns
337
+ -------
338
+ DataCube
339
+ openEO datacube with the DEM data (and slope if available).
340
+ """
341
+
342
+ extractor = build_generic_extractor(
343
+ backend_context=backend_context,
344
+ bands=["COP-DEM"],
345
+ fetch_type=fetch_type,
346
+ collection_name="COPERNICUS_30",
347
+ )
348
+
349
+ cube = extractor.get_cube(connection, None, None)
350
+ cube = cube.rename_labels(dimension="bands", target=["elevation"])
351
+
352
+ if backend_context.backend in [Backend.CDSE, Backend.CDSE_STAGING]:
353
+ # On CDSE we can load the slope from a global slope collection
354
+ slope = connection.load_stac(
355
+ "https://stac.openeo.vito.be/collections/COPERNICUS30_DEM_SLOPE",
356
+ bands=["Slope"],
357
+ ).rename_labels(dimension="bands", target=["slope"])
358
+ # Client fix for CDSE, the openeo client might be unsynchronized with
359
+ # the backend.
360
+ if "t" not in slope.metadata.dimension_names():
361
+ slope.metadata = slope.metadata.add_dimension("t", "2020-01-01", "temporal")
362
+ slope = slope.min_time()
363
+
364
+ # Note that when slope is available we use it as the base cube
365
+ # to merge DEM with, as it comes at 20m resolution.
366
+ cube = slope.merge_cubes(cube)
367
+
368
+ return cube
369
+
370
+
371
+ def raw_datacube_METEO(
372
+ connection: Connection,
373
+ backend_context: BackendContext,
374
+ temporal_extent: TemporalContext,
375
+ fetch_type: FetchType,
376
+ spatial_extent: Optional[SpatialContext] = None,
377
+ ) -> DataCube:
378
+ extractor = build_generic_extractor(
379
+ backend_context=backend_context,
380
+ bands=["AGERA5-TMEAN", "AGERA5-PRECIP"],
381
+ fetch_type=fetch_type,
382
+ collection_name="AGERA5",
383
+ )
384
+
385
+ meteo_cube = extractor.get_cube(connection, None, temporal_extent)
386
+
387
+ return meteo_cube
388
+
389
+
390
+ def precomposited_datacube_METEO(
391
+ connection: Connection,
392
+ temporal_extent: TemporalContext,
393
+ compositing_window: Literal["month", "dekad"] = "month",
394
+ ) -> DataCube:
395
+ """Extract the precipitation and temperature AGERA5 data from a
396
+ pre-composited and pre-processed collection. The data is stored in the
397
+ CloudFerro S3 stoage, allowing faster access and processing from the CDSE
398
+ backend.
399
+
400
+ Limitations:
401
+ - Only monthly composited data is available.
402
+ - Only two bands are available: precipitation-flux and temperature-mean.
403
+ """
404
+ temporal_extent = [temporal_extent.start_date, temporal_extent.end_date]
405
+
406
+ if compositing_window == "month":
407
+ # Load precomposited monthly meteo data
408
+ cube = connection.load_stac(
409
+ url="https://stac.openeo.vito.be/collections/agera5_monthly",
410
+ temporal_extent=temporal_extent,
411
+ bands=["precipitation-flux", "temperature-mean"],
412
+ )
413
+ elif compositing_window == "dekad":
414
+ # Load precomposited dekadal meteo data
415
+ cube = connection.load_stac(
416
+ url="https://stac.openeo.vito.be/collections/agera5_dekad",
417
+ temporal_extent=temporal_extent,
418
+ bands=["precipitation-flux", "temperature-mean"],
419
+ )
420
+
421
+ # cube.result_node().update_arguments(featureflags={"tilesize": 1})
422
+ cube = cube.rename_labels(
423
+ dimension="bands", target=["AGERA5-PRECIP", "AGERA5-TMEAN"]
424
+ )
425
+
426
+ return cube
427
+
428
+
429
+ def worldcereal_preprocessed_inputs(
430
+ connection: Connection,
431
+ backend_context: BackendContext,
432
+ spatial_extent: Union[GeoJSON, BoundingBoxExtent, str],
433
+ temporal_extent: TemporalContext,
434
+ fetch_type: Optional[FetchType] = FetchType.TILE,
435
+ disable_meteo: bool = False,
436
+ validate_temporal_context: bool = True,
437
+ s1_orbit_state: Optional[str] = None,
438
+ tile_size: Optional[int] = None,
439
+ s2_tile: Optional[str] = None,
440
+ compositing_window: Literal["month", "dekad"] = "month",
441
+ target_epsg: Optional[int] = None,
442
+ ) -> DataCube:
443
+ # First validate the temporal context
444
+ if validate_temporal_context:
445
+ _validate_temporal_context(temporal_extent)
446
+
447
+ # See if requested compositing method is supported
448
+ assert compositing_window in [
449
+ "month",
450
+ "dekad",
451
+ ], 'Compositing window must be either "month" or "dekad"'
452
+
453
+ # Extraction of S2 from GFMAP
454
+ s2_data = raw_datacube_S2(
455
+ connection=connection,
456
+ backend_context=backend_context,
457
+ temporal_extent=temporal_extent,
458
+ bands=WORLDCEREAL_S2_BANDS,
459
+ fetch_type=fetch_type,
460
+ filter_tile=s2_tile,
461
+ distance_to_cloud_flag=False if fetch_type == FetchType.POINT else True,
462
+ additional_masks_flag=False,
463
+ apply_mask_flag=True,
464
+ tile_size=tile_size,
465
+ target_epsg=target_epsg,
466
+ )
467
+
468
+ s2_data = median_compositing(s2_data, period=compositing_window)
469
+
470
+ # Cast to uint16
471
+ s2_data = s2_data.linear_scale_range(0, 65534, 0, 65534)
472
+
473
+ # Extraction of the S1 data
474
+ # Decides on the orbit direction from the maximum overlapping area of
475
+ # available products.
476
+ if s1_orbit_state is None and backend_context.backend in [
477
+ Backend.CDSE,
478
+ Backend.CDSE_STAGING,
479
+ Backend.FED,
480
+ ]:
481
+ s1_orbit_state = select_best_s1_orbit_direction(
482
+ backend_context, spatial_extent, temporal_extent
483
+ )
484
+ s1_data = raw_datacube_S1(
485
+ connection=connection,
486
+ backend_context=backend_context,
487
+ temporal_extent=temporal_extent,
488
+ bands=WORLDCEREAL_S1_BANDS,
489
+ fetch_type=fetch_type,
490
+ target_resolution=20.0, # Compute the backscatter at 20m resolution, then upsample nearest neighbor when merging cubes
491
+ orbit_direction=s1_orbit_state, # If None, make the query on the catalogue for the best orbit
492
+ tile_size=tile_size,
493
+ target_epsg=target_epsg,
494
+ )
495
+
496
+ s1_data = mean_compositing(s1_data, period=compositing_window)
497
+ s1_data = compress_backscatter_uint16(backend_context, s1_data)
498
+
499
+ dem_data = raw_datacube_DEM(
500
+ connection=connection,
501
+ backend_context=backend_context,
502
+ fetch_type=fetch_type,
503
+ )
504
+
505
+ # Explicitly resample DEM with bilinear interpolation and based on S2 grid
506
+ # note: we use s2_data here as base to avoid issues at the edges because source
507
+ # data is not in UTM projection.
508
+ dem_data = dem_data.resample_cube_spatial(s2_data, method="bilinear")
509
+
510
+ # Cast DEM to UINT16
511
+ dem_data = dem_data.linear_scale_range(0, 65534, 0, 65534)
512
+
513
+ data = s2_data.merge_cubes(s1_data)
514
+ data = data.merge_cubes(dem_data)
515
+
516
+ if not disable_meteo:
517
+ meteo_data = precomposited_datacube_METEO(
518
+ connection=connection,
519
+ temporal_extent=temporal_extent,
520
+ compositing_window=compositing_window,
521
+ )
522
+
523
+ # Explicitly resample meteo with bilinear interpolation and based on S2 grid
524
+ # note: we use s2_data here as base to avoid issues at the edges because source
525
+ # data is not in UTM projection.
526
+ meteo_data = meteo_data.resample_cube_spatial(s2_data, method="bilinear")
527
+
528
+ data = data.merge_cubes(meteo_data)
529
+
530
+ return data
531
+
532
+
533
+ def _validate_temporal_context(temporal_context: TemporalContext) -> None:
534
+ """validation method to ensure proper specification of temporal context.
535
+ which requires that the start and end date are at the first and last day of a month.
536
+ We also check if the temporal context does not span more than a year which is
537
+ currently not supported.
538
+
539
+ Parameters
540
+ ----------
541
+ temporal_context : TemporalContext
542
+ temporal context to validate
543
+
544
+ Raises
545
+ ------
546
+ InvalidTemporalContextError
547
+ if start_date is not on the first day of a month or end_date
548
+ is not on the last day of a month or the span is more than
549
+ one year.
550
+ """
551
+
552
+ start_date, end_date = temporal_context.to_datetime()
553
+
554
+ if start_date != start_date.replace(
555
+ day=1
556
+ ) or end_date != end_date + pd.offsets.MonthEnd(0):
557
+ error_msg = (
558
+ "WorldCereal uses monthly compositing. For this to work properly, "
559
+ "requested temporal range should start and end at the first and last "
560
+ "day of a month. Instead, got: "
561
+ f"{temporal_context.start_date} - {temporal_context.end_date}. "
562
+ "You may use `worldcereal.preprocessing.correct_temporal_context()` "
563
+ "to correct the temporal context."
564
+ )
565
+ raise InvalidTemporalContextError(error_msg)
566
+
567
+ if pd.Timedelta(end_date - start_date).days > 365:
568
+ error_msg = (
569
+ "WorldCereal currently does not support temporal ranges spanning "
570
+ "more than a year. Got: "
571
+ f"{temporal_context.start_date} - {temporal_context.end_date}."
572
+ )
573
+ raise InvalidTemporalContextError(error_msg)
574
+
575
+
576
+ def correct_temporal_context(temporal_context: TemporalContext) -> TemporalContext:
577
+ """Corrects the temporal context to ensure that the start and end date are
578
+ at the first and last day of a month as required by the WorldCereal processing.
579
+
580
+ Parameters
581
+ ----------
582
+ temporal_context : TemporalContext
583
+ temporal context to correct
584
+
585
+ Returns
586
+ -------
587
+ TemporalContext
588
+ corrected temporal context
589
+ """
590
+
591
+ start_date, end_date = temporal_context.to_datetime()
592
+
593
+ start_date = start_date.replace(day=1)
594
+ end_date = end_date + pd.offsets.MonthEnd(0)
595
+
596
+ return TemporalContext(
597
+ start_date=start_date.strftime("%Y-%m-%d"),
598
+ end_date=end_date.strftime("%Y-%m-%d"),
599
+ )
worldcereal/openeo/udf_distance_to_cloud.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "scikit-image",
4
+ # ]
5
+ # ///
6
+
7
+ import numpy as np
8
+ import xarray as xr
9
+ from openeo.udf import XarrayDataCube
10
+ from scipy.ndimage import distance_transform_cdt
11
+ from skimage.morphology import binary_erosion, footprints
12
+
13
+
14
+ def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube:
15
+ cube_array: xr.DataArray = cube.get_array()
16
+ cube_array = cube_array.transpose("bands", "y", "x")
17
+
18
+ clouds: xr.DataArray = np.logical_or(
19
+ np.logical_and(cube_array < 11, cube_array >= 8), cube_array == 3
20
+ ).isel(
21
+ bands=0
22
+ ) # type: ignore
23
+
24
+ # Calculate the Distance To Cloud score
25
+ # Erode
26
+ er = footprints.disk(3)
27
+
28
+ # Define a function to apply binary erosion
29
+ def erode(image, selem):
30
+ return ~binary_erosion(image, selem)
31
+
32
+ # Use apply_ufunc to apply the erosion operation
33
+ eroded = xr.apply_ufunc(
34
+ erode, # function to apply
35
+ clouds, # input DataArray
36
+ input_core_dims=[["y", "x"]], # dimensions over which to apply function
37
+ output_core_dims=[["y", "x"]], # dimensions of the output
38
+ vectorize=True, # vectorize the function over non-core dimensions
39
+ dask="parallelized", # enable dask parallelization
40
+ output_dtypes=[np.int32], # data type of the output
41
+ kwargs={"selem": er}, # additional keyword arguments to pass to erode
42
+ )
43
+
44
+ # Distance to cloud in manhattan distance measure
45
+ distance = xr.apply_ufunc(
46
+ distance_transform_cdt,
47
+ eroded,
48
+ input_core_dims=[["y", "x"]],
49
+ output_core_dims=[["y", "x"]],
50
+ vectorize=True,
51
+ dask="parallelized",
52
+ output_dtypes=[np.int32],
53
+ )
54
+
55
+ distance_da = xr.DataArray(
56
+ distance,
57
+ coords={
58
+ "y": cube_array.coords["y"],
59
+ "x": cube_array.coords["x"],
60
+ },
61
+ dims=["y", "x"],
62
+ )
63
+
64
+ distance_da = distance_da.expand_dims(
65
+ dim={
66
+ "bands": cube_array.coords["bands"],
67
+ },
68
+ )
69
+
70
+ distance_da = distance_da.transpose("bands", "y", "x")
71
+
72
+ return XarrayDataCube(distance_da)
worldcereal/parameters.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from enum import Enum
3
+ from typing import Optional
4
+
5
+ from pydantic import BaseModel, Field, ValidationError, model_validator
6
+
7
+
8
+ class WorldCerealProductType(Enum):
9
+ """Enum to define the different WorldCereal products."""
10
+
11
+ CROPLAND = "cropland"
12
+ CROPTYPE = "croptype"
13
+ EMBEDDINGS = "embeddings"
14
+
15
+
16
+ class FeaturesParameters(BaseModel):
17
+ """Parameters for the feature extraction UDFs. Types are enforced by
18
+ Pydantic.
19
+
20
+ Attributes
21
+ ----------
22
+ rescale_s1 : bool (default=False)
23
+ Whether to rescale Sentinel-1 bands before feature extraction. Should be
24
+ left to False, as this is done in the Presto UDF itself.
25
+ presto_model_url : str
26
+ Public URL to the Presto model used for feature extraction. The file
27
+ should be a PyTorch serialized model.
28
+ compile_presto : bool (default=False)
29
+ Whether to compile the Presto encoder for speeding up large-scale inference.
30
+ temporal_prediction : bool (default=False)
31
+ Whether to use temporal-explicit predictions. If True, the time dimension
32
+ is preserved in Presto features and a specific timestep is selected later.
33
+ If False, features are pooled across time (non-temporal prediction).
34
+ target_date : str (default=None)
35
+ Target date for temporal-explicit predictions in ISO format (YYYY-MM-DD).
36
+ Only used when temporal_prediction=True. If None, the middle timestep is used.
37
+ """
38
+
39
+ rescale_s1: bool
40
+ presto_model_url: str
41
+ compile_presto: bool
42
+ temporal_prediction: bool = Field(default=False)
43
+ target_date: Optional[str] = Field(default=None)
44
+
45
+ @model_validator(mode="after")
46
+ def check_temporal_parameters(self):
47
+ """Validates temporal prediction parameters."""
48
+ if self.target_date is not None and not self.temporal_prediction:
49
+ raise ValidationError(
50
+ "target_date can only be specified when temporal_prediction=True"
51
+ )
52
+
53
+ if self.target_date is not None:
54
+ try:
55
+ datetime.fromisoformat(self.target_date)
56
+ except ValueError:
57
+ raise ValidationError("target_date must be in ISO format (YYYY-MM-DD)")
58
+
59
+ return self
60
+
61
+
62
+ class ClassifierParameters(BaseModel):
63
+ """Parameters for the classifier. Types are enforced by Pydantic.
64
+
65
+ Attributes
66
+ ----------
67
+ classifier_url : str
68
+ Public URL to the classifier model. Te file should be an ONNX accepting
69
+ a `features` field for input data and returning either two output
70
+ probability arrays `true` and `false` in case of cropland mapping, or
71
+ a probability array per-class in case of croptype mapping.
72
+ """
73
+
74
+ classifier_url: str
75
+
76
+
77
+ class PostprocessParameters(BaseModel):
78
+ """Parameters for postprocessing. Types are enforced by Pydantic.
79
+
80
+ Attributes
81
+ ----------
82
+ enable: bool (default=True)
83
+ Whether to enable postprocessing.
84
+ method: str (default="smooth_probabilities")
85
+ The method to use for postprocessing. Must be one of ["smooth_probabilities", "majority_vote"]
86
+ kernel_size: int (default=5)
87
+ Used for majority vote postprocessing. Must be an odd number, larger than 1 and smaller than 25.
88
+ save_intermediate: bool (default=False)
89
+ Whether to save intermediate results (before applying the postprocessing).
90
+ The intermediate results will be saved in the GeoTiff format.
91
+ keep_class_probs: bool (default=True)
92
+ If the per-class probabilities should be outputted in the final product.
93
+ """
94
+
95
+ enable: bool = Field(default=True)
96
+ method: str = Field(default="smooth_probabilities")
97
+ kernel_size: int = Field(default=5)
98
+ save_intermediate: bool = Field(default=False)
99
+ keep_class_probs: bool = Field(default=True)
100
+
101
+ @model_validator(mode="after")
102
+ def check_parameters(self):
103
+ """Validates parameters."""
104
+ if not self.enable and self.save_intermediate:
105
+ raise ValueError(
106
+ "Cannot save intermediate results if postprocessing is disabled."
107
+ )
108
+
109
+ if self.method not in ["smooth_probabilities", "majority_vote"]:
110
+ raise ValueError(
111
+ f"Method must be one of ['smooth_probabilities', 'majority_vote'], got {self.method}"
112
+ )
113
+
114
+ if self.method == "majority_vote":
115
+ if self.kernel_size % 2 == 0:
116
+ raise ValueError(
117
+ f"Kernel size for majority filtering should be an odd number, got {self.kernel_size}"
118
+ )
119
+ if self.kernel_size > 25:
120
+ raise ValueError(
121
+ f"Kernel size for majority filtering should be an odd number smaller than 25, got {self.kernel_size}"
122
+ )
123
+ if self.kernel_size < 3:
124
+ raise ValueError(
125
+ f"Kernel size for majority filtering should be an odd number larger than 1, got {self.kernel_size}"
126
+ )
127
+
128
+ return self
129
+
130
+
131
+ class BaseParameters(BaseModel):
132
+ """Base class for shared parameter logic."""
133
+
134
+ postprocess_parameters: PostprocessParameters = Field(
135
+ default_factory=lambda: PostprocessParameters()
136
+ )
137
+
138
+ @staticmethod
139
+ def create_feature_parameters(**kwargs):
140
+ defaults = {
141
+ "rescale_s1": False,
142
+ "presto_model_url": "",
143
+ "compile_presto": False,
144
+ "temporal_prediction": False,
145
+ "target_date": None,
146
+ }
147
+ defaults.update(kwargs)
148
+ return FeaturesParameters(**defaults)
149
+
150
+ @staticmethod
151
+ def create_classifier_parameters(classifier_url: str):
152
+ return ClassifierParameters(classifier_url=classifier_url)
153
+
154
+
155
+ class CropLandParameters(BaseParameters):
156
+ """Parameters for the cropland product inference pipeline. Types are
157
+ enforced by Pydantic.
158
+
159
+ Attributes
160
+ ----------
161
+ feature_parameters : FeaturesParameters
162
+ Parameters for the feature extraction UDF. Will be serialized into a
163
+ dictionary and passed in the process graph.
164
+ classifier_parameters : ClassifierParameters
165
+ Parameters for the classifier UDF. Will be serialized into a dictionary
166
+ and passed in the process graph.
167
+ """
168
+
169
+ feature_parameters: FeaturesParameters = BaseParameters.create_feature_parameters(
170
+ rescale_s1=False,
171
+ presto_model_url="https://s3.waw3-1.cloudferro.com/swift/v1/openeo-ml-models-prod/worldcereal/presto-prometheo-landcover-MulticlassWithCroplandAuxBCELoss-labelsmoothing=0.05-month-LANDCOVER10-augment=True-balance=True-timeexplicit=False-masking=enabled-run=202510301004_encoder.pt", # NOQA
172
+ compile_presto=False,
173
+ temporal_prediction=False,
174
+ target_date=None,
175
+ )
176
+
177
+ @staticmethod
178
+ def _default_classifier_parameters() -> ClassifierParameters:
179
+ return BaseParameters.create_classifier_parameters(
180
+ classifier_url="https://s3.waw3-1.cloudferro.com/swift/v1/openeo-ml-models-prod/worldcereal/PrestoDownstreamCatBoost_temporary-crops_v201-prestorun=202510301004.onnx" # NOQA
181
+ )
182
+
183
+ classifier_parameters: ClassifierParameters = Field(
184
+ default_factory=lambda: CropLandParameters._default_classifier_parameters()
185
+ )
186
+
187
+ def __init__(self, classifier_url: Optional[str] = None, **kwargs):
188
+ # Allow overriding classifier URL unless explicit classifier_parameters provided
189
+ if "classifier_parameters" not in kwargs and classifier_url is not None:
190
+ kwargs["classifier_parameters"] = (
191
+ BaseParameters.create_classifier_parameters(
192
+ classifier_url=classifier_url
193
+ )
194
+ )
195
+ super().__init__(**kwargs)
196
+
197
+
198
+ class CropTypeParameters(BaseParameters):
199
+ """Parameters for the croptype product inference pipeline. Types are
200
+ enforced by Pydantic.
201
+
202
+ Attributes
203
+ ----------
204
+ feature_parameters : FeaturesParameters
205
+ Parameters for the feature extraction UDF. Will be serialized into a
206
+ dictionary and passed in the process graph.
207
+ classifier_parameters : ClassifierParameters
208
+ Parameters for the classifier UDF. Will be serialized into a dictionary
209
+ and passed in the process graph.
210
+ mask_cropland : bool (default=True)
211
+ Whether or not to mask the cropland pixels before running crop type inference.
212
+ save_mask : bool (default=False)
213
+ Whether or not to save the cropland mask as an intermediate result.
214
+ """
215
+
216
+ @staticmethod
217
+ def _default_feature_parameters() -> FeaturesParameters:
218
+ """Single source of truth for default croptype feature parameters."""
219
+ return BaseParameters.create_feature_parameters(
220
+ rescale_s1=False,
221
+ presto_model_url="https://s3.waw3-1.cloudferro.com/swift/v1/openeo-ml-models-prod/worldcereal/presto-prometheo-croptype-with-nocrop-FocalLoss-labelsmoothing%3D0.05-month-CROPTYPE27-augment%3DTrue-balance%3DTrue-timeexplicit%3DFalse-masking%3Denabled-run%3D202510301004_encoder.pt", # NOQA
222
+ compile_presto=False,
223
+ temporal_prediction=False,
224
+ target_date=None, # By default take the middle date
225
+ )
226
+
227
+ @staticmethod
228
+ def _default_classifier_parameters() -> ClassifierParameters:
229
+ return BaseParameters.create_classifier_parameters(
230
+ classifier_url="https://s3.waw3-1.cloudferro.com/swift/v1/openeo-ml-models-prod/worldcereal/PrestoDownstreamCatBoost_croptype_v201-prestorun%3D202510301004.onnx"
231
+ )
232
+
233
+ feature_parameters: FeaturesParameters = Field(
234
+ default_factory=lambda: CropTypeParameters._default_feature_parameters()
235
+ )
236
+ classifier_parameters: ClassifierParameters = Field(
237
+ default_factory=lambda: CropTypeParameters._default_classifier_parameters()
238
+ )
239
+ mask_cropland: bool = Field(default=True)
240
+ save_mask: bool = Field(default=False)
241
+
242
+ def __init__(
243
+ self,
244
+ target_date: Optional[str] = None,
245
+ classifier_url: Optional[str] = None,
246
+ **kwargs,
247
+ ):
248
+ # Override feature target_date if feature_parameters not supplied
249
+ if "feature_parameters" not in kwargs:
250
+ fp = self._default_feature_parameters().model_copy()
251
+ fp.target_date = target_date # type: ignore[attr-defined]
252
+ kwargs["feature_parameters"] = fp
253
+ # Override classifier URL if classifier_parameters not supplied
254
+ if "classifier_parameters" not in kwargs and classifier_url is not None:
255
+ kwargs["classifier_parameters"] = (
256
+ BaseParameters.create_classifier_parameters(
257
+ classifier_url=classifier_url
258
+ )
259
+ )
260
+ super().__init__(**kwargs)
261
+
262
+ @model_validator(mode="after")
263
+ def check_mask_parameters(self):
264
+ """Validates the mask-related parameters."""
265
+ if not self.mask_cropland and self.save_mask:
266
+ raise ValidationError("Cannot save mask if mask_cropland is disabled.")
267
+ return self
268
+
269
+
270
+ class EmbeddingsParameters(BaseParameters):
271
+ """Parameters for the embeddings product inference pipeline. Types are
272
+ enforced by Pydantic.
273
+
274
+ Attributes
275
+ ----------
276
+ feature_parameters : FeaturesParameters
277
+ Parameters for the feature extraction UDF. Will be serialized into a
278
+ dictionary and passed in the process graph.
279
+ classifier_parameters : ClassifierParameters
280
+ Parameters for the classifier UDF. Will be serialized into a dictionary
281
+ and passed in the process graph.
282
+ """
283
+
284
+ @staticmethod
285
+ def _default_feature_parameters() -> FeaturesParameters:
286
+ """Internal helper returning the default feature parameters instance.
287
+
288
+ Centralizes the defaults so they are declared only once.
289
+ """
290
+ return BaseParameters.create_feature_parameters(
291
+ rescale_s1=False,
292
+ presto_model_url="https://s3.waw3-1.cloudferro.com/swift/v1/openeo-ml-models-prod/worldcereal/presto-prometheo-landcover-month-LANDCOVER10-augment%3DTrue-balance%3DTrue-timeexplicit%3DFalse-run%3D202507170930_encoder.pt", # NOQA
293
+ compile_presto=False,
294
+ temporal_prediction=False,
295
+ target_date=None,
296
+ )
297
+
298
+ feature_parameters: FeaturesParameters = Field(
299
+ # Wrap staticmethod call so pydantic receives a true zero-arg callable
300
+ default_factory=lambda: EmbeddingsParameters._default_feature_parameters()
301
+ )
302
+
303
+ def __init__(self, presto_model_url: Optional[str] = None, **kwargs):
304
+ """Allow initialization with a custom Presto model URL without
305
+ duplicating the default argument list.
306
+
307
+ Users may still pass an explicit `feature_parameters` to override all
308
+ aspects; in that case `presto_model_url` is ignored.
309
+ """
310
+ if "feature_parameters" not in kwargs and presto_model_url is not None:
311
+ fp = self._default_feature_parameters().model_copy()
312
+ fp.presto_model_url = presto_model_url # type: ignore[attr-defined]
313
+ kwargs["feature_parameters"] = fp
314
+ super().__init__(**kwargs)
worldcereal/utils/models.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities around models for the WorldCereal package."""
2
+
3
+ import json
4
+ from functools import lru_cache
5
+
6
+ import onnxruntime as ort
7
+ import requests
8
+
9
+
10
+ @lru_cache(maxsize=2)
11
+ def load_model_onnx(model_url) -> ort.InferenceSession:
12
+ """Load an ONNX model from a URL.
13
+
14
+ Parameters
15
+ ----------
16
+ model_url: str
17
+ URL to the ONNX model.
18
+
19
+ Returns
20
+ -------
21
+ ort.InferenceSession
22
+ ONNX model loaded with ONNX runtime.
23
+ """
24
+ # Two minutes timeout to download the model
25
+ response = requests.get(model_url, timeout=120)
26
+ model = response.content
27
+
28
+ return ort.InferenceSession(model)
29
+
30
+
31
+ def validate_cb_model(model_url: str) -> ort.InferenceSession:
32
+ """Validate a catboost model by loading it and checking if the required
33
+ metadata is present. Checks for the `class_names` and `class_to_labels`
34
+ fields are present in the `class_params` field of the custom metadata of
35
+ the model. By default, the CatBoost module should include those fields
36
+ when exporting a model to ONNX.
37
+
38
+ Raises an exception if the model is not valid.
39
+
40
+ Parameters
41
+ ----------
42
+ model_url : str
43
+ URL to the ONNX model.
44
+
45
+ Returns
46
+ -------
47
+ ort.InferenceSession
48
+ ONNX model loaded with ONNX runtime.
49
+ """
50
+ model = load_model_onnx(model_url=model_url)
51
+
52
+ metadata = model.get_modelmeta().custom_metadata_map
53
+
54
+ if "class_params" not in metadata:
55
+ raise ValueError("Could not find class names in the model metadata.")
56
+
57
+ class_params = json.loads(metadata["class_params"])
58
+
59
+ if "class_names" not in class_params:
60
+ raise ValueError("Could not find class names in the model metadata.")
61
+
62
+ if "class_to_label" not in class_params:
63
+ raise ValueError("Could not find class to labels in the model metadata.")
64
+
65
+ return model
66
+
67
+
68
+ def load_model_lut(model_url: str) -> dict:
69
+ """Load the class names to labels mapping from a CatBoost model.
70
+
71
+ Parameters
72
+ ----------
73
+ model_url : str
74
+ URL to the ONNX model.
75
+
76
+ Returns
77
+ -------
78
+ dict
79
+ Look-up table with class names and labels.
80
+ """
81
+ model = validate_cb_model(model_url=model_url)
82
+ metadata = model.get_modelmeta().custom_metadata_map
83
+ class_params = json.loads(metadata["class_params"])
84
+
85
+ lut = dict(zip(class_params["class_names"], class_params["class_to_label"]))
86
+ sorted_lut = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])}
87
+ return sorted_lut