Fahimeh Orvati Nia commited on
Commit
b4123b8
·
1 Parent(s): 4768cde

Add sorghum_pipeline code

Browse files
Files changed (39) hide show
  1. sorghum_pipeline/__init__.py +31 -0
  2. sorghum_pipeline/__pycache__/__init__.cpython-312.pyc +0 -0
  3. sorghum_pipeline/__pycache__/config.cpython-312.pyc +0 -0
  4. sorghum_pipeline/__pycache__/pipeline.cpython-312.pyc +0 -0
  5. sorghum_pipeline/config.py +249 -0
  6. sorghum_pipeline/data/__init__.py +15 -0
  7. sorghum_pipeline/data/__pycache__/__init__.cpython-312.pyc +0 -0
  8. sorghum_pipeline/data/__pycache__/loader.cpython-312.pyc +0 -0
  9. sorghum_pipeline/data/__pycache__/mask_handler.cpython-312.pyc +0 -0
  10. sorghum_pipeline/data/__pycache__/preprocessor.cpython-312.pyc +0 -0
  11. sorghum_pipeline/data/loader.py +444 -0
  12. sorghum_pipeline/data/mask_handler.py +296 -0
  13. sorghum_pipeline/data/preprocessor.py +279 -0
  14. sorghum_pipeline/features/__init__.py +21 -0
  15. sorghum_pipeline/features/__pycache__/__init__.cpython-312.pyc +0 -0
  16. sorghum_pipeline/features/__pycache__/morphology.cpython-312.pyc +0 -0
  17. sorghum_pipeline/features/__pycache__/spectral.cpython-312.pyc +0 -0
  18. sorghum_pipeline/features/__pycache__/texture.cpython-312.pyc +0 -0
  19. sorghum_pipeline/features/__pycache__/vegetation.cpython-312.pyc +0 -0
  20. sorghum_pipeline/features/morphology.py +380 -0
  21. sorghum_pipeline/features/spectral.py +383 -0
  22. sorghum_pipeline/features/texture.py +373 -0
  23. sorghum_pipeline/features/vegetation.py +308 -0
  24. sorghum_pipeline/models/__init__.py +10 -0
  25. sorghum_pipeline/models/__pycache__/__init__.cpython-312.pyc +0 -0
  26. sorghum_pipeline/models/__pycache__/dbc_lacunarity.cpython-312.pyc +0 -0
  27. sorghum_pipeline/models/dbc_lacunarity.py +90 -0
  28. sorghum_pipeline/output/__init__.py +13 -0
  29. sorghum_pipeline/output/__pycache__/__init__.cpython-312.pyc +0 -0
  30. sorghum_pipeline/output/__pycache__/manager.cpython-312.pyc +0 -0
  31. sorghum_pipeline/output/manager.py +688 -0
  32. sorghum_pipeline/pipeline.py +1377 -0
  33. sorghum_pipeline/segmentation/__init__.py +12 -0
  34. sorghum_pipeline/segmentation/__pycache__/__init__.cpython-312.pyc +0 -0
  35. sorghum_pipeline/segmentation/__pycache__/advanced_occlusion_handler.cpython-312.pyc +0 -0
  36. sorghum_pipeline/segmentation/__pycache__/leaf_occlusion_handler.cpython-312.pyc +0 -0
  37. sorghum_pipeline/segmentation/__pycache__/manager.cpython-312.pyc +0 -0
  38. sorghum_pipeline/segmentation/__pycache__/occlusion_handler.cpython-312.pyc +0 -0
  39. sorghum_pipeline/segmentation/manager.py +309 -0
