Install dependencies
Browse files- .gitignore +4 -0
- Dockerfile +19 -0
- README.md +3 -0
- app.py +413 -0
- openeo_gfmap/__init__.py +23 -0
- openeo_gfmap/backend.py +122 -0
- openeo_gfmap/fetching.py +98 -0
- openeo_gfmap/metadata.py +24 -0
- openeo_gfmap/spatial.py +53 -0
- openeo_gfmap/temporal.py +22 -0
- pyproject.toml +16 -0
- uv.lock +0 -0
- worldcereal/__init__.py +39 -0
- worldcereal/_version.py +3 -0
- worldcereal/job.py +960 -0
- worldcereal/openeo/__init__.py +0 -0
- worldcereal/openeo/feature_extractor.py +582 -0
- worldcereal/openeo/inference.py +1191 -0
- worldcereal/openeo/mapping.py +250 -0
- worldcereal/openeo/preprocessing.py +599 -0
- worldcereal/openeo/udf_distance_to_cloud.py +72 -0
- worldcereal/parameters.py +314 -0
- worldcereal/utils/models.py +87 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
.idea/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
ENV UV_COMPILE_BYTECODE=1
|
| 8 |
+
ENV UV_LINK_MODE=copy
|
| 9 |
+
|
| 10 |
+
COPY pyproject.toml uv.lock ./
|
| 11 |
+
|
| 12 |
+
RUN uv sync --frozen --no-cache --no-install-project
|
| 13 |
+
|
| 14 |
+
COPY . .
|
| 15 |
+
|
| 16 |
+
EXPOSE 8501
|
| 17 |
+
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 18 |
+
|
| 19 |
+
ENTRYPOINT ["steamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
README.md
CHANGED
|
@@ -9,3 +9,6 @@ short_description: Application for Automatic Crop Type Mapping
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 12 |
+
|
| 13 |
+
# License
|
| 14 |
+
This project is licensed under the terms of the MIT License. See the LICENSE file for details.
|
app.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import folium
|
| 3 |
+
from folium import plugins
|
| 4 |
+
from streamlit_folium import st_folium
|
| 5 |
+
import rasterio
|
| 6 |
+
from rasterio.warp import calculate_default_transform, reproject, Resampling
|
| 7 |
+
import joblib
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import geopandas as gpd
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from matplotlib import colors as colors
|
| 13 |
+
import time
|
| 14 |
+
from rasterio.crs import CRS
|
| 15 |
+
|
| 16 |
+
from worldcereal.job import INFERENCE_JOB_OPTIONS, create_embeddings_process_graph
|
| 17 |
+
from openeo_gfmap import TemporalContext, BoundingBoxExtent
|
| 18 |
+
from worldcereal.parameters import EmbeddingsParameters
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
crop_classes = {
|
| 22 |
+
"tuz": "#32cd32",
|
| 23 |
+
"burak": "#8b008b",
|
| 24 |
+
"jęczmień": "#ffd700",
|
| 25 |
+
"kukurydza": "#ffa500",
|
| 26 |
+
"lucerna": "#9acd32",
|
| 27 |
+
"mieszanka": "#daa520",
|
| 28 |
+
"owies": "#f0e68c",
|
| 29 |
+
"pszenica": "#f5deb3",
|
| 30 |
+
"pszenżyto": "#bdb76b",
|
| 31 |
+
"rzepak": "#ffff00",
|
| 32 |
+
"sad": "#228b22",
|
| 33 |
+
"słonecznik": "#ff4500",
|
| 34 |
+
"ziemniak": "#a0522d",
|
| 35 |
+
"łubin": "#9370db",
|
| 36 |
+
"żyto": "#cd853f",
|
| 37 |
+
"inne": "#808080"
|
| 38 |
+
}
|
| 39 |
+
class_to_id = {name: i for i, name in enumerate(crop_classes.keys())}
|
| 40 |
+
id_to_class = {i: name for name, i in class_to_id.items()}
|
| 41 |
+
|
| 42 |
+
st.set_page_config(page_title="Crop Map", layout="wide")
|
| 43 |
+
|
| 44 |
+
model_path = Path("app/crop_map_app/models/random_forest_crop_classifier_06.joblib")
|
| 45 |
+
demo_dir = Path("app/crop_map_app/embeddings/demo")
|
| 46 |
+
temp_dir = Path("embeddings/temp_analysis") # for new files
|
| 47 |
+
temp_dir.mkdir(parents=True, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
def get_class_color_rgba(class_name, alpha=180):
|
| 50 |
+
hex_color = crop_classes.get(class_name, "#000000")
|
| 51 |
+
rgb = colors.hex2color(hex_color)
|
| 52 |
+
return (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255), alpha)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def create_legend_html(stats_legend):
|
| 56 |
+
html_parts = [
|
| 57 |
+
"<div style='background-color: rgba(255, 255, 255, 0.1); padding: 10px; border-radius: 5px; font-family: sans-serif;'>"
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
for _, row in stats_legend.iterrows():
|
| 61 |
+
crop = row['Crop']
|
| 62 |
+
color = crop_classes.get(crop, "#000000")
|
| 63 |
+
percent = row['Percentage']
|
| 64 |
+
|
| 65 |
+
row_html = (
|
| 66 |
+
f"<div style='display: flex; align-items: center; margin-bottom: 4px;'>"
|
| 67 |
+
f"<div style='width: 15px; height: 15px; background-color: {color}; margin-right: 10px; border-radius: 3px;'></div>"
|
| 68 |
+
f"<span style='font-size: 14px; flex-grow: 1;'>{crop}</span>"
|
| 69 |
+
f"<span style='font-weight: bold; font-size: 14px;'>{percent:.1f}%</span>"
|
| 70 |
+
f"</div>"
|
| 71 |
+
)
|
| 72 |
+
html_parts.append(row_html)
|
| 73 |
+
|
| 74 |
+
html_parts.append("</div>")
|
| 75 |
+
return "".join(html_parts)
|
| 76 |
+
|
| 77 |
+
@st.cache_resource
|
| 78 |
+
def load_model():
|
| 79 |
+
if not model_path.exists(): return None
|
| 80 |
+
return joblib.load(model_path)
|
| 81 |
+
|
| 82 |
+
@st.cache_data
|
| 83 |
+
def run_prediction(tif_path, _model):
|
| 84 |
+
with rasterio.open(tif_path) as src:
|
| 85 |
+
embedding = src.read()
|
| 86 |
+
src_transform = src.transform
|
| 87 |
+
src_crs = src.crs
|
| 88 |
+
h, w = src.height, src.width
|
| 89 |
+
|
| 90 |
+
n_channels = embedding.shape[0]
|
| 91 |
+
reshaped = embedding.transpose(1, 2, 0).reshape(-1, n_channels)
|
| 92 |
+
|
| 93 |
+
# prediction
|
| 94 |
+
batch_size = 50000
|
| 95 |
+
preds = []
|
| 96 |
+
for i in range(0, reshaped.shape[0], batch_size):
|
| 97 |
+
batch = reshaped[i:i + batch_size]
|
| 98 |
+
batch = np.nan_to_num(batch)
|
| 99 |
+
preds.append(_model.predict(batch))
|
| 100 |
+
|
| 101 |
+
raw_class_map_str = np.concatenate(preds).reshape(h, w)
|
| 102 |
+
|
| 103 |
+
raw_class_map_int = np.zeros((h, w), dtype=np.uint8)
|
| 104 |
+
for class_name, class_id in class_to_id.items():
|
| 105 |
+
raw_class_map_int[raw_class_map_str == class_name] = class_id
|
| 106 |
+
|
| 107 |
+
src_crs_str = src_crs.to_string()
|
| 108 |
+
dst_crs = CRS.from_string('EPSG:4326')
|
| 109 |
+
|
| 110 |
+
left, bottom, right, top = rasterio.transform.array_bounds(h, w, src_transform)
|
| 111 |
+
transform, dst_width, dst_height = calculate_default_transform(
|
| 112 |
+
src_crs_str, dst_crs, w, h, left=left, bottom=bottom, right=right, top=top
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
destination = np.zeros((dst_height, dst_width), dtype=np.uint8)
|
| 116 |
+
reproject(
|
| 117 |
+
source=raw_class_map_int,
|
| 118 |
+
destination=destination,
|
| 119 |
+
src_transform=src_transform,
|
| 120 |
+
src_crs=src_crs_str,
|
| 121 |
+
dst_transform=transform,
|
| 122 |
+
dst_crs=dst_crs,
|
| 123 |
+
resampling=Resampling.nearest
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
bounds_orig = rasterio.transform.array_bounds(dst_height, dst_width, transform)
|
| 127 |
+
folium_bounds = [[bounds_orig[1], bounds_orig[0]], [bounds_orig[3], bounds_orig[2]]]
|
| 128 |
+
|
| 129 |
+
return destination, folium_bounds
|
| 130 |
+
|
| 131 |
+
def run_openeo_job(lat, lon, size_km=1.0):
|
| 132 |
+
"""
|
| 133 |
+
Runs WorldCereal job for a small box around lat/lon.
|
| 134 |
+
Returns path to downloaded tif or None.
|
| 135 |
+
"""
|
| 136 |
+
try:
|
| 137 |
+
offset = (size_km / 111) / 2
|
| 138 |
+
west, east = lon - offset, lon + offset
|
| 139 |
+
south, north = lat - offset, lat + offset
|
| 140 |
+
|
| 141 |
+
spatial_extent = BoundingBoxExtent(
|
| 142 |
+
west=west, south=south, east=east, north=north, epsg=4326
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# changing time range
|
| 146 |
+
temporal_extent = TemporalContext("2025-01-01", "2025-12-31")
|
| 147 |
+
|
| 148 |
+
st.info("Building OpenEO Process Graph...")
|
| 149 |
+
embedding_params = EmbeddingsParameters()
|
| 150 |
+
inference_result = create_embeddings_process_graph(
|
| 151 |
+
spatial_extent=spatial_extent,
|
| 152 |
+
temporal_extent=temporal_extent,
|
| 153 |
+
embeddings_parameters=embedding_params,
|
| 154 |
+
scale_uint16=True
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
job_title = f"thesis_demo_{lat}_{lon}"
|
| 158 |
+
st.info(f"Submitting Job: {job_title}...")
|
| 159 |
+
job = inference_result.create_job(
|
| 160 |
+
title=job_title,
|
| 161 |
+
job_options=INFERENCE_JOB_OPTIONS,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
job.start()
|
| 165 |
+
job_id = job.job_id
|
| 166 |
+
st.success(f"Job started. ID: {job_id}")
|
| 167 |
+
|
| 168 |
+
status_box = st.empty()
|
| 169 |
+
while True:
|
| 170 |
+
metadata = job.describe_job()
|
| 171 |
+
status = metadata.get("status")
|
| 172 |
+
status_box.markdown(f"**Status:** `{status}` (refreshing every 5s...)")
|
| 173 |
+
|
| 174 |
+
if status == "finished":
|
| 175 |
+
break
|
| 176 |
+
elif status in ["error", "canceled"]:
|
| 177 |
+
st.error(f"Job failed with status: {status}")
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
time.sleep(5)
|
| 181 |
+
|
| 182 |
+
st.info("Downloading results...")
|
| 183 |
+
results = job.get_results()
|
| 184 |
+
output_path = temp_dir / f"embedding_{lat}_{lon}.tif"
|
| 185 |
+
|
| 186 |
+
found = False
|
| 187 |
+
for asset in results.get_assets():
|
| 188 |
+
if asset.metadata.get("type", "").startswith("image/tiff"):
|
| 189 |
+
asset.download(str(output_path))
|
| 190 |
+
found = True
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
if found:
|
| 194 |
+
return output_path
|
| 195 |
+
else:
|
| 196 |
+
st.error("No TIFF found in results.")
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
st.error(f"OpenEO Error: {str(e)}")
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
st.title("Crop Map")
|
| 205 |
+
|
| 206 |
+
with st.sidebar:
|
| 207 |
+
st.header("Control Panel")
|
| 208 |
+
tif_files = list(demo_dir.glob("*.tif"))
|
| 209 |
+
if not tif_files:
|
| 210 |
+
st.error(f"No .tif files in {demo_dir}")
|
| 211 |
+
st.stop()
|
| 212 |
+
|
| 213 |
+
selected_tif = st.selectbox("Select Region", tif_files, format_func=lambda x: x.name)
|
| 214 |
+
|
| 215 |
+
possible_name = selected_tif.stem.replace("_embedding", "") + ".geojson"
|
| 216 |
+
geojson_path = selected_tif.parent / possible_name
|
| 217 |
+
has_geojson = geojson_path.exists()
|
| 218 |
+
|
| 219 |
+
if has_geojson:
|
| 220 |
+
st.success(f"Linked: {geojson_path.name}")
|
| 221 |
+
|
| 222 |
+
run_btn = st.button("Run Analysis", type="primary")
|
| 223 |
+
|
| 224 |
+
if run_btn:
|
| 225 |
+
model = load_model()
|
| 226 |
+
if not model:
|
| 227 |
+
st.error("Model not found")
|
| 228 |
+
st.stop()
|
| 229 |
+
|
| 230 |
+
with st.spinner("Processing..."): # type: ignore[arg-type]
|
| 231 |
+
class_map, bounds = run_prediction(selected_tif, model)
|
| 232 |
+
|
| 233 |
+
h, w = class_map.shape
|
| 234 |
+
rgba_img = np.zeros((h, w, 4), dtype=np.uint8)
|
| 235 |
+
unique_ids = np.unique(class_map)
|
| 236 |
+
|
| 237 |
+
for uid in unique_ids:
|
| 238 |
+
if uid not in id_to_class: continue
|
| 239 |
+
crop = id_to_class[uid]
|
| 240 |
+
c = get_class_color_rgba(crop, alpha=255)
|
| 241 |
+
rgba_img[class_map == uid] = c
|
| 242 |
+
|
| 243 |
+
gdf = None
|
| 244 |
+
if has_geojson:
|
| 245 |
+
gdf = gpd.read_file(geojson_path)
|
| 246 |
+
if gdf.crs != "EPSG:4326":
|
| 247 |
+
gdf = gdf.to_crs("EPSG:4326")
|
| 248 |
+
gdf['geometry'] = gdf['geometry'].simplify(tolerance=0.0001)
|
| 249 |
+
|
| 250 |
+
total = class_map.size
|
| 251 |
+
counts = {id_to_class[uid]: np.sum(class_map == uid) for uid in unique_ids if uid in id_to_class}
|
| 252 |
+
stats_df = pd.DataFrame([
|
| 253 |
+
{"Crop": k, "Pixels": v, "Percentage": v / total * 100} for k, v in counts.items()
|
| 254 |
+
]).sort_values("Percentage", ascending=False)
|
| 255 |
+
|
| 256 |
+
st.session_state['analysis_results'] = {
|
| 257 |
+
"bounds": bounds,
|
| 258 |
+
"rgba_img": rgba_img,
|
| 259 |
+
"gdf": gdf,
|
| 260 |
+
"stats_df": stats_df
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
tab1, tab2 = st.tabs(["Pre-loaded Regions", "Analyze New Area"])
|
| 264 |
+
|
| 265 |
+
with tab1:
|
| 266 |
+
if 'analysis_results' in st.session_state:
|
| 267 |
+
data = st.session_state['analysis_results']
|
| 268 |
+
bounds = data['bounds']
|
| 269 |
+
rgba_img = data['rgba_img']
|
| 270 |
+
gdf = data['gdf']
|
| 271 |
+
stats_df = data['stats_df']
|
| 272 |
+
|
| 273 |
+
c1, c2 = st.columns([3, 1])
|
| 274 |
+
|
| 275 |
+
with c1:
|
| 276 |
+
center_lat = (bounds[0][0] + bounds[1][0]) / 2
|
| 277 |
+
center_lon = (bounds[0][1] + bounds[1][1]) / 2
|
| 278 |
+
|
| 279 |
+
overlay_opacity = st.slider("Overlay Opacity", 0.0, 1.0, 0.7, 0.1, key="opacity_tab1")
|
| 280 |
+
|
| 281 |
+
m = folium.Map(location=[center_lat, center_lon], zoom_start=14, control_scale=True)
|
| 282 |
+
|
| 283 |
+
folium.TileLayer(
|
| 284 |
+
tiles='CartoDB positron',
|
| 285 |
+
name='Light Map',
|
| 286 |
+
overlay=False
|
| 287 |
+
).add_to(m)
|
| 288 |
+
|
| 289 |
+
folium.TileLayer(
|
| 290 |
+
tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
|
| 291 |
+
attr='Esri',
|
| 292 |
+
name='Satellite',
|
| 293 |
+
overlay=False
|
| 294 |
+
).add_to(m)
|
| 295 |
+
|
| 296 |
+
folium.raster_layers.ImageOverlay(
|
| 297 |
+
image=rgba_img,
|
| 298 |
+
bounds=bounds,
|
| 299 |
+
opacity=overlay_opacity,
|
| 300 |
+
name='Prediction',
|
| 301 |
+
pixelated=True
|
| 302 |
+
).add_to(m)
|
| 303 |
+
|
| 304 |
+
if gdf is not None:
|
| 305 |
+
folium.GeoJson(
|
| 306 |
+
gdf,
|
| 307 |
+
name="Fields",
|
| 308 |
+
style_function=lambda x: {'color': 'white', 'weight': 1, 'fillOpacity': 0, 'dashArray': '5, 5'},
|
| 309 |
+
tooltip=folium.GeoJsonTooltip(fields=['roslina'], aliases=['Crop:'])
|
| 310 |
+
).add_to(m)
|
| 311 |
+
|
| 312 |
+
folium.LayerControl().add_to(m)
|
| 313 |
+
plugins.Fullscreen().add_to(m)
|
| 314 |
+
|
| 315 |
+
st_folium(m, height=600, use_container_width=True)
|
| 316 |
+
|
| 317 |
+
with c2:
|
| 318 |
+
st.subheader("Legend")
|
| 319 |
+
st.markdown(create_legend_html(stats_df), unsafe_allow_html=True)
|
| 320 |
+
st.dataframe(stats_df[["Crop", "Percentage"]], hide_index=True)
|
| 321 |
+
|
| 322 |
+
with tab2:
|
| 323 |
+
c1, c2 = st.columns([1, 2])
|
| 324 |
+
|
| 325 |
+
if 'tab2_results' not in st.session_state:
|
| 326 |
+
st.session_state['tab2_results'] = None
|
| 327 |
+
|
| 328 |
+
with c1:
|
| 329 |
+
st.markdown("### 1. Select Area")
|
| 330 |
+
lat = st.number_input("Latitude", value=50.93131691432723, format="%.4f")
|
| 331 |
+
lon = st.number_input("Longitude", value=22.781513694631702, format="%.4f")
|
| 332 |
+
|
| 333 |
+
if st.button("Generate the embedding and classify"):
|
| 334 |
+
with st.spinner("Talking to Satellites... (This takes ~5 mins)"): # type: ignore[arg-type]
|
| 335 |
+
tif_path = run_openeo_job(lat, lon)
|
| 336 |
+
|
| 337 |
+
if tif_path:
|
| 338 |
+
st.success("Embedding Generated!")
|
| 339 |
+
|
| 340 |
+
model = load_model()
|
| 341 |
+
class_map, bounds = run_prediction(tif_path, model)
|
| 342 |
+
|
| 343 |
+
h, w = class_map.shape
|
| 344 |
+
rgba_img = np.zeros((h, w, 4), dtype=np.uint8)
|
| 345 |
+
unique_ids = np.unique(class_map)
|
| 346 |
+
|
| 347 |
+
for uid in unique_ids:
|
| 348 |
+
if uid not in id_to_class: continue
|
| 349 |
+
crop = id_to_class[uid]
|
| 350 |
+
c = get_class_color_rgba(crop, alpha=255)
|
| 351 |
+
rgba_img[class_map == uid] = c
|
| 352 |
+
|
| 353 |
+
total = class_map.size
|
| 354 |
+
counts = {id_to_class[uid]: np.sum(class_map == uid) for uid in unique_ids if uid in id_to_class}
|
| 355 |
+
stats_df = pd.DataFrame([
|
| 356 |
+
{"Crop": k, "Pixels": v, "Percentage": v / total * 100} for k, v in counts.items()
|
| 357 |
+
]).sort_values("Percentage", ascending=False)
|
| 358 |
+
|
| 359 |
+
st.session_state['tab2_results'] = {
|
| 360 |
+
"bounds": bounds,
|
| 361 |
+
"rgba_img": rgba_img,
|
| 362 |
+
"stats_df": stats_df
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
st.success("Classification Complete")
|
| 366 |
+
|
| 367 |
+
with c2:
|
| 368 |
+
if st.session_state['tab2_results']:
|
| 369 |
+
data = st.session_state['tab2_results']
|
| 370 |
+
bounds = data['bounds']
|
| 371 |
+
rgba_img = data['rgba_img']
|
| 372 |
+
stats_df = data['stats_df']
|
| 373 |
+
|
| 374 |
+
st.markdown("### 2. Analysis Results")
|
| 375 |
+
|
| 376 |
+
center_lat = (bounds[0][0] + bounds[1][0]) / 2
|
| 377 |
+
center_lon = (bounds[0][1] + bounds[1][1]) / 2
|
| 378 |
+
|
| 379 |
+
overlay_opacity = st.slider("Overlay Opacity", 0.0, 1.0, 0.7, 0.1, key="opacity_tab2")
|
| 380 |
+
|
| 381 |
+
m = folium.Map(location=[center_lat, center_lon], zoom_start=14, control_scale=True)
|
| 382 |
+
|
| 383 |
+
folium.TileLayer(
|
| 384 |
+
tiles='CartoDB positron',
|
| 385 |
+
name='Light Map',
|
| 386 |
+
overlay=False
|
| 387 |
+
).add_to(m)
|
| 388 |
+
|
| 389 |
+
folium.TileLayer(
|
| 390 |
+
tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
|
| 391 |
+
attr='Esri',
|
| 392 |
+
name='Satellite',
|
| 393 |
+
overlay=False
|
| 394 |
+
).add_to(m)
|
| 395 |
+
|
| 396 |
+
folium.raster_layers.ImageOverlay(
|
| 397 |
+
image=rgba_img,
|
| 398 |
+
bounds=bounds,
|
| 399 |
+
opacity=overlay_opacity,
|
| 400 |
+
name='Prediction',
|
| 401 |
+
pixelated=True
|
| 402 |
+
).add_to(m)
|
| 403 |
+
|
| 404 |
+
folium.LayerControl().add_to(m)
|
| 405 |
+
plugins.Fullscreen().add_to(m)
|
| 406 |
+
|
| 407 |
+
st_folium(m, height=500, use_container_width=True)
|
| 408 |
+
|
| 409 |
+
st.divider()
|
| 410 |
+
col_leg, col_df = st.columns(2)
|
| 411 |
+
with col_leg:
|
| 412 |
+
st.subheader("Legend")
|
| 413 |
+
st.markdown(create_legend_html(stats_df), unsafe_allow_html=True)
|
openeo_gfmap/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenEO General Framework for Mapping.
|
| 2 |
+
|
| 3 |
+
Simplify the development of mapping applications through Remote Sensing data
|
| 4 |
+
by leveraging the power of OpenEO (http://openeo.org/).
|
| 5 |
+
|
| 6 |
+
More information available in the README.md file.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from .backend import Backend, BackendContext
|
| 10 |
+
from .fetching import FetchType
|
| 11 |
+
from .metadata import FakeMetadata
|
| 12 |
+
from .spatial import BoundingBoxExtent, SpatialContext
|
| 13 |
+
from .temporal import TemporalContext
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"Backend",
|
| 17 |
+
"BackendContext",
|
| 18 |
+
"SpatialContext",
|
| 19 |
+
"BoundingBoxExtent",
|
| 20 |
+
"TemporalContext",
|
| 21 |
+
"FakeMetadata",
|
| 22 |
+
"FetchType",
|
| 23 |
+
]
|
openeo_gfmap/backend.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Backend Contct.
|
| 2 |
+
|
| 3 |
+
Defines on which backend the pipeline is being currently used.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from typing import Callable, Dict, Optional
|
| 11 |
+
|
| 12 |
+
import openeo
|
| 13 |
+
|
| 14 |
+
_log = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Backend(Enum):
|
| 18 |
+
"""Enumerating the backends supported by the Mapping Framework."""
|
| 19 |
+
|
| 20 |
+
TERRASCOPE = "terrascope"
|
| 21 |
+
EODC = "eodc" # Dask implementation. Do not test on this yet.
|
| 22 |
+
CDSE = "cdse" # Terrascope implementation (pyspark) #URL: openeo.dataspace.copernicus.eu (need to register)
|
| 23 |
+
CDSE_STAGING = "cdse-staging"
|
| 24 |
+
LOCAL = "local" # Based on the same components of EODc
|
| 25 |
+
FED = "fed" # Federation backend
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class BackendContext:
|
| 30 |
+
"""Backend context and information.
|
| 31 |
+
|
| 32 |
+
Containing backend related information useful for the framework to
|
| 33 |
+
adapt the process graph.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
backend: Backend
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _create_connection(
|
| 40 |
+
url: str, *, env_var_suffix: str, connect_kwargs: Optional[dict] = None
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
Generic helper to create an openEO connection
|
| 44 |
+
with support for multiple client credential configurations from environment variables
|
| 45 |
+
"""
|
| 46 |
+
connection = openeo.connect(url, **(connect_kwargs or {}))
|
| 47 |
+
|
| 48 |
+
if (
|
| 49 |
+
os.environ.get("OPENEO_AUTH_METHOD") == "client_credentials"
|
| 50 |
+
and f"OPENEO_AUTH_CLIENT_ID_{env_var_suffix}" in os.environ
|
| 51 |
+
):
|
| 52 |
+
# Support for multiple client credentials configs from env vars
|
| 53 |
+
client_id = os.environ[f"OPENEO_AUTH_CLIENT_ID_{env_var_suffix}"]
|
| 54 |
+
client_secret = os.environ[f"OPENEO_AUTH_CLIENT_SECRET_{env_var_suffix}"]
|
| 55 |
+
provider_id = os.environ.get(f"OPENEO_AUTH_PROVIDER_ID_{env_var_suffix}")
|
| 56 |
+
_log.info(
|
| 57 |
+
f"Doing client credentials from env var with {env_var_suffix=} {provider_id} {client_id=} {len(client_secret)=} "
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
connection.authenticate_oidc_client_credentials(
|
| 61 |
+
client_id=client_id, client_secret=client_secret, provider_id=provider_id
|
| 62 |
+
)
|
| 63 |
+
else:
|
| 64 |
+
# Standard authenticate_oidc procedure: refresh token, device code or default env var handling
|
| 65 |
+
# See https://open-eo.github.io/openeo-python-client/auth.html#oidc-authentication-dynamic-method-selection
|
| 66 |
+
|
| 67 |
+
# Use a shorter max poll time by default to alleviate the default impression that the test seem to hang
|
| 68 |
+
# because of the OIDC device code poll loop.
|
| 69 |
+
max_poll_time = int(
|
| 70 |
+
os.environ.get("OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME") or 30
|
| 71 |
+
)
|
| 72 |
+
connection.authenticate_oidc(max_poll_time=max_poll_time)
|
| 73 |
+
return connection
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def vito_connection() -> openeo.Connection:
|
| 77 |
+
"""Performs a connection to the VITO backend using the oidc authentication."""
|
| 78 |
+
return _create_connection(
|
| 79 |
+
url="openeo.vito.be",
|
| 80 |
+
env_var_suffix="VITO",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def cdse_connection() -> openeo.Connection:
|
| 85 |
+
"""Performs a connection to the CDSE backend using oidc authentication."""
|
| 86 |
+
return _create_connection(
|
| 87 |
+
url="openeo.dataspace.copernicus.eu",
|
| 88 |
+
env_var_suffix="CDSE",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def cdse_staging_connection() -> openeo.Connection:
|
| 93 |
+
"""Performs a connection to the CDSE backend using oidc authentication."""
|
| 94 |
+
return _create_connection(
|
| 95 |
+
url="openeo-staging.dataspace.copernicus.eu",
|
| 96 |
+
env_var_suffix="CDSE_STAGING",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def eodc_connection() -> openeo.Connection:
|
| 101 |
+
"""Perfroms a connection to the EODC backend using the oidc authentication."""
|
| 102 |
+
return _create_connection(
|
| 103 |
+
url="https://openeo.eodc.eu/openeo/1.1.0",
|
| 104 |
+
env_var_suffix="EODC",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def fed_connection() -> openeo.Connection:
|
| 109 |
+
"""Performs a connection to the OpenEO federated backend using the oidc
|
| 110 |
+
authentication."""
|
| 111 |
+
return _create_connection(
|
| 112 |
+
url="openeofed.dataspace.copernicus.eu/",
|
| 113 |
+
env_var_suffix="FED",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
BACKEND_CONNECTIONS: Dict[Backend, Callable] = {
|
| 118 |
+
Backend.TERRASCOPE: vito_connection,
|
| 119 |
+
Backend.CDSE: cdse_connection,
|
| 120 |
+
Backend.CDSE_STAGING: cdse_staging_connection,
|
| 121 |
+
Backend.FED: fed_connection,
|
| 122 |
+
}
|
openeo_gfmap/fetching.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Main file for extractions and pre-processing of data through OpenEO
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Callable
|
| 6 |
+
|
| 7 |
+
import openeo
|
| 8 |
+
|
| 9 |
+
from openeo_gfmap import BackendContext
|
| 10 |
+
from openeo_gfmap.spatial import SpatialContext
|
| 11 |
+
from openeo_gfmap.temporal import TemporalContext
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FetchType(Enum):
|
| 15 |
+
"""Enumerates the different types of extraction. There are three types of
|
| 16 |
+
enumerations.
|
| 17 |
+
|
| 18 |
+
* TILE: Tile based extractions, getting the data for a dense part. The
|
| 19 |
+
output of such fetching process in a dense DataCube.
|
| 20 |
+
* POINT: Point based extractions. From a datasets of polygons, gets sparse
|
| 21 |
+
extractions and performs spatial aggregation on the selected polygons. The
|
| 22 |
+
output of such fetching process is a VectorCube, that can be used to get
|
| 23 |
+
a pandas.DataFrame
|
| 24 |
+
* POLYGON: Patch based extractions, returning a VectorCube of sparsed
|
| 25 |
+
patches. This can be retrieved as multiple NetCDF files from one job.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
TILE = "tile"
|
| 29 |
+
POINT = "point"
|
| 30 |
+
POLYGON = "polygon"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class CollectionFetcher:
|
| 34 |
+
"""Base class to fetch a particular collection.
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
backend_context: BackendContext
|
| 39 |
+
Information about the backend in use, useful in certain cases.
|
| 40 |
+
bands: list
|
| 41 |
+
List of band names to load from that collection.
|
| 42 |
+
collection_fetch: Callable
|
| 43 |
+
Function defining how to fetch a collection for a specific backend,
|
| 44 |
+
the function accepts the following parameters: connection,
|
| 45 |
+
spatial extent, temporal extent, bands and additional parameters.
|
| 46 |
+
collection_preprocessing: Callable
|
| 47 |
+
Function defining how to harmonize the data of a collection in a
|
| 48 |
+
backend. For example, this function could rename the bands as they
|
| 49 |
+
can be different for every backend/collection (SENTINEL2_L2A or
|
| 50 |
+
SENTINEL2_L2A_SENTINELHUB). Accepts the following parameters:
|
| 51 |
+
datacube (of pre-fetched collection) and additional parameters.
|
| 52 |
+
colection_params: dict
|
| 53 |
+
Additional parameters encoded within a dictionnary that will be
|
| 54 |
+
passed in the fetch and preprocessing function.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
backend_context: BackendContext,
|
| 60 |
+
bands: list,
|
| 61 |
+
collection_fetch: Callable,
|
| 62 |
+
collection_preprocessing: Callable,
|
| 63 |
+
**collection_params,
|
| 64 |
+
):
|
| 65 |
+
self.backend_contect = backend_context
|
| 66 |
+
self.bands = bands
|
| 67 |
+
self.fetcher = collection_fetch
|
| 68 |
+
self.processing = collection_preprocessing
|
| 69 |
+
self.params = collection_params
|
| 70 |
+
|
| 71 |
+
def get_cube(
|
| 72 |
+
self,
|
| 73 |
+
connection: openeo.Connection,
|
| 74 |
+
spatial_context: SpatialContext,
|
| 75 |
+
temporal_context: TemporalContext,
|
| 76 |
+
) -> openeo.DataCube:
|
| 77 |
+
"""Retrieve a data cube from the given spatial and temporal context.
|
| 78 |
+
|
| 79 |
+
Parameters
|
| 80 |
+
----------
|
| 81 |
+
connection: openeo.Connection
|
| 82 |
+
A connection to an OpenEO backend. The backend provided must be the
|
| 83 |
+
same as the one this extractor class is configured for.
|
| 84 |
+
spatial_extent: SpatialContext
|
| 85 |
+
Either a GeoJSON collection on which spatial filtering will be
|
| 86 |
+
applied or a bounding box with an EPSG code. If a bounding box is
|
| 87 |
+
provided, no filtering is applied and the entirety of the data is
|
| 88 |
+
fetched for that region.
|
| 89 |
+
temporal_extent: TemporalContext
|
| 90 |
+
The begin and end date of the extraction.
|
| 91 |
+
"""
|
| 92 |
+
collection_data = self.fetcher(
|
| 93 |
+
connection, spatial_context, temporal_context, self.bands, **self.params
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
preprocessed_data = self.processing(collection_data, **self.params)
|
| 97 |
+
|
| 98 |
+
return preprocessed_data
|
openeo_gfmap/metadata.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Metadata utilities related to the usage of a DataCube. Used to interract
|
| 2 |
+
with the OpenEO backends and cover some shortcomings.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class FakeMetadata:
|
| 10 |
+
"""Fake metdata object used for datacubes fetched from STAC catalogues.
|
| 11 |
+
This is used as a temporal fix for OpenEO backend shortcomings, but
|
| 12 |
+
will become unused with the time.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
band_names: list
|
| 16 |
+
|
| 17 |
+
def rename_labels(self, _, target, source):
|
| 18 |
+
"""Rename the labels of the band dimension."""
|
| 19 |
+
mapping = dict(zip(target, source))
|
| 20 |
+
band_names = self.band_names.copy()
|
| 21 |
+
for idx, name in enumerate(band_names):
|
| 22 |
+
if name in target:
|
| 23 |
+
self.band_names[idx] = mapping[name]
|
| 24 |
+
return self
|
openeo_gfmap/spatial.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Definitions of spatial context, either point-based or spatial"""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
from geojson import GeoJSON
|
| 7 |
+
from shapely.geometry import Polygon, box
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class BoundingBoxExtent:
|
| 12 |
+
"""Definition of a bounding box as accepted by OpenEO
|
| 13 |
+
|
| 14 |
+
Contains the minx, miny, maxx, maxy coordinates expressed as west, south
|
| 15 |
+
east, north. The EPSG is also defined.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
west: float
|
| 19 |
+
south: float
|
| 20 |
+
east: float
|
| 21 |
+
north: float
|
| 22 |
+
epsg: int = 4326
|
| 23 |
+
|
| 24 |
+
def __dict__(self):
|
| 25 |
+
return {
|
| 26 |
+
"west": self.west,
|
| 27 |
+
"south": self.south,
|
| 28 |
+
"east": self.east,
|
| 29 |
+
"north": self.north,
|
| 30 |
+
"crs": f"EPSG:{self.epsg}",
|
| 31 |
+
"srs": f"EPSG:{self.epsg}",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def __iter__(self):
|
| 35 |
+
return iter(
|
| 36 |
+
[
|
| 37 |
+
("west", self.west),
|
| 38 |
+
("south", self.south),
|
| 39 |
+
("east", self.east),
|
| 40 |
+
("north", self.north),
|
| 41 |
+
("crs", f"EPSG:{self.epsg}"),
|
| 42 |
+
("srs", f"EPSG:{self.epsg}"),
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def to_geometry(self) -> Polygon:
|
| 47 |
+
return box(self.west, self.south, self.east, self.north)
|
| 48 |
+
|
| 49 |
+
def to_geojson(self) -> GeoJSON:
|
| 50 |
+
return self.to_geometry().__geo_interface__
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
SpatialContext = Union[GeoJSON, BoundingBoxExtent, str]
|
openeo_gfmap/temporal.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Definitions of temporal context"""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class TemporalContext:
|
| 9 |
+
"""Temporal context is defined by a `start_date` and `end_date` values.
|
| 10 |
+
|
| 11 |
+
The value must be encoded on a YYYY-mm-dd format, e.g. 2020-01-01
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
start_date: str
|
| 15 |
+
end_date: str
|
| 16 |
+
|
| 17 |
+
def to_datetime(self):
|
| 18 |
+
"""Converts the temporal context to a tuple of datetime objects."""
|
| 19 |
+
return (
|
| 20 |
+
datetime.strptime(self.start_date, "%Y-%m-%d"),
|
| 21 |
+
datetime.strptime(self.end_date, "%Y-%m-%d"),
|
| 22 |
+
)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "crop-map"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Application for Crop Type Mapping"
|
| 5 |
+
requires-python = ">=3.12"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"folium>=0.20.0",
|
| 8 |
+
"geojson>=3.2.0",
|
| 9 |
+
"geopandas>=1.1.2",
|
| 10 |
+
"joblib>=1.5.3",
|
| 11 |
+
"matplotlib>=3.10.8",
|
| 12 |
+
"openeo>=0.47.0",
|
| 13 |
+
"rasterio>=1.5.0",
|
| 14 |
+
"streamlit>=1.53.1",
|
| 15 |
+
"streamlit-folium>=0.26.1",
|
| 16 |
+
]
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
worldcereal/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
from ._version import __version__
|
| 4 |
+
|
| 5 |
+
__all__ = ["__version__"]
|
| 6 |
+
|
| 7 |
+
SUPPORTED_SEASONS = [
|
| 8 |
+
"tc-s1",
|
| 9 |
+
"tc-s2",
|
| 10 |
+
"tc-annual",
|
| 11 |
+
"custom",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
SEASONAL_MAPPING = {
|
| 15 |
+
"tc-s1": "S1",
|
| 16 |
+
"tc-s2": "S2",
|
| 17 |
+
"tc-annual": "ANNUAL",
|
| 18 |
+
"custom": "custom",
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Default buffer (days) prior to
|
| 23 |
+
# season start
|
| 24 |
+
SEASON_PRIOR_BUFFER = {
|
| 25 |
+
"tc-s1": 0,
|
| 26 |
+
"tc-s2": 0,
|
| 27 |
+
"tc-annual": 0,
|
| 28 |
+
"custom": 0,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Default buffer (days) after
|
| 33 |
+
# season end
|
| 34 |
+
SEASON_POST_BUFFER = {
|
| 35 |
+
"tc-s1": 0,
|
| 36 |
+
"tc-s2": 0,
|
| 37 |
+
"tc-annual": 0,
|
| 38 |
+
"custom": 0,
|
| 39 |
+
}
|
worldcereal/_version.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
__version__ = "2.4.1"
|
worldcereal/job.py
ADDED
|
@@ -0,0 +1,960 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Executing inference jobs on the OpenEO backend.
|
| 2 |
+
|
| 3 |
+
Possible entry points for inference in this module:
|
| 4 |
+
- `generate_map`: This function is used to generate a map for a single patch.
|
| 5 |
+
It creates one OpenEO job and processes the inference for the specified
|
| 6 |
+
spatial and temporal extent.
|
| 7 |
+
- `collect_inputs`: This function is used to collect preprocessed inputs
|
| 8 |
+
without performing inference. It retrieves the required data for further
|
| 9 |
+
processing or analysis.
|
| 10 |
+
- `run_largescale_inference`: This function utilizes a job manager to
|
| 11 |
+
orchestrate and execute multiple inference jobs automatically, enabling
|
| 12 |
+
efficient large-scale processing.
|
| 13 |
+
- `setup_inference_job_manager`: This function prepares the job manager
|
| 14 |
+
and job database for large-scale inference jobs. It sets up the necessary
|
| 15 |
+
infrastructure to manage and track jobs in a notebook environment.
|
| 16 |
+
Used in the WorldCereal demo notebooks.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import json
|
| 21 |
+
import shutil
|
| 22 |
+
from copy import deepcopy
|
| 23 |
+
from functools import partial
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Callable, Dict, List, Literal, Optional, Union
|
| 26 |
+
|
| 27 |
+
import geopandas as gpd
|
| 28 |
+
import openeo
|
| 29 |
+
import pandas as pd
|
| 30 |
+
from loguru import logger
|
| 31 |
+
from openeo import BatchJob
|
| 32 |
+
from openeo.extra.job_management import CsvJobDatabase, MultiBackendJobManager
|
| 33 |
+
from openeo_gfmap import Backend, BackendContext, BoundingBoxExtent, TemporalContext
|
| 34 |
+
from openeo_gfmap.backend import BACKEND_CONNECTIONS
|
| 35 |
+
from pydantic import BaseModel
|
| 36 |
+
from typing_extensions import TypedDict
|
| 37 |
+
|
| 38 |
+
from worldcereal.openeo.mapping import _cropland_map, _croptype_map, _embeddings_map
|
| 39 |
+
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs
|
| 40 |
+
from worldcereal.parameters import (
|
| 41 |
+
CropLandParameters,
|
| 42 |
+
CropTypeParameters,
|
| 43 |
+
EmbeddingsParameters,
|
| 44 |
+
WorldCerealProductType,
|
| 45 |
+
)
|
| 46 |
+
from worldcereal.utils.models import load_model_lut
|
| 47 |
+
|
| 48 |
+
ONNX_DEPS_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/onnx_deps_python311.zip"
|
| 49 |
+
FEATURE_DEPS_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/torch_deps_python311.zip"
|
| 50 |
+
INFERENCE_JOB_OPTIONS = {
|
| 51 |
+
"driver-memory": "4g",
|
| 52 |
+
"executor-memory": "2g",
|
| 53 |
+
"executor-memoryOverhead": "3g",
|
| 54 |
+
"max-executors": 20,
|
| 55 |
+
"python-memory": "disable",
|
| 56 |
+
"soft-errors": 0.1,
|
| 57 |
+
"image-name": "python311",
|
| 58 |
+
"udf-dependency-archives": [
|
| 59 |
+
f"{ONNX_DEPS_URL}#onnx_deps",
|
| 60 |
+
f"{FEATURE_DEPS_URL}#feature_deps",
|
| 61 |
+
],
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class WorldCerealProduct(TypedDict):
|
| 66 |
+
"""Dataclass representing a WorldCereal inference product.
|
| 67 |
+
|
| 68 |
+
Attributes
|
| 69 |
+
----------
|
| 70 |
+
url: str
|
| 71 |
+
URL to the product.
|
| 72 |
+
type: WorldCerealProductType
|
| 73 |
+
Type of the product. Either cropland or croptype.
|
| 74 |
+
temporal_extent: TemporalContext
|
| 75 |
+
Period of time for which the product has been generated.
|
| 76 |
+
path: Optional[Path]
|
| 77 |
+
Path to the downloaded product.
|
| 78 |
+
lut: Optional[Dict]
|
| 79 |
+
Look-up table for the product.
|
| 80 |
+
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
url: str
|
| 84 |
+
type: WorldCerealProductType
|
| 85 |
+
temporal_extent: TemporalContext
|
| 86 |
+
path: Optional[Path]
|
| 87 |
+
lut: Optional[Dict]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class InferenceResults(BaseModel):
|
| 91 |
+
"""Dataclass to store the results of the WorldCereal job.
|
| 92 |
+
|
| 93 |
+
Attributes
|
| 94 |
+
----------
|
| 95 |
+
job_id : str
|
| 96 |
+
Job ID of the finished OpenEO job.
|
| 97 |
+
products: Dict[str, WorldCerealProduct]
|
| 98 |
+
Dictionary with the different products.
|
| 99 |
+
metadata: Optional[Path]
|
| 100 |
+
Path to metadata file, if it was downloaded locally.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
job_id: str
|
| 104 |
+
products: Dict[str, WorldCerealProduct]
|
| 105 |
+
metadata: Optional[Path]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class InferenceJobManager(MultiBackendJobManager):
|
| 109 |
+
"""A job manager for executing large-scale WorldCereal inference jobs on the OpenEO backend.
|
| 110 |
+
Based on official MultiBackendJobManager with extension of how results are downloaded
|
| 111 |
+
and named.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def generate_output_path_inference(
|
| 116 |
+
cls,
|
| 117 |
+
root_folder: Path,
|
| 118 |
+
geometry_index: int,
|
| 119 |
+
row: pd.Series,
|
| 120 |
+
asset_id: Optional[str] = None,
|
| 121 |
+
) -> Path:
|
| 122 |
+
"""Method to generate the output path for inference jobs.
|
| 123 |
+
|
| 124 |
+
Parameters
|
| 125 |
+
----------
|
| 126 |
+
root_folder : Path
|
| 127 |
+
root folder where the output parquet file will be saved
|
| 128 |
+
geometry_index : int
|
| 129 |
+
For point extractions, only one asset (a geoparquet file) is generated per job.
|
| 130 |
+
Therefore geometry_index is always 0. It has to be included in the function signature
|
| 131 |
+
to be compatible with the GFMapJobManager
|
| 132 |
+
row : pd.Series
|
| 133 |
+
the current job row from the GFMapJobManager
|
| 134 |
+
asset_id : str, optional
|
| 135 |
+
Needed for compatibility with GFMapJobManager but not used.
|
| 136 |
+
|
| 137 |
+
Returns
|
| 138 |
+
-------
|
| 139 |
+
Path
|
| 140 |
+
output path for the point extractions parquet file
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
tile_name = row.tile_name
|
| 144 |
+
|
| 145 |
+
# Create the subfolder to store the output
|
| 146 |
+
subfolder = root_folder / str(tile_name)
|
| 147 |
+
subfolder.mkdir(parents=True, exist_ok=True)
|
| 148 |
+
|
| 149 |
+
return subfolder
|
| 150 |
+
|
| 151 |
+
def on_job_done(self, job: BatchJob, row):
|
| 152 |
+
logger.info(f"Job {job.job_id} completed")
|
| 153 |
+
output_dir = self.generate_output_path_inference(self._root_dir, 0, row)
|
| 154 |
+
|
| 155 |
+
# Get job results
|
| 156 |
+
job_result = job.get_results()
|
| 157 |
+
|
| 158 |
+
# Get the products
|
| 159 |
+
assets = job_result.get_assets()
|
| 160 |
+
for asset in assets:
|
| 161 |
+
asset_name = asset.name.split(".")[0].split("_")[0]
|
| 162 |
+
asset_type = asset_name.split("-")[0]
|
| 163 |
+
asset_type = getattr(WorldCerealProductType, asset_type.upper())
|
| 164 |
+
filepath = asset.download(target=output_dir)
|
| 165 |
+
|
| 166 |
+
# We want to add the tile name to the filename
|
| 167 |
+
new_filepath = filepath.parent / f"{filepath.stem}_{row.tile_name}.tif"
|
| 168 |
+
shutil.move(filepath, new_filepath)
|
| 169 |
+
|
| 170 |
+
job_metadata = job.describe()
|
| 171 |
+
result_metadata = job_result.get_metadata()
|
| 172 |
+
job_metadata_path = output_dir / f"job_{job.job_id}.json"
|
| 173 |
+
result_metadata_path = output_dir / f"result_{job.job_id}.json"
|
| 174 |
+
|
| 175 |
+
with job_metadata_path.open("w", encoding="utf-8") as f:
|
| 176 |
+
json.dump(job_metadata, f, ensure_ascii=False)
|
| 177 |
+
with result_metadata_path.open("w", encoding="utf-8") as f:
|
| 178 |
+
json.dump(result_metadata, f, ensure_ascii=False)
|
| 179 |
+
|
| 180 |
+
# post_job_action(output_file)
|
| 181 |
+
logger.success("Job completed")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def create_inference_process_graph(
|
| 185 |
+
spatial_extent: BoundingBoxExtent,
|
| 186 |
+
temporal_extent: TemporalContext,
|
| 187 |
+
product_type: WorldCerealProductType = WorldCerealProductType.CROPLAND,
|
| 188 |
+
cropland_parameters: CropLandParameters = CropLandParameters(),
|
| 189 |
+
croptype_parameters: CropTypeParameters = CropTypeParameters(),
|
| 190 |
+
s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
|
| 191 |
+
out_format: str = "GTiff",
|
| 192 |
+
backend_context: BackendContext = BackendContext(Backend.CDSE),
|
| 193 |
+
tile_size: Optional[int] = 128,
|
| 194 |
+
target_epsg: Optional[int] = None,
|
| 195 |
+
connection: Optional[openeo.Connection] = None,
|
| 196 |
+
) -> List[openeo.DataCube]:
|
| 197 |
+
"""Wrapper function that creates the inference openEO process graph.
|
| 198 |
+
|
| 199 |
+
Parameters
|
| 200 |
+
----------
|
| 201 |
+
spatial_extent : BoundingBoxExtent
|
| 202 |
+
spatial extent of the map
|
| 203 |
+
temporal_extent : TemporalContext
|
| 204 |
+
temporal range to consider
|
| 205 |
+
product_type : WorldCerealProductType, optional
|
| 206 |
+
product describer, by default WorldCerealProductType.CROPLAND
|
| 207 |
+
cropland_parameters: CropLandParameters
|
| 208 |
+
Parameters for the cropland product inference pipeline.
|
| 209 |
+
croptype_parameters: Optional[CropTypeParameters]
|
| 210 |
+
Parameters for the croptype product inference pipeline. Only required
|
| 211 |
+
whenever `product_type` is set to `WorldCerealProductType.CROPTYPE`,
|
| 212 |
+
will be ignored otherwise.
|
| 213 |
+
s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]]
|
| 214 |
+
Sentinel-1 orbit state to use for the inference. If not provided,
|
| 215 |
+
the orbit state will be dynamically determined based on the spatial extent.
|
| 216 |
+
out_format : str, optional
|
| 217 |
+
Output format, by default "GTiff"
|
| 218 |
+
backend_context : BackendContext
|
| 219 |
+
backend to run the job on, by default CDSE.
|
| 220 |
+
tile_size: int, optional
|
| 221 |
+
Tile size to use for the data loading in OpenEO, by default 128.
|
| 222 |
+
target_epsg: Optional[int] = None
|
| 223 |
+
EPSG code to use for the output products. If not provided, the
|
| 224 |
+
default EPSG will be used.
|
| 225 |
+
connection: Optional[openeo.Connection] = None,
|
| 226 |
+
Optional OpenEO connection to use. If not provided, a new connection
|
| 227 |
+
will be created based on the backend_context.
|
| 228 |
+
|
| 229 |
+
Returns
|
| 230 |
+
-------
|
| 231 |
+
List[openeo.DataCube]
|
| 232 |
+
A list with one or more result objects or a list of DataCube objects, representing the inference
|
| 233 |
+
process graph. This object can be used to execute the job on the OpenEO backend.
|
| 234 |
+
The result will be a DataCube with the classification results.
|
| 235 |
+
|
| 236 |
+
Raises
|
| 237 |
+
------
|
| 238 |
+
ValueError
|
| 239 |
+
if the product is not supported
|
| 240 |
+
ValueError
|
| 241 |
+
if the out_format is not supported
|
| 242 |
+
"""
|
| 243 |
+
if product_type not in WorldCerealProductType:
|
| 244 |
+
raise ValueError(f"Product {product_type.value} not supported.")
|
| 245 |
+
|
| 246 |
+
if out_format not in ["GTiff", "NetCDF"]:
|
| 247 |
+
raise ValueError(f"Format {format} not supported.")
|
| 248 |
+
|
| 249 |
+
# Make a connection to the OpenEO backend
|
| 250 |
+
if connection is None:
|
| 251 |
+
connection = BACKEND_CONNECTIONS[backend_context.backend]()
|
| 252 |
+
|
| 253 |
+
# Preparing the input cube for inference
|
| 254 |
+
inputs = worldcereal_preprocessed_inputs(
|
| 255 |
+
connection=connection,
|
| 256 |
+
backend_context=backend_context,
|
| 257 |
+
spatial_extent=spatial_extent,
|
| 258 |
+
temporal_extent=temporal_extent,
|
| 259 |
+
tile_size=tile_size,
|
| 260 |
+
s1_orbit_state=s1_orbit_state,
|
| 261 |
+
target_epsg=target_epsg,
|
| 262 |
+
# disable_meteo=True,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Spatial filtering
|
| 266 |
+
inputs = inputs.filter_bbox(dict(spatial_extent))
|
| 267 |
+
|
| 268 |
+
# Construct the feature extraction and model inference pipeline
|
| 269 |
+
if product_type == WorldCerealProductType.CROPLAND:
|
| 270 |
+
results = _cropland_map(
|
| 271 |
+
inputs,
|
| 272 |
+
temporal_extent,
|
| 273 |
+
cropland_parameters=cropland_parameters,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
elif product_type == WorldCerealProductType.CROPTYPE:
|
| 277 |
+
if not isinstance(croptype_parameters, CropTypeParameters):
|
| 278 |
+
raise ValueError(
|
| 279 |
+
f"Please provide a valid `croptype_parameters` parameter."
|
| 280 |
+
f" Received: {croptype_parameters}"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Generate crop type map with optional cropland masking
|
| 284 |
+
results = _croptype_map(
|
| 285 |
+
inputs,
|
| 286 |
+
temporal_extent,
|
| 287 |
+
cropland_parameters=cropland_parameters,
|
| 288 |
+
croptype_parameters=croptype_parameters,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
return results
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def create_embeddings_process_graph(
|
| 295 |
+
spatial_extent: BoundingBoxExtent,
|
| 296 |
+
temporal_extent: TemporalContext,
|
| 297 |
+
embeddings_parameters: EmbeddingsParameters = EmbeddingsParameters(),
|
| 298 |
+
s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
|
| 299 |
+
out_format: str = "GTiff",
|
| 300 |
+
backend_context: BackendContext = BackendContext(Backend.CDSE),
|
| 301 |
+
tile_size: Optional[int] = 128,
|
| 302 |
+
target_epsg: Optional[int] = None,
|
| 303 |
+
scale_uint16: bool = True,
|
| 304 |
+
) -> openeo.DataCube:
|
| 305 |
+
"""Create an OpenEO process graph for generating embeddings.
|
| 306 |
+
|
| 307 |
+
Parameters
|
| 308 |
+
----------
|
| 309 |
+
spatial_extent : BoundingBoxExtent
|
| 310 |
+
Spatial extent of the map.
|
| 311 |
+
temporal_extent : TemporalContext
|
| 312 |
+
Temporal range to consider.
|
| 313 |
+
embeddings_parameters : EmbeddingsParameters, optional
|
| 314 |
+
Parameters for the embeddings product inference pipeline, by default EmbeddingsParameters().
|
| 315 |
+
s1_orbit_state : Optional[Literal["ASCENDING", "DESCENDING"]], optional
|
| 316 |
+
Sentinel-1 orbit state to use for the inference. If not provided, the orbit state will be dynamically determined based on the spatial extent, by default None.
|
| 317 |
+
out_format : str, optional
|
| 318 |
+
Output format, by default "GTiff".
|
| 319 |
+
backend_context : BackendContext, optional
|
| 320 |
+
Backend to run the job on, by default BackendContext(Backend.CDSE).
|
| 321 |
+
tile_size : Optional[int], optional
|
| 322 |
+
Tile size to use for the data loading in OpenEO, by default 128.
|
| 323 |
+
target_epsg : Optional[int], optional
|
| 324 |
+
EPSG code to use for the output products. If not provided, the default EPSG will be used.
|
| 325 |
+
scale_uint16 : bool, optional
|
| 326 |
+
Whether to scale the embeddings to uint16 for memory optimization, by default True.
|
| 327 |
+
|
| 328 |
+
Returns
|
| 329 |
+
-------
|
| 330 |
+
openeo.DataCube
|
| 331 |
+
DataCube object representing the embeddings process graph. This object can be used to execute the job on the OpenEO backend. The result will be a DataCube with the embeddings.
|
| 332 |
+
|
| 333 |
+
Raises
|
| 334 |
+
------
|
| 335 |
+
ValueError
|
| 336 |
+
If the output format is not supported.
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
if out_format not in ["GTiff", "NetCDF"]:
|
| 340 |
+
raise ValueError(f"Format {format} not supported.")
|
| 341 |
+
|
| 342 |
+
# Make a connection to the OpenEO backend
|
| 343 |
+
connection = BACKEND_CONNECTIONS[backend_context.backend]()
|
| 344 |
+
|
| 345 |
+
# Preparing the input cube for inference
|
| 346 |
+
inputs = worldcereal_preprocessed_inputs(
|
| 347 |
+
connection=connection,
|
| 348 |
+
backend_context=backend_context,
|
| 349 |
+
spatial_extent=spatial_extent,
|
| 350 |
+
temporal_extent=temporal_extent,
|
| 351 |
+
tile_size=tile_size,
|
| 352 |
+
s1_orbit_state=s1_orbit_state,
|
| 353 |
+
target_epsg=target_epsg,
|
| 354 |
+
# disable_meteo=True,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# Spatial filtering
|
| 358 |
+
inputs = inputs.filter_bbox(dict(spatial_extent))
|
| 359 |
+
|
| 360 |
+
embeddings = _embeddings_map(
|
| 361 |
+
inputs,
|
| 362 |
+
temporal_extent,
|
| 363 |
+
embeddings_parameters=embeddings_parameters,
|
| 364 |
+
scale_uint16=scale_uint16,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Save the final result
|
| 368 |
+
embeddings = embeddings.save_result(
|
| 369 |
+
format=out_format,
|
| 370 |
+
options=dict(
|
| 371 |
+
filename_prefix=f"WorldCereal_Embeddings_{temporal_extent.start_date}_{temporal_extent.end_date}",
|
| 372 |
+
),
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
return embeddings
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def create_inputs_process_graph(
|
| 379 |
+
spatial_extent: BoundingBoxExtent,
|
| 380 |
+
temporal_extent: TemporalContext,
|
| 381 |
+
s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
|
| 382 |
+
out_format: str = "NetCDF",
|
| 383 |
+
backend_context: BackendContext = BackendContext(Backend.CDSE),
|
| 384 |
+
tile_size: Optional[int] = 128,
|
| 385 |
+
target_epsg: Optional[int] = None,
|
| 386 |
+
compositing_window: Literal["month", "dekad"] = "month",
|
| 387 |
+
) -> openeo.DataCube:
|
| 388 |
+
"""Wrapper function that creates the inputs openEO process graph.
|
| 389 |
+
|
| 390 |
+
Parameters
|
| 391 |
+
----------
|
| 392 |
+
spatial_extent : BoundingBoxExtent
|
| 393 |
+
spatial extent of the map
|
| 394 |
+
temporal_extent : TemporalContext
|
| 395 |
+
temporal range to consider
|
| 396 |
+
s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]]
|
| 397 |
+
Sentinel-1 orbit state to use for the inference. If not provided,
|
| 398 |
+
the orbit state will be dynamically determined based on the spatial extent.
|
| 399 |
+
out_format : str, optional
|
| 400 |
+
Output format, by default "NetCDF"
|
| 401 |
+
backend_context : BackendContext
|
| 402 |
+
backend to run the job on, by default CDSE.
|
| 403 |
+
tile_size: int, optional
|
| 404 |
+
Tile size to use for the data loading in OpenEO, by default 128.
|
| 405 |
+
target_epsg: Optional[int] = None
|
| 406 |
+
EPSG code to use for the output products. If not provided, the
|
| 407 |
+
default EPSG will be used.
|
| 408 |
+
compositing_window: Literal["month", "dekad"]
|
| 409 |
+
Compositing window to use for the data loading in OpenEO, by default
|
| 410 |
+
"month".
|
| 411 |
+
|
| 412 |
+
Returns
|
| 413 |
+
-------
|
| 414 |
+
openeo.DataCube
|
| 415 |
+
DataCube object representing the inputs process graph.
|
| 416 |
+
This object can be used to execute the job on the OpenEO backend.
|
| 417 |
+
The result will be a DataCube with the preprocessed inputs.
|
| 418 |
+
|
| 419 |
+
Raises
|
| 420 |
+
------
|
| 421 |
+
ValueError
|
| 422 |
+
if the out_format is not supported
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
if out_format not in ["GTiff", "NetCDF"]:
|
| 426 |
+
raise ValueError(f"Format {format} not supported.")
|
| 427 |
+
|
| 428 |
+
# Make a connection to the OpenEO backend
|
| 429 |
+
connection = BACKEND_CONNECTIONS[backend_context.backend]()
|
| 430 |
+
|
| 431 |
+
# Preparing the input cube for inference
|
| 432 |
+
inputs = worldcereal_preprocessed_inputs(
|
| 433 |
+
connection=connection,
|
| 434 |
+
backend_context=backend_context,
|
| 435 |
+
spatial_extent=spatial_extent,
|
| 436 |
+
temporal_extent=temporal_extent,
|
| 437 |
+
tile_size=tile_size,
|
| 438 |
+
s1_orbit_state=s1_orbit_state,
|
| 439 |
+
target_epsg=target_epsg,
|
| 440 |
+
compositing_window=compositing_window,
|
| 441 |
+
# disable_meteo=True,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Spatial filtering
|
| 445 |
+
inputs = inputs.filter_bbox(dict(spatial_extent))
|
| 446 |
+
|
| 447 |
+
# Save the final result
|
| 448 |
+
inputs = inputs.save_result(
|
| 449 |
+
format=out_format,
|
| 450 |
+
options=dict(
|
| 451 |
+
filename_prefix=f"preprocessed-inputs_{temporal_extent.start_date}_{temporal_extent.end_date}",
|
| 452 |
+
),
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
return inputs
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def create_inference_job(
|
| 459 |
+
row: pd.Series,
|
| 460 |
+
connection: openeo.Connection,
|
| 461 |
+
provider: str,
|
| 462 |
+
connection_provider: str,
|
| 463 |
+
product_type: WorldCerealProductType = WorldCerealProductType.CROPTYPE,
|
| 464 |
+
cropland_parameters: CropLandParameters = CropLandParameters(),
|
| 465 |
+
croptype_parameters: CropTypeParameters = CropTypeParameters(),
|
| 466 |
+
s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
|
| 467 |
+
target_epsg: Optional[int] = None,
|
| 468 |
+
job_options: Optional[dict] = None,
|
| 469 |
+
) -> BatchJob:
|
| 470 |
+
"""Create an OpenEO batch job for WorldCereal inference.
|
| 471 |
+
|
| 472 |
+
Parameters
|
| 473 |
+
----------
|
| 474 |
+
row : pd.Series
|
| 475 |
+
_description_
|
| 476 |
+
Contains at least the following fields:
|
| 477 |
+
- start_date: str, start date of the temporal extent
|
| 478 |
+
- end_date: str, end date of the temporal extent
|
| 479 |
+
- geometry: shapely.geometry, geometry of the spatial extent
|
| 480 |
+
- tile_name: str, name of the tile
|
| 481 |
+
- epsg: int, EPSG code of the spatial extent
|
| 482 |
+
- bounds_epsg: str representation of tuple,
|
| 483 |
+
bounds of the spatial extent in CRS as
|
| 484 |
+
specified by epsg attribute
|
| 485 |
+
connection : openeo.Connection
|
| 486 |
+
openEO connection to the backend
|
| 487 |
+
provider : str
|
| 488 |
+
unused but required for compatibility with MultiBackendJobManager
|
| 489 |
+
connection_provider : str
|
| 490 |
+
unused but required for compatibility with MultiBackendJobManager6
|
| 491 |
+
product_type : WorldCerealProductType, optional
|
| 492 |
+
Type of the WorldCereal product to generate, by default WorldCerealProductType.CROPTYPE
|
| 493 |
+
croptype_parameters : Optional[CropTypeParameters], optional
|
| 494 |
+
Parameters for the croptype product inference pipeline. Only required
|
| 495 |
+
whenever `product_type` is set to `WorldCerealProductType.CROPTYPE`,
|
| 496 |
+
will be ignored otherwise, by default None
|
| 497 |
+
cropland_parameters : Optional[CropLandParameters], optional
|
| 498 |
+
Parameters for the cropland product inference pipeline, by default None
|
| 499 |
+
s1_orbit_state : Optional[Literal["ASCENDING", "DESCENDING"]], optional
|
| 500 |
+
Sentinel-1 orbit state to use for the inference. If not provided, the
|
| 501 |
+
best orbit will be dynamically derived from the catalogue.
|
| 502 |
+
target_epsg : Optional[int], optional
|
| 503 |
+
EPSG code to reproject the data to. If not provided, the data will be
|
| 504 |
+
left in the original epsg as mentioned in the row.
|
| 505 |
+
job_options : Optional[dict], optional
|
| 506 |
+
Additional job options to pass to the OpenEO backend, by default None
|
| 507 |
+
|
| 508 |
+
Returns
|
| 509 |
+
-------
|
| 510 |
+
BatchJob
|
| 511 |
+
Batch job created on openEO backend.
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
# Get temporal and spatial extents from the row
|
| 515 |
+
temporal_extent = TemporalContext(start_date=row.start_date, end_date=row.end_date)
|
| 516 |
+
epsg = int(row.epsg)
|
| 517 |
+
bounds = eval(row.bounds_epsg)
|
| 518 |
+
spatial_extent = BoundingBoxExtent(
|
| 519 |
+
west=bounds[0], south=bounds[1], east=bounds[2], north=bounds[3], epsg=epsg
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
if target_epsg is None:
|
| 523 |
+
# If no target EPSG is provided, use the EPSG from the row
|
| 524 |
+
target_epsg = epsg
|
| 525 |
+
|
| 526 |
+
# Update default job options with the provided ones
|
| 527 |
+
inference_job_options = deepcopy(INFERENCE_JOB_OPTIONS)
|
| 528 |
+
if job_options is not None:
|
| 529 |
+
inference_job_options.update(job_options)
|
| 530 |
+
|
| 531 |
+
inference_result = create_inference_process_graph(
|
| 532 |
+
spatial_extent=spatial_extent,
|
| 533 |
+
temporal_extent=temporal_extent,
|
| 534 |
+
product_type=product_type,
|
| 535 |
+
croptype_parameters=croptype_parameters,
|
| 536 |
+
cropland_parameters=cropland_parameters,
|
| 537 |
+
s1_orbit_state=s1_orbit_state,
|
| 538 |
+
target_epsg=target_epsg,
|
| 539 |
+
connection=connection,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# Submit the job
|
| 543 |
+
return connection.create_job(
|
| 544 |
+
inference_result,
|
| 545 |
+
title=f"WorldCereal [{product_type.value}] job_{row.tile_name}",
|
| 546 |
+
description="Job that performs end-to-end WorldCereal inference",
|
| 547 |
+
additional=inference_job_options, # TODO: once openeo-python-client supports job_options, use that
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def generate_map(
|
| 552 |
+
spatial_extent: BoundingBoxExtent,
|
| 553 |
+
temporal_extent: TemporalContext,
|
| 554 |
+
output_dir: Optional[Union[Path, str]] = None,
|
| 555 |
+
product_type: WorldCerealProductType = WorldCerealProductType.CROPLAND,
|
| 556 |
+
cropland_parameters: CropLandParameters = CropLandParameters(),
|
| 557 |
+
croptype_parameters: CropTypeParameters = CropTypeParameters(),
|
| 558 |
+
out_format: str = "GTiff",
|
| 559 |
+
backend_context: BackendContext = BackendContext(Backend.CDSE),
|
| 560 |
+
tile_size: Optional[int] = 128,
|
| 561 |
+
job_options: Optional[dict] = None,
|
| 562 |
+
target_epsg: Optional[int] = None,
|
| 563 |
+
) -> InferenceResults:
|
| 564 |
+
"""Main function to generate a WorldCereal product.
|
| 565 |
+
|
| 566 |
+
Parameters
|
| 567 |
+
----------
|
| 568 |
+
spatial_extent : BoundingBoxExtent
|
| 569 |
+
spatial extent of the map
|
| 570 |
+
temporal_extent : TemporalContext
|
| 571 |
+
temporal range to consider
|
| 572 |
+
output_dir : Optional[Union[Path, str]]
|
| 573 |
+
path to directory where products should be downloaded to
|
| 574 |
+
product_type : WorldCerealProductType, optional
|
| 575 |
+
product describer, by default WorldCerealProductType.CROPLAND
|
| 576 |
+
cropland_parameters: CropLandParameters
|
| 577 |
+
Parameters for the cropland product inference pipeline.
|
| 578 |
+
croptype_parameters: Optional[CropTypeParameters]
|
| 579 |
+
Parameters for the croptype product inference pipeline. Only required
|
| 580 |
+
whenever `product_type` is set to `WorldCerealProductType.CROPTYPE`,
|
| 581 |
+
will be ignored otherwise.
|
| 582 |
+
out_format : str, optional
|
| 583 |
+
Output format, by default "GTiff"
|
| 584 |
+
backend_context : BackendContext
|
| 585 |
+
backend to run the job on, by default CDSE.
|
| 586 |
+
tile_size: int, optional
|
| 587 |
+
Tile size to use for the data loading in OpenEO, by default 128.
|
| 588 |
+
job_options: dict, optional
|
| 589 |
+
Additional job options to pass to the OpenEO backend, by default None
|
| 590 |
+
target_epsg: Optional[int] = None
|
| 591 |
+
EPSG code to use for the output products. If not provided, the
|
| 592 |
+
default EPSG will be used.
|
| 593 |
+
|
| 594 |
+
Returns
|
| 595 |
+
-------
|
| 596 |
+
InferenceResults
|
| 597 |
+
Results of the finished WorldCereal job.
|
| 598 |
+
|
| 599 |
+
Raises
|
| 600 |
+
------
|
| 601 |
+
ValueError
|
| 602 |
+
if the product is not supported
|
| 603 |
+
ValueError
|
| 604 |
+
if the out_format is not supported
|
| 605 |
+
"""
|
| 606 |
+
|
| 607 |
+
# Get a connection to the OpenEO backend
|
| 608 |
+
connection = BACKEND_CONNECTIONS[backend_context.backend]()
|
| 609 |
+
|
| 610 |
+
# Create the process graph
|
| 611 |
+
results = create_inference_process_graph(
|
| 612 |
+
spatial_extent=spatial_extent,
|
| 613 |
+
temporal_extent=temporal_extent,
|
| 614 |
+
product_type=product_type,
|
| 615 |
+
cropland_parameters=cropland_parameters,
|
| 616 |
+
croptype_parameters=croptype_parameters,
|
| 617 |
+
out_format=out_format,
|
| 618 |
+
backend_context=backend_context,
|
| 619 |
+
tile_size=tile_size,
|
| 620 |
+
target_epsg=target_epsg,
|
| 621 |
+
connection=connection,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
if output_dir is not None:
|
| 625 |
+
output_dir = Path(output_dir)
|
| 626 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 627 |
+
|
| 628 |
+
# Submit the job
|
| 629 |
+
inference_job_options = deepcopy(INFERENCE_JOB_OPTIONS)
|
| 630 |
+
if job_options is not None:
|
| 631 |
+
inference_job_options.update(job_options)
|
| 632 |
+
|
| 633 |
+
# Execute the job
|
| 634 |
+
job = connection.create_job(
|
| 635 |
+
results,
|
| 636 |
+
additional=inference_job_options, # TODO: once openeo-python-client supports job_options, use that
|
| 637 |
+
title=f"WorldCereal [{product_type.value}] job",
|
| 638 |
+
description="Job that performs end-to-end WorldCereal inference",
|
| 639 |
+
).start_and_wait()
|
| 640 |
+
|
| 641 |
+
# Get look-up tables
|
| 642 |
+
luts = {}
|
| 643 |
+
luts[WorldCerealProductType.CROPLAND.value] = load_model_lut(
|
| 644 |
+
cropland_parameters.classifier_parameters.classifier_url
|
| 645 |
+
)
|
| 646 |
+
if product_type == WorldCerealProductType.CROPTYPE:
|
| 647 |
+
luts[WorldCerealProductType.CROPTYPE.value] = load_model_lut(
|
| 648 |
+
croptype_parameters.classifier_parameters.classifier_url
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# Get job results
|
| 652 |
+
job_result = job.get_results()
|
| 653 |
+
|
| 654 |
+
# Get the products
|
| 655 |
+
assets = job_result.get_assets()
|
| 656 |
+
products = {}
|
| 657 |
+
for asset in assets:
|
| 658 |
+
asset_name = asset.name.split(".")[0].split("_")[0]
|
| 659 |
+
asset_type = asset_name.split("-")[0]
|
| 660 |
+
asset_type = getattr(WorldCerealProductType, asset_type.upper())
|
| 661 |
+
if output_dir is not None:
|
| 662 |
+
filepath = asset.download(target=output_dir)
|
| 663 |
+
else:
|
| 664 |
+
filepath = None
|
| 665 |
+
products[asset_name] = {
|
| 666 |
+
"url": asset.href,
|
| 667 |
+
"type": asset_type,
|
| 668 |
+
"temporal_extent": temporal_extent,
|
| 669 |
+
"path": filepath,
|
| 670 |
+
"lut": luts[asset_type.value],
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
# Download job metadata if output path is provided
|
| 674 |
+
if output_dir is not None:
|
| 675 |
+
metadata_file = output_dir / "job-results.json"
|
| 676 |
+
metadata_file.write_text(json.dumps(job_result.get_metadata()))
|
| 677 |
+
else:
|
| 678 |
+
metadata_file = None
|
| 679 |
+
|
| 680 |
+
# Compile InferenceResults and return
|
| 681 |
+
return InferenceResults(
|
| 682 |
+
job_id=job.job_id, products=products, metadata=metadata_file
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def collect_inputs(
|
| 687 |
+
spatial_extent: BoundingBoxExtent,
|
| 688 |
+
temporal_extent: TemporalContext,
|
| 689 |
+
output_path: Union[Path, str],
|
| 690 |
+
backend_context: BackendContext = BackendContext(Backend.CDSE),
|
| 691 |
+
tile_size: Optional[int] = 128,
|
| 692 |
+
job_options: Optional[dict] = None,
|
| 693 |
+
compositing_window: Literal["month", "dekad"] = "month",
|
| 694 |
+
):
|
| 695 |
+
"""Function to retrieve preprocessed inputs that are being
|
| 696 |
+
used in the generation of WorldCereal products.
|
| 697 |
+
|
| 698 |
+
Parameters
|
| 699 |
+
----------
|
| 700 |
+
spatial_extent : BoundingBoxExtent
|
| 701 |
+
spatial extent of the map
|
| 702 |
+
temporal_extent : TemporalContext
|
| 703 |
+
temporal range to consider
|
| 704 |
+
output_path : Union[Path, str]
|
| 705 |
+
output path to download the product to
|
| 706 |
+
backend_context : BackendContext
|
| 707 |
+
backend to run the job on, by default CDSE
|
| 708 |
+
tile_size: int, optional
|
| 709 |
+
Tile size to use for the data loading in OpenEO, by default 128
|
| 710 |
+
so it uses the OpenEO default setting.
|
| 711 |
+
job_options: dict, optional
|
| 712 |
+
Additional job options to pass to the OpenEO backend, by default None
|
| 713 |
+
compositing_window: Literal["month", "dekad"]
|
| 714 |
+
Compositing window to use for the data loading in OpenEO, by default
|
| 715 |
+
"month".
|
| 716 |
+
"""
|
| 717 |
+
|
| 718 |
+
# Make a connection to the OpenEO backend
|
| 719 |
+
connection = BACKEND_CONNECTIONS[backend_context.backend]()
|
| 720 |
+
|
| 721 |
+
# Preparing the input cube for the inference
|
| 722 |
+
inputs = worldcereal_preprocessed_inputs(
|
| 723 |
+
connection=connection,
|
| 724 |
+
backend_context=backend_context,
|
| 725 |
+
spatial_extent=spatial_extent,
|
| 726 |
+
temporal_extent=temporal_extent,
|
| 727 |
+
tile_size=tile_size,
|
| 728 |
+
validate_temporal_context=False,
|
| 729 |
+
compositing_window=compositing_window,
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
# Spatial filtering
|
| 733 |
+
inputs = inputs.filter_bbox(dict(spatial_extent))
|
| 734 |
+
|
| 735 |
+
JOB_OPTIONS = {
|
| 736 |
+
"driver-memory": "4g",
|
| 737 |
+
"executor-memory": "1g",
|
| 738 |
+
"executor-memoryOverhead": "1g",
|
| 739 |
+
"python-memory": "3g",
|
| 740 |
+
"soft-errors": 0.1,
|
| 741 |
+
}
|
| 742 |
+
if job_options is not None:
|
| 743 |
+
JOB_OPTIONS.update(job_options)
|
| 744 |
+
|
| 745 |
+
inputs.execute_batch(
|
| 746 |
+
outputfile=output_path,
|
| 747 |
+
out_format="NetCDF",
|
| 748 |
+
title="WorldCereal [collect_inputs] job",
|
| 749 |
+
description="Job that collects inputs for WorldCereal inference",
|
| 750 |
+
job_options=JOB_OPTIONS,
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def run_largescale_inference(
|
| 755 |
+
production_grid: Union[Path, gpd.GeoDataFrame],
|
| 756 |
+
output_dir: Union[Path, str],
|
| 757 |
+
product_type: WorldCerealProductType = WorldCerealProductType.CROPLAND,
|
| 758 |
+
cropland_parameters: CropLandParameters = CropLandParameters(),
|
| 759 |
+
croptype_parameters: CropTypeParameters = CropTypeParameters(),
|
| 760 |
+
backend_context: BackendContext = BackendContext(Backend.CDSE),
|
| 761 |
+
target_epsg: Optional[int] = None,
|
| 762 |
+
s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
|
| 763 |
+
job_options: Optional[dict] = None,
|
| 764 |
+
parallel_jobs: int = 2,
|
| 765 |
+
):
|
| 766 |
+
"""
|
| 767 |
+
Run large-scale inference jobs on the OpenEO backend.
|
| 768 |
+
This function orchestrates the execution of large-scale inference jobs
|
| 769 |
+
using a production grid (either a Parquet file or a GeoDataFrame) and specified parameters.
|
| 770 |
+
It manages job creation, tracking, and execution on the OpenEO backend.
|
| 771 |
+
|
| 772 |
+
Parameters
|
| 773 |
+
----------
|
| 774 |
+
production_grid : Union[Path, gpd.GeoDataFrame]
|
| 775 |
+
Path to the production grid file in Parquet format or a GeoDataFrame.
|
| 776 |
+
The grid must contain the required attributes: 'start_date', 'end_date',
|
| 777 |
+
'geometry', 'tile_name', 'epsg' and 'bounds_epsg'.
|
| 778 |
+
output_dir : Union[Path, str]
|
| 779 |
+
Directory where output files and job tracking information will be stored.
|
| 780 |
+
product_type : WorldCerealProductType
|
| 781 |
+
Type of product to generate. Defaults to WorldCerealProductType.CROPLAND.
|
| 782 |
+
cropland_parameters : CropLandParameters
|
| 783 |
+
Parameters for cropland inference.
|
| 784 |
+
croptype_parameters : CropTypeParameters
|
| 785 |
+
Parameters for crop type inference.
|
| 786 |
+
backend_context : BackendContext
|
| 787 |
+
Context for the backend to use. Defaults to BackendContext(Backend.CDSE).
|
| 788 |
+
target_epsg : Optional[int]
|
| 789 |
+
EPSG code for the target coordinate reference system.
|
| 790 |
+
If None, no reprojection will be performed.
|
| 791 |
+
s1_orbit_state : Optional[Literal["ASCENDING", "DESCENDING"]]
|
| 792 |
+
Sentinel-1 orbit state to use ('ASCENDING' or 'DESCENDING')
|
| 793 |
+
If None, no specific orbit state is enforced.
|
| 794 |
+
job_options : Optional[dict]
|
| 795 |
+
Additional options for configuring the inference jobs. Defaults to None.
|
| 796 |
+
parallel_jobs : int
|
| 797 |
+
Number of parallel jobs to manage on the backend. Defaults to 2. Note that load
|
| 798 |
+
balancing does not guarantee that all jobs will run in parallel.
|
| 799 |
+
|
| 800 |
+
Returns
|
| 801 |
+
-------
|
| 802 |
+
None
|
| 803 |
+
"""
|
| 804 |
+
|
| 805 |
+
job_manager, job_db, start_job = setup_inference_job_manager(
|
| 806 |
+
production_grid=production_grid,
|
| 807 |
+
output_dir=output_dir,
|
| 808 |
+
product_type=product_type,
|
| 809 |
+
cropland_parameters=cropland_parameters,
|
| 810 |
+
croptype_parameters=croptype_parameters,
|
| 811 |
+
backend_context=backend_context,
|
| 812 |
+
target_epsg=target_epsg,
|
| 813 |
+
s1_orbit_state=s1_orbit_state,
|
| 814 |
+
job_options=job_options,
|
| 815 |
+
parallel_jobs=parallel_jobs,
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
job_df = job_db.df
|
| 819 |
+
job_tracking_csv = job_db.path
|
| 820 |
+
|
| 821 |
+
# Run the jobs
|
| 822 |
+
job_manager.run_jobs(
|
| 823 |
+
df=job_df,
|
| 824 |
+
start_job=start_job,
|
| 825 |
+
job_db=job_tracking_csv,
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
logger.info("Job manager finished.")
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def setup_inference_job_manager(
|
| 832 |
+
production_grid: Union[Path, gpd.GeoDataFrame],
|
| 833 |
+
output_dir: Union[Path, str],
|
| 834 |
+
product_type: WorldCerealProductType = WorldCerealProductType.CROPLAND,
|
| 835 |
+
cropland_parameters: CropLandParameters = CropLandParameters(),
|
| 836 |
+
croptype_parameters: CropTypeParameters = CropTypeParameters(),
|
| 837 |
+
backend_context: BackendContext = BackendContext(Backend.CDSE),
|
| 838 |
+
target_epsg: Optional[int] = None,
|
| 839 |
+
s1_orbit_state: Optional[Literal["ASCENDING", "DESCENDING"]] = None,
|
| 840 |
+
job_options: Optional[dict] = None,
|
| 841 |
+
parallel_jobs: int = 2,
|
| 842 |
+
) -> tuple[InferenceJobManager, CsvJobDatabase, Callable]:
|
| 843 |
+
"""
|
| 844 |
+
Prepare large-scale inference jobs on the OpenEO backend.
|
| 845 |
+
This function sets up the job manager, creates job tracking information,
|
| 846 |
+
and defines the job creation function for WorldCereal inference jobs.
|
| 847 |
+
|
| 848 |
+
Parameters
|
| 849 |
+
----------
|
| 850 |
+
production_grid : Union[Path, gpd.GeoDataFrame]
|
| 851 |
+
Path to the production grid file in Parquet format or a GeoDataFrame.
|
| 852 |
+
The grid must contain the required attributes: 'start_date', 'end_date',
|
| 853 |
+
'geometry', 'tile_name', 'epsg' and 'bounds_epsg'.
|
| 854 |
+
output_dir : Union[Path, str]
|
| 855 |
+
Directory where output files and job tracking information will be stored.
|
| 856 |
+
product_type : WorldCerealProductType
|
| 857 |
+
Type of product to generate. Defaults to WorldCerealProductType.CROPLAND.
|
| 858 |
+
cropland_parameters : CropLandParameters
|
| 859 |
+
Parameters for cropland inference.
|
| 860 |
+
croptype_parameters : CropTypeParameters
|
| 861 |
+
Parameters for crop type inference.
|
| 862 |
+
backend_context : BackendContext
|
| 863 |
+
Context for the backend to use. Defaults to BackendContext(Backend.CDSE).
|
| 864 |
+
target_epsg : Optional[int]
|
| 865 |
+
EPSG code for the target coordinate reference system.
|
| 866 |
+
If None, no reprojection will be performed.
|
| 867 |
+
s1_orbit_state : Optional[Literal["ASCENDING", "DESCENDING"]]
|
| 868 |
+
Sentinel-1 orbit state to use ('ASCENDING' or 'DESCENDING')
|
| 869 |
+
If None, no specific orbit state is enforced.
|
| 870 |
+
job_options : Optional[dict]
|
| 871 |
+
Additional options for configuring the inference jobs. Defaults to None.
|
| 872 |
+
parallel_jobs : int
|
| 873 |
+
Number of parallel jobs to manage on the backend. Defaults to 2. Note that load
|
| 874 |
+
balancing does not guarantee that all jobs will run in parallel.
|
| 875 |
+
|
| 876 |
+
Returns
|
| 877 |
+
-------
|
| 878 |
+
tuple[InferenceJobManager, CsvJobDatabase, callable]
|
| 879 |
+
A tuple containing:
|
| 880 |
+
- InferenceJobManager: The job manager for handling inference jobs.
|
| 881 |
+
- CsvJobDatabase: The job database for tracking job information.
|
| 882 |
+
- callable: A function to create individual inference jobs.
|
| 883 |
+
|
| 884 |
+
Raises
|
| 885 |
+
-------
|
| 886 |
+
AssertionError:
|
| 887 |
+
If the production grid does not contain the required attributes.
|
| 888 |
+
"""
|
| 889 |
+
|
| 890 |
+
# Setup output directory
|
| 891 |
+
output_dir = Path(output_dir)
|
| 892 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 893 |
+
|
| 894 |
+
# Make a connection to the OpenEO backend
|
| 895 |
+
backend = backend_context.backend
|
| 896 |
+
connection = BACKEND_CONNECTIONS[backend]()
|
| 897 |
+
|
| 898 |
+
# Setup the job manager
|
| 899 |
+
logger.info("Setting up the job manager.")
|
| 900 |
+
manager = InferenceJobManager(root_dir=output_dir)
|
| 901 |
+
manager.add_backend(
|
| 902 |
+
backend.value, connection=connection, parallel_jobs=parallel_jobs
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
# Configure job tracking CSV file
|
| 906 |
+
job_tracking_csv = output_dir / "job_tracking.csv"
|
| 907 |
+
|
| 908 |
+
job_db = CsvJobDatabase(path=job_tracking_csv)
|
| 909 |
+
if not job_db.exists():
|
| 910 |
+
logger.info("Job tracking file does not exist, creating new jobs.")
|
| 911 |
+
|
| 912 |
+
if isinstance(production_grid, Path):
|
| 913 |
+
production_gdf = gpd.read_parquet(production_grid)
|
| 914 |
+
elif isinstance(production_grid, gpd.GeoDataFrame):
|
| 915 |
+
production_gdf = production_grid
|
| 916 |
+
else:
|
| 917 |
+
raise ValueError("production_grid must be a Path or a GeoDataFrame.")
|
| 918 |
+
|
| 919 |
+
REQUIRED_ATTRIBUTES = [
|
| 920 |
+
"start_date",
|
| 921 |
+
"end_date",
|
| 922 |
+
"geometry",
|
| 923 |
+
"tile_name",
|
| 924 |
+
"epsg",
|
| 925 |
+
"bounds_epsg",
|
| 926 |
+
]
|
| 927 |
+
for attr in REQUIRED_ATTRIBUTES:
|
| 928 |
+
assert (
|
| 929 |
+
attr in production_gdf.columns
|
| 930 |
+
), f"The production grid must contain a '{attr}' column."
|
| 931 |
+
|
| 932 |
+
job_df = production_gdf[REQUIRED_ATTRIBUTES].copy()
|
| 933 |
+
|
| 934 |
+
df = manager._normalize_df(job_df)
|
| 935 |
+
# Save the job tracking DataFrame to the job database
|
| 936 |
+
job_db.persist(df)
|
| 937 |
+
|
| 938 |
+
else:
|
| 939 |
+
logger.info("Job tracking file already exists, skipping job creation.")
|
| 940 |
+
|
| 941 |
+
# Define the job creation function
|
| 942 |
+
start_job = partial(
|
| 943 |
+
create_inference_job,
|
| 944 |
+
product_type=product_type,
|
| 945 |
+
cropland_parameters=cropland_parameters,
|
| 946 |
+
croptype_parameters=croptype_parameters,
|
| 947 |
+
s1_orbit_state=s1_orbit_state,
|
| 948 |
+
job_options=job_options,
|
| 949 |
+
target_epsg=target_epsg,
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
# Check if there are jobs to run
|
| 953 |
+
if job_db.df.empty:
|
| 954 |
+
logger.warning("No jobs to run. The job tracking CSV is empty.")
|
| 955 |
+
raise ValueError(
|
| 956 |
+
"No jobs to run. The job tracking CSV is empty. "
|
| 957 |
+
"Please check the production grid and ensure it contains valid data."
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
return manager, job_db, start_job
|
worldcereal/openeo/__init__.py
ADDED
|
File without changes
|
worldcereal/openeo/feature_extractor.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""openEO UDF to compute Presto/Prometheo features."""
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import functools
|
| 5 |
+
import logging
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import urllib.request
|
| 9 |
+
import zipfile
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import xarray as xr
|
| 15 |
+
from openeo.metadata import CollectionMetadata
|
| 16 |
+
from openeo.udf import XarrayDataCube
|
| 17 |
+
from openeo.udf.udf_data import UdfData
|
| 18 |
+
from pyproj import Transformer
|
| 19 |
+
from pyproj.crs import CRS
|
| 20 |
+
from scipy.ndimage import (
|
| 21 |
+
convolve,
|
| 22 |
+
zoom,
|
| 23 |
+
)
|
| 24 |
+
from shapely.geometry import Point
|
| 25 |
+
from shapely.ops import transform
|
| 26 |
+
|
| 27 |
+
sys.path.append("feature_deps")
|
| 28 |
+
|
| 29 |
+
import torch # noqa: E402
|
| 30 |
+
|
| 31 |
+
PROMETHEO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/prometheo-0.0.3-py3-none-any.whl"
|
| 32 |
+
|
| 33 |
+
GFMAP_BAND_MAPPING = {
|
| 34 |
+
"S2-L2A-B02": "B2",
|
| 35 |
+
"S2-L2A-B03": "B3",
|
| 36 |
+
"S2-L2A-B04": "B4",
|
| 37 |
+
"S2-L2A-B05": "B5",
|
| 38 |
+
"S2-L2A-B06": "B6",
|
| 39 |
+
"S2-L2A-B07": "B7",
|
| 40 |
+
"S2-L2A-B08": "B8",
|
| 41 |
+
"S2-L2A-B8A": "B8A",
|
| 42 |
+
"S2-L2A-B11": "B11",
|
| 43 |
+
"S2-L2A-B12": "B12",
|
| 44 |
+
"S1-SIGMA0-VH": "VH",
|
| 45 |
+
"S1-SIGMA0-VV": "VV",
|
| 46 |
+
"AGERA5-TMEAN": "temperature_2m",
|
| 47 |
+
"AGERA5-PRECIP": "total_precipitation",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
LAT_HARMONIZED_NAME = "GEO-LAT"
|
| 51 |
+
LON_HARMONIZED_NAME = "GEO-LON"
|
| 52 |
+
EPSG_HARMONIZED_NAME = "GEO-EPSG"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
logger = logging.getLogger(__name__)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@functools.lru_cache(maxsize=1)
|
| 59 |
+
def unpack_prometheo_wheel(wheel_url: str):
|
| 60 |
+
destination_dir = Path.cwd() / "dependencies" / "prometheo"
|
| 61 |
+
destination_dir.mkdir(exist_ok=True, parents=True)
|
| 62 |
+
|
| 63 |
+
# Downloads the wheel file
|
| 64 |
+
modelfile, _ = urllib.request.urlretrieve(
|
| 65 |
+
wheel_url, filename=Path.cwd() / Path(wheel_url).name
|
| 66 |
+
)
|
| 67 |
+
with zipfile.ZipFile(modelfile, "r") as zip_ref:
|
| 68 |
+
zip_ref.extractall(destination_dir)
|
| 69 |
+
return destination_dir
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@functools.lru_cache(maxsize=1)
|
| 73 |
+
def compile_encoder(presto_encoder):
|
| 74 |
+
"""Helper function that compiles the encoder of a Presto model
|
| 75 |
+
and performs a warm-up on dummy data. The lru_cache decorator
|
| 76 |
+
ensures caching on compute nodes to be able to actually benefit
|
| 77 |
+
from the compilation process.
|
| 78 |
+
|
| 79 |
+
Parameters
|
| 80 |
+
----------
|
| 81 |
+
presto_encoder : nn.Module
|
| 82 |
+
Encoder part of Presto model to compile
|
| 83 |
+
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
presto_encoder = torch.compile(presto_encoder) # type: ignore
|
| 87 |
+
|
| 88 |
+
for _ in range(3):
|
| 89 |
+
presto_encoder(
|
| 90 |
+
torch.rand((1, 12, 17)),
|
| 91 |
+
torch.ones((1, 12)).long(),
|
| 92 |
+
torch.rand(1, 2),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return presto_encoder
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def evaluate_resolution(inarr: xr.DataArray, epsg: int) -> int:
|
| 99 |
+
"""Helper function to get the resolution in meters for
|
| 100 |
+
the input array.
|
| 101 |
+
|
| 102 |
+
Parameters
|
| 103 |
+
----------
|
| 104 |
+
inarr : xr.DataArray
|
| 105 |
+
input array to determine resolution for.
|
| 106 |
+
|
| 107 |
+
Returns
|
| 108 |
+
-------
|
| 109 |
+
int
|
| 110 |
+
resolution in meters.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
if epsg == 4326:
|
| 114 |
+
logger.info(
|
| 115 |
+
"Converting WGS84 coordinates to EPSG:3857 to determine resolution."
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
transformer = Transformer.from_crs(epsg, 3857, always_xy=True)
|
| 119 |
+
points = [Point(x, y) for x, y in zip(inarr.x.values, inarr.y.values)]
|
| 120 |
+
points = [transform(transformer.transform, point) for point in points]
|
| 121 |
+
|
| 122 |
+
resolution = abs(points[1].x - points[0].x)
|
| 123 |
+
|
| 124 |
+
else:
|
| 125 |
+
resolution = abs(inarr.x[1].values - inarr.x[0].values)
|
| 126 |
+
|
| 127 |
+
logger.info(f"Resolution for computing slope: {resolution}")
|
| 128 |
+
|
| 129 |
+
return resolution
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def compute_slope(inarr: xr.DataArray, resolution: int) -> xr.DataArray:
|
| 133 |
+
"""Computes the slope using the scipy library. The input array should
|
| 134 |
+
have the following bands: 'elevation' And no time dimension. Returns a
|
| 135 |
+
new DataArray containing the new `slope` band.
|
| 136 |
+
|
| 137 |
+
Parameters
|
| 138 |
+
----------
|
| 139 |
+
inarr : xr.DataArray
|
| 140 |
+
input array containing a band 'elevation'.
|
| 141 |
+
resolution : int
|
| 142 |
+
resolution of the input array in meters.
|
| 143 |
+
|
| 144 |
+
Returns
|
| 145 |
+
-------
|
| 146 |
+
xr.DataArray
|
| 147 |
+
output array containing 'slope' band in degrees.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def _rolling_fill(darr, max_iter=2):
|
| 151 |
+
"""Helper function that also reflects values inside
|
| 152 |
+
a patch with NaNs."""
|
| 153 |
+
if max_iter == 0:
|
| 154 |
+
return darr
|
| 155 |
+
else:
|
| 156 |
+
max_iter -= 1
|
| 157 |
+
# arr of shape (rows, cols)
|
| 158 |
+
mask = np.isnan(darr)
|
| 159 |
+
|
| 160 |
+
if ~np.any(mask):
|
| 161 |
+
return darr
|
| 162 |
+
|
| 163 |
+
roll_params = [(0, 1), (0, -1), (1, 0), (-1, 0)]
|
| 164 |
+
random.shuffle(roll_params)
|
| 165 |
+
|
| 166 |
+
for roll_param in roll_params:
|
| 167 |
+
rolled = np.roll(darr, roll_param, axis=(0, 1))
|
| 168 |
+
darr[mask] = rolled[mask]
|
| 169 |
+
|
| 170 |
+
return _rolling_fill(darr, max_iter=max_iter)
|
| 171 |
+
|
| 172 |
+
def _downsample(arr: np.ndarray, factor: int) -> np.ndarray:
|
| 173 |
+
"""Downsamples a 2D NumPy array by a given factor with average resampling and reflect padding.
|
| 174 |
+
|
| 175 |
+
Parameters
|
| 176 |
+
----------
|
| 177 |
+
arr : np.ndarray
|
| 178 |
+
The 2D input array.
|
| 179 |
+
factor : int
|
| 180 |
+
The factor by which to downsample. For example, factor=2 downsamples by 2x.
|
| 181 |
+
|
| 182 |
+
Returns
|
| 183 |
+
-------
|
| 184 |
+
np.ndarray
|
| 185 |
+
Downsampled array.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
# Get the original shape of the array
|
| 189 |
+
X, Y = arr.shape
|
| 190 |
+
|
| 191 |
+
# Calculate how much padding is needed for each dimension
|
| 192 |
+
pad_X = (
|
| 193 |
+
factor - (X % factor)
|
| 194 |
+
) % factor # Ensures padding is only applied if needed
|
| 195 |
+
pad_Y = (
|
| 196 |
+
factor - (Y % factor)
|
| 197 |
+
) % factor # Ensures padding is only applied if needed
|
| 198 |
+
|
| 199 |
+
# Pad the array using 'reflect' mode
|
| 200 |
+
padded = np.pad(arr, ((0, pad_X), (0, pad_Y)), mode="reflect")
|
| 201 |
+
|
| 202 |
+
# Reshape the array to form blocks of size 'factor' x 'factor'
|
| 203 |
+
reshaped = padded.reshape(
|
| 204 |
+
(X + pad_X) // factor, factor, (Y + pad_Y) // factor, factor
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Take the mean over the factor-sized blocks
|
| 208 |
+
downsampled = np.nanmean(reshaped, axis=(1, 3))
|
| 209 |
+
|
| 210 |
+
return downsampled
|
| 211 |
+
|
| 212 |
+
dem = inarr.sel(bands="elevation").values
|
| 213 |
+
dem_arr = dem.astype(np.float32)
|
| 214 |
+
|
| 215 |
+
# Invalid to NaN and keep track of these pixels
|
| 216 |
+
dem_arr[dem_arr == 65535] = np.nan
|
| 217 |
+
idx_invalid = np.isnan(dem_arr)
|
| 218 |
+
|
| 219 |
+
# Fill NaNs with rolling fill
|
| 220 |
+
dem_arr = _rolling_fill(dem_arr)
|
| 221 |
+
|
| 222 |
+
# We make sure DEM is at 20m for slope computation
|
| 223 |
+
# compatible with global slope collection
|
| 224 |
+
factor = int(20 / resolution)
|
| 225 |
+
if factor < 1 or factor % 2 != 0:
|
| 226 |
+
raise NotImplementedError(
|
| 227 |
+
f"Unsupported resolution for slope computation: {resolution}"
|
| 228 |
+
)
|
| 229 |
+
dem_arr_downsampled = _downsample(dem_arr, factor)
|
| 230 |
+
x_odd, y_odd = dem_arr.shape[0] % 2 != 0, dem_arr.shape[1] % 2 != 0
|
| 231 |
+
|
| 232 |
+
# Mask NaN values in the DEM data
|
| 233 |
+
dem_masked = np.ma.masked_invalid(dem_arr_downsampled)
|
| 234 |
+
|
| 235 |
+
# Define convolution kernels for x and y gradients (simple finite difference approximation)
|
| 236 |
+
kernel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) / (
|
| 237 |
+
8.0 * 20 # array is now at 20m resolution
|
| 238 |
+
) # x-derivative kernel
|
| 239 |
+
|
| 240 |
+
kernel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) / (
|
| 241 |
+
8.0 * 20 # array is now at 20m resolution
|
| 242 |
+
) # y-derivative kernel
|
| 243 |
+
|
| 244 |
+
# Apply convolution to compute gradients
|
| 245 |
+
dx = convolve(dem_masked, kernel_x) # Gradient in the x-direction
|
| 246 |
+
dy = convolve(dem_masked, kernel_y) # Gradient in the y-direction
|
| 247 |
+
|
| 248 |
+
# Reapply the mask to the gradients
|
| 249 |
+
dx = np.ma.masked_where(dem_masked.mask, dx)
|
| 250 |
+
dy = np.ma.masked_where(dem_masked.mask, dy)
|
| 251 |
+
|
| 252 |
+
# Calculate the magnitude of the gradient (rise/run)
|
| 253 |
+
gradient_magnitude = np.ma.sqrt(dx**2 + dy**2)
|
| 254 |
+
|
| 255 |
+
# Convert gradient magnitude to slope (in degrees)
|
| 256 |
+
slope = np.ma.arctan(gradient_magnitude) * (180 / np.pi)
|
| 257 |
+
|
| 258 |
+
# Upsample to original resolution with bilinear interpolation
|
| 259 |
+
mask = slope.mask
|
| 260 |
+
mask = zoom(mask, zoom=factor, order=0)
|
| 261 |
+
slope = zoom(slope, zoom=factor, order=1)
|
| 262 |
+
slope[mask] = 65535
|
| 263 |
+
|
| 264 |
+
# Strip one row or column if original array was odd in that dimension
|
| 265 |
+
if x_odd:
|
| 266 |
+
slope = slope[:-1, :]
|
| 267 |
+
if y_odd:
|
| 268 |
+
slope = slope[:, :-1]
|
| 269 |
+
|
| 270 |
+
# Fill slope values where the original DEM had NaNs
|
| 271 |
+
slope[idx_invalid] = 65535
|
| 272 |
+
slope[np.isnan(slope)] = 65535
|
| 273 |
+
slope = slope.astype(np.uint16)
|
| 274 |
+
|
| 275 |
+
return xr.DataArray(
|
| 276 |
+
slope[None, :, :],
|
| 277 |
+
dims=("bands", "y", "x"),
|
| 278 |
+
coords={
|
| 279 |
+
"bands": ["slope"],
|
| 280 |
+
"y": inarr.y,
|
| 281 |
+
"x": inarr.x,
|
| 282 |
+
},
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def select_timestep_from_temporal_features(
|
| 287 |
+
features: xr.DataArray, target_date: Optional[str] = None
|
| 288 |
+
) -> xr.DataArray:
|
| 289 |
+
"""Select a specific timestep from temporal features based on target date.
|
| 290 |
+
|
| 291 |
+
Parameters
|
| 292 |
+
----------
|
| 293 |
+
features : xr.DataArray
|
| 294 |
+
Temporal features with time dimension preserved.
|
| 295 |
+
target_date : str, optional
|
| 296 |
+
Target date in ISO format (YYYY-MM-DD). If None, selects middle timestep.
|
| 297 |
+
|
| 298 |
+
Returns
|
| 299 |
+
-------
|
| 300 |
+
xr.DataArray
|
| 301 |
+
Features for the selected timestep with time dimension removed.
|
| 302 |
+
"""
|
| 303 |
+
if target_date is None:
|
| 304 |
+
# Select middle timestep
|
| 305 |
+
mid_idx = len(features.t) // 2
|
| 306 |
+
features = features.isel(t=mid_idx)
|
| 307 |
+
else:
|
| 308 |
+
# Parse target date and find closest timestep
|
| 309 |
+
target_datetime = np.datetime64(target_date)
|
| 310 |
+
|
| 311 |
+
# Check if target_datetime is within the temporal extent of features
|
| 312 |
+
min_time = features.t.min().values
|
| 313 |
+
max_time = features.t.max().values
|
| 314 |
+
|
| 315 |
+
if target_datetime < min_time or target_datetime > max_time:
|
| 316 |
+
raise ValueError(
|
| 317 |
+
f"Target date {target_date} is outside the temporal extent of features. "
|
| 318 |
+
f"Available time range: {min_time} to {max_time}"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Find closest timestep
|
| 322 |
+
features = features.sel(t=target_datetime, method="nearest")
|
| 323 |
+
|
| 324 |
+
return features
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def extract_presto_embeddings(
|
| 328 |
+
inarr: xr.DataArray, parameters: dict, epsg: int
|
| 329 |
+
) -> xr.DataArray:
|
| 330 |
+
"""Executes the feature extraction process on the input array."""
|
| 331 |
+
|
| 332 |
+
if epsg is None:
|
| 333 |
+
raise ValueError(
|
| 334 |
+
"EPSG code is required for Presto feature extraction, but was "
|
| 335 |
+
"not correctly initialized."
|
| 336 |
+
)
|
| 337 |
+
if "presto_model_url" not in parameters:
|
| 338 |
+
raise ValueError('Missing required parameter "presto_model_url"')
|
| 339 |
+
|
| 340 |
+
presto_model_url = parameters.get("presto_model_url")
|
| 341 |
+
logger.info(f'Loading Presto model from "{presto_model_url}"')
|
| 342 |
+
prometheo_wheel_url = parameters.get("prometheo_wheel_url", PROMETHEO_WHL_URL)
|
| 343 |
+
logger.info(f'Loading Prometheo wheel from "{prometheo_wheel_url}"')
|
| 344 |
+
|
| 345 |
+
ignore_dependencies = parameters.get("ignore_dependencies", False)
|
| 346 |
+
if ignore_dependencies:
|
| 347 |
+
logger.info(
|
| 348 |
+
"`ignore_dependencies` flag is set to True. Make sure that "
|
| 349 |
+
"Presto and its dependencies are available on the runtime "
|
| 350 |
+
"environment"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# The below is required to avoid flipping of the result
|
| 354 |
+
# when running on OpenEO backend!
|
| 355 |
+
inarr = inarr.transpose(
|
| 356 |
+
"bands", "t", "x", "y"
|
| 357 |
+
) # Presto/Prometheo expects xy dimension order
|
| 358 |
+
|
| 359 |
+
# Change the band names
|
| 360 |
+
new_band_names = [GFMAP_BAND_MAPPING.get(b.item(), b.item()) for b in inarr.bands]
|
| 361 |
+
inarr = inarr.assign_coords(bands=new_band_names)
|
| 362 |
+
|
| 363 |
+
# Log pixel statistics
|
| 364 |
+
total_pixels = inarr.size
|
| 365 |
+
num_nan_pixels = np.isnan(inarr.values).sum()
|
| 366 |
+
num_zero_pixels = (inarr.values == 0).sum()
|
| 367 |
+
num_nodatavalue_pixels = (inarr.values == 65535).sum()
|
| 368 |
+
logger.info("Band names: " + ", ".join(inarr.bands.values))
|
| 369 |
+
logger.debug(
|
| 370 |
+
f"Array dtype: {inarr.dtype}, "
|
| 371 |
+
f"Array size: {inarr.shape}, total pixels: {total_pixels}, "
|
| 372 |
+
f"Pixel statistics: NaN pixels = {num_nan_pixels} "
|
| 373 |
+
f"({num_nan_pixels / total_pixels * 100:.2f}%), "
|
| 374 |
+
f"0 pixels = {num_zero_pixels} "
|
| 375 |
+
f"({num_zero_pixels / total_pixels * 100:.2f}%), "
|
| 376 |
+
f"NODATAVALUE pixels = {num_nodatavalue_pixels} "
|
| 377 |
+
f"({num_nodatavalue_pixels / total_pixels * 100:.2f}%)"
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Log mean value (ignoring NaNs) per band
|
| 381 |
+
for band in inarr.bands.values:
|
| 382 |
+
band_data = inarr.sel(bands=band).values
|
| 383 |
+
mean_value = np.nanmean(band_data)
|
| 384 |
+
logger.debug(f"Band '{band}': Mean value (ignoring NaNs) = {mean_value:.2f}")
|
| 385 |
+
|
| 386 |
+
# Handle NaN values in Presto compatible way
|
| 387 |
+
inarr = inarr.fillna(65535)
|
| 388 |
+
|
| 389 |
+
if not ignore_dependencies:
|
| 390 |
+
# Unzip the Presto dependencies on the backend
|
| 391 |
+
logger.info("Unpacking prometheo wheel")
|
| 392 |
+
deps_dir = unpack_prometheo_wheel(prometheo_wheel_url)
|
| 393 |
+
|
| 394 |
+
logger.info("Appending dependencies")
|
| 395 |
+
sys.path.append(str(deps_dir))
|
| 396 |
+
|
| 397 |
+
if "slope" not in inarr.bands:
|
| 398 |
+
# If 'slope' is not present we need to compute it here
|
| 399 |
+
logger.warning("`slope` not found in input array. Computing ...")
|
| 400 |
+
resolution = evaluate_resolution(inarr.isel(t=0), epsg)
|
| 401 |
+
slope = compute_slope(inarr.isel(t=0), resolution)
|
| 402 |
+
slope = slope.expand_dims({"t": inarr.t}, axis=0).astype("float32")
|
| 403 |
+
|
| 404 |
+
inarr = xr.concat([inarr.astype("float32"), slope], dim="bands")
|
| 405 |
+
|
| 406 |
+
batch_size = parameters.get("batch_size", 256)
|
| 407 |
+
temporal_prediction = parameters.get("temporal_prediction", False)
|
| 408 |
+
target_date = parameters.get("target_date", None)
|
| 409 |
+
logger.info(
|
| 410 |
+
(
|
| 411 |
+
f"Extracting Presto features with batch size {batch_size}, "
|
| 412 |
+
f"temporal_prediction={temporal_prediction}, "
|
| 413 |
+
f"target_date={target_date}"
|
| 414 |
+
)
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# TODO: compile_presto not used for now?
|
| 418 |
+
# compile_presto = parameters.get("compile_presto", False)
|
| 419 |
+
# self.logger.info(f"Compile presto: {compile_presto}")
|
| 420 |
+
|
| 421 |
+
logger.info("Loading Presto model for inference")
|
| 422 |
+
|
| 423 |
+
# TODO: try to take run_model_inference from worldcereal
|
| 424 |
+
from prometheo.datasets.worldcereal import run_model_inference
|
| 425 |
+
from prometheo.models import Presto
|
| 426 |
+
from prometheo.models.pooling import PoolingMethods
|
| 427 |
+
from prometheo.models.presto.wrapper import load_presto_weights
|
| 428 |
+
|
| 429 |
+
presto_model = Presto()
|
| 430 |
+
presto_model = load_presto_weights(presto_model, presto_model_url)
|
| 431 |
+
|
| 432 |
+
logger.info("Extracting presto features")
|
| 433 |
+
# Check if we have the expected 12 timesteps
|
| 434 |
+
if len(inarr.t) != 12:
|
| 435 |
+
raise ValueError(f"Can only run Presto on 12 timesteps, got: {len(inarr.t)}")
|
| 436 |
+
|
| 437 |
+
# Determine pooling method based on temporal_prediction parameter
|
| 438 |
+
pooling_method = (
|
| 439 |
+
PoolingMethods.TIME if temporal_prediction else PoolingMethods.GLOBAL
|
| 440 |
+
)
|
| 441 |
+
logger.info(f"Using pooling method: {pooling_method}")
|
| 442 |
+
|
| 443 |
+
features = run_model_inference(
|
| 444 |
+
inarr,
|
| 445 |
+
presto_model,
|
| 446 |
+
epsg=epsg,
|
| 447 |
+
batch_size=batch_size,
|
| 448 |
+
pooling_method=pooling_method,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# If temporal prediction, select specific timestep based on target_date
|
| 452 |
+
if temporal_prediction:
|
| 453 |
+
features = select_timestep_from_temporal_features(features, target_date)
|
| 454 |
+
|
| 455 |
+
features = features.transpose(
|
| 456 |
+
"bands", "y", "x"
|
| 457 |
+
) # openEO expects yx order after the UDF
|
| 458 |
+
|
| 459 |
+
return features
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def get_latlons(inarr: xr.DataArray, epsg: int) -> xr.DataArray:
|
| 463 |
+
"""Returns the latitude and longitude coordinates of the given array in
|
| 464 |
+
a dataarray. Returns a dataarray with the same width/height of the input
|
| 465 |
+
array, but with two bands, one for latitude and one for longitude. The
|
| 466 |
+
metadata coordinates of the output array are the same as the input
|
| 467 |
+
array, as the array wasn't reprojected but instead new features were
|
| 468 |
+
computed.
|
| 469 |
+
|
| 470 |
+
The latitude and longitude band names are standardized to the names
|
| 471 |
+
`LAT_HARMONIZED_NAME` and `LON_HARMONIZED_NAME` respectively.
|
| 472 |
+
"""
|
| 473 |
+
|
| 474 |
+
lon = inarr.coords["x"]
|
| 475 |
+
lat = inarr.coords["y"]
|
| 476 |
+
lon, lat = np.meshgrid(lon, lat)
|
| 477 |
+
|
| 478 |
+
if epsg is None:
|
| 479 |
+
raise Exception(
|
| 480 |
+
"EPSG code was not defined, cannot extract lat/lon array "
|
| 481 |
+
"as the CRS is unknown."
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# If the coordiantes are not in EPSG:4326, we need to reproject them
|
| 485 |
+
if epsg != 4326:
|
| 486 |
+
# Initializes a pyproj reprojection object
|
| 487 |
+
transformer = Transformer.from_crs(
|
| 488 |
+
crs_from=CRS.from_epsg(epsg),
|
| 489 |
+
crs_to=CRS.from_epsg(4326),
|
| 490 |
+
always_xy=True,
|
| 491 |
+
)
|
| 492 |
+
lon, lat = transformer.transform(xx=lon, yy=lat)
|
| 493 |
+
|
| 494 |
+
# Create a two channel numpy array of the lat and lons together by stacking
|
| 495 |
+
latlon = np.stack([lat, lon])
|
| 496 |
+
|
| 497 |
+
# Repack in a dataarray
|
| 498 |
+
return xr.DataArray(
|
| 499 |
+
latlon,
|
| 500 |
+
dims=["bands", "y", "x"],
|
| 501 |
+
coords={
|
| 502 |
+
"bands": [LAT_HARMONIZED_NAME, LON_HARMONIZED_NAME],
|
| 503 |
+
"y": inarr.coords["y"],
|
| 504 |
+
"x": inarr.coords["x"],
|
| 505 |
+
},
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def rescale_s1_backscatter(arr: xr.DataArray) -> xr.DataArray:
|
| 510 |
+
"""Rescales the input array from uint16 to float32 decibel values.
|
| 511 |
+
The input array should be in uint16 format, as this optimizes memory usage in Open-EO
|
| 512 |
+
processes. This function is called automatically on the bands of the input array, except
|
| 513 |
+
if the parameter `rescale_s1` is set to False.
|
| 514 |
+
"""
|
| 515 |
+
s1_bands = ["S1-SIGMA0-VV", "S1-SIGMA0-VH", "S1-SIGMA0-HV", "S1-SIGMA0-HH"]
|
| 516 |
+
s1_bands_to_select = list(set(arr.bands.values) & set(s1_bands))
|
| 517 |
+
|
| 518 |
+
if len(s1_bands_to_select) == 0:
|
| 519 |
+
return arr
|
| 520 |
+
|
| 521 |
+
data_to_rescale = arr.sel(bands=s1_bands_to_select).astype(np.float32).data
|
| 522 |
+
|
| 523 |
+
# Assert that the values are set between 1 and 65535
|
| 524 |
+
if data_to_rescale.min().item() < 1 or data_to_rescale.max().item() > 65535:
|
| 525 |
+
raise ValueError(
|
| 526 |
+
"The input array should be in uint16 format, with values between 1 and 65535. "
|
| 527 |
+
"This restriction assures that the data was processed according to the S1 fetcher "
|
| 528 |
+
"preprocessor. The user can disable this scaling manually by setting the "
|
| 529 |
+
"`rescale_s1` parameter to False in the feature extractor."
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# Converting back to power values
|
| 533 |
+
data_to_rescale = 20.0 * np.log10(data_to_rescale) - 83.0
|
| 534 |
+
data_to_rescale = np.power(10, data_to_rescale / 10.0)
|
| 535 |
+
data_to_rescale[~np.isfinite(data_to_rescale)] = np.nan
|
| 536 |
+
|
| 537 |
+
# Converting power values to decibels
|
| 538 |
+
data_to_rescale = 10.0 * np.log10(data_to_rescale)
|
| 539 |
+
|
| 540 |
+
# Change the bands within the array
|
| 541 |
+
arr.loc[dict(bands=s1_bands_to_select)] = data_to_rescale
|
| 542 |
+
return arr
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
# Below comes the actual UDF part
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
# Apply the Feature Extraction UDF
|
| 549 |
+
def apply_udf_data(udf_data: UdfData) -> UdfData:
|
| 550 |
+
"""This is the actual openeo UDF that will be executed by the backend."""
|
| 551 |
+
|
| 552 |
+
cube = udf_data.datacube_list[0]
|
| 553 |
+
parameters = copy.deepcopy(udf_data.user_context)
|
| 554 |
+
|
| 555 |
+
proj = udf_data.proj
|
| 556 |
+
if proj is not None:
|
| 557 |
+
proj = proj["EPSG"]
|
| 558 |
+
|
| 559 |
+
parameters[EPSG_HARMONIZED_NAME] = proj
|
| 560 |
+
|
| 561 |
+
arr = cube.get_array().transpose("bands", "t", "y", "x")
|
| 562 |
+
|
| 563 |
+
epsg = parameters.pop(EPSG_HARMONIZED_NAME)
|
| 564 |
+
logger.info(f"EPSG code determined for feature extraction: {epsg}")
|
| 565 |
+
|
| 566 |
+
if parameters.get("rescale_s1", True):
|
| 567 |
+
arr = rescale_s1_backscatter(arr)
|
| 568 |
+
|
| 569 |
+
arr = extract_presto_embeddings(inarr=arr, parameters=parameters, epsg=epsg)
|
| 570 |
+
|
| 571 |
+
cube = XarrayDataCube(arr)
|
| 572 |
+
|
| 573 |
+
udf_data.datacube_list = [cube]
|
| 574 |
+
|
| 575 |
+
return udf_data
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
# Change band names
|
| 579 |
+
def apply_metadata(metadata: CollectionMetadata, context: dict) -> CollectionMetadata:
|
| 580 |
+
return metadata.rename_labels(
|
| 581 |
+
dimension="bands", target=[f"presto_ft_{i}" for i in range(128)]
|
| 582 |
+
)
|
worldcereal/openeo/inference.py
ADDED
|
@@ -0,0 +1,1191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""openEO UDF to compute Presto/Prometheo features with clean code structure."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import requests
|
| 12 |
+
import xarray as xr
|
| 13 |
+
from openeo.udf import XarrayDataCube
|
| 14 |
+
from openeo.udf.udf_data import UdfData
|
| 15 |
+
from pyproj import Transformer
|
| 16 |
+
from scipy.ndimage import convolve, zoom
|
| 17 |
+
from shapely.geometry import Point
|
| 18 |
+
from shapely.ops import transform
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from loguru import logger
|
| 22 |
+
|
| 23 |
+
logger.remove()
|
| 24 |
+
logger.add(sys.stderr, level="INFO")
|
| 25 |
+
|
| 26 |
+
class InterceptHandler(logging.Handler):
|
| 27 |
+
def emit(self, record):
|
| 28 |
+
level = record.levelname
|
| 29 |
+
logger.opt(depth=6).log(level, record.getMessage())
|
| 30 |
+
|
| 31 |
+
# Replace existing handlers
|
| 32 |
+
for h in logging.root.handlers[:]:
|
| 33 |
+
logging.root.removeHandler(h)
|
| 34 |
+
|
| 35 |
+
logging.root.setLevel(logging.INFO)
|
| 36 |
+
logging.root.addHandler(InterceptHandler())
|
| 37 |
+
|
| 38 |
+
except ImportError:
|
| 39 |
+
# loguru not available, use standard logging
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
_MODULE_CACHE_KEY = f"__model_cache_{__name__}"
|
| 43 |
+
|
| 44 |
+
# Constants
|
| 45 |
+
PROMETHEO_WHL_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/prometheo-0.0.3-py3-none-any.whl"
|
| 46 |
+
|
| 47 |
+
GFMAP_BAND_MAPPING = {
|
| 48 |
+
"S2-L2A-B02": "B2",
|
| 49 |
+
"S2-L2A-B03": "B3",
|
| 50 |
+
"S2-L2A-B04": "B4",
|
| 51 |
+
"S2-L2A-B05": "B5",
|
| 52 |
+
"S2-L2A-B06": "B6",
|
| 53 |
+
"S2-L2A-B07": "B7",
|
| 54 |
+
"S2-L2A-B08": "B8",
|
| 55 |
+
"S2-L2A-B8A": "B8A",
|
| 56 |
+
"S2-L2A-B11": "B11",
|
| 57 |
+
"S2-L2A-B12": "B12",
|
| 58 |
+
"S1-SIGMA0-VH": "VH",
|
| 59 |
+
"S1-SIGMA0-VV": "VV",
|
| 60 |
+
"AGERA5-TMEAN": "temperature_2m",
|
| 61 |
+
"AGERA5-PRECIP": "total_precipitation",
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
LAT_HARMONIZED_NAME = "GEO-LAT"
|
| 65 |
+
LON_HARMONIZED_NAME = "GEO-LON"
|
| 66 |
+
EPSG_HARMONIZED_NAME = "GEO-EPSG"
|
| 67 |
+
|
| 68 |
+
S1_BANDS = ["S1-SIGMA0-VV", "S1-SIGMA0-VH", "S1-SIGMA0-HV", "S1-SIGMA0-HH"]
|
| 69 |
+
NODATA_VALUE = 65535
|
| 70 |
+
|
| 71 |
+
POSTPROCESSING_EXCLUDED_VALUES = [254, 255, 65535]
|
| 72 |
+
POSTPROCESSING_NODATA = 255
|
| 73 |
+
|
| 74 |
+
NUM_THREADS = 2
|
| 75 |
+
|
| 76 |
+
sys.path.append("feature_deps")
|
| 77 |
+
sys.path.append("onnx_deps")
|
| 78 |
+
import onnxruntime as ort # noqa: E402
|
| 79 |
+
|
| 80 |
+
_PROMETHEO_INSTALLED = False
|
| 81 |
+
|
| 82 |
+
# Global variables for Prometheo imports
|
| 83 |
+
Presto = None
|
| 84 |
+
load_presto_weights = None
|
| 85 |
+
run_model_inference = None
|
| 86 |
+
PoolingMethods = None
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# =============================================================================
|
| 90 |
+
# STANDALONE FUNCTIONS (Work in both apply_udf_data and apply_metadata contexts)
|
| 91 |
+
# =============================================================================
|
| 92 |
+
def get_model_cache():
|
| 93 |
+
"""Get or create module-specific cache."""
|
| 94 |
+
if not hasattr(sys, _MODULE_CACHE_KEY):
|
| 95 |
+
setattr(sys, _MODULE_CACHE_KEY, {})
|
| 96 |
+
return getattr(sys, _MODULE_CACHE_KEY)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _ensure_prometheo_dependencies():
|
| 100 |
+
"""Non-cached dependency check."""
|
| 101 |
+
global _PROMETHEO_INSTALLED, Presto, load_presto_weights, run_model_inference, PoolingMethods
|
| 102 |
+
|
| 103 |
+
if _PROMETHEO_INSTALLED:
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
# Try to import first
|
| 108 |
+
from prometheo.datasets.worldcereal import run_model_inference
|
| 109 |
+
from prometheo.models import Presto
|
| 110 |
+
from prometheo.models.pooling import PoolingMethods
|
| 111 |
+
from prometheo.models.presto.wrapper import load_presto_weights
|
| 112 |
+
|
| 113 |
+
# They're now available in the global scope
|
| 114 |
+
_PROMETHEO_INSTALLED = True
|
| 115 |
+
return
|
| 116 |
+
except ImportError:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
# Installation required
|
| 120 |
+
logger.info("Prometheo not available, installing...")
|
| 121 |
+
_install_prometheo()
|
| 122 |
+
|
| 123 |
+
# Import immediately after installation - these will be available globally
|
| 124 |
+
from prometheo.datasets.worldcereal import run_model_inference
|
| 125 |
+
from prometheo.models import Presto
|
| 126 |
+
from prometheo.models.pooling import PoolingMethods
|
| 127 |
+
from prometheo.models.presto.wrapper import load_presto_weights
|
| 128 |
+
|
| 129 |
+
optimize_pytorch_cpu_performance(NUM_THREADS)
|
| 130 |
+
_PROMETHEO_INSTALLED = True
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _install_prometheo():
|
| 134 |
+
"""Non-cached installation function."""
|
| 135 |
+
import shutil
|
| 136 |
+
import tempfile
|
| 137 |
+
import urllib.request
|
| 138 |
+
import zipfile
|
| 139 |
+
|
| 140 |
+
temp_dir = Path(tempfile.mkdtemp())
|
| 141 |
+
try:
|
| 142 |
+
# Download wheel
|
| 143 |
+
wheel_path, _ = urllib.request.urlretrieve(PROMETHEO_WHL_URL)
|
| 144 |
+
|
| 145 |
+
# Extract to temp directory
|
| 146 |
+
with zipfile.ZipFile(wheel_path, "r") as zip_ref:
|
| 147 |
+
zip_ref.extractall(temp_dir)
|
| 148 |
+
|
| 149 |
+
# Add to Python path
|
| 150 |
+
sys.path.append(str(temp_dir))
|
| 151 |
+
logger.info(f"Prometheo installed to {temp_dir}.")
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
if temp_dir.exists():
|
| 155 |
+
shutil.rmtree(temp_dir)
|
| 156 |
+
logger.error(f"Failed to install prometheo: {e}")
|
| 157 |
+
raise
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def load_onnx_model_cached(model_url: str):
|
| 161 |
+
"""ONNX loading is fine since it's pure (no side effects)."""
|
| 162 |
+
|
| 163 |
+
cache = get_model_cache()
|
| 164 |
+
if model_url in cache:
|
| 165 |
+
logger.debug(f"ONNX model cache hit for {model_url}.")
|
| 166 |
+
return cache[model_url]
|
| 167 |
+
|
| 168 |
+
logger.info(f"Loading ONNX model from {model_url}")
|
| 169 |
+
response = requests.get(model_url, timeout=120)
|
| 170 |
+
|
| 171 |
+
session_options, providers = optimize_onnx_cpu_performance(NUM_THREADS)
|
| 172 |
+
|
| 173 |
+
model = ort.InferenceSession(response.content, session_options, providers=providers)
|
| 174 |
+
|
| 175 |
+
metadata = model.get_modelmeta().custom_metadata_map
|
| 176 |
+
class_params = eval(metadata["class_params"], {"__builtins__": None}, {})
|
| 177 |
+
|
| 178 |
+
lut = dict(zip(class_params["class_names"], class_params["class_to_label"]))
|
| 179 |
+
sorted_lut = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])}
|
| 180 |
+
|
| 181 |
+
result = (model, sorted_lut)
|
| 182 |
+
cache[model_url] = result
|
| 183 |
+
return result
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def load_presto_weights_cached(presto_model_url: str):
|
| 187 |
+
"""Manual caching for Presto weights with dependency check."""
|
| 188 |
+
cache = get_model_cache()
|
| 189 |
+
if presto_model_url in cache:
|
| 190 |
+
logger.debug(f"Presto model cache hit for {presto_model_url}")
|
| 191 |
+
return cache[presto_model_url]
|
| 192 |
+
|
| 193 |
+
# Ensure dependencies are available (not cached)
|
| 194 |
+
_ensure_prometheo_dependencies()
|
| 195 |
+
|
| 196 |
+
logger.info(f"Loading Presto weights from: {presto_model_url}")
|
| 197 |
+
|
| 198 |
+
model = Presto() # type: ignore
|
| 199 |
+
result = load_presto_weights(model, presto_model_url) # type: ignore
|
| 200 |
+
|
| 201 |
+
cache[presto_model_url] = result
|
| 202 |
+
return result
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def get_output_labels(lut_sorted: dict, postprocess_parameters: dict = {}) -> list:
|
| 206 |
+
"""Generate output band names from LUT - works in both contexts.
|
| 207 |
+
Parameters
|
| 208 |
+
----------
|
| 209 |
+
lut_sorted : dict
|
| 210 |
+
Sorted lookup table mapping class names to labels.
|
| 211 |
+
postprocess_parameters : dict
|
| 212 |
+
Postprocessing parameters to determine whether to keep per-class probability bands.
|
| 213 |
+
If not provided, we assume all probabilities are kept."""
|
| 214 |
+
|
| 215 |
+
# Determine whether to remove per-class probability bands
|
| 216 |
+
# based on postprocessing parameters
|
| 217 |
+
postprocessing_enabled = postprocess_parameters.get("enabled", True)
|
| 218 |
+
keep_class_probs = postprocess_parameters.get("keep_class_probs", True)
|
| 219 |
+
if postprocessing_enabled and (not keep_class_probs):
|
| 220 |
+
# Only classification and overall probability
|
| 221 |
+
return ["classification", "probability"]
|
| 222 |
+
else:
|
| 223 |
+
# Include per-class probabilities
|
| 224 |
+
class_names = lut_sorted.keys()
|
| 225 |
+
return ["classification", "probability"] + [
|
| 226 |
+
f"probability_{name}" for name in class_names
|
| 227 |
+
]
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def optimize_pytorch_cpu_performance(num_threads):
|
| 231 |
+
"""CPU-specific optimizations for Prometheo."""
|
| 232 |
+
import torch
|
| 233 |
+
|
| 234 |
+
# Thread configuration
|
| 235 |
+
|
| 236 |
+
torch.set_num_threads(num_threads)
|
| 237 |
+
torch.set_num_interop_threads(
|
| 238 |
+
num_threads
|
| 239 |
+
) # TODO test setting to 4 due to parallel slope cal ect
|
| 240 |
+
os.environ["OMP_NUM_THREADS"] = str(num_threads)
|
| 241 |
+
os.environ["MKL_NUM_THREADS"] = str(num_threads)
|
| 242 |
+
os.environ["OPENBLAS_NUM_THREADS"] = str(num_threads)
|
| 243 |
+
|
| 244 |
+
logger.info(f"PyTorch CPU: using {num_threads} threads")
|
| 245 |
+
|
| 246 |
+
# CPU-specific optimizations
|
| 247 |
+
if hasattr(torch.backends, "mkldnn"):
|
| 248 |
+
torch.backends.mkldnn.enabled = True
|
| 249 |
+
|
| 250 |
+
torch.set_grad_enabled(False) # Disable gradients for inference
|
| 251 |
+
|
| 252 |
+
return num_threads
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def optimize_onnx_cpu_performance(num_threads):
|
| 256 |
+
"""CPU-specific ONNX optimizations."""
|
| 257 |
+
session_options = ort.SessionOptions()
|
| 258 |
+
|
| 259 |
+
session_options.intra_op_num_threads = num_threads
|
| 260 |
+
session_options.inter_op_num_threads = (
|
| 261 |
+
num_threads # TODO test setting to 1 due to sequential nature
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# CPU-specific optimizations
|
| 265 |
+
session_options.enable_cpu_mem_arena = True
|
| 266 |
+
session_options.enable_mem_pattern = True
|
| 267 |
+
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
| 268 |
+
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 269 |
+
|
| 270 |
+
providers = ["CPUExecutionProvider"]
|
| 271 |
+
|
| 272 |
+
return session_options, providers
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# =============================================================================
|
| 276 |
+
# POSTPROCESSING FUNCTIONS
|
| 277 |
+
# =============================================================================
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def majority_vote(
|
| 281 |
+
base_labels: xr.DataArray,
|
| 282 |
+
max_probabilities: xr.DataArray,
|
| 283 |
+
kernel_size: int,
|
| 284 |
+
) -> xr.DataArray:
|
| 285 |
+
"""Majority vote is performed using a sliding local kernel.
|
| 286 |
+
For each pixel, the voting of a final class is done by counting
|
| 287 |
+
neighbours values.
|
| 288 |
+
Pixels that have one of the specified excluded values are
|
| 289 |
+
excluded in the voting process and are unchanged.
|
| 290 |
+
|
| 291 |
+
The prediction probabilities are reevaluated by taking, for each pixel,
|
| 292 |
+
the average of probabilities of the neighbors that belong to the winning class.
|
| 293 |
+
(For example, if a pixel was voted to class 2 and there are three
|
| 294 |
+
neighbors of that class, then the new probability is the sum of the
|
| 295 |
+
old probabilities of each pixels divided by 3)
|
| 296 |
+
|
| 297 |
+
Parameters
|
| 298 |
+
----------
|
| 299 |
+
base_labels : xr.DataArray
|
| 300 |
+
The original predicted classification labels.
|
| 301 |
+
max_probabilities : xr.DataArray
|
| 302 |
+
The original probabilities of the winning class (ranging between 0 and 100).
|
| 303 |
+
kernel_size : int
|
| 304 |
+
The size of the kernel used for the neighbour around the pixel.
|
| 305 |
+
|
| 306 |
+
Returns
|
| 307 |
+
-------
|
| 308 |
+
xr.DataArray
|
| 309 |
+
The cleaned classification labels and associated probabilities.
|
| 310 |
+
"""
|
| 311 |
+
from scipy.signal import convolve2d
|
| 312 |
+
|
| 313 |
+
prediction = base_labels.values
|
| 314 |
+
probability = max_probabilities.values
|
| 315 |
+
|
| 316 |
+
# As the probabilities are in integers between 0 and 100,
|
| 317 |
+
# we use uint16 matrices to store the vote scores
|
| 318 |
+
assert (
|
| 319 |
+
kernel_size <= 25
|
| 320 |
+
), f"Kernel value cannot be larger than 25 (currently: {kernel_size}) because it might lead to scenarios where the 16-bit count matrix is overflown"
|
| 321 |
+
|
| 322 |
+
# Build a class mapping, so classes are converted to indexes and vice-versa
|
| 323 |
+
unique_values = set(np.unique(prediction))
|
| 324 |
+
unique_values = sorted(unique_values - set(POSTPROCESSING_EXCLUDED_VALUES)) # type: ignore
|
| 325 |
+
index_value_lut = [(k, v) for k, v in enumerate(unique_values)]
|
| 326 |
+
|
| 327 |
+
counts = np.zeros(shape=(*prediction.shape, len(unique_values)), dtype=np.uint16)
|
| 328 |
+
probabilities = np.zeros(
|
| 329 |
+
shape=(*probability.shape, len(unique_values)), dtype=np.uint16
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Iterates for each classes
|
| 333 |
+
for cls_idx, cls_value in index_value_lut:
|
| 334 |
+
# Take the binary mask of the interest class, and multiply by the probabilities
|
| 335 |
+
class_mask = ((prediction == cls_value) * probability).astype(np.uint16)
|
| 336 |
+
|
| 337 |
+
# Set to 0 the class scores where the label is excluded
|
| 338 |
+
for excluded_value in POSTPROCESSING_EXCLUDED_VALUES:
|
| 339 |
+
class_mask[prediction == excluded_value] = 0
|
| 340 |
+
|
| 341 |
+
# Binary class mask, used to count HOW MANY neighbours pixels are used for this class
|
| 342 |
+
binary_class_mask = (class_mask > 0).astype(np.uint16)
|
| 343 |
+
|
| 344 |
+
# Creates the kernel
|
| 345 |
+
kernel = np.ones(shape=(kernel_size, kernel_size), dtype=np.uint16)
|
| 346 |
+
|
| 347 |
+
# Counts around the window the sum of probabilities for that given class
|
| 348 |
+
counts[:, :, cls_idx] = convolve2d(class_mask, kernel, mode="same")
|
| 349 |
+
|
| 350 |
+
# Counts the number of neighbors pixels that voted for that given class
|
| 351 |
+
class_voters = convolve2d(binary_class_mask, kernel, mode="same")
|
| 352 |
+
# Remove the 0 values because might create divide by 0 issues
|
| 353 |
+
class_voters[class_voters == 0] = 1
|
| 354 |
+
|
| 355 |
+
probabilities[:, :, cls_idx] = np.divide(counts[:, :, cls_idx], class_voters)
|
| 356 |
+
|
| 357 |
+
# Initializes output array
|
| 358 |
+
aggregated_predictions = np.zeros(
|
| 359 |
+
shape=(counts.shape[0], counts.shape[1]), dtype=np.uint16
|
| 360 |
+
)
|
| 361 |
+
# Initializes probabilities output array
|
| 362 |
+
aggregated_probabilities = np.zeros(
|
| 363 |
+
shape=(counts.shape[0], counts.shape[1]), dtype=np.uint16
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if len(unique_values) > 0:
|
| 367 |
+
# Takes the indices that have the biggest scores
|
| 368 |
+
aggregated_predictions_indices = np.argmax(counts, axis=2)
|
| 369 |
+
|
| 370 |
+
# Get the new probabilities of the predictions
|
| 371 |
+
aggregated_probabilities = np.take_along_axis(
|
| 372 |
+
probabilities,
|
| 373 |
+
aggregated_predictions_indices.reshape(
|
| 374 |
+
*aggregated_predictions_indices.shape, 1
|
| 375 |
+
),
|
| 376 |
+
axis=2,
|
| 377 |
+
).squeeze()
|
| 378 |
+
|
| 379 |
+
# Check which pixels have a counts value equal to 0
|
| 380 |
+
no_score_mask = np.sum(counts, axis=2) == 0
|
| 381 |
+
|
| 382 |
+
# convert back to values from indices
|
| 383 |
+
for cls_idx, cls_value in index_value_lut:
|
| 384 |
+
aggregated_predictions[aggregated_predictions_indices == cls_idx] = (
|
| 385 |
+
cls_value
|
| 386 |
+
)
|
| 387 |
+
aggregated_predictions = aggregated_predictions.astype(np.uint16)
|
| 388 |
+
|
| 389 |
+
aggregated_predictions[no_score_mask] = POSTPROCESSING_NODATA
|
| 390 |
+
aggregated_probabilities[no_score_mask] = POSTPROCESSING_NODATA
|
| 391 |
+
|
| 392 |
+
# Setting excluded values back to their original values
|
| 393 |
+
for excluded_value in POSTPROCESSING_EXCLUDED_VALUES:
|
| 394 |
+
aggregated_predictions[prediction == excluded_value] = excluded_value
|
| 395 |
+
aggregated_probabilities[prediction == excluded_value] = excluded_value
|
| 396 |
+
|
| 397 |
+
return xr.DataArray(
|
| 398 |
+
np.stack((aggregated_predictions, aggregated_probabilities)),
|
| 399 |
+
dims=["bands", "y", "x"],
|
| 400 |
+
coords={
|
| 401 |
+
"bands": ["classification", "probability"],
|
| 402 |
+
"y": base_labels.y,
|
| 403 |
+
"x": base_labels.x,
|
| 404 |
+
},
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def smooth_probabilities(
|
| 409 |
+
base_labels: xr.DataArray, class_probabilities: xr.DataArray
|
| 410 |
+
) -> xr.DataArray:
|
| 411 |
+
"""Performs gaussian smoothing on the class probabilities. Requires the
|
| 412 |
+
base labels to keep the pixels that are excluded away from smoothing.
|
| 413 |
+
"""
|
| 414 |
+
from scipy.signal import convolve2d
|
| 415 |
+
|
| 416 |
+
base_labels_vals = base_labels.values
|
| 417 |
+
probabilities_vals = class_probabilities.values
|
| 418 |
+
|
| 419 |
+
excluded_mask = np.in1d(
|
| 420 |
+
base_labels_vals.reshape(-1),
|
| 421 |
+
POSTPROCESSING_EXCLUDED_VALUES,
|
| 422 |
+
).reshape(*base_labels_vals.shape)
|
| 423 |
+
|
| 424 |
+
conv_kernel = np.array([[1, 2, 1], [2, 3, 2], [1, 2, 1]], dtype=np.int16)
|
| 425 |
+
|
| 426 |
+
for class_idx in range(probabilities_vals.shape[0]):
|
| 427 |
+
probabilities_vals[class_idx] = (
|
| 428 |
+
convolve2d(
|
| 429 |
+
probabilities_vals[class_idx],
|
| 430 |
+
conv_kernel,
|
| 431 |
+
mode="same",
|
| 432 |
+
boundary="symm",
|
| 433 |
+
)
|
| 434 |
+
/ conv_kernel.sum()
|
| 435 |
+
)
|
| 436 |
+
probabilities_vals[class_idx][excluded_mask] = 0
|
| 437 |
+
|
| 438 |
+
# Sum of probabilities should be 1, cast to uint16
|
| 439 |
+
probabilities_vals = np.round(
|
| 440 |
+
probabilities_vals / probabilities_vals.sum(axis=0) * 100.0
|
| 441 |
+
).astype("uint16")
|
| 442 |
+
|
| 443 |
+
return xr.DataArray(
|
| 444 |
+
probabilities_vals,
|
| 445 |
+
coords=class_probabilities.coords,
|
| 446 |
+
dims=class_probabilities.dims,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def reclassify(
|
| 451 |
+
base_labels: xr.DataArray,
|
| 452 |
+
base_max_probs: xr.DataArray,
|
| 453 |
+
probabilities: xr.DataArray,
|
| 454 |
+
) -> xr.DataArray:
|
| 455 |
+
base_labels_vals = base_labels.values
|
| 456 |
+
base_max_probs_vals = base_max_probs.values
|
| 457 |
+
|
| 458 |
+
excluded_mask = np.in1d(
|
| 459 |
+
base_labels_vals.reshape(-1),
|
| 460 |
+
POSTPROCESSING_EXCLUDED_VALUES,
|
| 461 |
+
).reshape(*base_labels_vals.shape)
|
| 462 |
+
|
| 463 |
+
new_labels_vals = np.argmax(probabilities.values, axis=0)
|
| 464 |
+
new_max_probs_vals = np.max(probabilities.values, axis=0)
|
| 465 |
+
|
| 466 |
+
new_labels_vals[excluded_mask] = base_labels_vals[excluded_mask]
|
| 467 |
+
new_max_probs_vals[excluded_mask] = base_max_probs_vals[excluded_mask]
|
| 468 |
+
|
| 469 |
+
return xr.DataArray(
|
| 470 |
+
np.stack((new_labels_vals, new_max_probs_vals)),
|
| 471 |
+
dims=["bands", "y", "x"],
|
| 472 |
+
coords={
|
| 473 |
+
"bands": ["classification", "probability"],
|
| 474 |
+
"y": base_labels.y,
|
| 475 |
+
"x": base_labels.x,
|
| 476 |
+
},
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# =============================================================================
|
| 481 |
+
# ERROR HANDLING - SIMPLE VERSION
|
| 482 |
+
# =============================================================================
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def create_nan_output_array(
|
| 486 |
+
inarr: xr.DataArray, num_outputs: int, error_info: str = ""
|
| 487 |
+
) -> xr.DataArray:
|
| 488 |
+
"""Creates a NaN-filled output array with proper dimensions and coordinates.
|
| 489 |
+
|
| 490 |
+
Parameters
|
| 491 |
+
----------
|
| 492 |
+
inarr : xr.DataArray
|
| 493 |
+
Input array to derive dimensions from
|
| 494 |
+
num_outputs : int
|
| 495 |
+
Number of output bands/classes
|
| 496 |
+
error_info : str
|
| 497 |
+
Error information to include in attributes for debugging
|
| 498 |
+
|
| 499 |
+
Returns
|
| 500 |
+
-------
|
| 501 |
+
xr.DataArray
|
| 502 |
+
NaN-filled array with proper structure
|
| 503 |
+
"""
|
| 504 |
+
logger.error(f"Creating NaN output array due to error: {error_info}")
|
| 505 |
+
logger.error(f"Input array shape: {inarr.shape}, dims: {inarr.dims}")
|
| 506 |
+
logger.error(
|
| 507 |
+
f"Input array coords - bands: {inarr.bands.values}, t: {len(inarr.t)}, x: {len(inarr.x)}, y: {len(inarr.y)}"
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Create NaN array with same spatial dimensions
|
| 511 |
+
nan_array = np.full(
|
| 512 |
+
(num_outputs, len(inarr.y), len(inarr.x)), np.nan, dtype=np.float32
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
# Create output array with proper coordinates
|
| 516 |
+
output_array = xr.DataArray(
|
| 517 |
+
nan_array,
|
| 518 |
+
dims=["bands", "y", "x"],
|
| 519 |
+
coords={
|
| 520 |
+
"bands": list(range(num_outputs)),
|
| 521 |
+
"y": inarr.y,
|
| 522 |
+
"x": inarr.x,
|
| 523 |
+
},
|
| 524 |
+
attrs={"error": error_info},
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
return output_array
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
# =============================================================================
|
| 531 |
+
# CLASSES (Main logic for apply_udf_data)
|
| 532 |
+
# =============================================================================
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class SlopeCalculator:
|
| 536 |
+
"""Handles slope computation from elevation data."""
|
| 537 |
+
|
| 538 |
+
@staticmethod
|
| 539 |
+
def compute(resolution: float, elevation_data: np.ndarray) -> np.ndarray:
|
| 540 |
+
"""Compute slope from elevation data."""
|
| 541 |
+
dem_arr = SlopeCalculator._prepare_dem_array(elevation_data)
|
| 542 |
+
dem_downsampled = SlopeCalculator._downsample_to_20m(dem_arr, resolution)
|
| 543 |
+
slope = SlopeCalculator._compute_slope_gradient(dem_downsampled)
|
| 544 |
+
result = SlopeCalculator._upsample_to_original(slope, dem_arr.shape, resolution)
|
| 545 |
+
return result
|
| 546 |
+
|
| 547 |
+
@staticmethod
|
| 548 |
+
def _prepare_dem_array(dem: np.ndarray) -> np.ndarray:
|
| 549 |
+
"""Prepare DEM array by handling NaNs and invalid values."""
|
| 550 |
+
dem_arr = dem.astype(np.float32)
|
| 551 |
+
dem_arr[dem_arr == NODATA_VALUE] = np.nan
|
| 552 |
+
return SlopeCalculator._fill_nans(dem_arr)
|
| 553 |
+
|
| 554 |
+
@staticmethod
|
| 555 |
+
def _fill_nans(dem_arr: np.ndarray, max_iter: int = 2) -> np.ndarray:
|
| 556 |
+
"""Fill NaN values using rolling fill approach."""
|
| 557 |
+
if max_iter == 0 or not np.any(np.isnan(dem_arr)):
|
| 558 |
+
return dem_arr
|
| 559 |
+
|
| 560 |
+
mask = np.isnan(dem_arr)
|
| 561 |
+
roll_params = [(0, 1), (0, -1), (1, 0), (-1, 0)]
|
| 562 |
+
random.shuffle(roll_params)
|
| 563 |
+
|
| 564 |
+
for roll_param in roll_params:
|
| 565 |
+
rolled = np.roll(dem_arr, roll_param, axis=(0, 1))
|
| 566 |
+
dem_arr[mask] = rolled[mask]
|
| 567 |
+
|
| 568 |
+
return SlopeCalculator._fill_nans(dem_arr, max_iter - 1)
|
| 569 |
+
|
| 570 |
+
@staticmethod
|
| 571 |
+
def _downsample_to_20m(dem_arr: np.ndarray, resolution: float) -> np.ndarray:
|
| 572 |
+
"""Downsample DEM to 20m resolution for slope computation."""
|
| 573 |
+
factor = int(20 / resolution)
|
| 574 |
+
if factor < 1 or factor % 2 != 0:
|
| 575 |
+
raise ValueError(f"Unsupported resolution for slope: {resolution}")
|
| 576 |
+
|
| 577 |
+
X, Y = dem_arr.shape
|
| 578 |
+
pad_X, pad_Y = (
|
| 579 |
+
(factor - (X % factor)) % factor,
|
| 580 |
+
(factor - (Y % factor)) % factor,
|
| 581 |
+
)
|
| 582 |
+
padded = np.pad(dem_arr, ((0, pad_X), (0, pad_Y)), mode="reflect")
|
| 583 |
+
|
| 584 |
+
reshaped = padded.reshape(
|
| 585 |
+
(X + pad_X) // factor, factor, (Y + pad_Y) // factor, factor
|
| 586 |
+
)
|
| 587 |
+
return np.nanmean(reshaped, axis=(1, 3))
|
| 588 |
+
|
| 589 |
+
@staticmethod
|
| 590 |
+
def _compute_slope_gradient(dem: np.ndarray) -> np.ndarray:
|
| 591 |
+
"""Compute slope gradient using Sobel operators."""
|
| 592 |
+
kernel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) / (8.0 * 20)
|
| 593 |
+
kernel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) / (8.0 * 20)
|
| 594 |
+
|
| 595 |
+
dx = convolve(dem, kernel_x)
|
| 596 |
+
dy = convolve(dem, kernel_y)
|
| 597 |
+
gradient_magnitude = np.sqrt(dx**2 + dy**2)
|
| 598 |
+
|
| 599 |
+
return np.arctan(gradient_magnitude) * (180 / np.pi)
|
| 600 |
+
|
| 601 |
+
@staticmethod
|
| 602 |
+
def _upsample_to_original(
|
| 603 |
+
slope: np.ndarray, original_shape: Tuple[int, ...], resolution: float
|
| 604 |
+
) -> np.ndarray:
|
| 605 |
+
"""Upsample slope back to original resolution."""
|
| 606 |
+
factor = int(20 / resolution)
|
| 607 |
+
slope_upsampled = zoom(slope, zoom=factor, order=1)
|
| 608 |
+
|
| 609 |
+
# Handle odd dimensions
|
| 610 |
+
if original_shape[0] % 2 != 0:
|
| 611 |
+
slope_upsampled = slope_upsampled[:-1, :]
|
| 612 |
+
if original_shape[1] % 2 != 0:
|
| 613 |
+
slope_upsampled = slope_upsampled[:, :-1]
|
| 614 |
+
|
| 615 |
+
return slope_upsampled.astype(np.uint16)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class CoordinateTransformer:
|
| 619 |
+
"""Handles coordinate transformations and spatial operations."""
|
| 620 |
+
|
| 621 |
+
@staticmethod
|
| 622 |
+
def get_resolution(inarr: xr.DataArray, epsg: int) -> float:
|
| 623 |
+
"""Calculate resolution in meters."""
|
| 624 |
+
if epsg == 4326:
|
| 625 |
+
return CoordinateTransformer._get_wgs84_resolution(inarr)
|
| 626 |
+
return abs(inarr.x[1].values - inarr.x[0].values)
|
| 627 |
+
|
| 628 |
+
@staticmethod
|
| 629 |
+
def _get_wgs84_resolution(inarr: xr.DataArray) -> float:
|
| 630 |
+
"""Convert WGS84 coordinates to meters for resolution calculation."""
|
| 631 |
+
transformer = Transformer.from_crs(4326, 3857, always_xy=True)
|
| 632 |
+
points = [Point(x, y) for x, y in zip(inarr.x.values, inarr.y.values)]
|
| 633 |
+
points = [transform(transformer.transform, point) for point in points]
|
| 634 |
+
return abs(points[1].x - points[0].x)
|
| 635 |
+
|
| 636 |
+
@staticmethod
|
| 637 |
+
def get_lat_lon_array(inarr: xr.DataArray, epsg: int) -> xr.DataArray:
|
| 638 |
+
"""Create latitude/longitude array from coordinates."""
|
| 639 |
+
lon, lat = np.meshgrid(inarr.x.values, inarr.y.values)
|
| 640 |
+
|
| 641 |
+
if epsg != 4326:
|
| 642 |
+
transformer = Transformer.from_crs(epsg, 4326, always_xy=True)
|
| 643 |
+
lon, lat = transformer.transform(lon, lat)
|
| 644 |
+
|
| 645 |
+
latlon = np.stack([lat, lon])
|
| 646 |
+
return xr.DataArray(
|
| 647 |
+
latlon,
|
| 648 |
+
dims=["bands", "y", "x"],
|
| 649 |
+
coords={
|
| 650 |
+
"bands": [LAT_HARMONIZED_NAME, LON_HARMONIZED_NAME],
|
| 651 |
+
"y": inarr.y,
|
| 652 |
+
"x": inarr.x,
|
| 653 |
+
},
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
class DataPreprocessor:
|
| 658 |
+
"""Handles data preprocessing operations."""
|
| 659 |
+
|
| 660 |
+
@staticmethod
|
| 661 |
+
def rescale_s1_backscatter(arr: xr.DataArray) -> xr.DataArray:
|
| 662 |
+
"""Rescale Sentinel-1 backscatter from uint16 to dB values."""
|
| 663 |
+
s1_bands_present = [b for b in S1_BANDS if b in arr.bands.values]
|
| 664 |
+
if not s1_bands_present:
|
| 665 |
+
return arr
|
| 666 |
+
|
| 667 |
+
s1_data = arr.sel(bands=s1_bands_present).astype(np.float32)
|
| 668 |
+
DataPreprocessor._validate_s1_data(s1_data.values)
|
| 669 |
+
|
| 670 |
+
# Convert to power values then to dB
|
| 671 |
+
power_values = 20.0 * np.log10(s1_data.values) - 83.0
|
| 672 |
+
power_values = np.power(10, power_values / 10.0)
|
| 673 |
+
power_values[~np.isfinite(power_values)] = np.nan
|
| 674 |
+
|
| 675 |
+
db_values = 10.0 * np.log10(power_values)
|
| 676 |
+
arr.loc[dict(bands=s1_bands_present)] = db_values
|
| 677 |
+
|
| 678 |
+
return arr
|
| 679 |
+
|
| 680 |
+
@staticmethod
|
| 681 |
+
def _validate_s1_data(data: np.ndarray) -> None:
|
| 682 |
+
"""Validate S1 data meets preprocessing requirements."""
|
| 683 |
+
if data.min() < 1 or data.max() > NODATA_VALUE:
|
| 684 |
+
raise ValueError(
|
| 685 |
+
"S1 data should be uint16 format with values 1-65535. "
|
| 686 |
+
"Set 'rescale_s1' to False to disable scaling."
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
class PrestoFeatureExtractor:
|
| 691 |
+
"""Handles Presto feature extraction pipeline."""
|
| 692 |
+
|
| 693 |
+
def __init__(self, parameters: Dict[str, Any]):
|
| 694 |
+
self.parameters = parameters
|
| 695 |
+
|
| 696 |
+
def extract(self, inarr: xr.DataArray, epsg: int) -> xr.DataArray:
|
| 697 |
+
"""Extract Presto features from input array."""
|
| 698 |
+
if epsg is None:
|
| 699 |
+
raise ValueError("EPSG code required for Presto feature extraction")
|
| 700 |
+
|
| 701 |
+
# ONLY check top level - no nested lookup
|
| 702 |
+
presto_model_url = self.parameters.get("presto_model_url")
|
| 703 |
+
if not presto_model_url:
|
| 704 |
+
logger.error(
|
| 705 |
+
f"Missing presto_model_url. Available keys: {list(self.parameters.keys())}"
|
| 706 |
+
)
|
| 707 |
+
raise ValueError('Missing required parameter "presto_model_url"')
|
| 708 |
+
|
| 709 |
+
if len(inarr.t) != 12:
|
| 710 |
+
error_msg = (
|
| 711 |
+
f"Presto requires exactly 12 timesteps, but got {len(inarr.t)}. "
|
| 712 |
+
f"Available timesteps: {inarr.t.values}. "
|
| 713 |
+
f"Patch coordinates - x: {inarr.x.values.tolist()}, y: {inarr.y.values.tolist()}"
|
| 714 |
+
)
|
| 715 |
+
logger.error(error_msg)
|
| 716 |
+
|
| 717 |
+
# Return NaN array instead of crashing
|
| 718 |
+
return create_nan_output_array(
|
| 719 |
+
inarr, self.parameters["num_outputs"], error_msg
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
inarr = self._preprocess_input(inarr)
|
| 723 |
+
|
| 724 |
+
if "slope" not in inarr.bands:
|
| 725 |
+
inarr = self._add_slope_band(inarr, epsg)
|
| 726 |
+
|
| 727 |
+
return self._run_presto_inference(inarr, epsg)
|
| 728 |
+
|
| 729 |
+
def _preprocess_input(self, inarr: xr.DataArray) -> xr.DataArray:
|
| 730 |
+
"""Preprocess input array for Presto."""
|
| 731 |
+
inarr = inarr.transpose("bands", "t", "x", "y")
|
| 732 |
+
|
| 733 |
+
# Harmonize band names
|
| 734 |
+
new_bands = [GFMAP_BAND_MAPPING.get(b.item(), b.item()) for b in inarr.bands]
|
| 735 |
+
inarr = inarr.assign_coords(bands=new_bands)
|
| 736 |
+
|
| 737 |
+
return inarr.fillna(NODATA_VALUE)
|
| 738 |
+
|
| 739 |
+
def _add_slope_band(self, inarr: xr.DataArray, epsg: int) -> xr.DataArray:
|
| 740 |
+
"""Compute and add slope band to array."""
|
| 741 |
+
logger.warning("Slope band not found, computing...")
|
| 742 |
+
resolution = CoordinateTransformer.get_resolution(inarr.isel(t=0), epsg)
|
| 743 |
+
elevation_data = inarr.sel(bands="COP-DEM").isel(t=0).values
|
| 744 |
+
|
| 745 |
+
slope_array = SlopeCalculator.compute(resolution, elevation_data)
|
| 746 |
+
slope_da = (
|
| 747 |
+
xr.DataArray(
|
| 748 |
+
slope_array[None, :, :],
|
| 749 |
+
dims=("bands", "y", "x"),
|
| 750 |
+
coords={"bands": ["slope"], "y": inarr.y, "x": inarr.x},
|
| 751 |
+
)
|
| 752 |
+
.expand_dims({"t": inarr.t})
|
| 753 |
+
.astype("float32")
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
return xr.concat([inarr.astype("float32"), slope_da], dim="bands")
|
| 757 |
+
|
| 758 |
+
def _run_presto_inference(self, inarr: xr.DataArray, epsg: int) -> xr.DataArray:
|
| 759 |
+
"""Run Presto model inference with safe dependency handling."""
|
| 760 |
+
# Dependencies are now handled by load_presto_weights_cached
|
| 761 |
+
import gc
|
| 762 |
+
|
| 763 |
+
import torch
|
| 764 |
+
|
| 765 |
+
_ensure_prometheo_dependencies()
|
| 766 |
+
|
| 767 |
+
presto_model_url = self.parameters["presto_model_url"]
|
| 768 |
+
|
| 769 |
+
model = load_presto_weights_cached(presto_model_url)
|
| 770 |
+
|
| 771 |
+
# Import here to ensure dependencies are available
|
| 772 |
+
pooling_method = (
|
| 773 |
+
PoolingMethods.TIME # type: ignore
|
| 774 |
+
if self.parameters.get("temporal_prediction")
|
| 775 |
+
else PoolingMethods.GLOBAL # type: ignore
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
logger.info("Running presto inference ...")
|
| 779 |
+
try:
|
| 780 |
+
with torch.inference_mode():
|
| 781 |
+
features = run_model_inference(
|
| 782 |
+
inarr,
|
| 783 |
+
model,
|
| 784 |
+
epsg=epsg,
|
| 785 |
+
batch_size=self.parameters.get("batch_size", 256), # TODO optimize?
|
| 786 |
+
pooling_method=pooling_method,
|
| 787 |
+
) # type: ignore
|
| 788 |
+
logger.info("Inference completed.")
|
| 789 |
+
|
| 790 |
+
if self.parameters.get("temporal_prediction"):
|
| 791 |
+
features = self._select_temporal_features(features)
|
| 792 |
+
return features.transpose("bands", "y", "x")
|
| 793 |
+
|
| 794 |
+
finally:
|
| 795 |
+
gc.collect()
|
| 796 |
+
|
| 797 |
+
def _select_temporal_features(self, features: xr.DataArray) -> xr.DataArray:
|
| 798 |
+
"""Select specific timestep from temporal features."""
|
| 799 |
+
target_date = self.parameters.get("target_date")
|
| 800 |
+
|
| 801 |
+
if target_date is None:
|
| 802 |
+
mid_idx = len(features.t) // 2
|
| 803 |
+
return features.isel(t=mid_idx)
|
| 804 |
+
|
| 805 |
+
target_dt = np.datetime64(target_date)
|
| 806 |
+
min_time, max_time = features.t.min().values, features.t.max().values
|
| 807 |
+
|
| 808 |
+
if target_dt < min_time or target_dt > max_time:
|
| 809 |
+
raise ValueError(
|
| 810 |
+
f"Target date {target_date} outside feature range: {min_time} to {max_time}"
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
return features.sel(t=target_dt, method="nearest")
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
class ONNXClassifier:
|
| 817 |
+
"""Handles ONNX model inference for classification."""
|
| 818 |
+
|
| 819 |
+
def __init__(self, parameters: Dict[str, Any]):
|
| 820 |
+
self.parameters = parameters
|
| 821 |
+
|
| 822 |
+
def predict(self, features: xr.DataArray) -> xr.DataArray:
|
| 823 |
+
"""Run classification prediction."""
|
| 824 |
+
classifier_url = self.parameters.get("classifier_url")
|
| 825 |
+
if not classifier_url:
|
| 826 |
+
logger.error(
|
| 827 |
+
f"Missing classifier_url. Available keys: {list(self.parameters.keys())}"
|
| 828 |
+
)
|
| 829 |
+
raise ValueError('Missing required parameter "classifier_url"')
|
| 830 |
+
|
| 831 |
+
session, lut = load_onnx_model_cached(classifier_url)
|
| 832 |
+
features_flat = self._prepare_features(features)
|
| 833 |
+
|
| 834 |
+
logger.info("Running ONNX model inference ...")
|
| 835 |
+
predictions = self._run_inference(session, lut, features_flat)
|
| 836 |
+
logger.info("ONNX inference completed.")
|
| 837 |
+
|
| 838 |
+
return self._reshape_predictions(predictions, features, lut)
|
| 839 |
+
|
| 840 |
+
def _prepare_features(self, features: xr.DataArray) -> np.ndarray:
|
| 841 |
+
"""Prepare features for inference."""
|
| 842 |
+
return (
|
| 843 |
+
features.transpose("bands", "x", "y")
|
| 844 |
+
.stack(xy=["x", "y"])
|
| 845 |
+
.transpose()
|
| 846 |
+
.values
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
def _run_inference(
|
| 850 |
+
self, session: Any, lut: Dict, features: np.ndarray
|
| 851 |
+
) -> np.ndarray:
|
| 852 |
+
"""Run ONNX model inference."""
|
| 853 |
+
outputs = session.run(None, {"features": features})
|
| 854 |
+
|
| 855 |
+
labels = np.zeros(len(outputs[0]), dtype=np.uint16)
|
| 856 |
+
probabilities = np.zeros(len(outputs[0]), dtype=np.uint8)
|
| 857 |
+
|
| 858 |
+
for i, (label, prob) in enumerate(zip(outputs[0], outputs[1])):
|
| 859 |
+
labels[i] = lut[label]
|
| 860 |
+
probabilities[i] = int(round(prob[label] * 100))
|
| 861 |
+
|
| 862 |
+
class_probs = np.array(
|
| 863 |
+
[[prob[label] for label in lut.keys()] for prob in outputs[1]]
|
| 864 |
+
)
|
| 865 |
+
class_probs = (class_probs * 100).round().astype(np.uint8)
|
| 866 |
+
|
| 867 |
+
return np.hstack([labels[:, None], probabilities[:, None], class_probs]).T
|
| 868 |
+
|
| 869 |
+
def _reshape_predictions(
|
| 870 |
+
self, predictions: np.ndarray, original_features: xr.DataArray, lut: Dict
|
| 871 |
+
) -> xr.DataArray:
|
| 872 |
+
"""Reshape predictions to match original spatial dimensions."""
|
| 873 |
+
output_labels = get_output_labels(lut)
|
| 874 |
+
x_coords, y_coords = original_features.x.values, original_features.y.values
|
| 875 |
+
|
| 876 |
+
reshaped = predictions.reshape(
|
| 877 |
+
(len(output_labels), len(x_coords), len(y_coords))
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
return xr.DataArray(
|
| 881 |
+
reshaped,
|
| 882 |
+
dims=["bands", "x", "y"],
|
| 883 |
+
coords={"bands": output_labels, "x": x_coords, "y": y_coords},
|
| 884 |
+
).transpose("bands", "y", "x")
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
class Postprocessor:
|
| 888 |
+
"""Handles postprocessing of classification results."""
|
| 889 |
+
|
| 890 |
+
def __init__(self, parameters: Dict[str, Any], classifier_url: str):
|
| 891 |
+
self.parameters = parameters
|
| 892 |
+
self.classifier_url = classifier_url
|
| 893 |
+
|
| 894 |
+
def apply(self, inarr: xr.DataArray) -> xr.DataArray:
|
| 895 |
+
inarr = inarr.transpose(
|
| 896 |
+
"bands", "y", "x"
|
| 897 |
+
) # Ensure correct dimension order for openEO backend
|
| 898 |
+
|
| 899 |
+
_, lookup_table = load_onnx_model_cached(self.classifier_url)
|
| 900 |
+
|
| 901 |
+
if self.parameters.get("method") == "smooth_probabilities":
|
| 902 |
+
# Cast to float for more accurate gaussian smoothing
|
| 903 |
+
class_probabilities = (
|
| 904 |
+
inarr.isel(bands=slice(2, None)).astype("float32") / 100.0
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
# Peform probability smoothing
|
| 908 |
+
class_probabilities = smooth_probabilities(
|
| 909 |
+
inarr.sel(bands="classification"), class_probabilities
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
# Reclassify
|
| 913 |
+
new_labels = reclassify(
|
| 914 |
+
inarr.sel(bands="classification"),
|
| 915 |
+
inarr.sel(bands="probability"),
|
| 916 |
+
class_probabilities,
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
# Re-apply labels
|
| 920 |
+
class_labels = list(lookup_table.values())
|
| 921 |
+
|
| 922 |
+
# Create a final labels array with same dimensions as new_labels
|
| 923 |
+
final_labels = xr.full_like(new_labels, fill_value=65535)
|
| 924 |
+
for idx, label in enumerate(class_labels):
|
| 925 |
+
final_labels.loc[{"bands": "classification"}] = xr.where(
|
| 926 |
+
new_labels.sel(bands="classification") == idx,
|
| 927 |
+
label,
|
| 928 |
+
final_labels.sel(bands="classification"),
|
| 929 |
+
)
|
| 930 |
+
new_labels.sel(bands="classification").values = final_labels.sel(
|
| 931 |
+
bands="classification"
|
| 932 |
+
).values
|
| 933 |
+
|
| 934 |
+
# Append the per-class probabalities if required
|
| 935 |
+
if self.parameters.get("keep_class_probs", False):
|
| 936 |
+
new_labels = xr.concat([new_labels, class_probabilities], dim="bands")
|
| 937 |
+
|
| 938 |
+
elif self.parameters.get("method") == "majority_vote":
|
| 939 |
+
kernel_size = self.parameters.get("kernel_size", 5)
|
| 940 |
+
|
| 941 |
+
new_labels = majority_vote(
|
| 942 |
+
inarr.sel(bands="classification"),
|
| 943 |
+
inarr.sel(bands="probability"),
|
| 944 |
+
kernel_size=kernel_size,
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
# Append the per-class probabalities if required
|
| 948 |
+
if self.parameters.get("keep_class_probs", False):
|
| 949 |
+
class_probabilities = inarr.isel(bands=slice(2, None))
|
| 950 |
+
new_labels = xr.concat([new_labels, class_probabilities], dim="bands")
|
| 951 |
+
|
| 952 |
+
else:
|
| 953 |
+
raise ValueError(
|
| 954 |
+
f"Unknown post-processing method: {self.parameters.get('method')}"
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
new_labels = new_labels.transpose(
|
| 958 |
+
"bands", "y", "x"
|
| 959 |
+
) # Ensure correct dimension order for openEO backend
|
| 960 |
+
|
| 961 |
+
return new_labels
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
# =============================================================================
|
| 965 |
+
# MAIN UDF FUNCTIONS
|
| 966 |
+
# =============================================================================
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
def run_single_workflow(
|
| 970 |
+
input_array: xr.DataArray,
|
| 971 |
+
epsg: int,
|
| 972 |
+
parameters: Dict[str, Any],
|
| 973 |
+
mask: Optional[xr.DataArray] = None,
|
| 974 |
+
) -> xr.DataArray:
|
| 975 |
+
"""Run a single classification workflow with optional masking."""
|
| 976 |
+
|
| 977 |
+
# Preprocess data
|
| 978 |
+
if parameters["feature_parameters"].get("rescale_s1", True):
|
| 979 |
+
logger.info("Rescale s1 ...")
|
| 980 |
+
input_array = DataPreprocessor.rescale_s1_backscatter(input_array)
|
| 981 |
+
|
| 982 |
+
# Extract features
|
| 983 |
+
logger.info("Extract Presto embeddings ...")
|
| 984 |
+
feature_extractor = PrestoFeatureExtractor(parameters["feature_parameters"])
|
| 985 |
+
features = feature_extractor.extract(input_array, epsg)
|
| 986 |
+
logger.info("Presto embedding extraction done.")
|
| 987 |
+
|
| 988 |
+
# Classify
|
| 989 |
+
logger.info("Onnx classification ...")
|
| 990 |
+
classifier = ONNXClassifier(parameters["classifier_parameters"])
|
| 991 |
+
classes = classifier.predict(features)
|
| 992 |
+
logger.info("Onnx classification done.")
|
| 993 |
+
|
| 994 |
+
# Postprocess
|
| 995 |
+
postprocess_parameters: Dict[str, Any] = parameters.get(
|
| 996 |
+
"postprocess_parameters", {}
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
if postprocess_parameters.get("enable"):
|
| 1000 |
+
logger.info("Postprocessing classification results ...")
|
| 1001 |
+
if postprocess_parameters.get("save_intermediate"):
|
| 1002 |
+
classes_raw = classes.assign_coords(
|
| 1003 |
+
bands=[f"raw_{b}" for b in list(classes.bands.values)]
|
| 1004 |
+
)
|
| 1005 |
+
postprocessor = Postprocessor(
|
| 1006 |
+
postprocess_parameters,
|
| 1007 |
+
classifier_url=parameters.get("classifier_parameters", {}).get(
|
| 1008 |
+
"classifier_url"
|
| 1009 |
+
),
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
classes = postprocessor.apply(classes)
|
| 1013 |
+
if postprocess_parameters.get("save_intermediate"):
|
| 1014 |
+
classes = xr.concat([classes, classes_raw], dim="bands")
|
| 1015 |
+
logger.info("Postprocessing done.")
|
| 1016 |
+
|
| 1017 |
+
# Set masked areas to specific value
|
| 1018 |
+
if mask is not None:
|
| 1019 |
+
logger.info("`mask` provided, applying to classification results ...")
|
| 1020 |
+
classes = classes.where(mask, 254) # 254 = non-cropland
|
| 1021 |
+
|
| 1022 |
+
return classes
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
def combine_results(
|
| 1026 |
+
croptype_result: xr.DataArray, cropland_result: xr.DataArray
|
| 1027 |
+
) -> xr.DataArray:
|
| 1028 |
+
"""Combine crop type results with ALL cropland classification bands."""
|
| 1029 |
+
|
| 1030 |
+
# Rename cropland bands to avoid conflicts
|
| 1031 |
+
cropland_bands_renamed = [
|
| 1032 |
+
f"cropland_{band}" for band in cropland_result.bands.values
|
| 1033 |
+
]
|
| 1034 |
+
cropland_result = cropland_result.assign_coords(bands=cropland_bands_renamed)
|
| 1035 |
+
|
| 1036 |
+
# Rename croptype bands for clarity
|
| 1037 |
+
croptype_bands_renamed = [
|
| 1038 |
+
f"croptype_{band}" for band in croptype_result.bands.values
|
| 1039 |
+
]
|
| 1040 |
+
croptype_result = croptype_result.assign_coords(bands=croptype_bands_renamed)
|
| 1041 |
+
|
| 1042 |
+
# Combine all bands from both results
|
| 1043 |
+
combined_bands = list(croptype_bands_renamed) + list(cropland_bands_renamed)
|
| 1044 |
+
combined_data = np.concatenate(
|
| 1045 |
+
[croptype_result.values, cropland_result.values], axis=0
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
result = xr.DataArray(
|
| 1049 |
+
combined_data,
|
| 1050 |
+
dims=["bands", "y", "x"],
|
| 1051 |
+
coords={
|
| 1052 |
+
"bands": combined_bands,
|
| 1053 |
+
"y": croptype_result.y,
|
| 1054 |
+
"x": croptype_result.x,
|
| 1055 |
+
},
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
return result
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
def apply_udf_data(udf_data: UdfData) -> UdfData:
|
| 1062 |
+
"""Main UDF entry point - expects cropland_params and croptype_params in context."""
|
| 1063 |
+
|
| 1064 |
+
input_cube = udf_data.datacube_list[0]
|
| 1065 |
+
parameters = udf_data.user_context.copy()
|
| 1066 |
+
|
| 1067 |
+
epsg = udf_data.proj["EPSG"] if udf_data.proj else None
|
| 1068 |
+
if epsg is None:
|
| 1069 |
+
raise ValueError("EPSG code not found in projection information")
|
| 1070 |
+
|
| 1071 |
+
# Prepare input array
|
| 1072 |
+
input_array = input_cube.get_array().transpose("bands", "t", "y", "x")
|
| 1073 |
+
|
| 1074 |
+
# Extract both parameter sets directly from context
|
| 1075 |
+
cropland_params = parameters.get("cropland_params", {})
|
| 1076 |
+
croptype_params = parameters.get("croptype_params", {})
|
| 1077 |
+
|
| 1078 |
+
# Check if we have both parameter sets for dual workflow
|
| 1079 |
+
if cropland_params and croptype_params:
|
| 1080 |
+
logger.info(
|
| 1081 |
+
"Running combined workflow: cropland masking + croptype mapping ..."
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
# Run cropland classification - pass the FLAT parameters
|
| 1085 |
+
logger.info("Running cropland classification ...")
|
| 1086 |
+
cropland_result = run_single_workflow(input_array, epsg, cropland_params)
|
| 1087 |
+
logger.info("Cropland classification done.")
|
| 1088 |
+
|
| 1089 |
+
# Extract cropland mask for masking the crop type classification
|
| 1090 |
+
cropland_mask = cropland_result.sel(bands="classification") > 0
|
| 1091 |
+
|
| 1092 |
+
# Run crop type classification with mask
|
| 1093 |
+
logger.info("Running crop type classification ...")
|
| 1094 |
+
croptype_result = run_single_workflow(
|
| 1095 |
+
input_array, epsg, croptype_params, cropland_mask
|
| 1096 |
+
)
|
| 1097 |
+
logger.info("Croptype classification done.")
|
| 1098 |
+
|
| 1099 |
+
# Combine ALL bands from both results
|
| 1100 |
+
result = combine_results(croptype_result, cropland_result)
|
| 1101 |
+
result_cube = XarrayDataCube(result)
|
| 1102 |
+
|
| 1103 |
+
else:
|
| 1104 |
+
# Single workflow (fallback to original behavior)
|
| 1105 |
+
logger.info("Running single workflow ...")
|
| 1106 |
+
result = run_single_workflow(input_array, epsg, parameters)
|
| 1107 |
+
result_cube = XarrayDataCube(result)
|
| 1108 |
+
|
| 1109 |
+
udf_data.datacube_list = [result_cube]
|
| 1110 |
+
|
| 1111 |
+
return udf_data
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
+
def apply_metadata(metadata, context: Dict) -> Any:
|
| 1115 |
+
"""Update collection metadata for combined output with ALL bands.
|
| 1116 |
+
|
| 1117 |
+
Band naming logic summary (kept for mapping module resilience):
|
| 1118 |
+
- Single workflow (either cropland OR croptype parameters only):
|
| 1119 |
+
Base bands: classification, probability, probability_<class>
|
| 1120 |
+
If save_intermediate: raw_<band> duplicates are appended.
|
| 1121 |
+
- Combined workflow (both croptype_params & cropland_params):
|
| 1122 |
+
Prefixed bands: croptype_<band> and cropland_<band>
|
| 1123 |
+
If save_intermediate: croptype_raw_<band> and cropland_raw_<band> duplicates appended.
|
| 1124 |
+
|
| 1125 |
+
No renaming occurs here beyond prefixing for the combined workflow; logic in
|
| 1126 |
+
mapping.py must therefore accept both prefixed and unprefixed forms.
|
| 1127 |
+
"""
|
| 1128 |
+
try:
|
| 1129 |
+
# For dual workflow, combine band names from both models
|
| 1130 |
+
if "croptype_params" in context and "cropland_params" in context:
|
| 1131 |
+
# Get croptype band names
|
| 1132 |
+
croptype_classifier_url = context["croptype_params"][
|
| 1133 |
+
"classifier_parameters"
|
| 1134 |
+
].get("classifier_url")
|
| 1135 |
+
if croptype_classifier_url:
|
| 1136 |
+
_, croptype_lut = load_onnx_model_cached(croptype_classifier_url)
|
| 1137 |
+
postprocess_parameters = context["croptype_params"].get(
|
| 1138 |
+
"postprocess_parameters", {}
|
| 1139 |
+
)
|
| 1140 |
+
croptype_bands = [
|
| 1141 |
+
f"croptype_{band}"
|
| 1142 |
+
for band in get_output_labels(croptype_lut, postprocess_parameters)
|
| 1143 |
+
]
|
| 1144 |
+
if postprocess_parameters.get("save_intermediate", False):
|
| 1145 |
+
croptype_bands += [
|
| 1146 |
+
band.replace("croptype_", "croptype_raw_")
|
| 1147 |
+
for band in croptype_bands
|
| 1148 |
+
]
|
| 1149 |
+
else:
|
| 1150 |
+
raise ValueError("No croptype LUT found")
|
| 1151 |
+
|
| 1152 |
+
# Get cropland band names
|
| 1153 |
+
cropland_classifier_url = context["cropland_params"][
|
| 1154 |
+
"classifier_parameters"
|
| 1155 |
+
].get("classifier_url")
|
| 1156 |
+
if cropland_classifier_url:
|
| 1157 |
+
_, cropland_lut = load_onnx_model_cached(cropland_classifier_url)
|
| 1158 |
+
postprocess_parameters = context["cropland_params"].get(
|
| 1159 |
+
"postprocess_parameters", {}
|
| 1160 |
+
)
|
| 1161 |
+
cropland_bands = [
|
| 1162 |
+
f"cropland_{band}"
|
| 1163 |
+
for band in get_output_labels(cropland_lut, postprocess_parameters)
|
| 1164 |
+
]
|
| 1165 |
+
if postprocess_parameters.get("save_intermediate", False):
|
| 1166 |
+
cropland_bands += [
|
| 1167 |
+
band.replace("cropland_", "cropland_raw_")
|
| 1168 |
+
for band in cropland_bands
|
| 1169 |
+
]
|
| 1170 |
+
else:
|
| 1171 |
+
raise ValueError("No cropland LUT found")
|
| 1172 |
+
|
| 1173 |
+
output_labels = croptype_bands + cropland_bands
|
| 1174 |
+
|
| 1175 |
+
else:
|
| 1176 |
+
# Single workflow
|
| 1177 |
+
classifier_url = context["classifier_parameters"].get("classifier_url")
|
| 1178 |
+
if classifier_url:
|
| 1179 |
+
_, lut_sorted = load_onnx_model_cached(classifier_url)
|
| 1180 |
+
postprocess_parameters = context.get("postprocess_parameters", {})
|
| 1181 |
+
output_labels = get_output_labels(lut_sorted, postprocess_parameters)
|
| 1182 |
+
if postprocess_parameters.get("save_intermediate", False):
|
| 1183 |
+
output_labels += [f"raw_{band}" for band in output_labels]
|
| 1184 |
+
else:
|
| 1185 |
+
raise ValueError("No classifier URL found in context")
|
| 1186 |
+
|
| 1187 |
+
return metadata.rename_labels(dimension="bands", target=output_labels)
|
| 1188 |
+
|
| 1189 |
+
except Exception as e:
|
| 1190 |
+
logger.warning(f"Could not load model in metadata context: {e}")
|
| 1191 |
+
return metadata
|
worldcereal/openeo/mapping.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mapping helpers for cropland, croptype and embeddings products.
|
| 2 |
+
|
| 3 |
+
Band naming conventions produced by the UDF (`inference.py`):
|
| 4 |
+
|
| 5 |
+
Single workflow (only cropland OR only croptype parameters passed to UDF):
|
| 6 |
+
classification, probability, probability_<class>
|
| 7 |
+
If save_intermediate: raw_<band> duplicates (e.g. raw_classification)
|
| 8 |
+
|
| 9 |
+
Combined workflow (croptype with cropland masking: both `croptype_params` &
|
| 10 |
+
`cropland_params` passed):
|
| 11 |
+
croptype_<band>, cropland_<band>
|
| 12 |
+
If save_intermediate: croptype_raw_<band>, cropland_raw_<band>
|
| 13 |
+
Example: croptype_classification -> croptype_raw_classification
|
| 14 |
+
|
| 15 |
+
Important: Raw bands in the combined workflow do NOT duplicate the base prefix;
|
| 16 |
+
they simply replace the leading product prefix with <product>_raw_.
|
| 17 |
+
|
| 18 |
+
Simplification: We ignore any *save_intermediate* flags. If raw bands are
|
| 19 |
+
present we save them; the UDF only emits them when intermediate results were
|
| 20 |
+
requested upstream.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import List
|
| 25 |
+
|
| 26 |
+
import openeo
|
| 27 |
+
from openeo import DataCube
|
| 28 |
+
from openeo_gfmap import TemporalContext
|
| 29 |
+
from openeo_gfmap.preprocessing.scaling import compress_uint16
|
| 30 |
+
|
| 31 |
+
from worldcereal.openeo.inference import apply_metadata
|
| 32 |
+
from worldcereal.parameters import (
|
| 33 |
+
CropLandParameters,
|
| 34 |
+
CropTypeParameters,
|
| 35 |
+
EmbeddingsParameters,
|
| 36 |
+
WorldCerealProductType,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
NEIGHBORHOOD_SPEC = dict(
|
| 40 |
+
size=[
|
| 41 |
+
{"dimension": "x", "unit": "px", "value": 128},
|
| 42 |
+
{"dimension": "y", "unit": "px", "value": 128},
|
| 43 |
+
],
|
| 44 |
+
overlap=[
|
| 45 |
+
{"dimension": "x", "unit": "px", "value": 0},
|
| 46 |
+
{"dimension": "y", "unit": "px", "value": 0},
|
| 47 |
+
],
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _run_udf(inputs: DataCube, udf: openeo.UDF) -> DataCube:
|
| 52 |
+
return inputs.apply_neighborhood(process=udf, **NEIGHBORHOOD_SPEC)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _reduce_temporal_mean(cube: DataCube) -> DataCube:
|
| 56 |
+
return cube.reduce_dimension(dimension="t", reducer="mean")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _filename_prefix(
|
| 60 |
+
product: WorldCerealProductType, temporal: TemporalContext, raw: bool = False
|
| 61 |
+
) -> str:
|
| 62 |
+
suffix = "-raw" if raw else ""
|
| 63 |
+
return f"{product.value}{suffix}_{temporal.start_date}_{temporal.end_date}"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _save_result(cube: DataCube, prefix: str) -> DataCube:
|
| 67 |
+
return cube.save_result(format="GTiff", options={"filename_prefix": prefix})
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _cropland_map(
|
| 71 |
+
inputs: DataCube,
|
| 72 |
+
temporal_extent: TemporalContext,
|
| 73 |
+
cropland_parameters: CropLandParameters,
|
| 74 |
+
) -> List[DataCube]:
|
| 75 |
+
"""Produce cropland product from preprocessed inputs (single workflow).
|
| 76 |
+
|
| 77 |
+
Saves final bands and any raw_* bands purely based on presence.
|
| 78 |
+
"""
|
| 79 |
+
inference_udf = openeo.UDF.from_file(
|
| 80 |
+
path=Path(__file__).resolve().parent / "inference.py",
|
| 81 |
+
context=cropland_parameters.model_dump(),
|
| 82 |
+
)
|
| 83 |
+
classes = _run_udf(inputs, inference_udf)
|
| 84 |
+
classes.metadata = apply_metadata(
|
| 85 |
+
classes.metadata, cropland_parameters.model_dump()
|
| 86 |
+
)
|
| 87 |
+
classes = _reduce_temporal_mean(classes)
|
| 88 |
+
classes = compress_uint16(classes)
|
| 89 |
+
|
| 90 |
+
bands = classes.metadata.band_names
|
| 91 |
+
result_cubes: List[DataCube] = []
|
| 92 |
+
|
| 93 |
+
final_bands = [b for b in bands if not b.startswith("raw_")]
|
| 94 |
+
if final_bands:
|
| 95 |
+
final_cube = classes.filter_bands(final_bands)
|
| 96 |
+
result_cubes.append(
|
| 97 |
+
_save_result(
|
| 98 |
+
final_cube,
|
| 99 |
+
_filename_prefix(WorldCerealProductType.CROPLAND, temporal_extent),
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
raw_bands = [b for b in bands if b.startswith("raw_")]
|
| 104 |
+
if raw_bands:
|
| 105 |
+
raw_cube = classes.filter_bands(raw_bands)
|
| 106 |
+
result_cubes.append(
|
| 107 |
+
_save_result(
|
| 108 |
+
raw_cube,
|
| 109 |
+
_filename_prefix(
|
| 110 |
+
WorldCerealProductType.CROPLAND, temporal_extent, raw=True
|
| 111 |
+
),
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
return result_cubes
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _croptype_map(
|
| 119 |
+
inputs: DataCube,
|
| 120 |
+
temporal_extent: TemporalContext,
|
| 121 |
+
croptype_parameters: CropTypeParameters,
|
| 122 |
+
cropland_parameters: CropLandParameters,
|
| 123 |
+
) -> List[DataCube]:
|
| 124 |
+
"""Produce crop type product. Optionally includes cropland masking.
|
| 125 |
+
Cropland mask final bands saved only if `croptype_parameters.save_mask` is True.
|
| 126 |
+
"""
|
| 127 |
+
if croptype_parameters.mask_cropland:
|
| 128 |
+
parameters = {
|
| 129 |
+
"cropland_params": cropland_parameters.model_dump(),
|
| 130 |
+
"croptype_params": croptype_parameters.model_dump(),
|
| 131 |
+
}
|
| 132 |
+
else:
|
| 133 |
+
parameters = croptype_parameters.model_dump()
|
| 134 |
+
|
| 135 |
+
inference_udf = openeo.UDF.from_file(
|
| 136 |
+
path=Path(__file__).resolve().parent / "inference.py",
|
| 137 |
+
context=parameters,
|
| 138 |
+
)
|
| 139 |
+
classes = _run_udf(inputs, inference_udf)
|
| 140 |
+
classes.metadata = apply_metadata(classes.metadata, parameters)
|
| 141 |
+
classes = _reduce_temporal_mean(classes)
|
| 142 |
+
classes = compress_uint16(classes)
|
| 143 |
+
|
| 144 |
+
bands = classes.metadata.band_names
|
| 145 |
+
result_cubes: List[DataCube] = []
|
| 146 |
+
|
| 147 |
+
if croptype_parameters.mask_cropland:
|
| 148 |
+
# Prefixed croptype final and raw bands
|
| 149 |
+
croptype_final_bands = [
|
| 150 |
+
b for b in bands if b.startswith("croptype_") and "raw" not in b
|
| 151 |
+
]
|
| 152 |
+
# Raw croptype bands (presence-based)
|
| 153 |
+
raw_croptype_bands = [b for b in bands if b.startswith("croptype_raw_")]
|
| 154 |
+
else:
|
| 155 |
+
# Single workflow: unprefixed croptype bands
|
| 156 |
+
croptype_final_bands = [b for b in bands if not b.startswith("raw_")]
|
| 157 |
+
raw_croptype_bands = [b for b in bands if b.startswith("raw_")]
|
| 158 |
+
|
| 159 |
+
# Final croptype
|
| 160 |
+
croptype_cube = classes.filter_bands(croptype_final_bands).rename_labels(
|
| 161 |
+
dimension="bands",
|
| 162 |
+
target=[
|
| 163 |
+
b.replace("croptype_", "") for b in croptype_final_bands
|
| 164 |
+
], # Remove prefix
|
| 165 |
+
)
|
| 166 |
+
result_cubes.append(
|
| 167 |
+
_save_result(
|
| 168 |
+
croptype_cube,
|
| 169 |
+
_filename_prefix(WorldCerealProductType.CROPTYPE, temporal_extent),
|
| 170 |
+
)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Raw croptype if present
|
| 174 |
+
if raw_croptype_bands:
|
| 175 |
+
raw_croptype_cube = classes.filter_bands(raw_croptype_bands).rename_labels(
|
| 176 |
+
dimension="bands",
|
| 177 |
+
target=[
|
| 178 |
+
b.replace("croptype_", "") for b in raw_croptype_bands
|
| 179 |
+
], # Remove prefix
|
| 180 |
+
)
|
| 181 |
+
result_cubes.append(
|
| 182 |
+
_save_result(
|
| 183 |
+
raw_croptype_cube,
|
| 184 |
+
_filename_prefix(
|
| 185 |
+
WorldCerealProductType.CROPTYPE, temporal_extent, raw=True
|
| 186 |
+
),
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Optional cropland mask & raw cropland bands
|
| 191 |
+
if croptype_parameters.save_mask:
|
| 192 |
+
cropland_final_bands = [
|
| 193 |
+
b
|
| 194 |
+
for b in bands
|
| 195 |
+
if b.startswith("cropland_") and not b.startswith("cropland_raw_")
|
| 196 |
+
]
|
| 197 |
+
cropland_cube = classes.filter_bands(cropland_final_bands).rename_labels(
|
| 198 |
+
dimension="bands",
|
| 199 |
+
target=[
|
| 200 |
+
b.replace("cropland_", "") for b in cropland_final_bands
|
| 201 |
+
], # Remove prefix
|
| 202 |
+
)
|
| 203 |
+
result_cubes.append(
|
| 204 |
+
_save_result(
|
| 205 |
+
cropland_cube,
|
| 206 |
+
_filename_prefix(WorldCerealProductType.CROPLAND, temporal_extent),
|
| 207 |
+
)
|
| 208 |
+
)
|
| 209 |
+
raw_cropland_bands = [b for b in bands if b.startswith("cropland_raw_")]
|
| 210 |
+
if raw_cropland_bands:
|
| 211 |
+
raw_cropland_cube = classes.filter_bands(raw_cropland_bands).rename_labels(
|
| 212 |
+
dimension="bands",
|
| 213 |
+
target=[
|
| 214 |
+
b.replace("cropland_", "") for b in raw_cropland_bands
|
| 215 |
+
], # Remove prefix
|
| 216 |
+
)
|
| 217 |
+
result_cubes.append(
|
| 218 |
+
_save_result(
|
| 219 |
+
raw_cropland_cube,
|
| 220 |
+
_filename_prefix(
|
| 221 |
+
WorldCerealProductType.CROPLAND, temporal_extent, raw=True
|
| 222 |
+
),
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return result_cubes
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _embeddings_map(
|
| 230 |
+
inputs: DataCube,
|
| 231 |
+
temporal_extent: TemporalContext, # temporal extent unused but kept for signature consistency
|
| 232 |
+
embeddings_parameters: EmbeddingsParameters,
|
| 233 |
+
scale_uint16: bool = True,
|
| 234 |
+
) -> DataCube:
|
| 235 |
+
"""Produce embeddings map using Prometheo feature extractor."""
|
| 236 |
+
|
| 237 |
+
feature_udf = openeo.UDF.from_file(
|
| 238 |
+
path=Path(__file__).resolve().parent / "feature_extractor.py",
|
| 239 |
+
context=embeddings_parameters.feature_parameters.model_dump(),
|
| 240 |
+
)
|
| 241 |
+
embeddings = _run_udf(inputs, feature_udf)
|
| 242 |
+
embeddings = _reduce_temporal_mean(embeddings)
|
| 243 |
+
|
| 244 |
+
if scale_uint16:
|
| 245 |
+
OFFSET = -6
|
| 246 |
+
SCALE = 0.0002
|
| 247 |
+
embeddings = (embeddings - OFFSET) / SCALE
|
| 248 |
+
embeddings = embeddings.linear_scale_range(0, 65534, 0, 65534)
|
| 249 |
+
|
| 250 |
+
return embeddings
|
worldcereal/openeo/preprocessing.py
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Any, Dict, List, Literal, Optional, Union
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from geojson import GeoJSON
|
| 6 |
+
from openeo import UDF, Connection, DataCube
|
| 7 |
+
from openeo_gfmap import (
|
| 8 |
+
Backend,
|
| 9 |
+
BackendContext,
|
| 10 |
+
BoundingBoxExtent,
|
| 11 |
+
FetchType,
|
| 12 |
+
SpatialContext,
|
| 13 |
+
TemporalContext,
|
| 14 |
+
)
|
| 15 |
+
from openeo_gfmap.fetching.generic import build_generic_extractor
|
| 16 |
+
from openeo_gfmap.fetching.s1 import build_sentinel1_grd_extractor
|
| 17 |
+
from openeo_gfmap.fetching.s2 import build_sentinel2_l2a_extractor
|
| 18 |
+
from openeo_gfmap.preprocessing.compositing import mean_compositing, median_compositing
|
| 19 |
+
from openeo_gfmap.preprocessing.sar import compress_backscatter_uint16
|
| 20 |
+
from openeo_gfmap.utils.catalogue import UncoveredS1Exception, select_s1_orbitstate_vvvh
|
| 21 |
+
|
| 22 |
+
WORLDCEREAL_S2_BANDS = [
|
| 23 |
+
"S2-L2A-B02",
|
| 24 |
+
"S2-L2A-B03",
|
| 25 |
+
"S2-L2A-B04",
|
| 26 |
+
"S2-L2A-B05",
|
| 27 |
+
"S2-L2A-B06",
|
| 28 |
+
"S2-L2A-B07",
|
| 29 |
+
"S2-L2A-B08",
|
| 30 |
+
"S2-L2A-B8A",
|
| 31 |
+
"S2-L2A-B11",
|
| 32 |
+
"S2-L2A-B12",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
WORLDCEREAL_S1_BANDS = [
|
| 36 |
+
"S1-SIGMA0-VH",
|
| 37 |
+
"S1-SIGMA0-VV",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
WORLDCEREAL_DEM_BANDS = ["elevation", "slope"]
|
| 41 |
+
|
| 42 |
+
WORLDCEREAL_METEO_BANDS = ["AGERA5-PRECIP", "AGERA5-TMEAN"]
|
| 43 |
+
|
| 44 |
+
WORLDCEREAL_BANDS = {
|
| 45 |
+
"SENTINEL2": WORLDCEREAL_S2_BANDS,
|
| 46 |
+
"SENTINEL1": WORLDCEREAL_S1_BANDS,
|
| 47 |
+
"DEM": WORLDCEREAL_DEM_BANDS,
|
| 48 |
+
"METEO": WORLDCEREAL_METEO_BANDS,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class InvalidTemporalContextError(Exception):
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def spatially_filter_cube(
|
| 57 |
+
connection: Connection, cube: DataCube, spatial_extent: Optional[SpatialContext]
|
| 58 |
+
) -> DataCube:
|
| 59 |
+
"""
|
| 60 |
+
Apply spatial filtering to a data cube based on the given spatial extent.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
Parameters
|
| 64 |
+
----------
|
| 65 |
+
connection : Connection
|
| 66 |
+
The connection object used to interact with the openEO backend.
|
| 67 |
+
cube : DataCube
|
| 68 |
+
The input data cube to be spatially filtered.
|
| 69 |
+
spatial_extent : Optional[SpatialContext]
|
| 70 |
+
The spatial extent used for filtering the data cube. It can be a BoundingBoxExtent,
|
| 71 |
+
a GeoJSON object, or a URL to a GeoJSON or Parquet file. If set to `None`,
|
| 72 |
+
no spatial filtering will be applied.
|
| 73 |
+
|
| 74 |
+
Returns
|
| 75 |
+
-------
|
| 76 |
+
DataCube
|
| 77 |
+
The spatially filtered data cube.
|
| 78 |
+
|
| 79 |
+
Raises
|
| 80 |
+
------
|
| 81 |
+
ValueError
|
| 82 |
+
If the spatial_extent parameter is not of type BoundingBoxExtent, GeoJSON, or str.
|
| 83 |
+
|
| 84 |
+
"""
|
| 85 |
+
if isinstance(spatial_extent, BoundingBoxExtent):
|
| 86 |
+
cube = cube.filter_bbox(dict(spatial_extent))
|
| 87 |
+
elif isinstance(spatial_extent, GeoJSON):
|
| 88 |
+
cube = cube.filter_spatial(spatial_extent)
|
| 89 |
+
elif isinstance(spatial_extent, str):
|
| 90 |
+
geometry = connection.load_url(
|
| 91 |
+
spatial_extent,
|
| 92 |
+
format=(
|
| 93 |
+
"Parquet"
|
| 94 |
+
if ".parquet" in spatial_extent or ".geoparquet" in spatial_extent
|
| 95 |
+
else "GeoJSON"
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
+
cube = cube.filter_spatial(geometry)
|
| 99 |
+
|
| 100 |
+
return cube
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def select_best_s1_orbit_direction(
|
| 104 |
+
backend_context: BackendContext,
|
| 105 |
+
spatial_extent: SpatialContext,
|
| 106 |
+
temporal_extent: TemporalContext,
|
| 107 |
+
) -> str:
|
| 108 |
+
"""Selects the best Sentinel-1 orbit direction based on the given spatio-temporal context.
|
| 109 |
+
|
| 110 |
+
Parameters
|
| 111 |
+
----------
|
| 112 |
+
backend_context : BackendContext
|
| 113 |
+
The backend context for accessing the data.
|
| 114 |
+
spatial_extent : SpatialContext
|
| 115 |
+
The spatial extent of the data.
|
| 116 |
+
temporal_extent : TemporalContext
|
| 117 |
+
The temporal extent of the data.
|
| 118 |
+
|
| 119 |
+
Returns
|
| 120 |
+
-------
|
| 121 |
+
str
|
| 122 |
+
The selected orbit direction (either "ASCENDING" or "DESCENDING").
|
| 123 |
+
"""
|
| 124 |
+
try:
|
| 125 |
+
orbit_direction = select_s1_orbitstate_vvvh(
|
| 126 |
+
backend_context, spatial_extent, temporal_extent
|
| 127 |
+
)
|
| 128 |
+
except UncoveredS1Exception as exc:
|
| 129 |
+
orbit_direction = "ASCENDING"
|
| 130 |
+
print(
|
| 131 |
+
f"Could not find any Sentinel-1 data for the given spatio-temporal context. "
|
| 132 |
+
f"Using ASCENDING orbit direction as a last resort. Error: {exc}"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return orbit_direction
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def raw_datacube_S2(
|
| 139 |
+
connection: Connection,
|
| 140 |
+
backend_context: BackendContext,
|
| 141 |
+
temporal_extent: TemporalContext,
|
| 142 |
+
bands: List[str],
|
| 143 |
+
fetch_type: FetchType,
|
| 144 |
+
spatial_extent: Optional[SpatialContext] = None,
|
| 145 |
+
filter_tile: Optional[str] = None,
|
| 146 |
+
distance_to_cloud_flag: Optional[bool] = True,
|
| 147 |
+
additional_masks_flag: Optional[bool] = True,
|
| 148 |
+
apply_mask_flag: Optional[bool] = False,
|
| 149 |
+
tile_size: Optional[int] = None,
|
| 150 |
+
target_epsg: Optional[int] = None,
|
| 151 |
+
) -> DataCube:
|
| 152 |
+
"""Extract Sentinel-2 datacube from OpenEO using GFMAP routines.
|
| 153 |
+
Raw data is extracted with no cloud masking applied by default (can be
|
| 154 |
+
enabled by setting `apply_mask=True`). In additional to the raw band values
|
| 155 |
+
a cloud-mask computed from the dilation of the SCL layer, as well as a
|
| 156 |
+
rank mask from the BAP compositing are added.
|
| 157 |
+
|
| 158 |
+
Parameters
|
| 159 |
+
----------
|
| 160 |
+
connection : Connection
|
| 161 |
+
OpenEO connection instance.
|
| 162 |
+
backend_context : BackendContext
|
| 163 |
+
GFMAP Backend context to use for extraction.
|
| 164 |
+
temporal_extent : TemporalContext
|
| 165 |
+
Temporal context to extract data from.
|
| 166 |
+
bands : List[str]
|
| 167 |
+
List of Sentinel-2 bands to extract.
|
| 168 |
+
fetch_type : FetchType
|
| 169 |
+
GFMAP Fetch type to use for extraction.
|
| 170 |
+
spatial_extent : Optional[SpatialContext], optional
|
| 171 |
+
Spatial context to extract data from, can be a GFMAP BoundingBoxExtent,
|
| 172 |
+
a GeoJSON dict or an URL to a publicly accessible GeoParquet file.
|
| 173 |
+
filter_tile : Optional[str], optional
|
| 174 |
+
Filter by tile ID, by default disabled. This forces the process to only
|
| 175 |
+
one tile ID from the Sentinel-2 collection.
|
| 176 |
+
apply_mask : bool, optional
|
| 177 |
+
Apply cloud masking, by default False. Can be enabled for high
|
| 178 |
+
optimization of memory usage.
|
| 179 |
+
target_epsg : Optional[int], optional
|
| 180 |
+
Target EPSG to resample the data, by default None.
|
| 181 |
+
"""
|
| 182 |
+
# Extract the SCL collection only
|
| 183 |
+
scl_cube_properties = {"eo:cloud_cover": lambda val: val <= 95.0}
|
| 184 |
+
if filter_tile:
|
| 185 |
+
scl_cube_properties["tileId"] = lambda val: val == filter_tile
|
| 186 |
+
|
| 187 |
+
# Create the job to extract S2
|
| 188 |
+
extraction_parameters: dict[str, Any] = {
|
| 189 |
+
"target_resolution": 10,
|
| 190 |
+
"target_crs": target_epsg,
|
| 191 |
+
"load_collection": {
|
| 192 |
+
"eo:cloud_cover": lambda val: val <= 95.0,
|
| 193 |
+
},
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
scl_cube = connection.load_collection(
|
| 197 |
+
collection_id="SENTINEL2_L2A",
|
| 198 |
+
bands=["SCL"],
|
| 199 |
+
temporal_extent=[temporal_extent.start_date, temporal_extent.end_date],
|
| 200 |
+
properties=scl_cube_properties,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Resample to 10m resolution for the SCL layer, using optional target_epsg
|
| 204 |
+
scl_cube = scl_cube.resample_spatial(projection=target_epsg, resolution=10)
|
| 205 |
+
|
| 206 |
+
# Compute the SCL dilation mask
|
| 207 |
+
scl_dilated_mask = scl_cube.process(
|
| 208 |
+
"to_scl_dilation_mask",
|
| 209 |
+
data=scl_cube,
|
| 210 |
+
scl_band_name="SCL",
|
| 211 |
+
kernel1_size=17, # 17px dilation on a 10m layer
|
| 212 |
+
kernel2_size=77, # 77px dilation on a 10m layer
|
| 213 |
+
mask1_values=[2, 4, 5, 6, 7],
|
| 214 |
+
mask2_values=[3, 8, 9, 10, 11],
|
| 215 |
+
erosion_kernel_size=3,
|
| 216 |
+
).rename_labels("bands", ["S2-L2A-SCL_DILATED_MASK"])
|
| 217 |
+
|
| 218 |
+
additional_masks = scl_dilated_mask
|
| 219 |
+
|
| 220 |
+
if distance_to_cloud_flag:
|
| 221 |
+
# Compute the distance to cloud and add it to the cube
|
| 222 |
+
distance_to_cloud = scl_cube.apply_neighborhood(
|
| 223 |
+
process=UDF.from_file(Path(__file__).parent / "udf_distance_to_cloud.py"),
|
| 224 |
+
size=[
|
| 225 |
+
{"dimension": "x", "unit": "px", "value": 256},
|
| 226 |
+
{"dimension": "y", "unit": "px", "value": 256},
|
| 227 |
+
{"dimension": "t", "unit": "null", "value": "P1D"},
|
| 228 |
+
],
|
| 229 |
+
overlap=[
|
| 230 |
+
{"dimension": "x", "unit": "px", "value": 16},
|
| 231 |
+
{"dimension": "y", "unit": "px", "value": 16},
|
| 232 |
+
],
|
| 233 |
+
).rename_labels("bands", ["S2-L2A-DISTANCE-TO-CLOUD"])
|
| 234 |
+
|
| 235 |
+
additional_masks = scl_dilated_mask.merge_cubes(distance_to_cloud)
|
| 236 |
+
|
| 237 |
+
if additional_masks_flag:
|
| 238 |
+
extraction_parameters["pre_merge"] = additional_masks
|
| 239 |
+
|
| 240 |
+
if filter_tile:
|
| 241 |
+
extraction_parameters["load_collection"]["tileId"] = (
|
| 242 |
+
lambda val: val == filter_tile
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if tile_size is not None:
|
| 246 |
+
extraction_parameters["update_arguments"] = {
|
| 247 |
+
"featureflags": {"tilesize": tile_size}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
s2_cube = build_sentinel2_l2a_extractor(
|
| 251 |
+
backend_context,
|
| 252 |
+
bands=bands,
|
| 253 |
+
fetch_type=fetch_type,
|
| 254 |
+
**extraction_parameters,
|
| 255 |
+
).get_cube(connection, None, temporal_extent)
|
| 256 |
+
|
| 257 |
+
if apply_mask_flag:
|
| 258 |
+
s2_cube = s2_cube.mask(scl_dilated_mask)
|
| 259 |
+
|
| 260 |
+
return s2_cube
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def raw_datacube_S1(
|
| 264 |
+
connection: Connection,
|
| 265 |
+
backend_context: BackendContext,
|
| 266 |
+
temporal_extent: TemporalContext,
|
| 267 |
+
bands: List[str],
|
| 268 |
+
fetch_type: FetchType,
|
| 269 |
+
spatial_extent: Optional[SpatialContext] = None,
|
| 270 |
+
target_resolution: float = 20.0,
|
| 271 |
+
orbit_direction: Optional[str] = None,
|
| 272 |
+
tile_size: Optional[int] = None,
|
| 273 |
+
target_epsg: Optional[int] = None,
|
| 274 |
+
) -> DataCube:
|
| 275 |
+
"""Extract Sentinel-1 datacube from OpenEO using GFMAP routines.
|
| 276 |
+
|
| 277 |
+
Parameters
|
| 278 |
+
----------
|
| 279 |
+
connection : Connection
|
| 280 |
+
OpenEO connection instance.
|
| 281 |
+
backend_context : BackendContext
|
| 282 |
+
GFMAP Backend context to use for extraction.
|
| 283 |
+
temporal_extent : TemporalContext
|
| 284 |
+
Temporal context to extract data from.
|
| 285 |
+
bands : List[str]
|
| 286 |
+
List of Sentinel-1 bands to extract.
|
| 287 |
+
fetch_type : FetchType
|
| 288 |
+
GFMAP Fetch type to use for extraction.
|
| 289 |
+
spatial_extent : Optional[SpatialContext], optional
|
| 290 |
+
Spatial context to extract data from, can be a GFMAP BoundingBoxExtent,
|
| 291 |
+
a GeoJSON dict or an URL to a publicly accessible GeoParquet file.
|
| 292 |
+
target_resolution : float, optional
|
| 293 |
+
Target resolution to resample the data to, by default 20.0.
|
| 294 |
+
orbit_direction : Optional[str], optional
|
| 295 |
+
Orbit direction to filter the data, by default None.
|
| 296 |
+
target_epsg : Optional[int], optional
|
| 297 |
+
Target EPSG to resample the data to, by default None.
|
| 298 |
+
"""
|
| 299 |
+
extractor_parameters: Dict[str, Any] = {
|
| 300 |
+
"target_resolution": target_resolution,
|
| 301 |
+
"target_crs": target_epsg,
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
if orbit_direction is not None:
|
| 305 |
+
extractor_parameters["load_collection"] = {
|
| 306 |
+
"sat:orbit_state": lambda orbit: orbit == orbit_direction,
|
| 307 |
+
"polarisation": lambda pol: pol == "VV&VH",
|
| 308 |
+
}
|
| 309 |
+
else:
|
| 310 |
+
extractor_parameters["load_collection"] = {
|
| 311 |
+
"polarisation": lambda pol: pol == "VV&VH",
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
if tile_size is not None:
|
| 315 |
+
extractor_parameters["update_arguments"] = {
|
| 316 |
+
"featureflags": {"tilesize": tile_size}
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
s1_cube = build_sentinel1_grd_extractor(
|
| 320 |
+
backend_context, bands=bands, fetch_type=fetch_type, **extractor_parameters
|
| 321 |
+
).get_cube(connection, None, temporal_extent)
|
| 322 |
+
|
| 323 |
+
return s1_cube
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def raw_datacube_DEM(
|
| 327 |
+
connection: Connection,
|
| 328 |
+
backend_context: BackendContext,
|
| 329 |
+
fetch_type: FetchType,
|
| 330 |
+
spatial_extent: Optional[SpatialContext] = None,
|
| 331 |
+
) -> DataCube:
|
| 332 |
+
"""Method to get the DEM datacube from the backend.
|
| 333 |
+
If running on CDSE backend, the slope is also loaded from the global
|
| 334 |
+
slope collection and merged with the DEM cube.
|
| 335 |
+
|
| 336 |
+
Returns
|
| 337 |
+
-------
|
| 338 |
+
DataCube
|
| 339 |
+
openEO datacube with the DEM data (and slope if available).
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
extractor = build_generic_extractor(
|
| 343 |
+
backend_context=backend_context,
|
| 344 |
+
bands=["COP-DEM"],
|
| 345 |
+
fetch_type=fetch_type,
|
| 346 |
+
collection_name="COPERNICUS_30",
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
cube = extractor.get_cube(connection, None, None)
|
| 350 |
+
cube = cube.rename_labels(dimension="bands", target=["elevation"])
|
| 351 |
+
|
| 352 |
+
if backend_context.backend in [Backend.CDSE, Backend.CDSE_STAGING]:
|
| 353 |
+
# On CDSE we can load the slope from a global slope collection
|
| 354 |
+
slope = connection.load_stac(
|
| 355 |
+
"https://stac.openeo.vito.be/collections/COPERNICUS30_DEM_SLOPE",
|
| 356 |
+
bands=["Slope"],
|
| 357 |
+
).rename_labels(dimension="bands", target=["slope"])
|
| 358 |
+
# Client fix for CDSE, the openeo client might be unsynchronized with
|
| 359 |
+
# the backend.
|
| 360 |
+
if "t" not in slope.metadata.dimension_names():
|
| 361 |
+
slope.metadata = slope.metadata.add_dimension("t", "2020-01-01", "temporal")
|
| 362 |
+
slope = slope.min_time()
|
| 363 |
+
|
| 364 |
+
# Note that when slope is available we use it as the base cube
|
| 365 |
+
# to merge DEM with, as it comes at 20m resolution.
|
| 366 |
+
cube = slope.merge_cubes(cube)
|
| 367 |
+
|
| 368 |
+
return cube
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def raw_datacube_METEO(
|
| 372 |
+
connection: Connection,
|
| 373 |
+
backend_context: BackendContext,
|
| 374 |
+
temporal_extent: TemporalContext,
|
| 375 |
+
fetch_type: FetchType,
|
| 376 |
+
spatial_extent: Optional[SpatialContext] = None,
|
| 377 |
+
) -> DataCube:
|
| 378 |
+
extractor = build_generic_extractor(
|
| 379 |
+
backend_context=backend_context,
|
| 380 |
+
bands=["AGERA5-TMEAN", "AGERA5-PRECIP"],
|
| 381 |
+
fetch_type=fetch_type,
|
| 382 |
+
collection_name="AGERA5",
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
meteo_cube = extractor.get_cube(connection, None, temporal_extent)
|
| 386 |
+
|
| 387 |
+
return meteo_cube
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def precomposited_datacube_METEO(
|
| 391 |
+
connection: Connection,
|
| 392 |
+
temporal_extent: TemporalContext,
|
| 393 |
+
compositing_window: Literal["month", "dekad"] = "month",
|
| 394 |
+
) -> DataCube:
|
| 395 |
+
"""Extract the precipitation and temperature AGERA5 data from a
|
| 396 |
+
pre-composited and pre-processed collection. The data is stored in the
|
| 397 |
+
CloudFerro S3 stoage, allowing faster access and processing from the CDSE
|
| 398 |
+
backend.
|
| 399 |
+
|
| 400 |
+
Limitations:
|
| 401 |
+
- Only monthly composited data is available.
|
| 402 |
+
- Only two bands are available: precipitation-flux and temperature-mean.
|
| 403 |
+
"""
|
| 404 |
+
temporal_extent = [temporal_extent.start_date, temporal_extent.end_date]
|
| 405 |
+
|
| 406 |
+
if compositing_window == "month":
|
| 407 |
+
# Load precomposited monthly meteo data
|
| 408 |
+
cube = connection.load_stac(
|
| 409 |
+
url="https://stac.openeo.vito.be/collections/agera5_monthly",
|
| 410 |
+
temporal_extent=temporal_extent,
|
| 411 |
+
bands=["precipitation-flux", "temperature-mean"],
|
| 412 |
+
)
|
| 413 |
+
elif compositing_window == "dekad":
|
| 414 |
+
# Load precomposited dekadal meteo data
|
| 415 |
+
cube = connection.load_stac(
|
| 416 |
+
url="https://stac.openeo.vito.be/collections/agera5_dekad",
|
| 417 |
+
temporal_extent=temporal_extent,
|
| 418 |
+
bands=["precipitation-flux", "temperature-mean"],
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# cube.result_node().update_arguments(featureflags={"tilesize": 1})
|
| 422 |
+
cube = cube.rename_labels(
|
| 423 |
+
dimension="bands", target=["AGERA5-PRECIP", "AGERA5-TMEAN"]
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
return cube
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def worldcereal_preprocessed_inputs(
|
| 430 |
+
connection: Connection,
|
| 431 |
+
backend_context: BackendContext,
|
| 432 |
+
spatial_extent: Union[GeoJSON, BoundingBoxExtent, str],
|
| 433 |
+
temporal_extent: TemporalContext,
|
| 434 |
+
fetch_type: Optional[FetchType] = FetchType.TILE,
|
| 435 |
+
disable_meteo: bool = False,
|
| 436 |
+
validate_temporal_context: bool = True,
|
| 437 |
+
s1_orbit_state: Optional[str] = None,
|
| 438 |
+
tile_size: Optional[int] = None,
|
| 439 |
+
s2_tile: Optional[str] = None,
|
| 440 |
+
compositing_window: Literal["month", "dekad"] = "month",
|
| 441 |
+
target_epsg: Optional[int] = None,
|
| 442 |
+
) -> DataCube:
|
| 443 |
+
# First validate the temporal context
|
| 444 |
+
if validate_temporal_context:
|
| 445 |
+
_validate_temporal_context(temporal_extent)
|
| 446 |
+
|
| 447 |
+
# See if requested compositing method is supported
|
| 448 |
+
assert compositing_window in [
|
| 449 |
+
"month",
|
| 450 |
+
"dekad",
|
| 451 |
+
], 'Compositing window must be either "month" or "dekad"'
|
| 452 |
+
|
| 453 |
+
# Extraction of S2 from GFMAP
|
| 454 |
+
s2_data = raw_datacube_S2(
|
| 455 |
+
connection=connection,
|
| 456 |
+
backend_context=backend_context,
|
| 457 |
+
temporal_extent=temporal_extent,
|
| 458 |
+
bands=WORLDCEREAL_S2_BANDS,
|
| 459 |
+
fetch_type=fetch_type,
|
| 460 |
+
filter_tile=s2_tile,
|
| 461 |
+
distance_to_cloud_flag=False if fetch_type == FetchType.POINT else True,
|
| 462 |
+
additional_masks_flag=False,
|
| 463 |
+
apply_mask_flag=True,
|
| 464 |
+
tile_size=tile_size,
|
| 465 |
+
target_epsg=target_epsg,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
s2_data = median_compositing(s2_data, period=compositing_window)
|
| 469 |
+
|
| 470 |
+
# Cast to uint16
|
| 471 |
+
s2_data = s2_data.linear_scale_range(0, 65534, 0, 65534)
|
| 472 |
+
|
| 473 |
+
# Extraction of the S1 data
|
| 474 |
+
# Decides on the orbit direction from the maximum overlapping area of
|
| 475 |
+
# available products.
|
| 476 |
+
if s1_orbit_state is None and backend_context.backend in [
|
| 477 |
+
Backend.CDSE,
|
| 478 |
+
Backend.CDSE_STAGING,
|
| 479 |
+
Backend.FED,
|
| 480 |
+
]:
|
| 481 |
+
s1_orbit_state = select_best_s1_orbit_direction(
|
| 482 |
+
backend_context, spatial_extent, temporal_extent
|
| 483 |
+
)
|
| 484 |
+
s1_data = raw_datacube_S1(
|
| 485 |
+
connection=connection,
|
| 486 |
+
backend_context=backend_context,
|
| 487 |
+
temporal_extent=temporal_extent,
|
| 488 |
+
bands=WORLDCEREAL_S1_BANDS,
|
| 489 |
+
fetch_type=fetch_type,
|
| 490 |
+
target_resolution=20.0, # Compute the backscatter at 20m resolution, then upsample nearest neighbor when merging cubes
|
| 491 |
+
orbit_direction=s1_orbit_state, # If None, make the query on the catalogue for the best orbit
|
| 492 |
+
tile_size=tile_size,
|
| 493 |
+
target_epsg=target_epsg,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
s1_data = mean_compositing(s1_data, period=compositing_window)
|
| 497 |
+
s1_data = compress_backscatter_uint16(backend_context, s1_data)
|
| 498 |
+
|
| 499 |
+
dem_data = raw_datacube_DEM(
|
| 500 |
+
connection=connection,
|
| 501 |
+
backend_context=backend_context,
|
| 502 |
+
fetch_type=fetch_type,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
# Explicitly resample DEM with bilinear interpolation and based on S2 grid
|
| 506 |
+
# note: we use s2_data here as base to avoid issues at the edges because source
|
| 507 |
+
# data is not in UTM projection.
|
| 508 |
+
dem_data = dem_data.resample_cube_spatial(s2_data, method="bilinear")
|
| 509 |
+
|
| 510 |
+
# Cast DEM to UINT16
|
| 511 |
+
dem_data = dem_data.linear_scale_range(0, 65534, 0, 65534)
|
| 512 |
+
|
| 513 |
+
data = s2_data.merge_cubes(s1_data)
|
| 514 |
+
data = data.merge_cubes(dem_data)
|
| 515 |
+
|
| 516 |
+
if not disable_meteo:
|
| 517 |
+
meteo_data = precomposited_datacube_METEO(
|
| 518 |
+
connection=connection,
|
| 519 |
+
temporal_extent=temporal_extent,
|
| 520 |
+
compositing_window=compositing_window,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# Explicitly resample meteo with bilinear interpolation and based on S2 grid
|
| 524 |
+
# note: we use s2_data here as base to avoid issues at the edges because source
|
| 525 |
+
# data is not in UTM projection.
|
| 526 |
+
meteo_data = meteo_data.resample_cube_spatial(s2_data, method="bilinear")
|
| 527 |
+
|
| 528 |
+
data = data.merge_cubes(meteo_data)
|
| 529 |
+
|
| 530 |
+
return data
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def _validate_temporal_context(temporal_context: TemporalContext) -> None:
|
| 534 |
+
"""validation method to ensure proper specification of temporal context.
|
| 535 |
+
which requires that the start and end date are at the first and last day of a month.
|
| 536 |
+
We also check if the temporal context does not span more than a year which is
|
| 537 |
+
currently not supported.
|
| 538 |
+
|
| 539 |
+
Parameters
|
| 540 |
+
----------
|
| 541 |
+
temporal_context : TemporalContext
|
| 542 |
+
temporal context to validate
|
| 543 |
+
|
| 544 |
+
Raises
|
| 545 |
+
------
|
| 546 |
+
InvalidTemporalContextError
|
| 547 |
+
if start_date is not on the first day of a month or end_date
|
| 548 |
+
is not on the last day of a month or the span is more than
|
| 549 |
+
one year.
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
start_date, end_date = temporal_context.to_datetime()
|
| 553 |
+
|
| 554 |
+
if start_date != start_date.replace(
|
| 555 |
+
day=1
|
| 556 |
+
) or end_date != end_date + pd.offsets.MonthEnd(0):
|
| 557 |
+
error_msg = (
|
| 558 |
+
"WorldCereal uses monthly compositing. For this to work properly, "
|
| 559 |
+
"requested temporal range should start and end at the first and last "
|
| 560 |
+
"day of a month. Instead, got: "
|
| 561 |
+
f"{temporal_context.start_date} - {temporal_context.end_date}. "
|
| 562 |
+
"You may use `worldcereal.preprocessing.correct_temporal_context()` "
|
| 563 |
+
"to correct the temporal context."
|
| 564 |
+
)
|
| 565 |
+
raise InvalidTemporalContextError(error_msg)
|
| 566 |
+
|
| 567 |
+
if pd.Timedelta(end_date - start_date).days > 365:
|
| 568 |
+
error_msg = (
|
| 569 |
+
"WorldCereal currently does not support temporal ranges spanning "
|
| 570 |
+
"more than a year. Got: "
|
| 571 |
+
f"{temporal_context.start_date} - {temporal_context.end_date}."
|
| 572 |
+
)
|
| 573 |
+
raise InvalidTemporalContextError(error_msg)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def correct_temporal_context(temporal_context: TemporalContext) -> TemporalContext:
|
| 577 |
+
"""Corrects the temporal context to ensure that the start and end date are
|
| 578 |
+
at the first and last day of a month as required by the WorldCereal processing.
|
| 579 |
+
|
| 580 |
+
Parameters
|
| 581 |
+
----------
|
| 582 |
+
temporal_context : TemporalContext
|
| 583 |
+
temporal context to correct
|
| 584 |
+
|
| 585 |
+
Returns
|
| 586 |
+
-------
|
| 587 |
+
TemporalContext
|
| 588 |
+
corrected temporal context
|
| 589 |
+
"""
|
| 590 |
+
|
| 591 |
+
start_date, end_date = temporal_context.to_datetime()
|
| 592 |
+
|
| 593 |
+
start_date = start_date.replace(day=1)
|
| 594 |
+
end_date = end_date + pd.offsets.MonthEnd(0)
|
| 595 |
+
|
| 596 |
+
return TemporalContext(
|
| 597 |
+
start_date=start_date.strftime("%Y-%m-%d"),
|
| 598 |
+
end_date=end_date.strftime("%Y-%m-%d"),
|
| 599 |
+
)
|
worldcereal/openeo/udf_distance_to_cloud.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# dependencies = [
|
| 3 |
+
# "scikit-image",
|
| 4 |
+
# ]
|
| 5 |
+
# ///
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import xarray as xr
|
| 9 |
+
from openeo.udf import XarrayDataCube
|
| 10 |
+
from scipy.ndimage import distance_transform_cdt
|
| 11 |
+
from skimage.morphology import binary_erosion, footprints
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube:
|
| 15 |
+
cube_array: xr.DataArray = cube.get_array()
|
| 16 |
+
cube_array = cube_array.transpose("bands", "y", "x")
|
| 17 |
+
|
| 18 |
+
clouds: xr.DataArray = np.logical_or(
|
| 19 |
+
np.logical_and(cube_array < 11, cube_array >= 8), cube_array == 3
|
| 20 |
+
).isel(
|
| 21 |
+
bands=0
|
| 22 |
+
) # type: ignore
|
| 23 |
+
|
| 24 |
+
# Calculate the Distance To Cloud score
|
| 25 |
+
# Erode
|
| 26 |
+
er = footprints.disk(3)
|
| 27 |
+
|
| 28 |
+
# Define a function to apply binary erosion
|
| 29 |
+
def erode(image, selem):
|
| 30 |
+
return ~binary_erosion(image, selem)
|
| 31 |
+
|
| 32 |
+
# Use apply_ufunc to apply the erosion operation
|
| 33 |
+
eroded = xr.apply_ufunc(
|
| 34 |
+
erode, # function to apply
|
| 35 |
+
clouds, # input DataArray
|
| 36 |
+
input_core_dims=[["y", "x"]], # dimensions over which to apply function
|
| 37 |
+
output_core_dims=[["y", "x"]], # dimensions of the output
|
| 38 |
+
vectorize=True, # vectorize the function over non-core dimensions
|
| 39 |
+
dask="parallelized", # enable dask parallelization
|
| 40 |
+
output_dtypes=[np.int32], # data type of the output
|
| 41 |
+
kwargs={"selem": er}, # additional keyword arguments to pass to erode
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Distance to cloud in manhattan distance measure
|
| 45 |
+
distance = xr.apply_ufunc(
|
| 46 |
+
distance_transform_cdt,
|
| 47 |
+
eroded,
|
| 48 |
+
input_core_dims=[["y", "x"]],
|
| 49 |
+
output_core_dims=[["y", "x"]],
|
| 50 |
+
vectorize=True,
|
| 51 |
+
dask="parallelized",
|
| 52 |
+
output_dtypes=[np.int32],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
distance_da = xr.DataArray(
|
| 56 |
+
distance,
|
| 57 |
+
coords={
|
| 58 |
+
"y": cube_array.coords["y"],
|
| 59 |
+
"x": cube_array.coords["x"],
|
| 60 |
+
},
|
| 61 |
+
dims=["y", "x"],
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
distance_da = distance_da.expand_dims(
|
| 65 |
+
dim={
|
| 66 |
+
"bands": cube_array.coords["bands"],
|
| 67 |
+
},
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
distance_da = distance_da.transpose("bands", "y", "x")
|
| 71 |
+
|
| 72 |
+
return XarrayDataCube(distance_da)
|
worldcereal/parameters.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field, ValidationError, model_validator
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class WorldCerealProductType(Enum):
|
| 9 |
+
"""Enum to define the different WorldCereal products."""
|
| 10 |
+
|
| 11 |
+
CROPLAND = "cropland"
|
| 12 |
+
CROPTYPE = "croptype"
|
| 13 |
+
EMBEDDINGS = "embeddings"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FeaturesParameters(BaseModel):
|
| 17 |
+
"""Parameters for the feature extraction UDFs. Types are enforced by
|
| 18 |
+
Pydantic.
|
| 19 |
+
|
| 20 |
+
Attributes
|
| 21 |
+
----------
|
| 22 |
+
rescale_s1 : bool (default=False)
|
| 23 |
+
Whether to rescale Sentinel-1 bands before feature extraction. Should be
|
| 24 |
+
left to False, as this is done in the Presto UDF itself.
|
| 25 |
+
presto_model_url : str
|
| 26 |
+
Public URL to the Presto model used for feature extraction. The file
|
| 27 |
+
should be a PyTorch serialized model.
|
| 28 |
+
compile_presto : bool (default=False)
|
| 29 |
+
Whether to compile the Presto encoder for speeding up large-scale inference.
|
| 30 |
+
temporal_prediction : bool (default=False)
|
| 31 |
+
Whether to use temporal-explicit predictions. If True, the time dimension
|
| 32 |
+
is preserved in Presto features and a specific timestep is selected later.
|
| 33 |
+
If False, features are pooled across time (non-temporal prediction).
|
| 34 |
+
target_date : str (default=None)
|
| 35 |
+
Target date for temporal-explicit predictions in ISO format (YYYY-MM-DD).
|
| 36 |
+
Only used when temporal_prediction=True. If None, the middle timestep is used.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
rescale_s1: bool
|
| 40 |
+
presto_model_url: str
|
| 41 |
+
compile_presto: bool
|
| 42 |
+
temporal_prediction: bool = Field(default=False)
|
| 43 |
+
target_date: Optional[str] = Field(default=None)
|
| 44 |
+
|
| 45 |
+
@model_validator(mode="after")
|
| 46 |
+
def check_temporal_parameters(self):
|
| 47 |
+
"""Validates temporal prediction parameters."""
|
| 48 |
+
if self.target_date is not None and not self.temporal_prediction:
|
| 49 |
+
raise ValidationError(
|
| 50 |
+
"target_date can only be specified when temporal_prediction=True"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if self.target_date is not None:
|
| 54 |
+
try:
|
| 55 |
+
datetime.fromisoformat(self.target_date)
|
| 56 |
+
except ValueError:
|
| 57 |
+
raise ValidationError("target_date must be in ISO format (YYYY-MM-DD)")
|
| 58 |
+
|
| 59 |
+
return self
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ClassifierParameters(BaseModel):
|
| 63 |
+
"""Parameters for the classifier. Types are enforced by Pydantic.
|
| 64 |
+
|
| 65 |
+
Attributes
|
| 66 |
+
----------
|
| 67 |
+
classifier_url : str
|
| 68 |
+
Public URL to the classifier model. Te file should be an ONNX accepting
|
| 69 |
+
a `features` field for input data and returning either two output
|
| 70 |
+
probability arrays `true` and `false` in case of cropland mapping, or
|
| 71 |
+
a probability array per-class in case of croptype mapping.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
classifier_url: str
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class PostprocessParameters(BaseModel):
|
| 78 |
+
"""Parameters for postprocessing. Types are enforced by Pydantic.
|
| 79 |
+
|
| 80 |
+
Attributes
|
| 81 |
+
----------
|
| 82 |
+
enable: bool (default=True)
|
| 83 |
+
Whether to enable postprocessing.
|
| 84 |
+
method: str (default="smooth_probabilities")
|
| 85 |
+
The method to use for postprocessing. Must be one of ["smooth_probabilities", "majority_vote"]
|
| 86 |
+
kernel_size: int (default=5)
|
| 87 |
+
Used for majority vote postprocessing. Must be an odd number, larger than 1 and smaller than 25.
|
| 88 |
+
save_intermediate: bool (default=False)
|
| 89 |
+
Whether to save intermediate results (before applying the postprocessing).
|
| 90 |
+
The intermediate results will be saved in the GeoTiff format.
|
| 91 |
+
keep_class_probs: bool (default=True)
|
| 92 |
+
If the per-class probabilities should be outputted in the final product.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
enable: bool = Field(default=True)
|
| 96 |
+
method: str = Field(default="smooth_probabilities")
|
| 97 |
+
kernel_size: int = Field(default=5)
|
| 98 |
+
save_intermediate: bool = Field(default=False)
|
| 99 |
+
keep_class_probs: bool = Field(default=True)
|
| 100 |
+
|
| 101 |
+
@model_validator(mode="after")
|
| 102 |
+
def check_parameters(self):
|
| 103 |
+
"""Validates parameters."""
|
| 104 |
+
if not self.enable and self.save_intermediate:
|
| 105 |
+
raise ValueError(
|
| 106 |
+
"Cannot save intermediate results if postprocessing is disabled."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if self.method not in ["smooth_probabilities", "majority_vote"]:
|
| 110 |
+
raise ValueError(
|
| 111 |
+
f"Method must be one of ['smooth_probabilities', 'majority_vote'], got {self.method}"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if self.method == "majority_vote":
|
| 115 |
+
if self.kernel_size % 2 == 0:
|
| 116 |
+
raise ValueError(
|
| 117 |
+
f"Kernel size for majority filtering should be an odd number, got {self.kernel_size}"
|
| 118 |
+
)
|
| 119 |
+
if self.kernel_size > 25:
|
| 120 |
+
raise ValueError(
|
| 121 |
+
f"Kernel size for majority filtering should be an odd number smaller than 25, got {self.kernel_size}"
|
| 122 |
+
)
|
| 123 |
+
if self.kernel_size < 3:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"Kernel size for majority filtering should be an odd number larger than 1, got {self.kernel_size}"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return self
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class BaseParameters(BaseModel):
|
| 132 |
+
"""Base class for shared parameter logic."""
|
| 133 |
+
|
| 134 |
+
postprocess_parameters: PostprocessParameters = Field(
|
| 135 |
+
default_factory=lambda: PostprocessParameters()
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def create_feature_parameters(**kwargs):
|
| 140 |
+
defaults = {
|
| 141 |
+
"rescale_s1": False,
|
| 142 |
+
"presto_model_url": "",
|
| 143 |
+
"compile_presto": False,
|
| 144 |
+
"temporal_prediction": False,
|
| 145 |
+
"target_date": None,
|
| 146 |
+
}
|
| 147 |
+
defaults.update(kwargs)
|
| 148 |
+
return FeaturesParameters(**defaults)
|
| 149 |
+
|
| 150 |
+
@staticmethod
|
| 151 |
+
def create_classifier_parameters(classifier_url: str):
|
| 152 |
+
return ClassifierParameters(classifier_url=classifier_url)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class CropLandParameters(BaseParameters):
|
| 156 |
+
"""Parameters for the cropland product inference pipeline. Types are
|
| 157 |
+
enforced by Pydantic.
|
| 158 |
+
|
| 159 |
+
Attributes
|
| 160 |
+
----------
|
| 161 |
+
feature_parameters : FeaturesParameters
|
| 162 |
+
Parameters for the feature extraction UDF. Will be serialized into a
|
| 163 |
+
dictionary and passed in the process graph.
|
| 164 |
+
classifier_parameters : ClassifierParameters
|
| 165 |
+
Parameters for the classifier UDF. Will be serialized into a dictionary
|
| 166 |
+
and passed in the process graph.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
feature_parameters: FeaturesParameters = BaseParameters.create_feature_parameters(
|
| 170 |
+
rescale_s1=False,
|
| 171 |
+
presto_model_url="https://s3.waw3-1.cloudferro.com/swift/v1/openeo-ml-models-prod/worldcereal/presto-prometheo-landcover-MulticlassWithCroplandAuxBCELoss-labelsmoothing=0.05-month-LANDCOVER10-augment=True-balance=True-timeexplicit=False-masking=enabled-run=202510301004_encoder.pt", # NOQA
|
| 172 |
+
compile_presto=False,
|
| 173 |
+
temporal_prediction=False,
|
| 174 |
+
target_date=None,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def _default_classifier_parameters() -> ClassifierParameters:
|
| 179 |
+
return BaseParameters.create_classifier_parameters(
|
| 180 |
+
classifier_url="https://s3.waw3-1.cloudferro.com/swift/v1/openeo-ml-models-prod/worldcereal/PrestoDownstreamCatBoost_temporary-crops_v201-prestorun=202510301004.onnx" # NOQA
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
classifier_parameters: ClassifierParameters = Field(
|
| 184 |
+
default_factory=lambda: CropLandParameters._default_classifier_parameters()
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def __init__(self, classifier_url: Optional[str] = None, **kwargs):
|
| 188 |
+
# Allow overriding classifier URL unless explicit classifier_parameters provided
|
| 189 |
+
if "classifier_parameters" not in kwargs and classifier_url is not None:
|
| 190 |
+
kwargs["classifier_parameters"] = (
|
| 191 |
+
BaseParameters.create_classifier_parameters(
|
| 192 |
+
classifier_url=classifier_url
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
super().__init__(**kwargs)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class CropTypeParameters(BaseParameters):
|
| 199 |
+
"""Parameters for the croptype product inference pipeline. Types are
|
| 200 |
+
enforced by Pydantic.
|
| 201 |
+
|
| 202 |
+
Attributes
|
| 203 |
+
----------
|
| 204 |
+
feature_parameters : FeaturesParameters
|
| 205 |
+
Parameters for the feature extraction UDF. Will be serialized into a
|
| 206 |
+
dictionary and passed in the process graph.
|
| 207 |
+
classifier_parameters : ClassifierParameters
|
| 208 |
+
Parameters for the classifier UDF. Will be serialized into a dictionary
|
| 209 |
+
and passed in the process graph.
|
| 210 |
+
mask_cropland : bool (default=True)
|
| 211 |
+
Whether or not to mask the cropland pixels before running crop type inference.
|
| 212 |
+
save_mask : bool (default=False)
|
| 213 |
+
Whether or not to save the cropland mask as an intermediate result.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
@staticmethod
|
| 217 |
+
def _default_feature_parameters() -> FeaturesParameters:
|
| 218 |
+
"""Single source of truth for default croptype feature parameters."""
|
| 219 |
+
return BaseParameters.create_feature_parameters(
|
| 220 |
+
rescale_s1=False,
|
| 221 |
+
presto_model_url="https://s3.waw3-1.cloudferro.com/swift/v1/openeo-ml-models-prod/worldcereal/presto-prometheo-croptype-with-nocrop-FocalLoss-labelsmoothing%3D0.05-month-CROPTYPE27-augment%3DTrue-balance%3DTrue-timeexplicit%3DFalse-masking%3Denabled-run%3D202510301004_encoder.pt", # NOQA
|
| 222 |
+
compile_presto=False,
|
| 223 |
+
temporal_prediction=False,
|
| 224 |
+
target_date=None, # By default take the middle date
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
@staticmethod
|
| 228 |
+
def _default_classifier_parameters() -> ClassifierParameters:
|
| 229 |
+
return BaseParameters.create_classifier_parameters(
|
| 230 |
+
classifier_url="https://s3.waw3-1.cloudferro.com/swift/v1/openeo-ml-models-prod/worldcereal/PrestoDownstreamCatBoost_croptype_v201-prestorun%3D202510301004.onnx"
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
feature_parameters: FeaturesParameters = Field(
|
| 234 |
+
default_factory=lambda: CropTypeParameters._default_feature_parameters()
|
| 235 |
+
)
|
| 236 |
+
classifier_parameters: ClassifierParameters = Field(
|
| 237 |
+
default_factory=lambda: CropTypeParameters._default_classifier_parameters()
|
| 238 |
+
)
|
| 239 |
+
mask_cropland: bool = Field(default=True)
|
| 240 |
+
save_mask: bool = Field(default=False)
|
| 241 |
+
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
target_date: Optional[str] = None,
|
| 245 |
+
classifier_url: Optional[str] = None,
|
| 246 |
+
**kwargs,
|
| 247 |
+
):
|
| 248 |
+
# Override feature target_date if feature_parameters not supplied
|
| 249 |
+
if "feature_parameters" not in kwargs:
|
| 250 |
+
fp = self._default_feature_parameters().model_copy()
|
| 251 |
+
fp.target_date = target_date # type: ignore[attr-defined]
|
| 252 |
+
kwargs["feature_parameters"] = fp
|
| 253 |
+
# Override classifier URL if classifier_parameters not supplied
|
| 254 |
+
if "classifier_parameters" not in kwargs and classifier_url is not None:
|
| 255 |
+
kwargs["classifier_parameters"] = (
|
| 256 |
+
BaseParameters.create_classifier_parameters(
|
| 257 |
+
classifier_url=classifier_url
|
| 258 |
+
)
|
| 259 |
+
)
|
| 260 |
+
super().__init__(**kwargs)
|
| 261 |
+
|
| 262 |
+
@model_validator(mode="after")
|
| 263 |
+
def check_mask_parameters(self):
|
| 264 |
+
"""Validates the mask-related parameters."""
|
| 265 |
+
if not self.mask_cropland and self.save_mask:
|
| 266 |
+
raise ValidationError("Cannot save mask if mask_cropland is disabled.")
|
| 267 |
+
return self
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class EmbeddingsParameters(BaseParameters):
|
| 271 |
+
"""Parameters for the embeddings product inference pipeline. Types are
|
| 272 |
+
enforced by Pydantic.
|
| 273 |
+
|
| 274 |
+
Attributes
|
| 275 |
+
----------
|
| 276 |
+
feature_parameters : FeaturesParameters
|
| 277 |
+
Parameters for the feature extraction UDF. Will be serialized into a
|
| 278 |
+
dictionary and passed in the process graph.
|
| 279 |
+
classifier_parameters : ClassifierParameters
|
| 280 |
+
Parameters for the classifier UDF. Will be serialized into a dictionary
|
| 281 |
+
and passed in the process graph.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
@staticmethod
|
| 285 |
+
def _default_feature_parameters() -> FeaturesParameters:
|
| 286 |
+
"""Internal helper returning the default feature parameters instance.
|
| 287 |
+
|
| 288 |
+
Centralizes the defaults so they are declared only once.
|
| 289 |
+
"""
|
| 290 |
+
return BaseParameters.create_feature_parameters(
|
| 291 |
+
rescale_s1=False,
|
| 292 |
+
presto_model_url="https://s3.waw3-1.cloudferro.com/swift/v1/openeo-ml-models-prod/worldcereal/presto-prometheo-landcover-month-LANDCOVER10-augment%3DTrue-balance%3DTrue-timeexplicit%3DFalse-run%3D202507170930_encoder.pt", # NOQA
|
| 293 |
+
compile_presto=False,
|
| 294 |
+
temporal_prediction=False,
|
| 295 |
+
target_date=None,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
feature_parameters: FeaturesParameters = Field(
|
| 299 |
+
# Wrap staticmethod call so pydantic receives a true zero-arg callable
|
| 300 |
+
default_factory=lambda: EmbeddingsParameters._default_feature_parameters()
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def __init__(self, presto_model_url: Optional[str] = None, **kwargs):
|
| 304 |
+
"""Allow initialization with a custom Presto model URL without
|
| 305 |
+
duplicating the default argument list.
|
| 306 |
+
|
| 307 |
+
Users may still pass an explicit `feature_parameters` to override all
|
| 308 |
+
aspects; in that case `presto_model_url` is ignored.
|
| 309 |
+
"""
|
| 310 |
+
if "feature_parameters" not in kwargs and presto_model_url is not None:
|
| 311 |
+
fp = self._default_feature_parameters().model_copy()
|
| 312 |
+
fp.presto_model_url = presto_model_url # type: ignore[attr-defined]
|
| 313 |
+
kwargs["feature_parameters"] = fp
|
| 314 |
+
super().__init__(**kwargs)
|
worldcereal/utils/models.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities around models for the WorldCereal package."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
import onnxruntime as ort
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@lru_cache(maxsize=2)
|
| 11 |
+
def load_model_onnx(model_url) -> ort.InferenceSession:
|
| 12 |
+
"""Load an ONNX model from a URL.
|
| 13 |
+
|
| 14 |
+
Parameters
|
| 15 |
+
----------
|
| 16 |
+
model_url: str
|
| 17 |
+
URL to the ONNX model.
|
| 18 |
+
|
| 19 |
+
Returns
|
| 20 |
+
-------
|
| 21 |
+
ort.InferenceSession
|
| 22 |
+
ONNX model loaded with ONNX runtime.
|
| 23 |
+
"""
|
| 24 |
+
# Two minutes timeout to download the model
|
| 25 |
+
response = requests.get(model_url, timeout=120)
|
| 26 |
+
model = response.content
|
| 27 |
+
|
| 28 |
+
return ort.InferenceSession(model)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def validate_cb_model(model_url: str) -> ort.InferenceSession:
|
| 32 |
+
"""Validate a catboost model by loading it and checking if the required
|
| 33 |
+
metadata is present. Checks for the `class_names` and `class_to_labels`
|
| 34 |
+
fields are present in the `class_params` field of the custom metadata of
|
| 35 |
+
the model. By default, the CatBoost module should include those fields
|
| 36 |
+
when exporting a model to ONNX.
|
| 37 |
+
|
| 38 |
+
Raises an exception if the model is not valid.
|
| 39 |
+
|
| 40 |
+
Parameters
|
| 41 |
+
----------
|
| 42 |
+
model_url : str
|
| 43 |
+
URL to the ONNX model.
|
| 44 |
+
|
| 45 |
+
Returns
|
| 46 |
+
-------
|
| 47 |
+
ort.InferenceSession
|
| 48 |
+
ONNX model loaded with ONNX runtime.
|
| 49 |
+
"""
|
| 50 |
+
model = load_model_onnx(model_url=model_url)
|
| 51 |
+
|
| 52 |
+
metadata = model.get_modelmeta().custom_metadata_map
|
| 53 |
+
|
| 54 |
+
if "class_params" not in metadata:
|
| 55 |
+
raise ValueError("Could not find class names in the model metadata.")
|
| 56 |
+
|
| 57 |
+
class_params = json.loads(metadata["class_params"])
|
| 58 |
+
|
| 59 |
+
if "class_names" not in class_params:
|
| 60 |
+
raise ValueError("Could not find class names in the model metadata.")
|
| 61 |
+
|
| 62 |
+
if "class_to_label" not in class_params:
|
| 63 |
+
raise ValueError("Could not find class to labels in the model metadata.")
|
| 64 |
+
|
| 65 |
+
return model
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def load_model_lut(model_url: str) -> dict:
|
| 69 |
+
"""Load the class names to labels mapping from a CatBoost model.
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
model_url : str
|
| 74 |
+
URL to the ONNX model.
|
| 75 |
+
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
dict
|
| 79 |
+
Look-up table with class names and labels.
|
| 80 |
+
"""
|
| 81 |
+
model = validate_cb_model(model_url=model_url)
|
| 82 |
+
metadata = model.get_modelmeta().custom_metadata_map
|
| 83 |
+
class_params = json.loads(metadata["class_params"])
|
| 84 |
+
|
| 85 |
+
lut = dict(zip(class_params["class_names"], class_params["class_to_label"]))
|
| 86 |
+
sorted_lut = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])}
|
| 87 |
+
return sorted_lut
|