MineWatchAI / src /stac_utils.py
Ashkan Taghipour (The University of Western Australia)
Initial commit
f5648f5
"""
STAC/Planetary Computer utilities for RehabWatch.
Handles satellite data access via Microsoft Planetary Computer.
Data Sources:
- Sentinel-2 L2A: Multispectral imagery for vegetation indices
- Copernicus DEM GLO-30: Digital elevation model for terrain analysis
- IO-LULC: Land cover classification (2017-2023)
- ESA WorldCover: Land cover classification (2020-2021)
"""
import numpy as np
import xarray as xr
import rioxarray
import stackstac
import planetary_computer
from pystac_client import Client
from shapely.geometry import box, shape, mapping
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any, Tuple
import warnings
warnings.filterwarnings('ignore')
# Planetary Computer STAC endpoint
STAC_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"
# Collection names
SENTINEL2_COLLECTION = "sentinel-2-l2a"
COPERNICUS_DEM_COLLECTION = "cop-dem-glo-30"
IO_LULC_COLLECTION = "io-lulc-annual-v02"
ESA_WORLDCOVER_COLLECTION = "esa-worldcover"
# Land cover class mappings for IO-LULC
LULC_CLASSES = {
1: "Water",
2: "Trees",
4: "Flooded Vegetation",
5: "Crops",
7: "Built Area",
8: "Bare Ground",
9: "Snow/Ice",
10: "Clouds",
11: "Rangeland"
}
# ESA WorldCover class mappings
WORLDCOVER_CLASSES = {
10: "Tree cover",
20: "Shrubland",
30: "Grassland",
40: "Cropland",
50: "Built-up",
60: "Bare / sparse vegetation",
70: "Snow and ice",
80: "Permanent water bodies",
90: "Herbaceous wetland",
95: "Mangroves",
100: "Moss and lichen"
}
def get_stac_client() -> Client:
"""
Get a STAC client for Planetary Computer.
Returns:
pystac_client.Client instance
"""
return Client.open(STAC_URL, modifier=planetary_computer.sign_inplace)
# =============================================================================
# SENTINEL-2 DATA ACCESS
# =============================================================================
def search_sentinel2(
bbox: Tuple[float, float, float, float],
start_date: str,
end_date: str,
cloud_cover: int = 20
) -> List[Any]:
"""
Search for Sentinel-2 scenes in the Planetary Computer catalog.
Args:
bbox: Bounding box (min_lon, min_lat, max_lon, max_lat)
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
cloud_cover: Maximum cloud cover percentage
Returns:
List of STAC items
"""
client = get_stac_client()
search = client.search(
collections=[SENTINEL2_COLLECTION],
bbox=bbox,
datetime=f"{start_date}/{end_date}",
query={"eo:cloud_cover": {"lt": cloud_cover}}
)
items = list(search.items())
return items
def get_sentinel_composite(
bbox: Tuple[float, float, float, float],
start_date: str,
end_date: str,
cloud_threshold: int = 20,
resolution: int = 20
) -> xr.DataArray:
"""
Get a cloud-free Sentinel-2 composite for a given bbox and date range.
Includes all bands needed for comprehensive vegetation analysis.
Args:
bbox: Bounding box (min_lon, min_lat, max_lon, max_lat)
start_date: Start date string (YYYY-MM-DD)
end_date: End date string (YYYY-MM-DD)
cloud_threshold: Maximum cloud cover percentage (0-100)
resolution: Output resolution in meters (default 20m for memory efficiency)
Returns:
xarray DataArray with median composite
Raises:
ValueError: If no images found for the specified criteria
"""
items = search_sentinel2(bbox, start_date, end_date, cloud_threshold)
if len(items) == 0:
raise ValueError(
f"No Sentinel-2 images found for the specified location and date range "
f"({start_date} to {end_date}) with cloud cover below {cloud_threshold}%. "
"Try expanding the date range or increasing the cloud threshold."
)
# Limit number of items to reduce memory usage
if len(items) > 5:
items = sorted(items, key=lambda x: x.properties.get('eo:cloud_cover', 100))[:5]
# Select all bands needed for indices:
# B02 (Blue), B03 (Green), B04 (Red), B05 (Red Edge 1),
# B06 (Red Edge 2), B07 (Red Edge 3), B08 (NIR),
# B8A (NIR narrow), B11 (SWIR1), B12 (SWIR2), SCL (Scene Classification)
bands = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12", "SCL"]
stack = stackstac.stack(
items,
assets=bands,
bounds_latlon=bbox,
resolution=resolution,
epsg=32750, # UTM zone for Western Australia
dtype="float64",
rescale=False,
fill_value=np.nan,
chunksize=1024 # Smaller chunks for memory efficiency
)
# Apply cloud masking using SCL (Scene Classification Layer)
scl = stack.sel(band="SCL")
cloud_mask = (scl >= 7) & (scl <= 10)
# Apply mask to reflectance bands
masked = stack.where(~cloud_mask)
# Calculate median composite
composite = masked.median(dim="time", skipna=True)
# Scale to 0-1 reflectance (Sentinel-2 L2A is in 0-10000)
composite = composite / 10000.0
return composite.compute()
# =============================================================================
# VEGETATION INDICES
# =============================================================================
def calculate_ndvi(data: xr.DataArray) -> xr.DataArray:
"""
Calculate NDVI (Normalized Difference Vegetation Index).
NDVI = (NIR - Red) / (NIR + Red)
Range: -1 to 1 (higher = more vegetation)
"""
red = data.sel(band="B04")
nir = data.sel(band="B08")
ndvi = (nir - red) / (nir + red + 1e-10)
return ndvi.clip(-1, 1)
def calculate_savi(data: xr.DataArray, L: float = 0.5) -> xr.DataArray:
"""
Calculate SAVI (Soil Adjusted Vegetation Index).
SAVI = ((NIR - Red) / (NIR + Red + L)) * (1 + L)
Better than NDVI for areas with sparse vegetation.
L = 0.5 works well for most conditions.
Range: -1 to 1
"""
red = data.sel(band="B04")
nir = data.sel(band="B08")
savi = ((nir - red) / (nir + red + L + 1e-10)) * (1 + L)
return savi.clip(-1, 1)
def calculate_evi(data: xr.DataArray) -> xr.DataArray:
"""
Calculate EVI (Enhanced Vegetation Index).
EVI = 2.5 * ((NIR - Red) / (NIR + 6*Red - 7.5*Blue + 1))
More sensitive in high biomass regions, corrects for atmospheric influences.
Range: approximately -1 to 1
"""
blue = data.sel(band="B02")
red = data.sel(band="B04")
nir = data.sel(band="B08")
evi = 2.5 * ((nir - red) / (nir + 6 * red - 7.5 * blue + 1 + 1e-10))
return evi.clip(-1, 1)
def calculate_ndwi(data: xr.DataArray) -> xr.DataArray:
"""
Calculate NDWI (Normalized Difference Water Index).
NDWI = (Green - NIR) / (Green + NIR)
Detects water bodies. Higher values indicate water presence.
Range: -1 to 1
"""
green = data.sel(band="B03")
nir = data.sel(band="B08")
ndwi = (green - nir) / (green + nir + 1e-10)
return ndwi.clip(-1, 1)
def calculate_ndmi(data: xr.DataArray) -> xr.DataArray:
"""
Calculate NDMI (Normalized Difference Moisture Index).
NDMI = (NIR - SWIR1) / (NIR + SWIR1)
Measures vegetation water content/moisture stress.
Range: -1 to 1 (higher = more moisture)
"""
nir = data.sel(band="B08")
swir1 = data.sel(band="B11")
ndmi = (nir - swir1) / (nir + swir1 + 1e-10)
return ndmi.clip(-1, 1)
def calculate_bsi(data: xr.DataArray) -> xr.DataArray:
"""
Calculate BSI (Bare Soil Index).
BSI = ((SWIR1 + Red) - (NIR + Blue)) / ((SWIR1 + Red) + (NIR + Blue))
Identifies bare soil areas. Higher values indicate more bare soil.
Range: -1 to 1
"""
blue = data.sel(band="B02")
red = data.sel(band="B04")
nir = data.sel(band="B08")
swir1 = data.sel(band="B11")
bsi = ((swir1 + red) - (nir + blue)) / ((swir1 + red) + (nir + blue) + 1e-10)
return bsi.clip(-1, 1)
def calculate_nbr(data: xr.DataArray) -> xr.DataArray:
"""
Calculate NBR (Normalized Burn Ratio).
NBR = (NIR - SWIR2) / (NIR + SWIR2)
Useful for detecting burned areas and vegetation disturbance.
Range: -1 to 1
"""
nir = data.sel(band="B08")
swir2 = data.sel(band="B12")
nbr = (nir - swir2) / (nir + swir2 + 1e-10)
return nbr.clip(-1, 1)
def calculate_all_indices(data: xr.DataArray) -> Dict[str, xr.DataArray]:
"""
Calculate all vegetation and soil indices from Sentinel-2 data.
Returns:
Dictionary with index names as keys and DataArrays as values
"""
return {
'ndvi': calculate_ndvi(data),
'savi': calculate_savi(data),
'evi': calculate_evi(data),
'ndwi': calculate_ndwi(data),
'ndmi': calculate_ndmi(data),
'bsi': calculate_bsi(data),
'nbr': calculate_nbr(data)
}
def calculate_vegetation_heterogeneity(ndvi: xr.DataArray, window_size: int = 5) -> xr.DataArray:
"""
Calculate vegetation heterogeneity as local standard deviation of NDVI.
Higher values indicate more diverse/heterogeneous vegetation.
This serves as a proxy for species diversity.
Args:
ndvi: NDVI DataArray
window_size: Size of the moving window (default 5 = 50m at 10m resolution)
Returns:
DataArray with heterogeneity values
"""
# Use rolling window to calculate local std
heterogeneity = ndvi.rolling(x=window_size, y=window_size, center=True).std()
return heterogeneity
# =============================================================================
# COPERNICUS DEM DATA ACCESS
# =============================================================================
def get_dem_data(
bbox: Tuple[float, float, float, float],
resolution: int = 30
) -> xr.DataArray:
"""
Get Copernicus DEM GLO-30 elevation data.
Args:
bbox: Bounding box (min_lon, min_lat, max_lon, max_lat)
resolution: Output resolution in meters (default 30m)
Returns:
xarray DataArray with elevation values in meters
"""
client = get_stac_client()
search = client.search(
collections=[COPERNICUS_DEM_COLLECTION],
bbox=bbox
)
items = list(search.items())
if len(items) == 0:
raise ValueError("No DEM data found for the specified location.")
stack = stackstac.stack(
items,
assets=["data"],
bounds_latlon=bbox,
resolution=resolution,
epsg=32750,
dtype="float32",
fill_value=np.nan,
chunksize=2048
)
# Take the first (or merge if multiple tiles)
dem = stack.median(dim="time", skipna=True).squeeze()
return dem.compute()
def calculate_slope(dem: xr.DataArray, resolution: float = 30.0) -> xr.DataArray:
"""
Calculate slope from DEM in degrees.
Args:
dem: Elevation DataArray
resolution: Pixel resolution in meters
Returns:
Slope in degrees (0-90)
"""
# Calculate gradients
dy, dx = np.gradient(dem.values, resolution)
# Calculate slope in degrees
slope = np.degrees(np.arctan(np.sqrt(dx**2 + dy**2)))
# Create DataArray with same coordinates
slope_da = xr.DataArray(
slope,
dims=dem.dims,
coords=dem.coords,
name='slope'
)
return slope_da
def calculate_aspect(dem: xr.DataArray, resolution: float = 30.0) -> xr.DataArray:
"""
Calculate aspect from DEM in degrees.
Args:
dem: Elevation DataArray
resolution: Pixel resolution in meters
Returns:
Aspect in degrees (0-360, 0=North, 90=East)
"""
dy, dx = np.gradient(dem.values, resolution)
# Calculate aspect
aspect = np.degrees(np.arctan2(-dx, dy))
aspect = np.where(aspect < 0, aspect + 360, aspect)
aspect_da = xr.DataArray(
aspect,
dims=dem.dims,
coords=dem.coords,
name='aspect'
)
return aspect_da
def calculate_terrain_ruggedness(dem: xr.DataArray, window_size: int = 3) -> xr.DataArray:
"""
Calculate Terrain Ruggedness Index (TRI).
TRI is the mean of the absolute differences between the center cell
and its surrounding cells.
Args:
dem: Elevation DataArray
window_size: Size of the moving window
Returns:
TRI values (higher = more rugged terrain)
"""
# Calculate local range as a proxy for ruggedness
rolling = dem.rolling(x=window_size, y=window_size, center=True)
tri = rolling.max() - rolling.min()
return tri
def calculate_erosion_risk(
slope: xr.DataArray,
bsi: xr.DataArray,
slope_weight: float = 0.6,
bare_soil_weight: float = 0.4
) -> xr.DataArray:
"""
Calculate erosion risk index combining slope and bare soil.
Higher values indicate greater erosion risk.
Args:
slope: Slope in degrees
bsi: Bare Soil Index
slope_weight: Weight for slope component
bare_soil_weight: Weight for bare soil component
Returns:
Erosion risk index (0-1)
"""
# Normalize slope to 0-1 (assuming max slope of 45 degrees)
slope_norm = (slope / 45.0).clip(0, 1)
# Normalize BSI to 0-1
bsi_norm = ((bsi + 1) / 2).clip(0, 1)
# Combined erosion risk
erosion_risk = slope_weight * slope_norm + bare_soil_weight * bsi_norm
return erosion_risk.clip(0, 1)
# =============================================================================
# LAND COVER DATA ACCESS
# =============================================================================
def get_land_cover(
bbox: Tuple[float, float, float, float],
year: int = 2023,
resolution: int = 10
) -> xr.DataArray:
"""
Get IO-LULC annual land cover data.
Args:
bbox: Bounding box (min_lon, min_lat, max_lon, max_lat)
year: Year of land cover data (2017-2023)
resolution: Output resolution in meters
Returns:
xarray DataArray with land cover classes
"""
client = get_stac_client()
search = client.search(
collections=[IO_LULC_COLLECTION],
bbox=bbox,
datetime=f"{year}-01-01/{year}-12-31"
)
items = list(search.items())
if len(items) == 0:
raise ValueError(f"No land cover data found for year {year}.")
stack = stackstac.stack(
items,
assets=["data"],
bounds_latlon=bbox,
resolution=resolution,
epsg=32750,
dtype="uint8",
fill_value=0,
chunksize=2048
)
lulc = stack.max(dim="time").squeeze()
return lulc.compute()
def get_worldcover(
bbox: Tuple[float, float, float, float],
year: int = 2021,
resolution: int = 10
) -> xr.DataArray:
"""
Get ESA WorldCover land cover data.
Args:
bbox: Bounding box (min_lon, min_lat, max_lon, max_lat)
year: Year (2020 or 2021)
resolution: Output resolution in meters
Returns:
xarray DataArray with land cover classes
"""
client = get_stac_client()
search = client.search(
collections=[ESA_WORLDCOVER_COLLECTION],
bbox=bbox,
datetime=f"{year}-01-01/{year}-12-31"
)
items = list(search.items())
if len(items) == 0:
raise ValueError(f"No WorldCover data found for year {year}.")
stack = stackstac.stack(
items,
assets=["map"],
bounds_latlon=bbox,
resolution=resolution,
epsg=32750,
dtype="uint8",
fill_value=0,
chunksize=2048
)
worldcover = stack.max(dim="time").squeeze()
return worldcover.compute()
def calculate_land_cover_change(
lulc_before: xr.DataArray,
lulc_after: xr.DataArray
) -> Dict[str, Any]:
"""
Calculate land cover change statistics between two periods.
Args:
lulc_before: Land cover data for earlier period
lulc_after: Land cover data for later period
Returns:
Dictionary with change statistics
"""
# Calculate pixel counts for each class
before_counts = {}
after_counts = {}
for class_id, class_name in LULC_CLASSES.items():
before_counts[class_name] = int((lulc_before == class_id).sum().values)
after_counts[class_name] = int((lulc_after == class_id).sum().values)
# Calculate changes
changes = {}
for class_name in LULC_CLASSES.values():
before = before_counts.get(class_name, 0)
after = after_counts.get(class_name, 0)
changes[class_name] = {
'before': before,
'after': after,
'change': after - before,
'percent_change': ((after - before) / (before + 1)) * 100
}
return {
'before': before_counts,
'after': after_counts,
'changes': changes
}
def calculate_vegetation_cover_percent(
lulc: xr.DataArray,
source: str = 'io-lulc'
) -> float:
"""
Calculate percentage of area covered by vegetation.
Args:
lulc: Land cover DataArray
source: 'io-lulc' or 'worldcover'
Returns:
Vegetation cover percentage (0-100)
"""
total_pixels = lulc.size
if source == 'io-lulc':
# Vegetation classes: Trees (2), Flooded Vegetation (4), Crops (5), Rangeland (11)
veg_classes = [2, 4, 5, 11]
else: # worldcover
# Vegetation classes: Tree cover (10), Shrubland (20), Grassland (30),
# Cropland (40), Herbaceous wetland (90), Mangroves (95)
veg_classes = [10, 20, 30, 40, 90, 95]
veg_pixels = sum(int((lulc == c).sum().values) for c in veg_classes)
return (veg_pixels / total_pixels) * 100
def calculate_bare_ground_percent(
lulc: xr.DataArray,
source: str = 'io-lulc'
) -> float:
"""
Calculate percentage of area that is bare ground.
Args:
lulc: Land cover DataArray
source: 'io-lulc' or 'worldcover'
Returns:
Bare ground percentage (0-100)
"""
total_pixels = lulc.size
if source == 'io-lulc':
bare_classes = [8] # Bare Ground
else: # worldcover
bare_classes = [60] # Bare / sparse vegetation
bare_pixels = sum(int((lulc == c).sum().values) for c in bare_classes)
return (bare_pixels / total_pixels) * 100
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def get_image_count(
bbox: Tuple[float, float, float, float],
start_date: str,
end_date: str,
cloud_threshold: int = 20
) -> int:
"""Get count of available Sentinel-2 images for a location."""
items = search_sentinel2(bbox, start_date, end_date, cloud_threshold)
return len(items)
def get_image_dates(
bbox: Tuple[float, float, float, float],
start_date: str,
end_date: str,
cloud_threshold: int = 30
) -> List[str]:
"""Get list of available Sentinel-2 image dates for a location."""
items = search_sentinel2(bbox, start_date, end_date, cloud_threshold)
dates = [item.datetime.strftime("%Y-%m-%d") for item in items if item.datetime]
return sorted(list(set(dates)))
def geometry_to_bbox(geometry: Dict[str, Any]) -> Tuple[float, float, float, float]:
"""Convert a GeoJSON geometry to a bounding box."""
geom = shape(geometry)
bounds = geom.bounds
return bounds
def bbox_to_geometry(bbox: Tuple[float, float, float, float]) -> Dict[str, Any]:
"""Convert a bounding box to GeoJSON geometry."""
return mapping(box(*bbox))
def get_bbox_center(bbox: Tuple[float, float, float, float]) -> Tuple[float, float]:
"""Get the center point of a bounding box."""
min_lon, min_lat, max_lon, max_lat = bbox
center_lat = (min_lat + max_lat) / 2
center_lon = (min_lon + max_lon) / 2
return (center_lat, center_lon)
def expand_bbox(
bbox: Tuple[float, float, float, float],
buffer_deg: float = 0.01
) -> Tuple[float, float, float, float]:
"""Expand a bounding box by a buffer in degrees."""
min_lon, min_lat, max_lon, max_lat = bbox
return (
min_lon - buffer_deg,
min_lat - buffer_deg,
max_lon + buffer_deg,
max_lat + buffer_deg
)
def create_reference_bbox(
bbox: Tuple[float, float, float, float],
buffer_deg: float = 0.01
) -> Tuple[float, float, float, float]:
"""Create a reference bounding box around the site."""
return expand_bbox(bbox, buffer_deg)