File size: 14,553 Bytes
b4123b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 |
"""
Vegetation index extraction for the Sorghum Pipeline.
This module handles extraction of various vegetation indices
from multispectral data.
"""
import numpy as np
import cv2
from typing import Dict, Tuple, Optional, Any
import logging
logger = logging.getLogger(__name__)
class VegetationIndexExtractor:
"""Extracts vegetation indices from spectral data."""
def __init__(self, epsilon: float = 1e-10, soil_factor: float = 0.16):
"""
Initialize vegetation index extractor.
Args:
epsilon: Small value to avoid division by zero
soil_factor: Soil factor for certain indices
"""
# Coerce to float in case config passed strings like "1e-10"
try:
self.epsilon = float(epsilon)
except Exception:
self.epsilon = 1e-10
try:
self.soil_factor = float(soil_factor)
except Exception:
self.soil_factor = 0.16
# Define vegetation index formulas
self.index_formulas = {
"NDVI": lambda nir, red: (nir - red) / (nir + red + self.epsilon),
"GNDVI": lambda nir, green: (nir - green) / (nir + green + self.epsilon),
"NDRE": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
"GRNDVI": lambda nir, green, red: (nir - (green + red)) / (nir + (green + red) + self.epsilon),
"TNDVI": lambda nir, red: np.sqrt(np.clip(((nir - red) / (nir + red + self.epsilon)) + 0.5, 0, None)),
"MGRVI": lambda green, red: (green**2 - red**2) / (green**2 + red**2 + self.epsilon),
"GRVI": lambda nir, green: nir / (green + self.epsilon),
"NGRDI": lambda green, red: (green - red) / (green + red + self.epsilon),
"MSAVI": lambda nir, red: 0.5 * (2.0 * nir + 1 - np.sqrt((2 * nir + 1)**2 - 8 * (nir - red))),
"OSAVI": lambda nir, red: (nir - red) / (nir + red + self.soil_factor + self.epsilon),
"TSAVI": lambda nir, red, s=0.33, a=0.5, X=1.5: (s * (nir - s * red - a)) / (a * nir + red - a * s + X * (1 + s**2) + self.epsilon),
"GSAVI": lambda nir, green, l=0.5: (1 + l) * (nir - green) / (nir + green + l + self.epsilon),
# Requested additions and aliases
"GOSAVI": lambda nir, green: (nir - green) / (nir + green + 0.16 + self.epsilon),
"GDVI": lambda nir, green: nir - green,
"NDWI": lambda green, nir: (green - nir) / (green + nir + self.epsilon),
"DSWI4": lambda green, red: green / (red + self.epsilon),
"CIRE": lambda nir, red_edge: (nir / (red_edge + self.epsilon)) - 1.0,
"LCI": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
"CIgreen": lambda nir, green: (nir / (green + self.epsilon)) - 1,
"MCARI": lambda red_edge, red, green: ((red_edge - red) - 0.2 * (red_edge - green)) * (red_edge / (red + self.epsilon)),
"MCARI1": lambda nir, red, green: 1.2 * (2.5 * (nir - red) - 1.3 * (nir - green)),
"MCARI2": lambda nir, red, green: (1.5 * (2.5 * (nir - red) - 1.3 * (nir - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon))),
# MTVI variants per request
"MTVI1": lambda nir, red, green: 1.2 * (1.2 * (nir - green) - 2.5 * (red - green)),
"MTVI2": lambda nir, red, green: (1.5 * (1.2 * (nir - green) - 2.5 * (red - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon)) - 0.5 + self.epsilon),
"CVI": lambda nir, red, green: (nir * red) / (green**2 + self.epsilon),
"ARI": lambda green, red_edge: (1.0 / (green + self.epsilon)) - (1.0 / (red_edge + self.epsilon)),
"ARI2": lambda nir, green, red_edge: nir * (1.0 / (green + self.epsilon)) - nir * (1.0 / (red_edge + self.epsilon)),
"DVI": lambda nir, red: nir - red,
"WDVI": lambda nir, red, a=0.5: nir - a * red,
"SR": lambda nir, red: nir / (red + self.epsilon),
"MSR": lambda nir, red: (nir / (red + self.epsilon) - 1) / np.sqrt(nir / (red + self.epsilon) + 1),
"PVI": lambda nir, red, a=0.5, b=0.3: (nir - a * red - b) / (np.sqrt(1 + a**2) + self.epsilon),
"GEMI": lambda nir, red: ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon)) * (1 - 0.25 * ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon))) - ((red - 0.125) / (1 - red + self.epsilon)),
"ExR": lambda red, green: 1.3 * red - green,
"RI": lambda red, green: (red - green) / (red + green + self.epsilon),
"RRI1": lambda nir, red_edge: nir / (red_edge + self.epsilon),
"RRI2": lambda red_edge, red: red_edge / (red + self.epsilon),
"RRI": lambda nir, red_edge: nir / (red_edge + self.epsilon),
"AVI": lambda nir, red: np.cbrt(nir * (1.0 - red) * (nir - red + self.epsilon)),
"SIPI2": lambda nir, green, red: (nir - green) / (nir - red + self.epsilon),
"TCARI": lambda red_edge, red, green: 3 * ((red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))),
"TCARIOSAVI": lambda red_edge, red, green, nir: (3 * (red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))) / (1 + 0.16 * ((nir - red) / (nir + red + 0.16 + self.epsilon))),
"CCCI": lambda nir, red_edge, red: (((nir - red_edge) * (nir + red)) / ((nir + red_edge) * (nir - red) + self.epsilon)),
# Additional indices
"RDVI": lambda nir, red: (nir - red) / (np.sqrt(nir + red + self.epsilon)),
"NLI": lambda nir, red: ((nir**2) - red) / ((nir**2) + red + self.epsilon),
"BIXS": lambda green, red: np.sqrt(((green**2) + (red**2)) / 2.0),
"IPVI": lambda nir, red: nir / (nir + red + self.epsilon),
"EVI2": lambda nir, red: 2.4 * (nir - red) / (nir + red + 1.0 + self.epsilon)
}
# Define required bands for each index
self.index_bands = {
"NDVI": ["nir", "red"],
"GNDVI": ["nir", "green"],
"NDRE": ["nir", "red_edge"],
"GRNDVI": ["nir", "green", "red"],
"TNDVI": ["nir", "red"],
"MGRVI": ["green", "red"],
"GRVI": ["nir", "green"],
"NGRDI": ["green", "red"],
"MSAVI": ["nir", "red"],
"OSAVI": ["nir", "red"],
"TSAVI": ["nir", "red"],
"GSAVI": ["nir", "green"],
"GOSAVI": ["nir", "green"],
"GDVI": ["nir", "green"],
"NDWI": ["green", "nir"],
"DSWI4": ["green", "red"],
"CIRE": ["nir", "red_edge"],
"LCI": ["nir", "red_edge"],
"CIgreen": ["nir", "green"],
"MCARI": ["red_edge", "red", "green"],
"MCARI1": ["nir", "red", "green"],
"MCARI2": ["nir", "red", "green"],
"MTVI1": ["nir", "red", "green"],
"MTVI2": ["nir", "red", "green"],
"CVI": ["nir", "red", "green"],
"ARI": ["green", "red_edge"],
"ARI2": ["nir", "green", "red_edge"],
"DVI": ["nir", "red"],
"WDVI": ["nir", "red"],
"SR": ["nir", "red"],
"MSR": ["nir", "red"],
"PVI": ["nir", "red"],
"GEMI": ["nir", "red"],
"ExR": ["red", "green"],
"RI": ["red", "green"],
"RRI1": ["nir", "red_edge"],
"RRI2": ["red_edge", "red"],
"RRI": ["nir", "red_edge"],
"AVI": ["nir", "red"],
"SIPI2": ["nir", "green", "red"],
"TCARI": ["red_edge", "red", "green"],
"TCARIOSAVI": ["red_edge", "red", "green", "nir"],
"CCCI": ["nir", "red_edge", "red"],
"RDVI": ["nir", "red"],
"NLI": ["nir", "red"],
"BIXS": ["green", "red"],
"IPVI": ["nir", "red"],
"EVI2": ["nir", "red"]
}
def compute_vegetation_indices(self, spectral_stack: Dict[str, np.ndarray],
mask: np.ndarray) -> Dict[str, Dict[str, Any]]:
"""
Compute vegetation indices from spectral data.
Args:
spectral_stack: Dictionary of spectral bands
mask: Binary mask for the plant
Returns:
Dictionary of vegetation indices with values and statistics
"""
indices = {}
for index_name, formula in self.index_formulas.items():
try:
# Get required bands
required_bands = self.index_bands.get(index_name, [])
# Check if all required bands are available
if not all(band in spectral_stack for band in required_bands):
logger.warning(f"Skipping {index_name}: missing required bands")
continue
# Extract band data as float arrays
band_data = []
for band in required_bands:
arr = spectral_stack[band]
# Ensure numeric float np.ndarray
if isinstance(arr, np.ndarray):
arr = arr.squeeze(-1)
arr = np.asarray(arr, dtype=np.float64)
band_data.append(arr)
# Compute index (ensure float math)
index_values = formula(*band_data).astype(np.float64)
# Apply mask
if mask is not None:
binary_mask = (np.asarray(mask).astype(np.int32) > 0)
masked_values = np.where(binary_mask, index_values, np.nan)
else:
masked_values = index_values
# Compute statistics
valid_values = masked_values[~np.isnan(masked_values)]
if len(valid_values) > 0:
stats = {
'mean': float(np.mean(valid_values)),
'std': float(np.std(valid_values)),
'min': float(np.min(valid_values)),
'max': float(np.max(valid_values)),
'median': float(np.median(valid_values)),
'q25': float(np.percentile(valid_values, 25)),
'q75': float(np.percentile(valid_values, 75)),
'nan_fraction': float(np.isnan(masked_values).sum() / masked_values.size)
}
else:
stats = {
'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0,
'median': 0.0, 'q25': 0.0, 'q75': 0.0, 'nan_fraction': 1.0
}
indices[index_name] = {
'values': masked_values,
'statistics': stats
}
logger.debug(f"Computed {index_name}")
except Exception as e:
logger.error(f"Failed to compute {index_name}: {e}")
continue
return indices
def create_vegetation_index_image(self, index_values: np.ndarray,
colormap: str = 'RdYlGn',
vmin: Optional[float] = None,
vmax: Optional[float] = None) -> np.ndarray:
"""
Create visualization image for vegetation index.
Args:
index_values: Vegetation index values
colormap: Matplotlib colormap name
vmin: Minimum value for normalization
vmax: Maximum value for normalization
Returns:
RGB image array
"""
try:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
# Determine value range
valid_values = index_values[~np.isnan(index_values)]
if len(valid_values) == 0:
return np.zeros((*index_values.shape, 3), dtype=np.uint8)
if vmin is None:
vmin = np.min(valid_values)
if vmax is None:
vmax = np.max(valid_values)
# Normalize values
norm = Normalize(vmin=vmin, vmax=vmax)
cmap = cm.get_cmap(colormap)
# Apply colormap
rgba_img = cmap(norm(index_values))
rgba_img[np.isnan(index_values)] = [1, 1, 1, 1] # White for NaN
# Convert to RGB uint8
rgb_img = (rgba_img[:, :, :3] * 255).astype(np.uint8)
return rgb_img
except Exception as e:
logger.error(f"Failed to create vegetation index image: {e}")
return np.zeros((*index_values.shape, 3), dtype=np.uint8)
def get_available_indices(self) -> list:
"""Get list of available vegetation indices."""
return list(self.index_formulas.keys())
def get_index_requirements(self, index_name: str) -> list:
"""
Get required bands for a specific index.
Args:
index_name: Name of the vegetation index
Returns:
List of required band names
"""
return self.index_bands.get(index_name, [])
def validate_spectral_data(self, spectral_stack: Dict[str, np.ndarray]) -> bool:
"""
Validate spectral data for vegetation index computation.
Args:
spectral_stack: Dictionary of spectral bands
Returns:
True if valid, False otherwise
"""
if not spectral_stack:
return False
required_bands = ['nir', 'red', 'green', 'red_edge']
if not all(band in spectral_stack for band in required_bands):
logger.warning("Missing required spectral bands")
return False
# Check data shapes
shapes = [arr.shape for arr in spectral_stack.values()]
if not all(shape == shapes[0] for shape in shapes):
logger.warning("Inconsistent spectral band shapes")
return False
return True
|