sorghum_pipeline/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sorghum Plant Phenotyping Pipeline
3
+
4
+ A comprehensive pipeline for analyzing sorghum plant images including:
5
+ - Data loading and preprocessing
6
+ - Image segmentation and masking
7
+ - Feature extraction (texture, morphology, vegetation indices)
8
+ - Results visualization and export
9
+
10
+ Author: Fahime Horvatinia
11
+ Version: 2.0.0
12
+ """
13
+
14
+ __version__ = "2.0.0"
15
+ __author__ = "Fahime Horvatinia"
16
+
17
+ from .pipeline import SorghumPipeline
18
+ from .config import Config
19
+ from .data import DataLoader
20
+ from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
21
+ from .output import OutputManager
22
+
23
+ __all__ = [
24
+ "SorghumPipeline",
25
+ "Config",
26
+ "DataLoader",
27
+ "TextureExtractor",
28
+ "VegetationIndexExtractor",
29
+ "MorphologyExtractor",
30
+ "OutputManager"
31
+ ]
sorghum_pipeline/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (943 Bytes). View file
 
sorghum_pipeline/__pycache__/config.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
sorghum_pipeline/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (66.9 kB). View file
 
sorghum_pipeline/config.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration management for the Sorghum Pipeline.
3
+
4
+ This module handles all configuration settings, paths, and parameters
5
+ used throughout the pipeline.
6
+ """
7
+
8
+ import os
9
+ import yaml
10
+ from pathlib import Path
11
+ from typing import Dict, Any, Optional
12
+ from dataclasses import dataclass, field
13
+
14
+
15
+ @dataclass
16
+ class Paths:
17
+ """Configuration for all file paths."""
18
+ input_folder: str
19
+ output_folder: str
20
+ boundingbox_dir: Optional[str] = None
21
+ labels_folder: Optional[str] = None
22
+
23
+ def __post_init__(self):
24
+ """Ensure all paths are absolute where provided."""
25
+ self.input_folder = os.path.abspath(self.input_folder)
26
+ self.output_folder = os.path.abspath(self.output_folder)
27
+ if self.boundingbox_dir:
28
+ self.boundingbox_dir = os.path.abspath(self.boundingbox_dir)
29
+ if self.labels_folder:
30
+ self.labels_folder = os.path.abspath(self.labels_folder)
31
+
32
+
33
+ @dataclass
34
+ class ProcessingParams:
35
+ """Parameters for image processing."""
36
+ # Image processing
37
+ target_size: tuple = (1024, 1024)
38
+ gaussian_blur_kernel: int = 5
39
+ morphology_kernel_size: int = 7
40
+ min_component_area: int = 1000
41
+
42
+ # Segmentation
43
+ segmentation_threshold: float = 0.5
44
+ max_components: int = 10
45
+
46
+ # Texture analysis
47
+ lbp_points: int = 8
48
+ lbp_radius: int = 1
49
+ hog_orientations: int = 9
50
+ hog_pixels_per_cell: tuple = (8, 8)
51
+ hog_cells_per_block: tuple = (2, 2)
52
+ lacunarity_window: int = 15
53
+ ehd_threshold: float = 0.3
54
+ angle_resolution: int = 45
55
+
56
+ # Vegetation indices
57
+ epsilon: float = 1e-10
58
+ soil_factor: float = 0.16
59
+
60
+ # Morphology
61
+ pixel_to_cm: float = 0.1099609375
62
+ prune_sizes: list = field(default_factory=lambda: [200, 100, 50, 30, 10])
63
+
64
+
65
+ @dataclass
66
+ class OutputSettings:
67
+ """Settings for output generation."""
68
+ save_images: bool = True
69
+ save_plots: bool = True
70
+ save_metadata: bool = True
71
+ image_dpi: int = 150
72
+ plot_dpi: int = 100
73
+ image_format: str = "png"
74
+
75
+ # Subdirectories
76
+ segmentation_dir: str = "segmentation"
77
+ features_dir: str = "features"
78
+ texture_dir: str = "texture"
79
+ morphology_dir: str = "morphology"
80
+ vegetation_dir: str = "vegetation_indices"
81
+ analysis_dir: str = "analysis"
82
+
83
+
84
+ @dataclass
85
+ class ModelSettings:
86
+ """Settings for ML models."""
87
+ device: str = "auto" # auto, cpu, cuda
88
+ model_name: str = "briaai/RMBG-2.0"
89
+ batch_size: int = 1
90
+ trust_remote_code: bool = True
91
+ cache_dir: str = ""
92
+ local_files_only: bool = False
93
+
94
+
95
+ class Config:
96
+ """Main configuration class for the Sorghum Pipeline."""
97
+
98
+ def __init__(self, config_path: Optional[str] = None):
99
+ """
100
+ Initialize configuration.
101
+
102
+ Args:
103
+ config_path: Path to YAML configuration file. If None, uses defaults.
104
+ """
105
+ self.paths = Paths(
106
+ input_folder="",
107
+ output_folder="",
108
+ boundingbox_dir=""
109
+ )
110
+ self.processing = ProcessingParams()
111
+ self.output = OutputSettings()
112
+ self.model = ModelSettings()
113
+
114
+ if config_path:
115
+ self.load_from_file(config_path)
116
+
117
+ def load_from_file(self, config_path: str) -> None:
118
+ """Load configuration from YAML file."""
119
+ config_path = Path(config_path)
120
+ if not config_path.exists():
121
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
122
+
123
+ with open(config_path, 'r') as f:
124
+ config_data = yaml.safe_load(f)
125
+
126
+ # Update paths
127
+ if 'paths' in config_data:
128
+ self.paths = Paths(**config_data['paths'])
129
+
130
+ # Update processing parameters
131
+ if 'processing' in config_data:
132
+ for key, value in config_data['processing'].items():
133
+ if hasattr(self.processing, key):
134
+ setattr(self.processing, key, value)
135
+
136
+ # Update output settings
137
+ if 'output' in config_data:
138
+ for key, value in config_data['output'].items():
139
+ if hasattr(self.output, key):
140
+ setattr(self.output, key, value)
141
+
142
+ # Update model settings
143
+ if 'model' in config_data:
144
+ for key, value in config_data['model'].items():
145
+ if hasattr(self.model, key):
146
+ setattr(self.model, key, value)
147
+
148
+ def save_to_file(self, config_path: str) -> None:
149
+ """Save current configuration to YAML file."""
150
+ config_data = {
151
+ 'paths': {
152
+ 'input_folder': self.paths.input_folder,
153
+ 'output_folder': self.paths.output_folder,
154
+ 'boundingbox_dir': self.paths.boundingbox_dir,
155
+ 'labels_folder': self.paths.labels_folder
156
+ },
157
+ 'processing': {
158
+ 'target_size': self.processing.target_size,
159
+ 'gaussian_blur_kernel': self.processing.gaussian_blur_kernel,
160
+ 'morphology_kernel_size': self.processing.morphology_kernel_size,
161
+ 'min_component_area': self.processing.min_component_area,
162
+ 'segmentation_threshold': self.processing.segmentation_threshold,
163
+ 'max_components': self.processing.max_components,
164
+ 'lbp_points': self.processing.lbp_points,
165
+ 'lbp_radius': self.processing.lbp_radius,
166
+ 'hog_orientations': self.processing.hog_orientations,
167
+ 'hog_pixels_per_cell': self.processing.hog_pixels_per_cell,
168
+ 'hog_cells_per_block': self.processing.hog_cells_per_block,
169
+ 'lacunarity_window': self.processing.lacunarity_window,
170
+ 'ehd_threshold': self.processing.ehd_threshold,
171
+ 'angle_resolution': self.processing.angle_resolution,
172
+ 'epsilon': self.processing.epsilon,
173
+ 'soil_factor': self.processing.soil_factor,
174
+ 'pixel_to_cm': self.processing.pixel_to_cm,
175
+ 'prune_sizes': self.processing.prune_sizes
176
+ },
177
+ 'output': {
178
+ 'save_images': self.output.save_images,
179
+ 'save_plots': self.output.save_plots,
180
+ 'save_metadata': self.output.save_metadata,
181
+ 'image_dpi': self.output.image_dpi,
182
+ 'plot_dpi': self.output.plot_dpi,
183
+ 'image_format': self.output.image_format,
184
+ 'segmentation_dir': self.output.segmentation_dir,
185
+ 'features_dir': self.output.features_dir,
186
+ 'texture_dir': self.output.texture_dir,
187
+ 'morphology_dir': self.output.morphology_dir,
188
+ 'vegetation_dir': self.output.vegetation_dir,
189
+ 'analysis_dir': self.output.analysis_dir
190
+ },
191
+ 'model': {
192
+ 'device': self.model.device,
193
+ 'model_name': self.model.model_name,
194
+ 'batch_size': self.model.batch_size,
195
+ 'trust_remote_code': self.model.trust_remote_code,
196
+ 'cache_dir': self.model.cache_dir,
197
+ 'local_files_only': self.model.local_files_only,
198
+ }
199
+ }
200
+
201
+ with open(config_path, 'w') as f:
202
+ yaml.dump(config_data, f, default_flow_style=False, indent=2)
203
+
204
+ def get_device(self) -> str:
205
+ """Get the appropriate device for processing."""
206
+ if self.model.device == "auto":
207
+ import torch
208
+ return "cuda" if torch.cuda.is_available() else "cpu"
209
+ return self.model.device
210
+
211
+ def create_output_directories(self, base_path: str) -> None:
212
+ """Ensure base output directory exists only.
213
+
214
+ Subdirectories are created per plant in the output manager.
215
+ """
216
+ base_path = Path(base_path)
217
+ base_path.mkdir(parents=True, exist_ok=True)
218
+
219
+ def validate(self) -> bool:
220
+ """Validate configuration settings."""
221
+ # Check if input directory exists
222
+ if not os.path.exists(self.paths.input_folder):
223
+ raise FileNotFoundError(f"Input folder does not exist: {self.paths.input_folder}")
224
+
225
+ # Check if bounding box directory exists (optional)
226
+ if hasattr(self.paths, 'boundingbox_dir') and self.paths.boundingbox_dir and not os.path.exists(self.paths.boundingbox_dir):
227
+ raise FileNotFoundError(f"Bounding box directory does not exist: {self.paths.boundingbox_dir}")
228
+
229
+ # Validate processing parameters
230
+ if self.processing.target_size[0] <= 0 or self.processing.target_size[1] <= 0:
231
+ raise ValueError("Target size must be positive")
232
+
233
+ if self.processing.segmentation_threshold < 0 or self.processing.segmentation_threshold > 1:
234
+ raise ValueError("Segmentation threshold must be between 0 and 1")
235
+
236
+ return True
237
+
238
+
239
+ def create_default_config(output_path: str) -> None:
240
+ """Create a default configuration file."""
241
+ config = Config()
242
+ config.paths = Paths(
243
+ input_folder="Sorghum_dataset",
244
+ output_folder="Sorghum_pipeline_Results",
245
+ boundingbox_dir="boundingbox",
246
+ labels_folder="labels"
247
+ )
248
+ config.save_to_file(output_path)
249
+ print(f"Default configuration created at: {output_path}")
sorghum_pipeline/data/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading and preprocessing modules.
3
+
4
+ This package contains all data-related functionality including:
5
+ - Raw image loading
6
+ - Data preprocessing
7
+ - Mask handling
8
+ - Data validation
9
+ """
10
+
11
+ from .loader import DataLoader
12
+ from .preprocessor import ImagePreprocessor
13
+ from .mask_handler import MaskHandler
14
+
15
+ __all__ = ["DataLoader", "ImagePreprocessor", "MaskHandler"]
sorghum_pipeline/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (577 Bytes). View file
 
sorghum_pipeline/data/__pycache__/loader.cpython-312.pyc ADDED
Binary file (21.9 kB). View file
 
sorghum_pipeline/data/__pycache__/mask_handler.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
sorghum_pipeline/data/__pycache__/preprocessor.cpython-312.pyc ADDED
Binary file (11.4 kB). View file
 
sorghum_pipeline/data/loader.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading functionality for the Sorghum Pipeline.
3
+
4
+ This module handles loading raw images, managing plant data,
5
+ and organizing data according to the pipeline requirements.
6
+ """
7
+
8
+ import os
9
+ import glob
10
+ import json
11
+ from pathlib import Path
12
+ from typing import Dict, List, Tuple, Optional, Any
13
+ from PIL import Image
14
+ import numpy as np
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class DataLoader:
21
+ """Handles loading and organizing plant image data."""
22
+
23
+ # Plants to ignore completely (empty by default)
24
+ IGNORE_PLANTS = set()
25
+
26
+ # Plants where you want exactly one frame from their own folder
27
+ EXACT_FRAME = {
28
+ 4: 7, 5: 5, 7: 5, 12: 5, 13: 5, 18: 7, 19: 2, 20: 3,
29
+ 24: 6, 25: 5, 26: 5, 30: 8, 37: 7
30
+ }
31
+
32
+ # Plants where you want to borrow a frame from a different plant folder
33
+ BORROW_FRAME = {
34
+ 14: (13, 5), 15: (14, 5), 16: (15, 5), 33: (34, 7),
35
+ 34: (35, 7), 35: (35, 8), 36: (36, 6)
36
+ }
37
+
38
+ # Overrides provided by user: preferred frame per target plant name
39
+ FRAME_OVERRIDE_BY_NAME = {
40
+ 'plant1': 9, 'plant2': 10, 'plant3': 9, 'plant5': 7, 'plant6': 9, 'plant8': 5,
41
+ 'plant7': 9, 'plant10': 9, 'plant11': 9, 'plant12': 9,
42
+ 'plant13': 10, 'plant14': 8, 'plant15': 11, 'plant19': 4, 'plant20': 7,
43
+ 'plant21': 9, 'plant22': 10, 'plant25': 4, 'plant26': 2, 'plant27': 10, 'plant28': 9, 'plant29': 2,
44
+ 'plant30': 9, 'plant31': 10, 'plant32': 9, 'plant33': 8,
45
+ 'plant35': 9, 'plant36': 4, 'plant38': 9, 'plant39': 9, 'plant41': 9,
46
+ 'plant42': 6, 'plant43': 10, 'plant44': 9, 'plant45': 7,
47
+ 'plant47': 10, 'plant48': 11,
48
+ }
49
+
50
+ # Substitutes provided by user: map target plant name -> source plant name
51
+ PLANT_SUBSTITUTES_BY_NAME = {
52
+ 'plant16': 'plant15', 'plant15': 'plant14', 'plant14': 'plant13',
53
+ 'plant13': 'plant12', 'plant33': 'plant34', 'plant34': 'plant35',
54
+ 'plant24': 'plant25', 'plant25': 'plant25', 'plant35': 'plant36',
55
+ 'plant36': 'plant37', 'plant37': 'plant37', 'plant44': 'plant43',
56
+ 'plant45': 'plant44',
57
+ }
58
+
59
+ def __init__(self, input_folder: str, debug: bool = False, include_ignored: bool = False, strict_loader: bool = False, excluded_dates: Optional[List[str]] = None):
60
+ """
61
+ Initialize the data loader.
62
+
63
+ Args:
64
+ input_folder: Path to the input dataset folder
65
+ debug: Enable debug logging
66
+ """
67
+ self.input_folder = Path(input_folder)
68
+ self.debug = debug
69
+ self.include_ignored = include_ignored
70
+ self.strict_loader = strict_loader
71
+
72
+ if not self.input_folder.exists():
73
+ raise FileNotFoundError(f"Input folder does not exist: {input_folder}")
74
+ # Normalize excluded dates as a set of folder names (with dashes)
75
+ self.excluded_dates = set(excluded_dates or [])
76
+
77
+ def load_selected_frames(self) -> Dict[str, Dict[str, Any]]:
78
+ """
79
+ Load selected frames according to predefined rules.
80
+ If strict_loader is True, load only frame numbers from the plant's own folder (no borrowing/special picks).
81
+
82
+ Returns:
83
+ Dictionary with plant data organized by key format: "YYYY_MM_DD_plantX_frameY"
84
+ """
85
+ logger.info("Loading selected frames from dataset...")
86
+ plants = {}
87
+
88
+ # Detect if input folder is a direct date folder (contains plant folders)
89
+ first_items = list(self.input_folder.iterdir())
90
+ has_plant_folders = any(item.is_dir() and item.name.startswith('plant') for item in first_items)
91
+
92
+ def choose_frame_and_source(pid: int) -> Tuple[int, str]:
93
+ if self.strict_loader:
94
+ # In strict mode, honor explicit frame overrides AND substitution of source plant
95
+ plant_name_local = f"plant{pid}"
96
+ frame_num = self.FRAME_OVERRIDE_BY_NAME.get(
97
+ plant_name_local,
98
+ self.EXACT_FRAME.get(pid, 8)
99
+ )
100
+ source_plant = self.PLANT_SUBSTITUTES_BY_NAME.get(plant_name_local, plant_name_local)
101
+ return frame_num, source_plant
102
+ # Original behavior
103
+ frame_num = self._get_frame_number(pid)
104
+ source_plant = self._get_source_plant(pid)
105
+ return frame_num, source_plant
106
+
107
+ if has_plant_folders:
108
+ # Direct date folder structure
109
+ date_name = self.input_folder.name
110
+ date_path = self.input_folder
111
+ for plant_name in sorted(os.listdir(date_path)):
112
+ plant_path = date_path / plant_name
113
+ if not plant_path.is_dir():
114
+ continue
115
+ try:
116
+ plant_id = int(plant_name.replace("plant", ""))
117
+ except ValueError:
118
+ continue
119
+ if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
120
+ if self.debug:
121
+ logger.debug(f"Ignoring plant {plant_id}")
122
+ continue
123
+ frame_num, source_plant = choose_frame_and_source(plant_id)
124
+ frame_data = self._load_single_frame(date_path, source_plant, frame_num, plant_name)
125
+ if frame_data:
126
+ key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_num}"
127
+ plants[key] = frame_data
128
+ logger.debug(f"Loaded {key}")
129
+ else:
130
+ # Parent folder structure with date subfolders
131
+ for date_name in sorted(os.listdir(self.input_folder)):
132
+ date_path = self.input_folder / date_name
133
+ if not date_path.is_dir():
134
+ continue
135
+ if date_name in self.excluded_dates:
136
+ logger.info(f"Skipping excluded date: {date_name}")
137
+ continue
138
+ for plant_name in sorted(os.listdir(date_path)):
139
+ plant_path = date_path / plant_name
140
+ if not plant_path.is_dir():
141
+ continue
142
+ try:
143
+ plant_id = int(plant_name.replace("plant", ""))
144
+ except ValueError:
145
+ continue
146
+ if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
147
+ if self.debug:
148
+ logger.debug(f"Ignoring plant {plant_id}")
149
+ continue
150
+ frame_num, source_plant = choose_frame_and_source(plant_id)
151
+ frame_data = self._load_single_frame(date_path, source_plant, frame_num, plant_name)
152
+ if frame_data:
153
+ key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_num}"
154
+ plants[key] = frame_data
155
+ logger.debug(f"Loaded {key}")
156
+
157
+ logger.info(f"Successfully loaded {len(plants)} plant frames")
158
+ return plants
159
+
160
+ def load_all_frames(self) -> Dict[str, Dict[str, Any]]:
161
+ """
162
+ Load all available frames for each plant.
163
+
164
+ Returns:
165
+ Dictionary with all plant frames
166
+ """
167
+ logger.info("Loading all frames from dataset...")
168
+ plants = {}
169
+
170
+ # Check if we're directly in a date folder (contains plant folders)
171
+ # or in a parent folder (contains date folders)
172
+ first_items = list(self.input_folder.iterdir())
173
+ has_plant_folders = any(item.is_dir() and item.name.startswith('plant') for item in first_items)
174
+
175
+ if has_plant_folders:
176
+ # We're directly in a date folder
177
+ logger.info("Detected direct date folder structure")
178
+ date_name = self.input_folder.name
179
+ self._load_plants_from_date_folder(self.input_folder, date_name, plants)
180
+ else:
181
+ # We're in a parent folder with date subfolders
182
+ logger.info("Detected parent folder structure")
183
+ for date_name in sorted(os.listdir(self.input_folder)):
184
+ date_path = self.input_folder / date_name
185
+ if not date_path.is_dir():
186
+ continue
187
+ if date_name in self.excluded_dates:
188
+ logger.info(f"Skipping excluded date: {date_name}")
189
+ continue
190
+
191
+ logger.info(f"Processing date: {date_name}")
192
+ self._load_plants_from_date_folder(date_path, date_name, plants)
193
+
194
+ logger.info(f"Successfully loaded {len(plants)} plant frames")
195
+ return plants
196
+
197
+ def _load_plants_from_date_folder(self, date_path: Path, date_name: str, plants: Dict[str, Dict[str, Any]]) -> None:
198
+ """Load plants from a date folder."""
199
+ for plant_name in sorted(os.listdir(date_path)):
200
+ plant_path = date_path / plant_name
201
+ if not plant_path.is_dir():
202
+ continue
203
+
204
+ # Extract plant ID
205
+ try:
206
+ plant_id = int(plant_name.replace("plant", ""))
207
+ except ValueError:
208
+ logger.warning(f"Could not extract plant ID from {plant_name}")
209
+ continue
210
+
211
+ # Skip ignored plants
212
+ if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
213
+ logger.info(f"Skipping ignored plant {plant_id}")
214
+ continue
215
+
216
+ logger.info(f"Processing plant {plant_id}")
217
+
218
+ # Load all frames for this plant
219
+ pattern = str(plant_path / f"{plant_name}_frame*.tif")
220
+ frame_files = sorted(glob.glob(pattern))
221
+ logger.info(f"Found {len(frame_files)} frame files for {plant_name}")
222
+
223
+ for frame_path in frame_files:
224
+ frame_data = self._load_frame_from_path(frame_path, plant_name)
225
+ if frame_data:
226
+ frame_id = Path(frame_path).stem.split("_frame")[-1]
227
+ key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_id}"
228
+ plants[key] = frame_data
229
+ logger.debug(f"Loaded frame: {key}")
230
+ else:
231
+ logger.warning(f"Failed to load frame: {frame_path}")
232
+
233
+ def load_single_plant(self, date: str, plant: str, frame: int) -> Optional[Dict[str, Any]]:
234
+ """
235
+ Load a specific plant frame.
236
+
237
+ Args:
238
+ date: Date string (e.g., "2025-02-05")
239
+ plant: Plant name (e.g., "plant1")
240
+ frame: Frame number
241
+
242
+ Returns:
243
+ Plant data dictionary or None if not found
244
+ """
245
+ date_path = self.input_folder / date
246
+ if not date_path.exists():
247
+ logger.error(f"Date folder not found: {date}")
248
+ return None
249
+
250
+ plant_path = date_path / plant
251
+ if not plant_path.exists():
252
+ logger.error(f"Plant folder not found: {plant}")
253
+ return None
254
+
255
+ filename = f"{plant}_frame{frame}.tif"
256
+ frame_path = plant_path / filename
257
+
258
+ return self._load_frame_from_path(str(frame_path), plant)
259
+
260
+ def _get_frame_number(self, plant_id: int) -> int:
261
+ """Get the frame number for a plant ID."""
262
+ plant_name = f"plant{plant_id}"
263
+ # Highest priority: explicit overrides by plant name
264
+ if plant_name in self.FRAME_OVERRIDE_BY_NAME:
265
+ return int(self.FRAME_OVERRIDE_BY_NAME[plant_name])
266
+ # Next: original exact/borrrow rules
267
+ if plant_id in self.EXACT_FRAME:
268
+ return self.EXACT_FRAME[plant_id]
269
+ elif plant_id in self.BORROW_FRAME:
270
+ return self.BORROW_FRAME[plant_id][1]
271
+ else:
272
+ return 8 # Default frame
273
+
274
+ def _get_source_plant(self, plant_id: int) -> str:
275
+ """Get the source plant name for a plant ID."""
276
+ plant_name = f"plant{plant_id}"
277
+ # Highest priority: explicit substitutes by plant name
278
+ if plant_name in self.PLANT_SUBSTITUTES_BY_NAME:
279
+ return self.PLANT_SUBSTITUTES_BY_NAME[plant_name]
280
+ # Next: original borrow rules
281
+ if plant_id in self.BORROW_FRAME:
282
+ source_id = self.BORROW_FRAME[plant_id][0]
283
+ return f"plant{source_id}"
284
+ else:
285
+ return f"plant{plant_id}"
286
+
287
+ def _load_single_frame(self, date_path: Path, source_plant: str,
288
+ frame_num: int, target_plant: str) -> Optional[Dict[str, Any]]:
289
+ """Load a single frame from the specified path."""
290
+ filename = f"{source_plant}_frame{frame_num}.tif"
291
+ frame_path = date_path / source_plant / filename
292
+
293
+ if not frame_path.exists():
294
+ if self.debug:
295
+ logger.warning(f"Frame not found: {frame_path}")
296
+ return None
297
+
298
+ return self._load_frame_from_path(str(frame_path), target_plant)
299
+
300
+ def _load_frame_from_path(self, frame_path: str, plant_name: str) -> Optional[Dict[str, Any]]:
301
+ """Load frame data from a file path."""
302
+ try:
303
+ logger.debug(f"Attempting to load: {frame_path}")
304
+ image = Image.open(frame_path)
305
+ filename = Path(frame_path).name
306
+ logger.debug(f"Successfully loaded image: {filename}, size: {image.size}")
307
+
308
+ return {
309
+ "raw_image": (image, filename),
310
+ "plant_name": plant_name,
311
+ "file_path": frame_path
312
+ }
313
+ except Exception as e:
314
+ logger.error(f"Failed to load {frame_path}: {e}")
315
+ return None
316
+
317
+ def load_bounding_boxes(self, bbox_dir: str) -> Dict[str, Tuple[int, int, int, int]]:
318
+ """
319
+ Load bounding box data from JSON files.
320
+
321
+ Args:
322
+ bbox_dir: Directory containing bounding box JSON files
323
+
324
+ Returns:
325
+ Dictionary mapping plant names to bounding box coordinates
326
+ """
327
+ bbox_path = Path(bbox_dir)
328
+ if not bbox_path.exists():
329
+ raise FileNotFoundError(f"Bounding box directory not found: {bbox_dir}")
330
+
331
+ bbox_lookup = {}
332
+
333
+ for json_file in bbox_path.glob("*.json"):
334
+ stem = json_file.stem
335
+ # Normalize stems like plant_33_new -> plant33
336
+ if stem.startswith('plant_'):
337
+ parts = stem.split('_')
338
+ try:
339
+ idx = next(i for i,p in enumerate(parts) if p.isdigit())
340
+ plant_id = f"plant{parts[idx]}"
341
+ except Exception:
342
+ plant_id = stem.replace('_', '')
343
+ else:
344
+ plant_id = stem
345
+ try:
346
+ with open(json_file, 'r') as f:
347
+ data = json.load(f)
348
+
349
+ shapes = data.get('shapes', [])
350
+ # Prefer rectangle labeled 'sorghum' (case-insensitive), else first rectangle
351
+ def _is_sorghum_label(s: dict) -> bool:
352
+ for key in ('label', 'name', 'text'):
353
+ val = s.get(key)
354
+ if isinstance(val, str) and val.lower() == 'sorghum':
355
+ return True
356
+ return False
357
+ rect = next((s for s in shapes if s.get('shape_type') == 'rectangle' and _is_sorghum_label(s)), None)
358
+ if rect is None:
359
+ rect = next((s for s in shapes if s.get('shape_type') == 'rectangle'), None)
360
+
361
+ if rect:
362
+ (x1, y1), (x2, y2) = rect['points']
363
+ bbox_lookup[plant_id] = (
364
+ int(max(0, x1)),
365
+ int(max(0, y1)),
366
+ int(min(1e9, x2)),
367
+ int(min(1e9, y2))
368
+ )
369
+ else:
370
+ bbox_lookup[plant_id] = None
371
+
372
+ except Exception as e:
373
+ logger.error(f"Failed to load bounding box {json_file}: {e}")
374
+
375
+ logger.info(f"Loaded {len(bbox_lookup)} bounding boxes")
376
+ return bbox_lookup
377
+
378
+ def load_hand_labels(self, labels_dir: str) -> Dict[str, np.ndarray]:
379
+ """
380
+ Load hand-labeled masks from JSON files.
381
+
382
+ Args:
383
+ labels_dir: Directory containing label JSON files
384
+
385
+ Returns:
386
+ Dictionary mapping plant names to mask arrays
387
+ """
388
+ labels_path = Path(labels_dir)
389
+ if not labels_path.exists():
390
+ logger.warning(f"Labels directory not found: {labels_dir}")
391
+ return {}
392
+
393
+ masks = {}
394
+
395
+ for json_file in labels_path.glob("*.json"):
396
+ plant_id = json_file.stem
397
+ try:
398
+ with open(json_file, 'r') as f:
399
+ data = json.load(f)
400
+
401
+ # Create mask from shapes (assuming we have image dimensions)
402
+ # This would need to be adapted based on your label format
403
+ mask = self._create_mask_from_shapes(data)
404
+ if mask is not None:
405
+ masks[plant_id] = mask
406
+
407
+ except Exception as e:
408
+ logger.error(f"Failed to load label {json_file}: {e}")
409
+
410
+ logger.info(f"Loaded {len(masks)} hand labels")
411
+ return masks
412
+
413
+ def _create_mask_from_shapes(self, data: Dict) -> Optional[np.ndarray]:
414
+ """Create a mask array from shape data."""
415
+ # This is a placeholder - implement based on your label format
416
+ # For now, return None
417
+ return None
418
+
419
+ def validate_data(self, plants: Dict[str, Dict[str, Any]]) -> bool:
420
+ """
421
+ Validate loaded plant data.
422
+
423
+ Args:
424
+ plants: Dictionary of plant data
425
+
426
+ Returns:
427
+ True if data is valid, False otherwise
428
+ """
429
+ if not plants:
430
+ logger.error("No plant data loaded")
431
+ return False
432
+
433
+ for key, data in plants.items():
434
+ if "raw_image" not in data:
435
+ logger.error(f"Missing raw_image in {key}")
436
+ return False
437
+
438
+ image, filename = data["raw_image"]
439
+ if not isinstance(image, Image.Image):
440
+ logger.error(f"Invalid image type in {key}")
441
+ return False
442
+
443
+ logger.info("Data validation passed")
444
+ return True
sorghum_pipeline/data/mask_handler.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask handling functionality for the Sorghum Pipeline.
3
+
4
+ This module handles mask creation, processing, and validation
5
+ for plant segmentation tasks.
6
+ """
7
+
8
+ import numpy as np
9
+ import cv2
10
+ from typing import Dict, Tuple, Optional, List
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class MaskHandler:
17
+ """Handles mask creation, processing, and validation."""
18
+
19
+ def __init__(self, min_area: int = 1000, kernel_size: int = 7):
20
+ """
21
+ Initialize the mask handler.
22
+
23
+ Args:
24
+ min_area: Minimum area for connected components
25
+ kernel_size: Kernel size for morphological operations
26
+ """
27
+ self.min_area = min_area
28
+ self.kernel_size = kernel_size
29
+
30
+ def create_bounding_box_mask(self, image_shape: Tuple[int, int],
31
+ bbox: Tuple[int, int, int, int]) -> np.ndarray:
32
+ """
33
+ Create a mask from bounding box coordinates.
34
+
35
+ Args:
36
+ image_shape: Shape of the image (height, width)
37
+ bbox: Bounding box coordinates (x1, y1, x2, y2)
38
+
39
+ Returns:
40
+ Binary mask array
41
+ """
42
+ h, w = image_shape[:2]
43
+ mask = np.zeros((h, w), dtype=np.uint8)
44
+
45
+ x1, y1, x2, y2 = bbox
46
+ # Clamp coordinates to image bounds
47
+ x1 = max(0, min(w, x1))
48
+ y1 = max(0, min(h, y1))
49
+ x2 = max(0, min(w, x2))
50
+ y2 = max(0, min(h, y2))
51
+
52
+ mask[y1:y2, x1:x2] = 255
53
+ return mask
54
+
55
+ def preprocess_mask(self, mask: np.ndarray) -> np.ndarray:
56
+ """
57
+ Preprocess mask by cleaning and filtering.
58
+
59
+ Args:
60
+ mask: Input mask
61
+
62
+ Returns:
63
+ Cleaned mask
64
+ """
65
+ if mask is None:
66
+ return None
67
+
68
+ # Convert to binary if needed
69
+ if isinstance(mask, tuple):
70
+ mask = mask[0]
71
+
72
+ # Ensure binary format
73
+ mask = ((mask.astype(np.int32) > 0).astype(np.uint8)) * 255
74
+
75
+ # Morphological opening to remove noise
76
+ kernel = cv2.getStructuringElement(
77
+ cv2.MORPH_ELLIPSE,
78
+ (self.kernel_size, self.kernel_size)
79
+ )
80
+ opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
81
+
82
+ # Remove small connected components
83
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
84
+ opened, connectivity=8
85
+ )
86
+
87
+ clean_mask = np.zeros_like(opened)
88
+ for label in range(1, num_labels): # Skip background
89
+ if stats[label, cv2.CC_STAT_AREA] >= self.min_area:
90
+ clean_mask[labels == label] = 255
91
+
92
+ return clean_mask
93
+
94
+ def keep_largest_component(self, mask: np.ndarray) -> np.ndarray:
95
+ """
96
+ Keep only the largest connected component in the mask.
97
+
98
+ Args:
99
+ mask: Input mask
100
+
101
+ Returns:
102
+ Mask with only the largest component
103
+ """
104
+ if mask is None:
105
+ return None
106
+
107
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8)
108
+
109
+ if num_labels <= 1:
110
+ return mask
111
+
112
+ # Find the largest component (excluding background)
113
+ areas = stats[1:, cv2.CC_STAT_AREA]
114
+ largest_label = 1 + np.argmax(areas)
115
+
116
+ # Create mask with only the largest component
117
+ largest_mask = (labels == largest_label).astype(np.uint8) * 255
118
+
119
+ return largest_mask
120
+
121
+ def apply_mask_to_image(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
122
+ """
123
+ Apply mask to image.
124
+
125
+ Args:
126
+ image: Input image
127
+ mask: Binary mask
128
+
129
+ Returns:
130
+ Masked image
131
+ """
132
+ if mask is None:
133
+ return image
134
+
135
+ return cv2.bitwise_and(image, image, mask=mask)
136
+
137
+ def create_overlay(self, image: np.ndarray, mask: np.ndarray,
138
+ color: Tuple[int, int, int] = (0, 255, 0),
139
+ alpha: float = 0.5) -> np.ndarray:
140
+ """
141
+ Create overlay of mask on image.
142
+
143
+ Args:
144
+ image: Base image
145
+ mask: Binary mask
146
+ color: Overlay color (B, G, R)
147
+ alpha: Overlay transparency
148
+
149
+ Returns:
150
+ Image with mask overlay
151
+ """
152
+ overlay = image.copy()
153
+ overlay[mask == 255] = color
154
+ return cv2.addWeighted(image, 1.0 - alpha, overlay, alpha, 0)
155
+
156
+ def get_mask_properties(self, mask: np.ndarray) -> Dict[str, float]:
157
+ """
158
+ Get properties of the mask.
159
+
160
+ Args:
161
+ mask: Binary mask
162
+
163
+ Returns:
164
+ Dictionary of mask properties
165
+ """
166
+ if mask is None:
167
+ return {}
168
+
169
+ # Convert to binary
170
+ binary_mask = (mask > 127).astype(np.uint8)
171
+
172
+ # Calculate properties
173
+ area = np.sum(binary_mask)
174
+ perimeter = cv2.arcLength(
175
+ cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0][0],
176
+ True
177
+ ) if len(cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]) > 0 else 0
178
+
179
+ # Bounding box
180
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
181
+ if contours:
182
+ x, y, w, h = cv2.boundingRect(contours[0])
183
+ bbox_area = w * h
184
+ aspect_ratio = w / h if h > 0 else 0
185
+ else:
186
+ bbox_area = 0
187
+ aspect_ratio = 0
188
+
189
+ return {
190
+ "area": float(area),
191
+ "perimeter": float(perimeter),
192
+ "bbox_area": float(bbox_area),
193
+ "aspect_ratio": float(aspect_ratio),
194
+ "coverage": float(area) / (mask.shape[0] * mask.shape[1]) if mask.size > 0 else 0.0
195
+ }
196
+
197
+ def validate_mask(self, mask: np.ndarray) -> bool:
198
+ """
199
+ Validate mask format and content.
200
+
201
+ Args:
202
+ mask: Mask to validate
203
+
204
+ Returns:
205
+ True if valid, False otherwise
206
+ """
207
+ if mask is None:
208
+ return False
209
+
210
+ if not isinstance(mask, np.ndarray):
211
+ return False
212
+
213
+ if mask.ndim != 2:
214
+ return False
215
+
216
+ if mask.dtype not in [np.uint8, np.bool_]:
217
+ return False
218
+
219
+ # Check if mask has any foreground pixels
220
+ if np.sum(mask > 0) == 0:
221
+ logger.warning("Mask has no foreground pixels")
222
+ return False
223
+
224
+ return True
225
+
226
+ def resize_mask(self, mask: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
227
+ """
228
+ Resize mask to target size.
229
+
230
+ Args:
231
+ mask: Input mask
232
+ target_size: Target size (width, height)
233
+
234
+ Returns:
235
+ Resized mask
236
+ """
237
+ if mask is None:
238
+ return None
239
+
240
+ return cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
241
+
242
+ def dilate_mask(self, mask: np.ndarray, kernel_size: int = 5) -> np.ndarray:
243
+ """
244
+ Dilate mask to expand foreground regions.
245
+
246
+ Args:
247
+ mask: Input mask
248
+ kernel_size: Size of dilation kernel
249
+
250
+ Returns:
251
+ Dilated mask
252
+ """
253
+ if mask is None:
254
+ return None
255
+
256
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
257
+ return cv2.dilate(mask, kernel, iterations=1)
258
+
259
+ def erode_mask(self, mask: np.ndarray, kernel_size: int = 5) -> np.ndarray:
260
+ """
261
+ Erode mask to shrink foreground regions.
262
+
263
+ Args:
264
+ mask: Input mask
265
+ kernel_size: Size of erosion kernel
266
+
267
+ Returns:
268
+ Eroded mask
269
+ """
270
+ if mask is None:
271
+ return None
272
+
273
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
274
+ return cv2.erode(mask, kernel, iterations=1)
275
+
276
+ def fill_holes(self, mask: np.ndarray) -> np.ndarray:
277
+ """
278
+ Fill holes in the mask.
279
+
280
+ Args:
281
+ mask: Input mask
282
+
283
+ Returns:
284
+ Mask with filled holes
285
+ """
286
+ if mask is None:
287
+ return None
288
+
289
+ # Find contours
290
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
291
+
292
+ # Create filled mask
293
+ filled_mask = np.zeros_like(mask)
294
+ cv2.fillPoly(filled_mask, contours, 255)
295
+
296
+ return filled_mask
sorghum_pipeline/data/preprocessor.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image preprocessing functionality for the Sorghum Pipeline.
3
+
4
+ This module handles image preprocessing, composite creation,
5
+ and basic image transformations.
6
+ """
7
+
8
+ import numpy as np
9
+ import cv2
10
+ from PIL import Image
11
+ from typing import Dict, Tuple, Any, Optional
12
+ from itertools import product
13
+ import logging
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ImagePreprocessor:
19
+ """Handles image preprocessing and composite creation."""
20
+
21
+ def __init__(self, target_size: Optional[Tuple[int, int]] = None):
22
+ """
23
+ Initialize the image preprocessor.
24
+
25
+ Args:
26
+ target_size: Target size for image resizing (width, height)
27
+ """
28
+ self.target_size = target_size
29
+
30
+ def convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
31
+ """
32
+ Convert array to uint8 format with proper normalization.
33
+
34
+ Args:
35
+ arr: Input array
36
+
37
+ Returns:
38
+ Normalized uint8 array
39
+ """
40
+ # Handle NaN and infinite values
41
+ arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
42
+
43
+ # Normalize to 0-255 range
44
+ if arr.ptp() > 0:
45
+ normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
46
+ else:
47
+ normalized = np.zeros_like(arr)
48
+
49
+ return np.clip(normalized, 0, 255).astype(np.uint8)
50
+
51
+ def process_raw_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
52
+ """
53
+ Process raw 4-band image into composite and spectral bands.
54
+
55
+ Args:
56
+ pil_img: PIL Image object containing 4-band data
57
+
58
+ Returns:
59
+ Tuple of (composite_image, spectral_bands_dict)
60
+ """
61
+ # Split the 4-band RAW into tiles and stack them
62
+ d = pil_img.size[0] // 2
63
+ boxes = [
64
+ (j, i, j + d, i + d)
65
+ for i, j in product(
66
+ range(0, pil_img.height, d),
67
+ range(0, pil_img.width, d)
68
+ )
69
+ ]
70
+
71
+ # Extract tiles and stack them
72
+ stack = np.stack([
73
+ np.array(pil_img.crop(box), dtype=float)
74
+ for box in boxes
75
+ ], axis=-1)
76
+
77
+ # Bands come in order: [green, red, red_edge, nir]
78
+ green, red, red_edge, nir = np.split(stack, 4, axis=-1)
79
+
80
+ # Build pseudo-RGB composite as (green, red_edge, red)
81
+ composite = np.concatenate([green, red_edge, red], axis=-1)
82
+ composite_uint8 = self.convert_to_uint8(composite)
83
+
84
+ # Prepare spectral stack
85
+ spectral_bands = {
86
+ "green": green,
87
+ "red": red,
88
+ "red_edge": red_edge,
89
+ "nir": nir
90
+ }
91
+
92
+ return composite_uint8, spectral_bands
93
+
94
+ def create_composites(self, plants: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
95
+ """
96
+ Create composites for all plants in the dataset.
97
+
98
+ Args:
99
+ plants: Dictionary of plant data
100
+
101
+ Returns:
102
+ Updated plant data with composites and spectral stacks
103
+ """
104
+ logger.info("Creating composites for all plants...")
105
+
106
+ for key, pdata in plants.items():
107
+ try:
108
+ # Find the PIL Image
109
+ if "raw_image" in pdata:
110
+ image, _ = pdata["raw_image"]
111
+ elif "raw_images" in pdata and pdata["raw_images"]:
112
+ image, _ = pdata["raw_images"][0]
113
+ else:
114
+ logger.warning(f"No raw image found for {key}")
115
+ continue
116
+
117
+ # Process the image
118
+ composite, spectral_stack = self.process_raw_image(image)
119
+
120
+ # Store results
121
+ pdata["composite"] = composite
122
+ pdata["spectral_stack"] = spectral_stack
123
+
124
+ logger.debug(f"Created composite for {key}")
125
+
126
+ except Exception as e:
127
+ logger.error(f"Failed to create composite for {key}: {e}")
128
+ continue
129
+
130
+ logger.info("Composite creation completed")
131
+ return plants
132
+
133
+ def resize_image(self, image: np.ndarray, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray:
134
+ """
135
+ Resize image to target size.
136
+
137
+ Args:
138
+ image: Input image
139
+ target_size: Target size (width, height). If None, uses self.target_size
140
+
141
+ Returns:
142
+ Resized image
143
+ """
144
+ if target_size is None:
145
+ target_size = self.target_size
146
+
147
+ if target_size is None:
148
+ return image
149
+
150
+ return cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
151
+
152
+ def normalize_image(self, image: np.ndarray, method: str = "minmax") -> np.ndarray:
153
+ """
154
+ Normalize image using specified method.
155
+
156
+ Args:
157
+ image: Input image
158
+ method: Normalization method ("minmax", "zscore", "robust")
159
+
160
+ Returns:
161
+ Normalized image
162
+ """
163
+ if method == "minmax":
164
+ if image.dtype == np.uint8:
165
+ return image.astype(np.float32) / 255.0
166
+ else:
167
+ img_min, img_max = image.min(), image.max()
168
+ if img_max > img_min:
169
+ return (image - img_min) / (img_max - img_min)
170
+ else:
171
+ return np.zeros_like(image, dtype=np.float32)
172
+
173
+ elif method == "zscore":
174
+ mean, std = image.mean(), image.std()
175
+ if std > 0:
176
+ return (image - mean) / std
177
+ else:
178
+ return np.zeros_like(image, dtype=np.float32)
179
+
180
+ elif method == "robust":
181
+ q25, q75 = np.percentile(image, [25, 75])
182
+ if q75 > q25:
183
+ return (image - q25) / (q75 - q25)
184
+ else:
185
+ return np.zeros_like(image, dtype=np.float32)
186
+
187
+ else:
188
+ raise ValueError(f"Unknown normalization method: {method}")
189
+
190
+ def apply_gaussian_blur(self, image: np.ndarray, kernel_size: int = 5) -> np.ndarray:
191
+ """
192
+ Apply Gaussian blur to image.
193
+
194
+ Args:
195
+ image: Input image
196
+ kernel_size: Size of Gaussian kernel
197
+
198
+ Returns:
199
+ Blurred image
200
+ """
201
+ if kernel_size % 2 == 0:
202
+ kernel_size += 1
203
+
204
+ return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
205
+
206
+ def apply_sharpening(self, image: np.ndarray) -> np.ndarray:
207
+ """
208
+ Apply sharpening filter to image.
209
+
210
+ Args:
211
+ image: Input image
212
+
213
+ Returns:
214
+ Sharpened image
215
+ """
216
+ kernel = np.array([
217
+ [0, -1, 0],
218
+ [-1, 5, -1],
219
+ [0, -1, 0]
220
+ ])
221
+
222
+ return cv2.filter2D(image, -1, kernel)
223
+
224
+ def enhance_contrast(self, image: np.ndarray, alpha: float = 1.2, beta: int = 15) -> np.ndarray:
225
+ """
226
+ Enhance image contrast.
227
+
228
+ Args:
229
+ image: Input image
230
+ alpha: Contrast control (1.0 = no change)
231
+ beta: Brightness control (0 = no change)
232
+
233
+ Returns:
234
+ Enhanced image
235
+ """
236
+ return cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
237
+
238
+ def create_overlay(self, base_image: np.ndarray, mask: np.ndarray,
239
+ color: Tuple[int, int, int] = (0, 255, 0),
240
+ alpha: float = 0.5) -> np.ndarray:
241
+ """
242
+ Create overlay of mask on base image.
243
+
244
+ Args:
245
+ base_image: Base image
246
+ mask: Binary mask
247
+ color: Overlay color (B, G, R)
248
+ alpha: Overlay transparency
249
+
250
+ Returns:
251
+ Image with overlay
252
+ """
253
+ overlay = base_image.copy()
254
+ overlay[mask == 255] = color
255
+ return cv2.addWeighted(base_image, 1.0 - alpha, overlay, alpha, 0)
256
+
257
+ def validate_composite(self, composite: np.ndarray) -> bool:
258
+ """
259
+ Validate composite image.
260
+
261
+ Args:
262
+ composite: Composite image to validate
263
+
264
+ Returns:
265
+ True if valid, False otherwise
266
+ """
267
+ if composite is None:
268
+ return False
269
+
270
+ if not isinstance(composite, np.ndarray):
271
+ return False
272
+
273
+ if composite.ndim != 3 or composite.shape[2] != 3:
274
+ return False
275
+
276
+ if composite.dtype != np.uint8:
277
+ return False
278
+
279
+ return True
sorghum_pipeline/features/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature extraction modules for the Sorghum Pipeline.
3
+
4
+ This package contains all feature extraction functionality including:
5
+ - Texture features (LBP, HOG, Lacunarity, EHD)
6
+ - Vegetation indices
7
+ - Morphological features
8
+ - Spectral features
9
+ """
10
+
11
+ from .texture import TextureExtractor
12
+ from .vegetation import VegetationIndexExtractor
13
+ from .morphology import MorphologyExtractor
14
+ from .spectral import SpectralExtractor
15
+
16
+ __all__ = [
17
+ "TextureExtractor",
18
+ "VegetationIndexExtractor",
19
+ "MorphologyExtractor",
20
+ "SpectralExtractor"
21
+ ]
sorghum_pipeline/features/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (714 Bytes). View file
 
sorghum_pipeline/features/__pycache__/morphology.cpython-312.pyc ADDED
Binary file (18.6 kB). View file
 
sorghum_pipeline/features/__pycache__/spectral.cpython-312.pyc ADDED
Binary file (18 kB). View file
 
sorghum_pipeline/features/__pycache__/texture.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
sorghum_pipeline/features/__pycache__/vegetation.cpython-312.pyc ADDED
Binary file (25.1 kB). View file
 
sorghum_pipeline/features/morphology.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Morphological feature extraction for the Sorghum Pipeline.
3
+
4
+ This module handles extraction of morphological features using PlantCV
5
+ and other computer vision techniques.
6
+ """
7
+
8
+ import numpy as np
9
+ import cv2
10
+ import contextlib
11
+ import sys
12
+ from typing import Dict, Any, Optional, List, Tuple
13
+ import logging
14
+
15
+ # Try to import PlantCV, but don't fail if not available
16
+ try:
17
+ from plantcv import plantcv as pcv
18
+ PLANT_CV_AVAILABLE = True
19
+ except ImportError:
20
+ PLANT_CV_AVAILABLE = False
21
+ logger.warning("PlantCV not available. Morphological features will be limited.")
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class MorphologyExtractor:
27
+ """Extracts morphological features from plant images."""
28
+
29
+ def __init__(self, pixel_to_cm: float = 0.1099609375, prune_sizes: List[int] = None):
30
+ """
31
+ Initialize morphology extractor.
32
+
33
+ Args:
34
+ pixel_to_cm: Conversion factor from pixels to centimeters
35
+ prune_sizes: List of pruning sizes for skeleton processing
36
+ """
37
+ self.pixel_to_cm = pixel_to_cm
38
+ self.prune_sizes = prune_sizes or [200, 100, 50, 30, 10]
39
+
40
+ if PLANT_CV_AVAILABLE:
41
+ # Configure PlantCV
42
+ pcv.params.debug = None
43
+ pcv.params.text_size = 0.7
44
+ pcv.params.text_thickness = 2
45
+ pcv.params.line_thickness = 3
46
+ pcv.params.dpi = 100
47
+
48
+ def extract_morphology_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
49
+ """
50
+ Extract morphological features from plant image and mask.
51
+
52
+ Args:
53
+ image: Plant image (BGR format)
54
+ mask: Binary mask of the plant
55
+
56
+ Returns:
57
+ Dictionary containing morphological features and images
58
+ """
59
+ features = {
60
+ 'traits': {},
61
+ 'images': {},
62
+ 'success': False
63
+ }
64
+
65
+ try:
66
+ # Preprocess mask
67
+ clean_mask = self._preprocess_mask(mask)
68
+ if clean_mask is None:
69
+ logger.warning("Failed to preprocess mask")
70
+ return features
71
+
72
+ # Extract basic morphological features
73
+ basic_traits = self._extract_basic_features(clean_mask)
74
+ features['traits'].update(basic_traits)
75
+
76
+ # Extract skeleton-based features if PlantCV is available
77
+ if PLANT_CV_AVAILABLE:
78
+ skeleton_features = self._extract_skeleton_features(image, clean_mask)
79
+ features['traits'].update(skeleton_features['traits'])
80
+ features['images'].update(skeleton_features['images'])
81
+ else:
82
+ # Fallback to basic OpenCV features
83
+ cv_features = self._extract_opencv_features(image, clean_mask)
84
+ features['traits'].update(cv_features['traits'])
85
+ features['images'].update(cv_features['images'])
86
+
87
+ features['success'] = True
88
+ logger.debug("Morphological features extracted successfully")
89
+
90
+ except Exception as e:
91
+ logger.error(f"Morphological feature extraction failed: {e}")
92
+
93
+ return features
94
+
95
+ def _preprocess_mask(self, mask: np.ndarray) -> Optional[np.ndarray]:
96
+ """Preprocess mask for morphological analysis."""
97
+ if mask is None:
98
+ return None
99
+
100
+ # Convert to binary if needed
101
+ if isinstance(mask, tuple):
102
+ mask = mask[0]
103
+
104
+ # Ensure binary format
105
+ mask = ((mask.astype(np.int32) > 0).astype(np.uint8)) * 255
106
+
107
+ # Morphological opening to remove noise
108
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
109
+ opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
110
+
111
+ # Remove small connected components
112
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(opened, connectivity=8)
113
+ clean_mask = np.zeros_like(opened)
114
+
115
+ for label in range(1, num_labels): # Skip background
116
+ if stats[label, cv2.CC_STAT_AREA] >= 1000:
117
+ clean_mask[labels == label] = 255
118
+
119
+ return clean_mask
120
+
121
+ def _extract_basic_features(self, mask: np.ndarray) -> Dict[str, float]:
122
+ """Extract basic morphological features using OpenCV."""
123
+ features = {}
124
+
125
+ try:
126
+ # Find contours
127
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
128
+
129
+ if not contours:
130
+ return features
131
+
132
+ # Get the largest contour
133
+ largest_contour = max(contours, key=cv2.contourArea)
134
+
135
+ # Basic measurements
136
+ area = cv2.contourArea(largest_contour)
137
+ perimeter = cv2.arcLength(largest_contour, True)
138
+
139
+ # Bounding box
140
+ x, y, w, h = cv2.boundingRect(largest_contour)
141
+ bbox_area = w * h
142
+
143
+ # Ellipse fitting
144
+ if len(largest_contour) >= 5:
145
+ ellipse = cv2.fitEllipse(largest_contour)
146
+ (center, axes, angle) = ellipse
147
+ major_axis = max(axes)
148
+ minor_axis = min(axes)
149
+ else:
150
+ major_axis = max(w, h)
151
+ minor_axis = min(w, h)
152
+
153
+ # Convert to centimeters
154
+ features['area_cm2'] = area * (self.pixel_to_cm ** 2)
155
+ features['perimeter_cm'] = perimeter * self.pixel_to_cm
156
+ features['width_cm'] = w * self.pixel_to_cm
157
+ features['height_cm'] = h * self.pixel_to_cm
158
+ features['bbox_area_cm2'] = bbox_area * (self.pixel_to_cm ** 2)
159
+ features['major_axis_cm'] = major_axis * self.pixel_to_cm
160
+ features['minor_axis_cm'] = minor_axis * self.pixel_to_cm
161
+ features['aspect_ratio'] = w / h if h > 0 else 0
162
+ features['elongation'] = major_axis / minor_axis if minor_axis > 0 else 0
163
+ features['circularity'] = (4 * np.pi * area) / (perimeter ** 2) if perimeter > 0 else 0
164
+ features['solidity'] = area / bbox_area if bbox_area > 0 else 0
165
+
166
+ # Convex hull
167
+ hull = cv2.convexHull(largest_contour)
168
+ hull_area = cv2.contourArea(hull)
169
+ features['convexity'] = area / hull_area if hull_area > 0 else 0
170
+
171
+ except Exception as e:
172
+ logger.error(f"Basic feature extraction failed: {e}")
173
+
174
+ return features
175
+
176
+ def _extract_skeleton_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
177
+ """Extract skeleton-based features using PlantCV."""
178
+ features = {'traits': {}, 'images': {}}
179
+
180
+ if not PLANT_CV_AVAILABLE:
181
+ return features
182
+
183
+ try:
184
+ # Suppress PlantCV output
185
+ with contextlib.redirect_stdout(self._FilteredStream(sys.stdout)), \
186
+ contextlib.redirect_stderr(self._FilteredStream(sys.stderr)):
187
+
188
+ # Skeletonize
189
+ skeleton = pcv.morphology.skeletonize(mask=mask)
190
+ features['images']['skeleton'] = skeleton
191
+
192
+ # Prune skeleton
193
+ pruned_skel = skeleton
194
+ for size in self.prune_sizes:
195
+ pruned_skel, _, _ = pcv.morphology.prune(
196
+ skel_img=pruned_skel, size=size, mask=mask
197
+ )
198
+
199
+ features['images']['pruned_skeleton'] = pruned_skel
200
+
201
+ # Find branch points and tips
202
+ branch_pts = pcv.morphology.find_branch_pts(pruned_skel, mask)
203
+ features['images']['branch_points'] = branch_pts
204
+
205
+ try:
206
+ tip_pts = pcv.morphology.find_tips(pruned_skel, mask)
207
+ features['images']['tip_points'] = tip_pts
208
+ except Exception as e:
209
+ logger.warning(f"Tip detection failed: {e}")
210
+
211
+ # Segment objects
212
+ try:
213
+ leaf_obj, stem_obj = pcv.morphology.segment_sort(
214
+ pruned_skel, [], mask
215
+ )
216
+ features['traits']['num_leaves'] = len(leaf_obj)
217
+ features['traits']['num_stems'] = len(stem_obj)
218
+ except Exception as e:
219
+ logger.warning(f"Object segmentation failed: {e}")
220
+ features['traits']['num_leaves'] = 0
221
+ features['traits']['num_stems'] = 0
222
+
223
+ # Size analysis
224
+ try:
225
+ labeled_mask, n_labels = pcv.create_labels(mask)
226
+ size_analysis = pcv.analyze.size(image, labeled_mask, n_labels, label="default")
227
+ features['images']['size_analysis'] = size_analysis
228
+
229
+ # Get size traits
230
+ obs = pcv.outputs.observations.get("default_1", {})
231
+ for trait, info in obs.items():
232
+ if trait not in ["in_bounds", "object_in_frame"]:
233
+ val = info.get("value", None)
234
+ if val is not None:
235
+ if trait == "area":
236
+ val = val * (self.pixel_to_cm ** 2)
237
+ elif trait in ["perimeter", "width", "height", "longest_path",
238
+ "ellipse_major_axis", "ellipse_minor_axis"]:
239
+ val = val * self.pixel_to_cm
240
+ features['traits'][trait] = val
241
+
242
+ except Exception as e:
243
+ logger.warning(f"Size analysis failed: {e}")
244
+
245
+ except Exception as e:
246
+ logger.error(f"Skeleton feature extraction failed: {e}")
247
+
248
+ return features
249
+
250
+ def _extract_opencv_features(self, image: np.ndarray, mask: np.ndarray) -> Dict[str, Any]:
251
+ """Extract features using only OpenCV (fallback when PlantCV is not available)."""
252
+ features = {'traits': {}, 'images': {}}
253
+
254
+ try:
255
+ # Create skeleton using OpenCV
256
+ skeleton = self._create_skeleton_opencv(mask)
257
+ features['images']['skeleton'] = skeleton
258
+
259
+ # Find branch points
260
+ branch_points = self._find_branch_points_opencv(skeleton)
261
+ features['images']['branch_points'] = branch_points
262
+ features['traits']['num_branches'] = len(branch_points)
263
+
264
+ # Find endpoints
265
+ endpoints = self._find_endpoints_opencv(skeleton)
266
+ features['images']['endpoints'] = endpoints
267
+ features['traits']['num_endpoints'] = len(endpoints)
268
+
269
+ # Skeleton length
270
+ skeleton_length = np.sum(skeleton > 0)
271
+ features['traits']['skeleton_length_pixels'] = skeleton_length
272
+ features['traits']['skeleton_length_cm'] = skeleton_length * self.pixel_to_cm
273
+
274
+ except Exception as e:
275
+ logger.error(f"OpenCV feature extraction failed: {e}")
276
+
277
+ return features
278
+
279
+ def _create_skeleton_opencv(self, mask: np.ndarray) -> np.ndarray:
280
+ """Create skeleton using OpenCV."""
281
+ # Convert to binary
282
+ binary = (mask > 0).astype(np.uint8)
283
+
284
+ # Create skeleton using morphological operations
285
+ skeleton = np.zeros_like(binary)
286
+ element = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
287
+
288
+ while True:
289
+ eroded = cv2.erode(binary, element)
290
+ temp = cv2.dilate(eroded, element)
291
+ temp = cv2.subtract(binary, temp)
292
+ skeleton = cv2.bitwise_or(skeleton, temp)
293
+ binary = eroded.copy()
294
+
295
+ if cv2.countNonZero(binary) == 0:
296
+ break
297
+
298
+ return skeleton * 255
299
+
300
+ def _find_branch_points_opencv(self, skeleton: np.ndarray) -> List[Tuple[int, int]]:
301
+ """Find branch points in skeleton using OpenCV."""
302
+ # Count neighbors for each pixel
303
+ kernel = np.ones((3, 3), dtype=np.uint8)
304
+ kernel[1, 1] = 0 # Don't count center pixel
305
+
306
+ neighbor_count = cv2.filter2D(skeleton, -1, kernel)
307
+
308
+ # Branch points have 3 or more neighbors
309
+ branch_points = np.where((skeleton > 0) & (neighbor_count >= 3))
310
+ return list(zip(branch_points[1], branch_points[0])) # (x, y) format
311
+
312
+ def _find_endpoints_opencv(self, skeleton: np.ndarray) -> List[Tuple[int, int]]:
313
+ """Find endpoints in skeleton using OpenCV."""
314
+ # Count neighbors for each pixel
315
+ kernel = np.ones((3, 3), dtype=np.uint8)
316
+ kernel[1, 1] = 0 # Don't count center pixel
317
+
318
+ neighbor_count = cv2.filter2D(skeleton, -1, kernel)
319
+
320
+ # Endpoints have exactly 1 neighbor
321
+ endpoints = np.where((skeleton > 0) & (neighbor_count == 1))
322
+ return list(zip(endpoints[1], endpoints[0])) # (x, y) format
323
+
324
+ class _FilteredStream:
325
+ """Filter PlantCV output to reduce noise."""
326
+ def __init__(self, stream):
327
+ self.stream = stream
328
+
329
+ def write(self, msg):
330
+ skip = ("got pruned", "Slope of contour", "cannot be plotted")
331
+ if not any(s in msg for s in skip):
332
+ self.stream.write(msg)
333
+
334
+ def flush(self):
335
+ try:
336
+ self.stream.flush()
337
+ except Exception:
338
+ pass
339
+
340
+ def create_morphology_visualization(self, image: np.ndarray, mask: np.ndarray,
341
+ features: Dict[str, Any]) -> np.ndarray:
342
+ """
343
+ Create visualization of morphological features.
344
+
345
+ Args:
346
+ image: Original image
347
+ mask: Binary mask
348
+ features: Extracted features
349
+
350
+ Returns:
351
+ Visualization image
352
+ """
353
+ try:
354
+ # Create visualization
355
+ vis = image.copy()
356
+
357
+ # Draw mask outline
358
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
359
+ cv2.drawContours(vis, contours, -1, (0, 255, 0), 2)
360
+
361
+ # Draw bounding box
362
+ if contours:
363
+ x, y, w, h = cv2.boundingRect(contours[0])
364
+ cv2.rectangle(vis, (x, y), (x + w, y + h), (255, 0, 0), 2)
365
+
366
+ # Draw skeleton if available
367
+ if 'skeleton' in features.get('images', {}):
368
+ skeleton = features['images']['skeleton']
369
+ vis[skeleton > 0] = [0, 0, 255] # Red skeleton
370
+
371
+ # Draw branch points if available
372
+ if 'branch_points' in features.get('images', {}):
373
+ branch_img = features['images']['branch_points']
374
+ vis[branch_img > 0] = [255, 255, 0] # Yellow branch points
375
+
376
+ return vis
377
+
378
+ except Exception as e:
379
+ logger.error(f"Visualization creation failed: {e}")
380
+ return image
sorghum_pipeline/features/spectral.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spectral feature extraction for the Sorghum Pipeline.
3
+
4
+ This module handles extraction of spectral features and analysis
5
+ of multispectral data.
6
+ """
7
+
8
+ import numpy as np
9
+ import cv2
10
+ from sklearn.decomposition import PCA
11
+ from typing import Dict, Any, Optional, List, Tuple
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class SpectralExtractor:
18
+ """Extracts spectral features from multispectral data."""
19
+
20
+ def __init__(self, n_components: int = 3):
21
+ """
22
+ Initialize spectral extractor.
23
+
24
+ Args:
25
+ n_components: Number of PCA components to extract
26
+ """
27
+ self.n_components = n_components
28
+
29
+ def extract_spectral_features(self, spectral_stack: Dict[str, np.ndarray],
30
+ mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
31
+ """
32
+ Extract spectral features from multispectral data.
33
+
34
+ Args:
35
+ spectral_stack: Dictionary of spectral bands
36
+ mask: Optional binary mask
37
+
38
+ Returns:
39
+ Dictionary containing spectral features
40
+ """
41
+ features = {}
42
+
43
+ try:
44
+ # Extract individual band features
45
+ features['band_features'] = self._extract_band_features(spectral_stack, mask)
46
+
47
+ # Extract PCA features
48
+ features['pca_features'] = self._extract_pca_features(spectral_stack, mask)
49
+
50
+ # Extract spectral indices
51
+ features['spectral_indices'] = self._extract_spectral_indices(spectral_stack, mask)
52
+
53
+ # Extract texture features from spectral bands
54
+ features['spectral_texture'] = self._extract_spectral_texture(spectral_stack, mask)
55
+
56
+ logger.debug("Spectral features extracted successfully")
57
+
58
+ except Exception as e:
59
+ logger.error(f"Spectral feature extraction failed: {e}")
60
+
61
+ return features
62
+
63
+ def _extract_band_features(self, spectral_stack: Dict[str, np.ndarray],
64
+ mask: Optional[np.ndarray] = None) -> Dict[str, Dict[str, float]]:
65
+ """Extract features from individual spectral bands."""
66
+ band_features = {}
67
+
68
+ for band_name, band_data in spectral_stack.items():
69
+ try:
70
+ # Squeeze to 2D if needed
71
+ if band_data.ndim > 2:
72
+ band_data = band_data.squeeze()
73
+
74
+ # Apply mask if provided
75
+ if mask is not None and mask.shape == band_data.shape:
76
+ masked_data = np.where(mask > 0, band_data, np.nan)
77
+ else:
78
+ masked_data = band_data
79
+
80
+ # Compute statistics
81
+ valid_data = masked_data[~np.isnan(masked_data)]
82
+ if len(valid_data) > 0:
83
+ band_features[band_name] = {
84
+ 'mean': float(np.mean(valid_data)),
85
+ 'std': float(np.std(valid_data)),
86
+ 'min': float(np.min(valid_data)),
87
+ 'max': float(np.max(valid_data)),
88
+ 'median': float(np.median(valid_data)),
89
+ 'q25': float(np.percentile(valid_data, 25)),
90
+ 'q75': float(np.percentile(valid_data, 75)),
91
+ 'skewness': float(self._compute_skewness(valid_data)),
92
+ 'kurtosis': float(self._compute_kurtosis(valid_data)),
93
+ 'entropy': float(self._compute_entropy(valid_data))
94
+ }
95
+ else:
96
+ band_features[band_name] = {
97
+ 'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0,
98
+ 'median': 0.0, 'q25': 0.0, 'q75': 0.0,
99
+ 'skewness': 0.0, 'kurtosis': 0.0, 'entropy': 0.0
100
+ }
101
+
102
+ except Exception as e:
103
+ logger.error(f"Band feature extraction failed for {band_name}: {e}")
104
+ band_features[band_name] = {}
105
+
106
+ return band_features
107
+
108
+ def _extract_pca_features(self, spectral_stack: Dict[str, np.ndarray],
109
+ mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
110
+ """Extract PCA features from spectral data."""
111
+ try:
112
+ # Stack all bands
113
+ band_names = ['nir', 'red_edge', 'red', 'green']
114
+ band_data = []
115
+
116
+ for band_name in band_names:
117
+ if band_name in spectral_stack:
118
+ arr = spectral_stack[band_name].squeeze().astype(float)
119
+ if mask is not None and mask.shape == arr.shape:
120
+ arr = np.where(mask > 0, arr, np.nan)
121
+ band_data.append(arr)
122
+
123
+ if not band_data:
124
+ return {}
125
+
126
+ # Stack bands
127
+ full_stack = np.stack(band_data, axis=-1)
128
+ h, w, c = full_stack.shape
129
+
130
+ # Reshape for PCA
131
+ flat_data = full_stack.reshape(-1, c)
132
+ valid_mask = ~np.isnan(flat_data).any(axis=1)
133
+
134
+ if valid_mask.sum() == 0:
135
+ return {}
136
+
137
+ # Apply PCA
138
+ valid_data = flat_data[valid_mask]
139
+ pca = PCA(n_components=min(self.n_components, valid_data.shape[1]))
140
+ pca_result = pca.fit_transform(valid_data)
141
+
142
+ # Create full result array
143
+ full_result = np.full((h * w, self.n_components), np.nan)
144
+ full_result[valid_mask] = pca_result
145
+
146
+ # Reshape back to image dimensions
147
+ pca_components = {}
148
+ for i in range(self.n_components):
149
+ component = full_result[:, i].reshape(h, w)
150
+ pca_components[f'pca_{i+1}'] = component
151
+
152
+ # Compute statistics for this component
153
+ valid_component = component[~np.isnan(component)]
154
+ if len(valid_component) > 0:
155
+ pca_components[f'pca_{i+1}_stats'] = {
156
+ 'mean': float(np.mean(valid_component)),
157
+ 'std': float(np.std(valid_component)),
158
+ 'min': float(np.min(valid_component)),
159
+ 'max': float(np.max(valid_component))
160
+ }
161
+
162
+ # Add PCA metadata
163
+ pca_components['explained_variance_ratio'] = pca.explained_variance_ratio_.tolist()
164
+ pca_components['total_variance_explained'] = float(np.sum(pca.explained_variance_ratio_))
165
+
166
+ return pca_components
167
+
168
+ except Exception as e:
169
+ logger.error(f"PCA feature extraction failed: {e}")
170
+ return {}
171
+
172
+ def _extract_spectral_indices(self, spectral_stack: Dict[str, np.ndarray],
173
+ mask: Optional[np.ndarray] = None) -> Dict[str, np.ndarray]:
174
+ """Extract basic spectral indices."""
175
+ indices = {}
176
+
177
+ try:
178
+ # Get required bands
179
+ nir = spectral_stack.get('nir', None)
180
+ red = spectral_stack.get('red', None)
181
+ green = spectral_stack.get('green', None)
182
+ red_edge = spectral_stack.get('red_edge', None)
183
+
184
+ if nir is not None:
185
+ nir = nir.squeeze().astype(float)
186
+ if red is not None:
187
+ red = red.squeeze().astype(float)
188
+ if green is not None:
189
+ green = green.squeeze().astype(float)
190
+ if red_edge is not None:
191
+ red_edge = red_edge.squeeze().astype(float)
192
+
193
+ # Apply mask
194
+ if mask is not None:
195
+ if nir is not None and mask.shape == nir.shape:
196
+ nir = np.where(mask > 0, nir, np.nan)
197
+ if red is not None and mask.shape == red.shape:
198
+ red = np.where(mask > 0, red, np.nan)
199
+ if green is not None and mask.shape == green.shape:
200
+ green = np.where(mask > 0, green, np.nan)
201
+ if red_edge is not None and mask.shape == red_edge.shape:
202
+ red_edge = np.where(mask > 0, red_edge, np.nan)
203
+
204
+ # Compute basic indices
205
+ if nir is not None and red is not None:
206
+ indices['nir_red_ratio'] = nir / (red + 1e-10)
207
+ indices['nir_red_diff'] = nir - red
208
+
209
+ if nir is not None and green is not None:
210
+ indices['nir_green_ratio'] = nir / (green + 1e-10)
211
+ indices['nir_green_diff'] = nir - green
212
+
213
+ if red is not None and green is not None:
214
+ indices['red_green_ratio'] = red / (green + 1e-10)
215
+ indices['red_green_diff'] = red - green
216
+
217
+ if nir is not None and red_edge is not None:
218
+ indices['nir_red_edge_ratio'] = nir / (red_edge + 1e-10)
219
+ indices['nir_red_edge_diff'] = nir - red_edge
220
+
221
+ # Compute band ratios
222
+ if nir is not None and red is not None and green is not None:
223
+ indices['nir_red_green_sum'] = nir + red + green
224
+ indices['nir_red_green_mean'] = (nir + red + green) / 3
225
+
226
+ except Exception as e:
227
+ logger.error(f"Spectral index extraction failed: {e}")
228
+
229
+ return indices
230
+
231
+ def _extract_spectral_texture(self, spectral_stack: Dict[str, np.ndarray],
232
+ mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
233
+ """Extract texture features from spectral bands."""
234
+ texture_features = {}
235
+
236
+ try:
237
+ from .texture import TextureExtractor
238
+
239
+ texture_extractor = TextureExtractor()
240
+
241
+ for band_name, band_data in spectral_stack.items():
242
+ try:
243
+ # Prepare grayscale image
244
+ gray_data = band_data.squeeze().astype(float)
245
+
246
+ # Apply mask
247
+ if mask is not None and mask.shape == gray_data.shape:
248
+ gray_data = np.where(mask > 0, gray_data, np.nan)
249
+
250
+ # Normalize to 0-255
251
+ valid_data = gray_data[~np.isnan(gray_data)]
252
+ if len(valid_data) > 0:
253
+ m, M = np.min(valid_data), np.max(valid_data)
254
+ if M > m:
255
+ normalized = ((gray_data - m) / (M - m) * 255).astype(np.uint8)
256
+ else:
257
+ normalized = np.zeros_like(gray_data, dtype=np.uint8)
258
+ else:
259
+ normalized = np.zeros_like(gray_data, dtype=np.uint8)
260
+
261
+ # Extract texture features
262
+ band_texture = texture_extractor.extract_all_texture_features(normalized)
263
+ texture_features[band_name] = band_texture
264
+
265
+ except Exception as e:
266
+ logger.error(f"Spectral texture extraction failed for {band_name}: {e}")
267
+ texture_features[band_name] = {}
268
+
269
+ except ImportError:
270
+ logger.warning("TextureExtractor not available for spectral texture analysis")
271
+
272
+ return texture_features
273
+
274
+ def _compute_skewness(self, data: np.ndarray) -> float:
275
+ """Compute skewness of data."""
276
+ if len(data) < 3:
277
+ return 0.0
278
+
279
+ mean = np.mean(data)
280
+ std = np.std(data)
281
+ if std == 0:
282
+ return 0.0
283
+
284
+ return np.mean(((data - mean) / std) ** 3)
285
+
286
+ def _compute_kurtosis(self, data: np.ndarray) -> float:
287
+ """Compute kurtosis of data."""
288
+ if len(data) < 4:
289
+ return 0.0
290
+
291
+ mean = np.mean(data)
292
+ std = np.std(data)
293
+ if std == 0:
294
+ return 0.0
295
+
296
+ return np.mean(((data - mean) / std) ** 4) - 3
297
+
298
+ def _compute_entropy(self, data: np.ndarray) -> float:
299
+ """Compute entropy of data."""
300
+ if len(data) == 0:
301
+ return 0.0
302
+
303
+ # Create histogram
304
+ hist, _ = np.histogram(data, bins=256, range=(0, 256))
305
+ hist = hist / np.sum(hist) # Normalize
306
+
307
+ # Remove zero probabilities
308
+ hist = hist[hist > 0]
309
+
310
+ # Compute entropy
311
+ return -np.sum(hist * np.log2(hist))
312
+
313
+ def create_spectral_visualization(self, spectral_stack: Dict[str, np.ndarray],
314
+ pca_features: Dict[str, Any]) -> np.ndarray:
315
+ """
316
+ Create visualization of spectral features.
317
+
318
+ Args:
319
+ spectral_stack: Original spectral data
320
+ pca_features: PCA features
321
+
322
+ Returns:
323
+ Visualization image
324
+ """
325
+ try:
326
+ # Preferred visualization: RGB = (Red, Red-Edge, Green)
327
+ if 'red' in spectral_stack and 'red_edge' in spectral_stack and 'green' in spectral_stack:
328
+ red = spectral_stack['red'].squeeze()
329
+ red_edge = spectral_stack['red_edge'].squeeze()
330
+ green = spectral_stack['green'].squeeze()
331
+
332
+ # Normalize each band
333
+ red_norm = self._normalize_band(red)
334
+ red_edge_norm = self._normalize_band(red_edge)
335
+ green_norm = self._normalize_band(green)
336
+
337
+ # Create composite (Red, Red-Edge, Green)
338
+ rgb_composite = np.stack([red_norm, red_edge_norm, green_norm], axis=-1)
339
+
340
+ return rgb_composite.astype(np.uint8)
341
+
342
+ # Fallback visualization: RGB = (NIR, Red, Green)
343
+ if 'red' in spectral_stack and 'green' in spectral_stack and 'nir' in spectral_stack:
344
+ red = spectral_stack['red'].squeeze()
345
+ green = spectral_stack['green'].squeeze()
346
+ nir = spectral_stack['nir'].squeeze()
347
+
348
+ # Normalize each band
349
+ red_norm = self._normalize_band(red)
350
+ green_norm = self._normalize_band(green)
351
+ nir_norm = self._normalize_band(nir)
352
+
353
+ rgb_composite = np.stack([nir_norm, red_norm, green_norm], axis=-1)
354
+
355
+ return rgb_composite.astype(np.uint8)
356
+
357
+ # Fallback to first PCA component
358
+ elif 'pca_1' in pca_features:
359
+ pca1 = pca_features['pca_1']
360
+ pca1_norm = self._normalize_band(pca1)
361
+ return np.stack([pca1_norm, pca1_norm, pca1_norm], axis=-1).astype(np.uint8)
362
+
363
+ else:
364
+ # Return empty image
365
+ return np.zeros((100, 100, 3), dtype=np.uint8)
366
+
367
+ except Exception as e:
368
+ logger.error(f"Spectral visualization creation failed: {e}")
369
+ return np.zeros((100, 100, 3), dtype=np.uint8)
370
+
371
+ def _normalize_band(self, band: np.ndarray) -> np.ndarray:
372
+ """Normalize band to 0-255 range."""
373
+ valid_data = band[~np.isnan(band)]
374
+ if len(valid_data) == 0:
375
+ return np.zeros_like(band, dtype=np.uint8)
376
+
377
+ m, M = np.min(valid_data), np.max(valid_data)
378
+ if M > m:
379
+ normalized = ((band - m) / (M - m) * 255).astype(np.uint8)
380
+ else:
381
+ normalized = np.zeros_like(band, dtype=np.uint8)
382
+
383
+ return normalized
sorghum_pipeline/features/texture.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Texture feature extraction for the Sorghum Pipeline.
3
+
4
+ This module handles extraction of texture features including:
5
+ - Local Binary Patterns (LBP)
6
+ - Histogram of Oriented Gradients (HOG)
7
+ - Lacunarity features
8
+ - Edge Histogram Descriptor (EHD)
9
+ """
10
+
11
+ import numpy as np
12
+ import cv2
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from skimage.feature import local_binary_pattern, hog
16
+ from skimage import exposure
17
+ from scipy import ndimage, signal
18
+ from sklearn.decomposition import PCA
19
+ from typing import Dict, Tuple, Optional, Any
20
+ import logging
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class TextureExtractor:
26
+ """Extracts texture features from images."""
27
+
28
+ def __init__(self,
29
+ lbp_points: int = 8,
30
+ lbp_radius: int = 1,
31
+ hog_orientations: int = 9,
32
+ hog_pixels_per_cell: Tuple[int, int] = (8, 8),
33
+ hog_cells_per_block: Tuple[int, int] = (2, 2),
34
+ lacunarity_window: int = 15,
35
+ ehd_threshold: float = 0.3,
36
+ angle_resolution: int = 45):
37
+ """
38
+ Initialize texture extractor.
39
+
40
+ Args:
41
+ lbp_points: Number of points for LBP
42
+ lbp_radius: Radius for LBP
43
+ hog_orientations: Number of orientations for HOG
44
+ hog_pixels_per_cell: Pixels per cell for HOG
45
+ hog_cells_per_block: Cells per block for HOG
46
+ lacunarity_window: Window size for lacunarity
47
+ ehd_threshold: Threshold for EHD
48
+ angle_resolution: Angle resolution for EHD
49
+ """
50
+ self.lbp_points = lbp_points
51
+ self.lbp_radius = lbp_radius
52
+ self.hog_orientations = hog_orientations
53
+ self.hog_pixels_per_cell = hog_pixels_per_cell
54
+ self.hog_cells_per_block = hog_cells_per_block
55
+ self.lacunarity_window = lacunarity_window
56
+ self.ehd_threshold = ehd_threshold
57
+ self.angle_resolution = angle_resolution
58
+
59
+ def extract_lbp(self, gray_image: np.ndarray) -> np.ndarray:
60
+ """
61
+ Extract Local Binary Pattern features.
62
+
63
+ Args:
64
+ gray_image: Grayscale input image
65
+
66
+ Returns:
67
+ LBP feature map
68
+ """
69
+ try:
70
+ lbp = local_binary_pattern(
71
+ gray_image,
72
+ self.lbp_points,
73
+ self.lbp_radius,
74
+ method='uniform'
75
+ )
76
+ return self._convert_to_uint8(lbp)
77
+ except Exception as e:
78
+ logger.error(f"LBP extraction failed: {e}")
79
+ return np.zeros_like(gray_image, dtype=np.uint8)
80
+
81
+ def extract_hog(self, gray_image: np.ndarray) -> np.ndarray:
82
+ """
83
+ Extract Histogram of Oriented Gradients features.
84
+
85
+ Args:
86
+ gray_image: Grayscale input image
87
+
88
+ Returns:
89
+ HOG feature map
90
+ """
91
+ try:
92
+ _, vis = hog(
93
+ gray_image,
94
+ orientations=self.hog_orientations,
95
+ pixels_per_cell=self.hog_pixels_per_cell,
96
+ cells_per_block=self.hog_cells_per_block,
97
+ visualize=True,
98
+ feature_vector=True
99
+ )
100
+ return exposure.rescale_intensity(vis, out_range=(0, 255)).astype(np.uint8)
101
+ except Exception as e:
102
+ logger.error(f"HOG extraction failed: {e}")
103
+ return np.zeros_like(gray_image, dtype=np.uint8)
104
+
105
+ def compute_local_lacunarity(self, gray_image: np.ndarray, window_size: int) -> np.ndarray:
106
+ """
107
+ Compute local lacunarity.
108
+
109
+ Args:
110
+ gray_image: Grayscale input image
111
+ window_size: Size of the sliding window
112
+
113
+ Returns:
114
+ Local lacunarity map
115
+ """
116
+ try:
117
+ arr = gray_image.astype(np.float32)
118
+ m1 = ndimage.uniform_filter(arr, size=window_size)
119
+ m2 = ndimage.uniform_filter(arr * arr, size=window_size)
120
+ var = m2 - m1 * m1
121
+ eps = 1e-6
122
+ lac = var / (m1 * m1 + eps) + 1
123
+ lac[m1 <= eps] = 0
124
+ return lac
125
+ except Exception as e:
126
+ logger.error(f"Local lacunarity computation failed: {e}")
127
+ return np.zeros_like(gray_image, dtype=np.float32)
128
+
129
+ def compute_lacunarity_features(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
130
+ """
131
+ Compute three types of lacunarity features.
132
+
133
+ Args:
134
+ gray_image: Grayscale input image
135
+
136
+ Returns:
137
+ Tuple of (lac1, lac2, lac3) lacunarity maps
138
+ """
139
+ try:
140
+ # L1: Single window lacunarity
141
+ lac1 = self.compute_local_lacunarity(gray_image, self.lacunarity_window)
142
+
143
+ # L2: Multi-scale lacunarity
144
+ scales = [max(3, self.lacunarity_window//2), self.lacunarity_window, self.lacunarity_window*2]
145
+ lac2 = np.mean([
146
+ self.compute_local_lacunarity(gray_image, s) for s in scales
147
+ ], axis=0)
148
+
149
+ # L3: DBC Lacunarity (if available)
150
+ try:
151
+ from ..models.dbc_lacunarity import DBC_Lacunarity
152
+ x = torch.from_numpy(gray_image.astype(np.float32)/255.0)[None, None]
153
+ layer = DBC_Lacunarity(window_size=self.lacunarity_window).eval()
154
+ with torch.no_grad():
155
+ lac3 = layer(x).squeeze().cpu().numpy()
156
+ except ImportError:
157
+ logger.warning("DBC Lacunarity not available, using L2 as L3")
158
+ lac3 = lac2.copy()
159
+
160
+ return (
161
+ self._convert_to_uint8(lac1),
162
+ self._convert_to_uint8(lac2),
163
+ self._convert_to_uint8(lac3)
164
+ )
165
+ except Exception as e:
166
+ logger.error(f"Lacunarity features computation failed: {e}")
167
+ empty = np.zeros_like(gray_image, dtype=np.uint8)
168
+ return empty, empty, empty
169
+
170
+ def generate_ehd_masks(self, mask_size: int = 3) -> np.ndarray:
171
+ """
172
+ Generate masks for Edge Histogram Descriptor.
173
+
174
+ Args:
175
+ mask_size: Size of the mask
176
+
177
+ Returns:
178
+ Array of EHD masks
179
+ """
180
+ if mask_size < 3:
181
+ mask_size = 3
182
+ if mask_size % 2 == 0:
183
+ mask_size += 1
184
+
185
+ # Base gradient mask
186
+ Gy = np.outer([1, 0, -1], [1, 2, 1])
187
+
188
+ # Expand if needed
189
+ if mask_size > 3:
190
+ expd = np.outer([1, 2, 1], [1, 2, 1])
191
+ for _ in range((mask_size - 3) // 2):
192
+ Gy = signal.convolve2d(expd, Gy, mode='full')
193
+
194
+ # Generate masks for different angles
195
+ angles = np.arange(0, 360, self.angle_resolution)
196
+ masks = np.zeros((len(angles), mask_size, mask_size), dtype=np.float32)
197
+
198
+ for i, angle in enumerate(angles):
199
+ masks[i] = ndimage.rotate(Gy, angle, reshape=False, mode='nearest')
200
+
201
+ return masks
202
+
203
+ def extract_ehd_features(self, gray_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
204
+ """
205
+ Extract Edge Histogram Descriptor features.
206
+
207
+ Args:
208
+ gray_image: Grayscale input image
209
+
210
+ Returns:
211
+ Tuple of (ehd_features, ehd_map)
212
+ """
213
+ try:
214
+ # Generate masks
215
+ masks = self.generate_ehd_masks()
216
+
217
+ # Convert to tensor
218
+ X = torch.from_numpy(gray_image.astype(np.float32)/255.0).unsqueeze(0).unsqueeze(0)
219
+ masks_tensor = torch.tensor(masks).unsqueeze(1).float()
220
+
221
+ # Convolve with masks
222
+ edge_responses = F.conv2d(X, masks_tensor, dilation=7)
223
+
224
+ # Find maximum response
225
+ values, indices = torch.max(edge_responses, dim=1)
226
+ indices[values < self.ehd_threshold] = masks.shape[0]
227
+
228
+ # Pool features
229
+ feat_vect = []
230
+ for edge in range(masks.shape[0] + 1):
231
+ pooled = F.avg_pool2d(
232
+ (indices == edge).unsqueeze(1).float(),
233
+ kernel_size=5, stride=1, padding=2
234
+ )
235
+ feat_vect.append(pooled.squeeze(1))
236
+
237
+ ehd_features = torch.stack(feat_vect, dim=1).squeeze(0).cpu().numpy()
238
+ ehd_map = np.argmax(ehd_features, axis=0).astype(np.uint8)
239
+
240
+ return ehd_features, ehd_map
241
+
242
+ except Exception as e:
243
+ logger.error(f"EHD features extraction failed: {e}")
244
+ empty_features = np.zeros((9, gray_image.shape[0]-4, gray_image.shape[1]-4), dtype=np.float32)
245
+ empty_map = np.zeros_like(gray_image, dtype=np.uint8)
246
+ return empty_features, empty_map
247
+
248
+ def extract_all_texture_features(self, gray_image: np.ndarray) -> Dict[str, np.ndarray]:
249
+ """
250
+ Extract all texture features from a grayscale image.
251
+
252
+ Args:
253
+ gray_image: Grayscale input image
254
+
255
+ Returns:
256
+ Dictionary of texture features
257
+ """
258
+ features = {}
259
+
260
+ try:
261
+ # LBP
262
+ features['lbp'] = self.extract_lbp(gray_image)
263
+
264
+ # HOG
265
+ features['hog'] = self.extract_hog(gray_image)
266
+
267
+ # Lacunarity
268
+ lac1, lac2, lac3 = self.compute_lacunarity_features(gray_image)
269
+ features['lac1'] = lac1
270
+ features['lac2'] = lac2
271
+ features['lac3'] = lac3
272
+
273
+ # EHD
274
+ ehd_features, ehd_map = self.extract_ehd_features(gray_image)
275
+ features['ehd_features'] = ehd_features
276
+ features['ehd_map'] = ehd_map
277
+
278
+ logger.debug("All texture features extracted successfully")
279
+
280
+ except Exception as e:
281
+ logger.error(f"Texture feature extraction failed: {e}")
282
+ # Return empty features
283
+ features = {
284
+ 'lbp': np.zeros_like(gray_image, dtype=np.uint8),
285
+ 'hog': np.zeros_like(gray_image, dtype=np.uint8),
286
+ 'lac1': np.zeros_like(gray_image, dtype=np.uint8),
287
+ 'lac2': np.zeros_like(gray_image, dtype=np.uint8),
288
+ 'lac3': np.zeros_like(gray_image, dtype=np.uint8),
289
+ 'ehd_features': np.zeros((9, gray_image.shape[0]-4, gray_image.shape[1]-4), dtype=np.float32),
290
+ 'ehd_map': np.zeros_like(gray_image, dtype=np.uint8)
291
+ }
292
+
293
+ return features
294
+
295
+ def _convert_to_uint8(self, arr: np.ndarray) -> np.ndarray:
296
+ """Convert array to uint8 with proper normalization."""
297
+ arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
298
+ if arr.ptp() > 0:
299
+ normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
300
+ else:
301
+ normalized = np.zeros_like(arr)
302
+ return np.clip(normalized, 0, 255).astype(np.uint8)
303
+
304
+ def compute_texture_statistics(self, features: Dict[str, np.ndarray],
305
+ mask: Optional[np.ndarray] = None) -> Dict[str, Dict[str, float]]:
306
+ """
307
+ Compute statistics for texture features.
308
+
309
+ Args:
310
+ features: Dictionary of texture features
311
+ mask: Optional mask to apply
312
+
313
+ Returns:
314
+ Dictionary of feature statistics
315
+ """
316
+ stats = {}
317
+
318
+ for feature_name, feature_data in features.items():
319
+ if feature_name == 'ehd_features':
320
+ # Special handling for EHD features
321
+ if mask is not None:
322
+ # Apply mask to each channel
323
+ masked_features = []
324
+ for i in range(feature_data.shape[0]):
325
+ channel = feature_data[i]
326
+ if mask.shape != channel.shape:
327
+ # Resize mask to match channel
328
+ mask_resized = cv2.resize(mask, (channel.shape[1], channel.shape[0]),
329
+ interpolation=cv2.INTER_NEAREST)
330
+ masked_channel = np.where(mask_resized > 0, channel, np.nan)
331
+ else:
332
+ masked_channel = np.where(mask > 0, channel, np.nan)
333
+ masked_features.append(masked_channel)
334
+ feature_data = np.stack(masked_features, axis=0)
335
+ else:
336
+ feature_data = feature_data
337
+
338
+ # Compute statistics for each EHD channel
339
+ channel_stats = {}
340
+ for i in range(feature_data.shape[0]):
341
+ channel = feature_data[i]
342
+ valid_data = channel[~np.isnan(channel)]
343
+ if len(valid_data) > 0:
344
+ channel_stats[f'channel_{i}'] = {
345
+ 'mean': float(np.mean(valid_data)),
346
+ 'std': float(np.std(valid_data)),
347
+ 'min': float(np.min(valid_data)),
348
+ 'max': float(np.max(valid_data)),
349
+ 'median': float(np.median(valid_data))
350
+ }
351
+ stats[feature_name] = channel_stats
352
+ else:
353
+ # Regular 2D features
354
+ if mask is not None and mask.shape == feature_data.shape:
355
+ masked_data = np.where(mask > 0, feature_data, np.nan)
356
+ else:
357
+ masked_data = feature_data
358
+
359
+ valid_data = masked_data[~np.isnan(masked_data)]
360
+ if len(valid_data) > 0:
361
+ stats[feature_name] = {
362
+ 'mean': float(np.mean(valid_data)),
363
+ 'std': float(np.std(valid_data)),
364
+ 'min': float(np.min(valid_data)),
365
+ 'max': float(np.max(valid_data)),
366
+ 'median': float(np.median(valid_data))
367
+ }
368
+ else:
369
+ stats[feature_name] = {
370
+ 'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0, 'median': 0.0
371
+ }
372
+
373
+ return stats
sorghum_pipeline/features/vegetation.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vegetation index extraction for the Sorghum Pipeline.
3
+
4
+ This module handles extraction of various vegetation indices
5
+ from multispectral data.
6
+ """
7
+
8
+ import numpy as np
9
+ import cv2
10
+ from typing import Dict, Tuple, Optional, Any
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class VegetationIndexExtractor:
17
+ """Extracts vegetation indices from spectral data."""
18
+
19
+ def __init__(self, epsilon: float = 1e-10, soil_factor: float = 0.16):
20
+ """
21
+ Initialize vegetation index extractor.
22
+
23
+ Args:
24
+ epsilon: Small value to avoid division by zero
25
+ soil_factor: Soil factor for certain indices
26
+ """
27
+ # Coerce to float in case config passed strings like "1e-10"
28
+ try:
29
+ self.epsilon = float(epsilon)
30
+ except Exception:
31
+ self.epsilon = 1e-10
32
+ try:
33
+ self.soil_factor = float(soil_factor)
34
+ except Exception:
35
+ self.soil_factor = 0.16
36
+
37
+ # Define vegetation index formulas
38
+ self.index_formulas = {
39
+ "NDVI": lambda nir, red: (nir - red) / (nir + red + self.epsilon),
40
+ "GNDVI": lambda nir, green: (nir - green) / (nir + green + self.epsilon),
41
+ "NDRE": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
42
+ "GRNDVI": lambda nir, green, red: (nir - (green + red)) / (nir + (green + red) + self.epsilon),
43
+ "TNDVI": lambda nir, red: np.sqrt(np.clip(((nir - red) / (nir + red + self.epsilon)) + 0.5, 0, None)),
44
+ "MGRVI": lambda green, red: (green**2 - red**2) / (green**2 + red**2 + self.epsilon),
45
+ "GRVI": lambda nir, green: nir / (green + self.epsilon),
46
+ "NGRDI": lambda green, red: (green - red) / (green + red + self.epsilon),
47
+ "MSAVI": lambda nir, red: 0.5 * (2.0 * nir + 1 - np.sqrt((2 * nir + 1)**2 - 8 * (nir - red))),
48
+ "OSAVI": lambda nir, red: (nir - red) / (nir + red + self.soil_factor + self.epsilon),
49
+ "TSAVI": lambda nir, red, s=0.33, a=0.5, X=1.5: (s * (nir - s * red - a)) / (a * nir + red - a * s + X * (1 + s**2) + self.epsilon),
50
+ "GSAVI": lambda nir, green, l=0.5: (1 + l) * (nir - green) / (nir + green + l + self.epsilon),
51
+ # Requested additions and aliases
52
+ "GOSAVI": lambda nir, green: (nir - green) / (nir + green + 0.16 + self.epsilon),
53
+ "GDVI": lambda nir, green: nir - green,
54
+ "NDWI": lambda green, nir: (green - nir) / (green + nir + self.epsilon),
55
+ "DSWI4": lambda green, red: green / (red + self.epsilon),
56
+ "CIRE": lambda nir, red_edge: (nir / (red_edge + self.epsilon)) - 1.0,
57
+ "LCI": lambda nir, red_edge: (nir - red_edge) / (nir + red_edge + self.epsilon),
58
+ "CIgreen": lambda nir, green: (nir / (green + self.epsilon)) - 1,
59
+ "MCARI": lambda red_edge, red, green: ((red_edge - red) - 0.2 * (red_edge - green)) * (red_edge / (red + self.epsilon)),
60
+ "MCARI1": lambda nir, red, green: 1.2 * (2.5 * (nir - red) - 1.3 * (nir - green)),
61
+ "MCARI2": lambda nir, red, green: (1.5 * (2.5 * (nir - red) - 1.3 * (nir - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon))),
62
+ # MTVI variants per request
63
+ "MTVI1": lambda nir, red, green: 1.2 * (1.2 * (nir - green) - 2.5 * (red - green)),
64
+ "MTVI2": lambda nir, red, green: (1.5 * (1.2 * (nir - green) - 2.5 * (red - green))) / np.sqrt((2 * nir + 1)**2 - (6 * nir - 5 * np.sqrt(red + self.epsilon)) - 0.5 + self.epsilon),
65
+ "CVI": lambda nir, red, green: (nir * red) / (green**2 + self.epsilon),
66
+ "ARI": lambda green, red_edge: (1.0 / (green + self.epsilon)) - (1.0 / (red_edge + self.epsilon)),
67
+ "ARI2": lambda nir, green, red_edge: nir * (1.0 / (green + self.epsilon)) - nir * (1.0 / (red_edge + self.epsilon)),
68
+ "DVI": lambda nir, red: nir - red,
69
+ "WDVI": lambda nir, red, a=0.5: nir - a * red,
70
+ "SR": lambda nir, red: nir / (red + self.epsilon),
71
+ "MSR": lambda nir, red: (nir / (red + self.epsilon) - 1) / np.sqrt(nir / (red + self.epsilon) + 1),
72
+ "PVI": lambda nir, red, a=0.5, b=0.3: (nir - a * red - b) / (np.sqrt(1 + a**2) + self.epsilon),
73
+ "GEMI": lambda nir, red: ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon)) * (1 - 0.25 * ((2 * (nir**2 - red**2) + 1.5 * nir + 0.5 * red) / (nir + red + 0.5 + self.epsilon))) - ((red - 0.125) / (1 - red + self.epsilon)),
74
+ "ExR": lambda red, green: 1.3 * red - green,
75
+ "RI": lambda red, green: (red - green) / (red + green + self.epsilon),
76
+ "RRI1": lambda nir, red_edge: nir / (red_edge + self.epsilon),
77
+ "RRI2": lambda red_edge, red: red_edge / (red + self.epsilon),
78
+ "RRI": lambda nir, red_edge: nir / (red_edge + self.epsilon),
79
+ "AVI": lambda nir, red: np.cbrt(nir * (1.0 - red) * (nir - red + self.epsilon)),
80
+ "SIPI2": lambda nir, green, red: (nir - green) / (nir - red + self.epsilon),
81
+ "TCARI": lambda red_edge, red, green: 3 * ((red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))),
82
+ "TCARIOSAVI": lambda red_edge, red, green, nir: (3 * (red_edge - red) - 0.2 * (red_edge - green) * (red_edge / (red + self.epsilon))) / (1 + 0.16 * ((nir - red) / (nir + red + 0.16 + self.epsilon))),
83
+ "CCCI": lambda nir, red_edge, red: (((nir - red_edge) * (nir + red)) / ((nir + red_edge) * (nir - red) + self.epsilon)),
84
+ # Additional indices
85
+ "RDVI": lambda nir, red: (nir - red) / (np.sqrt(nir + red + self.epsilon)),
86
+ "NLI": lambda nir, red: ((nir**2) - red) / ((nir**2) + red + self.epsilon),
87
+ "BIXS": lambda green, red: np.sqrt(((green**2) + (red**2)) / 2.0),
88
+ "IPVI": lambda nir, red: nir / (nir + red + self.epsilon),
89
+ "EVI2": lambda nir, red: 2.4 * (nir - red) / (nir + red + 1.0 + self.epsilon)
90
+ }
91
+
92
+ # Define required bands for each index
93
+ self.index_bands = {
94
+ "NDVI": ["nir", "red"],
95
+ "GNDVI": ["nir", "green"],
96
+ "NDRE": ["nir", "red_edge"],
97
+ "GRNDVI": ["nir", "green", "red"],
98
+ "TNDVI": ["nir", "red"],
99
+ "MGRVI": ["green", "red"],
100
+ "GRVI": ["nir", "green"],
101
+ "NGRDI": ["green", "red"],
102
+ "MSAVI": ["nir", "red"],
103
+ "OSAVI": ["nir", "red"],
104
+ "TSAVI": ["nir", "red"],
105
+ "GSAVI": ["nir", "green"],
106
+ "GOSAVI": ["nir", "green"],
107
+ "GDVI": ["nir", "green"],
108
+ "NDWI": ["green", "nir"],
109
+ "DSWI4": ["green", "red"],
110
+ "CIRE": ["nir", "red_edge"],
111
+ "LCI": ["nir", "red_edge"],
112
+ "CIgreen": ["nir", "green"],
113
+ "MCARI": ["red_edge", "red", "green"],
114
+ "MCARI1": ["nir", "red", "green"],
115
+ "MCARI2": ["nir", "red", "green"],
116
+ "MTVI1": ["nir", "red", "green"],
117
+ "MTVI2": ["nir", "red", "green"],
118
+ "CVI": ["nir", "red", "green"],
119
+ "ARI": ["green", "red_edge"],
120
+ "ARI2": ["nir", "green", "red_edge"],
121
+ "DVI": ["nir", "red"],
122
+ "WDVI": ["nir", "red"],
123
+ "SR": ["nir", "red"],
124
+ "MSR": ["nir", "red"],
125
+ "PVI": ["nir", "red"],
126
+ "GEMI": ["nir", "red"],
127
+ "ExR": ["red", "green"],
128
+ "RI": ["red", "green"],
129
+ "RRI1": ["nir", "red_edge"],
130
+ "RRI2": ["red_edge", "red"],
131
+ "RRI": ["nir", "red_edge"],
132
+ "AVI": ["nir", "red"],
133
+ "SIPI2": ["nir", "green", "red"],
134
+ "TCARI": ["red_edge", "red", "green"],
135
+ "TCARIOSAVI": ["red_edge", "red", "green", "nir"],
136
+ "CCCI": ["nir", "red_edge", "red"],
137
+ "RDVI": ["nir", "red"],
138
+ "NLI": ["nir", "red"],
139
+ "BIXS": ["green", "red"],
140
+ "IPVI": ["nir", "red"],
141
+ "EVI2": ["nir", "red"]
142
+ }
143
+
144
+ def compute_vegetation_indices(self, spectral_stack: Dict[str, np.ndarray],
145
+ mask: np.ndarray) -> Dict[str, Dict[str, Any]]:
146
+ """
147
+ Compute vegetation indices from spectral data.
148
+
149
+ Args:
150
+ spectral_stack: Dictionary of spectral bands
151
+ mask: Binary mask for the plant
152
+
153
+ Returns:
154
+ Dictionary of vegetation indices with values and statistics
155
+ """
156
+ indices = {}
157
+
158
+ for index_name, formula in self.index_formulas.items():
159
+ try:
160
+ # Get required bands
161
+ required_bands = self.index_bands.get(index_name, [])
162
+
163
+ # Check if all required bands are available
164
+ if not all(band in spectral_stack for band in required_bands):
165
+ logger.warning(f"Skipping {index_name}: missing required bands")
166
+ continue
167
+
168
+ # Extract band data as float arrays
169
+ band_data = []
170
+ for band in required_bands:
171
+ arr = spectral_stack[band]
172
+ # Ensure numeric float np.ndarray
173
+ if isinstance(arr, np.ndarray):
174
+ arr = arr.squeeze(-1)
175
+ arr = np.asarray(arr, dtype=np.float64)
176
+ band_data.append(arr)
177
+
178
+ # Compute index (ensure float math)
179
+ index_values = formula(*band_data).astype(np.float64)
180
+
181
+ # Apply mask
182
+ if mask is not None:
183
+ binary_mask = (np.asarray(mask).astype(np.int32) > 0)
184
+ masked_values = np.where(binary_mask, index_values, np.nan)
185
+ else:
186
+ masked_values = index_values
187
+
188
+ # Compute statistics
189
+ valid_values = masked_values[~np.isnan(masked_values)]
190
+ if len(valid_values) > 0:
191
+ stats = {
192
+ 'mean': float(np.mean(valid_values)),
193
+ 'std': float(np.std(valid_values)),
194
+ 'min': float(np.min(valid_values)),
195
+ 'max': float(np.max(valid_values)),
196
+ 'median': float(np.median(valid_values)),
197
+ 'q25': float(np.percentile(valid_values, 25)),
198
+ 'q75': float(np.percentile(valid_values, 75)),
199
+ 'nan_fraction': float(np.isnan(masked_values).sum() / masked_values.size)
200
+ }
201
+ else:
202
+ stats = {
203
+ 'mean': 0.0, 'std': 0.0, 'min': 0.0, 'max': 0.0,
204
+ 'median': 0.0, 'q25': 0.0, 'q75': 0.0, 'nan_fraction': 1.0
205
+ }
206
+
207
+ indices[index_name] = {
208
+ 'values': masked_values,
209
+ 'statistics': stats
210
+ }
211
+
212
+ logger.debug(f"Computed {index_name}")
213
+
214
+ except Exception as e:
215
+ logger.error(f"Failed to compute {index_name}: {e}")
216
+ continue
217
+
218
+ return indices
219
+
220
+ def create_vegetation_index_image(self, index_values: np.ndarray,
221
+ colormap: str = 'RdYlGn',
222
+ vmin: Optional[float] = None,
223
+ vmax: Optional[float] = None) -> np.ndarray:
224
+ """
225
+ Create visualization image for vegetation index.
226
+
227
+ Args:
228
+ index_values: Vegetation index values
229
+ colormap: Matplotlib colormap name
230
+ vmin: Minimum value for normalization
231
+ vmax: Maximum value for normalization
232
+
233
+ Returns:
234
+ RGB image array
235
+ """
236
+ try:
237
+ import matplotlib.pyplot as plt
238
+ import matplotlib.cm as cm
239
+ from matplotlib.colors import Normalize
240
+
241
+ # Determine value range
242
+ valid_values = index_values[~np.isnan(index_values)]
243
+ if len(valid_values) == 0:
244
+ return np.zeros((*index_values.shape, 3), dtype=np.uint8)
245
+
246
+ if vmin is None:
247
+ vmin = np.min(valid_values)
248
+ if vmax is None:
249
+ vmax = np.max(valid_values)
250
+
251
+ # Normalize values
252
+ norm = Normalize(vmin=vmin, vmax=vmax)
253
+ cmap = cm.get_cmap(colormap)
254
+
255
+ # Apply colormap
256
+ rgba_img = cmap(norm(index_values))
257
+ rgba_img[np.isnan(index_values)] = [1, 1, 1, 1] # White for NaN
258
+
259
+ # Convert to RGB uint8
260
+ rgb_img = (rgba_img[:, :, :3] * 255).astype(np.uint8)
261
+
262
+ return rgb_img
263
+
264
+ except Exception as e:
265
+ logger.error(f"Failed to create vegetation index image: {e}")
266
+ return np.zeros((*index_values.shape, 3), dtype=np.uint8)
267
+
268
+ def get_available_indices(self) -> list:
269
+ """Get list of available vegetation indices."""
270
+ return list(self.index_formulas.keys())
271
+
272
+ def get_index_requirements(self, index_name: str) -> list:
273
+ """
274
+ Get required bands for a specific index.
275
+
276
+ Args:
277
+ index_name: Name of the vegetation index
278
+
279
+ Returns:
280
+ List of required band names
281
+ """
282
+ return self.index_bands.get(index_name, [])
283
+
284
+ def validate_spectral_data(self, spectral_stack: Dict[str, np.ndarray]) -> bool:
285
+ """
286
+ Validate spectral data for vegetation index computation.
287
+
288
+ Args:
289
+ spectral_stack: Dictionary of spectral bands
290
+
291
+ Returns:
292
+ True if valid, False otherwise
293
+ """
294
+ if not spectral_stack:
295
+ return False
296
+
297
+ required_bands = ['nir', 'red', 'green', 'red_edge']
298
+ if not all(band in spectral_stack for band in required_bands):
299
+ logger.warning("Missing required spectral bands")
300
+ return False
301
+
302
+ # Check data shapes
303
+ shapes = [arr.shape for arr in spectral_stack.values()]
304
+ if not all(shape == shapes[0] for shape in shapes):
305
+ logger.warning("Inconsistent spectral band shapes")
306
+ return False
307
+
308
+ return True
sorghum_pipeline/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model definitions for the Sorghum Pipeline.
3
+
4
+ This package contains neural network models and other
5
+ computational models used in the pipeline.
6
+ """
7
+
8
+ from .dbc_lacunarity import DBC_Lacunarity
9
+
10
+ __all__ = ["DBC_Lacunarity"]
sorghum_pipeline/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (438 Bytes). View file
 
sorghum_pipeline/models/__pycache__/dbc_lacunarity.cpython-312.pyc ADDED
Binary file (4.14 kB). View file
 
sorghum_pipeline/models/dbc_lacunarity.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DBC Lacunarity model for texture analysis.
3
+
4
+ This module implements the Differential Box Counting (DBC) method
5
+ for computing lacunarity features.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from typing import Optional
11
+
12
+
13
+ class DBC_Lacunarity(nn.Module):
14
+ """
15
+ Differential Box Counting Lacunarity model.
16
+
17
+ This model computes lacunarity features using the DBC method,
18
+ which is useful for texture analysis in plant images.
19
+ """
20
+
21
+ def __init__(self, model_name: str = 'Net', window_size: int = 3, eps: float = 1e-6):
22
+ """
23
+ Initialize DBC Lacunarity model.
24
+
25
+ Args:
26
+ model_name: Name of the model
27
+ window_size: Size of the sliding window
28
+ eps: Small value to avoid division by zero
29
+ """
30
+ super(DBC_Lacunarity, self).__init__()
31
+ self.window_size = window_size
32
+ self.normalize = nn.Tanh()
33
+ self.num_output_channels = 3
34
+ self.eps = eps
35
+ self.r = 1
36
+ self.model_name = model_name
37
+ self.max_pool = nn.MaxPool2d(kernel_size=self.window_size, stride=1)
38
+
39
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ Forward pass of the DBC Lacunarity model.
42
+
43
+ Args:
44
+ image: Input image tensor [B, C, H, W]
45
+
46
+ Returns:
47
+ Lacunarity features tensor
48
+ """
49
+ # Normalize image to 0-255 range
50
+ image = ((self.normalize(image) + 1) / 2) * 255
51
+
52
+ # Perform operations independently for each window in the current channel
53
+ max_pool_output = self.max_pool(image)
54
+ min_pool_output = -self.max_pool(-image)
55
+
56
+ # Compute DBC lacunarity
57
+ nr = torch.ceil(max_pool_output / (self.r + self.eps)) - torch.ceil(min_pool_output / (self.r + self.eps)) - 1
58
+ Mr = torch.sum(nr)
59
+ Q_mr = nr / (self.window_size - self.r + 1)
60
+ L_r = (Mr**2) * Q_mr / (Mr * Q_mr + self.eps)**2
61
+
62
+ return L_r
63
+
64
+ def compute_lacunarity(self, image: torch.Tensor) -> torch.Tensor:
65
+ """
66
+ Compute lacunarity for a single image.
67
+
68
+ Args:
69
+ image: Input image tensor [1, 1, H, W]
70
+
71
+ Returns:
72
+ Lacunarity tensor
73
+ """
74
+ with torch.no_grad():
75
+ return self.forward(image)
76
+
77
+ def get_model_info(self) -> dict:
78
+ """
79
+ Get model information.
80
+
81
+ Returns:
82
+ Dictionary containing model parameters
83
+ """
84
+ return {
85
+ 'model_name': self.model_name,
86
+ 'window_size': self.window_size,
87
+ 'eps': self.eps,
88
+ 'r': self.r,
89
+ 'num_output_channels': self.num_output_channels
90
+ }
sorghum_pipeline/output/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Output management modules for the Sorghum Pipeline.
3
+
4
+ This package contains output functionality including:
5
+ - Result saving
6
+ - Visualization generation
7
+ - Report creation
8
+ - Data export
9
+ """
10
+
11
+ from .manager import OutputManager
12
+
13
+ __all__ = ["OutputManager"]
sorghum_pipeline/output/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (470 Bytes). View file
 
sorghum_pipeline/output/__pycache__/manager.cpython-312.pyc ADDED
Binary file (40.9 kB). View file
 
sorghum_pipeline/output/manager.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Output manager for the Sorghum Pipeline.
3
+
4
+ This module handles saving results, generating visualizations,
5
+ and creating reports.
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import numpy as np
11
+ import cv2
12
+
13
+ # Use a non-GUI backend to avoid segmentation faults in headless runs
14
+ try:
15
+ import matplotlib
16
+ if os.environ.get('MPLBACKEND') is None:
17
+ matplotlib.use('Agg')
18
+ import matplotlib.pyplot as plt
19
+ import matplotlib.cm as cm
20
+ from matplotlib.colors import Normalize
21
+ except Exception:
22
+ # Fallback safe imports (should not happen normally)
23
+ import matplotlib.pyplot as plt
24
+ import matplotlib.cm as cm
25
+ from matplotlib.colors import Normalize
26
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
27
+ from pathlib import Path
28
+ from typing import Dict, Any, Optional, List, Tuple
29
+ from concurrent.futures import ThreadPoolExecutor, as_completed
30
+ import pandas as pd
31
+ import logging
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class OutputManager:
37
+ """Manages output generation and saving."""
38
+
39
+ def __init__(self, output_folder: str, settings: Any):
40
+ """
41
+ Initialize output manager.
42
+
43
+ Args:
44
+ output_folder: Base output folder
45
+ settings: Output settings from config
46
+ """
47
+ self.output_folder = Path(output_folder)
48
+ self.settings = settings
49
+ # Fast mode and parallel save controls
50
+ try:
51
+ self.fast_mode: bool = bool(int(os.environ.get('FAST_OUTPUT', '0'))) or bool(getattr(settings, 'fast_mode', False))
52
+ except Exception:
53
+ self.fast_mode = False
54
+ try:
55
+ self.max_workers: int = int(os.environ.get('FAST_SAVE_WORKERS', '4'))
56
+ except Exception:
57
+ self.max_workers = 4
58
+ try:
59
+ self.png_compression: int = int(os.environ.get('PNG_COMPRESSION', '1')) # 0-9; 1 is fast
60
+ except Exception:
61
+ self.png_compression = 1
62
+
63
+ # Reduce thread usage to lower risk of native library segfaults
64
+ try:
65
+ import os as _os
66
+ _os.environ.setdefault('OMP_NUM_THREADS', '1')
67
+ _os.environ.setdefault('OPENBLAS_NUM_THREADS', '1')
68
+ _os.environ.setdefault('MKL_NUM_THREADS', '1')
69
+ _os.environ.setdefault('NUMEXPR_NUM_THREADS', '1')
70
+ except Exception:
71
+ pass
72
+ try:
73
+ cv2.setNumThreads(1)
74
+ except Exception:
75
+ pass
76
+
77
+ # Create base directories
78
+ self.output_folder.mkdir(parents=True, exist_ok=True)
79
+
80
+ def _imwrite_fast(self, dest: Path, img: np.ndarray) -> None:
81
+ try:
82
+ cv2.imwrite(str(dest), img, [cv2.IMWRITE_PNG_COMPRESSION, int(self.png_compression)])
83
+ except Exception:
84
+ cv2.imwrite(str(dest), img)
85
+
86
+ def create_output_directories(self) -> None:
87
+ """Ensure base output directory exists.
88
+
89
+ Note: Do NOT create subdirectories at the root (e.g., 'analysis').
90
+ Subdirectories are created within each plant's directory only.
91
+ """
92
+ self.output_folder.mkdir(parents=True, exist_ok=True)
93
+
94
+ def save_plant_results(self, plant_key: str, plant_data: Dict[str, Any]) -> None:
95
+ """
96
+ Save all results for a single plant.
97
+
98
+ Args:
99
+ plant_key: Plant identifier (e.g., "2025_02_05_plant1_frame8")
100
+ plant_data: Plant data dictionary
101
+ """
102
+ try:
103
+ # Parse plant key
104
+ parts = plant_key.split('_')
105
+ date_key = "_".join(parts[:3])
106
+ plant_name = parts[3]
107
+ frame_key = parts[4] if len(parts) > 4 else "frame0"
108
+
109
+ # Create plant-specific directory
110
+ plant_dir = self.output_folder / date_key / plant_name
111
+ plant_dir.mkdir(parents=True, exist_ok=True)
112
+
113
+ # Save segmentation results
114
+ self._save_segmentation_results(plant_dir, plant_name, plant_data)
115
+
116
+ # Save texture features
117
+ self._save_texture_features(plant_dir, plant_data)
118
+
119
+ # Save vegetation indices
120
+ self._save_vegetation_indices(plant_dir, plant_data)
121
+
122
+ # Save morphology features
123
+ self._save_morphology_features(plant_dir, plant_data)
124
+
125
+ # Save analysis plots
126
+ self._save_analysis_plots(plant_dir, plant_data)
127
+
128
+ # Save metadata
129
+ self._save_metadata(plant_dir, plant_key, plant_data)
130
+
131
+ logger.debug(f"Results saved for {plant_key}")
132
+
133
+ except Exception as e:
134
+ logger.error(f"Failed to save results for {plant_key}: {e}")
135
+
136
+ def _save_segmentation_results(self, plant_dir: Path, plant_name: str, plant_data: Dict[str, Any]) -> None:
137
+ """Save segmentation results."""
138
+ if not self.settings.save_images:
139
+ return
140
+
141
+ seg_dir = plant_dir / self.settings.segmentation_dir
142
+ seg_dir.mkdir(exist_ok=True)
143
+
144
+ try:
145
+ tasks: List[Tuple[Path, np.ndarray]] = []
146
+ # Choose which base image to present in original/overlay
147
+ use_feature_image = False
148
+ try:
149
+ # Allow env override, and special-case plants 13-16 per user requirement
150
+ use_feature_image = bool(int(os.environ.get('OUTPUT_USE_FEATURE_IMAGE', '0'))) or plant_name in { 'plant13','plant14','plant15','plant16' }
151
+ except Exception:
152
+ use_feature_image = plant_name in { 'plant13','plant14','plant15','plant16' }
153
+ if use_feature_image:
154
+ base_image = plant_data.get('composite', plant_data.get('segmentation_composite'))
155
+ else:
156
+ base_image = plant_data.get('segmentation_composite', plant_data.get('composite'))
157
+ if base_image is not None:
158
+ tasks.append((seg_dir / 'original.png', base_image))
159
+ if 'mask' in plant_data:
160
+ tasks.append((seg_dir / 'mask.png', plant_data['mask']))
161
+ if 'mask3' in plant_data and isinstance(plant_data['mask3'], np.ndarray):
162
+ tasks.append((seg_dir / 'mask3.png', plant_data['mask3']))
163
+ # Save the BRIA-generated mask (if present before overrides) as mask2.png
164
+ if 'original_mask' in plant_data and isinstance(plant_data['original_mask'], np.ndarray):
165
+ tasks.append((seg_dir / 'mask2.png', plant_data['original_mask']))
166
+ if base_image is not None and 'mask' in plant_data:
167
+ overlay = self._create_overlay(base_image, plant_data['mask'])
168
+ tasks.append((seg_dir / 'overlay.png', overlay))
169
+ if 'masked_composite' in plant_data:
170
+ tasks.append((seg_dir / 'masked_composite.png', plant_data['masked_composite']))
171
+
172
+ # Create white-background maskouts
173
+ try:
174
+ if base_image is not None and 'mask' in plant_data:
175
+ maskout_external = self._create_maskout_white_background(base_image, plant_data['mask'])
176
+ tasks.append((seg_dir / 'maskout_external.png', maskout_external))
177
+ # BRIA-only maskout directly on original composite
178
+ if base_image is not None and 'original_mask' in plant_data and isinstance(plant_data['original_mask'], np.ndarray):
179
+ maskout_bria = self._create_maskout_white_background(base_image, plant_data['original_mask'])
180
+ tasks.append((seg_dir / 'maskout_bria.png', maskout_bria))
181
+ # mask3 maskout on original composite
182
+ if base_image is not None and 'mask3' in plant_data and isinstance(plant_data['mask3'], np.ndarray):
183
+ maskout_mask3 = self._create_maskout_white_background(base_image, plant_data['mask3'])
184
+ tasks.append((seg_dir / 'maskout_mask3.png', maskout_mask3))
185
+ except Exception as _e:
186
+ logger.debug(f"Failed to create double maskouts: {_e}")
187
+
188
+ if self.max_workers > 1 and len(tasks) > 1:
189
+ with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
190
+ futures = [ex.submit(self._imwrite_fast, p, img) for p, img in tasks]
191
+ for _ in as_completed(futures):
192
+ pass
193
+ else:
194
+ for p, img in tasks:
195
+ self._imwrite_fast(p, img)
196
+ except Exception as e:
197
+ logger.error(f"Failed to save segmentation results: {e}")
198
+
199
+ def _save_texture_features(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
200
+ """Save texture features."""
201
+ if not self.settings.save_images or 'texture_features' not in plant_data:
202
+ return
203
+
204
+ texture_dir = plant_dir / self.settings.texture_dir
205
+ texture_dir.mkdir(exist_ok=True)
206
+
207
+ def save_feature_png(feature_name: str, values: Any, dest: Path, cmap_name: str = 'viridis') -> None:
208
+ try:
209
+ arr = np.asarray(values)
210
+ if arr.ndim == 3 and arr.shape[-1] == 3:
211
+ self._imwrite_fast(dest, cv2.cvtColor(arr.astype(np.uint8), cv2.COLOR_RGB2BGR))
212
+ return
213
+ if self.fast_mode:
214
+ # Fast path: simple normalization, no matplotlib
215
+ normalized = self._normalize_to_uint8(np.nan_to_num(arr.astype(np.float64), nan=0.0))
216
+ self._imwrite_fast(dest, normalized)
217
+ else:
218
+ arr = arr.astype(np.float64)
219
+ masked = np.ma.masked_invalid(arr)
220
+ fig, ax = plt.subplots(figsize=(5, 5))
221
+ ax.set_axis_off()
222
+ ax.set_facecolor('white')
223
+ im = ax.imshow(masked, cmap=cmap_name)
224
+ divider = make_axes_locatable(ax)
225
+ cax = divider.append_axes("right", size="2%", pad=0.02)
226
+ cbar = plt.colorbar(im, cax=cax, orientation='vertical')
227
+ cbar.set_label(feature_name, fontsize=7)
228
+ cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
229
+ if hasattr(cbar, 'outline') and cbar.outline is not None:
230
+ cbar.outline.set_linewidth(0.5)
231
+ plt.tight_layout()
232
+ plt.savefig(dest, dpi=self.settings.plot_dpi, bbox_inches='tight')
233
+ plt.close(fig)
234
+ except Exception as e:
235
+ logger.error(f"Failed to save texture feature image for {feature_name}: {e}")
236
+ try:
237
+ normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
238
+ self._imwrite_fast(dest, normalized)
239
+ except Exception:
240
+ pass
241
+
242
+ try:
243
+ texture_features = plant_data['texture_features']
244
+
245
+ for band, band_data in texture_features.items():
246
+ if 'features' not in band_data:
247
+ continue
248
+
249
+ band_dir = texture_dir / band
250
+ band_dir.mkdir(exist_ok=True)
251
+
252
+ features = band_data['features']
253
+
254
+ # Save individual feature maps (optionally in parallel)
255
+ items: List[Tuple[str, np.ndarray, Path, str]] = []
256
+ for feature_name, feature_map in features.items():
257
+ if feature_name == 'ehd_features':
258
+ for i in range(feature_map.shape[0]):
259
+ channel = feature_map[i]
260
+ if isinstance(channel, np.ndarray) and channel.size > 0:
261
+ items.append((f'ehd_channel_{i}', channel, band_dir / f'ehd_channel_{i}.png', 'magma'))
262
+ else:
263
+ if isinstance(feature_map, np.ndarray) and feature_map.size > 0:
264
+ cmap_choice = 'gray' if feature_name in ('lbp', 'hog') else 'plasma' if feature_name.startswith('lac') else 'viridis'
265
+ items.append((feature_name, feature_map, band_dir / f'{feature_name}.png', cmap_choice))
266
+
267
+ if self.max_workers > 1 and len(items) > 1:
268
+ with ThreadPoolExecutor(max_workers=self.max_workers) as ex:
269
+ futures = [ex.submit(save_feature_png, n, m, p, c) for (n, m, p, c) in items]
270
+ for _ in as_completed(futures):
271
+ pass
272
+ else:
273
+ for (n, m, p, c) in items:
274
+ save_feature_png(n, m, p, c)
275
+
276
+ # Create feature summary plot
277
+ self._create_texture_summary_plot(band_dir, features, band)
278
+
279
+ # Save texture statistics if available
280
+ if 'statistics' in band_data and isinstance(band_data['statistics'], dict):
281
+ try:
282
+ with open(band_dir / 'texture_statistics.json', 'w') as f:
283
+ json.dump(band_data['statistics'], f, indent=2)
284
+ except Exception as e:
285
+ logger.error(f"Failed to save texture statistics for {band}: {e}")
286
+
287
+ except Exception as e:
288
+ logger.error(f"Failed to save texture features: {e}")
289
+
290
+ def _save_vegetation_indices(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
291
+ """Save vegetation indices."""
292
+ if not self.settings.save_images or 'vegetation_indices' not in plant_data:
293
+ return
294
+
295
+ veg_dir = plant_dir / self.settings.vegetation_dir
296
+ veg_dir.mkdir(exist_ok=True)
297
+
298
+ # Colormap and range settings per index
299
+ index_cmap_settings = {
300
+ "NDVI": (cm.RdYlGn, -1, 1),
301
+ "GNDVI": (cm.RdYlGn, -1, 1),
302
+ "NDRE": (cm.RdYlGn, -1, 1),
303
+ "GRNDVI": (cm.RdYlGn, -1, 1),
304
+ "TNDVI": (cm.RdYlGn, -1, 1),
305
+ "MGRVI": (cm.RdYlGn, -1, 1),
306
+ "GRVI": (cm.RdYlGn, -1, 1),
307
+ "NGRDI": (cm.RdYlGn, -1, 1),
308
+ "MSAVI": (cm.YlGn, 0, 1),
309
+ "OSAVI": (cm.YlGn, 0, 1),
310
+ "TSAVI": (cm.YlGn, 0, 1),
311
+ "GSAVI": (cm.YlGn, 0, 1),
312
+ "NDWI": (cm.Blues, -1, 1),
313
+ "DSWI4": (cm.Blues, -1, 1),
314
+ "CIRE": (cm.viridis, 0, 10),
315
+ "LCI": (cm.viridis, 0, 5),
316
+ "CIgreen": (cm.viridis, 0, 5),
317
+ "MCARI": (cm.viridis, 0, 1.5),
318
+ "MCARI1": (cm.viridis, 0, 1.5),
319
+ "MCARI2": (cm.viridis, 0, 1.5),
320
+ "CVI": (cm.plasma, 0, 10),
321
+ "TCARI": (cm.viridis, 0, 1),
322
+ "TCARIOSAVI": (cm.viridis, 0, 1),
323
+ "AVI": (cm.magma, 0, 1),
324
+ "SIPI2": (cm.inferno, 0, 1),
325
+ "ARI": (cm.magma, 0, 1),
326
+ "ARI2": (cm.magma, 0, 1),
327
+ "DVI": (cm.Greens, 0, None),
328
+ "WDVI": (cm.Greens, 0, None),
329
+ "SR": (cm.viridis, 0, 10),
330
+ "MSR": (cm.viridis, 0, 10),
331
+ "PVI": (cm.cividis, None, None),
332
+ "GEMI": (cm.cividis, 0, 1),
333
+ "ExR": (cm.Reds, -1, 1),
334
+ "RI": (cm.Reds, 0, None),
335
+ "RRI1": (cm.Reds, 0, 1)
336
+ }
337
+
338
+ def save_index_png(index_name: str, values: Any, dest: Path) -> None:
339
+ try:
340
+ arr = values
341
+ if not isinstance(arr, (list, tuple,)) and isinstance(arr, (float, int)):
342
+ return
343
+ arr = np.asarray(arr, dtype=np.float64)
344
+ if self.fast_mode:
345
+ normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
346
+ self._imwrite_fast(dest, normalized)
347
+ else:
348
+ cmap, vmin, vmax = index_cmap_settings.get(index_name, (cm.viridis, np.nanmin(arr), np.nanmax(arr)))
349
+ if vmin is None:
350
+ vmin = np.nanmin(arr)
351
+ if vmax is None:
352
+ vmax = np.nanmax(arr)
353
+ if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
354
+ vmin, vmax = 0.0, 1.0
355
+ masked = np.ma.masked_invalid(arr)
356
+ fig, ax = plt.subplots(figsize=(5, 5))
357
+ ax.set_axis_off()
358
+ ax.set_facecolor('white')
359
+ im = ax.imshow(masked, cmap=cmap, vmin=vmin, vmax=vmax)
360
+ divider = make_axes_locatable(ax)
361
+ cax = divider.append_axes("right", size="2%", pad=0.02)
362
+ cbar = plt.colorbar(im, cax=cax, orientation='vertical')
363
+ cbar.set_label(index_name, fontsize=7)
364
+ cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
365
+ if hasattr(cbar, 'outline') and cbar.outline is not None:
366
+ cbar.outline.set_linewidth(0.5)
367
+ plt.tight_layout()
368
+ plt.savefig(dest, dpi=self.settings.plot_dpi, bbox_inches='tight')
369
+ plt.close(fig)
370
+ except Exception as e:
371
+ logger.error(f"Failed to save vegetation index image for {index_name}: {e}")
372
+ try:
373
+ # Fallback simple normalization
374
+ normalized = self._normalize_to_uint8(np.nan_to_num(arr, nan=0.0))
375
+ self._imwrite_fast(dest, normalized)
376
+ except Exception:
377
+ pass
378
+
379
+ try:
380
+ vegetation_indices = plant_data['vegetation_indices']
381
+
382
+ items_png: List[Tuple[str, np.ndarray, Path]] = []
383
+ items_stats: List[Tuple[Path, Dict[str, Any]]] = []
384
+ for index_name, index_data in vegetation_indices.items():
385
+ if isinstance(index_data, dict) and 'values' in index_data:
386
+ values = index_data['values']
387
+ if isinstance(values, np.ndarray) and values.size > 0:
388
+ items_png.append((index_name, values, veg_dir / f'{index_name}.png'))
389
+ stats = index_data.get('statistics')
390
+ if isinstance(stats, dict):
391
+ items_stats.append((veg_dir / f'{index_name}_stats.json', stats))
392
+
393
+ # Save sequentially to avoid matplotlib thread-safety issues
394
+ for (name, arr, dest) in items_png:
395
+ save_index_png(name, arr, dest)
396
+ for (path, stats) in items_stats:
397
+ try:
398
+ with open(path, 'w') as f:
399
+ json.dump(stats, f, indent=2)
400
+ except Exception as e:
401
+ logger.error(f"Failed to save stats for {path.name.split('.')[0]}: {e}")
402
+
403
+ # Create vegetation index summary (skip in fast mode)
404
+ if not self.fast_mode:
405
+ self._create_vegetation_summary_plot(veg_dir, vegetation_indices)
406
+
407
+ # Save aggregated vegetation statistics
408
+ try:
409
+ all_stats = {k: v.get('statistics', {}) for k, v in vegetation_indices.items() if isinstance(v, dict)}
410
+ with open(veg_dir / 'vegetation_statistics.json', 'w') as f:
411
+ json.dump(all_stats, f, indent=2)
412
+ except Exception as e:
413
+ logger.error(f"Failed to save aggregated vegetation statistics: {e}")
414
+
415
+ except Exception as e:
416
+ logger.error(f"Failed to save vegetation indices: {e}")
417
+
418
+ def _save_morphology_features(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
419
+ """Save morphological features."""
420
+ if not self.settings.save_images or 'morphology_features' not in plant_data:
421
+ return
422
+
423
+ morph_dir = plant_dir / self.settings.morphology_dir
424
+ morph_dir.mkdir(exist_ok=True)
425
+
426
+ try:
427
+ morphology_features = plant_data['morphology_features']
428
+
429
+ # Save morphological images
430
+ if 'images' in morphology_features:
431
+ for image_name, image_data in morphology_features['images'].items():
432
+ if isinstance(image_data, np.ndarray) and image_data.size > 0:
433
+ cv2.imwrite(str(morph_dir / f'{image_name}.png'), image_data)
434
+
435
+ # Save morphological data
436
+ if 'traits' in morphology_features:
437
+ traits = morphology_features['traits']
438
+ with open(morph_dir / 'traits.json', 'w') as f:
439
+ json.dump(traits, f, indent=2)
440
+
441
+ except Exception as e:
442
+ logger.error(f"Failed to save morphology features: {e}")
443
+
444
+ def _save_analysis_plots(self, plant_dir: Path, plant_data: Dict[str, Any]) -> None:
445
+ """Save analysis plots."""
446
+ if not self.settings.save_plots or self.fast_mode:
447
+ return
448
+
449
+ analysis_dir = plant_dir / self.settings.analysis_dir
450
+ analysis_dir.mkdir(exist_ok=True)
451
+
452
+ try:
453
+ # Create comprehensive analysis plot
454
+ self._create_comprehensive_analysis_plot(analysis_dir, plant_data)
455
+
456
+ except Exception as e:
457
+ logger.error(f"Failed to save analysis plots: {e}")
458
+
459
+ def _save_metadata(self, plant_dir: Path, plant_key: str, plant_data: Dict[str, Any]) -> None:
460
+ """Save metadata for the plant."""
461
+ if not self.settings.save_metadata:
462
+ return
463
+
464
+ try:
465
+ metadata = {
466
+ 'plant_key': plant_key,
467
+ 'timestamp': pd.Timestamp.now().isoformat(),
468
+ 'image_shape': plant_data.get('composite', np.array([])).shape if 'composite' in plant_data else None,
469
+ 'has_mask': 'mask' in plant_data and plant_data['mask'] is not None,
470
+ 'features_available': {
471
+ 'texture': 'texture_features' in plant_data,
472
+ 'vegetation': 'vegetation_indices' in plant_data,
473
+ 'morphology': 'morphology_features' in plant_data
474
+ }
475
+ }
476
+
477
+ with open(plant_dir / 'metadata.json', 'w') as f:
478
+ json.dump(metadata, f, indent=2)
479
+
480
+ except Exception as e:
481
+ logger.error(f"Failed to save metadata: {e}")
482
+
483
+ def _create_overlay(self, image: np.ndarray, mask: np.ndarray,
484
+ color: Tuple[int, int, int] = (0, 255, 0),
485
+ alpha: float = 0.5) -> np.ndarray:
486
+ """Return a strictly masked image: pixels where mask>0 keep original; others set to 0."""
487
+ if mask is None:
488
+ return image
489
+ # Resize mask to image size if needed
490
+ if mask.shape[:2] != image.shape[:2]:
491
+ try:
492
+ mask = cv2.resize(mask.astype(np.uint8), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
493
+ except Exception:
494
+ pass
495
+ binary = (mask.astype(np.int32) > 0).astype(np.uint8) * 255
496
+ return cv2.bitwise_and(image, image, mask=binary)
497
+
498
+ def _create_maskout_white_background(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
499
+ """Create maskout image with white background."""
500
+ # Create white background
501
+ white_background = np.full_like(image, 255, dtype=np.uint8)
502
+
503
+ # Apply mask to original image (keep only masked regions)
504
+ masked_image = image.copy()
505
+ masked_image[mask == 0] = 0 # Set non-masked regions to black
506
+
507
+ # Combine: white background + masked image
508
+ result = white_background.copy()
509
+ result[mask > 0] = masked_image[mask > 0]
510
+
511
+ return result
512
+
513
+ def _normalize_to_uint8(self, arr: np.ndarray) -> np.ndarray:
514
+ """Normalize array to uint8 range."""
515
+ if arr.size == 0:
516
+ return arr.astype(np.uint8)
517
+
518
+ arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
519
+
520
+ if arr.ptp() > 0:
521
+ normalized = (arr - arr.min()) / (arr.ptp() + 1e-6) * 255
522
+ else:
523
+ normalized = np.zeros_like(arr)
524
+
525
+ return np.clip(normalized, 0, 255).astype(np.uint8)
526
+
527
+ def _create_texture_summary_plot(self, output_dir: Path, features: Dict[str, np.ndarray], band: str) -> None:
528
+ """Create texture feature summary plot."""
529
+ try:
530
+ # Get available features
531
+ available_features = [k for k, v in features.items()
532
+ if isinstance(v, np.ndarray) and v.size > 0 and k != 'ehd_features']
533
+
534
+ if not available_features:
535
+ return
536
+
537
+ # Create subplot
538
+ n_features = len(available_features)
539
+ cols = min(3, n_features)
540
+ rows = (n_features + cols - 1) // cols
541
+
542
+ fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
543
+ if n_features == 1:
544
+ axes = [axes]
545
+ elif rows == 1:
546
+ axes = axes.reshape(1, -1)
547
+
548
+ for i, feature_name in enumerate(available_features):
549
+ row, col = divmod(i, cols)
550
+ ax = axes[row, col] if rows > 1 else axes[col]
551
+
552
+ feature_map = features[feature_name]
553
+ ax.imshow(feature_map, cmap='viridis')
554
+ ax.set_title(f'{band.upper()} - {feature_name.upper()}')
555
+ ax.axis('off')
556
+
557
+ # Hide unused subplots
558
+ for i in range(n_features, rows * cols):
559
+ row, col = divmod(i, cols)
560
+ ax = axes[row, col] if rows > 1 else axes[col]
561
+ ax.axis('off')
562
+
563
+ plt.tight_layout()
564
+ plt.savefig(output_dir / f'{band}_texture_summary.png',
565
+ dpi=self.settings.plot_dpi, bbox_inches='tight')
566
+ plt.close()
567
+
568
+ except Exception as e:
569
+ logger.error(f"Failed to create texture summary plot: {e}")
570
+
571
+ def _create_vegetation_summary_plot(self, output_dir: Path, vegetation_indices: Dict[str, Any]) -> None:
572
+ """Create vegetation index summary plot."""
573
+ try:
574
+ # Get available indices
575
+ available_indices = [k for k, v in vegetation_indices.items()
576
+ if isinstance(v, dict) and 'values' in v and isinstance(v['values'], np.ndarray)]
577
+
578
+ if not available_indices:
579
+ return
580
+
581
+ # Create subplot
582
+ n_indices = len(available_indices)
583
+ cols = min(3, n_indices)
584
+ rows = (n_indices + cols - 1) // cols
585
+
586
+ fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
587
+ if n_indices == 1:
588
+ axes = [axes]
589
+ elif rows == 1:
590
+ axes = axes.reshape(1, -1)
591
+
592
+ for i, index_name in enumerate(available_indices):
593
+ row, col = divmod(i, cols)
594
+ ax = axes[row, col] if rows > 1 else axes[col]
595
+
596
+ values = vegetation_indices[index_name]['values']
597
+ im = ax.imshow(values, cmap='RdYlGn')
598
+ ax.set_title(f'{index_name}')
599
+ ax.axis('off')
600
+ divider = make_axes_locatable(ax)
601
+ cax = divider.append_axes("right", size="2%", pad=0.02)
602
+ cbar = plt.colorbar(im, cax=cax, orientation='vertical')
603
+ cbar.ax.tick_params(labelsize=6, width=0.5, length=2)
604
+ if hasattr(cbar, 'outline') and cbar.outline is not None:
605
+ cbar.outline.set_linewidth(0.5)
606
+
607
+ # Hide unused subplots
608
+ for i in range(n_indices, rows * cols):
609
+ row, col = divmod(i, cols)
610
+ ax = axes[row, col] if rows > 1 else axes[col]
611
+ ax.axis('off')
612
+
613
+ plt.tight_layout()
614
+ plt.savefig(output_dir / 'vegetation_indices_summary.png',
615
+ dpi=self.settings.plot_dpi, bbox_inches='tight')
616
+ plt.close()
617
+
618
+ except Exception as e:
619
+ logger.error(f"Failed to create vegetation summary plot: {e}")
620
+
621
+ def _create_comprehensive_analysis_plot(self, output_dir: Path, plant_data: Dict[str, Any]) -> None:
622
+ """Create comprehensive analysis plot."""
623
+ try:
624
+ fig, axes = plt.subplots(2, 3, figsize=(15, 10))
625
+
626
+ # Original image
627
+ if 'composite' in plant_data:
628
+ axes[0, 0].imshow(cv2.cvtColor(plant_data['composite'], cv2.COLOR_BGR2RGB))
629
+ axes[0, 0].set_title('Original Composite')
630
+ axes[0, 0].axis('off')
631
+
632
+ # Mask
633
+ if 'mask' in plant_data:
634
+ axes[0, 1].imshow(plant_data['mask'], cmap='gray')
635
+ axes[0, 1].set_title('Segmentation Mask')
636
+ axes[0, 1].axis('off')
637
+
638
+ # Overlay
639
+ if 'composite' in plant_data and 'mask' in plant_data:
640
+ overlay = self._create_overlay(plant_data['composite'], plant_data['mask'])
641
+ axes[0, 2].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
642
+ axes[0, 2].set_title('Overlay')
643
+ axes[0, 2].axis('off')
644
+
645
+ # Texture features (if available)
646
+ if 'texture_features' in plant_data and 'color' in plant_data['texture_features']:
647
+ color_features = plant_data['texture_features']['color'].get('features', {})
648
+ if 'lbp' in color_features:
649
+ axes[1, 0].imshow(color_features['lbp'], cmap='viridis')
650
+ axes[1, 0].set_title('LBP Texture')
651
+ axes[1, 0].axis('off')
652
+
653
+ # Vegetation indices (if available)
654
+ if 'vegetation_indices' in plant_data:
655
+ veg_indices = plant_data['vegetation_indices']
656
+ if 'NDVI' in veg_indices and 'values' in veg_indices['NDVI']:
657
+ axes[1, 1].imshow(veg_indices['NDVI']['values'], cmap='RdYlGn')
658
+ axes[1, 1].set_title('NDVI')
659
+ axes[1, 1].axis('off')
660
+
661
+ # Morphology (if available)
662
+ if 'morphology_features' in plant_data and 'images' in plant_data['morphology_features']:
663
+ morph_images = plant_data['morphology_features']['images']
664
+ if 'skeleton' in morph_images:
665
+ axes[1, 2].imshow(morph_images['skeleton'], cmap='gray')
666
+ axes[1, 2].set_title('Skeleton')
667
+ axes[1, 2].axis('off')
668
+
669
+ plt.tight_layout()
670
+ plt.savefig(output_dir / 'comprehensive_analysis.png',
671
+ dpi=min(getattr(self.settings, 'plot_dpi', 100), 100), bbox_inches='tight')
672
+ plt.close()
673
+
674
+ except Exception as e:
675
+ logger.error(f"Failed to create comprehensive analysis plot: {e}")
676
+
677
+ def create_pipeline_summary(self, results: Dict[str, Any]) -> None:
678
+ """Create a summary of the entire pipeline run."""
679
+ try:
680
+ summary_file = self.output_folder / 'pipeline_summary.json'
681
+
682
+ with open(summary_file, 'w') as f:
683
+ json.dump(results['summary'], f, indent=2)
684
+
685
+ logger.info(f"Pipeline summary saved to {summary_file}")
686
+
687
+ except Exception as e:
688
+ logger.error(f"Failed to create pipeline summary: {e}")
sorghum_pipeline/pipeline.py ADDED
@@ -0,0 +1,1377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main pipeline class for the Sorghum Plant Phenotyping Pipeline.
3
+
4
+ This module orchestrates the entire pipeline from data loading
5
+ to feature extraction and result output.
6
+ """
7
+
8
+ import os
9
+ import subprocess
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Optional, List, Set
13
+ import numpy as np
14
+ import cv2
15
+ import torch
16
+ from torchvision import transforms
17
+ from transformers import AutoModelForImageSegmentation
18
+ from sklearn.decomposition import PCA
19
+ try:
20
+ from tqdm import tqdm
21
+ except Exception:
22
+ tqdm = None
23
+
24
+ from .config import Config
25
+ from .data import DataLoader, ImagePreprocessor, MaskHandler
26
+ from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
27
+ from .output import OutputManager
28
+ from .segmentation import SegmentationManager
29
+ # Make occlusion handling optional if the module is not present
30
+ try:
31
+ from .segmentation.occlusion_handler import OcclusionHandler # type: ignore
32
+ except Exception:
33
+ OcclusionHandler = None # type: ignore
34
+
35
+
36
+ class SorghumPipeline:
37
+ """
38
+ Main pipeline class for sorghum plant phenotyping.
39
+
40
+ This class orchestrates the entire pipeline from data loading
41
+ to feature extraction and result output.
42
+ """
43
+
44
+ def __init__(self, config_path: Optional[str] = None, config: Optional[Config] = None, include_ignored: bool = False, enable_occlusion_handling: bool = False, enable_instance_integration: bool = False, strict_loader: bool = False, excluded_dates: Optional[List[str]] = None):
45
+ """
46
+ Initialize the pipeline.
47
+
48
+ Args:
49
+ config_path: Path to configuration file
50
+ config: Configuration object (if not using file)
51
+ include_ignored: Whether to include ignored plants
52
+ enable_occlusion_handling: Whether to enable SAM2Long occlusion handling
53
+ """
54
+ # Setup logging
55
+ self._setup_logging()
56
+
57
+ # Load configuration
58
+ if config is not None:
59
+ self.config = config
60
+ elif config_path is not None:
61
+ self.config = Config(config_path)
62
+ else:
63
+ raise ValueError("Either config_path or config must be provided")
64
+
65
+ # Validate configuration
66
+ self.config.validate()
67
+
68
+ # Store settings
69
+ self.enable_occlusion_handling = enable_occlusion_handling
70
+ self.enable_instance_integration = enable_instance_integration
71
+ self.strict_loader = strict_loader
72
+ self.excluded_dates = excluded_dates or []
73
+
74
+ # Initialize components
75
+ self._initialize_components(include_ignored)
76
+
77
+ logger.info("Sorghum Pipeline initialized successfully")
78
+
79
+ def _setup_logging(self):
80
+ """Setup logging configuration."""
81
+ logging.basicConfig(
82
+ level=logging.INFO,
83
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
84
+ handlers=[
85
+ logging.StreamHandler(),
86
+ logging.FileHandler('sorghum_pipeline.log')
87
+ ]
88
+ )
89
+ global logger
90
+ logger = logging.getLogger(__name__)
91
+
92
+ def _initialize_components(self, include_ignored: bool = False):
93
+ """Initialize all pipeline components."""
94
+ # Data components
95
+ self.data_loader = DataLoader(
96
+ input_folder=self.config.paths.input_folder,
97
+ debug=True,
98
+ include_ignored=include_ignored,
99
+ strict_loader=self.strict_loader,
100
+ excluded_dates=self.excluded_dates,
101
+ )
102
+ self.preprocessor = ImagePreprocessor(
103
+ target_size=self.config.processing.target_size
104
+ )
105
+ self.mask_handler = MaskHandler(
106
+ min_area=self.config.processing.min_component_area,
107
+ kernel_size=self.config.processing.morphology_kernel_size
108
+ )
109
+
110
+ # Feature extractors
111
+ self.texture_extractor = TextureExtractor(
112
+ lbp_points=self.config.processing.lbp_points,
113
+ lbp_radius=self.config.processing.lbp_radius,
114
+ hog_orientations=self.config.processing.hog_orientations,
115
+ hog_pixels_per_cell=self.config.processing.hog_pixels_per_cell,
116
+ hog_cells_per_block=self.config.processing.hog_cells_per_block,
117
+ lacunarity_window=self.config.processing.lacunarity_window,
118
+ ehd_threshold=self.config.processing.ehd_threshold,
119
+ angle_resolution=self.config.processing.angle_resolution
120
+ )
121
+
122
+ self.vegetation_extractor = VegetationIndexExtractor(
123
+ epsilon=self.config.processing.epsilon,
124
+ soil_factor=self.config.processing.soil_factor
125
+ )
126
+
127
+ self.morphology_extractor = MorphologyExtractor(
128
+ pixel_to_cm=self.config.processing.pixel_to_cm,
129
+ prune_sizes=self.config.processing.prune_sizes
130
+ )
131
+
132
+ # Segmentation
133
+ self.segmentation_manager = SegmentationManager(
134
+ model_name=self.config.model.model_name,
135
+ device=self.config.get_device(),
136
+ threshold=self.config.processing.segmentation_threshold,
137
+ trust_remote_code=self.config.model.trust_remote_code,
138
+ cache_dir=self.config.model.cache_dir if getattr(self.config.model, 'cache_dir', '') else None,
139
+ local_files_only=getattr(self.config.model, 'local_files_only', False),
140
+ )
141
+
142
+ # Occlusion handling (optional)
143
+ self.occlusion_handler = None
144
+ if self.enable_occlusion_handling and OcclusionHandler is not None:
145
+ try:
146
+ self.occlusion_handler = OcclusionHandler(
147
+ device=self.config.get_device(),
148
+ model="tiny", # Can be made configurable
149
+ confidence_threshold=0.5,
150
+ iou_threshold=0.1
151
+ )
152
+ logger.info("Occlusion handler initialized successfully")
153
+ except Exception as e:
154
+ logger.warning(f"Failed to initialize occlusion handler: {e}")
155
+ logger.warning("Continuing without occlusion handling")
156
+ self.occlusion_handler = None
157
+ elif self.enable_occlusion_handling and OcclusionHandler is None:
158
+ logger.warning("Occlusion handler module not found; continuing without occlusion handling")
159
+
160
+ # Output manager
161
+ self.output_manager = OutputManager(
162
+ output_folder=self.config.paths.output_folder,
163
+ settings=self.config.output
164
+ )
165
+
166
+ def _free_gpu_memory_before_instance(self) -> None:
167
+ """Attempt to free GPU memory prior to running SAM2Long in a subprocess.
168
+
169
+ - Moves BRIA segmentation model to CPU if present
170
+ - Deletes the model reference to release VRAM
171
+ - Calls torch.cuda.empty_cache()
172
+ """
173
+ try:
174
+ import torch as _torch # type: ignore
175
+ # Move BRIA model to CPU and drop reference
176
+ try:
177
+ if getattr(self, 'segmentation_manager', None) is not None:
178
+ mdl = getattr(self.segmentation_manager, 'model', None)
179
+ if mdl is not None:
180
+ try:
181
+ mdl.to('cpu')
182
+ except Exception:
183
+ pass
184
+ try:
185
+ delattr(self.segmentation_manager, 'model')
186
+ except Exception:
187
+ pass
188
+ # Ensure attribute exists but is None for future checks
189
+ try:
190
+ self.segmentation_manager.model = None # type: ignore
191
+ except Exception:
192
+ pass
193
+ except Exception:
194
+ pass
195
+ # Free CUDA cache
196
+ try:
197
+ if _torch.cuda.is_available():
198
+ _torch.cuda.empty_cache()
199
+ except Exception:
200
+ pass
201
+ logger.info("Freed GPU memory before SAM2Long invocation (moved BRIA to CPU and emptied cache)")
202
+ except Exception as e:
203
+ logger.warning(f"Failed to free GPU memory before instance segmentation: {e}")
204
+
205
+ def run(self, load_all_frames: bool = False, segmentation_only: bool = False, filter_plants: Optional[List[str]] = None, filter_frames: Optional[List[str]] = None, run_instance_segmentation: bool = False, features_frame_only: Optional[int] = None, reuse_instance_results: bool = False, instance_mapping_path: Optional[str] = None, force_reprocess: bool = False, respect_instance_frame_rules_for_features: bool = False, substitute_feature_image_from_instance_src: bool = False) -> Dict[str, Any]:
206
+ """
207
+ Run the complete pipeline.
208
+
209
+ Args:
210
+ load_all_frames: Whether to load all frames or selected frames
211
+ segmentation_only: If True, run segmentation only and skip feature extraction
212
+
213
+ Returns:
214
+ Dictionary containing all results
215
+ """
216
+ logger.info("Starting Sorghum Pipeline...")
217
+
218
+ try:
219
+ import time
220
+ total_start = time.perf_counter()
221
+ # Step 1: Load data
222
+ logger.info("Step 1/6: Loading data...")
223
+ # In reuse mode we need all frames to select the mapped frame per plant
224
+ if reuse_instance_results:
225
+ plants = self.data_loader.load_all_frames()
226
+ else:
227
+ # If specific frames are requested, we must load all frames to filter correctly
228
+ if load_all_frames or (filter_frames is not None and len(filter_frames) > 0):
229
+ plants = self.data_loader.load_all_frames()
230
+ else:
231
+ plants = self.data_loader.load_selected_frames()
232
+
233
+ # Optional filter by specific plant names (e.g., ["plant1"])
234
+ if filter_plants:
235
+ allowed = set(filter_plants)
236
+ plants = {
237
+ key: pdata for key, pdata in plants.items()
238
+ if len(key.split('_')) > 3 and key.split('_')[3] in allowed
239
+ }
240
+
241
+ # Optional filter by specific frame numbers (e.g., ["9"] or ["frame9"])
242
+ if filter_frames:
243
+ # Normalize to 'frameX' tokens
244
+ wanted = set(
245
+ [f if str(f).startswith('frame') else f"frame{str(f)}" for f in filter_frames]
246
+ )
247
+ plants = {
248
+ key: pdata for key, pdata in plants.items()
249
+ if key.split('_')[-1] in wanted
250
+ }
251
+
252
+ if not plants:
253
+ raise ValueError("No plant data loaded")
254
+
255
+ logger.info(f"Loaded {len(plants)} plants")
256
+
257
+ # If reusing instance results with mapping, restrict to exactly the mapped frame per plant (default frame8)
258
+ if reuse_instance_results:
259
+ try:
260
+ import json as _json
261
+ if instance_mapping_path is None:
262
+ raise ValueError("instance_mapping_path is required in reuse mode")
263
+ _map = _json.load(open(instance_mapping_path, 'r'))
264
+ # Normalize mapping plant keys and compute target frame (default 8)
265
+ target_frame_by_plant = {}
266
+ for pk, pv in _map.items():
267
+ k_norm = pk if str(pk).startswith('plant') else f"plant{int(pk)}" if str(pk).isdigit() else str(pk)
268
+ try:
269
+ target_frame_by_plant[k_norm] = int(pv.get('frame', 8))
270
+ except Exception:
271
+ target_frame_by_plant[k_norm] = 8
272
+ before = len(plants)
273
+ plants = {
274
+ key: pdata for key, pdata in plants.items()
275
+ if (len(key.split('_')) > 3 and key.split('_')[3] in target_frame_by_plant
276
+ and key.split('_')[-1] == f"frame{target_frame_by_plant[key.split('_')[3]]}")
277
+ }
278
+ logger.info(f"Restricted loaded data by mapping frames: {before} -> {len(plants)} items")
279
+ except Exception as e:
280
+ logger.warning(f"Failed to restrict loaded data by mapping frames: {e}")
281
+
282
+ # Skip plants that already have saved results (unless force_reprocess)
283
+ if not force_reprocess:
284
+ try:
285
+ before = len(plants)
286
+ filtered = {}
287
+ for key, pdata in plants.items():
288
+ parts = key.split('_')
289
+ if len(parts) < 5:
290
+ filtered[key] = pdata
291
+ continue
292
+ date_key = "_".join(parts[:3])
293
+ plant_name = parts[3]
294
+ plant_dir = Path(self.config.paths.output_folder) / date_key / plant_name
295
+ meta_ok = (plant_dir / 'metadata.json').exists()
296
+ seg_mask_ok = (plant_dir / self.config.output.segmentation_dir / 'mask.png').exists()
297
+ if meta_ok or seg_mask_ok:
298
+ continue
299
+ filtered[key] = pdata
300
+ plants = filtered
301
+ logger.info(f"Skip-existing filter: {before} -> {len(plants)} items to process")
302
+ except Exception as e:
303
+ logger.warning(f"Skip-existing filter failed: {e}")
304
+
305
+ # Pre-segmentation borrowing: use plant12 images for plant13 from the start
306
+ try:
307
+ rewired = 0
308
+ borrow_map: Dict[str, str] = {
309
+ 'plant13': 'plant12',
310
+ 'plant14': 'plant13',
311
+ 'plant15': 'plant14',
312
+ 'plant16': 'plant15',
313
+ }
314
+ for _k in list(plants.keys()):
315
+ _parts = _k.split('_')
316
+ # Expect keys like YYYY_MM_DD_plantX_frameY
317
+ if len(_parts) < 5:
318
+ continue
319
+ _date_key = "_".join(_parts[:3])
320
+ _plant_name = _parts[3]
321
+ _frame_token = _parts[4]
322
+ # Do NOT borrow on 2025_05_08
323
+ if _date_key == '2025_05_08':
324
+ continue
325
+ if _plant_name not in borrow_map:
326
+ continue
327
+ _src_plant = borrow_map[_plant_name]
328
+ _src_key = f"{_date_key}_{_src_plant}_{_frame_token}"
329
+ _src = plants.get(_src_key)
330
+ if not _src:
331
+ # Fallback: load raw image for source plant directly from disk
332
+ try:
333
+ from PIL import Image as _Image
334
+ _date_folder = _date_key.replace('_', '-')
335
+ _frame_num = int(_frame_token.replace('frame', ''))
336
+ _date_dir = Path(self.config.paths.input_folder)
337
+ # If input folder is a parent of dates, append date folder
338
+ if _date_dir.name != _date_folder:
339
+ _date_dir = _date_dir / _date_folder
340
+ _frame_path = _date_dir / _src_plant / f"{_src_plant}_frame{_frame_num}.tif"
341
+ if _frame_path.exists():
342
+ _img = _Image.open(str(_frame_path))
343
+ _src = {"raw_image": (_img, _frame_path.name), "plant_name": _plant_name, "file_path": str(_frame_path)}
344
+ else:
345
+ _src = None
346
+ except Exception:
347
+ _src = None
348
+ if not _src:
349
+ continue
350
+ _tgt = plants[_k]
351
+ # Preserve original raw image once
352
+ if 'raw_image' in _tgt and 'raw_image_original' not in _tgt:
353
+ _tgt['raw_image_original'] = _tgt['raw_image']
354
+ if 'raw_image' in _src:
355
+ _tgt['raw_image'] = _src['raw_image']
356
+ _tgt['borrowed_from'] = _src_plant
357
+ rewired += 1
358
+ if rewired > 0:
359
+ logger.info(f"Pre-seg borrowing applied: rewired {rewired} frames for plants 13/14/15/16")
360
+ except Exception as e:
361
+ logger.warning(f"Pre-seg borrowing failed: {e}")
362
+
363
+ # Step 2: Create composites
364
+ logger.info("Step 2/6: Creating composites...")
365
+ step_start = time.perf_counter()
366
+ plants = self.preprocessor.create_composites(plants)
367
+ logger.info(f"Composites done in {(time.perf_counter()-step_start):.2f}s")
368
+
369
+ # Step 3: Segment plants (optionally with bounding boxes)
370
+ logger.info("Step 3/6: Segmenting plants...")
371
+ step_start = time.perf_counter()
372
+ bbox_lookup = None
373
+ try:
374
+ bbox_dir = getattr(self.config.paths, 'boundingbox_dir', None)
375
+ # Default to project BoundingBox dir if unset or falsy
376
+ if not bbox_dir:
377
+ try:
378
+ self.config.paths.boundingbox_dir = "/home/grads/f/fahimehorvatinia/Documents/my_full_project/BoundingBox"
379
+ bbox_dir = self.config.paths.boundingbox_dir
380
+ except Exception:
381
+ bbox_dir = None
382
+ if bbox_dir:
383
+ bbox_lookup = self.data_loader.load_bounding_boxes(bbox_dir)
384
+ logger.info(f"Loaded bounding boxes from {bbox_dir}")
385
+ except Exception as e:
386
+ logger.warning(f"Failed to load bounding boxes: {e}")
387
+ bbox_lookup = None
388
+ plants = self._segment_plants(plants, bbox_lookup)
389
+ logger.info(f"Segmentation done in {(time.perf_counter()-step_start):.2f}s")
390
+
391
+ # Step 3.5: Handle occlusion if enabled
392
+ if self.enable_occlusion_handling and self.occlusion_handler is not None:
393
+ logger.info("Step 3.5/6: Handling occlusion with SAM2Long...")
394
+ step_start = time.perf_counter()
395
+ plants = self._handle_occlusion(plants)
396
+ logger.info(f"Occlusion handling done in {(time.perf_counter()-step_start):.2f}s")
397
+
398
+ # Optional: Export RMBG maskouts with white background and run instance segmentation
399
+ if (run_instance_segmentation or self.enable_instance_integration) and not reuse_instance_results:
400
+ if not load_all_frames:
401
+ logger.warning("Instance segmentation expects all 13 frames; consider running with load_all_frames=True.")
402
+ logger.info("Step 3.6: Exporting white-background RMBG images for instance segmentation...")
403
+ # Derive date-specific export/result directories when a single date is present
404
+ date_keys = set()
405
+ try:
406
+ for _k in plants.keys():
407
+ _p = _k.split('_')
408
+ if len(_p) >= 3:
409
+ date_keys.add("_".join(_p[:3]))
410
+ except Exception:
411
+ pass
412
+ if len(date_keys) == 1:
413
+ date_key = next(iter(date_keys))
414
+ base_dir = Path(self.config.paths.output_folder) / date_key
415
+ export_dir = base_dir / "instance_input_maskouts"
416
+ instance_results_dir = base_dir / "instance_results"
417
+ else:
418
+ export_dir = Path(self.config.paths.output_folder) / "instance_input_maskouts"
419
+ instance_results_dir = Path(self.config.paths.output_folder) / "instance_results"
420
+ export_dir.mkdir(parents=True, exist_ok=True)
421
+ instance_results_dir.mkdir(parents=True, exist_ok=True)
422
+ self._export_white_background_maskouts(plants, export_dir)
423
+
424
+ logger.info("Invoking final SAM2Long instance segmentation on exported images...")
425
+ # Free GPU memory before launching SAM2Long to avoid CUDA OOM
426
+ self._free_gpu_memory_before_instance()
427
+ env = os.environ.copy()
428
+ env["SAM2LONG_IMAGES_DIR"] = str(export_dir)
429
+ env["SAM2LONG_RESULTS_DIR"] = str(instance_results_dir)
430
+ # Ensure instance outputs include all frames for all dates
431
+ try:
432
+ env.pop("INSTANCE_OUTPUT_FRAMES", None)
433
+ except Exception:
434
+ pass
435
+ script_path = "/home/grads/f/fahimehorvatinia/Documents/my_full_project/Experiments3_code/sam2long_instance_integration.py"
436
+ try:
437
+ subprocess.run(["python", script_path], check=True, env=env)
438
+ except subprocess.CalledProcessError as e:
439
+ logger.error(f"Instance segmentation failed: {e}")
440
+ else:
441
+ # Integrate instance masks (track_0 as target) into pdata before feature extraction
442
+ try:
443
+ self._apply_instance_masks(plants, instance_results_dir)
444
+ logger.info("Applied instance segmentation masks to pipeline data")
445
+ except Exception as e:
446
+ logger.warning(f"Failed to apply instance masks: {e}")
447
+ elif reuse_instance_results:
448
+ # Reuse existing instance masks from mapping file
449
+ if instance_mapping_path is None:
450
+ raise ValueError("reuse_instance_results=True requires instance_mapping_path to be provided")
451
+ try:
452
+ self._apply_instance_masks_from_mapping(plants, Path(instance_mapping_path))
453
+ logger.info("Applied instance masks from mapping file")
454
+ except Exception as e:
455
+ logger.error(f"Failed to apply instance masks from mapping: {e}")
456
+
457
+ if not segmentation_only:
458
+ # If reusing instance results with a mapping, restrict features to mapped frames per plant
459
+ if reuse_instance_results and instance_mapping_path is not None:
460
+ try:
461
+ import json as _json
462
+ _map = _json.load(open(instance_mapping_path, 'r'))
463
+ # Normalize map
464
+ _norm = {}
465
+ for pk, pv in _map.items():
466
+ k_norm = pk if str(pk).startswith('plant') else f"plant{int(pk)}" if str(pk).isdigit() else str(pk)
467
+ _norm[k_norm] = int(pv.get('frame', 8))
468
+ before = len(plants)
469
+ plants = {
470
+ k: v for k, v in plants.items()
471
+ if len(k.split('_')) > 3 and k.split('_')[3] in _norm and k.split('_')[-1] == f"frame{_norm[k.split('_')[3]]}"
472
+ }
473
+ logger.info(f"Restricted feature extraction by mapping: {before} -> {len(plants)} items")
474
+ except Exception as e:
475
+ logger.warning(f"Failed to restrict by mapping frames: {e}")
476
+ # Optional: restrict features to per-plant preferred frame using internal frame rules
477
+ if respect_instance_frame_rules_for_features:
478
+ try:
479
+ # Keep this in sync with _apply_instance_masks frame_rules
480
+ frame_rules: Dict[str, int] = {
481
+ "plant33": 2,
482
+ "plant16": 4,
483
+ "plant19": 5,
484
+ "plant26": 8,
485
+ "plant27": 8,
486
+ "plant29": 8,
487
+ "plant35": 7,
488
+ "plant36": 6,
489
+ "plant37": 2,
490
+ "plant45": 5,
491
+ }
492
+ before = len(plants)
493
+ def _keep(k: str) -> bool:
494
+ parts = k.split('_')
495
+ if len(parts) < 2:
496
+ return False
497
+ plant_name = parts[-2]
498
+ frame_token = parts[-1]
499
+ if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
500
+ return False
501
+ desired = frame_rules.get(plant_name, 8)
502
+ return frame_token == f"frame{desired}"
503
+ plants = {k: v for k, v in plants.items() if _keep(k)}
504
+ logger.info(f"Restricted feature extraction by per-plant frame rules: {before} -> {len(plants)} items")
505
+ except Exception as e:
506
+ logger.warning(f"Failed to apply per-plant frame restriction for features: {e}")
507
+
508
+ # Optional: if features_frame_only set, keep only that frame's entries (global single frame)
509
+ if features_frame_only is not None:
510
+ frame_token = f"frame{features_frame_only}"
511
+ plants = {k: v for k, v in plants.items() if k.split('_')[-1] == frame_token}
512
+ logger.info(f"Restricted feature extraction to {len(plants)} items for {frame_token}")
513
+
514
+ # Optional: substitute feature input image from instance src_rules mapping (e.g., plant14 <- plant13)
515
+ if substitute_feature_image_from_instance_src:
516
+ try:
517
+ src_rules: Dict[str, str] = {
518
+ "plant13": "plant12",
519
+ "plant14": "plant13",
520
+ "plant15": "plant14",
521
+ "plant16": "plant15",
522
+ }
523
+ switched = 0
524
+ for key in list(plants.keys()):
525
+ parts = key.split('_')
526
+ if len(parts) < 5:
527
+ continue
528
+ date_key = "_".join(parts[:3])
529
+ plant_name = parts[3]
530
+ frame_token = parts[-1]
531
+ if plant_name not in src_rules:
532
+ continue
533
+ src_plant = src_rules[plant_name]
534
+ src_key = f"{date_key}_{src_plant}_{frame_token}"
535
+ if src_key not in plants:
536
+ continue
537
+ src_pdata = plants[src_key]
538
+ tgt_pdata = plants[key]
539
+ # Preserve the original composite used for segmentation for correct overlays later
540
+ try:
541
+ if 'composite' in tgt_pdata and 'segmentation_composite' not in tgt_pdata:
542
+ tgt_pdata['segmentation_composite'] = tgt_pdata['composite']
543
+ except Exception:
544
+ pass
545
+ # Swap feature inputs: composite and spectral bands
546
+ if 'composite' in src_pdata:
547
+ tgt_pdata['composite'] = src_pdata['composite']
548
+ if 'spectral_stack' in src_pdata:
549
+ tgt_pdata['spectral_stack'] = src_pdata['spectral_stack']
550
+ # Ensure mask aligns with substituted composite; resize if needed
551
+ try:
552
+ import cv2 as _cv2
553
+ import numpy as _np
554
+ comp = tgt_pdata.get('composite')
555
+ msk = tgt_pdata.get('mask')
556
+ if comp is not None and msk is not None:
557
+ ch, cw = comp.shape[:2]
558
+ mh, mw = msk.shape[:2]
559
+ if (mh, mw) != (ch, cw):
560
+ resized = _cv2.resize(msk.astype('uint8'), (cw, ch), interpolation=_cv2.INTER_NEAREST)
561
+ tgt_pdata['mask'] = resized
562
+ if 'soft_mask' in tgt_pdata and isinstance(tgt_pdata['soft_mask'], _np.ndarray):
563
+ tgt_pdata['soft_mask'] = (resized > 0).astype(_np.float32)
564
+ # Precompute masked composite with white background for saving
565
+ white = _np.full_like(comp, 255, dtype=_np.uint8)
566
+ result = white.copy()
567
+ result[tgt_pdata['mask'] > 0] = comp[tgt_pdata['mask'] > 0]
568
+ tgt_pdata['masked_composite'] = result
569
+ except Exception:
570
+ pass
571
+ switched += 1
572
+ if switched > 0:
573
+ logger.info(f"Substituted feature images from src_rules for {switched} items")
574
+ except Exception as e:
575
+ logger.warning(f"Failed feature-image substitution via src_rules: {e}")
576
+ # Step 4: Extract features
577
+ logger.info("Step 4/6: Extracting features...")
578
+ step_start = time.perf_counter()
579
+ # Stream-save mode: save outputs immediately after each plant's features when fast output is enabled
580
+ stream_save = False
581
+ try:
582
+ import os as _os
583
+ stream_save = bool(int(_os.environ.get('STREAM_SAVE', '0'))) or bool(getattr(self.output_manager, 'fast_mode', False))
584
+ except Exception:
585
+ stream_save = False
586
+
587
+ plants = self._extract_features(plants, stream_save=stream_save)
588
+ logger.info(f"Features done in {(time.perf_counter()-step_start):.2f}s")
589
+
590
+ # Step 5: Generate outputs (skip if already stream-saved)
591
+ if not stream_save:
592
+ logger.info("Step 5/6: Generating outputs...")
593
+ step_start = time.perf_counter()
594
+ self._generate_outputs(plants)
595
+ logger.info(f"Outputs done in {(time.perf_counter()-step_start):.2f}s")
596
+
597
+ # Step 6: Create summary
598
+ logger.info("Step 6/6: Creating summary...")
599
+ summary = self._create_summary(plants)
600
+ else:
601
+ logger.info("Segmentation-only mode: skipping texture/vegetation/morphology features and plots")
602
+ # Segmentation-only: generate only segmentation outputs and a minimal summary
603
+ logger.info("Step 4/4: Generating segmentation outputs (segmentation-only mode)...")
604
+ self._generate_outputs(plants)
605
+ summary = {
606
+ "total_plants": len(plants),
607
+ "successful_plants": len(plants),
608
+ "failed_plants": 0,
609
+ "features_extracted": {
610
+ "texture": 0,
611
+ "vegetation": 0,
612
+ "morphology": 0
613
+ }
614
+ }
615
+
616
+ total_time = time.perf_counter() - total_start
617
+ logger.info(f"Pipeline completed successfully in {total_time:.2f}s!")
618
+ return {
619
+ "plants": plants,
620
+ "summary": summary,
621
+ "config": self.config,
622
+ "timing_seconds": total_time
623
+ }
624
+
625
+ except Exception as e:
626
+ logger.error(f"Pipeline failed: {e}")
627
+ raise
628
+
629
+ def _export_white_background_maskouts(self, plants: Dict[str, Any], out_dir: Path) -> None:
630
+ """Export RMBG composites with white background using the soft/binary masks.
631
+
632
+ Filenames follow: plantX_plantX_frameY_maskout.png so the final instance script can detect plants.
633
+ """
634
+ # Clear any previous maskouts to avoid processing stale plants
635
+ try:
636
+ if out_dir.exists():
637
+ for p in out_dir.glob("*_maskout.png"):
638
+ try:
639
+ p.unlink()
640
+ except Exception:
641
+ pass
642
+ except Exception:
643
+ pass
644
+ count = 0
645
+ # Per-plant rule: use bbox-only (skip SAM2Long) for these plants on all dates except 2025_05_08
646
+ bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant33", "plant39", "plant42", "plant44", "plant46"}
647
+ date_exception = "2025_05_08"
648
+ for key, pdata in plants.items():
649
+ try:
650
+ # key format: "YYYY_MM_DD_plantX_frameY"
651
+ parts = key.split('_')
652
+ if len(parts) < 3:
653
+ continue
654
+ plant_name = parts[-2]
655
+ frame_token = parts[-1] # e.g., frame8
656
+ if not plant_name.startswith('plant') or not frame_token.startswith('frame'):
657
+ continue
658
+ date_key = "_".join(parts[:3])
659
+ if (plant_name in bbox_only_plants) and (date_key != date_exception):
660
+ # Skip exporting maskouts for bbox-only plants so SAM2Long does not run on them
661
+ continue
662
+ # Extract frame number
663
+ frame_num = int(frame_token.replace('frame', ''))
664
+ composite = pdata.get('composite')
665
+ mask = pdata.get('mask')
666
+ if composite is None or mask is None:
667
+ continue
668
+ # Ensure 3-channel BGR
669
+ if len(composite.shape) == 2:
670
+ composite_bgr = cv2.cvtColor(composite, cv2.COLOR_GRAY2BGR)
671
+ else:
672
+ composite_bgr = composite
673
+ out_img = composite_bgr.copy()
674
+ # Set background to white where mask == 0
675
+ out_img[mask == 0] = (255, 255, 255)
676
+ out_path = out_dir / f"{plant_name}_{plant_name}_{frame_token}_maskout.png"
677
+ cv2.imwrite(str(out_path), out_img)
678
+ count += 1
679
+ except Exception as e:
680
+ logger.warning(f"Failed to export maskout for {key}: {e}")
681
+ logger.info(f"Exported {count} white-background maskouts to {out_dir}")
682
+
683
+ def _segment_plants(self, plants: Dict[str, Any],
684
+ bbox_lookup: Optional[Dict[str, tuple]]) -> Dict[str, Any]:
685
+ """Segment plants using BRIA model.
686
+
687
+ If bbox_lookup is provided and contains an entry for the plant (e.g., 'plant1'),
688
+ the image is cropped/masked to the bounding box region before segmentation and the
689
+ predicted mask is mapped back to the full image size. In bbox mode a largest
690
+ connected component post-processing is applied to obtain a clean target mask.
691
+ """
692
+ total = len(plants)
693
+ iterator = plants.items()
694
+ if tqdm is not None:
695
+ iterator = tqdm(list(plants.items()), desc="Segmenting", total=total, unit="img", leave=False)
696
+ for idx, (key, pdata) in enumerate(iterator):
697
+ try:
698
+ # Get composite image
699
+ composite = pdata['composite']
700
+ h, w = composite.shape[:2]
701
+
702
+ # Determine bbox for this plant if available
703
+ parts = key.split('_')
704
+ plant_name = parts[-2] if len(parts) >= 2 else None
705
+ date_key = "_".join(parts[:3]) if len(parts) >= 3 else None # e.g., 2025_04_16
706
+ bbox = None
707
+ if bbox_lookup is not None and plant_name is not None:
708
+ # keys in bbox_lookup are typically like 'plant1'
709
+ bbox = bbox_lookup.get(plant_name)
710
+ # For plant33, ignore any bbox and run full-image segmentation on all dates except the exception
711
+ if plant_name == 'plant33' and date_key != '2025_05_08':
712
+ bbox = None
713
+
714
+ # Plants that should use the bounding box itself as the mask (skip model)
715
+ bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant39", "plant42", "plant44", "plant46"}
716
+ use_bbox_only = (plant_name in bbox_only_plants)
717
+
718
+ # Do not use bounding boxes for date 2025_05_08
719
+ if date_key == '2025_05_08':
720
+ bbox = None
721
+
722
+ if bbox is not None:
723
+ # Clamp bbox to image
724
+ x1, y1, x2, y2 = bbox
725
+ x1 = max(0, min(w, int(x1)))
726
+ x2 = max(0, min(w, int(x2)))
727
+ y1 = max(0, min(h, int(y1)))
728
+ y2 = max(0, min(h, int(y2)))
729
+ if x2 <= x1 or y2 <= y1:
730
+ x1, y1, x2, y2 = 0, 0, w, h
731
+
732
+ if use_bbox_only:
733
+ # Use the bbox as the mask directly (255 inside, 0 outside)
734
+ soft_full = np.zeros((h, w), dtype=np.float32)
735
+ soft_full[y1:y2, x1:x2] = 1.0
736
+ bin_full = np.zeros((h, w), dtype=np.uint8)
737
+ bin_full[y1:y2, x1:x2] = 255
738
+ pdata['soft_mask'] = soft_full
739
+ pdata['mask'] = bin_full
740
+ else:
741
+ # Segment inside the bbox region and map back
742
+ crop = composite[y1:y2, x1:x2]
743
+ soft_mask_crop = self.segmentation_manager.segment_image_soft(crop)
744
+ soft_full = np.zeros((h, w), dtype=np.float32)
745
+ soft_resized = cv2.resize(soft_mask_crop, (x2 - x1, y2 - y1), interpolation=cv2.INTER_LINEAR)
746
+ soft_full[y1:y2, x1:x2] = soft_resized
747
+ bin_full = (soft_full > 0.5).astype(np.uint8) * 255
748
+ try:
749
+ n_lbl, labels, stats, _ = cv2.connectedComponentsWithStats(bin_full, 8)
750
+ if n_lbl > 1:
751
+ largest = 1 + int(np.argmax(stats[1:, cv2.CC_STAT_AREA]))
752
+ bin_full = (labels == largest).astype(np.uint8) * 255
753
+ except Exception:
754
+ pass
755
+ pdata['soft_mask'] = soft_full.astype(np.float32)
756
+ pdata['mask'] = bin_full.astype(np.uint8)
757
+ else:
758
+ # Full-image segmentation (no bbox)
759
+ soft_mask = self.segmentation_manager.segment_image_soft(composite)
760
+ pdata['soft_mask'] = soft_mask
761
+ pdata['mask'] = (soft_mask * 255.0).astype(np.uint8)
762
+
763
+ # Progress log every 25 items and for first/last
764
+ if tqdm is None and (idx == 0 or (idx + 1) % 25 == 0 or (idx + 1) == total):
765
+ logger.info(f"Segmented {idx + 1}/{total}: {key}")
766
+
767
+ except Exception as e:
768
+ logger.error(f"Segmentation failed for {key}: {e}")
769
+ pdata['soft_mask'] = np.zeros(composite.shape[:2], dtype=np.float32)
770
+ pdata['mask'] = np.zeros(composite.shape[:2], dtype=np.uint8)
771
+
772
+ return plants
773
+
774
+ def _handle_occlusion(self, plants: Dict[str, Any]) -> Dict[str, Any]:
775
+ """
776
+ Handle occlusion problems using SAM2Long.
777
+
778
+ This method groups plants by their base plant ID and processes
779
+ each plant's 13-frame sequence to differentiate target plant
780
+ from neighboring plants.
781
+
782
+ Args:
783
+ plants: Dictionary of plant data
784
+
785
+ Returns:
786
+ Updated plant data with occlusion handling results
787
+ """
788
+ if self.occlusion_handler is None:
789
+ logger.warning("Occlusion handler not available, skipping occlusion handling")
790
+ return plants
791
+
792
+ # Group plants by base plant ID (e.g., "plant1" from "plant1_plant1_frame1")
793
+ plant_groups = {}
794
+ for key, pdata in plants.items():
795
+ # Extract plant ID from key like "plant1_plant1_frame1"
796
+ parts = key.split('_')
797
+ if len(parts) >= 3:
798
+ plant_id = parts[0] # e.g., "plant1"
799
+ if plant_id not in plant_groups:
800
+ plant_groups[plant_id] = []
801
+ plant_groups[plant_id].append((key, pdata))
802
+
803
+ logger.info(f"Processing {len(plant_groups)} plant groups for occlusion handling")
804
+
805
+ # Process each plant group
806
+ for plant_id, plant_frames in plant_groups.items():
807
+ try:
808
+ # Sort frames by frame number
809
+ plant_frames.sort(key=lambda x: int(x[0].split('_')[-1].replace('frame', '')))
810
+
811
+ if len(plant_frames) < 2:
812
+ logger.warning(f"Plant {plant_id} has only {len(plant_frames)} frames, skipping")
813
+ continue
814
+
815
+ # Extract frames and keys
816
+ frame_keys = [x[0] for x in plant_frames]
817
+ frames = [x[1]['composite'] for x in plant_frames]
818
+
819
+ logger.info(f"Processing plant {plant_id} with {len(frames)} frames")
820
+
821
+ # Process with SAM2Long
822
+ occlusion_results = self.occlusion_handler.segment_plant_sequence(
823
+ frames=frames,
824
+ target_plant_id=plant_id
825
+ )
826
+
827
+ # Update plant data with occlusion results
828
+ target_masks = occlusion_results['target_masks']
829
+ neighbor_masks = occlusion_results['neighbor_masks']
830
+
831
+ for i, (key, pdata) in enumerate(plant_frames):
832
+ if i < len(target_masks):
833
+ # Update mask with target plant only
834
+ pdata['original_mask'] = pdata.get('mask', np.zeros_like(target_masks[i]))
835
+ pdata['mask'] = target_masks[i]
836
+ pdata['neighbor_mask'] = neighbor_masks[i]
837
+ pdata['occlusion_handled'] = True
838
+
839
+ # Update soft mask as well
840
+ pdata['original_soft_mask'] = pdata.get('soft_mask', np.zeros_like(target_masks[i], dtype=np.float32))
841
+ pdata['soft_mask'] = (target_masks[i] / 255.0).astype(np.float32)
842
+
843
+ # Calculate and store occlusion metrics
844
+ metrics = self.occlusion_handler.get_occlusion_metrics(occlusion_results)
845
+ for key, pdata in plant_frames:
846
+ pdata['occlusion_metrics'] = metrics
847
+
848
+ logger.info(f"Plant {plant_id} occlusion handling completed")
849
+ logger.info(f" - Average occlusion ratio: {metrics['average_occlusion_ratio']:.3f}")
850
+ logger.info(f" - Frames with occlusion: {metrics['frames_with_occlusion']}")
851
+
852
+ except Exception as e:
853
+ logger.error(f"Occlusion handling failed for plant {plant_id}: {e}")
854
+ # Mark as failed but continue
855
+ for key, pdata in plant_frames:
856
+ pdata['occlusion_handled'] = False
857
+ pdata['occlusion_error'] = str(e)
858
+
859
+ return plants
860
+
861
+ def _extract_features(self, plants: Dict[str, Any], stream_save: bool = False) -> Dict[str, Any]:
862
+ """Extract all features from plants.
863
+
864
+ If stream_save is True, save outputs for each plant immediately after
865
+ its features are computed to improve throughput and reduce peak memory.
866
+ """
867
+ total = len(plants)
868
+ logger.info(f"Extracting features for {total} plants...")
869
+ iterator = plants.items()
870
+ if tqdm is not None:
871
+ iterator = tqdm(list(plants.items()), desc="Extracting features", total=total, unit="img", leave=False)
872
+
873
+ # Prepare output directories once if we're streaming saves
874
+ if stream_save:
875
+ try:
876
+ self.output_manager.create_output_directories()
877
+ except Exception:
878
+ pass
879
+
880
+ for idx, (key, pdata) in enumerate(iterator):
881
+ try:
882
+ logger.debug(f"Extracting features for {key}")
883
+
884
+ # Extract texture features
885
+ pdata['texture_features'] = self._extract_texture_features(pdata)
886
+
887
+ # Extract vegetation indices
888
+ pdata['vegetation_indices'] = self._extract_vegetation_indices(pdata)
889
+
890
+ # Extract morphological features
891
+ pdata['morphology_features'] = self._extract_morphology_features(pdata)
892
+
893
+ # Immediately save outputs for this plant if streaming is enabled
894
+ if stream_save:
895
+ try:
896
+ self.output_manager.save_plant_results(key, pdata)
897
+ except Exception as _e:
898
+ logger.error(f"Stream-save failed for {key}: {_e}")
899
+
900
+ logger.debug(f"Features extracted for {key}")
901
+ if tqdm is None and (idx == 0 or (idx + 1) % 25 == 0 or (idx + 1) == total):
902
+ logger.info(f"Extracted features for {idx + 1}/{total}: {key}")
903
+
904
+ except Exception as e:
905
+ logger.error(f"Feature extraction failed for {key}: {e}")
906
+ # Add empty features
907
+ pdata['texture_features'] = {}
908
+ pdata['vegetation_indices'] = {}
909
+ pdata['morphology_features'] = {}
910
+
911
+ return plants
912
+
913
+ def _extract_texture_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
914
+ """Extract texture features for a single plant."""
915
+ features = {}
916
+
917
+ # Get bands to process
918
+ bands = ['color', 'nir', 'red_edge', 'red', 'green', 'pca']
919
+
920
+ for band in bands:
921
+ try:
922
+ # Prepare grayscale image
923
+ gray_image = self._prepare_band_image(pdata, band)
924
+
925
+ # Extract texture features
926
+ band_features = self.texture_extractor.extract_all_texture_features(gray_image)
927
+
928
+ # Compute statistics using mask3 → features_mask → mask
929
+ mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
930
+ stats = self.texture_extractor.compute_texture_statistics(band_features, mask)
931
+
932
+ features[band] = {
933
+ 'features': band_features,
934
+ 'statistics': stats
935
+ }
936
+
937
+ except Exception as e:
938
+ logger.error(f"Texture extraction failed for band {band}: {e}")
939
+ features[band] = {'features': {}, 'statistics': {}}
940
+
941
+ return features
942
+
943
+ def _extract_vegetation_indices(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
944
+ """Extract vegetation indices for a single plant."""
945
+ try:
946
+ spectral_stack = pdata.get('spectral_stack', {})
947
+ # Prefer mask3 → features_mask → mask
948
+ mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
949
+
950
+ if not spectral_stack or mask is None:
951
+ return {}
952
+
953
+ return self.vegetation_extractor.compute_vegetation_indices(
954
+ spectral_stack, mask
955
+ )
956
+
957
+ except Exception as e:
958
+ logger.error(f"Vegetation index extraction failed: {e}")
959
+ return {}
960
+
961
+ def _extract_morphology_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
962
+ """Extract morphological features for a single plant."""
963
+ try:
964
+ composite = pdata.get('composite')
965
+ # Prefer mask3 → features_mask → mask
966
+ mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
967
+
968
+ if composite is None or mask is None:
969
+ return {}
970
+
971
+ return self.morphology_extractor.extract_morphology_features(
972
+ composite, mask
973
+ )
974
+
975
+ except Exception as e:
976
+ logger.error(f"Morphology feature extraction failed: {e}")
977
+ return {}
978
+
979
+ def _prepare_band_image(self, pdata: Dict[str, Any], band: str) -> np.ndarray:
980
+ """Prepare grayscale image for a specific band."""
981
+ if band == 'color':
982
+ composite = pdata['composite']
983
+ # Prefer mask3 → features_mask → mask
984
+ mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
985
+ if mask is not None:
986
+ masked = self.mask_handler.apply_mask_to_image(composite, mask)
987
+ return cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
988
+ else:
989
+ return cv2.cvtColor(composite, cv2.COLOR_BGR2GRAY)
990
+
991
+ elif band == 'pca':
992
+ # Create PCA from spectral bands
993
+ spectral_stack = pdata.get('spectral_stack', {})
994
+ # Prefer mask3 → features_mask → mask
995
+ mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
996
+
997
+ if not spectral_stack:
998
+ return np.zeros((512, 512), dtype=np.uint8)
999
+
1000
+ # Stack bands
1001
+ bands_data = []
1002
+ for b in ['nir', 'red_edge', 'red', 'green']:
1003
+ if b in spectral_stack:
1004
+ arr = spectral_stack[b].squeeze(-1).astype(float)
1005
+ if mask is not None:
1006
+ arr = np.where(mask > 0, arr, np.nan)
1007
+ bands_data.append(arr)
1008
+
1009
+ if not bands_data:
1010
+ return np.zeros((512, 512), dtype=np.uint8)
1011
+
1012
+ # Create PCA
1013
+ full_stack = np.stack(bands_data, axis=-1)
1014
+ h, w, c = full_stack.shape
1015
+ flat = full_stack.reshape(-1, c)
1016
+ valid = ~np.isnan(flat).any(axis=1)
1017
+
1018
+ if valid.sum() == 0:
1019
+ return np.zeros((h, w), dtype=np.uint8)
1020
+
1021
+ vec = np.zeros(h * w)
1022
+ vec[valid] = PCA(n_components=1, whiten=True).fit_transform(
1023
+ flat[valid]
1024
+ ).squeeze()
1025
+
1026
+ gray_f = vec.reshape(h, w)
1027
+ if mask is not None:
1028
+ m, M = gray_f[mask > 0].min(), gray_f[mask > 0].max()
1029
+ else:
1030
+ m, M = gray_f.min(), gray_f.max()
1031
+
1032
+ if M > m:
1033
+ gray = ((gray_f - m) / (M - m) * 255).astype(np.uint8)
1034
+ else:
1035
+ gray = np.zeros_like(gray_f, dtype=np.uint8)
1036
+
1037
+ return gray
1038
+
1039
+ else:
1040
+ # Individual spectral band
1041
+ spectral_stack = pdata.get('spectral_stack', {})
1042
+ # Prefer mask3 → features_mask → mask
1043
+ mask = pdata.get('mask3', pdata.get('features_mask', pdata.get('mask')))
1044
+
1045
+ if band not in spectral_stack:
1046
+ return np.zeros((512, 512), dtype=np.uint8)
1047
+
1048
+ arr = spectral_stack[band].squeeze(-1).astype(float)
1049
+ if mask is not None:
1050
+ arr = np.where(mask > 0, arr, np.nan)
1051
+
1052
+ if mask is not None:
1053
+ m, M = np.nanmin(arr), np.nanmax(arr)
1054
+ else:
1055
+ m, M = arr.min(), arr.max()
1056
+
1057
+ if M > m:
1058
+ gray = ((np.nan_to_num(arr, nan=m) - m) / (M - m) * 255).astype(np.uint8)
1059
+ else:
1060
+ gray = np.zeros_like(arr, dtype=np.uint8)
1061
+
1062
+ return gray
1063
+
1064
+ def _generate_outputs(self, plants: Dict[str, Any]) -> None:
1065
+ """Generate all output files and visualizations."""
1066
+ self.output_manager.create_output_directories()
1067
+
1068
+ for key, pdata in plants.items():
1069
+ try:
1070
+ logger.debug(f"Generating outputs for {key}")
1071
+ self.output_manager.save_plant_results(key, pdata)
1072
+ except Exception as e:
1073
+ logger.error(f"Output generation failed for {key}: {e}")
1074
+
1075
+ def _create_summary(self, plants: Dict[str, Any]) -> Dict[str, Any]:
1076
+ """Create summary of pipeline results."""
1077
+ summary = {
1078
+ "total_plants": len(plants),
1079
+ "successful_plants": 0,
1080
+ "failed_plants": 0,
1081
+ "features_extracted": {
1082
+ "texture": 0,
1083
+ "vegetation": 0,
1084
+ "morphology": 0
1085
+ }
1086
+ }
1087
+
1088
+ for key, pdata in plants.items():
1089
+ try:
1090
+ # Check if features were extracted
1091
+ if pdata.get('texture_features'):
1092
+ summary["features_extracted"]["texture"] += 1
1093
+ if pdata.get('vegetation_indices'):
1094
+ summary["features_extracted"]["vegetation"] += 1
1095
+ if pdata.get('morphology_features'):
1096
+ summary["features_extracted"]["morphology"] += 1
1097
+
1098
+ summary["successful_plants"] += 1
1099
+
1100
+ except Exception:
1101
+ summary["failed_plants"] += 1
1102
+
1103
+ return summary
1104
+
1105
+ def _apply_instance_masks(self, plants: Dict[str, Any], instance_results_dir: Path) -> None:
1106
+ """Replace segmentation masks with SAM2Long instance masks using track_1.
1107
+
1108
+ Expects files under instance_results_dir/plantX/track_1/frame_YY_mask.png.
1109
+ """
1110
+ # Default and per-plant overrides for source plant, track and preferred frame
1111
+ default_track = "track_0"
1112
+ src_rules: Dict[str, str] = {
1113
+ "plant13": "plant12",
1114
+ "plant14": "plant13",
1115
+ "plant15": "plant14",
1116
+ "plant16": "plant15",
1117
+ }
1118
+ track_rules: Dict[str, str] = {
1119
+ # explicit track rules
1120
+ "plant1": "track_0",
1121
+ "plant4": "track_0",
1122
+ "plant9": "track_3",
1123
+ "plant13": "track_1",
1124
+ "plant14": "track_0",
1125
+ "plant15": "track_0",
1126
+ "plant16": "track_0",
1127
+ "plant18": "track_0",
1128
+ "plant19": "track_0",
1129
+ "plant23": "track_1",
1130
+ "plant26": "track_0",
1131
+ "plant27": "track_0",
1132
+ "plant29": "track_0",
1133
+ "plant31": "track_1",
1134
+ "plant34": "track_1",
1135
+ "plant35": "track_1",
1136
+ "plant36": "track_0",
1137
+ "plant37": "track_1",
1138
+ "plant38": "track_0",
1139
+ "plant39": "track_1",
1140
+ "plant40": "track_0",
1141
+ "plant41": "track_1",
1142
+ "plant42": "track_0",
1143
+ "plant43": "track_0",
1144
+ "plant45": "track_0",
1145
+ }
1146
+ frame_rules: Dict[str, int] = {
1147
+ # preferred frame overrides (1-based)
1148
+ "plant13": 8,
1149
+ "plant14": 8,
1150
+ "plant15": 8,
1151
+ "plant33": 2,
1152
+ "plant16": 4,
1153
+ "plant19": 5,
1154
+ "plant26": 8,
1155
+ "plant27": 8,
1156
+ "plant29": 8,
1157
+ "plant35": 7,
1158
+ "plant36": 6,
1159
+ "plant37": 2,
1160
+ "plant45": 5,
1161
+ }
1162
+ # Per-plant rule: skip applying instance masks (keep bbox/BRIA mask) on all dates except 2025_05_08
1163
+ bbox_only_plants: Set[str] = {"plant19", "plant20", "plant27", "plant33", "plant39", "plant42", "plant44", "plant46"}
1164
+ date_exception = "2025_05_08"
1165
+
1166
+ for key, pdata in plants.items():
1167
+ try:
1168
+ parts = key.split('_')
1169
+ if len(parts) < 3:
1170
+ continue
1171
+ plant_name = parts[-2]
1172
+ frame_token = parts[-1] # frame8
1173
+ if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
1174
+ continue
1175
+ date_key = "_".join(parts[:3])
1176
+ if (plant_name in bbox_only_plants) and (date_key != date_exception):
1177
+ # Do not override masks for bbox-only plants
1178
+ continue
1179
+ frame_num = int(frame_token.replace('frame', ''))
1180
+ # Resolve source plant, track and desired frame
1181
+ src_plant = src_rules.get(plant_name, plant_name)
1182
+ track_name = track_rules.get(plant_name, default_track)
1183
+ desired_frame = frame_rules.get(plant_name, frame_num)
1184
+ plant_dir = Path(instance_results_dir) / src_plant / track_name
1185
+ mask_path = plant_dir / f"frame_{desired_frame:02d}_mask.png"
1186
+ if not mask_path.exists():
1187
+ # Fallback to current frame if override not found
1188
+ fallback = plant_dir / f"frame_{frame_num:02d}_mask.png"
1189
+ if fallback.exists():
1190
+ mask_path = fallback
1191
+ else:
1192
+ # Last-resort: pick any available frame mask in the track directory
1193
+ try:
1194
+ candidates = sorted(plant_dir.glob("frame_*_mask.png"))
1195
+ if len(candidates) > 0:
1196
+ mask_path = candidates[0]
1197
+ else:
1198
+ continue
1199
+ except Exception:
1200
+ continue
1201
+ inst_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
1202
+ if inst_mask is None:
1203
+ continue
1204
+ # Ensure binary uint8 0/255
1205
+ inst_mask_bin = (inst_mask > 0).astype(np.uint8) * 255
1206
+ pdata['original_mask'] = pdata.get('mask', inst_mask_bin.copy())
1207
+ pdata['mask'] = inst_mask_bin
1208
+ pdata['original_soft_mask'] = pdata.get('soft_mask', (inst_mask_bin / 255.0).astype(np.float32))
1209
+ pdata['soft_mask'] = (inst_mask_bin / 255.0).astype(np.float32)
1210
+ pdata['instance_applied'] = True
1211
+
1212
+ # Build mask3 = external(mask) AND BRIA(original_mask)
1213
+ try:
1214
+ _m1 = pdata.get('mask')
1215
+ _m2 = pdata.get('original_mask')
1216
+ if isinstance(_m1, np.ndarray) and isinstance(_m2, np.ndarray):
1217
+ _m1b = (_m1.astype(np.uint8) > 0)
1218
+ _m2b = (_m2.astype(np.uint8) > 0)
1219
+ mask3 = (_m1b & _m2b).astype(np.uint8) * 255
1220
+ pdata['mask3'] = mask3
1221
+ pdata['features_mask'] = mask3
1222
+ except Exception:
1223
+ pass
1224
+
1225
+ # After applying instance masks, also overwrite the composite and spectral stack
1226
+ # with the source plant's raw image (desired frame preferred) so that
1227
+ # feature extraction and saved originals/overlays are consistent with the mask source.
1228
+ try:
1229
+ if plant_name in src_rules:
1230
+ date_key = "_".join(parts[:3])
1231
+ src_key_desired = f"{date_key}_{src_plant}_frame{desired_frame}"
1232
+ src_key_same = f"{date_key}_{src_plant}_{frame_token}"
1233
+ copy_from = plants.get(src_key_desired) or plants.get(src_key_same)
1234
+ if copy_from is None:
1235
+ # Fallback: load source composite from filesystem if not present in plants dict
1236
+ try:
1237
+ from PIL import Image as _Image
1238
+ _date_folder = date_key.replace('_', '-')
1239
+ _date_dir = Path(self.config.paths.input_folder)
1240
+ if _date_dir.name != _date_folder:
1241
+ _date_dir = _date_dir / _date_folder
1242
+ _frame_path = _date_dir / src_plant / f"{src_plant}_frame{desired_frame}.tif"
1243
+ if not _frame_path.exists():
1244
+ _frame_path = _date_dir / src_plant / f"{src_plant}_frame{frame_num}.tif"
1245
+ if _frame_path.exists():
1246
+ _img = _Image.open(str(_frame_path))
1247
+ # Process to composite using preprocessor
1248
+ comp, spec = self.preprocessor.process_raw_image(_img)
1249
+ copy_from = {"composite": comp, "spectral_stack": spec}
1250
+ except Exception:
1251
+ copy_from = None
1252
+ if copy_from is not None:
1253
+ # Preserve the segmentation-time composite once
1254
+ if 'composite' in pdata and 'segmentation_composite' not in pdata:
1255
+ pdata['segmentation_composite'] = pdata['composite']
1256
+ if 'composite' in copy_from:
1257
+ pdata['composite'] = copy_from['composite']
1258
+ if 'spectral_stack' in copy_from:
1259
+ pdata['spectral_stack'] = copy_from['spectral_stack']
1260
+ # Ensure mask size matches the copied composite
1261
+ ch, cw = pdata['composite'].shape[:2]
1262
+ mh, mw = pdata['mask'].shape[:2]
1263
+ if (mh, mw) != (ch, cw):
1264
+ pdata['mask'] = cv2.resize(pdata['mask'].astype('uint8'), (cw, ch), interpolation=cv2.INTER_NEAREST)
1265
+ pdata['soft_mask'] = (pdata['mask'] > 0).astype(np.float32)
1266
+ except Exception:
1267
+ pass
1268
+ except Exception as e:
1269
+ logger.debug(f"Instance mask apply failed for {key}: {e}")
1270
+
1271
+ def _apply_instance_masks_from_mapping(self, plants: Dict[str, Any], mapping_file: Path) -> None:
1272
+ """Apply instance masks using an explicit mapping file with absolute paths.
1273
+
1274
+ mapping JSON structure:
1275
+ {
1276
+ "plant1": {"frame": 8, "mask_path": "/abs/path/to/plant1/track_X/frame_08_mask.png"},
1277
+ "plant2": {"frame": 8, "mask_path": "/abs/path/.../frame_08_mask.png"},
1278
+ ...
1279
+ }
1280
+ If a plant's mapping specifies a different frame, only entries matching that frame are updated.
1281
+ """
1282
+ import json
1283
+ if not mapping_file.exists():
1284
+ raise FileNotFoundError(f"Mapping file not found: {mapping_file}")
1285
+ with open(mapping_file, "r") as f:
1286
+ mapping = json.load(f)
1287
+ # Normalize mapping plant keys to names like 'plantX'
1288
+ norm_map = {}
1289
+ for k, v in mapping.items():
1290
+ k_norm = k if str(k).startswith("plant") else f"plant{int(k)}" if str(k).isdigit() else str(k)
1291
+ norm_map[k_norm] = v
1292
+
1293
+ for key, pdata in plants.items():
1294
+ try:
1295
+ parts = key.split('_')
1296
+ if len(parts) < 3:
1297
+ continue
1298
+ plant_name = parts[-2]
1299
+ frame_token = parts[-1]
1300
+ if not (plant_name.startswith('plant') and frame_token.startswith('frame')):
1301
+ continue
1302
+ frame_num = int(frame_token.replace('frame', ''))
1303
+ if plant_name not in norm_map:
1304
+ continue
1305
+ entry = norm_map[plant_name]
1306
+ target_frame = int(entry.get("frame", frame_num))
1307
+ if frame_num != target_frame:
1308
+ # Only update the designated frame for this plant
1309
+ continue
1310
+ mask_path_str = entry.get("mask_path")
1311
+ if not mask_path_str:
1312
+ continue
1313
+ mask_path = Path(mask_path_str)
1314
+ if not mask_path.exists():
1315
+ logger.warning(f"Mask path not found for {plant_name} {frame_token}: {mask_path}")
1316
+ continue
1317
+ inst_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
1318
+ if inst_mask is None:
1319
+ continue
1320
+ inst_mask_bin = (inst_mask > 0).astype(np.uint8) * 255
1321
+ pdata['original_mask'] = pdata.get('mask', inst_mask_bin.copy())
1322
+ pdata['mask'] = inst_mask_bin
1323
+ pdata['original_soft_mask'] = pdata.get('soft_mask', (inst_mask_bin / 255.0).astype(np.float32))
1324
+ pdata['soft_mask'] = (inst_mask_bin / 255.0).astype(np.float32)
1325
+ pdata['instance_applied'] = True
1326
+
1327
+ # Build mask3 = external(mask) AND BRIA(original_mask)
1328
+ try:
1329
+ _m1 = pdata.get('mask')
1330
+ _m2 = pdata.get('original_mask')
1331
+ if isinstance(_m1, np.ndarray) and isinstance(_m2, np.ndarray):
1332
+ _m1b = (_m1.astype(np.uint8) > 0)
1333
+ _m2b = (_m2.astype(np.uint8) > 0)
1334
+ mask3 = (_m1b & _m2b).astype(np.uint8) * 255
1335
+ pdata['mask3'] = mask3
1336
+ pdata['features_mask'] = mask3
1337
+ except Exception:
1338
+ pass
1339
+ except Exception as e:
1340
+ logger.debug(f"Instance mapping apply failed for {key}: {e}")
1341
+
1342
+
1343
+ def run_pipeline(config_path: str, load_all_frames: bool = False, segmentation_only: bool = False, filter_plants: Optional[List[str]] = None) -> Dict[str, Any]:
1344
+ """
1345
+ Convenience function to run the pipeline.
1346
+
1347
+ Args:
1348
+ config_path: Path to configuration file
1349
+ load_all_frames: Whether to load all frames or selected frames
1350
+ segmentation_only: If True, run segmentation only and skip feature extraction
1351
+
1352
+ Returns:
1353
+ Pipeline results
1354
+ """
1355
+ pipeline = SorghumPipeline(config_path)
1356
+ return pipeline.run(load_all_frames, segmentation_only, filter_plants)
1357
+
1358
+
1359
+ if __name__ == "__main__":
1360
+ import sys
1361
+
1362
+ config_path = sys.argv[1] if len(sys.argv) > 1 else "config.yml"
1363
+ load_all = "--all" in sys.argv
1364
+ seg_only = "--seg-only" in sys.argv
1365
+ # Basic arg parse for --plant=<name>
1366
+ plant_filter = None
1367
+ for arg in sys.argv[1:]:
1368
+ if arg.startswith("--plant="):
1369
+ plant_filter = [arg.split("=", 1)[1]]
1370
+
1371
+ try:
1372
+ results = run_pipeline(config_path, load_all, seg_only, plant_filter)
1373
+ print("Pipeline completed successfully!")
1374
+ print(f"Processed {results['summary']['total_plants']} plants")
1375
+ except Exception as e:
1376
+ print(f"Pipeline failed: {e}")
1377
+ sys.exit(1)
sorghum_pipeline/segmentation/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Segmentation modules for the Sorghum Pipeline.
3
+
4
+ This package contains segmentation functionality including:
5
+ - BRIA model integration
6
+ - Mask post-processing
7
+ - Segmentation validation
8
+ """
9
+
10
+ from .manager import SegmentationManager
11
+
12
+ __all__ = ["SegmentationManager"]
sorghum_pipeline/segmentation/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (482 Bytes). View file
 
sorghum_pipeline/segmentation/__pycache__/advanced_occlusion_handler.cpython-312.pyc ADDED
Binary file (26.3 kB). View file
 
sorghum_pipeline/segmentation/__pycache__/leaf_occlusion_handler.cpython-312.pyc ADDED
Binary file (27 kB). View file
 
sorghum_pipeline/segmentation/__pycache__/manager.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
sorghum_pipeline/segmentation/__pycache__/occlusion_handler.cpython-312.pyc ADDED
Binary file (20.2 kB). View file
 
sorghum_pipeline/segmentation/manager.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Segmentation manager for the Sorghum Pipeline.
3
+
4
+ This module handles image segmentation using the BRIA model
5
+ and provides post-processing capabilities.
6
+ """
7
+
8
+ import numpy as np
9
+ import cv2
10
+ import torch
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from transformers import AutoModelForImageSegmentation
14
+ from typing import Optional, Tuple
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class SegmentationManager:
21
+ """Manages image segmentation using BRIA model."""
22
+
23
+ def __init__(self,
24
+ model_name: str = "briaai/RMBG-2.0",
25
+ device: str = "auto",
26
+ threshold: float = 0.5,
27
+ trust_remote_code: bool = True,
28
+ cache_dir: Optional[str] = None,
29
+ local_files_only: bool = False):
30
+ """
31
+ Initialize segmentation manager.
32
+
33
+ Args:
34
+ model_name: Name of the BRIA model
35
+ device: Device to run model on ("auto", "cpu", "cuda")
36
+ threshold: Segmentation threshold
37
+ trust_remote_code: Whether to trust remote code
38
+ cache_dir: Hugging Face cache directory for model weights
39
+ local_files_only: If True, only load from local cache
40
+ """
41
+ self.model_name = model_name
42
+ self.threshold = threshold
43
+ self.trust_remote_code = trust_remote_code
44
+ self.cache_dir = cache_dir
45
+ self.local_files_only = local_files_only
46
+
47
+ # Determine device
48
+ if device == "auto":
49
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ else:
51
+ self.device = device
52
+
53
+ # Initialize model
54
+ self.model = None
55
+ self.transform = None
56
+ self._load_model()
57
+
58
+ def _load_model(self):
59
+ """Load the BRIA segmentation model."""
60
+ try:
61
+ logger.info(f"Loading BRIA model: {self.model_name}")
62
+
63
+ self.model = AutoModelForImageSegmentation.from_pretrained(
64
+ self.model_name,
65
+ trust_remote_code=self.trust_remote_code,
66
+ cache_dir=self.cache_dir if self.cache_dir else None,
67
+ local_files_only=self.local_files_only,
68
+ ).eval().to(self.device)
69
+
70
+ # Define image transform
71
+ self.transform = transforms.Compose([
72
+ transforms.Resize((1024, 1024)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
75
+ ])
76
+
77
+ logger.info("BRIA model loaded successfully")
78
+
79
+ except Exception as e:
80
+ logger.error(f"Failed to load BRIA model: {e}")
81
+ raise
82
+
83
+ def segment_image(self, image: np.ndarray) -> np.ndarray:
84
+ """
85
+ Segment an image using the BRIA model.
86
+
87
+ Args:
88
+ image: Input image (BGR format)
89
+
90
+ Returns:
91
+ Binary mask (0/255)
92
+ """
93
+ if self.model is None:
94
+ raise RuntimeError("Model not loaded")
95
+
96
+ try:
97
+ # Convert BGR to RGB
98
+ rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
99
+ pil_image = Image.fromarray(rgb_image)
100
+
101
+ # Apply transform
102
+ input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
103
+
104
+ # Run inference
105
+ with torch.no_grad():
106
+ predictions = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
107
+
108
+ # Apply threshold
109
+ mask = (predictions > self.threshold).astype(np.uint8) * 255
110
+
111
+ # Resize back to original size
112
+ original_size = (image.shape[1], image.shape[0]) # (width, height)
113
+ mask_resized = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST)
114
+
115
+ return mask_resized
116
+
117
+ except Exception as e:
118
+ logger.error(f"Segmentation failed: {e}")
119
+ # Return empty mask
120
+ return np.zeros(image.shape[:2], dtype=np.uint8)
121
+
122
+ def segment_image_soft(self, image: np.ndarray) -> np.ndarray:
123
+ """
124
+ Segment an image and return a soft mask in [0, 1] resized to original size.
125
+ No thresholding or post-processing is applied.
126
+
127
+ Args:
128
+ image: Input image (BGR format)
129
+
130
+ Returns:
131
+ Float mask in [0,1] with shape (H, W)
132
+ """
133
+ if self.model is None:
134
+ raise RuntimeError("Model not loaded")
135
+ try:
136
+ rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
137
+ pil_image = Image.fromarray(rgb_image)
138
+ input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
139
+ with torch.no_grad():
140
+ preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
141
+ original_size = (image.shape[1], image.shape[0])
142
+ soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR)
143
+ return np.clip(soft_mask, 0.0, 1.0)
144
+ except Exception as e:
145
+ logger.error(f"Soft segmentation failed: {e}")
146
+ return np.zeros(image.shape[:2], dtype=np.float32)
147
+
148
+ def post_process_mask(self, mask: np.ndarray,
149
+ min_area: int = 1000,
150
+ kernel_size: int = 5) -> np.ndarray:
151
+ """
152
+ Post-process segmentation mask.
153
+
154
+ Args:
155
+ mask: Input mask
156
+ min_area: Minimum area for connected components
157
+ kernel_size: Kernel size for morphological operations
158
+
159
+ Returns:
160
+ Post-processed mask
161
+ """
162
+ try:
163
+ # Morphological opening to remove noise
164
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
165
+ opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
166
+
167
+ # Remove small connected components
168
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
169
+ opened, connectivity=8
170
+ )
171
+
172
+ processed_mask = np.zeros_like(opened)
173
+ for label in range(1, num_labels): # Skip background
174
+ if stats[label, cv2.CC_STAT_AREA] >= min_area:
175
+ processed_mask[labels == label] = 255
176
+
177
+ return processed_mask
178
+
179
+ except Exception as e:
180
+ logger.error(f"Mask post-processing failed: {e}")
181
+ return mask
182
+
183
+ def keep_largest_component(self, mask: np.ndarray) -> np.ndarray:
184
+ """
185
+ Keep only the largest connected component.
186
+
187
+ Args:
188
+ mask: Input mask
189
+
190
+ Returns:
191
+ Mask with only the largest component
192
+ """
193
+ try:
194
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8)
195
+
196
+ if num_labels <= 1:
197
+ return mask
198
+
199
+ # Find the largest component (excluding background)
200
+ areas = stats[1:, cv2.CC_STAT_AREA]
201
+ largest_label = 1 + np.argmax(areas)
202
+
203
+ # Create mask with only the largest component
204
+ largest_mask = (labels == largest_label).astype(np.uint8) * 255
205
+
206
+ return largest_mask
207
+
208
+ except Exception as e:
209
+ logger.error(f"Largest component extraction failed: {e}")
210
+ return mask
211
+
212
+ def validate_mask(self, mask: np.ndarray) -> bool:
213
+ """
214
+ Validate segmentation mask.
215
+
216
+ Args:
217
+ mask: Mask to validate
218
+
219
+ Returns:
220
+ True if valid, False otherwise
221
+ """
222
+ if mask is None:
223
+ return False
224
+
225
+ if not isinstance(mask, np.ndarray):
226
+ return False
227
+
228
+ if mask.ndim != 2:
229
+ return False
230
+
231
+ if mask.dtype not in [np.uint8, np.bool_]:
232
+ return False
233
+
234
+ # Check if mask has any foreground pixels
235
+ if np.sum(mask > 0) == 0:
236
+ logger.warning("Mask has no foreground pixels")
237
+ return False
238
+
239
+ return True
240
+
241
+ def get_mask_properties(self, mask: np.ndarray) -> dict:
242
+ """
243
+ Get properties of the segmentation mask.
244
+
245
+ Args:
246
+ mask: Binary mask
247
+
248
+ Returns:
249
+ Dictionary of mask properties
250
+ """
251
+ if not self.validate_mask(mask):
252
+ return {}
253
+
254
+ try:
255
+ # Convert to binary
256
+ binary_mask = (mask > 127).astype(np.uint8)
257
+
258
+ # Calculate properties
259
+ area = np.sum(binary_mask)
260
+ perimeter = 0
261
+
262
+ # Find contours
263
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
264
+ if contours:
265
+ perimeter = cv2.arcLength(contours[0], True)
266
+
267
+ # Bounding box
268
+ x, y, w, h = cv2.boundingRect(contours[0])
269
+ bbox_area = w * h
270
+ aspect_ratio = w / h if h > 0 else 0
271
+ else:
272
+ bbox_area = 0
273
+ aspect_ratio = 0
274
+
275
+ return {
276
+ "area": int(area),
277
+ "perimeter": float(perimeter),
278
+ "bbox_area": int(bbox_area),
279
+ "aspect_ratio": float(aspect_ratio),
280
+ "coverage": float(area) / (mask.shape[0] * mask.shape[1]) if mask.size > 0 else 0.0,
281
+ "num_components": len(contours)
282
+ }
283
+
284
+ except Exception as e:
285
+ logger.error(f"Mask property calculation failed: {e}")
286
+ return {}
287
+
288
+ def create_overlay(self, image: np.ndarray, mask: np.ndarray,
289
+ color: Tuple[int, int, int] = (0, 255, 0),
290
+ alpha: float = 0.5) -> np.ndarray:
291
+ """
292
+ Create overlay of mask on image.
293
+
294
+ Args:
295
+ image: Base image
296
+ mask: Binary mask
297
+ color: Overlay color (B, G, R)
298
+ alpha: Overlay transparency
299
+
300
+ Returns:
301
+ Image with mask overlay
302
+ """
303
+ try:
304
+ overlay = image.copy()
305
+ overlay[mask == 255] = color
306
+ return cv2.addWeighted(image, 1.0 - alpha, overlay, alpha, 0)
307
+ except Exception as e:
308
+ logger.error(f"Overlay creation failed: {e}")
309
+ return image