MineWatchAI / src /analysis.py
Ashkan Taghipour (The University of Western Australia)
Initial commit
f5648f5
"""
Comprehensive vegetation and terrain analysis module for RehabWatch.
Performs multi-index analysis including:
- Vegetation indices (NDVI, SAVI, EVI)
- Soil/water indices (BSI, NDWI, NDMI, NBR)
- Terrain analysis (slope, aspect, erosion risk)
- Land cover classification and change
- Rehabilitation metrics and scoring
"""
import numpy as np
import xarray as xr
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, Tuple, List
from .stac_utils import (
get_sentinel_composite,
calculate_ndvi,
calculate_savi,
calculate_evi,
calculate_ndwi,
calculate_ndmi,
calculate_bsi,
calculate_nbr,
calculate_all_indices,
calculate_vegetation_heterogeneity,
get_dem_data,
calculate_slope,
calculate_aspect,
calculate_terrain_ruggedness,
calculate_erosion_risk,
get_land_cover,
get_worldcover,
calculate_land_cover_change,
calculate_vegetation_cover_percent,
calculate_bare_ground_percent,
search_sentinel2,
create_reference_bbox,
get_bbox_center,
LULC_CLASSES,
WORLDCOVER_CLASSES
)
def analyze_vegetation_change(
bbox: Tuple[float, float, float, float],
date_before: str,
date_after: str,
window_days: int = 15,
cloud_threshold: int = 25
) -> Dict[str, Any]:
"""
Analyze vegetation change between two dates using multiple indices.
Args:
bbox: Bounding box (min_lon, min_lat, max_lon, max_lat)
date_before: Start date (YYYY-MM-DD)
date_after: End date (YYYY-MM-DD)
window_days: Days before/after each date for composite (default 15)
cloud_threshold: Maximum cloud cover percentage
Returns:
Dict containing composites, indices, and statistics
"""
# Parse dates and create windows
before_dt = datetime.strptime(date_before, '%Y-%m-%d')
after_dt = datetime.strptime(date_after, '%Y-%m-%d')
before_start = (before_dt - timedelta(days=window_days)).strftime('%Y-%m-%d')
before_end = (before_dt + timedelta(days=window_days)).strftime('%Y-%m-%d')
after_start = (after_dt - timedelta(days=window_days)).strftime('%Y-%m-%d')
after_end = (after_dt + timedelta(days=window_days)).strftime('%Y-%m-%d')
# Get cloud-free composites
composite_before = get_sentinel_composite(
bbox, before_start, before_end, cloud_threshold
)
composite_after = get_sentinel_composite(
bbox, after_start, after_end, cloud_threshold
)
# Calculate all indices for both periods
indices_before = calculate_all_indices(composite_before)
indices_after = calculate_all_indices(composite_after)
# Calculate changes for each index
index_changes = {}
for key in indices_before:
index_changes[key] = indices_after[key] - indices_before[key]
# Calculate vegetation heterogeneity (proxy for diversity)
heterogeneity_before = calculate_vegetation_heterogeneity(indices_before['ndvi'])
heterogeneity_after = calculate_vegetation_heterogeneity(indices_after['ndvi'])
# Calculate comprehensive statistics
stats = calculate_statistics(
indices_before, indices_after, index_changes, bbox
)
return {
'composite_before': composite_before,
'composite_after': composite_after,
'indices_before': indices_before,
'indices_after': indices_after,
'index_changes': index_changes,
'ndvi_before': indices_before['ndvi'],
'ndvi_after': indices_after['ndvi'],
'ndvi_change': index_changes['ndvi'],
'heterogeneity_before': heterogeneity_before,
'heterogeneity_after': heterogeneity_after,
'stats': stats,
'date_before': date_before,
'date_after': date_after,
'bbox': bbox
}
def analyze_terrain(
bbox: Tuple[float, float, float, float],
bsi: Optional[xr.DataArray] = None
) -> Dict[str, Any]:
"""
Analyze terrain characteristics including slope, aspect, and erosion risk.
Args:
bbox: Bounding box (min_lon, min_lat, max_lon, max_lat)
bsi: Bare Soil Index array (optional, for erosion risk)
Returns:
Dict containing terrain data and statistics
"""
try:
# Get DEM data
dem = get_dem_data(bbox)
# Calculate terrain derivatives
slope = calculate_slope(dem)
aspect = calculate_aspect(dem)
ruggedness = calculate_terrain_ruggedness(dem)
# Calculate erosion risk if BSI provided
erosion_risk = None
if bsi is not None:
# Resample BSI to match DEM resolution if needed
erosion_risk = calculate_erosion_risk(slope, bsi)
# Calculate statistics
terrain_stats = {
'elevation_min': float(np.nanmin(dem.values)),
'elevation_max': float(np.nanmax(dem.values)),
'elevation_mean': float(np.nanmean(dem.values)),
'slope_mean': float(np.nanmean(slope.values)),
'slope_max': float(np.nanmax(slope.values)),
'ruggedness_mean': float(np.nanmean(ruggedness.values)),
}
# Slope classification
flat_pixels = np.sum(slope.values < 5)
gentle_pixels = np.sum((slope.values >= 5) & (slope.values < 15))
moderate_pixels = np.sum((slope.values >= 15) & (slope.values < 30))
steep_pixels = np.sum(slope.values >= 30)
total_pixels = slope.size
terrain_stats['percent_flat'] = round((flat_pixels / total_pixels) * 100, 1)
terrain_stats['percent_gentle'] = round((gentle_pixels / total_pixels) * 100, 1)
terrain_stats['percent_moderate'] = round((moderate_pixels / total_pixels) * 100, 1)
terrain_stats['percent_steep'] = round((steep_pixels / total_pixels) * 100, 1)
if erosion_risk is not None:
terrain_stats['erosion_risk_mean'] = float(np.nanmean(erosion_risk.values))
high_risk = np.sum(erosion_risk.values > 0.6)
terrain_stats['percent_high_erosion_risk'] = round((high_risk / total_pixels) * 100, 1)
return {
'dem': dem,
'slope': slope,
'aspect': aspect,
'ruggedness': ruggedness,
'erosion_risk': erosion_risk,
'stats': terrain_stats
}
except Exception as e:
return {
'error': str(e),
'stats': {}
}
def analyze_land_cover(
bbox: Tuple[float, float, float, float],
year_before: int,
year_after: int
) -> Dict[str, Any]:
"""
Analyze land cover and its changes between two years.
Args:
bbox: Bounding box
year_before: Earlier year (2017-2023)
year_after: Later year (2017-2023)
Returns:
Dict containing land cover data and statistics
"""
try:
# Get land cover for both years
lulc_before = get_land_cover(bbox, year_before)
lulc_after = get_land_cover(bbox, year_after)
# Calculate change statistics
change_stats = calculate_land_cover_change(lulc_before, lulc_after)
# Calculate vegetation and bare ground percentages
veg_cover_before = calculate_vegetation_cover_percent(lulc_before)
veg_cover_after = calculate_vegetation_cover_percent(lulc_after)
bare_before = calculate_bare_ground_percent(lulc_before)
bare_after = calculate_bare_ground_percent(lulc_after)
land_cover_stats = {
'vegetation_cover_before': round(veg_cover_before, 1),
'vegetation_cover_after': round(veg_cover_after, 1),
'vegetation_cover_change': round(veg_cover_after - veg_cover_before, 1),
'bare_ground_before': round(bare_before, 1),
'bare_ground_after': round(bare_after, 1),
'bare_ground_change': round(bare_after - bare_before, 1),
'year_before': year_before,
'year_after': year_after,
'class_changes': change_stats['changes']
}
return {
'lulc_before': lulc_before,
'lulc_after': lulc_after,
'stats': land_cover_stats,
'classes': LULC_CLASSES
}
except Exception as e:
return {
'error': str(e),
'stats': {}
}
def calculate_statistics(
indices_before: Dict[str, xr.DataArray],
indices_after: Dict[str, xr.DataArray],
index_changes: Dict[str, xr.DataArray],
bbox: Tuple[float, float, float, float]
) -> Dict[str, float]:
"""
Calculate comprehensive vegetation and soil statistics.
Args:
indices_before: Dict of index arrays at start date
indices_after: Dict of index arrays at end date
index_changes: Dict of index change arrays
bbox: Bounding box for area calculation
Returns:
Dict with comprehensive statistics
"""
stats = {}
# Get NDVI arrays
ndvi_before = indices_before['ndvi']
ndvi_after = indices_after['ndvi']
ndvi_change = index_changes['ndvi']
# Get valid (non-NaN) data
valid_before = ndvi_before.values[~np.isnan(ndvi_before.values)]
valid_after = ndvi_after.values[~np.isnan(ndvi_after.values)]
valid_change = ndvi_change.values[~np.isnan(ndvi_change.values)]
# NDVI statistics
stats['ndvi_before_mean'] = round(float(np.nanmean(valid_before)), 4) if len(valid_before) > 0 else 0
stats['ndvi_after_mean'] = round(float(np.nanmean(valid_after)), 4) if len(valid_after) > 0 else 0
stats['ndvi_change_mean'] = round(float(np.nanmean(valid_change)), 4) if len(valid_change) > 0 else 0
stats['ndvi_change_std'] = round(float(np.nanstd(valid_change)), 4) if len(valid_change) > 0 else 0
# Calculate percent change
if stats['ndvi_before_mean'] > 0:
stats['percent_change'] = round(((stats['ndvi_after_mean'] - stats['ndvi_before_mean']) /
stats['ndvi_before_mean']) * 100, 2)
else:
stats['percent_change'] = 0
# All other indices - before/after means
for idx_name in ['savi', 'evi', 'ndwi', 'ndmi', 'bsi', 'nbr']:
if idx_name in indices_before and idx_name in indices_after:
before_vals = indices_before[idx_name].values
after_vals = indices_after[idx_name].values
valid_b = before_vals[~np.isnan(before_vals)]
valid_a = after_vals[~np.isnan(after_vals)]
stats[f'{idx_name}_before_mean'] = round(float(np.nanmean(valid_b)), 4) if len(valid_b) > 0 else 0
stats[f'{idx_name}_after_mean'] = round(float(np.nanmean(valid_a)), 4) if len(valid_a) > 0 else 0
stats[f'{idx_name}_change'] = round(stats[f'{idx_name}_after_mean'] - stats[f'{idx_name}_before_mean'], 4)
# Area calculations (improved, degraded, stable)
pixel_area_ha = (10 * 10) / 10000 # 0.01 ha per pixel
total_pixels = len(valid_change)
improved_pixels = np.sum(valid_change > 0.05)
degraded_pixels = np.sum(valid_change < -0.05)
stable_pixels = np.sum((valid_change >= -0.05) & (valid_change <= 0.05))
stats['area_improved_ha'] = round(float(improved_pixels * pixel_area_ha), 2)
stats['area_degraded_ha'] = round(float(degraded_pixels * pixel_area_ha), 2)
stats['area_stable_ha'] = round(float(stable_pixels * pixel_area_ha), 2)
stats['total_area_ha'] = round(float(total_pixels * pixel_area_ha), 2)
# Calculate percentages
if total_pixels > 0:
stats['percent_improved'] = round((improved_pixels / total_pixels) * 100, 2)
stats['percent_degraded'] = round((degraded_pixels / total_pixels) * 100, 2)
stats['percent_stable'] = round((stable_pixels / total_pixels) * 100, 2)
else:
stats['percent_improved'] = 0
stats['percent_degraded'] = 0
stats['percent_stable'] = 0
# Water presence (from NDWI)
if 'ndwi' in indices_after:
ndwi_vals = indices_after['ndwi'].values
valid_ndwi = ndwi_vals[~np.isnan(ndwi_vals)]
water_pixels = np.sum(valid_ndwi > 0)
stats['percent_water'] = round((water_pixels / len(valid_ndwi)) * 100, 2) if len(valid_ndwi) > 0 else 0
# Bare soil extent (from BSI)
if 'bsi' in indices_after:
bsi_vals = indices_after['bsi'].values
valid_bsi = bsi_vals[~np.isnan(bsi_vals)]
bare_pixels = np.sum(valid_bsi > 0.1)
stats['percent_bare_soil'] = round((bare_pixels / len(valid_bsi)) * 100, 2) if len(valid_bsi) > 0 else 0
# Moisture stress (from NDMI)
if 'ndmi' in indices_after:
ndmi_vals = indices_after['ndmi'].values
valid_ndmi = ndmi_vals[~np.isnan(ndmi_vals)]
stressed_pixels = np.sum(valid_ndmi < 0)
stats['percent_moisture_stressed'] = round((stressed_pixels / len(valid_ndmi)) * 100, 2) if len(valid_ndmi) > 0 else 0
# Vegetation health classification
if len(valid_after) > 0:
sparse = np.sum((valid_after > 0) & (valid_after <= 0.2))
low = np.sum((valid_after > 0.2) & (valid_after <= 0.4))
moderate = np.sum((valid_after > 0.4) & (valid_after <= 0.6))
dense = np.sum(valid_after > 0.6)
stats['percent_sparse_veg'] = round((sparse / len(valid_after)) * 100, 2)
stats['percent_low_veg'] = round((low / len(valid_after)) * 100, 2)
stats['percent_moderate_veg'] = round((moderate / len(valid_after)) * 100, 2)
stats['percent_dense_veg'] = round((dense / len(valid_after)) * 100, 2)
return stats
def calculate_reference_ndvi(
bbox: Tuple[float, float, float, float],
date: str,
window_days: int = 15,
cloud_threshold: int = 25,
buffer_deg: float = 0.01
) -> float:
"""
Calculate mean NDVI for reference area (buffer around site).
"""
dt = datetime.strptime(date, '%Y-%m-%d')
start = (dt - timedelta(days=window_days)).strftime('%Y-%m-%d')
end = (dt + timedelta(days=window_days)).strftime('%Y-%m-%d')
ref_bbox = create_reference_bbox(bbox, buffer_deg)
try:
composite = get_sentinel_composite(ref_bbox, start, end, cloud_threshold)
ndvi = calculate_ndvi(composite)
valid_ndvi = ndvi.values[~np.isnan(ndvi.values)]
return float(np.nanmean(valid_ndvi)) if len(valid_ndvi) > 0 else 0
except Exception:
return 0
def calculate_rehab_score(site_ndvi: float, reference_ndvi: float) -> int:
"""
Calculate rehabilitation score (0-100).
The score represents how close the site's vegetation is to
the reference (undisturbed) area.
"""
if reference_ndvi <= 0:
return 0
score = (site_ndvi / reference_ndvi) * 100
return min(100, max(0, round(score)))
def calculate_comprehensive_rehab_score(
stats: Dict[str, float],
terrain_stats: Optional[Dict[str, float]] = None,
land_cover_stats: Optional[Dict[str, float]] = None,
reference_ndvi: float = 0.5
) -> Dict[str, Any]:
"""
Calculate comprehensive rehabilitation score using multiple metrics.
Returns:
Dict with component scores and overall score
"""
scores = {}
# Vegetation score (based on NDVI)
site_ndvi = stats.get('ndvi_after_mean', 0)
scores['vegetation_score'] = min(100, max(0, round((site_ndvi / reference_ndvi) * 100)))
# Improvement score (based on change)
improvement = stats.get('percent_improved', 0)
degradation = stats.get('percent_degraded', 0)
scores['improvement_score'] = min(100, max(0, round(50 + improvement - degradation)))
# Soil stability score (based on BSI - lower is better)
bare_soil = stats.get('percent_bare_soil', 50)
scores['soil_stability_score'] = min(100, max(0, round(100 - bare_soil)))
# Moisture score (based on NDMI)
moisture_stressed = stats.get('percent_moisture_stressed', 50)
scores['moisture_score'] = min(100, max(0, round(100 - moisture_stressed)))
# Terrain score (if available)
if terrain_stats:
erosion_risk = terrain_stats.get('percent_high_erosion_risk', 50)
scores['terrain_score'] = min(100, max(0, round(100 - erosion_risk)))
# Land cover score (if available)
if land_cover_stats:
veg_cover = land_cover_stats.get('vegetation_cover_after', 0)
scores['land_cover_score'] = min(100, max(0, round(veg_cover)))
# Calculate weighted overall score
weights = {
'vegetation_score': 0.30,
'improvement_score': 0.25,
'soil_stability_score': 0.20,
'moisture_score': 0.10,
'terrain_score': 0.10,
'land_cover_score': 0.05
}
total_weight = 0
weighted_sum = 0
for key, weight in weights.items():
if key in scores:
weighted_sum += scores[key] * weight
total_weight += weight
scores['overall_score'] = round(weighted_sum / total_weight) if total_weight > 0 else 0
return scores
def generate_interpretation(
stats: Dict[str, float],
rehab_score: int,
terrain_stats: Optional[Dict] = None,
land_cover_stats: Optional[Dict] = None
) -> str:
"""
Generate comprehensive plain-language interpretation of the analysis results.
"""
interpretation_parts = []
# Vegetation change interpretation
change = stats.get('percent_change', 0)
if change > 10:
change_text = f"Vegetation cover has significantly improved by {change:.1f}%"
elif change > 0:
change_text = f"Vegetation cover has moderately improved by {change:.1f}%"
elif change > -10:
change_text = f"Vegetation cover has slightly declined by {abs(change):.1f}%"
else:
change_text = f"Vegetation cover has significantly declined by {abs(change):.1f}%"
interpretation_parts.append(change_text + " over the analysis period.")
# Area breakdown
if stats.get('percent_improved', 0) > stats.get('percent_degraded', 0):
area_text = (f"Approximately {stats['percent_improved']:.0f}% of the site "
f"({stats['area_improved_ha']:.1f} ha) shows vegetation improvement, "
f"while {stats['percent_degraded']:.0f}% ({stats['area_degraded_ha']:.1f} ha) "
"shows decline.")
else:
area_text = (f"Approximately {stats['percent_degraded']:.0f}% of the site "
f"({stats['area_degraded_ha']:.1f} ha) shows vegetation decline, "
f"while {stats['percent_improved']:.0f}% ({stats['area_improved_ha']:.1f} ha) "
"shows improvement.")
interpretation_parts.append(area_text)
# Soil and moisture conditions
bare_soil = stats.get('percent_bare_soil', 0)
moisture_stress = stats.get('percent_moisture_stressed', 0)
if bare_soil > 30:
interpretation_parts.append(f"Bare soil covers {bare_soil:.0f}% of the area, indicating potential erosion risk.")
elif bare_soil > 10:
interpretation_parts.append(f"Moderate bare soil exposure ({bare_soil:.0f}%) is present.")
if moisture_stress > 50:
interpretation_parts.append(f"Significant moisture stress detected in {moisture_stress:.0f}% of vegetation.")
# Water presence
water = stats.get('percent_water', 0)
if water > 5:
interpretation_parts.append(f"Water bodies or saturated areas cover {water:.0f}% of the site.")
# Terrain interpretation
if terrain_stats:
steep = terrain_stats.get('percent_steep', 0)
erosion = terrain_stats.get('percent_high_erosion_risk', 0)
if steep > 20:
interpretation_parts.append(f"The terrain includes {steep:.0f}% steep slopes (>30 degrees).")
if erosion > 30:
interpretation_parts.append(f"High erosion risk identified in {erosion:.0f}% of the area.")
# Land cover interpretation
if land_cover_stats:
veg_change = land_cover_stats.get('vegetation_cover_change', 0)
bare_change = land_cover_stats.get('bare_ground_change', 0)
if veg_change > 5:
interpretation_parts.append(f"Land cover analysis shows {veg_change:.0f}% increase in vegetated area.")
elif veg_change < -5:
interpretation_parts.append(f"Land cover analysis shows {abs(veg_change):.0f}% decrease in vegetated area.")
if bare_change < -5:
interpretation_parts.append(f"Bare ground has decreased by {abs(bare_change):.0f}%.")
# Rehabilitation score interpretation
if rehab_score >= 80:
rehab_text = (f"The site has achieved {rehab_score}% of reference vegetation conditions, "
"indicating excellent rehabilitation progress.")
elif rehab_score >= 60:
rehab_text = (f"The site has achieved {rehab_score}% of reference vegetation conditions, "
"indicating good rehabilitation progress.")
elif rehab_score >= 40:
rehab_text = (f"The site has achieved {rehab_score}% of reference vegetation conditions, "
"indicating moderate rehabilitation progress.")
elif rehab_score >= 20:
rehab_text = (f"The site has achieved {rehab_score}% of reference vegetation conditions, "
"indicating early-stage rehabilitation.")
else:
rehab_text = (f"The site has achieved {rehab_score}% of reference vegetation conditions, "
"indicating limited rehabilitation progress to date.")
interpretation_parts.append(rehab_text)
return " ".join(interpretation_parts)
def get_monthly_ndvi_timeseries(
bbox: Tuple[float, float, float, float],
start_year: int,
end_year: int,
cloud_threshold: int = 30
) -> List[Dict[str, Any]]:
"""
Get monthly NDVI time series for a bounding box.
"""
results = []
for year in range(start_year, end_year + 1):
for month in range(1, 13):
now = datetime.now()
if year > now.year or (year == now.year and month > now.month):
continue
start_date = f"{year}-{month:02d}-01"
if month == 12:
end_date = f"{year}-12-31"
else:
next_month = datetime(year, month + 1, 1)
end_of_month = next_month - timedelta(days=1)
end_date = end_of_month.strftime('%Y-%m-%d')
try:
items = search_sentinel2(bbox, start_date, end_date, cloud_threshold)
if len(items) > 0:
composite = get_sentinel_composite(
bbox, start_date, end_date, cloud_threshold
)
ndvi = calculate_ndvi(composite)
valid_ndvi = ndvi.values[~np.isnan(ndvi.values)]
if len(valid_ndvi) > 0:
mean_ndvi = float(np.nanmean(valid_ndvi))
results.append({
'date': f"{year}-{month:02d}-15",
'ndvi': mean_ndvi
})
except Exception:
continue
return sorted(results, key=lambda x: x['date'])
def get_multi_index_timeseries(
bbox: Tuple[float, float, float, float],
start_year: int,
end_year: int,
cloud_threshold: int = 30
) -> List[Dict[str, Any]]:
"""
Get monthly time series for multiple indices.
"""
results = []
for year in range(start_year, end_year + 1):
for month in range(1, 13):
now = datetime.now()
if year > now.year or (year == now.year and month > now.month):
continue
start_date = f"{year}-{month:02d}-01"
if month == 12:
end_date = f"{year}-12-31"
else:
next_month = datetime(year, month + 1, 1)
end_of_month = next_month - timedelta(days=1)
end_date = end_of_month.strftime('%Y-%m-%d')
try:
items = search_sentinel2(bbox, start_date, end_date, cloud_threshold)
if len(items) > 0:
composite = get_sentinel_composite(
bbox, start_date, end_date, cloud_threshold
)
indices = calculate_all_indices(composite)
record = {'date': f"{year}-{month:02d}-15"}
for idx_name, idx_data in indices.items():
valid_vals = idx_data.values[~np.isnan(idx_data.values)]
if len(valid_vals) > 0:
record[idx_name] = float(np.nanmean(valid_vals))
if len(record) > 1:
results.append(record)
except Exception:
continue
return sorted(results, key=lambda x: x['date'])
def calculate_seasonal_stability(timeseries: List[Dict[str, Any]]) -> Dict[str, float]:
"""
Calculate seasonal stability metrics from time series data.
Lower variance indicates more stable ecosystem function.
"""
if len(timeseries) < 4:
return {}
ndvi_values = [r.get('ndvi', 0) for r in timeseries if 'ndvi' in r]
if len(ndvi_values) < 4:
return {}
return {
'ndvi_mean': round(float(np.mean(ndvi_values)), 4),
'ndvi_std': round(float(np.std(ndvi_values)), 4),
'ndvi_cv': round(float(np.std(ndvi_values) / np.mean(ndvi_values)) * 100, 2),
'ndvi_min': round(float(np.min(ndvi_values)), 4),
'ndvi_max': round(float(np.max(ndvi_values)), 4),
'ndvi_range': round(float(np.max(ndvi_values) - np.min(ndvi_values)), 4)
}
def ndvi_to_image_array(ndvi: xr.DataArray) -> np.ndarray:
"""Convert NDVI xarray to a colored numpy array for visualization."""
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
colors = ['#8B4513', '#D2B48C', '#FFFF00', '#90EE90', '#228B22', '#006400']
cmap = LinearSegmentedColormap.from_list('ndvi', colors)
ndvi_normalized = (ndvi.values - (-0.1)) / (0.8 - (-0.1))
ndvi_normalized = np.clip(ndvi_normalized, 0, 1)
rgba = cmap(ndvi_normalized)
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
return rgb
def change_to_image_array(change: xr.DataArray) -> np.ndarray:
"""Convert NDVI change xarray to a colored numpy array for visualization."""
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
colors = ['#B71C1C', '#EF9A9A', '#FFFFFF', '#A5D6A7', '#1B5E20']
cmap = LinearSegmentedColormap.from_list('change', colors)
change_normalized = (change.values - (-0.3)) / (0.3 - (-0.3))
change_normalized = np.clip(change_normalized, 0, 1)
rgba = cmap(change_normalized)
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
return rgb
def index_to_image_array(
data: xr.DataArray,
colormap: str = 'viridis',
vmin: float = -1,
vmax: float = 1
) -> np.ndarray:
"""Convert any index xarray to a colored numpy array."""
import matplotlib.pyplot as plt
cmap = plt.get_cmap(colormap)
data_normalized = (data.values - vmin) / (vmax - vmin)
data_normalized = np.clip(data_normalized, 0, 1)
rgba = cmap(data_normalized)
rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
return rgb