""" STAC/Planetary Computer utilities for RehabWatch. Handles satellite data access via Microsoft Planetary Computer. Data Sources: - Sentinel-2 L2A: Multispectral imagery for vegetation indices - Copernicus DEM GLO-30: Digital elevation model for terrain analysis - IO-LULC: Land cover classification (2017-2023) - ESA WorldCover: Land cover classification (2020-2021) """ import numpy as np import xarray as xr import rioxarray import stackstac import planetary_computer from pystac_client import Client from shapely.geometry import box, shape, mapping from datetime import datetime, timedelta from typing import Optional, List, Dict, Any, Tuple import warnings warnings.filterwarnings('ignore') # Planetary Computer STAC endpoint STAC_URL = "https://planetarycomputer.microsoft.com/api/stac/v1" # Collection names SENTINEL2_COLLECTION = "sentinel-2-l2a" COPERNICUS_DEM_COLLECTION = "cop-dem-glo-30" IO_LULC_COLLECTION = "io-lulc-annual-v02" ESA_WORLDCOVER_COLLECTION = "esa-worldcover" # Land cover class mappings for IO-LULC LULC_CLASSES = { 1: "Water", 2: "Trees", 4: "Flooded Vegetation", 5: "Crops", 7: "Built Area", 8: "Bare Ground", 9: "Snow/Ice", 10: "Clouds", 11: "Rangeland" } # ESA WorldCover class mappings WORLDCOVER_CLASSES = { 10: "Tree cover", 20: "Shrubland", 30: "Grassland", 40: "Cropland", 50: "Built-up", 60: "Bare / sparse vegetation", 70: "Snow and ice", 80: "Permanent water bodies", 90: "Herbaceous wetland", 95: "Mangroves", 100: "Moss and lichen" } def get_stac_client() -> Client: """ Get a STAC client for Planetary Computer. Returns: pystac_client.Client instance """ return Client.open(STAC_URL, modifier=planetary_computer.sign_inplace) # ============================================================================= # SENTINEL-2 DATA ACCESS # ============================================================================= def search_sentinel2( bbox: Tuple[float, float, float, float], start_date: str, end_date: str, cloud_cover: int = 20 ) -> List[Any]: """ Search for Sentinel-2 scenes in the Planetary Computer catalog. Args: bbox: Bounding box (min_lon, min_lat, max_lon, max_lat) start_date: Start date (YYYY-MM-DD) end_date: End date (YYYY-MM-DD) cloud_cover: Maximum cloud cover percentage Returns: List of STAC items """ client = get_stac_client() search = client.search( collections=[SENTINEL2_COLLECTION], bbox=bbox, datetime=f"{start_date}/{end_date}", query={"eo:cloud_cover": {"lt": cloud_cover}} ) items = list(search.items()) return items def get_sentinel_composite( bbox: Tuple[float, float, float, float], start_date: str, end_date: str, cloud_threshold: int = 20, resolution: int = 20 ) -> xr.DataArray: """ Get a cloud-free Sentinel-2 composite for a given bbox and date range. Includes all bands needed for comprehensive vegetation analysis. Args: bbox: Bounding box (min_lon, min_lat, max_lon, max_lat) start_date: Start date string (YYYY-MM-DD) end_date: End date string (YYYY-MM-DD) cloud_threshold: Maximum cloud cover percentage (0-100) resolution: Output resolution in meters (default 20m for memory efficiency) Returns: xarray DataArray with median composite Raises: ValueError: If no images found for the specified criteria """ items = search_sentinel2(bbox, start_date, end_date, cloud_threshold) if len(items) == 0: raise ValueError( f"No Sentinel-2 images found for the specified location and date range " f"({start_date} to {end_date}) with cloud cover below {cloud_threshold}%. " "Try expanding the date range or increasing the cloud threshold." ) # Limit number of items to reduce memory usage if len(items) > 5: items = sorted(items, key=lambda x: x.properties.get('eo:cloud_cover', 100))[:5] # Select all bands needed for indices: # B02 (Blue), B03 (Green), B04 (Red), B05 (Red Edge 1), # B06 (Red Edge 2), B07 (Red Edge 3), B08 (NIR), # B8A (NIR narrow), B11 (SWIR1), B12 (SWIR2), SCL (Scene Classification) bands = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12", "SCL"] stack = stackstac.stack( items, assets=bands, bounds_latlon=bbox, resolution=resolution, epsg=32750, # UTM zone for Western Australia dtype="float64", rescale=False, fill_value=np.nan, chunksize=1024 # Smaller chunks for memory efficiency ) # Apply cloud masking using SCL (Scene Classification Layer) scl = stack.sel(band="SCL") cloud_mask = (scl >= 7) & (scl <= 10) # Apply mask to reflectance bands masked = stack.where(~cloud_mask) # Calculate median composite composite = masked.median(dim="time", skipna=True) # Scale to 0-1 reflectance (Sentinel-2 L2A is in 0-10000) composite = composite / 10000.0 return composite.compute() # ============================================================================= # VEGETATION INDICES # ============================================================================= def calculate_ndvi(data: xr.DataArray) -> xr.DataArray: """ Calculate NDVI (Normalized Difference Vegetation Index). NDVI = (NIR - Red) / (NIR + Red) Range: -1 to 1 (higher = more vegetation) """ red = data.sel(band="B04") nir = data.sel(band="B08") ndvi = (nir - red) / (nir + red + 1e-10) return ndvi.clip(-1, 1) def calculate_savi(data: xr.DataArray, L: float = 0.5) -> xr.DataArray: """ Calculate SAVI (Soil Adjusted Vegetation Index). SAVI = ((NIR - Red) / (NIR + Red + L)) * (1 + L) Better than NDVI for areas with sparse vegetation. L = 0.5 works well for most conditions. Range: -1 to 1 """ red = data.sel(band="B04") nir = data.sel(band="B08") savi = ((nir - red) / (nir + red + L + 1e-10)) * (1 + L) return savi.clip(-1, 1) def calculate_evi(data: xr.DataArray) -> xr.DataArray: """ Calculate EVI (Enhanced Vegetation Index). EVI = 2.5 * ((NIR - Red) / (NIR + 6*Red - 7.5*Blue + 1)) More sensitive in high biomass regions, corrects for atmospheric influences. Range: approximately -1 to 1 """ blue = data.sel(band="B02") red = data.sel(band="B04") nir = data.sel(band="B08") evi = 2.5 * ((nir - red) / (nir + 6 * red - 7.5 * blue + 1 + 1e-10)) return evi.clip(-1, 1) def calculate_ndwi(data: xr.DataArray) -> xr.DataArray: """ Calculate NDWI (Normalized Difference Water Index). NDWI = (Green - NIR) / (Green + NIR) Detects water bodies. Higher values indicate water presence. Range: -1 to 1 """ green = data.sel(band="B03") nir = data.sel(band="B08") ndwi = (green - nir) / (green + nir + 1e-10) return ndwi.clip(-1, 1) def calculate_ndmi(data: xr.DataArray) -> xr.DataArray: """ Calculate NDMI (Normalized Difference Moisture Index). NDMI = (NIR - SWIR1) / (NIR + SWIR1) Measures vegetation water content/moisture stress. Range: -1 to 1 (higher = more moisture) """ nir = data.sel(band="B08") swir1 = data.sel(band="B11") ndmi = (nir - swir1) / (nir + swir1 + 1e-10) return ndmi.clip(-1, 1) def calculate_bsi(data: xr.DataArray) -> xr.DataArray: """ Calculate BSI (Bare Soil Index). BSI = ((SWIR1 + Red) - (NIR + Blue)) / ((SWIR1 + Red) + (NIR + Blue)) Identifies bare soil areas. Higher values indicate more bare soil. Range: -1 to 1 """ blue = data.sel(band="B02") red = data.sel(band="B04") nir = data.sel(band="B08") swir1 = data.sel(band="B11") bsi = ((swir1 + red) - (nir + blue)) / ((swir1 + red) + (nir + blue) + 1e-10) return bsi.clip(-1, 1) def calculate_nbr(data: xr.DataArray) -> xr.DataArray: """ Calculate NBR (Normalized Burn Ratio). NBR = (NIR - SWIR2) / (NIR + SWIR2) Useful for detecting burned areas and vegetation disturbance. Range: -1 to 1 """ nir = data.sel(band="B08") swir2 = data.sel(band="B12") nbr = (nir - swir2) / (nir + swir2 + 1e-10) return nbr.clip(-1, 1) def calculate_all_indices(data: xr.DataArray) -> Dict[str, xr.DataArray]: """ Calculate all vegetation and soil indices from Sentinel-2 data. Returns: Dictionary with index names as keys and DataArrays as values """ return { 'ndvi': calculate_ndvi(data), 'savi': calculate_savi(data), 'evi': calculate_evi(data), 'ndwi': calculate_ndwi(data), 'ndmi': calculate_ndmi(data), 'bsi': calculate_bsi(data), 'nbr': calculate_nbr(data) } def calculate_vegetation_heterogeneity(ndvi: xr.DataArray, window_size: int = 5) -> xr.DataArray: """ Calculate vegetation heterogeneity as local standard deviation of NDVI. Higher values indicate more diverse/heterogeneous vegetation. This serves as a proxy for species diversity. Args: ndvi: NDVI DataArray window_size: Size of the moving window (default 5 = 50m at 10m resolution) Returns: DataArray with heterogeneity values """ # Use rolling window to calculate local std heterogeneity = ndvi.rolling(x=window_size, y=window_size, center=True).std() return heterogeneity # ============================================================================= # COPERNICUS DEM DATA ACCESS # ============================================================================= def get_dem_data( bbox: Tuple[float, float, float, float], resolution: int = 30 ) -> xr.DataArray: """ Get Copernicus DEM GLO-30 elevation data. Args: bbox: Bounding box (min_lon, min_lat, max_lon, max_lat) resolution: Output resolution in meters (default 30m) Returns: xarray DataArray with elevation values in meters """ client = get_stac_client() search = client.search( collections=[COPERNICUS_DEM_COLLECTION], bbox=bbox ) items = list(search.items()) if len(items) == 0: raise ValueError("No DEM data found for the specified location.") stack = stackstac.stack( items, assets=["data"], bounds_latlon=bbox, resolution=resolution, epsg=32750, dtype="float32", fill_value=np.nan, chunksize=2048 ) # Take the first (or merge if multiple tiles) dem = stack.median(dim="time", skipna=True).squeeze() return dem.compute() def calculate_slope(dem: xr.DataArray, resolution: float = 30.0) -> xr.DataArray: """ Calculate slope from DEM in degrees. Args: dem: Elevation DataArray resolution: Pixel resolution in meters Returns: Slope in degrees (0-90) """ # Calculate gradients dy, dx = np.gradient(dem.values, resolution) # Calculate slope in degrees slope = np.degrees(np.arctan(np.sqrt(dx**2 + dy**2))) # Create DataArray with same coordinates slope_da = xr.DataArray( slope, dims=dem.dims, coords=dem.coords, name='slope' ) return slope_da def calculate_aspect(dem: xr.DataArray, resolution: float = 30.0) -> xr.DataArray: """ Calculate aspect from DEM in degrees. Args: dem: Elevation DataArray resolution: Pixel resolution in meters Returns: Aspect in degrees (0-360, 0=North, 90=East) """ dy, dx = np.gradient(dem.values, resolution) # Calculate aspect aspect = np.degrees(np.arctan2(-dx, dy)) aspect = np.where(aspect < 0, aspect + 360, aspect) aspect_da = xr.DataArray( aspect, dims=dem.dims, coords=dem.coords, name='aspect' ) return aspect_da def calculate_terrain_ruggedness(dem: xr.DataArray, window_size: int = 3) -> xr.DataArray: """ Calculate Terrain Ruggedness Index (TRI). TRI is the mean of the absolute differences between the center cell and its surrounding cells. Args: dem: Elevation DataArray window_size: Size of the moving window Returns: TRI values (higher = more rugged terrain) """ # Calculate local range as a proxy for ruggedness rolling = dem.rolling(x=window_size, y=window_size, center=True) tri = rolling.max() - rolling.min() return tri def calculate_erosion_risk( slope: xr.DataArray, bsi: xr.DataArray, slope_weight: float = 0.6, bare_soil_weight: float = 0.4 ) -> xr.DataArray: """ Calculate erosion risk index combining slope and bare soil. Higher values indicate greater erosion risk. Args: slope: Slope in degrees bsi: Bare Soil Index slope_weight: Weight for slope component bare_soil_weight: Weight for bare soil component Returns: Erosion risk index (0-1) """ # Normalize slope to 0-1 (assuming max slope of 45 degrees) slope_norm = (slope / 45.0).clip(0, 1) # Normalize BSI to 0-1 bsi_norm = ((bsi + 1) / 2).clip(0, 1) # Combined erosion risk erosion_risk = slope_weight * slope_norm + bare_soil_weight * bsi_norm return erosion_risk.clip(0, 1) # ============================================================================= # LAND COVER DATA ACCESS # ============================================================================= def get_land_cover( bbox: Tuple[float, float, float, float], year: int = 2023, resolution: int = 10 ) -> xr.DataArray: """ Get IO-LULC annual land cover data. Args: bbox: Bounding box (min_lon, min_lat, max_lon, max_lat) year: Year of land cover data (2017-2023) resolution: Output resolution in meters Returns: xarray DataArray with land cover classes """ client = get_stac_client() search = client.search( collections=[IO_LULC_COLLECTION], bbox=bbox, datetime=f"{year}-01-01/{year}-12-31" ) items = list(search.items()) if len(items) == 0: raise ValueError(f"No land cover data found for year {year}.") stack = stackstac.stack( items, assets=["data"], bounds_latlon=bbox, resolution=resolution, epsg=32750, dtype="uint8", fill_value=0, chunksize=2048 ) lulc = stack.max(dim="time").squeeze() return lulc.compute() def get_worldcover( bbox: Tuple[float, float, float, float], year: int = 2021, resolution: int = 10 ) -> xr.DataArray: """ Get ESA WorldCover land cover data. Args: bbox: Bounding box (min_lon, min_lat, max_lon, max_lat) year: Year (2020 or 2021) resolution: Output resolution in meters Returns: xarray DataArray with land cover classes """ client = get_stac_client() search = client.search( collections=[ESA_WORLDCOVER_COLLECTION], bbox=bbox, datetime=f"{year}-01-01/{year}-12-31" ) items = list(search.items()) if len(items) == 0: raise ValueError(f"No WorldCover data found for year {year}.") stack = stackstac.stack( items, assets=["map"], bounds_latlon=bbox, resolution=resolution, epsg=32750, dtype="uint8", fill_value=0, chunksize=2048 ) worldcover = stack.max(dim="time").squeeze() return worldcover.compute() def calculate_land_cover_change( lulc_before: xr.DataArray, lulc_after: xr.DataArray ) -> Dict[str, Any]: """ Calculate land cover change statistics between two periods. Args: lulc_before: Land cover data for earlier period lulc_after: Land cover data for later period Returns: Dictionary with change statistics """ # Calculate pixel counts for each class before_counts = {} after_counts = {} for class_id, class_name in LULC_CLASSES.items(): before_counts[class_name] = int((lulc_before == class_id).sum().values) after_counts[class_name] = int((lulc_after == class_id).sum().values) # Calculate changes changes = {} for class_name in LULC_CLASSES.values(): before = before_counts.get(class_name, 0) after = after_counts.get(class_name, 0) changes[class_name] = { 'before': before, 'after': after, 'change': after - before, 'percent_change': ((after - before) / (before + 1)) * 100 } return { 'before': before_counts, 'after': after_counts, 'changes': changes } def calculate_vegetation_cover_percent( lulc: xr.DataArray, source: str = 'io-lulc' ) -> float: """ Calculate percentage of area covered by vegetation. Args: lulc: Land cover DataArray source: 'io-lulc' or 'worldcover' Returns: Vegetation cover percentage (0-100) """ total_pixels = lulc.size if source == 'io-lulc': # Vegetation classes: Trees (2), Flooded Vegetation (4), Crops (5), Rangeland (11) veg_classes = [2, 4, 5, 11] else: # worldcover # Vegetation classes: Tree cover (10), Shrubland (20), Grassland (30), # Cropland (40), Herbaceous wetland (90), Mangroves (95) veg_classes = [10, 20, 30, 40, 90, 95] veg_pixels = sum(int((lulc == c).sum().values) for c in veg_classes) return (veg_pixels / total_pixels) * 100 def calculate_bare_ground_percent( lulc: xr.DataArray, source: str = 'io-lulc' ) -> float: """ Calculate percentage of area that is bare ground. Args: lulc: Land cover DataArray source: 'io-lulc' or 'worldcover' Returns: Bare ground percentage (0-100) """ total_pixels = lulc.size if source == 'io-lulc': bare_classes = [8] # Bare Ground else: # worldcover bare_classes = [60] # Bare / sparse vegetation bare_pixels = sum(int((lulc == c).sum().values) for c in bare_classes) return (bare_pixels / total_pixels) * 100 # ============================================================================= # UTILITY FUNCTIONS # ============================================================================= def get_image_count( bbox: Tuple[float, float, float, float], start_date: str, end_date: str, cloud_threshold: int = 20 ) -> int: """Get count of available Sentinel-2 images for a location.""" items = search_sentinel2(bbox, start_date, end_date, cloud_threshold) return len(items) def get_image_dates( bbox: Tuple[float, float, float, float], start_date: str, end_date: str, cloud_threshold: int = 30 ) -> List[str]: """Get list of available Sentinel-2 image dates for a location.""" items = search_sentinel2(bbox, start_date, end_date, cloud_threshold) dates = [item.datetime.strftime("%Y-%m-%d") for item in items if item.datetime] return sorted(list(set(dates))) def geometry_to_bbox(geometry: Dict[str, Any]) -> Tuple[float, float, float, float]: """Convert a GeoJSON geometry to a bounding box.""" geom = shape(geometry) bounds = geom.bounds return bounds def bbox_to_geometry(bbox: Tuple[float, float, float, float]) -> Dict[str, Any]: """Convert a bounding box to GeoJSON geometry.""" return mapping(box(*bbox)) def get_bbox_center(bbox: Tuple[float, float, float, float]) -> Tuple[float, float]: """Get the center point of a bounding box.""" min_lon, min_lat, max_lon, max_lat = bbox center_lat = (min_lat + max_lat) / 2 center_lon = (min_lon + max_lon) / 2 return (center_lat, center_lon) def expand_bbox( bbox: Tuple[float, float, float, float], buffer_deg: float = 0.01 ) -> Tuple[float, float, float, float]: """Expand a bounding box by a buffer in degrees.""" min_lon, min_lat, max_lon, max_lat = bbox return ( min_lon - buffer_deg, min_lat - buffer_deg, max_lon + buffer_deg, max_lat + buffer_deg ) def create_reference_bbox( bbox: Tuple[float, float, float, float], buffer_deg: float = 0.01 ) -> Tuple[float, float, float, float]: """Create a reference bounding box around the site.""" return expand_bbox(bbox, buffer_deg)