Fahimeh Orvati Nia
Add sorghum_pipeline code
b4123b8
raw
history blame
14.6 kB
"""
Vegetation index extraction for the Sorghum Pipeline.
This module handles extraction of various vegetation indices
from multispectral data.
"""
import numpy as np
import cv2
from typing import Dict, Tuple, Optional, Any
import logging
logger = logging.getLogger(__name__)
class VegetationIndexExtractor:
"""Extracts vegetation indices from spectral data."""
def __init__(self, epsilon: float = 1e-10, soil_factor: float = 0.16):
"""
Initialize vegetation index extractor.
Args:
epsilon: Small value to avoid division by zero
soil_factor: Soil factor for certain indices
"""
# Coerce to float in case config passed strings like "1e-10"
try:
self.epsilon = float(epsilon)
except Exception:
self.epsilon = 1e-10
try:
self.soil_factor = float(soil_factor)
except Exception:
self.soil_factor = 0.16
# Define vegetation index formulas
self.index_formulas = {
"NDVI": lambda nir, red: (nir - red) / (nir + red + self.epsilon),
"GNDVI": lambda nir, green: (nir - green) / (nir + green + self.epsilon),
"NDRE": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
"GRNDVI": lambda nir, green, red: (nir - (green + red)) / (nir + (green + red) + self.epsilon),
"TNDVI": lambda nir, red: np.sqrt(np.clip(((nir - red) / (nir + red + self.epsilon)) + 0.5, 0, None)),
"MGRVI": lambda green, red: (green**2 - red**2) / (green**2 + red**2 + self.epsilon),
"GRVI": lambda nir, green: nir / (green + self.epsilon),
"NGRDI": lambda green, red: (green - red) / (green + red + self.epsilon),
"MSAVI": lambda nir, red: 0.5 * (2.0 * nir + 1 - np.sqrt((2 * nir + 1)**2 - 8 * (nir - red))),
"OSAVI": lambda nir, red: (nir - red) / (nir + red + self.soil_factor + self.epsilon),
"TSAVI": lambda nir, red, s=0.33, a=0.5, X=1.5: (s * (nir - s * red - a)) / (a * nir + red - a * s + X * (1 + s**2) + self.epsilon),
"GSAVI": lambda nir, green, l=0.5: (1 + l) * (nir - green) / (nir + green + l + self.epsilon),
# Requested additions and aliases
"GOSAVI": lambda nir, green: (nir - green) / (nir + green + 0.16 + self.epsilon),
"GDVI": lambda nir, green: nir - green,
"NDWI": lambda green, nir: (green - nir) / (green + nir + self.epsilon),
"DSWI4": lambda green, red: green / (red + self.epsilon),
"CIRE": lambda nir, red_edge: (nir / (red_edge + self.epsilon)) - 1.0,
"LCI": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
"CIgreen": lambda nir, green: (nir / (green + self.epsilon)) - 1,
"MCARI": lambda red_edge, red, green: ((red_edge - red) - 0.2 * (red_edge - green)) * (red_edge / (red + self.epsilon)),
"MCARI1": lambda nir, red, green: 1.2 * (2.5 * (nir - red) - 1.3 * (nir - green)),
"MCARI2": lambda nir, red, green: (1.5 * (2.5 * (nir - red) - 1.3 * (nir - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon))),
# MTVI variants per request
"MTVI1": lambda nir, red, green: 1.2 * (1.2 * (nir - green) - 2.5 * (red - green)),
"MTVI2": lambda nir, red, green: (1.5 * (1.2 * (nir - green) - 2.5 * (red - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon)) - 0.5 + self.epsilon),
"CVI": lambda nir, red, green: (nir * red) / (green**2 + self.epsilon),
"ARI": lambda green, red_edge: (1.0 / (green + self.epsilon)) - (1.0 / (red_edge + self.epsilon)),
"ARI2": lambda nir, green, red_edge: nir * (1.0 / (green + self.epsilon)) - nir * (1.0 / (red_edge + self.epsilon)),
"DVI": lambda nir, red: nir - red,
"WDVI": lambda nir, red, a=0.5: nir - a * red,
"SR": lambda nir, red: nir / (red + self.epsilon),
"MSR": lambda nir, red: (nir / (red + self.epsilon) - 1) / np.sqrt(nir / (red + self.epsilon) + 1),
"PVI": lambda nir, red, a=0.5, b=0.3: (nir - a * red - b) / (np.sqrt(1 + a**2) + self.epsilon),
"GEMI": lambda nir, red: ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon)) * (1 - 0.25 * ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon))) - ((red - 0.125) / (1 - red + self.epsilon)),
"ExR": lambda red, green: 1.3 * red - green,
"RI": lambda red, green: (red - green) / (red + green + self.epsilon),
"RRI1": lambda nir, red_edge: nir / (red_edge + self.epsilon),
"RRI2": lambda red_edge, red: red_edge / (red + self.epsilon),
"RRI": lambda nir, red_edge: nir / (red_edge + self.epsilon),
"AVI": lambda nir, red: np.cbrt(nir * (1.0 - red) * (nir - red + self.epsilon)),
"SIPI2": lambda nir, green, red: (nir - green) / (nir - red + self.epsilon),
"TCARI": lambda red_edge, red, green: 3 * ((red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))),
"TCARIOSAVI": lambda red_edge, red, green, nir: (3 * (red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))) / (1 + 0.16 * ((nir - red) / (nir + red + 0.16 + self.epsilon))),
"CCCI": lambda nir, red_edge, red: (((nir - red_edge) * (nir + red)) / ((nir + red_edge) * (nir - red) + self.epsilon)),
# Additional indices
"RDVI": lambda nir, red: (nir - red) / (np.sqrt(nir + red + self.epsilon)),
"NLI": lambda nir, red: ((nir**2) - red) / ((nir**2) + red + self.epsilon),
"BIXS": lambda green, red: np.sqrt(((green**2) + (red**2)) / 2.0),
"IPVI": lambda nir, red: nir / (nir + red + self.epsilon),
"EVI2": lambda nir, red: 2.4 * (nir - red) / (nir + red + 1.0 + self.epsilon)
}
# Define required bands for each index
self.index_bands = {
"NDVI": ["nir", "red"],
"GNDVI": ["nir", "green"],
"NDRE": ["nir", "red_edge"],
"GRNDVI": ["nir", "green", "red"],
"TNDVI": ["nir", "red"],
"MGRVI": ["green", "red"],
"GRVI": ["nir", "green"],
"NGRDI": ["green", "red"],
"MSAVI": ["nir", "red"],
"OSAVI": ["nir", "red"],
"TSAVI": ["nir", "red"],
"GSAVI": ["nir", "green"],
"GOSAVI": ["nir", "green"],
"GDVI": ["nir", "green"],
"NDWI": ["green", "nir"],
"DSWI4": ["green", "red"],
"CIRE": ["nir", "red_edge"],
"LCI": ["nir", "red_edge"],
"CIgreen": ["nir", "green"],
"MCARI": ["red_edge", "red", "green"],
"MCARI1": ["nir", "red", "green"],
"MCARI2": ["nir", "red", "green"],
"MTVI1": ["nir", "red", "green"],
"MTVI2": ["nir", "red", "green"],
"CVI": ["nir", "red", "green"],
"ARI": ["green", "red_edge"],
"ARI2": ["nir", "green", "red_edge"],
"DVI": ["nir", "red"],
"WDVI": ["nir", "red"],
"SR": ["nir", "red"],
"MSR": ["nir", "red"],
"PVI": ["nir", "red"],
"GEMI": ["nir", "red"],
"ExR": ["red", "green"],
"RI": ["red", "green"],
"RRI1": ["nir", "red_edge"],
"RRI2": ["red_edge", "red"],
"RRI": ["nir", "red_edge"],
"AVI": ["nir", "red"],
"SIPI2": ["nir", "green", "red"],
"TCARI": ["red_edge", "red", "green"],
"TCARIOSAVI": ["red_edge", "red", "green", "nir"],
"CCCI": ["nir", "red_edge", "red"],
"RDVI": ["nir", "red"],
"NLI": ["nir", "red"],
"BIXS": ["green", "red"],
"IPVI": ["nir", "red"],
"EVI2": ["nir", "red"]
}
def compute_vegetation_indices(self, spectral_stack: Dict[str, np.ndarray],
mask: np.ndarray) -> Dict[str, Dict[str, Any]]:
"""
Compute vegetation indices from spectral data.
Args:
spectral_stack: Dictionary of spectral bands
mask: Binary mask for the plant
Returns:
Dictionary of vegetation indices with values and statistics
"""
indices = {}
for index_name, formula in self.index_formulas.items():
try:
# Get required bands
required_bands = self.index_bands.get(index_name, [])
# Check if all required bands are available
if not all(band in spectral_stack for band in required_bands):
logger.warning(f"Skipping {index_name}: missing required bands")
continue
# Extract band data as float arrays
band_data = []
for band in required_bands:
arr = spectral_stack[band]
# Ensure numeric float np.ndarray
if isinstance(arr, np.ndarray):
arr = arr.squeeze(-1)
arr = np.asarray(arr, dtype=np.float64)
band_data.append(arr)
# Compute index (ensure float math)
index_values = formula(*band_data).astype(np.float64)
# Apply mask
if mask is not None:
binary_mask = (np.asarray(mask).astype(np.int32) > 0)
masked_values = np.where(binary_mask, index_values, np.nan)
else:
masked_values = index_values
# Compute statistics
valid_values = masked_values[~np.isnan(masked_values)]
if len(valid_values) > 0:
stats = {
'mean': float(np.mean(valid_values)),
'std': float(np.std(valid_values)),
'min': float(np.min(valid_values)),
'max': float(np.max(valid_values)),
'median': float(np.median(valid_values)),
'q25': float(np.percentile(valid_values, 25)),
'q75': float(np.percentile(valid_values, 75)),
'nan_fraction': float(np.isnan(masked_values).sum() / masked_values.size)
}
else:
stats = {
'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0,
'median': 0.0, 'q25': 0.0, 'q75': 0.0, 'nan_fraction': 1.0
}
indices[index_name] = {
'values': masked_values,
'statistics': stats
}
logger.debug(f"Computed {index_name}")
except Exception as e:
logger.error(f"Failed to compute {index_name}: {e}")
continue
return indices
def create_vegetation_index_image(self, index_values: np.ndarray,
colormap: str = 'RdYlGn',
vmin: Optional[float] = None,
vmax: Optional[float] = None) -> np.ndarray:
"""
Create visualization image for vegetation index.
Args:
index_values: Vegetation index values
colormap: Matplotlib colormap name
vmin: Minimum value for normalization
vmax: Maximum value for normalization
Returns:
RGB image array
"""
try:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
# Determine value range
valid_values = index_values[~np.isnan(index_values)]
if len(valid_values) == 0:
return np.zeros((*index_values.shape, 3), dtype=np.uint8)
if vmin is None:
vmin = np.min(valid_values)
if vmax is None:
vmax = np.max(valid_values)
# Normalize values
norm = Normalize(vmin=vmin, vmax=vmax)
cmap = cm.get_cmap(colormap)
# Apply colormap
rgba_img = cmap(norm(index_values))
rgba_img[np.isnan(index_values)] = [1, 1, 1, 1] # White for NaN
# Convert to RGB uint8
rgb_img = (rgba_img[:, :, :3] * 255).astype(np.uint8)
return rgb_img
except Exception as e:
logger.error(f"Failed to create vegetation index image: {e}")
return np.zeros((*index_values.shape, 3), dtype=np.uint8)
def get_available_indices(self) -> list:
"""Get list of available vegetation indices."""
return list(self.index_formulas.keys())
def get_index_requirements(self, index_name: str) -> list:
"""
Get required bands for a specific index.
Args:
index_name: Name of the vegetation index
Returns:
List of required band names
"""
return self.index_bands.get(index_name, [])
def validate_spectral_data(self, spectral_stack: Dict[str, np.ndarray]) -> bool:
"""
Validate spectral data for vegetation index computation.
Args:
spectral_stack: Dictionary of spectral bands
Returns:
True if valid, False otherwise
"""
if not spectral_stack:
return False
required_bands = ['nir', 'red', 'green', 'red_edge']
if not all(band in spectral_stack for band in required_bands):
logger.warning("Missing required spectral bands")
return False
# Check data shapes
shapes = [arr.shape for arr in spectral_stack.values()]
if not all(shape == shapes[0] for shape in shapes):
logger.warning("Inconsistent spectral band shapes")
return False
return True