Fahimeh Orvati Nia
commited on
Commit
·
b4123b8
1
Parent(s):
4768cde
Add sorghum_pipeline code
Browse files- sorghum_pipeline/__init__.py +31 -0
- sorghum_pipeline/__pycache__/__init__.cpython-312.pyc +0 -0
- sorghum_pipeline/__pycache__/config.cpython-312.pyc +0 -0
- sorghum_pipeline/__pycache__/pipeline.cpython-312.pyc +0 -0
- sorghum_pipeline/config.py +249 -0
- sorghum_pipeline/data/__init__.py +15 -0
- sorghum_pipeline/data/__pycache__/__init__.cpython-312.pyc +0 -0
- sorghum_pipeline/data/__pycache__/loader.cpython-312.pyc +0 -0
- sorghum_pipeline/data/__pycache__/mask_handler.cpython-312.pyc +0 -0
- sorghum_pipeline/data/__pycache__/preprocessor.cpython-312.pyc +0 -0
- sorghum_pipeline/data/loader.py +444 -0
- sorghum_pipeline/data/mask_handler.py +296 -0
- sorghum_pipeline/data/preprocessor.py +279 -0
- sorghum_pipeline/features/__init__.py +21 -0
- sorghum_pipeline/features/__pycache__/__init__.cpython-312.pyc +0 -0
- sorghum_pipeline/features/__pycache__/morphology.cpython-312.pyc +0 -0
- sorghum_pipeline/features/__pycache__/spectral.cpython-312.pyc +0 -0
- sorghum_pipeline/features/__pycache__/texture.cpython-312.pyc +0 -0
- sorghum_pipeline/features/__pycache__/vegetation.cpython-312.pyc +0 -0
- sorghum_pipeline/features/morphology.py +380 -0
- sorghum_pipeline/features/spectral.py +383 -0
- sorghum_pipeline/features/texture.py +373 -0
- sorghum_pipeline/features/vegetation.py +308 -0
- sorghum_pipeline/models/__init__.py +10 -0
- sorghum_pipeline/models/__pycache__/__init__.cpython-312.pyc +0 -0
- sorghum_pipeline/models/__pycache__/dbc_lacunarity.cpython-312.pyc +0 -0
- sorghum_pipeline/models/dbc_lacunarity.py +90 -0
- sorghum_pipeline/output/__init__.py +13 -0
- sorghum_pipeline/output/__pycache__/__init__.cpython-312.pyc +0 -0
- sorghum_pipeline/output/__pycache__/manager.cpython-312.pyc +0 -0
- sorghum_pipeline/output/manager.py +688 -0
- sorghum_pipeline/pipeline.py +1377 -0
- sorghum_pipeline/segmentation/__init__.py +12 -0
- sorghum_pipeline/segmentation/__pycache__/__init__.cpython-312.pyc +0 -0
- sorghum_pipeline/segmentation/__pycache__/advanced_occlusion_handler.cpython-312.pyc +0 -0
- sorghum_pipeline/segmentation/__pycache__/leaf_occlusion_handler.cpython-312.pyc +0 -0
- sorghum_pipeline/segmentation/__pycache__/manager.cpython-312.pyc +0 -0
- sorghum_pipeline/segmentation/__pycache__/occlusion_handler.cpython-312.pyc +0 -0
- sorghum_pipeline/segmentation/manager.py +309 -0
sorghum_pipeline/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sorghum Plant Phenotyping Pipeline
|
| 3 |
+
|
| 4 |
+
A comprehensive pipeline for analyzing sorghum plant images including:
|
| 5 |
+
- Data loading and preprocessing
|
| 6 |
+
- Image segmentation and masking
|
| 7 |
+
- Feature extraction (texture, morphology, vegetation indices)
|
| 8 |
+
- Results visualization and export
|
| 9 |
+
|
| 10 |
+
Author: Fahime Horvatinia
|
| 11 |
+
Version: 2.0.0
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
__version__ = "2.0.0"
|
| 15 |
+
__author__ = "Fahime Horvatinia"
|
| 16 |
+
|
| 17 |
+
from .pipeline import SorghumPipeline
|
| 18 |
+
from .config import Config
|
| 19 |
+
from .data import DataLoader
|
| 20 |
+
from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
|
| 21 |
+
from .output import OutputManager
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"SorghumPipeline",
|
| 25 |
+
"Config",
|
| 26 |
+
"DataLoader",
|
| 27 |
+
"TextureExtractor",
|
| 28 |
+
"VegetationIndexExtractor",
|
| 29 |
+
"MorphologyExtractor",
|
| 30 |
+
"OutputManager"
|
| 31 |
+
]
|
sorghum_pipeline/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (943 Bytes). View file
|
|
|
sorghum_pipeline/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
sorghum_pipeline/__pycache__/pipeline.cpython-312.pyc
ADDED
|
Binary file (66.9 kB). View file
|
|
|
sorghum_pipeline/config.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration management for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This module handles all configuration settings, paths, and parameters
|
| 5 |
+
used throughout the pipeline.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import yaml
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Any, Optional
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class Paths:
|
| 17 |
+
"""Configuration for all file paths."""
|
| 18 |
+
input_folder: str
|
| 19 |
+
output_folder: str
|
| 20 |
+
boundingbox_dir: Optional[str] = None
|
| 21 |
+
labels_folder: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
def __post_init__(self):
|
| 24 |
+
"""Ensure all paths are absolute where provided."""
|
| 25 |
+
self.input_folder = os.path.abspath(self.input_folder)
|
| 26 |
+
self.output_folder = os.path.abspath(self.output_folder)
|
| 27 |
+
if self.boundingbox_dir:
|
| 28 |
+
self.boundingbox_dir = os.path.abspath(self.boundingbox_dir)
|
| 29 |
+
if self.labels_folder:
|
| 30 |
+
self.labels_folder = os.path.abspath(self.labels_folder)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class ProcessingParams:
|
| 35 |
+
"""Parameters for image processing."""
|
| 36 |
+
# Image processing
|
| 37 |
+
target_size: tuple = (1024, 1024)
|
| 38 |
+
gaussian_blur_kernel: int = 5
|
| 39 |
+
morphology_kernel_size: int = 7
|
| 40 |
+
min_component_area: int = 1000
|
| 41 |
+
|
| 42 |
+
# Segmentation
|
| 43 |
+
segmentation_threshold: float = 0.5
|
| 44 |
+
max_components: int = 10
|
| 45 |
+
|
| 46 |
+
# Texture analysis
|
| 47 |
+
lbp_points: int = 8
|
| 48 |
+
lbp_radius: int = 1
|
| 49 |
+
hog_orientations: int = 9
|
| 50 |
+
hog_pixels_per_cell: tuple = (8, 8)
|
| 51 |
+
hog_cells_per_block: tuple = (2, 2)
|
| 52 |
+
lacunarity_window: int = 15
|
| 53 |
+
ehd_threshold: float = 0.3
|
| 54 |
+
angle_resolution: int = 45
|
| 55 |
+
|
| 56 |
+
# Vegetation indices
|
| 57 |
+
epsilon: float = 1e-10
|
| 58 |
+
soil_factor: float = 0.16
|
| 59 |
+
|
| 60 |
+
# Morphology
|
| 61 |
+
pixel_to_cm: float = 0.1099609375
|
| 62 |
+
prune_sizes: list = field(default_factory=lambda: [200, 100, 50, 30, 10])
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class OutputSettings:
|
| 67 |
+
"""Settings for output generation."""
|
| 68 |
+
save_images: bool = True
|
| 69 |
+
save_plots: bool = True
|
| 70 |
+
save_metadata: bool = True
|
| 71 |
+
image_dpi: int = 150
|
| 72 |
+
plot_dpi: int = 100
|
| 73 |
+
image_format: str = "png"
|
| 74 |
+
|
| 75 |
+
# Subdirectories
|
| 76 |
+
segmentation_dir: str = "segmentation"
|
| 77 |
+
features_dir: str = "features"
|
| 78 |
+
texture_dir: str = "texture"
|
| 79 |
+
morphology_dir: str = "morphology"
|
| 80 |
+
vegetation_dir: str = "vegetation_indices"
|
| 81 |
+
analysis_dir: str = "analysis"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class ModelSettings:
|
| 86 |
+
"""Settings for ML models."""
|
| 87 |
+
device: str = "auto" # auto, cpu, cuda
|
| 88 |
+
model_name: str = "briaai/RMBG-2.0"
|
| 89 |
+
batch_size: int = 1
|
| 90 |
+
trust_remote_code: bool = True
|
| 91 |
+
cache_dir: str = ""
|
| 92 |
+
local_files_only: bool = False
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Config:
|
| 96 |
+
"""Main configuration class for the Sorghum Pipeline."""
|
| 97 |
+
|
| 98 |
+
def __init__(self, config_path: Optional[str] = None):
|
| 99 |
+
"""
|
| 100 |
+
Initialize configuration.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
config_path: Path to YAML configuration file. If None, uses defaults.
|
| 104 |
+
"""
|
| 105 |
+
self.paths = Paths(
|
| 106 |
+
input_folder="",
|
| 107 |
+
output_folder="",
|
| 108 |
+
boundingbox_dir=""
|
| 109 |
+
)
|
| 110 |
+
self.processing = ProcessingParams()
|
| 111 |
+
self.output = OutputSettings()
|
| 112 |
+
self.model = ModelSettings()
|
| 113 |
+
|
| 114 |
+
if config_path:
|
| 115 |
+
self.load_from_file(config_path)
|
| 116 |
+
|
| 117 |
+
def load_from_file(self, config_path: str) -> None:
|
| 118 |
+
"""Load configuration from YAML file."""
|
| 119 |
+
config_path = Path(config_path)
|
| 120 |
+
if not config_path.exists():
|
| 121 |
+
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
| 122 |
+
|
| 123 |
+
with open(config_path, 'r') as f:
|
| 124 |
+
config_data = yaml.safe_load(f)
|
| 125 |
+
|
| 126 |
+
# Update paths
|
| 127 |
+
if 'paths' in config_data:
|
| 128 |
+
self.paths = Paths(**config_data['paths'])
|
| 129 |
+
|
| 130 |
+
# Update processing parameters
|
| 131 |
+
if 'processing' in config_data:
|
| 132 |
+
for key, value in config_data['processing'].items():
|
| 133 |
+
if hasattr(self.processing, key):
|
| 134 |
+
setattr(self.processing, key, value)
|
| 135 |
+
|
| 136 |
+
# Update output settings
|
| 137 |
+
if 'output' in config_data:
|
| 138 |
+
for key, value in config_data['output'].items():
|
| 139 |
+
if hasattr(self.output, key):
|
| 140 |
+
setattr(self.output, key, value)
|
| 141 |
+
|
| 142 |
+
# Update model settings
|
| 143 |
+
if 'model' in config_data:
|
| 144 |
+
for key, value in config_data['model'].items():
|
| 145 |
+
if hasattr(self.model, key):
|
| 146 |
+
setattr(self.model, key, value)
|
| 147 |
+
|
| 148 |
+
def save_to_file(self, config_path: str) -> None:
|
| 149 |
+
"""Save current configuration to YAML file."""
|
| 150 |
+
config_data = {
|
| 151 |
+
'paths': {
|
| 152 |
+
'input_folder': self.paths.input_folder,
|
| 153 |
+
'output_folder': self.paths.output_folder,
|
| 154 |
+
'boundingbox_dir': self.paths.boundingbox_dir,
|
| 155 |
+
'labels_folder': self.paths.labels_folder
|
| 156 |
+
},
|
| 157 |
+
'processing': {
|
| 158 |
+
'target_size': self.processing.target_size,
|
| 159 |
+
'gaussian_blur_kernel': self.processing.gaussian_blur_kernel,
|
| 160 |
+
'morphology_kernel_size': self.processing.morphology_kernel_size,
|
| 161 |
+
'min_component_area': self.processing.min_component_area,
|
| 162 |
+
'segmentation_threshold': self.processing.segmentation_threshold,
|
| 163 |
+
'max_components': self.processing.max_components,
|
| 164 |
+
'lbp_points': self.processing.lbp_points,
|
| 165 |
+
'lbp_radius': self.processing.lbp_radius,
|
| 166 |
+
'hog_orientations': self.processing.hog_orientations,
|
| 167 |
+
'hog_pixels_per_cell': self.processing.hog_pixels_per_cell,
|
| 168 |
+
'hog_cells_per_block': self.processing.hog_cells_per_block,
|
| 169 |
+
'lacunarity_window': self.processing.lacunarity_window,
|
| 170 |
+
'ehd_threshold': self.processing.ehd_threshold,
|
| 171 |
+
'angle_resolution': self.processing.angle_resolution,
|
| 172 |
+
'epsilon': self.processing.epsilon,
|
| 173 |
+
'soil_factor': self.processing.soil_factor,
|
| 174 |
+
'pixel_to_cm': self.processing.pixel_to_cm,
|
| 175 |
+
'prune_sizes': self.processing.prune_sizes
|
| 176 |
+
},
|
| 177 |
+
'output': {
|
| 178 |
+
'save_images': self.output.save_images,
|
| 179 |
+
'save_plots': self.output.save_plots,
|
| 180 |
+
'save_metadata': self.output.save_metadata,
|
| 181 |
+
'image_dpi': self.output.image_dpi,
|
| 182 |
+
'plot_dpi': self.output.plot_dpi,
|
| 183 |
+
'image_format': self.output.image_format,
|
| 184 |
+
'segmentation_dir': self.output.segmentation_dir,
|
| 185 |
+
'features_dir': self.output.features_dir,
|
| 186 |
+
'texture_dir': self.output.texture_dir,
|
| 187 |
+
'morphology_dir': self.output.morphology_dir,
|
| 188 |
+
'vegetation_dir': self.output.vegetation_dir,
|
| 189 |
+
'analysis_dir': self.output.analysis_dir
|
| 190 |
+
},
|
| 191 |
+
'model': {
|
| 192 |
+
'device': self.model.device,
|
| 193 |
+
'model_name': self.model.model_name,
|
| 194 |
+
'batch_size': self.model.batch_size,
|
| 195 |
+
'trust_remote_code': self.model.trust_remote_code,
|
| 196 |
+
'cache_dir': self.model.cache_dir,
|
| 197 |
+
'local_files_only': self.model.local_files_only,
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
with open(config_path, 'w') as f:
|
| 202 |
+
yaml.dump(config_data, f, default_flow_style=False, indent=2)
|
| 203 |
+
|
| 204 |
+
def get_device(self) -> str:
|
| 205 |
+
"""Get the appropriate device for processing."""
|
| 206 |
+
if self.model.device == "auto":
|
| 207 |
+
import torch
|
| 208 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 209 |
+
return self.model.device
|
| 210 |
+
|
| 211 |
+
def create_output_directories(self, base_path: str) -> None:
|
| 212 |
+
"""Ensure base output directory exists only.
|
| 213 |
+
|
| 214 |
+
Subdirectories are created per plant in the output manager.
|
| 215 |
+
"""
|
| 216 |
+
base_path = Path(base_path)
|
| 217 |
+
base_path.mkdir(parents=True, exist_ok=True)
|
| 218 |
+
|
| 219 |
+
def validate(self) -> bool:
|
| 220 |
+
"""Validate configuration settings."""
|
| 221 |
+
# Check if input directory exists
|
| 222 |
+
if not os.path.exists(self.paths.input_folder):
|
| 223 |
+
raise FileNotFoundError(f"Input folder does not exist: {self.paths.input_folder}")
|
| 224 |
+
|
| 225 |
+
# Check if bounding box directory exists (optional)
|
| 226 |
+
if hasattr(self.paths, 'boundingbox_dir') and self.paths.boundingbox_dir and not os.path.exists(self.paths.boundingbox_dir):
|
| 227 |
+
raise FileNotFoundError(f"Bounding box directory does not exist: {self.paths.boundingbox_dir}")
|
| 228 |
+
|
| 229 |
+
# Validate processing parameters
|
| 230 |
+
if self.processing.target_size[0] <= 0 or self.processing.target_size[1] <= 0:
|
| 231 |
+
raise ValueError("Target size must be positive")
|
| 232 |
+
|
| 233 |
+
if self.processing.segmentation_threshold < 0 or self.processing.segmentation_threshold > 1:
|
| 234 |
+
raise ValueError("Segmentation threshold must be between 0 and 1")
|
| 235 |
+
|
| 236 |
+
return True
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def create_default_config(output_path: str) -> None:
|
| 240 |
+
"""Create a default configuration file."""
|
| 241 |
+
config = Config()
|
| 242 |
+
config.paths = Paths(
|
| 243 |
+
input_folder="Sorghum_dataset",
|
| 244 |
+
output_folder="Sorghum_pipeline_Results",
|
| 245 |
+
boundingbox_dir="boundingbox",
|
| 246 |
+
labels_folder="labels"
|
| 247 |
+
)
|
| 248 |
+
config.save_to_file(output_path)
|
| 249 |
+
print(f"Default configuration created at: {output_path}")
|
sorghum_pipeline/data/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loading and preprocessing modules.
|
| 3 |
+
|
| 4 |
+
This package contains all data-related functionality including:
|
| 5 |
+
- Raw image loading
|
| 6 |
+
- Data preprocessing
|
| 7 |
+
- Mask handling
|
| 8 |
+
- Data validation
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from .loader import DataLoader
|
| 12 |
+
from .preprocessor import ImagePreprocessor
|
| 13 |
+
from .mask_handler import MaskHandler
|
| 14 |
+
|
| 15 |
+
__all__ = ["DataLoader", "ImagePreprocessor", "MaskHandler"]
|
sorghum_pipeline/data/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (577 Bytes). View file
|
|
|
sorghum_pipeline/data/__pycache__/loader.cpython-312.pyc
ADDED
|
Binary file (21.9 kB). View file
|
|
|
sorghum_pipeline/data/__pycache__/mask_handler.cpython-312.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
sorghum_pipeline/data/__pycache__/preprocessor.cpython-312.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
sorghum_pipeline/data/loader.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loading functionality for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This module handles loading raw images, managing plant data,
|
| 5 |
+
and organizing data according to the pipeline requirements.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import glob
|
| 10 |
+
import json
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, List, Tuple, Optional, Any
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import numpy as np
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DataLoader:
|
| 21 |
+
"""Handles loading and organizing plant image data."""
|
| 22 |
+
|
| 23 |
+
# Plants to ignore completely (empty by default)
|
| 24 |
+
IGNORE_PLANTS = set()
|
| 25 |
+
|
| 26 |
+
# Plants where you want exactly one frame from their own folder
|
| 27 |
+
EXACT_FRAME = {
|
| 28 |
+
4: 7, 5: 5, 7: 5, 12: 5, 13: 5, 18: 7, 19: 2, 20: 3,
|
| 29 |
+
24: 6, 25: 5, 26: 5, 30: 8, 37: 7
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# Plants where you want to borrow a frame from a different plant folder
|
| 33 |
+
BORROW_FRAME = {
|
| 34 |
+
14: (13, 5), 15: (14, 5), 16: (15, 5), 33: (34, 7),
|
| 35 |
+
34: (35, 7), 35: (35, 8), 36: (36, 6)
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# Overrides provided by user: preferred frame per target plant name
|
| 39 |
+
FRAME_OVERRIDE_BY_NAME = {
|
| 40 |
+
'plant1': 9, 'plant2': 10, 'plant3': 9, 'plant5': 7, 'plant6': 9, 'plant8': 5,
|
| 41 |
+
'plant7': 9, 'plant10': 9, 'plant11': 9, 'plant12': 9,
|
| 42 |
+
'plant13': 10, 'plant14': 8, 'plant15': 11, 'plant19': 4, 'plant20': 7,
|
| 43 |
+
'plant21': 9, 'plant22': 10, 'plant25': 4, 'plant26': 2, 'plant27': 10, 'plant28': 9, 'plant29': 2,
|
| 44 |
+
'plant30': 9, 'plant31': 10, 'plant32': 9, 'plant33': 8,
|
| 45 |
+
'plant35': 9, 'plant36': 4, 'plant38': 9, 'plant39': 9, 'plant41': 9,
|
| 46 |
+
'plant42': 6, 'plant43': 10, 'plant44': 9, 'plant45': 7,
|
| 47 |
+
'plant47': 10, 'plant48': 11,
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
# Substitutes provided by user: map target plant name -> source plant name
|
| 51 |
+
PLANT_SUBSTITUTES_BY_NAME = {
|
| 52 |
+
'plant16': 'plant15', 'plant15': 'plant14', 'plant14': 'plant13',
|
| 53 |
+
'plant13': 'plant12', 'plant33': 'plant34', 'plant34': 'plant35',
|
| 54 |
+
'plant24': 'plant25', 'plant25': 'plant25', 'plant35': 'plant36',
|
| 55 |
+
'plant36': 'plant37', 'plant37': 'plant37', 'plant44': 'plant43',
|
| 56 |
+
'plant45': 'plant44',
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
def __init__(self, input_folder: str, debug: bool = False, include_ignored: bool = False, strict_loader: bool = False, excluded_dates: Optional[List[str]] = None):
|
| 60 |
+
"""
|
| 61 |
+
Initialize the data loader.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
input_folder: Path to the input dataset folder
|
| 65 |
+
debug: Enable debug logging
|
| 66 |
+
"""
|
| 67 |
+
self.input_folder = Path(input_folder)
|
| 68 |
+
self.debug = debug
|
| 69 |
+
self.include_ignored = include_ignored
|
| 70 |
+
self.strict_loader = strict_loader
|
| 71 |
+
|
| 72 |
+
if not self.input_folder.exists():
|
| 73 |
+
raise FileNotFoundError(f"Input folder does not exist: {input_folder}")
|
| 74 |
+
# Normalize excluded dates as a set of folder names (with dashes)
|
| 75 |
+
self.excluded_dates = set(excluded_dates or [])
|
| 76 |
+
|
| 77 |
+
def load_selected_frames(self) -> Dict[str, Dict[str, Any]]:
|
| 78 |
+
"""
|
| 79 |
+
Load selected frames according to predefined rules.
|
| 80 |
+
If strict_loader is True, load only frame numbers from the plant's own folder (no borrowing/special picks).
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Dictionary with plant data organized by key format: "YYYY_MM_DD_plantX_frameY"
|
| 84 |
+
"""
|
| 85 |
+
logger.info("Loading selected frames from dataset...")
|
| 86 |
+
plants = {}
|
| 87 |
+
|
| 88 |
+
# Detect if input folder is a direct date folder (contains plant folders)
|
| 89 |
+
first_items = list(self.input_folder.iterdir())
|
| 90 |
+
has_plant_folders = any(item.is_dir() and item.name.startswith('plant') for item in first_items)
|
| 91 |
+
|
| 92 |
+
def choose_frame_and_source(pid: int) -> Tuple[int, str]:
|
| 93 |
+
if self.strict_loader:
|
| 94 |
+
# In strict mode, honor explicit frame overrides AND substitution of source plant
|
| 95 |
+
plant_name_local = f"plant{pid}"
|
| 96 |
+
frame_num = self.FRAME_OVERRIDE_BY_NAME.get(
|
| 97 |
+
plant_name_local,
|
| 98 |
+
self.EXACT_FRAME.get(pid, 8)
|
| 99 |
+
)
|
| 100 |
+
source_plant = self.PLANT_SUBSTITUTES_BY_NAME.get(plant_name_local, plant_name_local)
|
| 101 |
+
return frame_num, source_plant
|
| 102 |
+
# Original behavior
|
| 103 |
+
frame_num = self._get_frame_number(pid)
|
| 104 |
+
source_plant = self._get_source_plant(pid)
|
| 105 |
+
return frame_num, source_plant
|
| 106 |
+
|
| 107 |
+
if has_plant_folders:
|
| 108 |
+
# Direct date folder structure
|
| 109 |
+
date_name = self.input_folder.name
|
| 110 |
+
date_path = self.input_folder
|
| 111 |
+
for plant_name in sorted(os.listdir(date_path)):
|
| 112 |
+
plant_path = date_path / plant_name
|
| 113 |
+
if not plant_path.is_dir():
|
| 114 |
+
continue
|
| 115 |
+
try:
|
| 116 |
+
plant_id = int(plant_name.replace("plant", ""))
|
| 117 |
+
except ValueError:
|
| 118 |
+
continue
|
| 119 |
+
if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
|
| 120 |
+
if self.debug:
|
| 121 |
+
logger.debug(f"Ignoring plant {plant_id}")
|
| 122 |
+
continue
|
| 123 |
+
frame_num, source_plant = choose_frame_and_source(plant_id)
|
| 124 |
+
frame_data = self._load_single_frame(date_path, source_plant, frame_num, plant_name)
|
| 125 |
+
if frame_data:
|
| 126 |
+
key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_num}"
|
| 127 |
+
plants[key] = frame_data
|
| 128 |
+
logger.debug(f"Loaded {key}")
|
| 129 |
+
else:
|
| 130 |
+
# Parent folder structure with date subfolders
|
| 131 |
+
for date_name in sorted(os.listdir(self.input_folder)):
|
| 132 |
+
date_path = self.input_folder / date_name
|
| 133 |
+
if not date_path.is_dir():
|
| 134 |
+
continue
|
| 135 |
+
if date_name in self.excluded_dates:
|
| 136 |
+
logger.info(f"Skipping excluded date: {date_name}")
|
| 137 |
+
continue
|
| 138 |
+
for plant_name in sorted(os.listdir(date_path)):
|
| 139 |
+
plant_path = date_path / plant_name
|
| 140 |
+
if not plant_path.is_dir():
|
| 141 |
+
continue
|
| 142 |
+
try:
|
| 143 |
+
plant_id = int(plant_name.replace("plant", ""))
|
| 144 |
+
except ValueError:
|
| 145 |
+
continue
|
| 146 |
+
if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
|
| 147 |
+
if self.debug:
|
| 148 |
+
logger.debug(f"Ignoring plant {plant_id}")
|
| 149 |
+
continue
|
| 150 |
+
frame_num, source_plant = choose_frame_and_source(plant_id)
|
| 151 |
+
frame_data = self._load_single_frame(date_path, source_plant, frame_num, plant_name)
|
| 152 |
+
if frame_data:
|
| 153 |
+
key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_num}"
|
| 154 |
+
plants[key] = frame_data
|
| 155 |
+
logger.debug(f"Loaded {key}")
|
| 156 |
+
|
| 157 |
+
logger.info(f"Successfully loaded {len(plants)} plant frames")
|
| 158 |
+
return plants
|
| 159 |
+
|
| 160 |
+
def load_all_frames(self) -> Dict[str, Dict[str, Any]]:
|
| 161 |
+
"""
|
| 162 |
+
Load all available frames for each plant.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Dictionary with all plant frames
|
| 166 |
+
"""
|
| 167 |
+
logger.info("Loading all frames from dataset...")
|
| 168 |
+
plants = {}
|
| 169 |
+
|
| 170 |
+
# Check if we're directly in a date folder (contains plant folders)
|
| 171 |
+
# or in a parent folder (contains date folders)
|
| 172 |
+
first_items = list(self.input_folder.iterdir())
|
| 173 |
+
has_plant_folders = any(item.is_dir() and item.name.startswith('plant') for item in first_items)
|
| 174 |
+
|
| 175 |
+
if has_plant_folders:
|
| 176 |
+
# We're directly in a date folder
|
| 177 |
+
logger.info("Detected direct date folder structure")
|
| 178 |
+
date_name = self.input_folder.name
|
| 179 |
+
self._load_plants_from_date_folder(self.input_folder, date_name, plants)
|
| 180 |
+
else:
|
| 181 |
+
# We're in a parent folder with date subfolders
|
| 182 |
+
logger.info("Detected parent folder structure")
|
| 183 |
+
for date_name in sorted(os.listdir(self.input_folder)):
|
| 184 |
+
date_path = self.input_folder / date_name
|
| 185 |
+
if not date_path.is_dir():
|
| 186 |
+
continue
|
| 187 |
+
if date_name in self.excluded_dates:
|
| 188 |
+
logger.info(f"Skipping excluded date: {date_name}")
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
logger.info(f"Processing date: {date_name}")
|
| 192 |
+
self._load_plants_from_date_folder(date_path, date_name, plants)
|
| 193 |
+
|
| 194 |
+
logger.info(f"Successfully loaded {len(plants)} plant frames")
|
| 195 |
+
return plants
|
| 196 |
+
|
| 197 |
+
def _load_plants_from_date_folder(self, date_path: Path, date_name: str, plants: Dict[str, Dict[str, Any]]) -> None:
|
| 198 |
+
"""Load plants from a date folder."""
|
| 199 |
+
for plant_name in sorted(os.listdir(date_path)):
|
| 200 |
+
plant_path = date_path / plant_name
|
| 201 |
+
if not plant_path.is_dir():
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
# Extract plant ID
|
| 205 |
+
try:
|
| 206 |
+
plant_id = int(plant_name.replace("plant", ""))
|
| 207 |
+
except ValueError:
|
| 208 |
+
logger.warning(f"Could not extract plant ID from {plant_name}")
|
| 209 |
+
continue
|
| 210 |
+
|
| 211 |
+
# Skip ignored plants
|
| 212 |
+
if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
|
| 213 |
+
logger.info(f"Skipping ignored plant {plant_id}")
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
logger.info(f"Processing plant {plant_id}")
|
| 217 |
+
|
| 218 |
+
# Load all frames for this plant
|
| 219 |
+
pattern = str(plant_path / f"{plant_name}_frame*.tif")
|
| 220 |
+
frame_files = sorted(glob.glob(pattern))
|
| 221 |
+
logger.info(f"Found {len(frame_files)} frame files for {plant_name}")
|
| 222 |
+
|
| 223 |
+
for frame_path in frame_files:
|
| 224 |
+
frame_data = self._load_frame_from_path(frame_path, plant_name)
|
| 225 |
+
if frame_data:
|
| 226 |
+
frame_id = Path(frame_path).stem.split("_frame")[-1]
|
| 227 |
+
key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_id}"
|
| 228 |
+
plants[key] = frame_data
|
| 229 |
+
logger.debug(f"Loaded frame: {key}")
|
| 230 |
+
else:
|
| 231 |
+
logger.warning(f"Failed to load frame: {frame_path}")
|
| 232 |
+
|
| 233 |
+
def load_single_plant(self, date: str, plant: str, frame: int) -> Optional[Dict[str, Any]]:
|
| 234 |
+
"""
|
| 235 |
+
Load a specific plant frame.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
date: Date string (e.g., "2025-02-05")
|
| 239 |
+
plant: Plant name (e.g., "plant1")
|
| 240 |
+
frame: Frame number
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Plant data dictionary or None if not found
|
| 244 |
+
"""
|
| 245 |
+
date_path = self.input_folder / date
|
| 246 |
+
if not date_path.exists():
|
| 247 |
+
logger.error(f"Date folder not found: {date}")
|
| 248 |
+
return None
|
| 249 |
+
|
| 250 |
+
plant_path = date_path / plant
|
| 251 |
+
if not plant_path.exists():
|
| 252 |
+
logger.error(f"Plant folder not found: {plant}")
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
filename = f"{plant}_frame{frame}.tif"
|
| 256 |
+
frame_path = plant_path / filename
|
| 257 |
+
|
| 258 |
+
return self._load_frame_from_path(str(frame_path), plant)
|
| 259 |
+
|
| 260 |
+
def _get_frame_number(self, plant_id: int) -> int:
|
| 261 |
+
"""Get the frame number for a plant ID."""
|
| 262 |
+
plant_name = f"plant{plant_id}"
|
| 263 |
+
# Highest priority: explicit overrides by plant name
|
| 264 |
+
if plant_name in self.FRAME_OVERRIDE_BY_NAME:
|
| 265 |
+
return int(self.FRAME_OVERRIDE_BY_NAME[plant_name])
|
| 266 |
+
# Next: original exact/borrrow rules
|
| 267 |
+
if plant_id in self.EXACT_FRAME:
|
| 268 |
+
return self.EXACT_FRAME[plant_id]
|
| 269 |
+
elif plant_id in self.BORROW_FRAME:
|
| 270 |
+
return self.BORROW_FRAME[plant_id][1]
|
| 271 |
+
else:
|
| 272 |
+
return 8 # Default frame
|
| 273 |
+
|
| 274 |
+
def _get_source_plant(self, plant_id: int) -> str:
|
| 275 |
+
"""Get the source plant name for a plant ID."""
|
| 276 |
+
plant_name = f"plant{plant_id}"
|
| 277 |
+
# Highest priority: explicit substitutes by plant name
|
| 278 |
+
if plant_name in self.PLANT_SUBSTITUTES_BY_NAME:
|
| 279 |
+
return self.PLANT_SUBSTITUTES_BY_NAME[plant_name]
|
| 280 |
+
# Next: original borrow rules
|
| 281 |
+
if plant_id in self.BORROW_FRAME:
|
| 282 |
+
source_id = self.BORROW_FRAME[plant_id][0]
|
| 283 |
+
return f"plant{source_id}"
|
| 284 |
+
else:
|
| 285 |
+
return f"plant{plant_id}"
|
| 286 |
+
|
| 287 |
+
def _load_single_frame(self, date_path: Path, source_plant: str,
|
| 288 |
+
frame_num: int, target_plant: str) -> Optional[Dict[str, Any]]:
|
| 289 |
+
"""Load a single frame from the specified path."""
|
| 290 |
+
filename = f"{source_plant}_frame{frame_num}.tif"
|
| 291 |
+
frame_path = date_path / source_plant / filename
|
| 292 |
+
|
| 293 |
+
if not frame_path.exists():
|
| 294 |
+
if self.debug:
|
| 295 |
+
logger.warning(f"Frame not found: {frame_path}")
|
| 296 |
+
return None
|
| 297 |
+
|
| 298 |
+
return self._load_frame_from_path(str(frame_path), target_plant)
|
| 299 |
+
|
| 300 |
+
def _load_frame_from_path(self, frame_path: str, plant_name: str) -> Optional[Dict[str, Any]]:
|
| 301 |
+
"""Load frame data from a file path."""
|
| 302 |
+
try:
|
| 303 |
+
logger.debug(f"Attempting to load: {frame_path}")
|
| 304 |
+
image = Image.open(frame_path)
|
| 305 |
+
filename = Path(frame_path).name
|
| 306 |
+
logger.debug(f"Successfully loaded image: {filename}, size: {image.size}")
|
| 307 |
+
|
| 308 |
+
return {
|
| 309 |
+
"raw_image": (image, filename),
|
| 310 |
+
"plant_name": plant_name,
|
| 311 |
+
"file_path": frame_path
|
| 312 |
+
}
|
| 313 |
+
except Exception as e:
|
| 314 |
+
logger.error(f"Failed to load {frame_path}: {e}")
|
| 315 |
+
return None
|
| 316 |
+
|
| 317 |
+
def load_bounding_boxes(self, bbox_dir: str) -> Dict[str, Tuple[int, int, int, int]]:
|
| 318 |
+
"""
|
| 319 |
+
Load bounding box data from JSON files.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
bbox_dir: Directory containing bounding box JSON files
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
Dictionary mapping plant names to bounding box coordinates
|
| 326 |
+
"""
|
| 327 |
+
bbox_path = Path(bbox_dir)
|
| 328 |
+
if not bbox_path.exists():
|
| 329 |
+
raise FileNotFoundError(f"Bounding box directory not found: {bbox_dir}")
|
| 330 |
+
|
| 331 |
+
bbox_lookup = {}
|
| 332 |
+
|
| 333 |
+
for json_file in bbox_path.glob("*.json"):
|
| 334 |
+
stem = json_file.stem
|
| 335 |
+
# Normalize stems like plant_33_new -> plant33
|
| 336 |
+
if stem.startswith('plant_'):
|
| 337 |
+
parts = stem.split('_')
|
| 338 |
+
try:
|
| 339 |
+
idx = next(i for i,p in enumerate(parts) if p.isdigit())
|
| 340 |
+
plant_id = f"plant{parts[idx]}"
|
| 341 |
+
except Exception:
|
| 342 |
+
plant_id = stem.replace('_', '')
|
| 343 |
+
else:
|
| 344 |
+
plant_id = stem
|
| 345 |
+
try:
|
| 346 |
+
with open(json_file, 'r') as f:
|
| 347 |
+
data = json.load(f)
|
| 348 |
+
|
| 349 |
+
shapes = data.get('shapes', [])
|
| 350 |
+
# Prefer rectangle labeled 'sorghum' (case-insensitive), else first rectangle
|
| 351 |
+
def _is_sorghum_label(s: dict) -> bool:
|
| 352 |
+
for key in ('label', 'name', 'text'):
|
| 353 |
+
val = s.get(key)
|
| 354 |
+
if isinstance(val, str) and val.lower() == 'sorghum':
|
| 355 |
+
return True
|
| 356 |
+
return False
|
| 357 |
+
rect = next((s for s in shapes if s.get('shape_type') == 'rectangle' and _is_sorghum_label(s)), None)
|
| 358 |
+
if rect is None:
|
| 359 |
+
rect = next((s for s in shapes if s.get('shape_type') == 'rectangle'), None)
|
| 360 |
+
|
| 361 |
+
if rect:
|
| 362 |
+
(x1, y1), (x2, y2) = rect['points']
|
| 363 |
+
bbox_lookup[plant_id] = (
|
| 364 |
+
int(max(0, x1)),
|
| 365 |
+
int(max(0, y1)),
|
| 366 |
+
int(min(1e9, x2)),
|
| 367 |
+
int(min(1e9, y2))
|
| 368 |
+
)
|
| 369 |
+
else:
|
| 370 |
+
bbox_lookup[plant_id] = None
|
| 371 |
+
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.error(f"Failed to load bounding box {json_file}: {e}")
|
| 374 |
+
|
| 375 |
+
logger.info(f"Loaded {len(bbox_lookup)} bounding boxes")
|
| 376 |
+
return bbox_lookup
|
| 377 |
+
|
| 378 |
+
def load_hand_labels(self, labels_dir: str) -> Dict[str, np.ndarray]:
|
| 379 |
+
"""
|
| 380 |
+
Load hand-labeled masks from JSON files.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
labels_dir: Directory containing label JSON files
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
Dictionary mapping plant names to mask arrays
|
| 387 |
+
"""
|
| 388 |
+
labels_path = Path(labels_dir)
|
| 389 |
+
if not labels_path.exists():
|
| 390 |
+
logger.warning(f"Labels directory not found: {labels_dir}")
|
| 391 |
+
return {}
|
| 392 |
+
|
| 393 |
+
masks = {}
|
| 394 |
+
|
| 395 |
+
for json_file in labels_path.glob("*.json"):
|
| 396 |
+
plant_id = json_file.stem
|
| 397 |
+
try:
|
| 398 |
+
with open(json_file, 'r') as f:
|
| 399 |
+
data = json.load(f)
|
| 400 |
+
|
| 401 |
+
# Create mask from shapes (assuming we have image dimensions)
|
| 402 |
+
# This would need to be adapted based on your label format
|
| 403 |
+
mask = self._create_mask_from_shapes(data)
|
| 404 |
+
if mask is not None:
|
| 405 |
+
masks[plant_id] = mask
|
| 406 |
+
|
| 407 |
+
except Exception as e:
|
| 408 |
+
logger.error(f"Failed to load label {json_file}: {e}")
|
| 409 |
+
|
| 410 |
+
logger.info(f"Loaded {len(masks)} hand labels")
|
| 411 |
+
return masks
|
| 412 |
+
|
| 413 |
+
def _create_mask_from_shapes(self, data: Dict) -> Optional[np.ndarray]:
|
| 414 |
+
"""Create a mask array from shape data."""
|
| 415 |
+
# This is a placeholder - implement based on your label format
|
| 416 |
+
# For now, return None
|
| 417 |
+
return None
|
| 418 |
+
|
| 419 |
+
def validate_data(self, plants: Dict[str, Dict[str, Any]]) -> bool:
|
| 420 |
+
"""
|
| 421 |
+
Validate loaded plant data.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
plants: Dictionary of plant data
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
True if data is valid, False otherwise
|
| 428 |
+
"""
|
| 429 |
+
if not plants:
|
| 430 |
+
logger.error("No plant data loaded")
|
| 431 |
+
return False
|
| 432 |
+
|
| 433 |
+
for key, data in plants.items():
|
| 434 |
+
if "raw_image" not in data:
|
| 435 |
+
logger.error(f"Missing raw_image in {key}")
|
| 436 |
+
return False
|
| 437 |
+
|
| 438 |
+
image, filename = data["raw_image"]
|
| 439 |
+
if not isinstance(image, Image.Image):
|
| 440 |
+
logger.error(f"Invalid image type in {key}")
|
| 441 |
+
return False
|
| 442 |
+
|
| 443 |
+
logger.info("Data validation passed")
|
| 444 |
+
return True
|
sorghum_pipeline/data/mask_handler.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mask handling functionality for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This module handles mask creation, processing, and validation
|
| 5 |
+
for plant segmentation tasks.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
from typing import Dict, Tuple, Optional, List
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MaskHandler:
|
| 17 |
+
"""Handles mask creation, processing, and validation."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, min_area: int = 1000, kernel_size: int = 7):
|
| 20 |
+
"""
|
| 21 |
+
Initialize the mask handler.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
min_area: Minimum area for connected components
|
| 25 |
+
kernel_size: Kernel size for morphological operations
|
| 26 |
+
"""
|
| 27 |
+
self.min_area = min_area
|
| 28 |
+
self.kernel_size = kernel_size
|
| 29 |
+
|
| 30 |
+
def create_bounding_box_mask(self, image_shape: Tuple[int, int],
|
| 31 |
+
bbox: Tuple[int, int, int, int]) -> np.ndarray:
|
| 32 |
+
"""
|
| 33 |
+
Create a mask from bounding box coordinates.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
image_shape: Shape of the image (height, width)
|
| 37 |
+
bbox: Bounding box coordinates (x1, y1, x2, y2)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Binary mask array
|
| 41 |
+
"""
|
| 42 |
+
h, w = image_shape[:2]
|
| 43 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 44 |
+
|
| 45 |
+
x1, y1, x2, y2 = bbox
|
| 46 |
+
# Clamp coordinates to image bounds
|
| 47 |
+
x1 = max(0, min(w, x1))
|
| 48 |
+
y1 = max(0, min(h, y1))
|
| 49 |
+
x2 = max(0, min(w, x2))
|
| 50 |
+
y2 = max(0, min(h, y2))
|
| 51 |
+
|
| 52 |
+
mask[y1:y2, x1:x2] = 255
|
| 53 |
+
return mask
|
| 54 |
+
|
| 55 |
+
def preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
|
| 56 |
+
"""
|
| 57 |
+
Preprocess mask by cleaning and filtering.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
mask: Input mask
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Cleaned mask
|
| 64 |
+
"""
|
| 65 |
+
if mask is None:
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
# Convert to binary if needed
|
| 69 |
+
if isinstance(mask, tuple):
|
| 70 |
+
mask = mask[0]
|
| 71 |
+
|
| 72 |
+
# Ensure binary format
|
| 73 |
+
mask = ((mask.astype(np.int32) > 0).astype(np.uint8)) * 255
|
| 74 |
+
|
| 75 |
+
# Morphological opening to remove noise
|
| 76 |
+
kernel = cv2.getStructuringElement(
|
| 77 |
+
cv2.MORPH_ELLIPSE,
|
| 78 |
+
(self.kernel_size, self.kernel_size)
|
| 79 |
+
)
|
| 80 |
+
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 81 |
+
|
| 82 |
+
# Remove small connected components
|
| 83 |
+
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
| 84 |
+
opened, connectivity=8
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
clean_mask = np.zeros_like(opened)
|
| 88 |
+
for label in range(1, num_labels): # Skip background
|
| 89 |
+
if stats[label, cv2.CC_STAT_AREA] >= self.min_area:
|
| 90 |
+
clean_mask[labels == label] = 255
|
| 91 |
+
|
| 92 |
+
return clean_mask
|
| 93 |
+
|
| 94 |
+
def keep_largest_component(self, mask: np.ndarray) -> np.ndarray:
|
| 95 |
+
"""
|
| 96 |
+
Keep only the largest connected component in the mask.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
mask: Input mask
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Mask with only the largest component
|
| 103 |
+
"""
|
| 104 |
+
if mask is None:
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8)
|
| 108 |
+
|
| 109 |
+
if num_labels <= 1:
|
| 110 |
+
return mask
|
| 111 |
+
|
| 112 |
+
# Find the largest component (excluding background)
|
| 113 |
+
areas = stats[1:, cv2.CC_STAT_AREA]
|
| 114 |
+
largest_label = 1 + np.argmax(areas)
|
| 115 |
+
|
| 116 |
+
# Create mask with only the largest component
|
| 117 |
+
largest_mask = (labels == largest_label).astype(np.uint8) * 255
|
| 118 |
+
|
| 119 |
+
return largest_mask
|
| 120 |
+
|
| 121 |
+
def apply_mask_to_image(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 122 |
+
"""
|
| 123 |
+
Apply mask to image.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
image: Input image
|
| 127 |
+
mask: Binary mask
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Masked image
|
| 131 |
+
"""
|
| 132 |
+
if mask is None:
|
| 133 |
+
return image
|
| 134 |
+
|
| 135 |
+
return cv2.bitwise_and(image, image, mask=mask)
|
| 136 |
+
|
| 137 |
+
def create_overlay(self, image: np.ndarray, mask: np.ndarray,
|
| 138 |
+
color: Tuple[int, int, int] = (0, 255, 0),
|
| 139 |
+
alpha: float = 0.5) -> np.ndarray:
|
| 140 |
+
"""
|
| 141 |
+
Create overlay of mask on image.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
image: Base image
|
| 145 |
+
mask: Binary mask
|
| 146 |
+
color: Overlay color (B, G, R)
|
| 147 |
+
alpha: Overlay transparency
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Image with mask overlay
|
| 151 |
+
"""
|
| 152 |
+
overlay = image.copy()
|
| 153 |
+
overlay[mask == 255] = color
|
| 154 |
+
return cv2.addWeighted(image, 1.0 - alpha, overlay, alpha, 0)
|
| 155 |
+
|
| 156 |
+
def get_mask_properties(self, mask: np.ndarray) -> Dict[str, float]:
|
| 157 |
+
"""
|
| 158 |
+
Get properties of the mask.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
mask: Binary mask
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Dictionary of mask properties
|
| 165 |
+
"""
|
| 166 |
+
if mask is None:
|
| 167 |
+
return {}
|
| 168 |
+
|
| 169 |
+
# Convert to binary
|
| 170 |
+
binary_mask = (mask > 127).astype(np.uint8)
|
| 171 |
+
|
| 172 |
+
# Calculate properties
|
| 173 |
+
area = np.sum(binary_mask)
|
| 174 |
+
perimeter = cv2.arcLength(
|
| 175 |
+
cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0][0],
|
| 176 |
+
True
|
| 177 |
+
) if len(cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]) > 0 else 0
|
| 178 |
+
|
| 179 |
+
# Bounding box
|
| 180 |
+
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 181 |
+
if contours:
|
| 182 |
+
x, y, w, h = cv2.boundingRect(contours[0])
|
| 183 |
+
bbox_area = w * h
|
| 184 |
+
aspect_ratio = w / h if h > 0 else 0
|
| 185 |
+
else:
|
| 186 |
+
bbox_area = 0
|
| 187 |
+
aspect_ratio = 0
|
| 188 |
+
|
| 189 |
+
return {
|
| 190 |
+
"area": float(area),
|
| 191 |
+
"perimeter": float(perimeter),
|
| 192 |
+
"bbox_area": float(bbox_area),
|
| 193 |
+
"aspect_ratio": float(aspect_ratio),
|
| 194 |
+
"coverage": float(area) / (mask.shape[0] * mask.shape[1]) if mask.size > 0 else 0.0
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
def validate_mask(self, mask: np.ndarray) -> bool:
|
| 198 |
+
"""
|
| 199 |
+
Validate mask format and content.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
mask: Mask to validate
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
True if valid, False otherwise
|
| 206 |
+
"""
|
| 207 |
+
if mask is None:
|
| 208 |
+
return False
|
| 209 |
+
|
| 210 |
+
if not isinstance(mask, np.ndarray):
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
+
if mask.ndim != 2:
|
| 214 |
+
return False
|
| 215 |
+
|
| 216 |
+
if mask.dtype not in [np.uint8, np.bool_]:
|
| 217 |
+
return False
|
| 218 |
+
|
| 219 |
+
# Check if mask has any foreground pixels
|
| 220 |
+
if np.sum(mask > 0) == 0:
|
| 221 |
+
logger.warning("Mask has no foreground pixels")
|
| 222 |
+
return False
|
| 223 |
+
|
| 224 |
+
return True
|
| 225 |
+
|
| 226 |
+
def resize_mask(self, mask: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
|
| 227 |
+
"""
|
| 228 |
+
Resize mask to target size.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
mask: Input mask
|
| 232 |
+
target_size: Target size (width, height)
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
Resized mask
|
| 236 |
+
"""
|
| 237 |
+
if mask is None:
|
| 238 |
+
return None
|
| 239 |
+
|
| 240 |
+
return cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
|
| 241 |
+
|
| 242 |
+
def dilate_mask(self, mask: np.ndarray, kernel_size: int = 5) -> np.ndarray:
|
| 243 |
+
"""
|
| 244 |
+
Dilate mask to expand foreground regions.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
mask: Input mask
|
| 248 |
+
kernel_size: Size of dilation kernel
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Dilated mask
|
| 252 |
+
"""
|
| 253 |
+
if mask is None:
|
| 254 |
+
return None
|
| 255 |
+
|
| 256 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
| 257 |
+
return cv2.dilate(mask, kernel, iterations=1)
|
| 258 |
+
|
| 259 |
+
def erode_mask(self, mask: np.ndarray, kernel_size: int = 5) -> np.ndarray:
|
| 260 |
+
"""
|
| 261 |
+
Erode mask to shrink foreground regions.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
mask: Input mask
|
| 265 |
+
kernel_size: Size of erosion kernel
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Eroded mask
|
| 269 |
+
"""
|
| 270 |
+
if mask is None:
|
| 271 |
+
return None
|
| 272 |
+
|
| 273 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
| 274 |
+
return cv2.erode(mask, kernel, iterations=1)
|
| 275 |
+
|
| 276 |
+
def fill_holes(self, mask: np.ndarray) -> np.ndarray:
|
| 277 |
+
"""
|
| 278 |
+
Fill holes in the mask.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
mask: Input mask
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
Mask with filled holes
|
| 285 |
+
"""
|
| 286 |
+
if mask is None:
|
| 287 |
+
return None
|
| 288 |
+
|
| 289 |
+
# Find contours
|
| 290 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 291 |
+
|
| 292 |
+
# Create filled mask
|
| 293 |
+
filled_mask = np.zeros_like(mask)
|
| 294 |
+
cv2.fillPoly(filled_mask, contours, 255)
|
| 295 |
+
|
| 296 |
+
return filled_mask
|
sorghum_pipeline/data/preprocessor.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image preprocessing functionality for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This module handles image preprocessing, composite creation,
|
| 5 |
+
and basic image transformations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from typing import Dict, Tuple, Any, Optional
|
| 12 |
+
from itertools import product
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ImagePreprocessor:
|
| 19 |
+
"""Handles image preprocessing and composite creation."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, target_size: Optional[Tuple[int, int]] = None):
|
| 22 |
+
"""
|
| 23 |
+
Initialize the image preprocessor.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
target_size: Target size for image resizing (width, height)
|
| 27 |
+
"""
|
| 28 |
+
self.target_size = target_size
|
| 29 |
+
|
| 30 |
+
def convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
|
| 31 |
+
"""
|
| 32 |
+
Convert array to uint8 format with proper normalization.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
arr: Input array
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Normalized uint8 array
|
| 39 |
+
"""
|
| 40 |
+
# Handle NaN and infinite values
|
| 41 |
+
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 42 |
+
|
| 43 |
+
# Normalize to 0-255 range
|
| 44 |
+
if arr.ptp() > 0:
|
| 45 |
+
normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
|
| 46 |
+
else:
|
| 47 |
+
normalized = np.zeros_like(arr)
|
| 48 |
+
|
| 49 |
+
return np.clip(normalized, 0, 255).astype(np.uint8)
|
| 50 |
+
|
| 51 |
+
def process_raw_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
|
| 52 |
+
"""
|
| 53 |
+
Process raw 4-band image into composite and spectral bands.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
pil_img: PIL Image object containing 4-band data
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Tuple of (composite_image, spectral_bands_dict)
|
| 60 |
+
"""
|
| 61 |
+
# Split the 4-band RAW into tiles and stack them
|
| 62 |
+
d = pil_img.size[0] // 2
|
| 63 |
+
boxes = [
|
| 64 |
+
(j, i, j + d, i + d)
|
| 65 |
+
for i, j in product(
|
| 66 |
+
range(0, pil_img.height, d),
|
| 67 |
+
range(0, pil_img.width, d)
|
| 68 |
+
)
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
# Extract tiles and stack them
|
| 72 |
+
stack = np.stack([
|
| 73 |
+
np.array(pil_img.crop(box), dtype=float)
|
| 74 |
+
for box in boxes
|
| 75 |
+
], axis=-1)
|
| 76 |
+
|
| 77 |
+
# Bands come in order: [green, red, red_edge, nir]
|
| 78 |
+
green, red, red_edge, nir = np.split(stack, 4, axis=-1)
|
| 79 |
+
|
| 80 |
+
# Build pseudo-RGB composite as (green, red_edge, red)
|
| 81 |
+
composite = np.concatenate([green, red_edge, red], axis=-1)
|
| 82 |
+
composite_uint8 = self.convert_to_uint8(composite)
|
| 83 |
+
|
| 84 |
+
# Prepare spectral stack
|
| 85 |
+
spectral_bands = {
|
| 86 |
+
"green": green,
|
| 87 |
+
"red": red,
|
| 88 |
+
"red_edge": red_edge,
|
| 89 |
+
"nir": nir
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
return composite_uint8, spectral_bands
|
| 93 |
+
|
| 94 |
+
def create_composites(self, plants: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
| 95 |
+
"""
|
| 96 |
+
Create composites for all plants in the dataset.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
plants: Dictionary of plant data
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Updated plant data with composites and spectral stacks
|
| 103 |
+
"""
|
| 104 |
+
logger.info("Creating composites for all plants...")
|
| 105 |
+
|
| 106 |
+
for key, pdata in plants.items():
|
| 107 |
+
try:
|
| 108 |
+
# Find the PIL Image
|
| 109 |
+
if "raw_image" in pdata:
|
| 110 |
+
image, _ = pdata["raw_image"]
|
| 111 |
+
elif "raw_images" in pdata and pdata["raw_images"]:
|
| 112 |
+
image, _ = pdata["raw_images"][0]
|
| 113 |
+
else:
|
| 114 |
+
logger.warning(f"No raw image found for {key}")
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
# Process the image
|
| 118 |
+
composite, spectral_stack = self.process_raw_image(image)
|
| 119 |
+
|
| 120 |
+
# Store results
|
| 121 |
+
pdata["composite"] = composite
|
| 122 |
+
pdata["spectral_stack"] = spectral_stack
|
| 123 |
+
|
| 124 |
+
logger.debug(f"Created composite for {key}")
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.error(f"Failed to create composite for {key}: {e}")
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
logger.info("Composite creation completed")
|
| 131 |
+
return plants
|
| 132 |
+
|
| 133 |
+
def resize_image(self, image: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray:
|
| 134 |
+
"""
|
| 135 |
+
Resize image to target size.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
image: Input image
|
| 139 |
+
target_size: Target size (width, height). If None, uses self.target_size
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Resized image
|
| 143 |
+
"""
|
| 144 |
+
if target_size is None:
|
| 145 |
+
target_size = self.target_size
|
| 146 |
+
|
| 147 |
+
if target_size is None:
|
| 148 |
+
return image
|
| 149 |
+
|
| 150 |
+
return cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
|
| 151 |
+
|
| 152 |
+
def normalize_image(self, image: np.ndarray, method: str = "minmax") -> np.ndarray:
|
| 153 |
+
"""
|
| 154 |
+
Normalize image using specified method.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
image: Input image
|
| 158 |
+
method: Normalization method ("minmax", "zscore", "robust")
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Normalized image
|
| 162 |
+
"""
|
| 163 |
+
if method == "minmax":
|
| 164 |
+
if image.dtype == np.uint8:
|
| 165 |
+
return image.astype(np.float32) / 255.0
|
| 166 |
+
else:
|
| 167 |
+
img_min, img_max = image.min(), image.max()
|
| 168 |
+
if img_max > img_min:
|
| 169 |
+
return (image - img_min) / (img_max - img_min)
|
| 170 |
+
else:
|
| 171 |
+
return np.zeros_like(image, dtype=np.float32)
|
| 172 |
+
|
| 173 |
+
elif method == "zscore":
|
| 174 |
+
mean, std = image.mean(), image.std()
|
| 175 |
+
if std > 0:
|
| 176 |
+
return (image - mean) / std
|
| 177 |
+
else:
|
| 178 |
+
return np.zeros_like(image, dtype=np.float32)
|
| 179 |
+
|
| 180 |
+
elif method == "robust":
|
| 181 |
+
q25, q75 = np.percentile(image, [25, 75])
|
| 182 |
+
if q75 > q25:
|
| 183 |
+
return (image - q25) / (q75 - q25)
|
| 184 |
+
else:
|
| 185 |
+
return np.zeros_like(image, dtype=np.float32)
|
| 186 |
+
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError(f"Unknown normalization method: {method}")
|
| 189 |
+
|
| 190 |
+
def apply_gaussian_blur(self, image: np.ndarray, kernel_size: int = 5) -> np.ndarray:
|
| 191 |
+
"""
|
| 192 |
+
Apply Gaussian blur to image.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
image: Input image
|
| 196 |
+
kernel_size: Size of Gaussian kernel
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Blurred image
|
| 200 |
+
"""
|
| 201 |
+
if kernel_size % 2 == 0:
|
| 202 |
+
kernel_size += 1
|
| 203 |
+
|
| 204 |
+
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
| 205 |
+
|
| 206 |
+
def apply_sharpening(self, image: np.ndarray) -> np.ndarray:
|
| 207 |
+
"""
|
| 208 |
+
Apply sharpening filter to image.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
image: Input image
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Sharpened image
|
| 215 |
+
"""
|
| 216 |
+
kernel = np.array([
|
| 217 |
+
[0, -1, 0],
|
| 218 |
+
[-1, 5, -1],
|
| 219 |
+
[0, -1, 0]
|
| 220 |
+
])
|
| 221 |
+
|
| 222 |
+
return cv2.filter2D(image, -1, kernel)
|
| 223 |
+
|
| 224 |
+
def enhance_contrast(self, image: np.ndarray, alpha: float = 1.2, beta: int = 15) -> np.ndarray:
|
| 225 |
+
"""
|
| 226 |
+
Enhance image contrast.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
image: Input image
|
| 230 |
+
alpha: Contrast control (1.0 = no change)
|
| 231 |
+
beta: Brightness control (0 = no change)
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Enhanced image
|
| 235 |
+
"""
|
| 236 |
+
return cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
|
| 237 |
+
|
| 238 |
+
def create_overlay(self, base_image: np.ndarray, mask: np.ndarray,
|
| 239 |
+
color: Tuple[int, int, int] = (0, 255, 0),
|
| 240 |
+
alpha: float = 0.5) -> np.ndarray:
|
| 241 |
+
"""
|
| 242 |
+
Create overlay of mask on base image.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
base_image: Base image
|
| 246 |
+
mask: Binary mask
|
| 247 |
+
color: Overlay color (B, G, R)
|
| 248 |
+
alpha: Overlay transparency
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Image with overlay
|
| 252 |
+
"""
|
| 253 |
+
overlay = base_image.copy()
|
| 254 |
+
overlay[mask == 255] = color
|
| 255 |
+
return cv2.addWeighted(base_image, 1.0 - alpha, overlay, alpha, 0)
|
| 256 |
+
|
| 257 |
+
def validate_composite(self, composite: np.ndarray) -> bool:
|
| 258 |
+
"""
|
| 259 |
+
Validate composite image.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
composite: Composite image to validate
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
True if valid, False otherwise
|
| 266 |
+
"""
|
| 267 |
+
if composite is None:
|
| 268 |
+
return False
|
| 269 |
+
|
| 270 |
+
if not isinstance(composite, np.ndarray):
|
| 271 |
+
return False
|
| 272 |
+
|
| 273 |
+
if composite.ndim != 3 or composite.shape[2] != 3:
|
| 274 |
+
return False
|
| 275 |
+
|
| 276 |
+
if composite.dtype != np.uint8:
|
| 277 |
+
return False
|
| 278 |
+
|
| 279 |
+
return True
|
sorghum_pipeline/features/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feature extraction modules for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This package contains all feature extraction functionality including:
|
| 5 |
+
- Texture features (LBP, HOG, Lacunarity, EHD)
|
| 6 |
+
- Vegetation indices
|
| 7 |
+
- Morphological features
|
| 8 |
+
- Spectral features
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from .texture import TextureExtractor
|
| 12 |
+
from .vegetation import VegetationIndexExtractor
|
| 13 |
+
from .morphology import MorphologyExtractor
|
| 14 |
+
from .spectral import SpectralExtractor
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"TextureExtractor",
|
| 18 |
+
"VegetationIndexExtractor",
|
| 19 |
+
"MorphologyExtractor",
|
| 20 |
+
"SpectralExtractor"
|
| 21 |
+
]
|
sorghum_pipeline/features/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (714 Bytes). View file
|
|
|
sorghum_pipeline/features/__pycache__/morphology.cpython-312.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
sorghum_pipeline/features/__pycache__/spectral.cpython-312.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
sorghum_pipeline/features/__pycache__/texture.cpython-312.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
sorghum_pipeline/features/__pycache__/vegetation.cpython-312.pyc
ADDED
|
Binary file (25.1 kB). View file
|
|
|
sorghum_pipeline/features/morphology.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Morphological feature extraction for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This module handles extraction of morphological features using PlantCV
|
| 5 |
+
and other computer vision techniques.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
import contextlib
|
| 11 |
+
import sys
|
| 12 |
+
from typing import Dict, Any, Optional, List, Tuple
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
# Try to import PlantCV, but don't fail if not available
|
| 16 |
+
try:
|
| 17 |
+
from plantcv import plantcv as pcv
|
| 18 |
+
PLANT_CV_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
PLANT_CV_AVAILABLE = False
|
| 21 |
+
logger.warning("PlantCV not available. Morphological features will be limited.")
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MorphologyExtractor:
|
| 27 |
+
"""Extracts morphological features from plant images."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None):
|
| 30 |
+
"""
|
| 31 |
+
Initialize morphology extractor.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
pixel_to_cm: Conversion factor from pixels to centimeters
|
| 35 |
+
prune_sizes: List of pruning sizes for skeleton processing
|
| 36 |
+
"""
|
| 37 |
+
self.pixel_to_cm = pixel_to_cm
|
| 38 |
+
self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
|
| 39 |
+
|
| 40 |
+
if PLANT_CV_AVAILABLE:
|
| 41 |
+
# Configure PlantCV
|
| 42 |
+
pcv.params.debug = None
|
| 43 |
+
pcv.params.text_size = 0.7
|
| 44 |
+
pcv.params.text_thickness = 2
|
| 45 |
+
pcv.params.line_thickness = 3
|
| 46 |
+
pcv.params.dpi = 100
|
| 47 |
+
|
| 48 |
+
def extract_morphology_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
|
| 49 |
+
"""
|
| 50 |
+
Extract morphological features from plant image and mask.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
image: Plant image (BGR format)
|
| 54 |
+
mask: Binary mask of the plant
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Dictionary containing morphological features and images
|
| 58 |
+
"""
|
| 59 |
+
features = {
|
| 60 |
+
'traits': {},
|
| 61 |
+
'images': {},
|
| 62 |
+
'success': False
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
# Preprocess mask
|
| 67 |
+
clean_mask = self._preprocess_mask(mask)
|
| 68 |
+
if clean_mask is None:
|
| 69 |
+
logger.warning("Failed to preprocess mask")
|
| 70 |
+
return features
|
| 71 |
+
|
| 72 |
+
# Extract basic morphological features
|
| 73 |
+
basic_traits = self._extract_basic_features(clean_mask)
|
| 74 |
+
features['traits'].update(basic_traits)
|
| 75 |
+
|
| 76 |
+
# Extract skeleton-based features if PlantCV is available
|
| 77 |
+
if PLANT_CV_AVAILABLE:
|
| 78 |
+
skeleton_features = self._extract_skeleton_features(image, clean_mask)
|
| 79 |
+
features['traits'].update(skeleton_features['traits'])
|
| 80 |
+
features['images'].update(skeleton_features['images'])
|
| 81 |
+
else:
|
| 82 |
+
# Fallback to basic OpenCV features
|
| 83 |
+
cv_features = self._extract_opencv_features(image, clean_mask)
|
| 84 |
+
features['traits'].update(cv_features['traits'])
|
| 85 |
+
features['images'].update(cv_features['images'])
|
| 86 |
+
|
| 87 |
+
features['success'] = True
|
| 88 |
+
logger.debug("Morphological features extracted successfully")
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"Morphological feature extraction failed: {e}")
|
| 92 |
+
|
| 93 |
+
return features
|
| 94 |
+
|
| 95 |
+
def _preprocess_mask(self, mask: np.ndarray) -> Optional[np.ndarray]:
|
| 96 |
+
"""Preprocess mask for morphological analysis."""
|
| 97 |
+
if mask is None:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
# Convert to binary if needed
|
| 101 |
+
if isinstance(mask, tuple):
|
| 102 |
+
mask = mask[0]
|
| 103 |
+
|
| 104 |
+
# Ensure binary format
|
| 105 |
+
mask = ((mask.astype(np.int32) > 0).astype(np.uint8)) * 255
|
| 106 |
+
|
| 107 |
+
# Morphological opening to remove noise
|
| 108 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 109 |
+
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 110 |
+
|
| 111 |
+
# Remove small connected components
|
| 112 |
+
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(opened, connectivity=8)
|
| 113 |
+
clean_mask = np.zeros_like(opened)
|
| 114 |
+
|
| 115 |
+
for label in range(1, num_labels): # Skip background
|
| 116 |
+
if stats[label, cv2.CC_STAT_AREA] >= 1000:
|
| 117 |
+
clean_mask[labels == label] = 255
|
| 118 |
+
|
| 119 |
+
return clean_mask
|
| 120 |
+
|
| 121 |
+
def _extract_basic_features(self, mask: np.ndarray) -> Dict[str, float]:
|
| 122 |
+
"""Extract basic morphological features using OpenCV."""
|
| 123 |
+
features = {}
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
# Find contours
|
| 127 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 128 |
+
|
| 129 |
+
if not contours:
|
| 130 |
+
return features
|
| 131 |
+
|
| 132 |
+
# Get the largest contour
|
| 133 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
| 134 |
+
|
| 135 |
+
# Basic measurements
|
| 136 |
+
area = cv2.contourArea(largest_contour)
|
| 137 |
+
perimeter = cv2.arcLength(largest_contour, True)
|
| 138 |
+
|
| 139 |
+
# Bounding box
|
| 140 |
+
x, y, w, h = cv2.boundingRect(largest_contour)
|
| 141 |
+
bbox_area = w * h
|
| 142 |
+
|
| 143 |
+
# Ellipse fitting
|
| 144 |
+
if len(largest_contour) >= 5:
|
| 145 |
+
ellipse = cv2.fitEllipse(largest_contour)
|
| 146 |
+
(center, axes, angle) = ellipse
|
| 147 |
+
major_axis = max(axes)
|
| 148 |
+
minor_axis = min(axes)
|
| 149 |
+
else:
|
| 150 |
+
major_axis = max(w, h)
|
| 151 |
+
minor_axis = min(w, h)
|
| 152 |
+
|
| 153 |
+
# Convert to centimeters
|
| 154 |
+
features['area_cm2'] = area * (self.pixel_to_cm ** 2)
|
| 155 |
+
features['perimeter_cm'] = perimeter * self.pixel_to_cm
|
| 156 |
+
features['width_cm'] = w * self.pixel_to_cm
|
| 157 |
+
features['height_cm'] = h * self.pixel_to_cm
|
| 158 |
+
features['bbox_area_cm2'] = bbox_area * (self.pixel_to_cm ** 2)
|
| 159 |
+
features['major_axis_cm'] = major_axis * self.pixel_to_cm
|
| 160 |
+
features['minor_axis_cm'] = minor_axis * self.pixel_to_cm
|
| 161 |
+
features['aspect_ratio'] = w / h if h > 0 else 0
|
| 162 |
+
features['elongation'] = major_axis / minor_axis if minor_axis > 0 else 0
|
| 163 |
+
features['circularity'] = (4 * np.pi * area) / (perimeter ** 2) if perimeter > 0 else 0
|
| 164 |
+
features['solidity'] = area / bbox_area if bbox_area > 0 else 0
|
| 165 |
+
|
| 166 |
+
# Convex hull
|
| 167 |
+
hull = cv2.convexHull(largest_contour)
|
| 168 |
+
hull_area = cv2.contourArea(hull)
|
| 169 |
+
features['convexity'] = area / hull_area if hull_area > 0 else 0
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.error(f"Basic feature extraction failed: {e}")
|
| 173 |
+
|
| 174 |
+
return features
|
| 175 |
+
|
| 176 |
+
def _extract_skeleton_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
|
| 177 |
+
"""Extract skeleton-based features using PlantCV."""
|
| 178 |
+
features = {'traits': {}, 'images': {}}
|
| 179 |
+
|
| 180 |
+
if not PLANT_CV_AVAILABLE:
|
| 181 |
+
return features
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
# Suppress PlantCV output
|
| 185 |
+
with contextlib.redirect_stdout(self._FilteredStream(sys.stdout)), \
|
| 186 |
+
contextlib.redirect_stderr(self._FilteredStream(sys.stderr)):
|
| 187 |
+
|
| 188 |
+
# Skeletonize
|
| 189 |
+
skeleton = pcv.morphology.skeletonize(mask=mask)
|
| 190 |
+
features['images']['skeleton'] = skeleton
|
| 191 |
+
|
| 192 |
+
# Prune skeleton
|
| 193 |
+
pruned_skel = skeleton
|
| 194 |
+
for size in self.prune_sizes:
|
| 195 |
+
pruned_skel, _, _ = pcv.morphology.prune(
|
| 196 |
+
skel_img=pruned_skel, size=size, mask=mask
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
features['images']['pruned_skeleton'] = pruned_skel
|
| 200 |
+
|
| 201 |
+
# Find branch points and tips
|
| 202 |
+
branch_pts = pcv.morphology.find_branch_pts(pruned_skel, mask)
|
| 203 |
+
features['images']['branch_points'] = branch_pts
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
tip_pts = pcv.morphology.find_tips(pruned_skel, mask)
|
| 207 |
+
features['images']['tip_points'] = tip_pts
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.warning(f"Tip detection failed: {e}")
|
| 210 |
+
|
| 211 |
+
# Segment objects
|
| 212 |
+
try:
|
| 213 |
+
leaf_obj, stem_obj = pcv.morphology.segment_sort(
|
| 214 |
+
pruned_skel, [], mask
|
| 215 |
+
)
|
| 216 |
+
features['traits']['num_leaves'] = len(leaf_obj)
|
| 217 |
+
features['traits']['num_stems'] = len(stem_obj)
|
| 218 |
+
except Exception as e:
|
| 219 |
+
logger.warning(f"Object segmentation failed: {e}")
|
| 220 |
+
features['traits']['num_leaves'] = 0
|
| 221 |
+
features['traits']['num_stems'] = 0
|
| 222 |
+
|
| 223 |
+
# Size analysis
|
| 224 |
+
try:
|
| 225 |
+
labeled_mask, n_labels = pcv.create_labels(mask)
|
| 226 |
+
size_analysis = pcv.analyze.size(image, labeled_mask, n_labels, label="default")
|
| 227 |
+
features['images']['size_analysis'] = size_analysis
|
| 228 |
+
|
| 229 |
+
# Get size traits
|
| 230 |
+
obs = pcv.outputs.observations.get("default_1", {})
|
| 231 |
+
for trait, info in obs.items():
|
| 232 |
+
if trait not in ["in_bounds", "object_in_frame"]:
|
| 233 |
+
val = info.get("value", None)
|
| 234 |
+
if val is not None:
|
| 235 |
+
if trait == "area":
|
| 236 |
+
val = val * (self.pixel_to_cm ** 2)
|
| 237 |
+
elif trait in ["perimeter", "width", "height", "longest_path",
|
| 238 |
+
"ellipse_major_axis", "ellipse_minor_axis"]:
|
| 239 |
+
val = val * self.pixel_to_cm
|
| 240 |
+
features['traits'][trait] = val
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
logger.warning(f"Size analysis failed: {e}")
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
logger.error(f"Skeleton feature extraction failed: {e}")
|
| 247 |
+
|
| 248 |
+
return features
|
| 249 |
+
|
| 250 |
+
def _extract_opencv_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
|
| 251 |
+
"""Extract features using only OpenCV (fallback when PlantCV is not available)."""
|
| 252 |
+
features = {'traits': {}, 'images': {}}
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
# Create skeleton using OpenCV
|
| 256 |
+
skeleton = self._create_skeleton_opencv(mask)
|
| 257 |
+
features['images']['skeleton'] = skeleton
|
| 258 |
+
|
| 259 |
+
# Find branch points
|
| 260 |
+
branch_points = self._find_branch_points_opencv(skeleton)
|
| 261 |
+
features['images']['branch_points'] = branch_points
|
| 262 |
+
features['traits']['num_branches'] = len(branch_points)
|
| 263 |
+
|
| 264 |
+
# Find endpoints
|
| 265 |
+
endpoints = self._find_endpoints_opencv(skeleton)
|
| 266 |
+
features['images']['endpoints'] = endpoints
|
| 267 |
+
features['traits']['num_endpoints'] = len(endpoints)
|
| 268 |
+
|
| 269 |
+
# Skeleton length
|
| 270 |
+
skeleton_length = np.sum(skeleton > 0)
|
| 271 |
+
features['traits']['skeleton_length_pixels'] = skeleton_length
|
| 272 |
+
features['traits']['skeleton_length_cm'] = skeleton_length * self.pixel_to_cm
|
| 273 |
+
|
| 274 |
+
except Exception as e:
|
| 275 |
+
logger.error(f"OpenCV feature extraction failed: {e}")
|
| 276 |
+
|
| 277 |
+
return features
|
| 278 |
+
|
| 279 |
+
def _create_skeleton_opencv(self, mask: np.ndarray) -> np.ndarray:
|
| 280 |
+
"""Create skeleton using OpenCV."""
|
| 281 |
+
# Convert to binary
|
| 282 |
+
binary = (mask > 0).astype(np.uint8)
|
| 283 |
+
|
| 284 |
+
# Create skeleton using morphological operations
|
| 285 |
+
skeleton = np.zeros_like(binary)
|
| 286 |
+
element = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
|
| 287 |
+
|
| 288 |
+
while True:
|
| 289 |
+
eroded = cv2.erode(binary, element)
|
| 290 |
+
temp = cv2.dilate(eroded, element)
|
| 291 |
+
temp = cv2.subtract(binary, temp)
|
| 292 |
+
skeleton = cv2.bitwise_or(skeleton, temp)
|
| 293 |
+
binary = eroded.copy()
|
| 294 |
+
|
| 295 |
+
if cv2.countNonZero(binary) == 0:
|
| 296 |
+
break
|
| 297 |
+
|
| 298 |
+
return skeleton * 255
|
| 299 |
+
|
| 300 |
+
def _find_branch_points_opencv(self, skeleton: np.ndarray) -> List[Tuple[int, int]]:
|
| 301 |
+
"""Find branch points in skeleton using OpenCV."""
|
| 302 |
+
# Count neighbors for each pixel
|
| 303 |
+
kernel = np.ones((3, 3), dtype=np.uint8)
|
| 304 |
+
kernel[1, 1] = 0 # Don't count center pixel
|
| 305 |
+
|
| 306 |
+
neighbor_count = cv2.filter2D(skeleton, -1, kernel)
|
| 307 |
+
|
| 308 |
+
# Branch points have 3 or more neighbors
|
| 309 |
+
branch_points = np.where((skeleton > 0) & (neighbor_count >= 3))
|
| 310 |
+
return list(zip(branch_points[1], branch_points[0])) # (x, y) format
|
| 311 |
+
|
| 312 |
+
def _find_endpoints_opencv(self, skeleton: np.ndarray) -> List[Tuple[int, int]]:
|
| 313 |
+
"""Find endpoints in skeleton using OpenCV."""
|
| 314 |
+
# Count neighbors for each pixel
|
| 315 |
+
kernel = np.ones((3, 3), dtype=np.uint8)
|
| 316 |
+
kernel[1, 1] = 0 # Don't count center pixel
|
| 317 |
+
|
| 318 |
+
neighbor_count = cv2.filter2D(skeleton, -1, kernel)
|
| 319 |
+
|
| 320 |
+
# Endpoints have exactly 1 neighbor
|
| 321 |
+
endpoints = np.where((skeleton > 0) & (neighbor_count == 1))
|
| 322 |
+
return list(zip(endpoints[1], endpoints[0])) # (x, y) format
|
| 323 |
+
|
| 324 |
+
class _FilteredStream:
|
| 325 |
+
"""Filter PlantCV output to reduce noise."""
|
| 326 |
+
def __init__(self, stream):
|
| 327 |
+
self.stream = stream
|
| 328 |
+
|
| 329 |
+
def write(self, msg):
|
| 330 |
+
skip = ("got pruned", "Slope of contour", "cannot be plotted")
|
| 331 |
+
if not any(s in msg for s in skip):
|
| 332 |
+
self.stream.write(msg)
|
| 333 |
+
|
| 334 |
+
def flush(self):
|
| 335 |
+
try:
|
| 336 |
+
self.stream.flush()
|
| 337 |
+
except Exception:
|
| 338 |
+
pass
|
| 339 |
+
|
| 340 |
+
def create_morphology_visualization(self, image: np.ndarray, mask: np.ndarray,
|
| 341 |
+
features: Dict[str, Any]) -> np.ndarray:
|
| 342 |
+
"""
|
| 343 |
+
Create visualization of morphological features.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
image: Original image
|
| 347 |
+
mask: Binary mask
|
| 348 |
+
features: Extracted features
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
Visualization image
|
| 352 |
+
"""
|
| 353 |
+
try:
|
| 354 |
+
# Create visualization
|
| 355 |
+
vis = image.copy()
|
| 356 |
+
|
| 357 |
+
# Draw mask outline
|
| 358 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 359 |
+
cv2.drawContours(vis, contours, -1, (0, 255, 0), 2)
|
| 360 |
+
|
| 361 |
+
# Draw bounding box
|
| 362 |
+
if contours:
|
| 363 |
+
x, y, w, h = cv2.boundingRect(contours[0])
|
| 364 |
+
cv2.rectangle(vis, (x, y), (x + w, y + h), (255, 0, 0), 2)
|
| 365 |
+
|
| 366 |
+
# Draw skeleton if available
|
| 367 |
+
if 'skeleton' in features.get('images', {}):
|
| 368 |
+
skeleton = features['images']['skeleton']
|
| 369 |
+
vis[skeleton > 0] = [0, 0, 255] # Red skeleton
|
| 370 |
+
|
| 371 |
+
# Draw branch points if available
|
| 372 |
+
if 'branch_points' in features.get('images', {}):
|
| 373 |
+
branch_img = features['images']['branch_points']
|
| 374 |
+
vis[branch_img > 0] = [255, 255, 0] # Yellow branch points
|
| 375 |
+
|
| 376 |
+
return vis
|
| 377 |
+
|
| 378 |
+
except Exception as e:
|
| 379 |
+
logger.error(f"Visualization creation failed: {e}")
|
| 380 |
+
return image
|
sorghum_pipeline/features/spectral.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Spectral feature extraction for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This module handles extraction of spectral features and analysis
|
| 5 |
+
of multispectral data.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
from sklearn.decomposition import PCA
|
| 11 |
+
from typing import Dict, Any, Optional, List, Tuple
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SpectralExtractor:
|
| 18 |
+
"""Extracts spectral features from multispectral data."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, n_components: int = 3):
|
| 21 |
+
"""
|
| 22 |
+
Initialize spectral extractor.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
n_components: Number of PCA components to extract
|
| 26 |
+
"""
|
| 27 |
+
self.n_components = n_components
|
| 28 |
+
|
| 29 |
+
def extract_spectral_features(self, spectral_stack: Dict[str, np.ndarray],
|
| 30 |
+
mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
|
| 31 |
+
"""
|
| 32 |
+
Extract spectral features from multispectral data.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
spectral_stack: Dictionary of spectral bands
|
| 36 |
+
mask: Optional binary mask
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Dictionary containing spectral features
|
| 40 |
+
"""
|
| 41 |
+
features = {}
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
# Extract individual band features
|
| 45 |
+
features['band_features'] = self._extract_band_features(spectral_stack, mask)
|
| 46 |
+
|
| 47 |
+
# Extract PCA features
|
| 48 |
+
features['pca_features'] = self._extract_pca_features(spectral_stack, mask)
|
| 49 |
+
|
| 50 |
+
# Extract spectral indices
|
| 51 |
+
features['spectral_indices'] = self._extract_spectral_indices(spectral_stack, mask)
|
| 52 |
+
|
| 53 |
+
# Extract texture features from spectral bands
|
| 54 |
+
features['spectral_texture'] = self._extract_spectral_texture(spectral_stack, mask)
|
| 55 |
+
|
| 56 |
+
logger.debug("Spectral features extracted successfully")
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Spectral feature extraction failed: {e}")
|
| 60 |
+
|
| 61 |
+
return features
|
| 62 |
+
|
| 63 |
+
def _extract_band_features(self, spectral_stack: Dict[str, np.ndarray],
|
| 64 |
+
mask: Optional[np.ndarray] = None) -> Dict[str, Dict[str, float]]:
|
| 65 |
+
"""Extract features from individual spectral bands."""
|
| 66 |
+
band_features = {}
|
| 67 |
+
|
| 68 |
+
for band_name, band_data in spectral_stack.items():
|
| 69 |
+
try:
|
| 70 |
+
# Squeeze to 2D if needed
|
| 71 |
+
if band_data.ndim > 2:
|
| 72 |
+
band_data = band_data.squeeze()
|
| 73 |
+
|
| 74 |
+
# Apply mask if provided
|
| 75 |
+
if mask is not None and mask.shape == band_data.shape:
|
| 76 |
+
masked_data = np.where(mask > 0, band_data, np.nan)
|
| 77 |
+
else:
|
| 78 |
+
masked_data = band_data
|
| 79 |
+
|
| 80 |
+
# Compute statistics
|
| 81 |
+
valid_data = masked_data[~np.isnan(masked_data)]
|
| 82 |
+
if len(valid_data) > 0:
|
| 83 |
+
band_features[band_name] = {
|
| 84 |
+
'mean': float(np.mean(valid_data)),
|
| 85 |
+
'std': float(np.std(valid_data)),
|
| 86 |
+
'min': float(np.min(valid_data)),
|
| 87 |
+
'max': float(np.max(valid_data)),
|
| 88 |
+
'median': float(np.median(valid_data)),
|
| 89 |
+
'q25': float(np.percentile(valid_data, 25)),
|
| 90 |
+
'q75': float(np.percentile(valid_data, 75)),
|
| 91 |
+
'skewness': float(self._compute_skewness(valid_data)),
|
| 92 |
+
'kurtosis': float(self._compute_kurtosis(valid_data)),
|
| 93 |
+
'entropy': float(self._compute_entropy(valid_data))
|
| 94 |
+
}
|
| 95 |
+
else:
|
| 96 |
+
band_features[band_name] = {
|
| 97 |
+
'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0,
|
| 98 |
+
'median': 0.0, 'q25': 0.0, 'q75': 0.0,
|
| 99 |
+
'skewness': 0.0, 'kurtosis': 0.0, 'entropy': 0.0
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.error(f"Band feature extraction failed for {band_name}: {e}")
|
| 104 |
+
band_features[band_name] = {}
|
| 105 |
+
|
| 106 |
+
return band_features
|
| 107 |
+
|
| 108 |
+
def _extract_pca_features(self, spectral_stack: Dict[str, np.ndarray],
|
| 109 |
+
mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
|
| 110 |
+
"""Extract PCA features from spectral data."""
|
| 111 |
+
try:
|
| 112 |
+
# Stack all bands
|
| 113 |
+
band_names = ['nir', 'red_edge', 'red', 'green']
|
| 114 |
+
band_data = []
|
| 115 |
+
|
| 116 |
+
for band_name in band_names:
|
| 117 |
+
if band_name in spectral_stack:
|
| 118 |
+
arr = spectral_stack[band_name].squeeze().astype(float)
|
| 119 |
+
if mask is not None and mask.shape == arr.shape:
|
| 120 |
+
arr = np.where(mask > 0, arr, np.nan)
|
| 121 |
+
band_data.append(arr)
|
| 122 |
+
|
| 123 |
+
if not band_data:
|
| 124 |
+
return {}
|
| 125 |
+
|
| 126 |
+
# Stack bands
|
| 127 |
+
full_stack = np.stack(band_data, axis=-1)
|
| 128 |
+
h, w, c = full_stack.shape
|
| 129 |
+
|
| 130 |
+
# Reshape for PCA
|
| 131 |
+
flat_data = full_stack.reshape(-1, c)
|
| 132 |
+
valid_mask = ~np.isnan(flat_data).any(axis=1)
|
| 133 |
+
|
| 134 |
+
if valid_mask.sum() == 0:
|
| 135 |
+
return {}
|
| 136 |
+
|
| 137 |
+
# Apply PCA
|
| 138 |
+
valid_data = flat_data[valid_mask]
|
| 139 |
+
pca = PCA(n_components=min(self.n_components, valid_data.shape[1]))
|
| 140 |
+
pca_result = pca.fit_transform(valid_data)
|
| 141 |
+
|
| 142 |
+
# Create full result array
|
| 143 |
+
full_result = np.full((h * w, self.n_components), np.nan)
|
| 144 |
+
full_result[valid_mask] = pca_result
|
| 145 |
+
|
| 146 |
+
# Reshape back to image dimensions
|
| 147 |
+
pca_components = {}
|
| 148 |
+
for i in range(self.n_components):
|
| 149 |
+
component = full_result[:, i].reshape(h, w)
|
| 150 |
+
pca_components[f'pca_{i+1}'] = component
|
| 151 |
+
|
| 152 |
+
# Compute statistics for this component
|
| 153 |
+
valid_component = component[~np.isnan(component)]
|
| 154 |
+
if len(valid_component) > 0:
|
| 155 |
+
pca_components[f'pca_{i+1}_stats'] = {
|
| 156 |
+
'mean': float(np.mean(valid_component)),
|
| 157 |
+
'std': float(np.std(valid_component)),
|
| 158 |
+
'min': float(np.min(valid_component)),
|
| 159 |
+
'max': float(np.max(valid_component))
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# Add PCA metadata
|
| 163 |
+
pca_components['explained_variance_ratio'] = pca.explained_variance_ratio_.tolist()
|
| 164 |
+
pca_components['total_variance_explained'] = float(np.sum(pca.explained_variance_ratio_))
|
| 165 |
+
|
| 166 |
+
return pca_components
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.error(f"PCA feature extraction failed: {e}")
|
| 170 |
+
return {}
|
| 171 |
+
|
| 172 |
+
def _extract_spectral_indices(self, spectral_stack: Dict[str, np.ndarray],
|
| 173 |
+
mask: Optional[np.ndarray] = None) -> Dict[str, np.ndarray]:
|
| 174 |
+
"""Extract basic spectral indices."""
|
| 175 |
+
indices = {}
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
# Get required bands
|
| 179 |
+
nir = spectral_stack.get('nir', None)
|
| 180 |
+
red = spectral_stack.get('red', None)
|
| 181 |
+
green = spectral_stack.get('green', None)
|
| 182 |
+
red_edge = spectral_stack.get('red_edge', None)
|
| 183 |
+
|
| 184 |
+
if nir is not None:
|
| 185 |
+
nir = nir.squeeze().astype(float)
|
| 186 |
+
if red is not None:
|
| 187 |
+
red = red.squeeze().astype(float)
|
| 188 |
+
if green is not None:
|
| 189 |
+
green = green.squeeze().astype(float)
|
| 190 |
+
if red_edge is not None:
|
| 191 |
+
red_edge = red_edge.squeeze().astype(float)
|
| 192 |
+
|
| 193 |
+
# Apply mask
|
| 194 |
+
if mask is not None:
|
| 195 |
+
if nir is not None and mask.shape == nir.shape:
|
| 196 |
+
nir = np.where(mask > 0, nir, np.nan)
|
| 197 |
+
if red is not None and mask.shape == red.shape:
|
| 198 |
+
red = np.where(mask > 0, red, np.nan)
|
| 199 |
+
if green is not None and mask.shape == green.shape:
|
| 200 |
+
green = np.where(mask > 0, green, np.nan)
|
| 201 |
+
if red_edge is not None and mask.shape == red_edge.shape:
|
| 202 |
+
red_edge = np.where(mask > 0, red_edge, np.nan)
|
| 203 |
+
|
| 204 |
+
# Compute basic indices
|
| 205 |
+
if nir is not None and red is not None:
|
| 206 |
+
indices['nir_red_ratio'] = nir / (red + 1e-10)
|
| 207 |
+
indices['nir_red_diff'] = nir - red
|
| 208 |
+
|
| 209 |
+
if nir is not None and green is not None:
|
| 210 |
+
indices['nir_green_ratio'] = nir / (green + 1e-10)
|
| 211 |
+
indices['nir_green_diff'] = nir - green
|
| 212 |
+
|
| 213 |
+
if red is not None and green is not None:
|
| 214 |
+
indices['red_green_ratio'] = red / (green + 1e-10)
|
| 215 |
+
indices['red_green_diff'] = red - green
|
| 216 |
+
|
| 217 |
+
if nir is not None and red_edge is not None:
|
| 218 |
+
indices['nir_red_edge_ratio'] = nir / (red_edge + 1e-10)
|
| 219 |
+
indices['nir_red_edge_diff'] = nir - red_edge
|
| 220 |
+
|
| 221 |
+
# Compute band ratios
|
| 222 |
+
if nir is not None and red is not None and green is not None:
|
| 223 |
+
indices['nir_red_green_sum'] = nir + red + green
|
| 224 |
+
indices['nir_red_green_mean'] = (nir + red + green) / 3
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
logger.error(f"Spectral index extraction failed: {e}")
|
| 228 |
+
|
| 229 |
+
return indices
|
| 230 |
+
|
| 231 |
+
def _extract_spectral_texture(self, spectral_stack: Dict[str, np.ndarray],
|
| 232 |
+
mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
|
| 233 |
+
"""Extract texture features from spectral bands."""
|
| 234 |
+
texture_features = {}
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
from .texture import TextureExtractor
|
| 238 |
+
|
| 239 |
+
texture_extractor = TextureExtractor()
|
| 240 |
+
|
| 241 |
+
for band_name, band_data in spectral_stack.items():
|
| 242 |
+
try:
|
| 243 |
+
# Prepare grayscale image
|
| 244 |
+
gray_data = band_data.squeeze().astype(float)
|
| 245 |
+
|
| 246 |
+
# Apply mask
|
| 247 |
+
if mask is not None and mask.shape == gray_data.shape:
|
| 248 |
+
gray_data = np.where(mask > 0, gray_data, np.nan)
|
| 249 |
+
|
| 250 |
+
# Normalize to 0-255
|
| 251 |
+
valid_data = gray_data[~np.isnan(gray_data)]
|
| 252 |
+
if len(valid_data) > 0:
|
| 253 |
+
m, M = np.min(valid_data), np.max(valid_data)
|
| 254 |
+
if M > m:
|
| 255 |
+
normalized = ((gray_data - m) / (M - m) * 255).astype(np.uint8)
|
| 256 |
+
else:
|
| 257 |
+
normalized = np.zeros_like(gray_data, dtype=np.uint8)
|
| 258 |
+
else:
|
| 259 |
+
normalized = np.zeros_like(gray_data, dtype=np.uint8)
|
| 260 |
+
|
| 261 |
+
# Extract texture features
|
| 262 |
+
band_texture = texture_extractor.extract_all_texture_features(normalized)
|
| 263 |
+
texture_features[band_name] = band_texture
|
| 264 |
+
|
| 265 |
+
except Exception as e:
|
| 266 |
+
logger.error(f"Spectral texture extraction failed for {band_name}: {e}")
|
| 267 |
+
texture_features[band_name] = {}
|
| 268 |
+
|
| 269 |
+
except ImportError:
|
| 270 |
+
logger.warning("TextureExtractor not available for spectral texture analysis")
|
| 271 |
+
|
| 272 |
+
return texture_features
|
| 273 |
+
|
| 274 |
+
def _compute_skewness(self, data: np.ndarray) -> float:
|
| 275 |
+
"""Compute skewness of data."""
|
| 276 |
+
if len(data) < 3:
|
| 277 |
+
return 0.0
|
| 278 |
+
|
| 279 |
+
mean = np.mean(data)
|
| 280 |
+
std = np.std(data)
|
| 281 |
+
if std == 0:
|
| 282 |
+
return 0.0
|
| 283 |
+
|
| 284 |
+
return np.mean(((data - mean) / std) ** 3)
|
| 285 |
+
|
| 286 |
+
def _compute_kurtosis(self, data: np.ndarray) -> float:
|
| 287 |
+
"""Compute kurtosis of data."""
|
| 288 |
+
if len(data) < 4:
|
| 289 |
+
return 0.0
|
| 290 |
+
|
| 291 |
+
mean = np.mean(data)
|
| 292 |
+
std = np.std(data)
|
| 293 |
+
if std == 0:
|
| 294 |
+
return 0.0
|
| 295 |
+
|
| 296 |
+
return np.mean(((data - mean) / std) ** 4) - 3
|
| 297 |
+
|
| 298 |
+
def _compute_entropy(self, data: np.ndarray) -> float:
|
| 299 |
+
"""Compute entropy of data."""
|
| 300 |
+
if len(data) == 0:
|
| 301 |
+
return 0.0
|
| 302 |
+
|
| 303 |
+
# Create histogram
|
| 304 |
+
hist, _ = np.histogram(data, bins=256, range=(0, 256))
|
| 305 |
+
hist = hist / np.sum(hist) # Normalize
|
| 306 |
+
|
| 307 |
+
# Remove zero probabilities
|
| 308 |
+
hist = hist[hist > 0]
|
| 309 |
+
|
| 310 |
+
# Compute entropy
|
| 311 |
+
return -np.sum(hist * np.log2(hist))
|
| 312 |
+
|
| 313 |
+
def create_spectral_visualization(self, spectral_stack: Dict[str, np.ndarray],
|
| 314 |
+
pca_features: Dict[str, Any]) -> np.ndarray:
|
| 315 |
+
"""
|
| 316 |
+
Create visualization of spectral features.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
spectral_stack: Original spectral data
|
| 320 |
+
pca_features: PCA features
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
Visualization image
|
| 324 |
+
"""
|
| 325 |
+
try:
|
| 326 |
+
# Preferred visualization: RGB = (Red, Red-Edge, Green)
|
| 327 |
+
if 'red' in spectral_stack and 'red_edge' in spectral_stack and 'green' in spectral_stack:
|
| 328 |
+
red = spectral_stack['red'].squeeze()
|
| 329 |
+
red_edge = spectral_stack['red_edge'].squeeze()
|
| 330 |
+
green = spectral_stack['green'].squeeze()
|
| 331 |
+
|
| 332 |
+
# Normalize each band
|
| 333 |
+
red_norm = self._normalize_band(red)
|
| 334 |
+
red_edge_norm = self._normalize_band(red_edge)
|
| 335 |
+
green_norm = self._normalize_band(green)
|
| 336 |
+
|
| 337 |
+
# Create composite (Red, Red-Edge, Green)
|
| 338 |
+
rgb_composite = np.stack([red_norm, red_edge_norm, green_norm], axis=-1)
|
| 339 |
+
|
| 340 |
+
return rgb_composite.astype(np.uint8)
|
| 341 |
+
|
| 342 |
+
# Fallback visualization: RGB = (NIR, Red, Green)
|
| 343 |
+
if 'red' in spectral_stack and 'green' in spectral_stack and 'nir' in spectral_stack:
|
| 344 |
+
red = spectral_stack['red'].squeeze()
|
| 345 |
+
green = spectral_stack['green'].squeeze()
|
| 346 |
+
nir = spectral_stack['nir'].squeeze()
|
| 347 |
+
|
| 348 |
+
# Normalize each band
|
| 349 |
+
red_norm = self._normalize_band(red)
|
| 350 |
+
green_norm = self._normalize_band(green)
|
| 351 |
+
nir_norm = self._normalize_band(nir)
|
| 352 |
+
|
| 353 |
+
rgb_composite = np.stack([nir_norm, red_norm, green_norm], axis=-1)
|
| 354 |
+
|
| 355 |
+
return rgb_composite.astype(np.uint8)
|
| 356 |
+
|
| 357 |
+
# Fallback to first PCA component
|
| 358 |
+
elif 'pca_1' in pca_features:
|
| 359 |
+
pca1 = pca_features['pca_1']
|
| 360 |
+
pca1_norm = self._normalize_band(pca1)
|
| 361 |
+
return np.stack([pca1_norm, pca1_norm, pca1_norm], axis=-1).astype(np.uint8)
|
| 362 |
+
|
| 363 |
+
else:
|
| 364 |
+
# Return empty image
|
| 365 |
+
return np.zeros((100, 100, 3), dtype=np.uint8)
|
| 366 |
+
|
| 367 |
+
except Exception as e:
|
| 368 |
+
logger.error(f"Spectral visualization creation failed: {e}")
|
| 369 |
+
return np.zeros((100, 100, 3), dtype=np.uint8)
|
| 370 |
+
|
| 371 |
+
def _normalize_band(self, band: np.ndarray) -> np.ndarray:
|
| 372 |
+
"""Normalize band to 0-255 range."""
|
| 373 |
+
valid_data = band[~np.isnan(band)]
|
| 374 |
+
if len(valid_data) == 0:
|
| 375 |
+
return np.zeros_like(band, dtype=np.uint8)
|
| 376 |
+
|
| 377 |
+
m, M = np.min(valid_data), np.max(valid_data)
|
| 378 |
+
if M > m:
|
| 379 |
+
normalized = ((band - m) / (M - m) * 255).astype(np.uint8)
|
| 380 |
+
else:
|
| 381 |
+
normalized = np.zeros_like(band, dtype=np.uint8)
|
| 382 |
+
|
| 383 |
+
return normalized
|
sorghum_pipeline/features/texture.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Texture feature extraction for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This module handles extraction of texture features including:
|
| 5 |
+
- Local Binary Patterns (LBP)
|
| 6 |
+
- Histogram of Oriented Gradients (HOG)
|
| 7 |
+
- Lacunarity features
|
| 8 |
+
- Edge Histogram Descriptor (EHD)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import cv2
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from skimage.feature import local_binary_pattern, hog
|
| 16 |
+
from skimage import exposure
|
| 17 |
+
from scipy import ndimage, signal
|
| 18 |
+
from sklearn.decomposition import PCA
|
| 19 |
+
from typing import Dict, Tuple, Optional, Any
|
| 20 |
+
import logging
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TextureExtractor:
|
| 26 |
+
"""Extracts texture features from images."""
|
| 27 |
+
|
| 28 |
+
def __init__(self,
|
| 29 |
+
lbp_points: int = 8,
|
| 30 |
+
lbp_radius: int = 1,
|
| 31 |
+
hog_orientations: int = 9,
|
| 32 |
+
hog_pixels_per_cell: Tuple[int, int] = (8, 8),
|
| 33 |
+
hog_cells_per_block: Tuple[int, int] = (2, 2),
|
| 34 |
+
lacunarity_window: int = 15,
|
| 35 |
+
ehd_threshold: float = 0.3,
|
| 36 |
+
angle_resolution: int = 45):
|
| 37 |
+
"""
|
| 38 |
+
Initialize texture extractor.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
lbp_points: Number of points for LBP
|
| 42 |
+
lbp_radius: Radius for LBP
|
| 43 |
+
hog_orientations: Number of orientations for HOG
|
| 44 |
+
hog_pixels_per_cell: Pixels per cell for HOG
|
| 45 |
+
hog_cells_per_block: Cells per block for HOG
|
| 46 |
+
lacunarity_window: Window size for lacunarity
|
| 47 |
+
ehd_threshold: Threshold for EHD
|
| 48 |
+
angle_resolution: Angle resolution for EHD
|
| 49 |
+
"""
|
| 50 |
+
self.lbp_points = lbp_points
|
| 51 |
+
self.lbp_radius = lbp_radius
|
| 52 |
+
self.hog_orientations = hog_orientations
|
| 53 |
+
self.hog_pixels_per_cell = hog_pixels_per_cell
|
| 54 |
+
self.hog_cells_per_block = hog_cells_per_block
|
| 55 |
+
self.lacunarity_window = lacunarity_window
|
| 56 |
+
self.ehd_threshold = ehd_threshold
|
| 57 |
+
self.angle_resolution = angle_resolution
|
| 58 |
+
|
| 59 |
+
def extract_lbp(self, gray_image: np.ndarray) -> np.ndarray:
|
| 60 |
+
"""
|
| 61 |
+
Extract Local Binary Pattern features.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
gray_image: Grayscale input image
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
LBP feature map
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
lbp = local_binary_pattern(
|
| 71 |
+
gray_image,
|
| 72 |
+
self.lbp_points,
|
| 73 |
+
self.lbp_radius,
|
| 74 |
+
method='uniform'
|
| 75 |
+
)
|
| 76 |
+
return self._convert_to_uint8(lbp)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.error(f"LBP extraction failed: {e}")
|
| 79 |
+
return np.zeros_like(gray_image, dtype=np.uint8)
|
| 80 |
+
|
| 81 |
+
def extract_hog(self, gray_image: np.ndarray) -> np.ndarray:
|
| 82 |
+
"""
|
| 83 |
+
Extract Histogram of Oriented Gradients features.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
gray_image: Grayscale input image
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
HOG feature map
|
| 90 |
+
"""
|
| 91 |
+
try:
|
| 92 |
+
_, vis = hog(
|
| 93 |
+
gray_image,
|
| 94 |
+
orientations=self.hog_orientations,
|
| 95 |
+
pixels_per_cell=self.hog_pixels_per_cell,
|
| 96 |
+
cells_per_block=self.hog_cells_per_block,
|
| 97 |
+
visualize=True,
|
| 98 |
+
feature_vector=True
|
| 99 |
+
)
|
| 100 |
+
return exposure.rescale_intensity(vis, out_range=(0, 255)).astype(np.uint8)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"HOG extraction failed: {e}")
|
| 103 |
+
return np.zeros_like(gray_image, dtype=np.uint8)
|
| 104 |
+
|
| 105 |
+
def compute_local_lacunarity(self, gray_image: np.ndarray, window_size: int) -> np.ndarray:
|
| 106 |
+
"""
|
| 107 |
+
Compute local lacunarity.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
gray_image: Grayscale input image
|
| 111 |
+
window_size: Size of the sliding window
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Local lacunarity map
|
| 115 |
+
"""
|
| 116 |
+
try:
|
| 117 |
+
arr = gray_image.astype(np.float32)
|
| 118 |
+
m1 = ndimage.uniform_filter(arr, size=window_size)
|
| 119 |
+
m2 = ndimage.uniform_filter(arr * arr, size=window_size)
|
| 120 |
+
var = m2 - m1 * m1
|
| 121 |
+
eps = 1e-6
|
| 122 |
+
lac = var / (m1 * m1 + eps) + 1
|
| 123 |
+
lac[m1 <= eps] = 0
|
| 124 |
+
return lac
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.error(f"Local lacunarity computation failed: {e}")
|
| 127 |
+
return np.zeros_like(gray_image, dtype=np.float32)
|
| 128 |
+
|
| 129 |
+
def compute_lacunarity_features(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 130 |
+
"""
|
| 131 |
+
Compute three types of lacunarity features.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
gray_image: Grayscale input image
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Tuple of (lac1, lac2, lac3) lacunarity maps
|
| 138 |
+
"""
|
| 139 |
+
try:
|
| 140 |
+
# L1: Single window lacunarity
|
| 141 |
+
lac1 = self.compute_local_lacunarity(gray_image, self.lacunarity_window)
|
| 142 |
+
|
| 143 |
+
# L2: Multi-scale lacunarity
|
| 144 |
+
scales = [max(3, self.lacunarity_window//2), self.lacunarity_window, self.lacunarity_window*2]
|
| 145 |
+
lac2 = np.mean([
|
| 146 |
+
self.compute_local_lacunarity(gray_image, s) for s in scales
|
| 147 |
+
], axis=0)
|
| 148 |
+
|
| 149 |
+
# L3: DBC Lacunarity (if available)
|
| 150 |
+
try:
|
| 151 |
+
from ..models.dbc_lacunarity import DBC_Lacunarity
|
| 152 |
+
x = torch.from_numpy(gray_image.astype(np.float32)/255.0)[None, None]
|
| 153 |
+
layer = DBC_Lacunarity(window_size=self.lacunarity_window).eval()
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
lac3 = layer(x).squeeze().cpu().numpy()
|
| 156 |
+
except ImportError:
|
| 157 |
+
logger.warning("DBC Lacunarity not available, using L2 as L3")
|
| 158 |
+
lac3 = lac2.copy()
|
| 159 |
+
|
| 160 |
+
return (
|
| 161 |
+
self._convert_to_uint8(lac1),
|
| 162 |
+
self._convert_to_uint8(lac2),
|
| 163 |
+
self._convert_to_uint8(lac3)
|
| 164 |
+
)
|
| 165 |
+
except Exception as e:
|
| 166 |
+
logger.error(f"Lacunarity features computation failed: {e}")
|
| 167 |
+
empty = np.zeros_like(gray_image, dtype=np.uint8)
|
| 168 |
+
return empty, empty, empty
|
| 169 |
+
|
| 170 |
+
def generate_ehd_masks(self, mask_size: int = 3) -> np.ndarray:
|
| 171 |
+
"""
|
| 172 |
+
Generate masks for Edge Histogram Descriptor.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
mask_size: Size of the mask
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Array of EHD masks
|
| 179 |
+
"""
|
| 180 |
+
if mask_size < 3:
|
| 181 |
+
mask_size = 3
|
| 182 |
+
if mask_size % 2 == 0:
|
| 183 |
+
mask_size += 1
|
| 184 |
+
|
| 185 |
+
# Base gradient mask
|
| 186 |
+
Gy = np.outer([1, 0, -1], [1, 2, 1])
|
| 187 |
+
|
| 188 |
+
# Expand if needed
|
| 189 |
+
if mask_size > 3:
|
| 190 |
+
expd = np.outer([1, 2, 1], [1, 2, 1])
|
| 191 |
+
for _ in range((mask_size - 3) // 2):
|
| 192 |
+
Gy = signal.convolve2d(expd, Gy, mode='full')
|
| 193 |
+
|
| 194 |
+
# Generate masks for different angles
|
| 195 |
+
angles = np.arange(0, 360, self.angle_resolution)
|
| 196 |
+
masks = np.zeros((len(angles), mask_size, mask_size), dtype=np.float32)
|
| 197 |
+
|
| 198 |
+
for i, angle in enumerate(angles):
|
| 199 |
+
masks[i] = ndimage.rotate(Gy, angle, reshape=False, mode='nearest')
|
| 200 |
+
|
| 201 |
+
return masks
|
| 202 |
+
|
| 203 |
+
def extract_ehd_features(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 204 |
+
"""
|
| 205 |
+
Extract Edge Histogram Descriptor features.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
gray_image: Grayscale input image
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
Tuple of (ehd_features, ehd_map)
|
| 212 |
+
"""
|
| 213 |
+
try:
|
| 214 |
+
# Generate masks
|
| 215 |
+
masks = self.generate_ehd_masks()
|
| 216 |
+
|
| 217 |
+
# Convert to tensor
|
| 218 |
+
X = torch.from_numpy(gray_image.astype(np.float32)/255.0).unsqueeze(0).unsqueeze(0)
|
| 219 |
+
masks_tensor = torch.tensor(masks).unsqueeze(1).float()
|
| 220 |
+
|
| 221 |
+
# Convolve with masks
|
| 222 |
+
edge_responses = F.conv2d(X, masks_tensor, dilation=7)
|
| 223 |
+
|
| 224 |
+
# Find maximum response
|
| 225 |
+
values, indices = torch.max(edge_responses, dim=1)
|
| 226 |
+
indices[values < self.ehd_threshold] = masks.shape[0]
|
| 227 |
+
|
| 228 |
+
# Pool features
|
| 229 |
+
feat_vect = []
|
| 230 |
+
for edge in range(masks.shape[0] + 1):
|
| 231 |
+
pooled = F.avg_pool2d(
|
| 232 |
+
(indices == edge).unsqueeze(1).float(),
|
| 233 |
+
kernel_size=5, stride=1, padding=2
|
| 234 |
+
)
|
| 235 |
+
feat_vect.append(pooled.squeeze(1))
|
| 236 |
+
|
| 237 |
+
ehd_features = torch.stack(feat_vect, dim=1).squeeze(0).cpu().numpy()
|
| 238 |
+
ehd_map = np.argmax(ehd_features, axis=0).astype(np.uint8)
|
| 239 |
+
|
| 240 |
+
return ehd_features, ehd_map
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
logger.error(f"EHD features extraction failed: {e}")
|
| 244 |
+
empty_features = np.zeros((9, gray_image.shape[0]-4, gray_image.shape[1]-4), dtype=np.float32)
|
| 245 |
+
empty_map = np.zeros_like(gray_image, dtype=np.uint8)
|
| 246 |
+
return empty_features, empty_map
|
| 247 |
+
|
| 248 |
+
def extract_all_texture_features(self, gray_image: np.ndarray) -> Dict[str, np.ndarray]:
|
| 249 |
+
"""
|
| 250 |
+
Extract all texture features from a grayscale image.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
gray_image: Grayscale input image
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
Dictionary of texture features
|
| 257 |
+
"""
|
| 258 |
+
features = {}
|
| 259 |
+
|
| 260 |
+
try:
|
| 261 |
+
# LBP
|
| 262 |
+
features['lbp'] = self.extract_lbp(gray_image)
|
| 263 |
+
|
| 264 |
+
# HOG
|
| 265 |
+
features['hog'] = self.extract_hog(gray_image)
|
| 266 |
+
|
| 267 |
+
# Lacunarity
|
| 268 |
+
lac1, lac2, lac3 = self.compute_lacunarity_features(gray_image)
|
| 269 |
+
features['lac1'] = lac1
|
| 270 |
+
features['lac2'] = lac2
|
| 271 |
+
features['lac3'] = lac3
|
| 272 |
+
|
| 273 |
+
# EHD
|
| 274 |
+
ehd_features, ehd_map = self.extract_ehd_features(gray_image)
|
| 275 |
+
features['ehd_features'] = ehd_features
|
| 276 |
+
features['ehd_map'] = ehd_map
|
| 277 |
+
|
| 278 |
+
logger.debug("All texture features extracted successfully")
|
| 279 |
+
|
| 280 |
+
except Exception as e:
|
| 281 |
+
logger.error(f"Texture feature extraction failed: {e}")
|
| 282 |
+
# Return empty features
|
| 283 |
+
features = {
|
| 284 |
+
'lbp': np.zeros_like(gray_image, dtype=np.uint8),
|
| 285 |
+
'hog': np.zeros_like(gray_image, dtype=np.uint8),
|
| 286 |
+
'lac1': np.zeros_like(gray_image, dtype=np.uint8),
|
| 287 |
+
'lac2': np.zeros_like(gray_image, dtype=np.uint8),
|
| 288 |
+
'lac3': np.zeros_like(gray_image, dtype=np.uint8),
|
| 289 |
+
'ehd_features': np.zeros((9, gray_image.shape[0]-4, gray_image.shape[1]-4), dtype=np.float32),
|
| 290 |
+
'ehd_map': np.zeros_like(gray_image, dtype=np.uint8)
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
return features
|
| 294 |
+
|
| 295 |
+
def _convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
|
| 296 |
+
"""Convert array to uint8 with proper normalization."""
|
| 297 |
+
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 298 |
+
if arr.ptp() > 0:
|
| 299 |
+
normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
|
| 300 |
+
else:
|
| 301 |
+
normalized = np.zeros_like(arr)
|
| 302 |
+
return np.clip(normalized, 0, 255).astype(np.uint8)
|
| 303 |
+
|
| 304 |
+
def compute_texture_statistics(self, features: Dict[str, np.ndarray],
|
| 305 |
+
mask: Optional[np.ndarray] = None) -> Dict[str, Dict[str, float]]:
|
| 306 |
+
"""
|
| 307 |
+
Compute statistics for texture features.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
features: Dictionary of texture features
|
| 311 |
+
mask: Optional mask to apply
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
Dictionary of feature statistics
|
| 315 |
+
"""
|
| 316 |
+
stats = {}
|
| 317 |
+
|
| 318 |
+
for feature_name, feature_data in features.items():
|
| 319 |
+
if feature_name == 'ehd_features':
|
| 320 |
+
# Special handling for EHD features
|
| 321 |
+
if mask is not None:
|
| 322 |
+
# Apply mask to each channel
|
| 323 |
+
masked_features = []
|
| 324 |
+
for i in range(feature_data.shape[0]):
|
| 325 |
+
channel = feature_data[i]
|
| 326 |
+
if mask.shape != channel.shape:
|
| 327 |
+
# Resize mask to match channel
|
| 328 |
+
mask_resized = cv2.resize(mask, (channel.shape[1], channel.shape[0]),
|
| 329 |
+
interpolation=cv2.INTER_NEAREST)
|
| 330 |
+
masked_channel = np.where(mask_resized > 0, channel, np.nan)
|
| 331 |
+
else:
|
| 332 |
+
masked_channel = np.where(mask > 0, channel, np.nan)
|
| 333 |
+
masked_features.append(masked_channel)
|
| 334 |
+
feature_data = np.stack(masked_features, axis=0)
|
| 335 |
+
else:
|
| 336 |
+
feature_data = feature_data
|
| 337 |
+
|
| 338 |
+
# Compute statistics for each EHD channel
|
| 339 |
+
channel_stats = {}
|
| 340 |
+
for i in range(feature_data.shape[0]):
|
| 341 |
+
channel = feature_data[i]
|
| 342 |
+
valid_data = channel[~np.isnan(channel)]
|
| 343 |
+
if len(valid_data) > 0:
|
| 344 |
+
channel_stats[f'channel_{i}'] = {
|
| 345 |
+
'mean': float(np.mean(valid_data)),
|
| 346 |
+
'std': float(np.std(valid_data)),
|
| 347 |
+
'min': float(np.min(valid_data)),
|
| 348 |
+
'max': float(np.max(valid_data)),
|
| 349 |
+
'median': float(np.median(valid_data))
|
| 350 |
+
}
|
| 351 |
+
stats[feature_name] = channel_stats
|
| 352 |
+
else:
|
| 353 |
+
# Regular 2D features
|
| 354 |
+
if mask is not None and mask.shape == feature_data.shape:
|
| 355 |
+
masked_data = np.where(mask > 0, feature_data, np.nan)
|
| 356 |
+
else:
|
| 357 |
+
masked_data = feature_data
|
| 358 |
+
|
| 359 |
+
valid_data = masked_data[~np.isnan(masked_data)]
|
| 360 |
+
if len(valid_data) > 0:
|
| 361 |
+
stats[feature_name] = {
|
| 362 |
+
'mean': float(np.mean(valid_data)),
|
| 363 |
+
'std': float(np.std(valid_data)),
|
| 364 |
+
'min': float(np.min(valid_data)),
|
| 365 |
+
'max': float(np.max(valid_data)),
|
| 366 |
+
'median': float(np.median(valid_data))
|
| 367 |
+
}
|
| 368 |
+
else:
|
| 369 |
+
stats[feature_name] = {
|
| 370 |
+
'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0, 'median': 0.0
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
return stats
|
sorghum_pipeline/features/vegetation.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vegetation index extraction for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This module handles extraction of various vegetation indices
|
| 5 |
+
from multispectral data.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
from typing import Dict, Tuple, Optional, Any
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class VegetationIndexExtractor:
|
| 17 |
+
"""Extracts vegetation indices from spectral data."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, epsilon: float = 1e-10, soil_factor: float = 0.16):
|
| 20 |
+
"""
|
| 21 |
+
Initialize vegetation index extractor.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
epsilon: Small value to avoid division by zero
|
| 25 |
+
soil_factor: Soil factor for certain indices
|
| 26 |
+
"""
|
| 27 |
+
# Coerce to float in case config passed strings like "1e-10"
|
| 28 |
+
try:
|
| 29 |
+
self.epsilon = float(epsilon)
|
| 30 |
+
except Exception:
|
| 31 |
+
self.epsilon = 1e-10
|
| 32 |
+
try:
|
| 33 |
+
self.soil_factor = float(soil_factor)
|
| 34 |
+
except Exception:
|
| 35 |
+
self.soil_factor = 0.16
|
| 36 |
+
|
| 37 |
+
# Define vegetation index formulas
|
| 38 |
+
self.index_formulas = {
|
| 39 |
+
"NDVI": lambda nir, red: (nir - red) / (nir + red + self.epsilon),
|
| 40 |
+
"GNDVI": lambda nir, green: (nir - green) / (nir + green + self.epsilon),
|
| 41 |
+
"NDRE": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
|
| 42 |
+
"GRNDVI": lambda nir, green, red: (nir - (green + red)) / (nir + (green + red) + self.epsilon),
|
| 43 |
+
"TNDVI": lambda nir, red: np.sqrt(np.clip(((nir - red) / (nir + red + self.epsilon)) + 0.5, 0, None)),
|
| 44 |
+
"MGRVI": lambda green, red: (green**2 - red**2) / (green**2 + red**2 + self.epsilon),
|
| 45 |
+
"GRVI": lambda nir, green: nir / (green + self.epsilon),
|
| 46 |
+
"NGRDI": lambda green, red: (green - red) / (green + red + self.epsilon),
|
| 47 |
+
"MSAVI": lambda nir, red: 0.5 * (2.0 * nir + 1 - np.sqrt((2 * nir + 1)**2 - 8 * (nir - red))),
|
| 48 |
+
"OSAVI": lambda nir, red: (nir - red) / (nir + red + self.soil_factor + self.epsilon),
|
| 49 |
+
"TSAVI": lambda nir, red, s=0.33, a=0.5, X=1.5: (s * (nir - s * red - a)) / (a * nir + red - a * s + X * (1 + s**2) + self.epsilon),
|
| 50 |
+
"GSAVI": lambda nir, green, l=0.5: (1 + l) * (nir - green) / (nir + green + l + self.epsilon),
|
| 51 |
+
# Requested additions and aliases
|
| 52 |
+
"GOSAVI": lambda nir, green: (nir - green) / (nir + green + 0.16 + self.epsilon),
|
| 53 |
+
"GDVI": lambda nir, green: nir - green,
|
| 54 |
+
"NDWI": lambda green, nir: (green - nir) / (green + nir + self.epsilon),
|
| 55 |
+
"DSWI4": lambda green, red: green / (red + self.epsilon),
|
| 56 |
+
"CIRE": lambda nir, red_edge: (nir / (red_edge + self.epsilon)) - 1.0,
|
| 57 |
+
"LCI": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
|
| 58 |
+
"CIgreen": lambda nir, green: (nir / (green + self.epsilon)) - 1,
|
| 59 |
+
"MCARI": lambda red_edge, red, green: ((red_edge - red) - 0.2 * (red_edge - green)) * (red_edge / (red + self.epsilon)),
|
| 60 |
+
"MCARI1": lambda nir, red, green: 1.2 * (2.5 * (nir - red) - 1.3 * (nir - green)),
|
| 61 |
+
"MCARI2": lambda nir, red, green: (1.5 * (2.5 * (nir - red) - 1.3 * (nir - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon))),
|
| 62 |
+
# MTVI variants per request
|
| 63 |
+
"MTVI1": lambda nir, red, green: 1.2 * (1.2 * (nir - green) - 2.5 * (red - green)),
|
| 64 |
+
"MTVI2": lambda nir, red, green: (1.5 * (1.2 * (nir - green) - 2.5 * (red - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon)) - 0.5 + self.epsilon),
|
| 65 |
+
"CVI": lambda nir, red, green: (nir * red) / (green**2 + self.epsilon),
|
| 66 |
+
"ARI": lambda green, red_edge: (1.0 / (green + self.epsilon)) - (1.0 / (red_edge + self.epsilon)),
|
| 67 |
+
"ARI2": lambda nir, green, red_edge: nir * (1.0 / (green + self.epsilon)) - nir * (1.0 / (red_edge + self.epsilon)),
|
| 68 |
+
"DVI": lambda nir, red: nir - red,
|
| 69 |
+
"WDVI": lambda nir, red, a=0.5: nir - a * red,
|
| 70 |
+
"SR": lambda nir, red: nir / (red + self.epsilon),
|
| 71 |
+
"MSR": lambda nir, red: (nir / (red + self.epsilon) - 1) / np.sqrt(nir / (red + self.epsilon) + 1),
|
| 72 |
+
"PVI": lambda nir, red, a=0.5, b=0.3: (nir - a * red - b) / (np.sqrt(1 + a**2) + self.epsilon),
|
| 73 |
+
"GEMI": lambda nir, red: ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon)) * (1 - 0.25 * ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon))) - ((red - 0.125) / (1 - red + self.epsilon)),
|
| 74 |
+
"ExR": lambda red, green: 1.3 * red - green,
|
| 75 |
+
"RI": lambda red, green: (red - green) / (red + green + self.epsilon),
|
| 76 |
+
"RRI1": lambda nir, red_edge: nir / (red_edge + self.epsilon),
|
| 77 |
+
"RRI2": lambda red_edge, red: red_edge / (red + self.epsilon),
|
| 78 |
+
"RRI": lambda nir, red_edge: nir / (red_edge + self.epsilon),
|
| 79 |
+
"AVI": lambda nir, red: np.cbrt(nir * (1.0 - red) * (nir - red + self.epsilon)),
|
| 80 |
+
"SIPI2": lambda nir, green, red: (nir - green) / (nir - red + self.epsilon),
|
| 81 |
+
"TCARI": lambda red_edge, red, green: 3 * ((red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))),
|
| 82 |
+
"TCARIOSAVI": lambda red_edge, red, green, nir: (3 * (red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))) / (1 + 0.16 * ((nir - red) / (nir + red + 0.16 + self.epsilon))),
|
| 83 |
+
"CCCI": lambda nir, red_edge, red: (((nir - red_edge) * (nir + red)) / ((nir + red_edge) * (nir - red) + self.epsilon)),
|
| 84 |
+
# Additional indices
|
| 85 |
+
"RDVI": lambda nir, red: (nir - red) / (np.sqrt(nir + red + self.epsilon)),
|
| 86 |
+
"NLI": lambda nir, red: ((nir**2) - red) / ((nir**2) + red + self.epsilon),
|
| 87 |
+
"BIXS": lambda green, red: np.sqrt(((green**2) + (red**2)) / 2.0),
|
| 88 |
+
"IPVI": lambda nir, red: nir / (nir + red + self.epsilon),
|
| 89 |
+
"EVI2": lambda nir, red: 2.4 * (nir - red) / (nir + red + 1.0 + self.epsilon)
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# Define required bands for each index
|
| 93 |
+
self.index_bands = {
|
| 94 |
+
"NDVI": ["nir", "red"],
|
| 95 |
+
"GNDVI": ["nir", "green"],
|
| 96 |
+
"NDRE": ["nir", "red_edge"],
|
| 97 |
+
"GRNDVI": ["nir", "green", "red"],
|
| 98 |
+
"TNDVI": ["nir", "red"],
|
| 99 |
+
"MGRVI": ["green", "red"],
|
| 100 |
+
"GRVI": ["nir", "green"],
|
| 101 |
+
"NGRDI": ["green", "red"],
|
| 102 |
+
"MSAVI": ["nir", "red"],
|
| 103 |
+
"OSAVI": ["nir", "red"],
|
| 104 |
+
"TSAVI": ["nir", "red"],
|
| 105 |
+
"GSAVI": ["nir", "green"],
|
| 106 |
+
"GOSAVI": ["nir", "green"],
|
| 107 |
+
"GDVI": ["nir", "green"],
|
| 108 |
+
"NDWI": ["green", "nir"],
|
| 109 |
+
"DSWI4": ["green", "red"],
|
| 110 |
+
"CIRE": ["nir", "red_edge"],
|
| 111 |
+
"LCI": ["nir", "red_edge"],
|
| 112 |
+
"CIgreen": ["nir", "green"],
|
| 113 |
+
"MCARI": ["red_edge", "red", "green"],
|
| 114 |
+
"MCARI1": ["nir", "red", "green"],
|
| 115 |
+
"MCARI2": ["nir", "red", "green"],
|
| 116 |
+
"MTVI1": ["nir", "red", "green"],
|
| 117 |
+
"MTVI2": ["nir", "red", "green"],
|
| 118 |
+
"CVI": ["nir", "red", "green"],
|
| 119 |
+
"ARI": ["green", "red_edge"],
|
| 120 |
+
"ARI2": ["nir", "green", "red_edge"],
|
| 121 |
+
"DVI": ["nir", "red"],
|
| 122 |
+
"WDVI": ["nir", "red"],
|
| 123 |
+
"SR": ["nir", "red"],
|
| 124 |
+
"MSR": ["nir", "red"],
|
| 125 |
+
"PVI": ["nir", "red"],
|
| 126 |
+
"GEMI": ["nir", "red"],
|
| 127 |
+
"ExR": ["red", "green"],
|
| 128 |
+
"RI": ["red", "green"],
|
| 129 |
+
"RRI1": ["nir", "red_edge"],
|
| 130 |
+
"RRI2": ["red_edge", "red"],
|
| 131 |
+
"RRI": ["nir", "red_edge"],
|
| 132 |
+
"AVI": ["nir", "red"],
|
| 133 |
+
"SIPI2": ["nir", "green", "red"],
|
| 134 |
+
"TCARI": ["red_edge", "red", "green"],
|
| 135 |
+
"TCARIOSAVI": ["red_edge", "red", "green", "nir"],
|
| 136 |
+
"CCCI": ["nir", "red_edge", "red"],
|
| 137 |
+
"RDVI": ["nir", "red"],
|
| 138 |
+
"NLI": ["nir", "red"],
|
| 139 |
+
"BIXS": ["green", "red"],
|
| 140 |
+
"IPVI": ["nir", "red"],
|
| 141 |
+
"EVI2": ["nir", "red"]
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
def compute_vegetation_indices(self, spectral_stack: Dict[str, np.ndarray],
|
| 145 |
+
mask: np.ndarray) -> Dict[str, Dict[str, Any]]:
|
| 146 |
+
"""
|
| 147 |
+
Compute vegetation indices from spectral data.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
spectral_stack: Dictionary of spectral bands
|
| 151 |
+
mask: Binary mask for the plant
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Dictionary of vegetation indices with values and statistics
|
| 155 |
+
"""
|
| 156 |
+
indices = {}
|
| 157 |
+
|
| 158 |
+
for index_name, formula in self.index_formulas.items():
|
| 159 |
+
try:
|
| 160 |
+
# Get required bands
|
| 161 |
+
required_bands = self.index_bands.get(index_name, [])
|
| 162 |
+
|
| 163 |
+
# Check if all required bands are available
|
| 164 |
+
if not all(band in spectral_stack for band in required_bands):
|
| 165 |
+
logger.warning(f"Skipping {index_name}: missing required bands")
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
# Extract band data as float arrays
|
| 169 |
+
band_data = []
|
| 170 |
+
for band in required_bands:
|
| 171 |
+
arr = spectral_stack[band]
|
| 172 |
+
# Ensure numeric float np.ndarray
|
| 173 |
+
if isinstance(arr, np.ndarray):
|
| 174 |
+
arr = arr.squeeze(-1)
|
| 175 |
+
arr = np.asarray(arr, dtype=np.float64)
|
| 176 |
+
band_data.append(arr)
|
| 177 |
+
|
| 178 |
+
# Compute index (ensure float math)
|
| 179 |
+
index_values = formula(*band_data).astype(np.float64)
|
| 180 |
+
|
| 181 |
+
# Apply mask
|
| 182 |
+
if mask is not None:
|
| 183 |
+
binary_mask = (np.asarray(mask).astype(np.int32) > 0)
|
| 184 |
+
masked_values = np.where(binary_mask, index_values, np.nan)
|
| 185 |
+
else:
|
| 186 |
+
masked_values = index_values
|
| 187 |
+
|
| 188 |
+
# Compute statistics
|
| 189 |
+
valid_values = masked_values[~np.isnan(masked_values)]
|
| 190 |
+
if len(valid_values) > 0:
|
| 191 |
+
stats = {
|
| 192 |
+
'mean': float(np.mean(valid_values)),
|
| 193 |
+
'std': float(np.std(valid_values)),
|
| 194 |
+
'min': float(np.min(valid_values)),
|
| 195 |
+
'max': float(np.max(valid_values)),
|
| 196 |
+
'median': float(np.median(valid_values)),
|
| 197 |
+
'q25': float(np.percentile(valid_values, 25)),
|
| 198 |
+
'q75': float(np.percentile(valid_values, 75)),
|
| 199 |
+
'nan_fraction': float(np.isnan(masked_values).sum() / masked_values.size)
|
| 200 |
+
}
|
| 201 |
+
else:
|
| 202 |
+
stats = {
|
| 203 |
+
'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0,
|
| 204 |
+
'median': 0.0, 'q25': 0.0, 'q75': 0.0, 'nan_fraction': 1.0
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
indices[index_name] = {
|
| 208 |
+
'values': masked_values,
|
| 209 |
+
'statistics': stats
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
logger.debug(f"Computed {index_name}")
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.error(f"Failed to compute {index_name}: {e}")
|
| 216 |
+
continue
|
| 217 |
+
|
| 218 |
+
return indices
|
| 219 |
+
|
| 220 |
+
def create_vegetation_index_image(self, index_values: np.ndarray,
|
| 221 |
+
colormap: str = 'RdYlGn',
|
| 222 |
+
vmin: Optional[float] = None,
|
| 223 |
+
vmax: Optional[float] = None) -> np.ndarray:
|
| 224 |
+
"""
|
| 225 |
+
Create visualization image for vegetation index.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
index_values: Vegetation index values
|
| 229 |
+
colormap: Matplotlib colormap name
|
| 230 |
+
vmin: Minimum value for normalization
|
| 231 |
+
vmax: Maximum value for normalization
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
RGB image array
|
| 235 |
+
"""
|
| 236 |
+
try:
|
| 237 |
+
import matplotlib.pyplot as plt
|
| 238 |
+
import matplotlib.cm as cm
|
| 239 |
+
from matplotlib.colors import Normalize
|
| 240 |
+
|
| 241 |
+
# Determine value range
|
| 242 |
+
valid_values = index_values[~np.isnan(index_values)]
|
| 243 |
+
if len(valid_values) == 0:
|
| 244 |
+
return np.zeros((*index_values.shape, 3), dtype=np.uint8)
|
| 245 |
+
|
| 246 |
+
if vmin is None:
|
| 247 |
+
vmin = np.min(valid_values)
|
| 248 |
+
if vmax is None:
|
| 249 |
+
vmax = np.max(valid_values)
|
| 250 |
+
|
| 251 |
+
# Normalize values
|
| 252 |
+
norm = Normalize(vmin=vmin, vmax=vmax)
|
| 253 |
+
cmap = cm.get_cmap(colormap)
|
| 254 |
+
|
| 255 |
+
# Apply colormap
|
| 256 |
+
rgba_img = cmap(norm(index_values))
|
| 257 |
+
rgba_img[np.isnan(index_values)] = [1, 1, 1, 1] # White for NaN
|
| 258 |
+
|
| 259 |
+
# Convert to RGB uint8
|
| 260 |
+
rgb_img = (rgba_img[:, :, :3] * 255).astype(np.uint8)
|
| 261 |
+
|
| 262 |
+
return rgb_img
|
| 263 |
+
|
| 264 |
+
except Exception as e:
|
| 265 |
+
logger.error(f"Failed to create vegetation index image: {e}")
|
| 266 |
+
return np.zeros((*index_values.shape, 3), dtype=np.uint8)
|
| 267 |
+
|
| 268 |
+
def get_available_indices(self) -> list:
|
| 269 |
+
"""Get list of available vegetation indices."""
|
| 270 |
+
return list(self.index_formulas.keys())
|
| 271 |
+
|
| 272 |
+
def get_index_requirements(self, index_name: str) -> list:
|
| 273 |
+
"""
|
| 274 |
+
Get required bands for a specific index.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
index_name: Name of the vegetation index
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
List of required band names
|
| 281 |
+
"""
|
| 282 |
+
return self.index_bands.get(index_name, [])
|
| 283 |
+
|
| 284 |
+
def validate_spectral_data(self, spectral_stack: Dict[str, np.ndarray]) -> bool:
|
| 285 |
+
"""
|
| 286 |
+
Validate spectral data for vegetation index computation.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
spectral_stack: Dictionary of spectral bands
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
True if valid, False otherwise
|
| 293 |
+
"""
|
| 294 |
+
if not spectral_stack:
|
| 295 |
+
return False
|
| 296 |
+
|
| 297 |
+
required_bands = ['nir', 'red', 'green', 'red_edge']
|
| 298 |
+
if not all(band in spectral_stack for band in required_bands):
|
| 299 |
+
logger.warning("Missing required spectral bands")
|
| 300 |
+
return False
|
| 301 |
+
|
| 302 |
+
# Check data shapes
|
| 303 |
+
shapes = [arr.shape for arr in spectral_stack.values()]
|
| 304 |
+
if not all(shape == shapes[0] for shape in shapes):
|
| 305 |
+
logger.warning("Inconsistent spectral band shapes")
|
| 306 |
+
return False
|
| 307 |
+
|
| 308 |
+
return True
|
sorghum_pipeline/models/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model definitions for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This package contains neural network models and other
|
| 5 |
+
computational models used in the pipeline.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .dbc_lacunarity import DBC_Lacunarity
|
| 9 |
+
|
| 10 |
+
__all__ = ["DBC_Lacunarity"]
|
sorghum_pipeline/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (438 Bytes). View file
|
|
|
sorghum_pipeline/models/__pycache__/dbc_lacunarity.cpython-312.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
sorghum_pipeline/models/dbc_lacunarity.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DBC Lacunarity model for texture analysis.
|
| 3 |
+
|
| 4 |
+
This module implements the Differential Box Counting (DBC) method
|
| 5 |
+
for computing lacunarity features.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DBC_Lacunarity(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Differential Box Counting Lacunarity model.
|
| 16 |
+
|
| 17 |
+
This model computes lacunarity features using the DBC method,
|
| 18 |
+
which is useful for texture analysis in plant images.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, model_name: str = 'Net', window_size: int = 3, eps: float = 1e-6):
|
| 22 |
+
"""
|
| 23 |
+
Initialize DBC Lacunarity model.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
model_name: Name of the model
|
| 27 |
+
window_size: Size of the sliding window
|
| 28 |
+
eps: Small value to avoid division by zero
|
| 29 |
+
"""
|
| 30 |
+
super(DBC_Lacunarity, self).__init__()
|
| 31 |
+
self.window_size = window_size
|
| 32 |
+
self.normalize = nn.Tanh()
|
| 33 |
+
self.num_output_channels = 3
|
| 34 |
+
self.eps = eps
|
| 35 |
+
self.r = 1
|
| 36 |
+
self.model_name = model_name
|
| 37 |
+
self.max_pool = nn.MaxPool2d(kernel_size=self.window_size, stride=1)
|
| 38 |
+
|
| 39 |
+
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
"""
|
| 41 |
+
Forward pass of the DBC Lacunarity model.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
image: Input image tensor [B, C, H, W]
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Lacunarity features tensor
|
| 48 |
+
"""
|
| 49 |
+
# Normalize image to 0-255 range
|
| 50 |
+
image = ((self.normalize(image) + 1) / 2) * 255
|
| 51 |
+
|
| 52 |
+
# Perform operations independently for each window in the current channel
|
| 53 |
+
max_pool_output = self.max_pool(image)
|
| 54 |
+
min_pool_output = -self.max_pool(-image)
|
| 55 |
+
|
| 56 |
+
# Compute DBC lacunarity
|
| 57 |
+
nr = torch.ceil(max_pool_output / (self.r + self.eps)) - torch.ceil(min_pool_output / (self.r + self.eps)) - 1
|
| 58 |
+
Mr = torch.sum(nr)
|
| 59 |
+
Q_mr = nr / (self.window_size - self.r + 1)
|
| 60 |
+
L_r = (Mr**2) * Q_mr / (Mr * Q_mr + self.eps)**2
|
| 61 |
+
|
| 62 |
+
return L_r
|
| 63 |
+
|
| 64 |
+
def compute_lacunarity(self, image: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
Compute lacunarity for a single image.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
image: Input image tensor [1, 1, H, W]
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Lacunarity tensor
|
| 73 |
+
"""
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
return self.forward(image)
|
| 76 |
+
|
| 77 |
+
def get_model_info(self) -> dict:
|
| 78 |
+
"""
|
| 79 |
+
Get model information.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Dictionary containing model parameters
|
| 83 |
+
"""
|
| 84 |
+
return {
|
| 85 |
+
'model_name': self.model_name,
|
| 86 |
+
'window_size': self.window_size,
|
| 87 |
+
'eps': self.eps,
|
| 88 |
+
'r': self.r,
|
| 89 |
+
'num_output_channels': self.num_output_channels
|
| 90 |
+
}
|
sorghum_pipeline/output/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Output management modules for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This package contains output functionality including:
|
| 5 |
+
- Result saving
|
| 6 |
+
- Visualization generation
|
| 7 |
+
- Report creation
|
| 8 |
+
- Data export
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from .manager import OutputManager
|
| 12 |
+
|
| 13 |
+
__all__ = ["OutputManager"]
|
sorghum_pipeline/output/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (470 Bytes). View file
|
|
|
sorghum_pipeline/output/__pycache__/manager.cpython-312.pyc
ADDED
|
Binary file (40.9 kB). View file
|
|
|
sorghum_pipeline/output/manager.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Output manager for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This module handles saving results, generating visualizations,
|
| 5 |
+
and creating reports.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
|
| 13 |
+
# Use a non-GUI backend to avoid segmentation faults in headless runs
|
| 14 |
+
try:
|
| 15 |
+
import matplotlib
|
| 16 |
+
if os.environ.get('MPLBACKEND') is None:
|
| 17 |
+
matplotlib.use('Agg')
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import matplotlib.cm as cm
|
| 20 |
+
from matplotlib.colors import Normalize
|
| 21 |
+
except Exception:
|
| 22 |
+
# Fallback safe imports (should not happen normally)
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
import matplotlib.cm as cm
|
| 25 |
+
from matplotlib.colors import Normalize
|
| 26 |
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Dict, Any, Optional, List, Tuple
|
| 29 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 30 |
+
import pandas as pd
|
| 31 |
+
import logging
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class OutputManager:
|
| 37 |
+
"""Manages output generation and saving."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, output_folder: str, settings: Any):
|
| 40 |
+
"""
|
| 41 |
+
Initialize output manager.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
output_folder: Base output folder
|
| 45 |
+
settings: Output settings from config
|
| 46 |
+
"""
|
| 47 |
+
self.output_folder = Path(output_folder)
|
| 48 |
+
self.settings = settings
|
| 49 |
+
# Fast mode and parallel save controls
|
| 50 |
+
try:
|
| 51 |
+
self.fast_mode: bool = bool(int(os.environ.get('FAST_OUTPUT', '0'))) or bool(getattr(settings, 'fast_mode', False))
|
| 52 |
+
except Exception:
|
| 53 |
+
self.fast_mode = False
|
| 54 |
+
try:
|
| 55 |
+
self.max_workers: int = int(os.environ.get('FAST_SAVE_WORKERS', '4'))
|
| 56 |
+
except Exception:
|
| 57 |
+
self.max_workers = 4
|
| 58 |
+
try:
|
| 59 |
+
self.png_compression: int = int(os.environ.get('PNG_COMPRESSION', '1')) # 0-9; 1 is fast
|
| 60 |
+
except Exception:
|
| 61 |
+
self.png_compression = 1
|
| 62 |
+
|
| 63 |
+
# Reduce thread usage to lower risk of native library segfaults
|
| 64 |
+
try:
|
| 65 |
+
import os as _os
|
| 66 |
+
_os.environ.setdefault('OMP_NUM_THREADS', '1')
|
| 67 |
+
_os.environ.setdefault('OPENBLAS_NUM_THREADS', '1')
|
| 68 |
+
_os.environ.setdefault('MKL_NUM_THREADS', '1')
|
| 69 |
+
_os.environ.setdefault('NUMEXPR_NUM_THREADS', '1')
|
| 70 |
+
except Exception:
|
| 71 |
+
pass
|
| 72 |
+
try:
|
| 73 |
+
cv2.setNumThreads(1)
|
| 74 |
+
except Exception:
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
# Create base directories
|
| 78 |
+
self.output_folder.mkdir(parents=True, exist_ok=True)
|
| 79 |
+
|
| 80 |
+
def _imwrite_fast(self, dest: Path, img: np.ndarray) -> None:
|
| 81 |
+
try:
|
| 82 |
+
cv2.imwrite(str(dest), img, [cv2.IMWRITE_PNG_COMPRESSION, int(self.png_compression)])
|
| 83 |
+
except Exception:
|
| 84 |
+
cv2.imwrite(str(dest), img)
|
| 85 |
+
|
| 86 |
+
def create_output_directories(self) -> None:
|
| 87 |
+
"""Ensure base output directory exists.
|
| 88 |
+
|
| 89 |
+
Note: Do NOT create subdirectories at the root (e.g., 'analysis').
|
| 90 |
+
Subdirectories are created within each plant's directory only.
|
| 91 |
+
"""
|
| 92 |
+
self.output_folder.mkdir(parents=True, exist_ok=True)
|
| 93 |
+
|
| 94 |
+
def save_plant_results(self, plant_key: str, plant_data: Dict[str, Any]) -> None:
|
| 95 |
+
"""
|
| 96 |
+
Save all results for a single plant.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
plant_key: Plant identifier (e.g., "2025_02_05_plant1_frame8")
|
| 100 |
+
plant_data: Plant data dictionary
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
# Parse plant key
|
| 104 |
+
parts = plant_key.split('_')
|
| 105 |
+
date_key = "_".join(parts[:3])
|
| 106 |
+
plant_name = parts[3]
|
| 107 |
+
frame_key = parts[4] if len(parts) > 4 else "frame0"
|
| 108 |
+
|
| 109 |
+
# Create plant-specific directory
|
| 110 |
+
plant_dir = self.output_folder / date_key / plant_name
|
| 111 |
+
plant_dir.mkdir(parents=True, exist_ok=True)
|
| 112 |
+
|
| 113 |
+
# Save segmentation results
|
| 114 |
+
self._save_segmentation_results(plant_dir, plant_name, plant_data)
|
| 115 |
+
|
| 116 |
+
# Save texture features
|
| 117 |
+
self._save_texture_features(plant_dir, plant_data)
|
| 118 |
+
|
| 119 |
+
# Save vegetation indices
|
| 120 |
+
self._save_vegetation_indices(plant_dir, plant_data)
|
| 121 |
+
|
| 122 |
+
# Save morphology features
|
| 123 |
+
self._save_morphology_features(plant_dir, plant_data)
|
| 124 |
+
|
| 125 |
+
# Save analysis plots
|
| 126 |
+
self._save_analysis_plots(plant_dir, plant_data)
|
| 127 |
+
|
| 128 |
+
# Save metadata
|
| 129 |
+
self._save_metadata(plant_dir, plant_key, plant_data)
|
| 130 |
+
|
| 131 |
+
logger.debug(f"Results saved for {plant_key}")
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Failed to save results for {plant_key}: {e}")
|
| 135 |
+
|
| 136 |
+
def _save_segmentation_results(self, plant_dir: Path, plant_name: str, plant_data: Dict[str, Any]) -> None:
|
| 137 |
+
"""Save segmentation results."""
|
| 138 |
+
if not self.settings.save_images:
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
seg_dir = plant_dir / self.settings.segmentation_dir
|
| 142 |
+
seg_dir.mkdir(exist_ok=True)
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
tasks: List[Tuple[Path, np.ndarray]] = []
|
| 146 |
+
# Choose which base image to present in original/overlay
|
| 147 |
+
use_feature_image = False
|
| 148 |
+
try:
|
| 149 |
+
# Allow env override, and special-case plants 13-16 per user requirement
|
| 150 |
+
use_feature_image = bool(int(os.environ.get('OUTPUT_USE_FEATURE_IMAGE', '0'))) or plant_name in { 'plant13','plant14','plant15','plant16' }
|
| 151 |
+
except Exception:
|
| 152 |
+
use_feature_image = plant_name in { 'plant13','plant14','plant15','plant16' }
|
| 153 |
+
if use_feature_image:
|
| 154 |
+
base_image = plant_data.get('composite', plant_data.get('segmentation_composite'))
|
| 155 |
+
else:
|
| 156 |
+
base_image = plant_data.get('segmentation_composite', plant_data.get('composite'))
|
| 157 |
+
if base_image is not None:
|
| 158 |
+
tasks.append((seg_dir / 'original.png', base_image))
|
| 159 |
+
if 'mask' in plant_data:
|
| 160 |
+
tasks.append((seg_dir / 'mask.png', plant_data['mask']))
|
| 161 |
+
if 'mask3' in plant_data and isinstance(plant_data['mask3'], np.ndarray):
|
| 162 |
+
tasks.append((seg_dir / 'mask3.png', plant_data['mask3']))
|
| 163 |
+
# Save the BRIA-generated mask (if present before overrides) as mask2.png
|
| 164 |
+
if 'original_mask' in plant_data and isinstance(plant_data['original_mask'], np.ndarray):
|
| 165 |
+
tasks.append((seg_dir / 'mask2.png', plant_data['original_mask']))
|
| 166 |
+
if base_image is not None and 'mask' in plant_data:
|
| 167 |
+
overlay = self._create_overlay(base_image, plant_data['mask'])
|
| 168 |
+
tasks.append((seg_dir / 'overlay.png', overlay))
|
| 169 |
+
if 'masked_composite' in plant_data:
|
| 170 |
+
tasks.append((seg_dir / 'masked_composite.png', plant_data['masked_composite']))
|
| 171 |
+
|
| 172 |
+
# Create white-background maskouts
|
| 173 |
+
try:
|
| 174 |
+
if base_image is not None and 'mask' in plant_data:
|
| 175 |
+
maskout_external = self._create_maskout_white_background(base_image, plant_data['mask'])
|
| 176 |
+
tasks.append((seg_dir / 'maskout_external.png', maskout_external))
|
| 177 |
+
# BRIA-only maskout directly on original composite
|
| 178 |
+
if base_image is not None and 'original_mask' in plant_data and isinstance(plant_data['original_mask'], np.ndarray):
|
| 179 |
+
maskout_bria = self._create_maskout_white_background(base_image, plant_data['original_mask'])
|
| 180 |
+
tasks.append((seg_dir / 'maskout_bria.png', maskout_bria))
|
| 181 |
+
# mask3 maskout on original composite
|
| 182 |
+
if base_image is not None and 'mask3' in plant_data and isinstance(plant_data['mask3'], np.ndarray):
|
| 183 |
+
maskout_mask3 = self._create_maskout_white_background(base_image, plant_data['mask3'])
|
| 184 |
+
tasks.append((seg_dir / 'maskout_mask3.png', maskout_mask3))
|
| 185 |
+
except Exception as _e:
|
| 186 |
+
logger.debug(f"Failed to create double maskouts: {_e}")
|
| 187 |
+
|
| 188 |
+
if self.max_workers > 1 and len(tasks) > 1:
|
| 189 |
+
with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
|
| 190 |
+
futures = [ex.submit(self._imwrite_fast, p, img) for p, img in tasks]
|
| 191 |
+
for _ in as_completed(futures):
|
| 192 |
+
pass
|
| 193 |
+
else:
|
| 194 |
+
for p, img in tasks:
|
| 195 |
+
self._imwrite_fast(p, img)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.error(f"Failed to save segmentation results: {e}")
|
| 198 |
+
|
| 199 |
+
def _save_texture_features(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
|
| 200 |
+
"""Save texture features."""
|
| 201 |
+
if not self.settings.save_images or 'texture_features' not in plant_data:
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
texture_dir = plant_dir / self.settings.texture_dir
|
| 205 |
+
texture_dir.mkdir(exist_ok=True)
|
| 206 |
+
|
| 207 |
+
def save_feature_png(feature_name: str, values: Any, dest: Path, cmap_name: str = 'viridis') -> None:
|
| 208 |
+
try:
|
| 209 |
+
arr = np.asarray(values)
|
| 210 |
+
if arr.ndim == 3 and arr.shape[-1] == 3:
|
| 211 |
+
self._imwrite_fast(dest, cv2.cvtColor(arr.astype(np.uint8), cv2.COLOR_RGB2BGR))
|
| 212 |
+
return
|
| 213 |
+
if self.fast_mode:
|
| 214 |
+
# Fast path: simple normalization, no matplotlib
|
| 215 |
+
normalized = self._normalize_to_uint8(np.nan_to_num(arr.astype(np.float64), nan=0.0))
|
| 216 |
+
self._imwrite_fast(dest, normalized)
|
| 217 |
+
else:
|
| 218 |
+
arr = arr.astype(np.float64)
|
| 219 |
+
masked = np.ma.masked_invalid(arr)
|
| 220 |
+
fig, ax = plt.subplots(figsize=(5, 5))
|
| 221 |
+
ax.set_axis_off()
|
| 222 |
+
ax.set_facecolor('white')
|
| 223 |
+
im = ax.imshow(masked, cmap=cmap_name)
|
| 224 |
+
divider = make_axes_locatable(ax)
|
| 225 |
+
cax = divider.append_axes("right", size="2%", pad=0.02)
|
| 226 |
+
cbar = plt.colorbar(im, cax=cax, orientation='vertical')
|
| 227 |
+
cbar.set_label(feature_name, fontsize=7)
|
| 228 |
+
cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
|
| 229 |
+
if hasattr(cbar, 'outline') and cbar.outline is not None:
|
| 230 |
+
cbar.outline.set_linewidth(0.5)
|
| 231 |
+
plt.tight_layout()
|
| 232 |
+
plt.savefig(dest, dpi=self.settings.plot_dpi, bbox_inches='tight')
|
| 233 |
+
plt.close(fig)
|
| 234 |
+
except Exception as e:
|
| 235 |
+
logger.error(f"Failed to save texture feature image for {feature_name}: {e}")
|
| 236 |
+
try:
|
| 237 |
+
normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
|
| 238 |
+
self._imwrite_fast(dest, normalized)
|
| 239 |
+
except Exception:
|
| 240 |
+
pass
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
texture_features = plant_data['texture_features']
|
| 244 |
+
|
| 245 |
+
for band, band_data in texture_features.items():
|
| 246 |
+
if 'features' not in band_data:
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
band_dir = texture_dir / band
|
| 250 |
+
band_dir.mkdir(exist_ok=True)
|
| 251 |
+
|
| 252 |
+
features = band_data['features']
|
| 253 |
+
|
| 254 |
+
# Save individual feature maps (optionally in parallel)
|
| 255 |
+
items: List[Tuple[str, np.ndarray, Path, str]] = []
|
| 256 |
+
for feature_name, feature_map in features.items():
|
| 257 |
+
if feature_name == 'ehd_features':
|
| 258 |
+
for i in range(feature_map.shape[0]):
|
| 259 |
+
channel = feature_map[i]
|
| 260 |
+
if isinstance(channel, np.ndarray) and channel.size > 0:
|
| 261 |
+
items.append((f'ehd_channel_{i}', channel, band_dir / f'ehd_channel_{i}.png', 'magma'))
|
| 262 |
+
else:
|
| 263 |
+
if isinstance(feature_map, np.ndarray) and feature_map.size > 0:
|
| 264 |
+
cmap_choice = 'gray' if feature_name in ('lbp', 'hog') else 'plasma' if feature_name.startswith('lac') else 'viridis'
|
| 265 |
+
items.append((feature_name, feature_map, band_dir / f'{feature_name}.png', cmap_choice))
|
| 266 |
+
|
| 267 |
+
if self.max_workers > 1 and len(items) > 1:
|
| 268 |
+
with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
|
| 269 |
+
futures = [ex.submit(save_feature_png, n, m, p, c) for (n, m, p, c) in items]
|
| 270 |
+
for _ in as_completed(futures):
|
| 271 |
+
pass
|
| 272 |
+
else:
|
| 273 |
+
for (n, m, p, c) in items:
|
| 274 |
+
save_feature_png(n, m, p, c)
|
| 275 |
+
|
| 276 |
+
# Create feature summary plot
|
| 277 |
+
self._create_texture_summary_plot(band_dir, features, band)
|
| 278 |
+
|
| 279 |
+
# Save texture statistics if available
|
| 280 |
+
if 'statistics' in band_data and isinstance(band_data['statistics'], dict):
|
| 281 |
+
try:
|
| 282 |
+
with open(band_dir / 'texture_statistics.json', 'w') as f:
|
| 283 |
+
json.dump(band_data['statistics'], f, indent=2)
|
| 284 |
+
except Exception as e:
|
| 285 |
+
logger.error(f"Failed to save texture statistics for {band}: {e}")
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.error(f"Failed to save texture features: {e}")
|
| 289 |
+
|
| 290 |
+
def _save_vegetation_indices(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
|
| 291 |
+
"""Save vegetation indices."""
|
| 292 |
+
if not self.settings.save_images or 'vegetation_indices' not in plant_data:
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
veg_dir = plant_dir / self.settings.vegetation_dir
|
| 296 |
+
veg_dir.mkdir(exist_ok=True)
|
| 297 |
+
|
| 298 |
+
# Colormap and range settings per index
|
| 299 |
+
index_cmap_settings = {
|
| 300 |
+
"NDVI": (cm.RdYlGn, -1, 1),
|
| 301 |
+
"GNDVI": (cm.RdYlGn, -1, 1),
|
| 302 |
+
"NDRE": (cm.RdYlGn, -1, 1),
|
| 303 |
+
"GRNDVI": (cm.RdYlGn, -1, 1),
|
| 304 |
+
"TNDVI": (cm.RdYlGn, -1, 1),
|
| 305 |
+
"MGRVI": (cm.RdYlGn, -1, 1),
|
| 306 |
+
"GRVI": (cm.RdYlGn, -1, 1),
|
| 307 |
+
"NGRDI": (cm.RdYlGn, -1, 1),
|
| 308 |
+
"MSAVI": (cm.YlGn, 0, 1),
|
| 309 |
+
"OSAVI": (cm.YlGn, 0, 1),
|
| 310 |
+
"TSAVI": (cm.YlGn, 0, 1),
|
| 311 |
+
"GSAVI": (cm.YlGn, 0, 1),
|
| 312 |
+
"NDWI": (cm.Blues, -1, 1),
|
| 313 |
+
"DSWI4": (cm.Blues, -1, 1),
|
| 314 |
+
"CIRE": (cm.viridis, 0, 10),
|
| 315 |
+
"LCI": (cm.viridis, 0, 5),
|
| 316 |
+
"CIgreen": (cm.viridis, 0, 5),
|
| 317 |
+
"MCARI": (cm.viridis, 0, 1.5),
|
| 318 |
+
"MCARI1": (cm.viridis, 0, 1.5),
|
| 319 |
+
"MCARI2": (cm.viridis, 0, 1.5),
|
| 320 |
+
"CVI": (cm.plasma, 0, 10),
|
| 321 |
+
"TCARI": (cm.viridis, 0, 1),
|
| 322 |
+
"TCARIOSAVI": (cm.viridis, 0, 1),
|
| 323 |
+
"AVI": (cm.magma, 0, 1),
|
| 324 |
+
"SIPI2": (cm.inferno, 0, 1),
|
| 325 |
+
"ARI": (cm.magma, 0, 1),
|
| 326 |
+
"ARI2": (cm.magma, 0, 1),
|
| 327 |
+
"DVI": (cm.Greens, 0, None),
|
| 328 |
+
"WDVI": (cm.Greens, 0, None),
|
| 329 |
+
"SR": (cm.viridis, 0, 10),
|
| 330 |
+
"MSR": (cm.viridis, 0, 10),
|
| 331 |
+
"PVI": (cm.cividis, None, None),
|
| 332 |
+
"GEMI": (cm.cividis, 0, 1),
|
| 333 |
+
"ExR": (cm.Reds, -1, 1),
|
| 334 |
+
"RI": (cm.Reds, 0, None),
|
| 335 |
+
"RRI1": (cm.Reds, 0, 1)
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
def save_index_png(index_name: str, values: Any, dest: Path) -> None:
|
| 339 |
+
try:
|
| 340 |
+
arr = values
|
| 341 |
+
if not isinstance(arr, (list, tuple,)) and isinstance(arr, (float, int)):
|
| 342 |
+
return
|
| 343 |
+
arr = np.asarray(arr, dtype=np.float64)
|
| 344 |
+
if self.fast_mode:
|
| 345 |
+
normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
|
| 346 |
+
self._imwrite_fast(dest, normalized)
|
| 347 |
+
else:
|
| 348 |
+
cmap, vmin, vmax = index_cmap_settings.get(index_name, (cm.viridis, np.nanmin(arr), np.nanmax(arr)))
|
| 349 |
+
if vmin is None:
|
| 350 |
+
vmin = np.nanmin(arr)
|
| 351 |
+
if vmax is None:
|
| 352 |
+
vmax = np.nanmax(arr)
|
| 353 |
+
if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
|
| 354 |
+
vmin, vmax = 0.0, 1.0
|
| 355 |
+
masked = np.ma.masked_invalid(arr)
|
| 356 |
+
fig, ax = plt.subplots(figsize=(5, 5))
|
| 357 |
+
ax.set_axis_off()
|
| 358 |
+
ax.set_facecolor('white')
|
| 359 |
+
im = ax.imshow(masked, cmap=cmap, vmin=vmin, vmax=vmax)
|
| 360 |
+
divider = make_axes_locatable(ax)
|
| 361 |
+
cax = divider.append_axes("right", size="2%", pad=0.02)
|
| 362 |
+
cbar = plt.colorbar(im, cax=cax, orientation='vertical')
|
| 363 |
+
cbar.set_label(index_name, fontsize=7)
|
| 364 |
+
cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
|
| 365 |
+
if hasattr(cbar, 'outline') and cbar.outline is not None:
|
| 366 |
+
cbar.outline.set_linewidth(0.5)
|
| 367 |
+
plt.tight_layout()
|
| 368 |
+
plt.savefig(dest, dpi=self.settings.plot_dpi, bbox_inches='tight')
|
| 369 |
+
plt.close(fig)
|
| 370 |
+
except Exception as e:
|
| 371 |
+
logger.error(f"Failed to save vegetation index image for {index_name}: {e}")
|
| 372 |
+
try:
|
| 373 |
+
# Fallback simple normalization
|
| 374 |
+
normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
|
| 375 |
+
self._imwrite_fast(dest, normalized)
|
| 376 |
+
except Exception:
|
| 377 |
+
pass
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
vegetation_indices = plant_data['vegetation_indices']
|
| 381 |
+
|
| 382 |
+
items_png: List[Tuple[str, np.ndarray, Path]] = []
|
| 383 |
+
items_stats: List[Tuple[Path, Dict[str, Any]]] = []
|
| 384 |
+
for index_name, index_data in vegetation_indices.items():
|
| 385 |
+
if isinstance(index_data, dict) and 'values' in index_data:
|
| 386 |
+
values = index_data['values']
|
| 387 |
+
if isinstance(values, np.ndarray) and values.size > 0:
|
| 388 |
+
items_png.append((index_name, values, veg_dir / f'{index_name}.png'))
|
| 389 |
+
stats = index_data.get('statistics')
|
| 390 |
+
if isinstance(stats, dict):
|
| 391 |
+
items_stats.append((veg_dir / f'{index_name}_stats.json', stats))
|
| 392 |
+
|
| 393 |
+
# Save sequentially to avoid matplotlib thread-safety issues
|
| 394 |
+
for (name, arr, dest) in items_png:
|
| 395 |
+
save_index_png(name, arr, dest)
|
| 396 |
+
for (path, stats) in items_stats:
|
| 397 |
+
try:
|
| 398 |
+
with open(path, 'w') as f:
|
| 399 |
+
json.dump(stats, f, indent=2)
|
| 400 |
+
except Exception as e:
|
| 401 |
+
logger.error(f"Failed to save stats for {path.name.split('.')[0]}: {e}")
|
| 402 |
+
|
| 403 |
+
# Create vegetation index summary (skip in fast mode)
|
| 404 |
+
if not self.fast_mode:
|
| 405 |
+
self._create_vegetation_summary_plot(veg_dir, vegetation_indices)
|
| 406 |
+
|
| 407 |
+
# Save aggregated vegetation statistics
|
| 408 |
+
try:
|
| 409 |
+
all_stats = {k: v.get('statistics', {}) for k, v in vegetation_indices.items() if isinstance(v, dict)}
|
| 410 |
+
with open(veg_dir / 'vegetation_statistics.json', 'w') as f:
|
| 411 |
+
json.dump(all_stats, f, indent=2)
|
| 412 |
+
except Exception as e:
|
| 413 |
+
logger.error(f"Failed to save aggregated vegetation statistics: {e}")
|
| 414 |
+
|
| 415 |
+
except Exception as e:
|
| 416 |
+
logger.error(f"Failed to save vegetation indices: {e}")
|
| 417 |
+
|
| 418 |
+
def _save_morphology_features(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
|
| 419 |
+
"""Save morphological features."""
|
| 420 |
+
if not self.settings.save_images or 'morphology_features' not in plant_data:
|
| 421 |
+
return
|
| 422 |
+
|
| 423 |
+
morph_dir = plant_dir / self.settings.morphology_dir
|
| 424 |
+
morph_dir.mkdir(exist_ok=True)
|
| 425 |
+
|
| 426 |
+
try:
|
| 427 |
+
morphology_features = plant_data['morphology_features']
|
| 428 |
+
|
| 429 |
+
# Save morphological images
|
| 430 |
+
if 'images' in morphology_features:
|
| 431 |
+
for image_name, image_data in morphology_features['images'].items():
|
| 432 |
+
if isinstance(image_data, np.ndarray) and image_data.size > 0:
|
| 433 |
+
cv2.imwrite(str(morph_dir / f'{image_name}.png'), image_data)
|
| 434 |
+
|
| 435 |
+
# Save morphological data
|
| 436 |
+
if 'traits' in morphology_features:
|
| 437 |
+
traits = morphology_features['traits']
|
| 438 |
+
with open(morph_dir / 'traits.json', 'w') as f:
|
| 439 |
+
json.dump(traits, f, indent=2)
|
| 440 |
+
|
| 441 |
+
except Exception as e:
|
| 442 |
+
logger.error(f"Failed to save morphology features: {e}")
|
| 443 |
+
|
| 444 |
+
def _save_analysis_plots(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
|
| 445 |
+
"""Save analysis plots."""
|
| 446 |
+
if not self.settings.save_plots or self.fast_mode:
|
| 447 |
+
return
|
| 448 |
+
|
| 449 |
+
analysis_dir = plant_dir / self.settings.analysis_dir
|
| 450 |
+
analysis_dir.mkdir(exist_ok=True)
|
| 451 |
+
|
| 452 |
+
try:
|
| 453 |
+
# Create comprehensive analysis plot
|
| 454 |
+
self._create_comprehensive_analysis_plot(analysis_dir, plant_data)
|
| 455 |
+
|
| 456 |
+
except Exception as e:
|
| 457 |
+
logger.error(f"Failed to save analysis plots: {e}")
|
| 458 |
+
|
| 459 |
+
def _save_metadata(self, plant_dir: Path, plant_key: str, plant_data: Dict[str, Any]) -> None:
|
| 460 |
+
"""Save metadata for the plant."""
|
| 461 |
+
if not self.settings.save_metadata:
|
| 462 |
+
return
|
| 463 |
+
|
| 464 |
+
try:
|
| 465 |
+
metadata = {
|
| 466 |
+
'plant_key': plant_key,
|
| 467 |
+
'timestamp': pd.Timestamp.now().isoformat(),
|
| 468 |
+
'image_shape': plant_data.get('composite', np.array([])).shape if 'composite' in plant_data else None,
|
| 469 |
+
'has_mask': 'mask' in plant_data and plant_data['mask'] is not None,
|
| 470 |
+
'features_available': {
|
| 471 |
+
'texture': 'texture_features' in plant_data,
|
| 472 |
+
'vegetation': 'vegetation_indices' in plant_data,
|
| 473 |
+
'morphology': 'morphology_features' in plant_data
|
| 474 |
+
}
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
with open(plant_dir / 'metadata.json', 'w') as f:
|
| 478 |
+
json.dump(metadata, f, indent=2)
|
| 479 |
+
|
| 480 |
+
except Exception as e:
|
| 481 |
+
logger.error(f"Failed to save metadata: {e}")
|
| 482 |
+
|
| 483 |
+
def _create_overlay(self, image: np.ndarray, mask: np.ndarray,
|
| 484 |
+
color: Tuple[int, int, int] = (0, 255, 0),
|
| 485 |
+
alpha: float = 0.5) -> np.ndarray:
|
| 486 |
+
"""Return a strictly masked image: pixels where mask>0 keep original; others set to 0."""
|
| 487 |
+
if mask is None:
|
| 488 |
+
return image
|
| 489 |
+
# Resize mask to image size if needed
|
| 490 |
+
if mask.shape[:2] != image.shape[:2]:
|
| 491 |
+
try:
|
| 492 |
+
mask = cv2.resize(mask.astype(np.uint8), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 493 |
+
except Exception:
|
| 494 |
+
pass
|
| 495 |
+
binary = (mask.astype(np.int32) > 0).astype(np.uint8) * 255
|
| 496 |
+
return cv2.bitwise_and(image, image, mask=binary)
|
| 497 |
+
|
| 498 |
+
def _create_maskout_white_background(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 499 |
+
"""Create maskout image with white background."""
|
| 500 |
+
# Create white background
|
| 501 |
+
white_background = np.full_like(image, 255, dtype=np.uint8)
|
| 502 |
+
|
| 503 |
+
# Apply mask to original image (keep only masked regions)
|
| 504 |
+
masked_image = image.copy()
|
| 505 |
+
masked_image[mask == 0] = 0 # Set non-masked regions to black
|
| 506 |
+
|
| 507 |
+
# Combine: white background + masked image
|
| 508 |
+
result = white_background.copy()
|
| 509 |
+
result[mask > 0] = masked_image[mask > 0]
|
| 510 |
+
|
| 511 |
+
return result
|
| 512 |
+
|
| 513 |
+
def _normalize_to_uint8(self, arr: np.ndarray) -> np.ndarray:
|
| 514 |
+
"""Normalize array to uint8 range."""
|
| 515 |
+
if arr.size == 0:
|
| 516 |
+
return arr.astype(np.uint8)
|
| 517 |
+
|
| 518 |
+
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 519 |
+
|
| 520 |
+
if arr.ptp() > 0:
|
| 521 |
+
normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
|
| 522 |
+
else:
|
| 523 |
+
normalized = np.zeros_like(arr)
|
| 524 |
+
|
| 525 |
+
return np.clip(normalized, 0, 255).astype(np.uint8)
|
| 526 |
+
|
| 527 |
+
def _create_texture_summary_plot(self, output_dir: Path, features: Dict[str, np.ndarray], band: str) -> None:
|
| 528 |
+
"""Create texture feature summary plot."""
|
| 529 |
+
try:
|
| 530 |
+
# Get available features
|
| 531 |
+
available_features = [k for k, v in features.items()
|
| 532 |
+
if isinstance(v, np.ndarray) and v.size > 0 and k != 'ehd_features']
|
| 533 |
+
|
| 534 |
+
if not available_features:
|
| 535 |
+
return
|
| 536 |
+
|
| 537 |
+
# Create subplot
|
| 538 |
+
n_features = len(available_features)
|
| 539 |
+
cols = min(3, n_features)
|
| 540 |
+
rows = (n_features + cols - 1) // cols
|
| 541 |
+
|
| 542 |
+
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
|
| 543 |
+
if n_features == 1:
|
| 544 |
+
axes = [axes]
|
| 545 |
+
elif rows == 1:
|
| 546 |
+
axes = axes.reshape(1, -1)
|
| 547 |
+
|
| 548 |
+
for i, feature_name in enumerate(available_features):
|
| 549 |
+
row, col = divmod(i, cols)
|
| 550 |
+
ax = axes[row, col] if rows > 1 else axes[col]
|
| 551 |
+
|
| 552 |
+
feature_map = features[feature_name]
|
| 553 |
+
ax.imshow(feature_map, cmap='viridis')
|
| 554 |
+
ax.set_title(f'{band.upper()} - {feature_name.upper()}')
|
| 555 |
+
ax.axis('off')
|
| 556 |
+
|
| 557 |
+
# Hide unused subplots
|
| 558 |
+
for i in range(n_features, rows * cols):
|
| 559 |
+
row, col = divmod(i, cols)
|
| 560 |
+
ax = axes[row, col] if rows > 1 else axes[col]
|
| 561 |
+
ax.axis('off')
|
| 562 |
+
|
| 563 |
+
plt.tight_layout()
|
| 564 |
+
plt.savefig(output_dir / f'{band}_texture_summary.png',
|
| 565 |
+
dpi=self.settings.plot_dpi, bbox_inches='tight')
|
| 566 |
+
plt.close()
|
| 567 |
+
|
| 568 |
+
except Exception as e:
|
| 569 |
+
logger.error(f"Failed to create texture summary plot: {e}")
|
| 570 |
+
|
| 571 |
+
def _create_vegetation_summary_plot(self, output_dir: Path, vegetation_indices: Dict[str, Any]) -> None:
|
| 572 |
+
"""Create vegetation index summary plot."""
|
| 573 |
+
try:
|
| 574 |
+
# Get available indices
|
| 575 |
+
available_indices = [k for k, v in vegetation_indices.items()
|
| 576 |
+
if isinstance(v, dict) and 'values' in v and isinstance(v['values'], np.ndarray)]
|
| 577 |
+
|
| 578 |
+
if not available_indices:
|
| 579 |
+
return
|
| 580 |
+
|
| 581 |
+
# Create subplot
|
| 582 |
+
n_indices = len(available_indices)
|
| 583 |
+
cols = min(3, n_indices)
|
| 584 |
+
rows = (n_indices + cols - 1) // cols
|
| 585 |
+
|
| 586 |
+
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
|
| 587 |
+
if n_indices == 1:
|
| 588 |
+
axes = [axes]
|
| 589 |
+
elif rows == 1:
|
| 590 |
+
axes = axes.reshape(1, -1)
|
| 591 |
+
|
| 592 |
+
for i, index_name in enumerate(available_indices):
|
| 593 |
+
row, col = divmod(i, cols)
|
| 594 |
+
ax = axes[row, col] if rows > 1 else axes[col]
|
| 595 |
+
|
| 596 |
+
values = vegetation_indices[index_name]['values']
|
| 597 |
+
im = ax.imshow(values, cmap='RdYlGn')
|
| 598 |
+
ax.set_title(f'{index_name}')
|
| 599 |
+
ax.axis('off')
|
| 600 |
+
divider = make_axes_locatable(ax)
|
| 601 |
+
cax = divider.append_axes("right", size="2%", pad=0.02)
|
| 602 |
+
cbar = plt.colorbar(im, cax=cax, orientation='vertical')
|
| 603 |
+
cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
|
| 604 |
+
if hasattr(cbar, 'outline') and cbar.outline is not None:
|
| 605 |
+
cbar.outline.set_linewidth(0.5)
|
| 606 |
+
|
| 607 |
+
# Hide unused subplots
|
| 608 |
+
for i in range(n_indices, rows * cols):
|
| 609 |
+
row, col = divmod(i, cols)
|
| 610 |
+
ax = axes[row, col] if rows > 1 else axes[col]
|
| 611 |
+
ax.axis('off')
|
| 612 |
+
|
| 613 |
+
plt.tight_layout()
|
| 614 |
+
plt.savefig(output_dir / 'vegetation_indices_summary.png',
|
| 615 |
+
dpi=self.settings.plot_dpi, bbox_inches='tight')
|
| 616 |
+
plt.close()
|
| 617 |
+
|
| 618 |
+
except Exception as e:
|
| 619 |
+
logger.error(f"Failed to create vegetation summary plot: {e}")
|
| 620 |
+
|
| 621 |
+
def _create_comprehensive_analysis_plot(self, output_dir: Path, plant_data: Dict[str, Any]) -> None:
|
| 622 |
+
"""Create comprehensive analysis plot."""
|
| 623 |
+
try:
|
| 624 |
+
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
| 625 |
+
|
| 626 |
+
# Original image
|
| 627 |
+
if 'composite' in plant_data:
|
| 628 |
+
axes[0, 0].imshow(cv2.cvtColor(plant_data['composite'], cv2.COLOR_BGR2RGB))
|
| 629 |
+
axes[0, 0].set_title('Original Composite')
|
| 630 |
+
axes[0, 0].axis('off')
|
| 631 |
+
|
| 632 |
+
# Mask
|
| 633 |
+
if 'mask' in plant_data:
|
| 634 |
+
axes[0, 1].imshow(plant_data['mask'], cmap='gray')
|
| 635 |
+
axes[0, 1].set_title('Segmentation Mask')
|
| 636 |
+
axes[0, 1].axis('off')
|
| 637 |
+
|
| 638 |
+
# Overlay
|
| 639 |
+
if 'composite' in plant_data and 'mask' in plant_data:
|
| 640 |
+
overlay = self._create_overlay(plant_data['composite'], plant_data['mask'])
|
| 641 |
+
axes[0, 2].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
|
| 642 |
+
axes[0, 2].set_title('Overlay')
|
| 643 |
+
axes[0, 2].axis('off')
|
| 644 |
+
|
| 645 |
+
# Texture features (if available)
|
| 646 |
+
if 'texture_features' in plant_data and 'color' in plant_data['texture_features']:
|
| 647 |
+
color_features = plant_data['texture_features']['color'].get('features', {})
|
| 648 |
+
if 'lbp' in color_features:
|
| 649 |
+
axes[1, 0].imshow(color_features['lbp'], cmap='viridis')
|
| 650 |
+
axes[1, 0].set_title('LBP Texture')
|
| 651 |
+
axes[1, 0].axis('off')
|
| 652 |
+
|
| 653 |
+
# Vegetation indices (if available)
|
| 654 |
+
if 'vegetation_indices' in plant_data:
|
| 655 |
+
veg_indices = plant_data['vegetation_indices']
|
| 656 |
+
if 'NDVI' in veg_indices and 'values' in veg_indices['NDVI']:
|
| 657 |
+
axes[1, 1].imshow(veg_indices['NDVI']['values'], cmap='RdYlGn')
|
| 658 |
+
axes[1, 1].set_title('NDVI')
|
| 659 |
+
axes[1, 1].axis('off')
|
| 660 |
+
|
| 661 |
+
# Morphology (if available)
|
| 662 |
+
if 'morphology_features' in plant_data and 'images' in plant_data['morphology_features']:
|
| 663 |
+
morph_images = plant_data['morphology_features']['images']
|
| 664 |
+
if 'skeleton' in morph_images:
|
| 665 |
+
axes[1, 2].imshow(morph_images['skeleton'], cmap='gray')
|
| 666 |
+
axes[1, 2].set_title('Skeleton')
|
| 667 |
+
axes[1, 2].axis('off')
|
| 668 |
+
|
| 669 |
+
plt.tight_layout()
|
| 670 |
+
plt.savefig(output_dir / 'comprehensive_analysis.png',
|
| 671 |
+
dpi=min(getattr(self.settings, 'plot_dpi', 100), 100), bbox_inches='tight')
|
| 672 |
+
plt.close()
|
| 673 |
+
|
| 674 |
+
except Exception as e:
|
| 675 |
+
logger.error(f"Failed to create comprehensive analysis plot: {e}")
|
| 676 |
+
|
| 677 |
+
def create_pipeline_summary(self, results: Dict[str, Any]) -> None:
|
| 678 |
+
"""Create a summary of the entire pipeline run."""
|
| 679 |
+
try:
|
| 680 |
+
summary_file = self.output_folder / 'pipeline_summary.json'
|
| 681 |
+
|
| 682 |
+
with open(summary_file, 'w') as f:
|
| 683 |
+
json.dump(results['summary'], f, indent=2)
|
| 684 |
+
|
| 685 |
+
logger.info(f"Pipeline summary saved to {summary_file}")
|
| 686 |
+
|
| 687 |
+
except Exception as e:
|
| 688 |
+
logger.error(f"Failed to create pipeline summary: {e}")
|
sorghum_pipeline/pipeline.py
ADDED
|
@@ -0,0 +1,1377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main pipeline class for the Sorghum Plant Phenotyping Pipeline.
|
| 3 |
+
|
| 4 |
+
This module orchestrates the entire pipeline from data loading
|
| 5 |
+
to feature extraction and result output.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import subprocess
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, Any, Optional, List, Set
|
| 13 |
+
import numpy as np
|
| 14 |
+
import cv2
|
| 15 |
+
import torch
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
from transformers import AutoModelForImageSegmentation
|
| 18 |
+
from sklearn.decomposition import PCA
|
| 19 |
+
try:
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
except Exception:
|
| 22 |
+
tqdm = None
|
| 23 |
+
|
| 24 |
+
from .config import Config
|
| 25 |
+
from .data import DataLoader, ImagePreprocessor, MaskHandler
|
| 26 |
+
from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
|
| 27 |
+
from .output import OutputManager
|
| 28 |
+
from .segmentation import SegmentationManager
|
| 29 |
+
# Make occlusion handling optional if the module is not present
|
| 30 |
+
try:
|
| 31 |
+
from .segmentation.occlusion_handler import OcclusionHandler # type: ignore
|
| 32 |
+
except Exception:
|
| 33 |
+
OcclusionHandler = None # type: ignore
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SorghumPipeline:
|
| 37 |
+
"""
|
| 38 |
+
Main pipeline class for sorghum plant phenotyping.
|
| 39 |
+
|
| 40 |
+
This class orchestrates the entire pipeline from data loading
|
| 41 |
+
to feature extraction and result output.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, config_path: Optional[str] = None, config: Optional[Config] = None, include_ignored: bool = False, enable_occlusion_handling: bool = False, enable_instance_integration: bool = False, strict_loader: bool = False, excluded_dates: Optional[List[str]] = None):
|
| 45 |
+
"""
|
| 46 |
+
Initialize the pipeline.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
config_path: Path to configuration file
|
| 50 |
+
config: Configuration object (if not using file)
|
| 51 |
+
include_ignored: Whether to include ignored plants
|
| 52 |
+
enable_occlusion_handling: Whether to enable SAM2Long occlusion handling
|
| 53 |
+
"""
|
| 54 |
+
# Setup logging
|
| 55 |
+
self._setup_logging()
|
| 56 |
+
|
| 57 |
+
# Load configuration
|
| 58 |
+
if config is not None:
|
| 59 |
+
self.config = config
|
| 60 |
+
elif config_path is not None:
|
| 61 |
+
self.config = Config(config_path)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError("Either config_path or config must be provided")
|
| 64 |
+
|
| 65 |
+
# Validate configuration
|
| 66 |
+
self.config.validate()
|
| 67 |
+
|
| 68 |
+
# Store settings
|
| 69 |
+
self.enable_occlusion_handling = enable_occlusion_handling
|
| 70 |
+
self.enable_instance_integration = enable_instance_integration
|
| 71 |
+
self.strict_loader = strict_loader
|
| 72 |
+
self.excluded_dates = excluded_dates or []
|
| 73 |
+
|
| 74 |
+
# Initialize components
|
| 75 |
+
self._initialize_components(include_ignored)
|
| 76 |
+
|
| 77 |
+
logger.info("Sorghum Pipeline initialized successfully")
|
| 78 |
+
|
| 79 |
+
def _setup_logging(self):
|
| 80 |
+
"""Setup logging configuration."""
|
| 81 |
+
logging.basicConfig(
|
| 82 |
+
level=logging.INFO,
|
| 83 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 84 |
+
handlers=[
|
| 85 |
+
logging.StreamHandler(),
|
| 86 |
+
logging.FileHandler('sorghum_pipeline.log')
|
| 87 |
+
]
|
| 88 |
+
)
|
| 89 |
+
global logger
|
| 90 |
+
logger = logging.getLogger(__name__)
|
| 91 |
+
|
| 92 |
+
def _initialize_components(self, include_ignored: bool = False):
|
| 93 |
+
"""Initialize all pipeline components."""
|
| 94 |
+
# Data components
|
| 95 |
+
self.data_loader = DataLoader(
|
| 96 |
+
input_folder=self.config.paths.input_folder,
|
| 97 |
+
debug=True,
|
| 98 |
+
include_ignored=include_ignored,
|
| 99 |
+
strict_loader=self.strict_loader,
|
| 100 |
+
excluded_dates=self.excluded_dates,
|
| 101 |
+
)
|
| 102 |
+
self.preprocessor = ImagePreprocessor(
|
| 103 |
+
target_size=self.config.processing.target_size
|
| 104 |
+
)
|
| 105 |
+
self.mask_handler = MaskHandler(
|
| 106 |
+
min_area=self.config.processing.min_component_area,
|
| 107 |
+
kernel_size=self.config.processing.morphology_kernel_size
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Feature extractors
|
| 111 |
+
self.texture_extractor = TextureExtractor(
|
| 112 |
+
lbp_points=self.config.processing.lbp_points,
|
| 113 |
+
lbp_radius=self.config.processing.lbp_radius,
|
| 114 |
+
hog_orientations=self.config.processing.hog_orientations,
|
| 115 |
+
hog_pixels_per_cell=self.config.processing.hog_pixels_per_cell,
|
| 116 |
+
hog_cells_per_block=self.config.processing.hog_cells_per_block,
|
| 117 |
+
lacunarity_window=self.config.processing.lacunarity_window,
|
| 118 |
+
ehd_threshold=self.config.processing.ehd_threshold,
|
| 119 |
+
angle_resolution=self.config.processing.angle_resolution
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.vegetation_extractor = VegetationIndexExtractor(
|
| 123 |
+
epsilon=self.config.processing.epsilon,
|
| 124 |
+
soil_factor=self.config.processing.soil_factor
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
self.morphology_extractor = MorphologyExtractor(
|
| 128 |
+
pixel_to_cm=self.config.processing.pixel_to_cm,
|
| 129 |
+
prune_sizes=self.config.processing.prune_sizes
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Segmentation
|
| 133 |
+
self.segmentation_manager = SegmentationManager(
|
| 134 |
+
model_name=self.config.model.model_name,
|
| 135 |
+
device=self.config.get_device(),
|
| 136 |
+
threshold=self.config.processing.segmentation_threshold,
|
| 137 |
+
trust_remote_code=self.config.model.trust_remote_code,
|
| 138 |
+
cache_dir=self.config.model.cache_dir if getattr(self.config.model, 'cache_dir', '') else None,
|
| 139 |
+
local_files_only=getattr(self.config.model, 'local_files_only', False),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Occlusion handling (optional)
|
| 143 |
+
self.occlusion_handler = None
|
| 144 |
+
if self.enable_occlusion_handling and OcclusionHandler is not None:
|
| 145 |
+
try:
|
| 146 |
+
self.occlusion_handler = OcclusionHandler(
|
| 147 |
+
device=self.config.get_device(),
|
| 148 |
+
model="tiny", # Can be made configurable
|
| 149 |
+
confidence_threshold=0.5,
|
| 150 |
+
iou_threshold=0.1
|
| 151 |
+
)
|
| 152 |
+
logger.info("Occlusion handler initialized successfully")
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.warning(f"Failed to initialize occlusion handler: {e}")
|
| 155 |
+
logger.warning("Continuing without occlusion handling")
|
| 156 |
+
self.occlusion_handler = None
|
| 157 |
+
elif self.enable_occlusion_handling and OcclusionHandler is None:
|
| 158 |
+
logger.warning("Occlusion handler module not found; continuing without occlusion handling")
|
| 159 |
+
|
| 160 |
+
# Output manager
|
| 161 |
+
self.output_manager = OutputManager(
|
| 162 |
+
output_folder=self.config.paths.output_folder,
|
| 163 |
+
settings=self.config.output
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def _free_gpu_memory_before_instance(self) -> None:
|
| 167 |
+
"""Attempt to free GPU memory prior to running SAM2Long in a subprocess.
|
| 168 |
+
|
| 169 |
+
- Moves BRIA segmentation model to CPU if present
|
| 170 |
+
- Deletes the model reference to release VRAM
|
| 171 |
+
- Calls torch.cuda.empty_cache()
|
| 172 |
+
"""
|
| 173 |
+
try:
|
| 174 |
+
import torch as _torch # type: ignore
|
| 175 |
+
# Move BRIA model to CPU and drop reference
|
| 176 |
+
try:
|
| 177 |
+
if getattr(self, 'segmentation_manager', None) is not None:
|
| 178 |
+
mdl = getattr(self.segmentation_manager, 'model', None)
|
| 179 |
+
if mdl is not None:
|
| 180 |
+
try:
|
| 181 |
+
mdl.to('cpu')
|
| 182 |
+
except Exception:
|
| 183 |
+
pass
|
| 184 |
+
try:
|
| 185 |
+
delattr(self.segmentation_manager, 'model')
|
| 186 |
+
except Exception:
|
| 187 |
+
pass
|
| 188 |
+
# Ensure attribute exists but is None for future checks
|
| 189 |
+
try:
|
| 190 |
+
self.segmentation_manager.model = None # type: ignore
|
| 191 |
+
except Exception:
|
| 192 |
+
pass
|
| 193 |
+
except Exception:
|
| 194 |
+
pass
|
| 195 |
+
# Free CUDA cache
|
| 196 |
+
try:
|
| 197 |
+
if _torch.cuda.is_available():
|
| 198 |
+
_torch.cuda.empty_cache()
|
| 199 |
+
except Exception:
|
| 200 |
+
pass
|
| 201 |
+
logger.info("Freed GPU memory before SAM2Long invocation (moved BRIA to CPU and emptied cache)")
|
| 202 |
+
except Exception as e:
|
| 203 |
+
logger.warning(f"Failed to free GPU memory before instance segmentation: {e}")
|
| 204 |
+
|
| 205 |
+
def run(self, load_all_frames: bool = False, segmentation_only: bool = False, filter_plants: Optional[List[str]] = None, filter_frames: Optional[List[str]] = None, run_instance_segmentation: bool = False, features_frame_only: Optional[int] = None, reuse_instance_results: bool = False, instance_mapping_path: Optional[str] = None, force_reprocess: bool = False, respect_instance_frame_rules_for_features: bool = False, substitute_feature_image_from_instance_src: bool = False) -> Dict[str, Any]:
|
| 206 |
+
"""
|
| 207 |
+
Run the complete pipeline.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
load_all_frames: Whether to load all frames or selected frames
|
| 211 |
+
segmentation_only: If True, run segmentation only and skip feature extraction
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Dictionary containing all results
|
| 215 |
+
"""
|
| 216 |
+
logger.info("Starting Sorghum Pipeline...")
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
import time
|
| 220 |
+
total_start = time.perf_counter()
|
| 221 |
+
# Step 1: Load data
|
| 222 |
+
logger.info("Step 1/6: Loading data...")
|
| 223 |
+
# In reuse mode we need all frames to select the mapped frame per plant
|
| 224 |
+
if reuse_instance_results:
|
| 225 |
+
plants = self.data_loader.load_all_frames()
|
| 226 |
+
else:
|
| 227 |
+
# If specific frames are requested, we must load all frames to filter correctly
|
| 228 |
+
if load_all_frames or (filter_frames is not None and len(filter_frames) > 0):
|
| 229 |
+
plants = self.data_loader.load_all_frames()
|
| 230 |
+
else:
|
| 231 |
+
plants = self.data_loader.load_selected_frames()
|
| 232 |
+
|
| 233 |
+
# Optional filter by specific plant names (e.g., ["plant1"])
|
| 234 |
+
if filter_plants:
|
| 235 |
+
allowed = set(filter_plants)
|
| 236 |
+
plants = {
|
| 237 |
+
key: pdata for key, pdata in plants.items()
|
| 238 |
+
if len(key.split('_')) > 3 and key.split('_')[3] in allowed
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
# Optional filter by specific frame numbers (e.g., ["9"] or ["frame9"])
|
| 242 |
+
if filter_frames:
|
| 243 |
+
# Normalize to 'frameX' tokens
|
| 244 |
+
wanted = set(
|
| 245 |
+
[f if str(f).startswith('frame') else f"frame{str(f)}" for f in filter_frames]
|
| 246 |
+
)
|
| 247 |
+
plants = {
|
| 248 |
+
key: pdata for key, pdata in plants.items()
|
| 249 |
+
if key.split('_')[-1] in wanted
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
if not plants:
|
| 253 |
+
raise ValueError("No plant data loaded")
|
| 254 |
+
|
| 255 |
+
logger.info(f"Loaded {len(plants)} plants")
|
| 256 |
+
|
| 257 |
+
# If reusing instance results with mapping, restrict to exactly the mapped frame per plant (default frame8)
|
| 258 |
+
if reuse_instance_results:
|
| 259 |
+
try:
|
| 260 |
+
import json as _json
|
| 261 |
+
if instance_mapping_path is None:
|
| 262 |
+
raise ValueError("instance_mapping_path is required in reuse mode")
|
| 263 |
+
_map = _json.load(open(instance_mapping_path, 'r'))
|
| 264 |
+
# Normalize mapping plant keys and compute target frame (default 8)
|
| 265 |
+
target_frame_by_plant = {}
|
| 266 |
+
for pk, pv in _map.items():
|
| 267 |
+
k_norm = pk if str(pk).startswith('plant') else f"plant{int(pk)}" if str(pk).isdigit() else str(pk)
|
| 268 |
+
try:
|
| 269 |
+
target_frame_by_plant[k_norm] = int(pv.get('frame', 8))
|
| 270 |
+
except Exception:
|
| 271 |
+
target_frame_by_plant[k_norm] = 8
|
| 272 |
+
before = len(plants)
|
| 273 |
+
plants = {
|
| 274 |
+
key: pdata for key, pdata in plants.items()
|
| 275 |
+
if (len(key.split('_')) > 3 and key.split('_')[3] in target_frame_by_plant
|
| 276 |
+
and key.split('_')[-1] == f"frame{target_frame_by_plant[key.split('_')[3]]}")
|
| 277 |
+
}
|
| 278 |
+
logger.info(f"Restricted loaded data by mapping frames: {before} -> {len(plants)} items")
|
| 279 |
+
except Exception as e:
|
| 280 |
+
logger.warning(f"Failed to restrict loaded data by mapping frames: {e}")
|
| 281 |
+
|
| 282 |
+
# Skip plants that already have saved results (unless force_reprocess)
|
| 283 |
+
if not force_reprocess:
|
| 284 |
+
try:
|
| 285 |
+
before = len(plants)
|
| 286 |
+
filtered = {}
|
| 287 |
+
for key, pdata in plants.items():
|
| 288 |
+
parts = key.split('_')
|
| 289 |
+
if len(parts) < 5:
|
| 290 |
+
filtered[key] = pdata
|
| 291 |
+
continue
|
| 292 |
+
date_key = "_".join(parts[:3])
|
| 293 |
+
plant_name = parts[3]
|
| 294 |
+
plant_dir = Path(self.config.paths.output_folder) / date_key / plant_name
|
| 295 |
+
meta_ok = (plant_dir / 'metadata.json').exists()
|
| 296 |
+
seg_mask_ok = (plant_dir / self.config.output.segmentation_dir / 'mask.png').exists()
|
| 297 |
+
if meta_ok or seg_mask_ok:
|
| 298 |
+
continue
|
| 299 |
+
filtered[key] = pdata
|
| 300 |
+
plants = filtered
|
| 301 |
+
logger.info(f"Skip-existing filter: {before} -> {len(plants)} items to process")
|
| 302 |
+
except Exception as e:
|
| 303 |
+
logger.warning(f"Skip-existing filter failed: {e}")
|
| 304 |
+
|
| 305 |
+
# Pre-segmentation borrowing: use plant12 images for plant13 from the start
|
| 306 |
+
try:
|
| 307 |
+
rewired = 0
|
| 308 |
+
borrow_map: Dict[str, str] = {
|
| 309 |
+
'plant13': 'plant12',
|
| 310 |
+
'plant14': 'plant13',
|
| 311 |
+
'plant15': 'plant14',
|
| 312 |
+
'plant16': 'plant15',
|
| 313 |
+
}
|
| 314 |
+
for _k in list(plants.keys()):
|
| 315 |
+
_parts = _k.split('_')
|
| 316 |
+
# Expect keys like YYYY_MM_DD_plantX_frameY
|
| 317 |
+
if len(_parts) < 5:
|
| 318 |
+
continue
|
| 319 |
+
_date_key = "_".join(_parts[:3])
|
| 320 |
+
_plant_name = _parts[3]
|
| 321 |
+
_frame_token = _parts[4]
|
| 322 |
+
# Do NOT borrow on 2025_05_08
|
| 323 |
+
if _date_key == '2025_05_08':
|
| 324 |
+
continue
|
| 325 |
+
if _plant_name not in borrow_map:
|
| 326 |
+
continue
|
| 327 |
+
_src_plant = borrow_map[_plant_name]
|
| 328 |
+
_src_key = f"{_date_key}_{_src_plant}_{_frame_token}"
|
| 329 |
+
_src = plants.get(_src_key)
|
| 330 |
+
if not _src:
|
| 331 |
+
# Fallback: load raw image for source plant directly from disk
|
| 332 |
+
try:
|
| 333 |
+
from PIL import Image as _Image
|
| 334 |
+
_date_folder = _date_key.replace('_', '-')
|
| 335 |
+
_frame_num = int(_frame_token.replace('frame', ''))
|
| 336 |
+
_date_dir = Path(self.config.paths.input_folder)
|
| 337 |
+
# If input folder is a parent of dates, append date folder
|
| 338 |
+
if _date_dir.name != _date_folder:
|
| 339 |
+
_date_dir = _date_dir / _date_folder
|
| 340 |
+
_frame_path = _date_dir / _src_plant / f"{_src_plant}_frame{_frame_num}.tif"
|
| 341 |
+
if _frame_path.exists():
|
| 342 |
+
_img = _Image.open(str(_frame_path))
|
| 343 |
+
_src = {"raw_image": (_img, _frame_path.name), "plant_name": _plant_name, "file_path": str(_frame_path)}
|
| 344 |
+
else:
|
| 345 |
+
_src = None
|
| 346 |
+
except Exception:
|
| 347 |
+
_src = None
|
| 348 |
+
if not _src:
|
| 349 |
+
continue
|
| 350 |
+
_tgt = plants[_k]
|
| 351 |
+
# Preserve original raw image once
|
| 352 |
+
if 'raw_image' in _tgt and 'raw_image_original' not in _tgt:
|
| 353 |
+
_tgt['raw_image_original'] = _tgt['raw_image']
|
| 354 |
+
if 'raw_image' in _src:
|
| 355 |
+
_tgt['raw_image'] = _src['raw_image']
|
| 356 |
+
_tgt['borrowed_from'] = _src_plant
|
| 357 |
+
rewired += 1
|
| 358 |
+
if rewired > 0:
|
| 359 |
+
logger.info(f"Pre-seg borrowing applied: rewired {rewired} frames for plants 13/14/15/16")
|
| 360 |
+
except Exception as e:
|
| 361 |
+
logger.warning(f"Pre-seg borrowing failed: {e}")
|
| 362 |
+
|
| 363 |
+
# Step 2: Create composites
|
| 364 |
+
logger.info("Step 2/6: Creating composites...")
|
| 365 |
+
step_start = time.perf_counter()
|
| 366 |
+
plants = self.preprocessor.create_composites(plants)
|
| 367 |
+
logger.info(f"Composites done in {(time.perf_counter()-step_start):.2f}s")
|
| 368 |
+
|
| 369 |
+
# Step 3: Segment plants (optionally with bounding boxes)
|
| 370 |
+
logger.info("Step 3/6: Segmenting plants...")
|
| 371 |
+
step_start = time.perf_counter()
|
| 372 |
+
bbox_lookup = None
|
| 373 |
+
try:
|
| 374 |
+
bbox_dir = getattr(self.config.paths, 'boundingbox_dir', None)
|
| 375 |
+
# Default to project BoundingBox dir if unset or falsy
|
| 376 |
+
if not bbox_dir:
|
| 377 |
+
try:
|
| 378 |
+
self.config.paths.boundingbox_dir = "/home/grads/f/fahimehorvatinia/Documents/my_full_project/BoundingBox"
|
| 379 |
+
bbox_dir = self.config.paths.boundingbox_dir
|
| 380 |
+
except Exception:
|
| 381 |
+
bbox_dir = None
|
| 382 |
+
if bbox_dir:
|
| 383 |
+
bbox_lookup = self.data_loader.load_bounding_boxes(bbox_dir)
|
| 384 |
+
logger.info(f"Loaded bounding boxes from {bbox_dir}")
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.warning(f"Failed to load bounding boxes: {e}")
|
| 387 |
+
bbox_lookup = None
|
| 388 |
+
plants = self._segment_plants(plants, bbox_lookup)
|
| 389 |
+
logger.info(f"Segmentation done in {(time.perf_counter()-step_start):.2f}s")
|
| 390 |
+
|
| 391 |
+
# Step 3.5: Handle occlusion if enabled
|
| 392 |
+
if self.enable_occlusion_handling and self.occlusion_handler is not None:
|
| 393 |
+
logger.info("Step 3.5/6: Handling occlusion with SAM2Long...")
|
| 394 |
+
step_start = time.perf_counter()
|
| 395 |
+
plants = self._handle_occlusion(plants)
|
| 396 |
+
logger.info(f"Occlusion handling done in {(time.perf_counter()-step_start):.2f}s")
|
| 397 |
+
|
| 398 |
+
# Optional: Export RMBG maskouts with white background and run instance segmentation
|
| 399 |
+
if (run_instance_segmentation or self.enable_instance_integration) and not reuse_instance_results:
|
| 400 |
+
if not load_all_frames:
|
| 401 |
+
logger.warning("Instance segmentation expects all 13 frames; consider running with load_all_frames=True.")
|
| 402 |
+
logger.info("Step 3.6: Exporting white-background RMBG images for instance segmentation...")
|
| 403 |
+
# Derive date-specific export/result directories when a single date is present
|
| 404 |
+
date_keys = set()
|
| 405 |
+
try:
|
| 406 |
+
for _k in plants.keys():
|
| 407 |
+
_p = _k.split('_')
|
| 408 |
+
if len(_p) >= 3:
|
| 409 |
+
date_keys.add("_".join(_p[:3]))
|
| 410 |
+
except Exception:
|
| 411 |
+
pass
|
| 412 |
+
if len(date_keys) == 1:
|
| 413 |
+
date_key = next(iter(date_keys))
|
| 414 |
+
base_dir = Path(self.config.paths.output_folder) / date_key
|
| 415 |
+
export_dir = base_dir / "instance_input_maskouts"
|
| 416 |
+
instance_results_dir = base_dir / "instance_results"
|
| 417 |
+
else:
|
| 418 |
+
export_dir = Path(self.config.paths.output_folder) / "instance_input_maskouts"
|
| 419 |
+
instance_results_dir = Path(self.config.paths.output_folder) / "instance_results"
|
| 420 |
+
export_dir.mkdir(parents=True, exist_ok=True)
|
| 421 |
+
instance_results_dir.mkdir(parents=True, exist_ok=True)
|
| 422 |
+
self._export_white_background_maskouts(plants, export_dir)
|
| 423 |
+
|
| 424 |
+
logger.info("Invoking final SAM2Long instance segmentation on exported images...")
|
| 425 |
+
# Free GPU memory before launching SAM2Long to avoid CUDA OOM
|
| 426 |
+
self._free_gpu_memory_before_instance()
|
| 427 |
+
env = os.environ.copy()
|
| 428 |
+
env["SAM2LONG_IMAGES_DIR"] = str(export_dir)
|
| 429 |
+
env["SAM2LONG_RESULTS_DIR"] = str(instance_results_dir)
|
| 430 |
+
# Ensure instance outputs include all frames for all dates
|
| 431 |
+
try:
|
| 432 |
+
env.pop("INSTANCE_OUTPUT_FRAMES", None)
|
| 433 |
+
except Exception:
|
| 434 |
+
pass
|
| 435 |
+
script_path = "/home/grads/f/fahimehorvatinia/Documents/my_full_project/Experiments3_code/sam2long_instance_integration.py"
|
| 436 |
+
try:
|
| 437 |
+
subprocess.run(["python", script_path], check=True, env=env)
|
| 438 |
+
except subprocess.CalledProcessError as e:
|
| 439 |
+
logger.error(f"Instance segmentation failed: {e}")
|
| 440 |
+
else:
|
| 441 |
+
# Integrate instance masks (track_0 as target) into pdata before feature extraction
|
| 442 |
+
try:
|
| 443 |
+
self._apply_instance_masks(plants, instance_results_dir)
|
| 444 |
+
logger.info("Applied instance segmentation masks to pipeline data")
|
| 445 |
+
except Exception as e:
|
| 446 |
+
logger.warning(f"Failed to apply instance masks: {e}")
|
| 447 |
+
elif reuse_instance_results:
|
| 448 |
+
# Reuse existing instance masks from mapping file
|
| 449 |
+
if instance_mapping_path is None:
|
| 450 |
+
raise ValueError("reuse_instance_results=True requires instance_mapping_path to be provided")
|
| 451 |
+
try:
|
| 452 |
+
self._apply_instance_masks_from_mapping(plants, Path(instance_mapping_path))
|
| 453 |
+
logger.info("Applied instance masks from mapping file")
|
| 454 |
+
except Exception as e:
|
| 455 |
+
logger.error(f"Failed to apply instance masks from mapping: {e}")
|
| 456 |
+
|
| 457 |
+
if not segmentation_only:
|
| 458 |
+
# If reusing instance results with a mapping, restrict features to mapped frames per plant
|
| 459 |
+
if reuse_instance_results and instance_mapping_path is not None:
|
| 460 |
+
try:
|
| 461 |
+
import json as _json
|
| 462 |
+
_map = _json.load(open(instance_mapping_path, 'r'))
|
| 463 |
+
# Normalize map
|
| 464 |
+
_norm = {}
|
| 465 |
+
for pk, pv in _map.items():
|
| 466 |
+
k_norm = pk if str(pk).startswith('plant') else f"plant{int(pk)}" if str(pk).isdigit() else str(pk)
|
| 467 |
+
_norm[k_norm] = int(pv.get('frame', 8))
|
| 468 |
+
before = len(plants)
|
| 469 |
+
plants = {
|
| 470 |
+
k: v for k, v in plants.items()
|
| 471 |
+
if len(k.split('_')) > 3 and k.split('_')[3] in _norm and k.split('_')[-1] == f"frame{_norm[k.split('_')[3]]}"
|
| 472 |
+
}
|
| 473 |
+
logger.info(f"Restricted feature extraction by mapping: {before} -> {len(plants)} items")
|
| 474 |
+
except Exception as e:
|
| 475 |
+
logger.warning(f"Failed to restrict by mapping frames: {e}")
|
| 476 |
+
# Optional: restrict features to per-plant preferred frame using internal frame rules
|
| 477 |
+
if respect_instance_frame_rules_for_features:
|
| 478 |
+
try:
|
| 479 |
+
# Keep this in sync with _apply_instance_masks frame_rules
|
| 480 |
+
frame_rules: Dict[str, int] = {
|
| 481 |
+
"plant33": 2,
|
| 482 |
+
"plant16": 4,
|
| 483 |
+
"plant19": 5,
|
| 484 |
+
"plant26": 8,
|
| 485 |
+
"plant27": 8,
|
| 486 |
+
"plant29": 8,
|
| 487 |
+
"plant35": 7,
|
| 488 |
+
"plant36": 6,
|
| 489 |
+
"plant37": 2,
|
| 490 |
+
"plant45": 5,
|
| 491 |
+
}
|
| 492 |
+
before = len(plants)
|
| 493 |
+
def _keep(k: str) -> bool:
|
| 494 |
+
parts = k.split('_')
|
| 495 |
+
if len(parts) < 2:
|
| 496 |
+
return False
|
| 497 |
+
plant_name = parts[-2]
|
| 498 |
+
frame_token = parts[-1]
|
| 499 |
+
if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
|
| 500 |
+
return False
|
| 501 |
+
desired = frame_rules.get(plant_name, 8)
|
| 502 |
+
return frame_token == f"frame{desired}"
|
| 503 |
+
plants = {k: v for k, v in plants.items() if _keep(k)}
|
| 504 |
+
logger.info(f"Restricted feature extraction by per-plant frame rules: {before} -> {len(plants)} items")
|
| 505 |
+
except Exception as e:
|
| 506 |
+
logger.warning(f"Failed to apply per-plant frame restriction for features: {e}")
|
| 507 |
+
|
| 508 |
+
# Optional: if features_frame_only set, keep only that frame's entries (global single frame)
|
| 509 |
+
if features_frame_only is not None:
|
| 510 |
+
frame_token = f"frame{features_frame_only}"
|
| 511 |
+
plants = {k: v for k, v in plants.items() if k.split('_')[-1] == frame_token}
|
| 512 |
+
logger.info(f"Restricted feature extraction to {len(plants)} items for {frame_token}")
|
| 513 |
+
|
| 514 |
+
# Optional: substitute feature input image from instance src_rules mapping (e.g., plant14 <- plant13)
|
| 515 |
+
if substitute_feature_image_from_instance_src:
|
| 516 |
+
try:
|
| 517 |
+
src_rules: Dict[str, str] = {
|
| 518 |
+
"plant13": "plant12",
|
| 519 |
+
"plant14": "plant13",
|
| 520 |
+
"plant15": "plant14",
|
| 521 |
+
"plant16": "plant15",
|
| 522 |
+
}
|
| 523 |
+
switched = 0
|
| 524 |
+
for key in list(plants.keys()):
|
| 525 |
+
parts = key.split('_')
|
| 526 |
+
if len(parts) < 5:
|
| 527 |
+
continue
|
| 528 |
+
date_key = "_".join(parts[:3])
|
| 529 |
+
plant_name = parts[3]
|
| 530 |
+
frame_token = parts[-1]
|
| 531 |
+
if plant_name not in src_rules:
|
| 532 |
+
continue
|
| 533 |
+
src_plant = src_rules[plant_name]
|
| 534 |
+
src_key = f"{date_key}_{src_plant}_{frame_token}"
|
| 535 |
+
if src_key not in plants:
|
| 536 |
+
continue
|
| 537 |
+
src_pdata = plants[src_key]
|
| 538 |
+
tgt_pdata = plants[key]
|
| 539 |
+
# Preserve the original composite used for segmentation for correct overlays later
|
| 540 |
+
try:
|
| 541 |
+
if 'composite' in tgt_pdata and 'segmentation_composite' not in tgt_pdata:
|
| 542 |
+
tgt_pdata['segmentation_composite'] = tgt_pdata['composite']
|
| 543 |
+
except Exception:
|
| 544 |
+
pass
|
| 545 |
+
# Swap feature inputs: composite and spectral bands
|
| 546 |
+
if 'composite' in src_pdata:
|
| 547 |
+
tgt_pdata['composite'] = src_pdata['composite']
|
| 548 |
+
if 'spectral_stack' in src_pdata:
|
| 549 |
+
tgt_pdata['spectral_stack'] = src_pdata['spectral_stack']
|
| 550 |
+
# Ensure mask aligns with substituted composite; resize if needed
|
| 551 |
+
try:
|
| 552 |
+
import cv2 as _cv2
|
| 553 |
+
import numpy as _np
|
| 554 |
+
comp = tgt_pdata.get('composite')
|
| 555 |
+
msk = tgt_pdata.get('mask')
|
| 556 |
+
if comp is not None and msk is not None:
|
| 557 |
+
ch, cw = comp.shape[:2]
|
| 558 |
+
mh, mw = msk.shape[:2]
|
| 559 |
+
if (mh, mw) != (ch, cw):
|
| 560 |
+
resized = _cv2.resize(msk.astype('uint8'), (cw, ch), interpolation=_cv2.INTER_NEAREST)
|
| 561 |
+
tgt_pdata['mask'] = resized
|
| 562 |
+
if 'soft_mask' in tgt_pdata and isinstance(tgt_pdata['soft_mask'], _np.ndarray):
|
| 563 |
+
tgt_pdata['soft_mask'] = (resized > 0).astype(_np.float32)
|
| 564 |
+
# Precompute masked composite with white background for saving
|
| 565 |
+
white = _np.full_like(comp, 255, dtype=_np.uint8)
|
| 566 |
+
result = white.copy()
|
| 567 |
+
result[tgt_pdata['mask'] > 0] = comp[tgt_pdata['mask'] > 0]
|
| 568 |
+
tgt_pdata['masked_composite'] = result
|
| 569 |
+
except Exception:
|
| 570 |
+
pass
|
| 571 |
+
switched += 1
|
| 572 |
+
if switched > 0:
|
| 573 |
+
logger.info(f"Substituted feature images from src_rules for {switched} items")
|
| 574 |
+
except Exception as e:
|
| 575 |
+
logger.warning(f"Failed feature-image substitution via src_rules: {e}")
|
| 576 |
+
# Step 4: Extract features
|
| 577 |
+
logger.info("Step 4/6: Extracting features...")
|
| 578 |
+
step_start = time.perf_counter()
|
| 579 |
+
# Stream-save mode: save outputs immediately after each plant's features when fast output is enabled
|
| 580 |
+
stream_save = False
|
| 581 |
+
try:
|
| 582 |
+
import os as _os
|
| 583 |
+
stream_save = bool(int(_os.environ.get('STREAM_SAVE', '0'))) or bool(getattr(self.output_manager, 'fast_mode', False))
|
| 584 |
+
except Exception:
|
| 585 |
+
stream_save = False
|
| 586 |
+
|
| 587 |
+
plants = self._extract_features(plants, stream_save=stream_save)
|
| 588 |
+
logger.info(f"Features done in {(time.perf_counter()-step_start):.2f}s")
|
| 589 |
+
|
| 590 |
+
# Step 5: Generate outputs (skip if already stream-saved)
|
| 591 |
+
if not stream_save:
|
| 592 |
+
logger.info("Step 5/6: Generating outputs...")
|
| 593 |
+
step_start = time.perf_counter()
|
| 594 |
+
self._generate_outputs(plants)
|
| 595 |
+
logger.info(f"Outputs done in {(time.perf_counter()-step_start):.2f}s")
|
| 596 |
+
|
| 597 |
+
# Step 6: Create summary
|
| 598 |
+
logger.info("Step 6/6: Creating summary...")
|
| 599 |
+
summary = self._create_summary(plants)
|
| 600 |
+
else:
|
| 601 |
+
logger.info("Segmentation-only mode: skipping texture/vegetation/morphology features and plots")
|
| 602 |
+
# Segmentation-only: generate only segmentation outputs and a minimal summary
|
| 603 |
+
logger.info("Step 4/4: Generating segmentation outputs (segmentation-only mode)...")
|
| 604 |
+
self._generate_outputs(plants)
|
| 605 |
+
summary = {
|
| 606 |
+
"total_plants": len(plants),
|
| 607 |
+
"successful_plants": len(plants),
|
| 608 |
+
"failed_plants": 0,
|
| 609 |
+
"features_extracted": {
|
| 610 |
+
"texture": 0,
|
| 611 |
+
"vegetation": 0,
|
| 612 |
+
"morphology": 0
|
| 613 |
+
}
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
total_time = time.perf_counter() - total_start
|
| 617 |
+
logger.info(f"Pipeline completed successfully in {total_time:.2f}s!")
|
| 618 |
+
return {
|
| 619 |
+
"plants": plants,
|
| 620 |
+
"summary": summary,
|
| 621 |
+
"config": self.config,
|
| 622 |
+
"timing_seconds": total_time
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
except Exception as e:
|
| 626 |
+
logger.error(f"Pipeline failed: {e}")
|
| 627 |
+
raise
|
| 628 |
+
|
| 629 |
+
def _export_white_background_maskouts(self, plants: Dict[str, Any], out_dir: Path) -> None:
|
| 630 |
+
"""Export RMBG composites with white background using the soft/binary masks.
|
| 631 |
+
|
| 632 |
+
Filenames follow: plantX_plantX_frameY_maskout.png so the final instance script can detect plants.
|
| 633 |
+
"""
|
| 634 |
+
# Clear any previous maskouts to avoid processing stale plants
|
| 635 |
+
try:
|
| 636 |
+
if out_dir.exists():
|
| 637 |
+
for p in out_dir.glob("*_maskout.png"):
|
| 638 |
+
try:
|
| 639 |
+
p.unlink()
|
| 640 |
+
except Exception:
|
| 641 |
+
pass
|
| 642 |
+
except Exception:
|
| 643 |
+
pass
|
| 644 |
+
count = 0
|
| 645 |
+
# Per-plant rule: use bbox-only (skip SAM2Long) for these plants on all dates except 2025_05_08
|
| 646 |
+
bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant33", "plant39", "plant42", "plant44", "plant46"}
|
| 647 |
+
date_exception = "2025_05_08"
|
| 648 |
+
for key, pdata in plants.items():
|
| 649 |
+
try:
|
| 650 |
+
# key format: "YYYY_MM_DD_plantX_frameY"
|
| 651 |
+
parts = key.split('_')
|
| 652 |
+
if len(parts) < 3:
|
| 653 |
+
continue
|
| 654 |
+
plant_name = parts[-2]
|
| 655 |
+
frame_token = parts[-1] # e.g., frame8
|
| 656 |
+
if not plant_name.startswith('plant') or not frame_token.startswith('frame'):
|
| 657 |
+
continue
|
| 658 |
+
date_key = "_".join(parts[:3])
|
| 659 |
+
if (plant_name in bbox_only_plants) and (date_key != date_exception):
|
| 660 |
+
# Skip exporting maskouts for bbox-only plants so SAM2Long does not run on them
|
| 661 |
+
continue
|
| 662 |
+
# Extract frame number
|
| 663 |
+
frame_num = int(frame_token.replace('frame', ''))
|
| 664 |
+
composite = pdata.get('composite')
|
| 665 |
+
mask = pdata.get('mask')
|
| 666 |
+
if composite is None or mask is None:
|
| 667 |
+
continue
|
| 668 |
+
# Ensure 3-channel BGR
|
| 669 |
+
if len(composite.shape) == 2:
|
| 670 |
+
composite_bgr = cv2.cvtColor(composite, cv2.COLOR_GRAY2BGR)
|
| 671 |
+
else:
|
| 672 |
+
composite_bgr = composite
|
| 673 |
+
out_img = composite_bgr.copy()
|
| 674 |
+
# Set background to white where mask == 0
|
| 675 |
+
out_img[mask == 0] = (255, 255, 255)
|
| 676 |
+
out_path = out_dir / f"{plant_name}_{plant_name}_{frame_token}_maskout.png"
|
| 677 |
+
cv2.imwrite(str(out_path), out_img)
|
| 678 |
+
count += 1
|
| 679 |
+
except Exception as e:
|
| 680 |
+
logger.warning(f"Failed to export maskout for {key}: {e}")
|
| 681 |
+
logger.info(f"Exported {count} white-background maskouts to {out_dir}")
|
| 682 |
+
|
| 683 |
+
def _segment_plants(self, plants: Dict[str, Any],
|
| 684 |
+
bbox_lookup: Optional[Dict[str, tuple]]) -> Dict[str, Any]:
|
| 685 |
+
"""Segment plants using BRIA model.
|
| 686 |
+
|
| 687 |
+
If bbox_lookup is provided and contains an entry for the plant (e.g., 'plant1'),
|
| 688 |
+
the image is cropped/masked to the bounding box region before segmentation and the
|
| 689 |
+
predicted mask is mapped back to the full image size. In bbox mode a largest
|
| 690 |
+
connected component post-processing is applied to obtain a clean target mask.
|
| 691 |
+
"""
|
| 692 |
+
total = len(plants)
|
| 693 |
+
iterator = plants.items()
|
| 694 |
+
if tqdm is not None:
|
| 695 |
+
iterator = tqdm(list(plants.items()), desc="Segmenting", total=total, unit="img", leave=False)
|
| 696 |
+
for idx, (key, pdata) in enumerate(iterator):
|
| 697 |
+
try:
|
| 698 |
+
# Get composite image
|
| 699 |
+
composite = pdata['composite']
|
| 700 |
+
h, w = composite.shape[:2]
|
| 701 |
+
|
| 702 |
+
# Determine bbox for this plant if available
|
| 703 |
+
parts = key.split('_')
|
| 704 |
+
plant_name = parts[-2] if len(parts) >= 2 else None
|
| 705 |
+
date_key = "_".join(parts[:3]) if len(parts) >= 3 else None # e.g., 2025_04_16
|
| 706 |
+
bbox = None
|
| 707 |
+
if bbox_lookup is not None and plant_name is not None:
|
| 708 |
+
# keys in bbox_lookup are typically like 'plant1'
|
| 709 |
+
bbox = bbox_lookup.get(plant_name)
|
| 710 |
+
# For plant33, ignore any bbox and run full-image segmentation on all dates except the exception
|
| 711 |
+
if plant_name == 'plant33' and date_key != '2025_05_08':
|
| 712 |
+
bbox = None
|
| 713 |
+
|
| 714 |
+
# Plants that should use the bounding box itself as the mask (skip model)
|
| 715 |
+
bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant39", "plant42", "plant44", "plant46"}
|
| 716 |
+
use_bbox_only = (plant_name in bbox_only_plants)
|
| 717 |
+
|
| 718 |
+
# Do not use bounding boxes for date 2025_05_08
|
| 719 |
+
if date_key == '2025_05_08':
|
| 720 |
+
bbox = None
|
| 721 |
+
|
| 722 |
+
if bbox is not None:
|
| 723 |
+
# Clamp bbox to image
|
| 724 |
+
x1, y1, x2, y2 = bbox
|
| 725 |
+
x1 = max(0, min(w, int(x1)))
|
| 726 |
+
x2 = max(0, min(w, int(x2)))
|
| 727 |
+
y1 = max(0, min(h, int(y1)))
|
| 728 |
+
y2 = max(0, min(h, int(y2)))
|
| 729 |
+
if x2 <= x1 or y2 <= y1:
|
| 730 |
+
x1, y1, x2, y2 = 0, 0, w, h
|
| 731 |
+
|
| 732 |
+
if use_bbox_only:
|
| 733 |
+
# Use the bbox as the mask directly (255 inside, 0 outside)
|
| 734 |
+
soft_full = np.zeros((h, w), dtype=np.float32)
|
| 735 |
+
soft_full[y1:y2, x1:x2] = 1.0
|
| 736 |
+
bin_full = np.zeros((h, w), dtype=np.uint8)
|
| 737 |
+
bin_full[y1:y2, x1:x2] = 255
|
| 738 |
+
pdata['soft_mask'] = soft_full
|
| 739 |
+
pdata['mask'] = bin_full
|
| 740 |
+
else:
|
| 741 |
+
# Segment inside the bbox region and map back
|
| 742 |
+
crop = composite[y1:y2, x1:x2]
|
| 743 |
+
soft_mask_crop = self.segmentation_manager.segment_image_soft(crop)
|
| 744 |
+
soft_full = np.zeros((h, w), dtype=np.float32)
|
| 745 |
+
soft_resized = cv2.resize(soft_mask_crop, (x2 - x1, y2 - y1), interpolation=cv2.INTER_LINEAR)
|
| 746 |
+
soft_full[y1:y2, x1:x2] = soft_resized
|
| 747 |
+
bin_full = (soft_full > 0.5).astype(np.uint8) * 255
|
| 748 |
+
try:
|
| 749 |
+
n_lbl, labels, stats, _ = cv2.connectedComponentsWithStats(bin_full, 8)
|
| 750 |
+
if n_lbl > 1:
|
| 751 |
+
largest = 1 + int(np.argmax(stats[1:, cv2.CC_STAT_AREA]))
|
| 752 |
+
bin_full = (labels == largest).astype(np.uint8) * 255
|
| 753 |
+
except Exception:
|
| 754 |
+
pass
|
| 755 |
+
pdata['soft_mask'] = soft_full.astype(np.float32)
|
| 756 |
+
pdata['mask'] = bin_full.astype(np.uint8)
|
| 757 |
+
else:
|
| 758 |
+
# Full-image segmentation (no bbox)
|
| 759 |
+
soft_mask = self.segmentation_manager.segment_image_soft(composite)
|
| 760 |
+
pdata['soft_mask'] = soft_mask
|
| 761 |
+
pdata['mask'] = (soft_mask * 255.0).astype(np.uint8)
|
| 762 |
+
|
| 763 |
+
# Progress log every 25 items and for first/last
|
| 764 |
+
if tqdm is None and (idx == 0 or (idx + 1) % 25 == 0 or (idx + 1) == total):
|
| 765 |
+
logger.info(f"Segmented {idx + 1}/{total}: {key}")
|
| 766 |
+
|
| 767 |
+
except Exception as e:
|
| 768 |
+
logger.error(f"Segmentation failed for {key}: {e}")
|
| 769 |
+
pdata['soft_mask'] = np.zeros(composite.shape[:2], dtype=np.float32)
|
| 770 |
+
pdata['mask'] = np.zeros(composite.shape[:2], dtype=np.uint8)
|
| 771 |
+
|
| 772 |
+
return plants
|
| 773 |
+
|
| 774 |
+
def _handle_occlusion(self, plants: Dict[str, Any]) -> Dict[str, Any]:
|
| 775 |
+
"""
|
| 776 |
+
Handle occlusion problems using SAM2Long.
|
| 777 |
+
|
| 778 |
+
This method groups plants by their base plant ID and processes
|
| 779 |
+
each plant's 13-frame sequence to differentiate target plant
|
| 780 |
+
from neighboring plants.
|
| 781 |
+
|
| 782 |
+
Args:
|
| 783 |
+
plants: Dictionary of plant data
|
| 784 |
+
|
| 785 |
+
Returns:
|
| 786 |
+
Updated plant data with occlusion handling results
|
| 787 |
+
"""
|
| 788 |
+
if self.occlusion_handler is None:
|
| 789 |
+
logger.warning("Occlusion handler not available, skipping occlusion handling")
|
| 790 |
+
return plants
|
| 791 |
+
|
| 792 |
+
# Group plants by base plant ID (e.g., "plant1" from "plant1_plant1_frame1")
|
| 793 |
+
plant_groups = {}
|
| 794 |
+
for key, pdata in plants.items():
|
| 795 |
+
# Extract plant ID from key like "plant1_plant1_frame1"
|
| 796 |
+
parts = key.split('_')
|
| 797 |
+
if len(parts) >= 3:
|
| 798 |
+
plant_id = parts[0] # e.g., "plant1"
|
| 799 |
+
if plant_id not in plant_groups:
|
| 800 |
+
plant_groups[plant_id] = []
|
| 801 |
+
plant_groups[plant_id].append((key, pdata))
|
| 802 |
+
|
| 803 |
+
logger.info(f"Processing {len(plant_groups)} plant groups for occlusion handling")
|
| 804 |
+
|
| 805 |
+
# Process each plant group
|
| 806 |
+
for plant_id, plant_frames in plant_groups.items():
|
| 807 |
+
try:
|
| 808 |
+
# Sort frames by frame number
|
| 809 |
+
plant_frames.sort(key=lambda x: int(x[0].split('_')[-1].replace('frame', '')))
|
| 810 |
+
|
| 811 |
+
if len(plant_frames) < 2:
|
| 812 |
+
logger.warning(f"Plant {plant_id} has only {len(plant_frames)} frames, skipping")
|
| 813 |
+
continue
|
| 814 |
+
|
| 815 |
+
# Extract frames and keys
|
| 816 |
+
frame_keys = [x[0] for x in plant_frames]
|
| 817 |
+
frames = [x[1]['composite'] for x in plant_frames]
|
| 818 |
+
|
| 819 |
+
logger.info(f"Processing plant {plant_id} with {len(frames)} frames")
|
| 820 |
+
|
| 821 |
+
# Process with SAM2Long
|
| 822 |
+
occlusion_results = self.occlusion_handler.segment_plant_sequence(
|
| 823 |
+
frames=frames,
|
| 824 |
+
target_plant_id=plant_id
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
# Update plant data with occlusion results
|
| 828 |
+
target_masks = occlusion_results['target_masks']
|
| 829 |
+
neighbor_masks = occlusion_results['neighbor_masks']
|
| 830 |
+
|
| 831 |
+
for i, (key, pdata) in enumerate(plant_frames):
|
| 832 |
+
if i < len(target_masks):
|
| 833 |
+
# Update mask with target plant only
|
| 834 |
+
pdata['original_mask'] = pdata.get('mask', np.zeros_like(target_masks[i]))
|
| 835 |
+
pdata['mask'] = target_masks[i]
|
| 836 |
+
pdata['neighbor_mask'] = neighbor_masks[i]
|
| 837 |
+
pdata['occlusion_handled'] = True
|
| 838 |
+
|
| 839 |
+
# Update soft mask as well
|
| 840 |
+
pdata['original_soft_mask'] = pdata.get('soft_mask', np.zeros_like(target_masks[i], dtype=np.float32))
|
| 841 |
+
pdata['soft_mask'] = (target_masks[i] / 255.0).astype(np.float32)
|
| 842 |
+
|
| 843 |
+
# Calculate and store occlusion metrics
|
| 844 |
+
metrics = self.occlusion_handler.get_occlusion_metrics(occlusion_results)
|
| 845 |
+
for key, pdata in plant_frames:
|
| 846 |
+
pdata['occlusion_metrics'] = metrics
|
| 847 |
+
|
| 848 |
+
logger.info(f"Plant {plant_id} occlusion handling completed")
|
| 849 |
+
logger.info(f" - Average occlusion ratio: {metrics['average_occlusion_ratio']:.3f}")
|
| 850 |
+
logger.info(f" - Frames with occlusion: {metrics['frames_with_occlusion']}")
|
| 851 |
+
|
| 852 |
+
except Exception as e:
|
| 853 |
+
logger.error(f"Occlusion handling failed for plant {plant_id}: {e}")
|
| 854 |
+
# Mark as failed but continue
|
| 855 |
+
for key, pdata in plant_frames:
|
| 856 |
+
pdata['occlusion_handled'] = False
|
| 857 |
+
pdata['occlusion_error'] = str(e)
|
| 858 |
+
|
| 859 |
+
return plants
|
| 860 |
+
|
| 861 |
+
def _extract_features(self, plants: Dict[str, Any], stream_save: bool = False) -> Dict[str, Any]:
|
| 862 |
+
"""Extract all features from plants.
|
| 863 |
+
|
| 864 |
+
If stream_save is True, save outputs for each plant immediately after
|
| 865 |
+
its features are computed to improve throughput and reduce peak memory.
|
| 866 |
+
"""
|
| 867 |
+
total = len(plants)
|
| 868 |
+
logger.info(f"Extracting features for {total} plants...")
|
| 869 |
+
iterator = plants.items()
|
| 870 |
+
if tqdm is not None:
|
| 871 |
+
iterator = tqdm(list(plants.items()), desc="Extracting features", total=total, unit="img", leave=False)
|
| 872 |
+
|
| 873 |
+
# Prepare output directories once if we're streaming saves
|
| 874 |
+
if stream_save:
|
| 875 |
+
try:
|
| 876 |
+
self.output_manager.create_output_directories()
|
| 877 |
+
except Exception:
|
| 878 |
+
pass
|
| 879 |
+
|
| 880 |
+
for idx, (key, pdata) in enumerate(iterator):
|
| 881 |
+
try:
|
| 882 |
+
logger.debug(f"Extracting features for {key}")
|
| 883 |
+
|
| 884 |
+
# Extract texture features
|
| 885 |
+
pdata['texture_features'] = self._extract_texture_features(pdata)
|
| 886 |
+
|
| 887 |
+
# Extract vegetation indices
|
| 888 |
+
pdata['vegetation_indices'] = self._extract_vegetation_indices(pdata)
|
| 889 |
+
|
| 890 |
+
# Extract morphological features
|
| 891 |
+
pdata['morphology_features'] = self._extract_morphology_features(pdata)
|
| 892 |
+
|
| 893 |
+
# Immediately save outputs for this plant if streaming is enabled
|
| 894 |
+
if stream_save:
|
| 895 |
+
try:
|
| 896 |
+
self.output_manager.save_plant_results(key, pdata)
|
| 897 |
+
except Exception as _e:
|
| 898 |
+
logger.error(f"Stream-save failed for {key}: {_e}")
|
| 899 |
+
|
| 900 |
+
logger.debug(f"Features extracted for {key}")
|
| 901 |
+
if tqdm is None and (idx == 0 or (idx + 1) % 25 == 0 or (idx + 1) == total):
|
| 902 |
+
logger.info(f"Extracted features for {idx + 1}/{total}: {key}")
|
| 903 |
+
|
| 904 |
+
except Exception as e:
|
| 905 |
+
logger.error(f"Feature extraction failed for {key}: {e}")
|
| 906 |
+
# Add empty features
|
| 907 |
+
pdata['texture_features'] = {}
|
| 908 |
+
pdata['vegetation_indices'] = {}
|
| 909 |
+
pdata['morphology_features'] = {}
|
| 910 |
+
|
| 911 |
+
return plants
|
| 912 |
+
|
| 913 |
+
def _extract_texture_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
|
| 914 |
+
"""Extract texture features for a single plant."""
|
| 915 |
+
features = {}
|
| 916 |
+
|
| 917 |
+
# Get bands to process
|
| 918 |
+
bands = ['color', 'nir', 'red_edge', 'red', 'green', 'pca']
|
| 919 |
+
|
| 920 |
+
for band in bands:
|
| 921 |
+
try:
|
| 922 |
+
# Prepare grayscale image
|
| 923 |
+
gray_image = self._prepare_band_image(pdata, band)
|
| 924 |
+
|
| 925 |
+
# Extract texture features
|
| 926 |
+
band_features = self.texture_extractor.extract_all_texture_features(gray_image)
|
| 927 |
+
|
| 928 |
+
# Compute statistics using mask3 → features_mask → mask
|
| 929 |
+
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 930 |
+
stats = self.texture_extractor.compute_texture_statistics(band_features, mask)
|
| 931 |
+
|
| 932 |
+
features[band] = {
|
| 933 |
+
'features': band_features,
|
| 934 |
+
'statistics': stats
|
| 935 |
+
}
|
| 936 |
+
|
| 937 |
+
except Exception as e:
|
| 938 |
+
logger.error(f"Texture extraction failed for band {band}: {e}")
|
| 939 |
+
features[band] = {'features': {}, 'statistics': {}}
|
| 940 |
+
|
| 941 |
+
return features
|
| 942 |
+
|
| 943 |
+
def _extract_vegetation_indices(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
|
| 944 |
+
"""Extract vegetation indices for a single plant."""
|
| 945 |
+
try:
|
| 946 |
+
spectral_stack = pdata.get('spectral_stack', {})
|
| 947 |
+
# Prefer mask3 → features_mask → mask
|
| 948 |
+
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 949 |
+
|
| 950 |
+
if not spectral_stack or mask is None:
|
| 951 |
+
return {}
|
| 952 |
+
|
| 953 |
+
return self.vegetation_extractor.compute_vegetation_indices(
|
| 954 |
+
spectral_stack, mask
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
except Exception as e:
|
| 958 |
+
logger.error(f"Vegetation index extraction failed: {e}")
|
| 959 |
+
return {}
|
| 960 |
+
|
| 961 |
+
def _extract_morphology_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
|
| 962 |
+
"""Extract morphological features for a single plant."""
|
| 963 |
+
try:
|
| 964 |
+
composite = pdata.get('composite')
|
| 965 |
+
# Prefer mask3 → features_mask → mask
|
| 966 |
+
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 967 |
+
|
| 968 |
+
if composite is None or mask is None:
|
| 969 |
+
return {}
|
| 970 |
+
|
| 971 |
+
return self.morphology_extractor.extract_morphology_features(
|
| 972 |
+
composite, mask
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
except Exception as e:
|
| 976 |
+
logger.error(f"Morphology feature extraction failed: {e}")
|
| 977 |
+
return {}
|
| 978 |
+
|
| 979 |
+
def _prepare_band_image(self, pdata: Dict[str, Any], band: str) -> np.ndarray:
|
| 980 |
+
"""Prepare grayscale image for a specific band."""
|
| 981 |
+
if band == 'color':
|
| 982 |
+
composite = pdata['composite']
|
| 983 |
+
# Prefer mask3 → features_mask → mask
|
| 984 |
+
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 985 |
+
if mask is not None:
|
| 986 |
+
masked = self.mask_handler.apply_mask_to_image(composite, mask)
|
| 987 |
+
return cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
|
| 988 |
+
else:
|
| 989 |
+
return cv2.cvtColor(composite, cv2.COLOR_BGR2GRAY)
|
| 990 |
+
|
| 991 |
+
elif band == 'pca':
|
| 992 |
+
# Create PCA from spectral bands
|
| 993 |
+
spectral_stack = pdata.get('spectral_stack', {})
|
| 994 |
+
# Prefer mask3 → features_mask → mask
|
| 995 |
+
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 996 |
+
|
| 997 |
+
if not spectral_stack:
|
| 998 |
+
return np.zeros((512, 512), dtype=np.uint8)
|
| 999 |
+
|
| 1000 |
+
# Stack bands
|
| 1001 |
+
bands_data = []
|
| 1002 |
+
for b in ['nir', 'red_edge', 'red', 'green']:
|
| 1003 |
+
if b in spectral_stack:
|
| 1004 |
+
arr = spectral_stack[b].squeeze(-1).astype(float)
|
| 1005 |
+
if mask is not None:
|
| 1006 |
+
arr = np.where(mask > 0, arr, np.nan)
|
| 1007 |
+
bands_data.append(arr)
|
| 1008 |
+
|
| 1009 |
+
if not bands_data:
|
| 1010 |
+
return np.zeros((512, 512), dtype=np.uint8)
|
| 1011 |
+
|
| 1012 |
+
# Create PCA
|
| 1013 |
+
full_stack = np.stack(bands_data, axis=-1)
|
| 1014 |
+
h, w, c = full_stack.shape
|
| 1015 |
+
flat = full_stack.reshape(-1, c)
|
| 1016 |
+
valid = ~np.isnan(flat).any(axis=1)
|
| 1017 |
+
|
| 1018 |
+
if valid.sum() == 0:
|
| 1019 |
+
return np.zeros((h, w), dtype=np.uint8)
|
| 1020 |
+
|
| 1021 |
+
vec = np.zeros(h * w)
|
| 1022 |
+
vec[valid] = PCA(n_components=1, whiten=True).fit_transform(
|
| 1023 |
+
flat[valid]
|
| 1024 |
+
).squeeze()
|
| 1025 |
+
|
| 1026 |
+
gray_f = vec.reshape(h, w)
|
| 1027 |
+
if mask is not None:
|
| 1028 |
+
m, M = gray_f[mask > 0].min(), gray_f[mask > 0].max()
|
| 1029 |
+
else:
|
| 1030 |
+
m, M = gray_f.min(), gray_f.max()
|
| 1031 |
+
|
| 1032 |
+
if M > m:
|
| 1033 |
+
gray = ((gray_f - m) / (M - m) * 255).astype(np.uint8)
|
| 1034 |
+
else:
|
| 1035 |
+
gray = np.zeros_like(gray_f, dtype=np.uint8)
|
| 1036 |
+
|
| 1037 |
+
return gray
|
| 1038 |
+
|
| 1039 |
+
else:
|
| 1040 |
+
# Individual spectral band
|
| 1041 |
+
spectral_stack = pdata.get('spectral_stack', {})
|
| 1042 |
+
# Prefer mask3 → features_mask → mask
|
| 1043 |
+
mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
|
| 1044 |
+
|
| 1045 |
+
if band not in spectral_stack:
|
| 1046 |
+
return np.zeros((512, 512), dtype=np.uint8)
|
| 1047 |
+
|
| 1048 |
+
arr = spectral_stack[band].squeeze(-1).astype(float)
|
| 1049 |
+
if mask is not None:
|
| 1050 |
+
arr = np.where(mask > 0, arr, np.nan)
|
| 1051 |
+
|
| 1052 |
+
if mask is not None:
|
| 1053 |
+
m, M = np.nanmin(arr), np.nanmax(arr)
|
| 1054 |
+
else:
|
| 1055 |
+
m, M = arr.min(), arr.max()
|
| 1056 |
+
|
| 1057 |
+
if M > m:
|
| 1058 |
+
gray = ((np.nan_to_num(arr, nan=m) - m) / (M - m) * 255).astype(np.uint8)
|
| 1059 |
+
else:
|
| 1060 |
+
gray = np.zeros_like(arr, dtype=np.uint8)
|
| 1061 |
+
|
| 1062 |
+
return gray
|
| 1063 |
+
|
| 1064 |
+
def _generate_outputs(self, plants: Dict[str, Any]) -> None:
|
| 1065 |
+
"""Generate all output files and visualizations."""
|
| 1066 |
+
self.output_manager.create_output_directories()
|
| 1067 |
+
|
| 1068 |
+
for key, pdata in plants.items():
|
| 1069 |
+
try:
|
| 1070 |
+
logger.debug(f"Generating outputs for {key}")
|
| 1071 |
+
self.output_manager.save_plant_results(key, pdata)
|
| 1072 |
+
except Exception as e:
|
| 1073 |
+
logger.error(f"Output generation failed for {key}: {e}")
|
| 1074 |
+
|
| 1075 |
+
def _create_summary(self, plants: Dict[str, Any]) -> Dict[str, Any]:
|
| 1076 |
+
"""Create summary of pipeline results."""
|
| 1077 |
+
summary = {
|
| 1078 |
+
"total_plants": len(plants),
|
| 1079 |
+
"successful_plants": 0,
|
| 1080 |
+
"failed_plants": 0,
|
| 1081 |
+
"features_extracted": {
|
| 1082 |
+
"texture": 0,
|
| 1083 |
+
"vegetation": 0,
|
| 1084 |
+
"morphology": 0
|
| 1085 |
+
}
|
| 1086 |
+
}
|
| 1087 |
+
|
| 1088 |
+
for key, pdata in plants.items():
|
| 1089 |
+
try:
|
| 1090 |
+
# Check if features were extracted
|
| 1091 |
+
if pdata.get('texture_features'):
|
| 1092 |
+
summary["features_extracted"]["texture"] += 1
|
| 1093 |
+
if pdata.get('vegetation_indices'):
|
| 1094 |
+
summary["features_extracted"]["vegetation"] += 1
|
| 1095 |
+
if pdata.get('morphology_features'):
|
| 1096 |
+
summary["features_extracted"]["morphology"] += 1
|
| 1097 |
+
|
| 1098 |
+
summary["successful_plants"] += 1
|
| 1099 |
+
|
| 1100 |
+
except Exception:
|
| 1101 |
+
summary["failed_plants"] += 1
|
| 1102 |
+
|
| 1103 |
+
return summary
|
| 1104 |
+
|
| 1105 |
+
def _apply_instance_masks(self, plants: Dict[str, Any], instance_results_dir: Path) -> None:
|
| 1106 |
+
"""Replace segmentation masks with SAM2Long instance masks using track_1.
|
| 1107 |
+
|
| 1108 |
+
Expects files under instance_results_dir/plantX/track_1/frame_YY_mask.png.
|
| 1109 |
+
"""
|
| 1110 |
+
# Default and per-plant overrides for source plant, track and preferred frame
|
| 1111 |
+
default_track = "track_0"
|
| 1112 |
+
src_rules: Dict[str, str] = {
|
| 1113 |
+
"plant13": "plant12",
|
| 1114 |
+
"plant14": "plant13",
|
| 1115 |
+
"plant15": "plant14",
|
| 1116 |
+
"plant16": "plant15",
|
| 1117 |
+
}
|
| 1118 |
+
track_rules: Dict[str, str] = {
|
| 1119 |
+
# explicit track rules
|
| 1120 |
+
"plant1": "track_0",
|
| 1121 |
+
"plant4": "track_0",
|
| 1122 |
+
"plant9": "track_3",
|
| 1123 |
+
"plant13": "track_1",
|
| 1124 |
+
"plant14": "track_0",
|
| 1125 |
+
"plant15": "track_0",
|
| 1126 |
+
"plant16": "track_0",
|
| 1127 |
+
"plant18": "track_0",
|
| 1128 |
+
"plant19": "track_0",
|
| 1129 |
+
"plant23": "track_1",
|
| 1130 |
+
"plant26": "track_0",
|
| 1131 |
+
"plant27": "track_0",
|
| 1132 |
+
"plant29": "track_0",
|
| 1133 |
+
"plant31": "track_1",
|
| 1134 |
+
"plant34": "track_1",
|
| 1135 |
+
"plant35": "track_1",
|
| 1136 |
+
"plant36": "track_0",
|
| 1137 |
+
"plant37": "track_1",
|
| 1138 |
+
"plant38": "track_0",
|
| 1139 |
+
"plant39": "track_1",
|
| 1140 |
+
"plant40": "track_0",
|
| 1141 |
+
"plant41": "track_1",
|
| 1142 |
+
"plant42": "track_0",
|
| 1143 |
+
"plant43": "track_0",
|
| 1144 |
+
"plant45": "track_0",
|
| 1145 |
+
}
|
| 1146 |
+
frame_rules: Dict[str, int] = {
|
| 1147 |
+
# preferred frame overrides (1-based)
|
| 1148 |
+
"plant13": 8,
|
| 1149 |
+
"plant14": 8,
|
| 1150 |
+
"plant15": 8,
|
| 1151 |
+
"plant33": 2,
|
| 1152 |
+
"plant16": 4,
|
| 1153 |
+
"plant19": 5,
|
| 1154 |
+
"plant26": 8,
|
| 1155 |
+
"plant27": 8,
|
| 1156 |
+
"plant29": 8,
|
| 1157 |
+
"plant35": 7,
|
| 1158 |
+
"plant36": 6,
|
| 1159 |
+
"plant37": 2,
|
| 1160 |
+
"plant45": 5,
|
| 1161 |
+
}
|
| 1162 |
+
# Per-plant rule: skip applying instance masks (keep bbox/BRIA mask) on all dates except 2025_05_08
|
| 1163 |
+
bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant33", "plant39", "plant42", "plant44", "plant46"}
|
| 1164 |
+
date_exception = "2025_05_08"
|
| 1165 |
+
|
| 1166 |
+
for key, pdata in plants.items():
|
| 1167 |
+
try:
|
| 1168 |
+
parts = key.split('_')
|
| 1169 |
+
if len(parts) < 3:
|
| 1170 |
+
continue
|
| 1171 |
+
plant_name = parts[-2]
|
| 1172 |
+
frame_token = parts[-1] # frame8
|
| 1173 |
+
if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
|
| 1174 |
+
continue
|
| 1175 |
+
date_key = "_".join(parts[:3])
|
| 1176 |
+
if (plant_name in bbox_only_plants) and (date_key != date_exception):
|
| 1177 |
+
# Do not override masks for bbox-only plants
|
| 1178 |
+
continue
|
| 1179 |
+
frame_num = int(frame_token.replace('frame', ''))
|
| 1180 |
+
# Resolve source plant, track and desired frame
|
| 1181 |
+
src_plant = src_rules.get(plant_name, plant_name)
|
| 1182 |
+
track_name = track_rules.get(plant_name, default_track)
|
| 1183 |
+
desired_frame = frame_rules.get(plant_name, frame_num)
|
| 1184 |
+
plant_dir = Path(instance_results_dir) / src_plant / track_name
|
| 1185 |
+
mask_path = plant_dir / f"frame_{desired_frame:02d}_mask.png"
|
| 1186 |
+
if not mask_path.exists():
|
| 1187 |
+
# Fallback to current frame if override not found
|
| 1188 |
+
fallback = plant_dir / f"frame_{frame_num:02d}_mask.png"
|
| 1189 |
+
if fallback.exists():
|
| 1190 |
+
mask_path = fallback
|
| 1191 |
+
else:
|
| 1192 |
+
# Last-resort: pick any available frame mask in the track directory
|
| 1193 |
+
try:
|
| 1194 |
+
candidates = sorted(plant_dir.glob("frame_*_mask.png"))
|
| 1195 |
+
if len(candidates) > 0:
|
| 1196 |
+
mask_path = candidates[0]
|
| 1197 |
+
else:
|
| 1198 |
+
continue
|
| 1199 |
+
except Exception:
|
| 1200 |
+
continue
|
| 1201 |
+
inst_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
| 1202 |
+
if inst_mask is None:
|
| 1203 |
+
continue
|
| 1204 |
+
# Ensure binary uint8 0/255
|
| 1205 |
+
inst_mask_bin = (inst_mask > 0).astype(np.uint8) * 255
|
| 1206 |
+
pdata['original_mask'] = pdata.get('mask', inst_mask_bin.copy())
|
| 1207 |
+
pdata['mask'] = inst_mask_bin
|
| 1208 |
+
pdata['original_soft_mask'] = pdata.get('soft_mask', (inst_mask_bin / 255.0).astype(np.float32))
|
| 1209 |
+
pdata['soft_mask'] = (inst_mask_bin / 255.0).astype(np.float32)
|
| 1210 |
+
pdata['instance_applied'] = True
|
| 1211 |
+
|
| 1212 |
+
# Build mask3 = external(mask) AND BRIA(original_mask)
|
| 1213 |
+
try:
|
| 1214 |
+
_m1 = pdata.get('mask')
|
| 1215 |
+
_m2 = pdata.get('original_mask')
|
| 1216 |
+
if isinstance(_m1, np.ndarray) and isinstance(_m2, np.ndarray):
|
| 1217 |
+
_m1b = (_m1.astype(np.uint8) > 0)
|
| 1218 |
+
_m2b = (_m2.astype(np.uint8) > 0)
|
| 1219 |
+
mask3 = (_m1b & _m2b).astype(np.uint8) * 255
|
| 1220 |
+
pdata['mask3'] = mask3
|
| 1221 |
+
pdata['features_mask'] = mask3
|
| 1222 |
+
except Exception:
|
| 1223 |
+
pass
|
| 1224 |
+
|
| 1225 |
+
# After applying instance masks, also overwrite the composite and spectral stack
|
| 1226 |
+
# with the source plant's raw image (desired frame preferred) so that
|
| 1227 |
+
# feature extraction and saved originals/overlays are consistent with the mask source.
|
| 1228 |
+
try:
|
| 1229 |
+
if plant_name in src_rules:
|
| 1230 |
+
date_key = "_".join(parts[:3])
|
| 1231 |
+
src_key_desired = f"{date_key}_{src_plant}_frame{desired_frame}"
|
| 1232 |
+
src_key_same = f"{date_key}_{src_plant}_{frame_token}"
|
| 1233 |
+
copy_from = plants.get(src_key_desired) or plants.get(src_key_same)
|
| 1234 |
+
if copy_from is None:
|
| 1235 |
+
# Fallback: load source composite from filesystem if not present in plants dict
|
| 1236 |
+
try:
|
| 1237 |
+
from PIL import Image as _Image
|
| 1238 |
+
_date_folder = date_key.replace('_', '-')
|
| 1239 |
+
_date_dir = Path(self.config.paths.input_folder)
|
| 1240 |
+
if _date_dir.name != _date_folder:
|
| 1241 |
+
_date_dir = _date_dir / _date_folder
|
| 1242 |
+
_frame_path = _date_dir / src_plant / f"{src_plant}_frame{desired_frame}.tif"
|
| 1243 |
+
if not _frame_path.exists():
|
| 1244 |
+
_frame_path = _date_dir / src_plant / f"{src_plant}_frame{frame_num}.tif"
|
| 1245 |
+
if _frame_path.exists():
|
| 1246 |
+
_img = _Image.open(str(_frame_path))
|
| 1247 |
+
# Process to composite using preprocessor
|
| 1248 |
+
comp, spec = self.preprocessor.process_raw_image(_img)
|
| 1249 |
+
copy_from = {"composite": comp, "spectral_stack": spec}
|
| 1250 |
+
except Exception:
|
| 1251 |
+
copy_from = None
|
| 1252 |
+
if copy_from is not None:
|
| 1253 |
+
# Preserve the segmentation-time composite once
|
| 1254 |
+
if 'composite' in pdata and 'segmentation_composite' not in pdata:
|
| 1255 |
+
pdata['segmentation_composite'] = pdata['composite']
|
| 1256 |
+
if 'composite' in copy_from:
|
| 1257 |
+
pdata['composite'] = copy_from['composite']
|
| 1258 |
+
if 'spectral_stack' in copy_from:
|
| 1259 |
+
pdata['spectral_stack'] = copy_from['spectral_stack']
|
| 1260 |
+
# Ensure mask size matches the copied composite
|
| 1261 |
+
ch, cw = pdata['composite'].shape[:2]
|
| 1262 |
+
mh, mw = pdata['mask'].shape[:2]
|
| 1263 |
+
if (mh, mw) != (ch, cw):
|
| 1264 |
+
pdata['mask'] = cv2.resize(pdata['mask'].astype('uint8'), (cw, ch), interpolation=cv2.INTER_NEAREST)
|
| 1265 |
+
pdata['soft_mask'] = (pdata['mask'] > 0).astype(np.float32)
|
| 1266 |
+
except Exception:
|
| 1267 |
+
pass
|
| 1268 |
+
except Exception as e:
|
| 1269 |
+
logger.debug(f"Instance mask apply failed for {key}: {e}")
|
| 1270 |
+
|
| 1271 |
+
def _apply_instance_masks_from_mapping(self, plants: Dict[str, Any], mapping_file: Path) -> None:
|
| 1272 |
+
"""Apply instance masks using an explicit mapping file with absolute paths.
|
| 1273 |
+
|
| 1274 |
+
mapping JSON structure:
|
| 1275 |
+
{
|
| 1276 |
+
"plant1": {"frame": 8, "mask_path": "/abs/path/to/plant1/track_X/frame_08_mask.png"},
|
| 1277 |
+
"plant2": {"frame": 8, "mask_path": "/abs/path/.../frame_08_mask.png"},
|
| 1278 |
+
...
|
| 1279 |
+
}
|
| 1280 |
+
If a plant's mapping specifies a different frame, only entries matching that frame are updated.
|
| 1281 |
+
"""
|
| 1282 |
+
import json
|
| 1283 |
+
if not mapping_file.exists():
|
| 1284 |
+
raise FileNotFoundError(f"Mapping file not found: {mapping_file}")
|
| 1285 |
+
with open(mapping_file, "r") as f:
|
| 1286 |
+
mapping = json.load(f)
|
| 1287 |
+
# Normalize mapping plant keys to names like 'plantX'
|
| 1288 |
+
norm_map = {}
|
| 1289 |
+
for k, v in mapping.items():
|
| 1290 |
+
k_norm = k if str(k).startswith("plant") else f"plant{int(k)}" if str(k).isdigit() else str(k)
|
| 1291 |
+
norm_map[k_norm] = v
|
| 1292 |
+
|
| 1293 |
+
for key, pdata in plants.items():
|
| 1294 |
+
try:
|
| 1295 |
+
parts = key.split('_')
|
| 1296 |
+
if len(parts) < 3:
|
| 1297 |
+
continue
|
| 1298 |
+
plant_name = parts[-2]
|
| 1299 |
+
frame_token = parts[-1]
|
| 1300 |
+
if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
|
| 1301 |
+
continue
|
| 1302 |
+
frame_num = int(frame_token.replace('frame', ''))
|
| 1303 |
+
if plant_name not in norm_map:
|
| 1304 |
+
continue
|
| 1305 |
+
entry = norm_map[plant_name]
|
| 1306 |
+
target_frame = int(entry.get("frame", frame_num))
|
| 1307 |
+
if frame_num != target_frame:
|
| 1308 |
+
# Only update the designated frame for this plant
|
| 1309 |
+
continue
|
| 1310 |
+
mask_path_str = entry.get("mask_path")
|
| 1311 |
+
if not mask_path_str:
|
| 1312 |
+
continue
|
| 1313 |
+
mask_path = Path(mask_path_str)
|
| 1314 |
+
if not mask_path.exists():
|
| 1315 |
+
logger.warning(f"Mask path not found for {plant_name} {frame_token}: {mask_path}")
|
| 1316 |
+
continue
|
| 1317 |
+
inst_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
| 1318 |
+
if inst_mask is None:
|
| 1319 |
+
continue
|
| 1320 |
+
inst_mask_bin = (inst_mask > 0).astype(np.uint8) * 255
|
| 1321 |
+
pdata['original_mask'] = pdata.get('mask', inst_mask_bin.copy())
|
| 1322 |
+
pdata['mask'] = inst_mask_bin
|
| 1323 |
+
pdata['original_soft_mask'] = pdata.get('soft_mask', (inst_mask_bin / 255.0).astype(np.float32))
|
| 1324 |
+
pdata['soft_mask'] = (inst_mask_bin / 255.0).astype(np.float32)
|
| 1325 |
+
pdata['instance_applied'] = True
|
| 1326 |
+
|
| 1327 |
+
# Build mask3 = external(mask) AND BRIA(original_mask)
|
| 1328 |
+
try:
|
| 1329 |
+
_m1 = pdata.get('mask')
|
| 1330 |
+
_m2 = pdata.get('original_mask')
|
| 1331 |
+
if isinstance(_m1, np.ndarray) and isinstance(_m2, np.ndarray):
|
| 1332 |
+
_m1b = (_m1.astype(np.uint8) > 0)
|
| 1333 |
+
_m2b = (_m2.astype(np.uint8) > 0)
|
| 1334 |
+
mask3 = (_m1b & _m2b).astype(np.uint8) * 255
|
| 1335 |
+
pdata['mask3'] = mask3
|
| 1336 |
+
pdata['features_mask'] = mask3
|
| 1337 |
+
except Exception:
|
| 1338 |
+
pass
|
| 1339 |
+
except Exception as e:
|
| 1340 |
+
logger.debug(f"Instance mapping apply failed for {key}: {e}")
|
| 1341 |
+
|
| 1342 |
+
|
| 1343 |
+
def run_pipeline(config_path: str, load_all_frames: bool = False, segmentation_only: bool = False, filter_plants: Optional[List[str]] = None) -> Dict[str, Any]:
|
| 1344 |
+
"""
|
| 1345 |
+
Convenience function to run the pipeline.
|
| 1346 |
+
|
| 1347 |
+
Args:
|
| 1348 |
+
config_path: Path to configuration file
|
| 1349 |
+
load_all_frames: Whether to load all frames or selected frames
|
| 1350 |
+
segmentation_only: If True, run segmentation only and skip feature extraction
|
| 1351 |
+
|
| 1352 |
+
Returns:
|
| 1353 |
+
Pipeline results
|
| 1354 |
+
"""
|
| 1355 |
+
pipeline = SorghumPipeline(config_path)
|
| 1356 |
+
return pipeline.run(load_all_frames, segmentation_only, filter_plants)
|
| 1357 |
+
|
| 1358 |
+
|
| 1359 |
+
if __name__ == "__main__":
|
| 1360 |
+
import sys
|
| 1361 |
+
|
| 1362 |
+
config_path = sys.argv[1] if len(sys.argv) > 1 else "config.yml"
|
| 1363 |
+
load_all = "--all" in sys.argv
|
| 1364 |
+
seg_only = "--seg-only" in sys.argv
|
| 1365 |
+
# Basic arg parse for --plant=<name>
|
| 1366 |
+
plant_filter = None
|
| 1367 |
+
for arg in sys.argv[1:]:
|
| 1368 |
+
if arg.startswith("--plant="):
|
| 1369 |
+
plant_filter = [arg.split("=", 1)[1]]
|
| 1370 |
+
|
| 1371 |
+
try:
|
| 1372 |
+
results = run_pipeline(config_path, load_all, seg_only, plant_filter)
|
| 1373 |
+
print("Pipeline completed successfully!")
|
| 1374 |
+
print(f"Processed {results['summary']['total_plants']} plants")
|
| 1375 |
+
except Exception as e:
|
| 1376 |
+
print(f"Pipeline failed: {e}")
|
| 1377 |
+
sys.exit(1)
|
sorghum_pipeline/segmentation/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Segmentation modules for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This package contains segmentation functionality including:
|
| 5 |
+
- BRIA model integration
|
| 6 |
+
- Mask post-processing
|
| 7 |
+
- Segmentation validation
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from .manager import SegmentationManager
|
| 11 |
+
|
| 12 |
+
__all__ = ["SegmentationManager"]
|
sorghum_pipeline/segmentation/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (482 Bytes). View file
|
|
|
sorghum_pipeline/segmentation/__pycache__/advanced_occlusion_handler.cpython-312.pyc
ADDED
|
Binary file (26.3 kB). View file
|
|
|
sorghum_pipeline/segmentation/__pycache__/leaf_occlusion_handler.cpython-312.pyc
ADDED
|
Binary file (27 kB). View file
|
|
|
sorghum_pipeline/segmentation/__pycache__/manager.cpython-312.pyc
ADDED
|
Binary file (14.5 kB). View file
|
|
|
sorghum_pipeline/segmentation/__pycache__/occlusion_handler.cpython-312.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
sorghum_pipeline/segmentation/manager.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Segmentation manager for the Sorghum Pipeline.
|
| 3 |
+
|
| 4 |
+
This module handles image segmentation using the BRIA model
|
| 5 |
+
and provides post-processing capabilities.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
from transformers import AutoModelForImageSegmentation
|
| 14 |
+
from typing import Optional, Tuple
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SegmentationManager:
|
| 21 |
+
"""Manages image segmentation using BRIA model."""
|
| 22 |
+
|
| 23 |
+
def __init__(self,
|
| 24 |
+
model_name: str = "briaai/RMBG-2.0",
|
| 25 |
+
device: str = "auto",
|
| 26 |
+
threshold: float = 0.5,
|
| 27 |
+
trust_remote_code: bool = True,
|
| 28 |
+
cache_dir: Optional[str] = None,
|
| 29 |
+
local_files_only: bool = False):
|
| 30 |
+
"""
|
| 31 |
+
Initialize segmentation manager.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
model_name: Name of the BRIA model
|
| 35 |
+
device: Device to run model on ("auto", "cpu", "cuda")
|
| 36 |
+
threshold: Segmentation threshold
|
| 37 |
+
trust_remote_code: Whether to trust remote code
|
| 38 |
+
cache_dir: Hugging Face cache directory for model weights
|
| 39 |
+
local_files_only: If True, only load from local cache
|
| 40 |
+
"""
|
| 41 |
+
self.model_name = model_name
|
| 42 |
+
self.threshold = threshold
|
| 43 |
+
self.trust_remote_code = trust_remote_code
|
| 44 |
+
self.cache_dir = cache_dir
|
| 45 |
+
self.local_files_only = local_files_only
|
| 46 |
+
|
| 47 |
+
# Determine device
|
| 48 |
+
if device == "auto":
|
| 49 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
+
else:
|
| 51 |
+
self.device = device
|
| 52 |
+
|
| 53 |
+
# Initialize model
|
| 54 |
+
self.model = None
|
| 55 |
+
self.transform = None
|
| 56 |
+
self._load_model()
|
| 57 |
+
|
| 58 |
+
def _load_model(self):
|
| 59 |
+
"""Load the BRIA segmentation model."""
|
| 60 |
+
try:
|
| 61 |
+
logger.info(f"Loading BRIA model: {self.model_name}")
|
| 62 |
+
|
| 63 |
+
self.model = AutoModelForImageSegmentation.from_pretrained(
|
| 64 |
+
self.model_name,
|
| 65 |
+
trust_remote_code=self.trust_remote_code,
|
| 66 |
+
cache_dir=self.cache_dir if self.cache_dir else None,
|
| 67 |
+
local_files_only=self.local_files_only,
|
| 68 |
+
).eval().to(self.device)
|
| 69 |
+
|
| 70 |
+
# Define image transform
|
| 71 |
+
self.transform = transforms.Compose([
|
| 72 |
+
transforms.Resize((1024, 1024)),
|
| 73 |
+
transforms.ToTensor(),
|
| 74 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
logger.info("BRIA model loaded successfully")
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.error(f"Failed to load BRIA model: {e}")
|
| 81 |
+
raise
|
| 82 |
+
|
| 83 |
+
def segment_image(self, image: np.ndarray) -> np.ndarray:
|
| 84 |
+
"""
|
| 85 |
+
Segment an image using the BRIA model.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
image: Input image (BGR format)
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Binary mask (0/255)
|
| 92 |
+
"""
|
| 93 |
+
if self.model is None:
|
| 94 |
+
raise RuntimeError("Model not loaded")
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
# Convert BGR to RGB
|
| 98 |
+
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 99 |
+
pil_image = Image.fromarray(rgb_image)
|
| 100 |
+
|
| 101 |
+
# Apply transform
|
| 102 |
+
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
|
| 103 |
+
|
| 104 |
+
# Run inference
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
predictions = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
|
| 107 |
+
|
| 108 |
+
# Apply threshold
|
| 109 |
+
mask = (predictions > self.threshold).astype(np.uint8) * 255
|
| 110 |
+
|
| 111 |
+
# Resize back to original size
|
| 112 |
+
original_size = (image.shape[1], image.shape[0]) # (width, height)
|
| 113 |
+
mask_resized = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST)
|
| 114 |
+
|
| 115 |
+
return mask_resized
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"Segmentation failed: {e}")
|
| 119 |
+
# Return empty mask
|
| 120 |
+
return np.zeros(image.shape[:2], dtype=np.uint8)
|
| 121 |
+
|
| 122 |
+
def segment_image_soft(self, image: np.ndarray) -> np.ndarray:
|
| 123 |
+
"""
|
| 124 |
+
Segment an image and return a soft mask in [0, 1] resized to original size.
|
| 125 |
+
No thresholding or post-processing is applied.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
image: Input image (BGR format)
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Float mask in [0,1] with shape (H, W)
|
| 132 |
+
"""
|
| 133 |
+
if self.model is None:
|
| 134 |
+
raise RuntimeError("Model not loaded")
|
| 135 |
+
try:
|
| 136 |
+
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 137 |
+
pil_image = Image.fromarray(rgb_image)
|
| 138 |
+
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
|
| 141 |
+
original_size = (image.shape[1], image.shape[0])
|
| 142 |
+
soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR)
|
| 143 |
+
return np.clip(soft_mask, 0.0, 1.0)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Soft segmentation failed: {e}")
|
| 146 |
+
return np.zeros(image.shape[:2], dtype=np.float32)
|
| 147 |
+
|
| 148 |
+
def post_process_mask(self, mask: np.ndarray,
|
| 149 |
+
min_area: int = 1000,
|
| 150 |
+
kernel_size: int = 5) -> np.ndarray:
|
| 151 |
+
"""
|
| 152 |
+
Post-process segmentation mask.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
mask: Input mask
|
| 156 |
+
min_area: Minimum area for connected components
|
| 157 |
+
kernel_size: Kernel size for morphological operations
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Post-processed mask
|
| 161 |
+
"""
|
| 162 |
+
try:
|
| 163 |
+
# Morphological opening to remove noise
|
| 164 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
| 165 |
+
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 166 |
+
|
| 167 |
+
# Remove small connected components
|
| 168 |
+
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
| 169 |
+
opened, connectivity=8
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
processed_mask = np.zeros_like(opened)
|
| 173 |
+
for label in range(1, num_labels): # Skip background
|
| 174 |
+
if stats[label, cv2.CC_STAT_AREA] >= min_area:
|
| 175 |
+
processed_mask[labels == label] = 255
|
| 176 |
+
|
| 177 |
+
return processed_mask
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.error(f"Mask post-processing failed: {e}")
|
| 181 |
+
return mask
|
| 182 |
+
|
| 183 |
+
def keep_largest_component(self, mask: np.ndarray) -> np.ndarray:
|
| 184 |
+
"""
|
| 185 |
+
Keep only the largest connected component.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
mask: Input mask
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Mask with only the largest component
|
| 192 |
+
"""
|
| 193 |
+
try:
|
| 194 |
+
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8)
|
| 195 |
+
|
| 196 |
+
if num_labels <= 1:
|
| 197 |
+
return mask
|
| 198 |
+
|
| 199 |
+
# Find the largest component (excluding background)
|
| 200 |
+
areas = stats[1:, cv2.CC_STAT_AREA]
|
| 201 |
+
largest_label = 1 + np.argmax(areas)
|
| 202 |
+
|
| 203 |
+
# Create mask with only the largest component
|
| 204 |
+
largest_mask = (labels == largest_label).astype(np.uint8) * 255
|
| 205 |
+
|
| 206 |
+
return largest_mask
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.error(f"Largest component extraction failed: {e}")
|
| 210 |
+
return mask
|
| 211 |
+
|
| 212 |
+
def validate_mask(self, mask: np.ndarray) -> bool:
|
| 213 |
+
"""
|
| 214 |
+
Validate segmentation mask.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
mask: Mask to validate
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
True if valid, False otherwise
|
| 221 |
+
"""
|
| 222 |
+
if mask is None:
|
| 223 |
+
return False
|
| 224 |
+
|
| 225 |
+
if not isinstance(mask, np.ndarray):
|
| 226 |
+
return False
|
| 227 |
+
|
| 228 |
+
if mask.ndim != 2:
|
| 229 |
+
return False
|
| 230 |
+
|
| 231 |
+
if mask.dtype not in [np.uint8, np.bool_]:
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
# Check if mask has any foreground pixels
|
| 235 |
+
if np.sum(mask > 0) == 0:
|
| 236 |
+
logger.warning("Mask has no foreground pixels")
|
| 237 |
+
return False
|
| 238 |
+
|
| 239 |
+
return True
|
| 240 |
+
|
| 241 |
+
def get_mask_properties(self, mask: np.ndarray) -> dict:
|
| 242 |
+
"""
|
| 243 |
+
Get properties of the segmentation mask.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
mask: Binary mask
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Dictionary of mask properties
|
| 250 |
+
"""
|
| 251 |
+
if not self.validate_mask(mask):
|
| 252 |
+
return {}
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
# Convert to binary
|
| 256 |
+
binary_mask = (mask > 127).astype(np.uint8)
|
| 257 |
+
|
| 258 |
+
# Calculate properties
|
| 259 |
+
area = np.sum(binary_mask)
|
| 260 |
+
perimeter = 0
|
| 261 |
+
|
| 262 |
+
# Find contours
|
| 263 |
+
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 264 |
+
if contours:
|
| 265 |
+
perimeter = cv2.arcLength(contours[0], True)
|
| 266 |
+
|
| 267 |
+
# Bounding box
|
| 268 |
+
x, y, w, h = cv2.boundingRect(contours[0])
|
| 269 |
+
bbox_area = w * h
|
| 270 |
+
aspect_ratio = w / h if h > 0 else 0
|
| 271 |
+
else:
|
| 272 |
+
bbox_area = 0
|
| 273 |
+
aspect_ratio = 0
|
| 274 |
+
|
| 275 |
+
return {
|
| 276 |
+
"area": int(area),
|
| 277 |
+
"perimeter": float(perimeter),
|
| 278 |
+
"bbox_area": int(bbox_area),
|
| 279 |
+
"aspect_ratio": float(aspect_ratio),
|
| 280 |
+
"coverage": float(area) / (mask.shape[0] * mask.shape[1]) if mask.size > 0 else 0.0,
|
| 281 |
+
"num_components": len(contours)
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
except Exception as e:
|
| 285 |
+
logger.error(f"Mask property calculation failed: {e}")
|
| 286 |
+
return {}
|
| 287 |
+
|
| 288 |
+
def create_overlay(self, image: np.ndarray, mask: np.ndarray,
|
| 289 |
+
color: Tuple[int, int, int] = (0, 255, 0),
|
| 290 |
+
alpha: float = 0.5) -> np.ndarray:
|
| 291 |
+
"""
|
| 292 |
+
Create overlay of mask on image.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
image: Base image
|
| 296 |
+
mask: Binary mask
|
| 297 |
+
color: Overlay color (B, G, R)
|
| 298 |
+
alpha: Overlay transparency
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
Image with mask overlay
|
| 302 |
+
"""
|
| 303 |
+
try:
|
| 304 |
+
overlay = image.copy()
|
| 305 |
+
overlay[mask == 255] = color
|
| 306 |
+
return cv2.addWeighted(image, 1.0 - alpha, overlay, alpha, 0)
|
| 307 |
+
except Exception as e:
|
| 308 |
+
logger.error(f"Overlay creation failed: {e}")
|
| 309 |
+
return image
|