Fahimeh Orvati Nia
update the morphology, remove yolo, and correct the display
60e6efb
raw
history blame
5.68 kB
from pathlib import Path
from typing import Dict, Callable, Optional, Generator, Any
import shutil
from PIL import Image
import glob
import os
from sorghum_pipeline.pipeline import SorghumPipeline
from sorghum_pipeline.config import Config, Paths
def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True,
progress_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> Generator[Dict[str, str], None, None]:
"""
Run sorghum pipeline on a single image (no instance segmentation).
Yields dict[label -> image_path] progressively for gallery display.
Args:
input_image_path: Path to input image
work_dir: Working directory for outputs
save_artifacts: Whether to save artifacts
progress_callback: Optional callback(stage_name, data) called after each pipeline stage
Yields:
Dictionary of output paths progressively as they become available
"""
work = Path(work_dir)
work.mkdir(parents=True, exist_ok=True)
# Use input path directly (already in work_dir from app.py)
input_path = Path(input_image_path)
# Ensure demo env vars are set before pipeline construction
os.environ['MINIMAL_DEMO'] = '1'
os.environ['FAST_OUTPUT'] = '1'
# Build in-memory config pointing input/output to the working directory
cfg = Config()
cfg.paths = Paths(
input_folder=str(work),
output_folder=str(work),
boundingbox_dir=str(work)
)
pipeline = SorghumPipeline(config=cfg)
# Run the pipeline with progress callback (generator)
for stage_result in pipeline.run_with_progress(single_image_path=str(input_path), progress_callback=progress_callback):
# Yield intermediate outputs as they become available
outputs = _collect_outputs(work, stage_result.get('plants', {}))
yield outputs
# Final results
results = stage_result
def _collect_outputs(work: Path, plants: Dict[str, Any]) -> Dict[str, str]:
"""Collect all available outputs from work directory and plants data."""
outputs: Dict[str, str] = {}
try:
# Log immediate output directory contents for debugging
for sub in ['results', 'Vegetation_indices_images', 'texture_output']:
p = work / sub
if p.exists():
files = sorted([str(x.name) for x in p.iterdir() if x.is_file()])
print(f"Artifacts in {sub}: {files}")
except Exception:
pass
# Collect desired vegetation indices (replace ARI with SAVI)
wanted = [
work / 'Vegetation_indices_images/ndvi.png',
work / 'Vegetation_indices_images/gndvi.png',
work / 'Vegetation_indices_images/savi.png',
]
labels = [
'NDVI', 'GNDVI', 'SAVI',
]
for label, path in zip(labels, wanted):
if path.exists():
outputs[label] = str(path)
# Also include overlay and mask if present
overlay_path = work / 'results/overlay.png'
mask_path = work / 'results/mask.png'
composite_path = work / 'results/composite.png'
input_img_path = work / 'results/input_image.png'
if overlay_path.exists():
outputs['Overlay'] = str(overlay_path)
if mask_path.exists():
outputs['Mask'] = str(mask_path)
if composite_path.exists():
outputs['Composite'] = str(composite_path)
if input_img_path.exists():
outputs['InputImage'] = str(input_img_path)
# Extract simple stats for display if present in pipeline results
try:
if plants:
_, pdata = next(iter(plants.items()))
veg = pdata.get('vegetation_indices', {})
stats_lines = []
for name in ['NDVI', 'GNDVI', 'SAVI']:
entry = veg.get(name, {})
st = entry.get('statistics', {}) if isinstance(entry, dict) else {}
if st:
stats_lines.append(f"{name}: mean={st.get('mean', 0):.3f}, std={st.get('std', 0):.3f}")
# Morphology stats (height for multiple plants)
morph = pdata.get('morphology_features', {}) if isinstance(pdata, dict) else {}
traits = morph.get('traits', {}) if isinstance(morph, dict) else {}
# Check if we have multiple plants
plant_heights = traits.get('plant_heights', {})
num_plants = traits.get('num_plants', 0)
if isinstance(plant_heights, dict) and len(plant_heights) > 1:
# Multiple plants detected
stats_lines.append(f"Number of plants: {num_plants}")
# Sort by plant index for consistent display
sorted_plants = sorted(plant_heights.items(), key=lambda x: int(x[0].split('_')[1]))
for plant_name, height_cm in sorted_plants:
plant_num = plant_name.split('_')[1]
stats_lines.append(f" Plant {plant_num}: {height_cm:.2f} cm")
elif isinstance(plant_heights, dict) and len(plant_heights) == 1:
# Single plant
height_cm = list(plant_heights.values())[0]
stats_lines.append(f"Plant height: {height_cm:.2f} cm")
else:
# Fallback to old single height field
height_cm = traits.get('plant_height_cm')
if isinstance(height_cm, (int, float)) and height_cm > 0:
stats_lines.append(f"Plant height: {height_cm:.2f} cm")
if stats_lines:
outputs['StatsText'] = "\n".join(stats_lines)
except Exception:
pass
return outputs