Fahimeh Orvati Nia
make pipeline minimal
dd1d7f5
raw
history blame
9.76 kB
"""
Main pipeline class for the Sorghum Plant Phenotyping Pipeline.
Minimal single-image version for Hugging Face demo.
"""
import os
import logging
from pathlib import Path
from typing import Dict, Any, Optional
import numpy as np
import cv2
from sklearn.decomposition import PCA
from .config import Config
from .data import ImagePreprocessor, MaskHandler
from .features import TextureExtractor, VegetationIndexExtractor, MorphologyExtractor
from .output import OutputManager
from .segmentation import SegmentationManager
logger = logging.getLogger(__name__)
class SorghumPipeline:
"""Minimal pipeline for single-image plant phenotyping."""
def __init__(self, config: Config):
"""Initialize the minimal pipeline."""
self._setup_logging()
self.config = config
self.config.validate()
self._initialize_components()
logger.info("Sorghum Pipeline initialized")
def _setup_logging(self):
"""Setup logging configuration."""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
def _initialize_components(self):
"""Initialize pipeline components."""
self.preprocessor = ImagePreprocessor(target_size=None)
self.mask_handler = MaskHandler(min_area=1000, kernel_size=7)
self.texture_extractor = TextureExtractor()
self.vegetation_extractor = VegetationIndexExtractor()
self.morphology_extractor = MorphologyExtractor()
self.segmentation_manager = SegmentationManager(
model_name="briaai/RMBG-2.0",
device=self.config.get_device(),
threshold=0.5,
trust_remote_code=True
)
self.output_manager = OutputManager(
output_folder=self.config.paths.output_folder,
settings=self.config.output
)
def run(self, single_image_path: str) -> Dict[str, Any]:
"""
Run minimal pipeline on single image.
Args:
single_image_path: Path to input image
Returns:
Dictionary containing results
"""
logger.info("Starting minimal single-image pipeline...")
try:
import time
from PIL import Image as _Image
total_start = time.perf_counter()
# Load single image
_p = Path(single_image_path)
_img = _Image.open(str(_p))
plants = {
"demo_demo_frame1": {
"raw_image": (_img, _p.name),
"plant_name": "demo",
"file_path": str(_p)
}
}
# Create composite
plants = self.preprocessor.create_composites(plants)
# Segment
plants = self._segment_plants(plants)
# Extract features
plants = self._extract_features(plants)
# Generate outputs
self._generate_outputs(plants)
# Summary
summary = self._create_summary(plants)
total_time = time.perf_counter() - total_start
logger.info(f"Pipeline completed in {total_time:.2f}s")
return {
"plants": plants,
"summary": summary,
"config": self.config,
"timing_seconds": total_time
}
except Exception as e:
logger.error(f"Pipeline failed: {e}")
raise
def _segment_plants(self, plants: Dict[str, Any]) -> Dict[str, Any]:
"""Segment plants using BRIA model (full image)."""
for key, pdata in plants.items():
try:
composite = pdata['composite']
soft_mask = self.segmentation_manager.segment_image_soft(composite)
pdata['soft_mask'] = soft_mask
pdata['mask'] = (soft_mask * 255.0).astype(np.uint8)
logger.info(f"Segmented {key}")
except Exception as e:
logger.error(f"Segmentation failed for {key}: {e}")
pdata['soft_mask'] = np.zeros(composite.shape[:2], dtype=np.float32)
pdata['mask'] = np.zeros(composite.shape[:2], dtype=np.uint8)
return plants
def _extract_features(self, plants: Dict[str, Any]) -> Dict[str, Any]:
"""Extract features from plants."""
for key, pdata in plants.items():
try:
pdata['texture_features'] = self._extract_texture_features(pdata)
pdata['vegetation_indices'] = self._extract_vegetation_indices(pdata)
pdata['morphology_features'] = self._extract_morphology_features(pdata)
logger.info(f"Features extracted for {key}")
except Exception as e:
logger.error(f"Feature extraction failed for {key}: {e}")
pdata['texture_features'] = {}
pdata['vegetation_indices'] = {}
pdata['morphology_features'] = {}
return plants
def _extract_texture_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
"""Extract texture features from pseudo-color image only."""
features = {}
try:
# Only process pseudo-color composite
composite = pdata['composite']
mask = pdata.get('mask')
if mask is not None:
masked = self.mask_handler.apply_mask_to_image(composite, mask)
gray_image = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
else:
gray_image = cv2.cvtColor(composite, cv2.COLOR_BGR2GRAY)
band_features = self.texture_extractor.extract_all_texture_features(gray_image)
stats = self.texture_extractor.compute_texture_statistics(band_features, mask)
features['color'] = {
'features': band_features,
'statistics': stats
}
except Exception as e:
logger.error(f"Texture extraction failed: {e}")
features['color'] = {'features': {}, 'statistics': {}}
return features
def _extract_vegetation_indices(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
"""Extract vegetation indices (NDVI, ARI, GNDVI only)."""
try:
spectral_stack = pdata.get('spectral_stack', {})
mask = pdata.get('mask')
if not spectral_stack or mask is None:
return {}
out: Dict[str, Any] = {}
for name in ("NDVI", "ARI", "GNDVI"):
bands = self.vegetation_extractor.index_bands.get(name, [])
if not all(b in spectral_stack for b in bands):
continue
arrays = []
for b in bands:
arr = spectral_stack[b]
if isinstance(arr, np.ndarray):
arr = arr.squeeze(-1)
arrays.append(np.asarray(arr, dtype=np.float64))
values = self.vegetation_extractor.index_formulas[name](*arrays).astype(np.float64)
binary_mask = (np.asarray(mask).astype(np.int32) > 0)
masked_values = np.where(binary_mask, values, np.nan)
valid = masked_values[~np.isnan(masked_values)]
stats = {
'mean': float(np.mean(valid)) if valid.size else 0.0,
'std': float(np.std(valid)) if valid.size else 0.0,
'min': float(np.min(valid)) if valid.size else 0.0,
'max': float(np.max(valid)) if valid.size else 0.0,
'median': float(np.median(valid)) if valid.size else 0.0,
}
out[name] = {'values': masked_values, 'statistics': stats}
return out
except Exception as e:
logger.error(f"Vegetation index extraction failed: {e}")
return {}
def _extract_morphology_features(self, pdata: Dict[str, Any]) -> Dict[str, Any]:
"""Extract morphological features."""
try:
composite = pdata.get('composite')
mask = pdata.get('mask')
if composite is None or mask is None:
return {}
return self.morphology_extractor.extract_morphology_features(composite, mask)
except Exception as e:
logger.error(f"Morphology extraction failed: {e}")
return {}
def _generate_outputs(self, plants: Dict[str, Any]) -> None:
"""Generate output files."""
self.output_manager.create_output_directories()
for key, pdata in plants.items():
try:
self.output_manager.save_plant_results(key, pdata)
except Exception as e:
logger.error(f"Output generation failed for {key}: {e}")
def _create_summary(self, plants: Dict[str, Any]) -> Dict[str, Any]:
"""Create summary of results."""
return {
"total_plants": len(plants),
"successful_plants": sum(1 for p in plants.values() if p.get('texture_features')),
"features_extracted": {
"texture": sum(1 for p in plants.values() if p.get('texture_features')),
"vegetation": sum(1 for p in plants.values() if p.get('vegetation_indices')),
"morphology": sum(1 for p in plants.values() if p.get('morphology_features'))
}
}