""" Common internal operations within collection extraction logic, such as reprojection. """ from functools import partial from typing import Dict, Optional, Sequence, Union import openeo from geojson import GeoJSON from openeo.api.process import Parameter from openeo.rest.connection import InputDate from pyproj.crs import CRS from pyproj.exceptions import CRSError from openeo_gfmap.spatial import BoundingBoxExtent, SpatialContext from openeo_gfmap.temporal import TemporalContext from .fetching import FetchType def convert_band_names(desired_bands: list, band_dict: dict) -> list: """Renames the desired bands to the band names of the collection specified in the backend. Parameters ---------- desired_bands: list List of bands that are desired by the user, written in the OpenEO-GFMAP harmonized names convention. band_dict: dict Dictionnary mapping for a backend the collection band names to the OpenEO-GFMAP harmonized names. This dictionnary will be reversed to be used within this function. Returns ------- backend_band_list: list List of band names within the backend collection names. """ # Reverse the dictionarry band_dict = {v: k for k, v in band_dict.items()} return [band_dict[band] for band in desired_bands] def resample_reproject( datacube: openeo.DataCube, resolution: float, epsg_code: Optional[Union[str, int]] = None, method: str = "near", ) -> openeo.DataCube: """Reprojects the given datacube to the target epsg code, if the provided epsg code is not None. Also performs checks on the give code to check its validity. """ if epsg_code is not None: # Checks that the code is valid try: CRS.from_epsg(int(epsg_code)) except (CRSError, ValueError) as exc: raise ValueError( f"Specified target_crs: {epsg_code} is not a valid EPSG code." ) from exc return datacube.resample_spatial( resolution=resolution, projection=epsg_code, method=method ) return datacube.resample_spatial(resolution=resolution, method=method) def rename_bands(datacube: openeo.DataCube, mapping: dict) -> openeo.DataCube: """Rename the bands from the given mapping scheme""" # Filter out bands that are not part of the datacube def filter_condition(band_name, _): return band_name in datacube.metadata.band_names mapping = {k: v for k, v in mapping.items() if filter_condition(k, v)} return datacube.rename_labels( dimension="bands", target=list(mapping.values()), source=list(mapping.keys()) ) def _load_collection_hybrid( connection: openeo.Connection, is_stac: bool, collection_id_or_url: str, bands: list, spatial_extent: Union[Dict[str, float], Parameter, None] = None, temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, properties: Optional[dict] = None, ): """Wrapper around the load_collection, or load_stac method of the openeo connection.""" if not is_stac: return connection.load_collection( collection_id=collection_id_or_url, spatial_extent=spatial_extent, temporal_extent=temporal_extent, bands=bands, properties=properties, ) cube = connection.load_stac( url=collection_id_or_url, spatial_extent=spatial_extent, temporal_extent=temporal_extent, bands=bands, properties=properties, ) cube = cube.rename_labels(dimension="bands", target=bands) return cube def _load_collection( connection: openeo.Connection, bands: list, collection_name: str, spatial_extent: SpatialContext, temporal_extent: Optional[TemporalContext], fetch_type: FetchType, is_stac: bool = False, **params, ): """Loads a collection from the openeo backend, acting differently depending on the fetch type. """ load_collection_parameters = params.get("load_collection", {}) load_collection_method = partial( _load_collection_hybrid, is_stac=is_stac, collection_id_or_url=collection_name ) if ( temporal_extent is not None ): # Can be ignored for intemporal collections such as DEM temporal_extent = [temporal_extent.start_date, temporal_extent.end_date] if fetch_type == FetchType.TILE: if isinstance(spatial_extent, BoundingBoxExtent): spatial_extent = dict(spatial_extent) elif spatial_extent is not None: raise ValueError( "`spatial_extent` should be either None or an instance of BoundingBoxExtent for tile-based fetching." ) cube = load_collection_method( connection=connection, bands=bands, spatial_extent=spatial_extent, temporal_extent=temporal_extent, properties=load_collection_parameters, ) elif fetch_type == FetchType.POINT or fetch_type == FetchType.POLYGON: cube = load_collection_method( connection=connection, bands=bands, temporal_extent=temporal_extent, properties=load_collection_parameters, ) # Adding the process graph updates for experimental features if params.get("update_arguments") is not None: cube.result_node().update_arguments(**params["update_arguments"]) # Peforming pre-mask optimization pre_mask = params.get("pre_mask", None) if pre_mask is not None: assert isinstance(pre_mask, openeo.DataCube), ( f"The 'pre_mask' parameter must be an openeo datacube, got {pre_mask}." ) cube = cube.mask(pre_mask) # Merges additional bands continuing the operations. pre_merge_cube = params.get("pre_merge", None) if pre_merge_cube is not None: assert isinstance(pre_merge_cube, openeo.DataCube), ( f"The 'pre_merge' parameter value must be an openeo datacube, " f"got {pre_merge_cube}." ) if pre_mask is not None: pre_merge_cube = pre_merge_cube.mask(pre_mask) cube = cube.merge_cubes(pre_merge_cube) if fetch_type == FetchType.POLYGON and spatial_extent is not None: if isinstance(spatial_extent, str): geometry = connection.load_url( spatial_extent, format="Parquet" if ".parquet" in spatial_extent else "GeoJSON", ) cube = cube.filter_spatial(geometry) else: cube = cube.filter_spatial(spatial_extent) return cube