Fahimeh Orvati Nia
update sorghum for multiple plants
f8ac29e
raw
history blame
6 kB
from pathlib import Path
from typing import Dict, Callable, Optional, Generator, Any
import shutil
from PIL import Image
import glob
import os
from sorghum_pipeline.pipeline import SorghumPipeline
from sorghum_pipeline.config import Config, Paths
def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True,
progress_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
single_plant_mode: bool = False) -> Generator[Dict[str, str], None, None]:
"""
Run sorghum pipeline on a single image (no instance segmentation).
Yields dict[label -> image_path] progressively for gallery display.
Args:
input_image_path: Path to input image
work_dir: Working directory for outputs
save_artifacts: Whether to save artifacts
progress_callback: Optional callback(stage_name, data) called after each pipeline stage
Yields:
Dictionary of output paths progressively as they become available
"""
work = Path(work_dir)
work.mkdir(parents=True, exist_ok=True)
# Use input path directly (already in work_dir from app.py)
input_path = Path(input_image_path)
# Ensure demo env vars are set before pipeline construction
os.environ['MINIMAL_DEMO'] = '1'
os.environ['FAST_OUTPUT'] = '1'
# Build in-memory config pointing input/output to the working directory
cfg = Config()
cfg.paths = Paths(
input_folder=str(work),
output_folder=str(work),
boundingbox_dir=str(work)
)
pipeline = SorghumPipeline(config=cfg, single_plant_mode=single_plant_mode)
# Run the pipeline with progress callback (generator)
for stage_result in pipeline.run_with_progress(single_image_path=str(input_path), progress_callback=progress_callback, single_plant_mode=single_plant_mode):
# Yield intermediate outputs as they become available
outputs = _collect_outputs(work, stage_result.get('plants', {}))
yield outputs
# Final results
results = stage_result
def _collect_outputs(work: Path, plants: Dict[str, Any]) -> Dict[str, str]:
"""Collect all available outputs from work directory and plants data."""
outputs: Dict[str, str] = {}
try:
# Log immediate output directory contents for debugging
for sub in ['results', 'Vegetation_indices_images', 'texture_output']:
p = work / sub
if p.exists():
files = sorted([str(x.name) for x in p.iterdir() if x.is_file()])
print(f"Artifacts in {sub}: {files}")
except Exception:
pass
# Collect desired vegetation indices (replace ARI with SAVI)
wanted = [
work / 'Vegetation_indices_images/ndvi.png',
work / 'Vegetation_indices_images/gndvi.png',
work / 'Vegetation_indices_images/savi.png',
]
labels = [
'NDVI', 'GNDVI', 'SAVI',
]
for label, path in zip(labels, wanted):
if path.exists():
outputs[label] = str(path)
# Also include overlay and mask if present
overlay_path = work / 'results/overlay.png'
mask_path = work / 'results/mask.png'
composite_path = work / 'results/composite.png'
input_img_path = work / 'results/input_image.png'
if overlay_path.exists():
outputs['Overlay'] = str(overlay_path)
if mask_path.exists():
outputs['Mask'] = str(mask_path)
if composite_path.exists():
outputs['Composite'] = str(composite_path)
if input_img_path.exists():
outputs['InputImage'] = str(input_img_path)
# Extract simple stats for display if present in pipeline results
try:
if plants:
_, pdata = next(iter(plants.items()))
veg = pdata.get('vegetation_indices', {})
stats_lines = []
for name in ['NDVI', 'GNDVI', 'SAVI']:
entry = veg.get(name, {})
st = entry.get('statistics', {}) if isinstance(entry, dict) else {}
if st:
stats_lines.append(f"{name}: mean={st.get('mean', 0):.3f}, std={st.get('std', 0):.3f}")
# Morphology stats (height - always show as single plant)
morph = pdata.get('morphology_features', {}) if isinstance(pdata, dict) else {}
traits = morph.get('traits', {}) if isinstance(morph, dict) else {}
# Get plant height (system now filters to largest plant only)
plant_heights = traits.get('plant_heights', {})
num_plants = traits.get('num_plants', 0)
# Display plant info based on mode
if num_plants > 0 and isinstance(plant_heights, dict):
if num_plants == 1 or len(plant_heights) == 1:
# Single plant display
height_cm = list(plant_heights.values())[0]
stats_lines.append(f"Number of plants: 1")
stats_lines.append(f"Plant height: {height_cm:.2f} cm")
else:
# Multiple plants display
stats_lines.append(f"Number of plants: {num_plants}")
sorted_plants = sorted(plant_heights.items(), key=lambda x: int(x[0].split('_')[1]))
for plant_name, height_cm in sorted_plants:
plant_num = plant_name.split('_')[1]
stats_lines.append(f" Plant {plant_num}: {height_cm:.2f} cm")
else:
# Fallback to old single height field
height_cm = traits.get('plant_height_cm')
if isinstance(height_cm, (int, float)) and height_cm > 0:
stats_lines.append(f"Number of plants: 1")
stats_lines.append(f"Plant height: {height_cm:.2f} cm")
if stats_lines:
outputs['StatsText'] = "\n".join(stats_lines)
except Exception:
pass
return outputs