|
|
""" |
|
|
Main pipeline class for the Sorghum Plant Phenotyping Pipeline. |
|
|
|
|
|
Minimal single-image version for Hugging Face demo. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, Optional |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from sklearn.decomposition import PCA |
|
|
|
|
|
from .config import Config |
|
|
from .data import ImagePreprocessor, MaskHandler |
|
|
from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor |
|
|
from .output import OutputManager |
|
|
from .segmentation import SegmentationManager |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class SorghumPipeline: |
|
|
"""Minimal pipeline for single-image plant phenotyping.""" |
|
|
|
|
|
def __init__(self, config: Config): |
|
|
"""Initialize the minimal pipeline.""" |
|
|
self._setup_logging() |
|
|
self.config = config |
|
|
self.config.validate() |
|
|
self._initialize_components() |
|
|
logger.info("Sorghum Pipeline initialized") |
|
|
|
|
|
def _setup_logging(self): |
|
|
"""Setup logging configuration.""" |
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
handlers=[logging.StreamHandler()] |
|
|
) |
|
|
|
|
|
def _initialize_components(self): |
|
|
"""Initialize pipeline components.""" |
|
|
self.preprocessor = ImagePreprocessor(target_size=None) |
|
|
self.mask_handler = MaskHandler(min_area=1000, kernel_size=7) |
|
|
self.texture_extractor = TextureExtractor() |
|
|
self.vegetation_extractor = VegetationIndexExtractor() |
|
|
self.morphology_extractor = MorphologyExtractor() |
|
|
self.segmentation_manager = SegmentationManager( |
|
|
model_name="briaai/RMBG-2.0", |
|
|
device=self.config.get_device(), |
|
|
threshold=0.5, |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.output_manager = OutputManager( |
|
|
output_folder=self.config.paths.output_folder, |
|
|
settings=self.config.output |
|
|
) |
|
|
|
|
|
def run(self, single_image_path: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Run minimal pipeline on single image. |
|
|
|
|
|
Args: |
|
|
single_image_path: Path to input image |
|
|
|
|
|
Returns: |
|
|
Dictionary containing results |
|
|
""" |
|
|
logger.info("Starting minimal single-image pipeline...") |
|
|
|
|
|
try: |
|
|
import time |
|
|
from PIL import Image as _Image |
|
|
|
|
|
total_start = time.perf_counter() |
|
|
|
|
|
|
|
|
_p = Path(single_image_path) |
|
|
_img = _Image.open(str(_p)) |
|
|
plants = { |
|
|
"demo_demo_frame1": { |
|
|
"raw_image": (_img, _p.name), |
|
|
"plant_name": "demo", |
|
|
"file_path": str(_p) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
plants = self.preprocessor.create_composites(plants) |
|
|
|
|
|
|
|
|
plants = self._segment_plants(plants) |
|
|
|
|
|
|
|
|
plants = self._extract_features(plants) |
|
|
|
|
|
|
|
|
self._generate_outputs(plants) |
|
|
|
|
|
|
|
|
summary = self._create_summary(plants) |
|
|
|
|
|
total_time = time.perf_counter() - total_start |
|
|
logger.info(f"Pipeline completed in {total_time:.2f}s") |
|
|
|
|
|
return { |
|
|
"plants": plants, |
|
|
"summary": summary, |
|
|
"config": self.config, |
|
|
"timing_seconds": total_time |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Pipeline failed: {e}") |
|
|
raise |
|
|
|
|
|
def _segment_plants(self, plants: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Segment plants using BRIA model (full image).""" |
|
|
for key, pdata in plants.items(): |
|
|
try: |
|
|
composite = pdata['composite'] |
|
|
soft_mask = self.segmentation_manager.segment_image_soft(composite) |
|
|
pdata['soft_mask'] = soft_mask |
|
|
pdata['mask'] = (soft_mask * 255.0).astype(np.uint8) |
|
|
logger.info(f"Segmented {key}") |
|
|
except Exception as e: |
|
|
logger.error(f"Segmentation failed for {key}: {e}") |
|
|
pdata['soft_mask'] = np.zeros(composite.shape[:2], dtype=np.float32) |
|
|
pdata['mask'] = np.zeros(composite.shape[:2], dtype=np.uint8) |
|
|
return plants |
|
|
|
|
|
def _extract_features(self, plants: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Extract features from plants.""" |
|
|
for key, pdata in plants.items(): |
|
|
try: |
|
|
pdata['texture_features'] = self._extract_texture_features(pdata) |
|
|
pdata['vegetation_indices'] = self._extract_vegetation_indices(pdata) |
|
|
pdata['morphology_features'] = self._extract_morphology_features(pdata) |
|
|
logger.info(f"Features extracted for {key}") |
|
|
except Exception as e: |
|
|
logger.error(f"Feature extraction failed for {key}: {e}") |
|
|
pdata['texture_features'] = {} |
|
|
pdata['vegetation_indices'] = {} |
|
|
pdata['morphology_features'] = {} |
|
|
return plants |
|
|
|
|
|
def _extract_texture_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Extract texture features from pseudo-color image only.""" |
|
|
features = {} |
|
|
try: |
|
|
|
|
|
composite = pdata['composite'] |
|
|
mask = pdata.get('mask') |
|
|
if mask is not None: |
|
|
masked = self.mask_handler.apply_mask_to_image(composite, mask) |
|
|
gray_image = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY) |
|
|
else: |
|
|
gray_image = cv2.cvtColor(composite, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
band_features = self.texture_extractor.extract_all_texture_features(gray_image) |
|
|
stats = self.texture_extractor.compute_texture_statistics(band_features, mask) |
|
|
|
|
|
features['color'] = { |
|
|
'features': band_features, |
|
|
'statistics': stats |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Texture extraction failed: {e}") |
|
|
features['color'] = {'features': {}, 'statistics': {}} |
|
|
|
|
|
return features |
|
|
|
|
|
def _extract_vegetation_indices(self, pdata: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Extract vegetation indices (NDVI, ARI, GNDVI only).""" |
|
|
try: |
|
|
spectral_stack = pdata.get('spectral_stack', {}) |
|
|
mask = pdata.get('mask') |
|
|
if not spectral_stack or mask is None: |
|
|
return {} |
|
|
|
|
|
out: Dict[str, Any] = {} |
|
|
for name in ("NDVI", "ARI", "GNDVI"): |
|
|
bands = self.vegetation_extractor.index_bands.get(name, []) |
|
|
if not all(b in spectral_stack for b in bands): |
|
|
continue |
|
|
arrays = [] |
|
|
for b in bands: |
|
|
arr = spectral_stack[b] |
|
|
if isinstance(arr, np.ndarray): |
|
|
arr = arr.squeeze(-1) |
|
|
arrays.append(np.asarray(arr, dtype=np.float64)) |
|
|
|
|
|
values = self.vegetation_extractor.index_formulas[name](*arrays).astype(np.float64) |
|
|
binary_mask = (np.asarray(mask).astype(np.int32) > 0) |
|
|
masked_values = np.where(binary_mask, values, np.nan) |
|
|
valid = masked_values[~np.isnan(masked_values)] |
|
|
|
|
|
stats = { |
|
|
'mean': float(np.mean(valid)) if valid.size else 0.0, |
|
|
'std': float(np.std(valid)) if valid.size else 0.0, |
|
|
'min': float(np.min(valid)) if valid.size else 0.0, |
|
|
'max': float(np.max(valid)) if valid.size else 0.0, |
|
|
'median': float(np.median(valid)) if valid.size else 0.0, |
|
|
} |
|
|
out[name] = {'values': masked_values, 'statistics': stats} |
|
|
return out |
|
|
except Exception as e: |
|
|
logger.error(f"Vegetation index extraction failed: {e}") |
|
|
return {} |
|
|
|
|
|
def _extract_morphology_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Extract morphological features.""" |
|
|
try: |
|
|
composite = pdata.get('composite') |
|
|
mask = pdata.get('mask') |
|
|
if composite is None or mask is None: |
|
|
return {} |
|
|
return self.morphology_extractor.extract_morphology_features(composite, mask) |
|
|
except Exception as e: |
|
|
logger.error(f"Morphology extraction failed: {e}") |
|
|
return {} |
|
|
|
|
|
def _generate_outputs(self, plants: Dict[str, Any]) -> None: |
|
|
"""Generate output files.""" |
|
|
self.output_manager.create_output_directories() |
|
|
for key, pdata in plants.items(): |
|
|
try: |
|
|
self.output_manager.save_plant_results(key, pdata) |
|
|
except Exception as e: |
|
|
logger.error(f"Output generation failed for {key}: {e}") |
|
|
|
|
|
def _create_summary(self, plants: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Create summary of results.""" |
|
|
return { |
|
|
"total_plants": len(plants), |
|
|
"successful_plants": sum(1 for p in plants.values() if p.get('texture_features')), |
|
|
"features_extracted": { |
|
|
"texture": sum(1 for p in plants.values() if p.get('texture_features')), |
|
|
"vegetation": sum(1 for p in plants.values() if p.get('vegetation_indices')), |
|
|
"morphology": sum(1 for p in plants.values() if p.get('morphology_features')) |
|
|
} |
|
|
} |