Fahimeh Orvati Nia
Add sorghum_pipeline code
b4123b8
raw
history blame
18.5 kB
"""
Data loading functionality for the Sorghum Pipeline.
This module handles loading raw images, managing plant data,
and organizing data according to the pipeline requirements.
"""
import os
import glob
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from PIL import Image
import numpy as np
import logging
logger = logging.getLogger(__name__)
class DataLoader:
"""Handles loading and organizing plant image data."""
# Plants to ignore completely (empty by default)
IGNORE_PLANTS = set()
# Plants where you want exactly one frame from their own folder
EXACT_FRAME = {
4: 7, 5: 5, 7: 5, 12: 5, 13: 5, 18: 7, 19: 2, 20: 3,
24: 6, 25: 5, 26: 5, 30: 8, 37: 7
}
# Plants where you want to borrow a frame from a different plant folder
BORROW_FRAME = {
14: (13, 5), 15: (14, 5), 16: (15, 5), 33: (34, 7),
34: (35, 7), 35: (35, 8), 36: (36, 6)
}
# Overrides provided by user: preferred frame per target plant name
FRAME_OVERRIDE_BY_NAME = {
'plant1': 9, 'plant2': 10, 'plant3': 9, 'plant5': 7, 'plant6': 9, 'plant8': 5,
'plant7': 9, 'plant10': 9, 'plant11': 9, 'plant12': 9,
'plant13': 10, 'plant14': 8, 'plant15': 11, 'plant19': 4, 'plant20': 7,
'plant21': 9, 'plant22': 10, 'plant25': 4, 'plant26': 2, 'plant27': 10, 'plant28': 9, 'plant29': 2,
'plant30': 9, 'plant31': 10, 'plant32': 9, 'plant33': 8,
'plant35': 9, 'plant36': 4, 'plant38': 9, 'plant39': 9, 'plant41': 9,
'plant42': 6, 'plant43': 10, 'plant44': 9, 'plant45': 7,
'plant47': 10, 'plant48': 11,
}
# Substitutes provided by user: map target plant name -> source plant name
PLANT_SUBSTITUTES_BY_NAME = {
'plant16': 'plant15', 'plant15': 'plant14', 'plant14': 'plant13',
'plant13': 'plant12', 'plant33': 'plant34', 'plant34': 'plant35',
'plant24': 'plant25', 'plant25': 'plant25', 'plant35': 'plant36',
'plant36': 'plant37', 'plant37': 'plant37', 'plant44': 'plant43',
'plant45': 'plant44',
}
def __init__(self, input_folder: str, debug: bool = False, include_ignored: bool = False, strict_loader: bool = False, excluded_dates: Optional[List[str]] = None):
"""
Initialize the data loader.
Args:
input_folder: Path to the input dataset folder
debug: Enable debug logging
"""
self.input_folder = Path(input_folder)
self.debug = debug
self.include_ignored = include_ignored
self.strict_loader = strict_loader
if not self.input_folder.exists():
raise FileNotFoundError(f"Input folder does not exist: {input_folder}")
# Normalize excluded dates as a set of folder names (with dashes)
self.excluded_dates = set(excluded_dates or [])
def load_selected_frames(self) -> Dict[str, Dict[str, Any]]:
"""
Load selected frames according to predefined rules.
If strict_loader is True, load only frame numbers from the plant's own folder (no borrowing/special picks).
Returns:
Dictionary with plant data organized by key format: "YYYY_MM_DD_plantX_frameY"
"""
logger.info("Loading selected frames from dataset...")
plants = {}
# Detect if input folder is a direct date folder (contains plant folders)
first_items = list(self.input_folder.iterdir())
has_plant_folders = any(item.is_dir() and item.name.startswith('plant') for item in first_items)
def choose_frame_and_source(pid: int) -> Tuple[int, str]:
if self.strict_loader:
# In strict mode, honor explicit frame overrides AND substitution of source plant
plant_name_local = f"plant{pid}"
frame_num = self.FRAME_OVERRIDE_BY_NAME.get(
plant_name_local,
self.EXACT_FRAME.get(pid, 8)
)
source_plant = self.PLANT_SUBSTITUTES_BY_NAME.get(plant_name_local, plant_name_local)
return frame_num, source_plant
# Original behavior
frame_num = self._get_frame_number(pid)
source_plant = self._get_source_plant(pid)
return frame_num, source_plant
if has_plant_folders:
# Direct date folder structure
date_name = self.input_folder.name
date_path = self.input_folder
for plant_name in sorted(os.listdir(date_path)):
plant_path = date_path / plant_name
if not plant_path.is_dir():
continue
try:
plant_id = int(plant_name.replace("plant", ""))
except ValueError:
continue
if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
if self.debug:
logger.debug(f"Ignoring plant {plant_id}")
continue
frame_num, source_plant = choose_frame_and_source(plant_id)
frame_data = self._load_single_frame(date_path, source_plant, frame_num, plant_name)
if frame_data:
key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_num}"
plants[key] = frame_data
logger.debug(f"Loaded {key}")
else:
# Parent folder structure with date subfolders
for date_name in sorted(os.listdir(self.input_folder)):
date_path = self.input_folder / date_name
if not date_path.is_dir():
continue
if date_name in self.excluded_dates:
logger.info(f"Skipping excluded date: {date_name}")
continue
for plant_name in sorted(os.listdir(date_path)):
plant_path = date_path / plant_name
if not plant_path.is_dir():
continue
try:
plant_id = int(plant_name.replace("plant", ""))
except ValueError:
continue
if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
if self.debug:
logger.debug(f"Ignoring plant {plant_id}")
continue
frame_num, source_plant = choose_frame_and_source(plant_id)
frame_data = self._load_single_frame(date_path, source_plant, frame_num, plant_name)
if frame_data:
key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_num}"
plants[key] = frame_data
logger.debug(f"Loaded {key}")
logger.info(f"Successfully loaded {len(plants)} plant frames")
return plants
def load_all_frames(self) -> Dict[str, Dict[str, Any]]:
"""
Load all available frames for each plant.
Returns:
Dictionary with all plant frames
"""
logger.info("Loading all frames from dataset...")
plants = {}
# Check if we're directly in a date folder (contains plant folders)
# or in a parent folder (contains date folders)
first_items = list(self.input_folder.iterdir())
has_plant_folders = any(item.is_dir() and item.name.startswith('plant') for item in first_items)
if has_plant_folders:
# We're directly in a date folder
logger.info("Detected direct date folder structure")
date_name = self.input_folder.name
self._load_plants_from_date_folder(self.input_folder, date_name, plants)
else:
# We're in a parent folder with date subfolders
logger.info("Detected parent folder structure")
for date_name in sorted(os.listdir(self.input_folder)):
date_path = self.input_folder / date_name
if not date_path.is_dir():
continue
if date_name in self.excluded_dates:
logger.info(f"Skipping excluded date: {date_name}")
continue
logger.info(f"Processing date: {date_name}")
self._load_plants_from_date_folder(date_path, date_name, plants)
logger.info(f"Successfully loaded {len(plants)} plant frames")
return plants
def _load_plants_from_date_folder(self, date_path: Path, date_name: str, plants: Dict[str, Dict[str, Any]]) -> None:
"""Load plants from a date folder."""
for plant_name in sorted(os.listdir(date_path)):
plant_path = date_path / plant_name
if not plant_path.is_dir():
continue
# Extract plant ID
try:
plant_id = int(plant_name.replace("plant", ""))
except ValueError:
logger.warning(f"Could not extract plant ID from {plant_name}")
continue
# Skip ignored plants
if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
logger.info(f"Skipping ignored plant {plant_id}")
continue
logger.info(f"Processing plant {plant_id}")
# Load all frames for this plant
pattern = str(plant_path / f"{plant_name}_frame*.tif")
frame_files = sorted(glob.glob(pattern))
logger.info(f"Found {len(frame_files)} frame files for {plant_name}")
for frame_path in frame_files:
frame_data = self._load_frame_from_path(frame_path, plant_name)
if frame_data:
frame_id = Path(frame_path).stem.split("_frame")[-1]
key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_id}"
plants[key] = frame_data
logger.debug(f"Loaded frame: {key}")
else:
logger.warning(f"Failed to load frame: {frame_path}")
def load_single_plant(self, date: str, plant: str, frame: int) -> Optional[Dict[str, Any]]:
"""
Load a specific plant frame.
Args:
date: Date string (e.g., "2025-02-05")
plant: Plant name (e.g., "plant1")
frame: Frame number
Returns:
Plant data dictionary or None if not found
"""
date_path = self.input_folder / date
if not date_path.exists():
logger.error(f"Date folder not found: {date}")
return None
plant_path = date_path / plant
if not plant_path.exists():
logger.error(f"Plant folder not found: {plant}")
return None
filename = f"{plant}_frame{frame}.tif"
frame_path = plant_path / filename
return self._load_frame_from_path(str(frame_path), plant)
def _get_frame_number(self, plant_id: int) -> int:
"""Get the frame number for a plant ID."""
plant_name = f"plant{plant_id}"
# Highest priority: explicit overrides by plant name
if plant_name in self.FRAME_OVERRIDE_BY_NAME:
return int(self.FRAME_OVERRIDE_BY_NAME[plant_name])
# Next: original exact/borrrow rules
if plant_id in self.EXACT_FRAME:
return self.EXACT_FRAME[plant_id]
elif plant_id in self.BORROW_FRAME:
return self.BORROW_FRAME[plant_id][1]
else:
return 8 # Default frame
def _get_source_plant(self, plant_id: int) -> str:
"""Get the source plant name for a plant ID."""
plant_name = f"plant{plant_id}"
# Highest priority: explicit substitutes by plant name
if plant_name in self.PLANT_SUBSTITUTES_BY_NAME:
return self.PLANT_SUBSTITUTES_BY_NAME[plant_name]
# Next: original borrow rules
if plant_id in self.BORROW_FRAME:
source_id = self.BORROW_FRAME[plant_id][0]
return f"plant{source_id}"
else:
return f"plant{plant_id}"
def _load_single_frame(self, date_path: Path, source_plant: str,
frame_num: int, target_plant: str) -> Optional[Dict[str, Any]]:
"""Load a single frame from the specified path."""
filename = f"{source_plant}_frame{frame_num}.tif"
frame_path = date_path / source_plant / filename
if not frame_path.exists():
if self.debug:
logger.warning(f"Frame not found: {frame_path}")
return None
return self._load_frame_from_path(str(frame_path), target_plant)
def _load_frame_from_path(self, frame_path: str, plant_name: str) -> Optional[Dict[str, Any]]:
"""Load frame data from a file path."""
try:
logger.debug(f"Attempting to load: {frame_path}")
image = Image.open(frame_path)
filename = Path(frame_path).name
logger.debug(f"Successfully loaded image: {filename}, size: {image.size}")
return {
"raw_image": (image, filename),
"plant_name": plant_name,
"file_path": frame_path
}
except Exception as e:
logger.error(f"Failed to load {frame_path}: {e}")
return None
def load_bounding_boxes(self, bbox_dir: str) -> Dict[str, Tuple[int, int, int, int]]:
"""
Load bounding box data from JSON files.
Args:
bbox_dir: Directory containing bounding box JSON files
Returns:
Dictionary mapping plant names to bounding box coordinates
"""
bbox_path = Path(bbox_dir)
if not bbox_path.exists():
raise FileNotFoundError(f"Bounding box directory not found: {bbox_dir}")
bbox_lookup = {}
for json_file in bbox_path.glob("*.json"):
stem = json_file.stem
# Normalize stems like plant_33_new -> plant33
if stem.startswith('plant_'):
parts = stem.split('_')
try:
idx = next(i for i,p in enumerate(parts) if p.isdigit())
plant_id = f"plant{parts[idx]}"
except Exception:
plant_id = stem.replace('_', '')
else:
plant_id = stem
try:
with open(json_file, 'r') as f:
data = json.load(f)
shapes = data.get('shapes', [])
# Prefer rectangle labeled 'sorghum' (case-insensitive), else first rectangle
def _is_sorghum_label(s: dict) -> bool:
for key in ('label', 'name', 'text'):
val = s.get(key)
if isinstance(val, str) and val.lower() == 'sorghum':
return True
return False
rect = next((s for s in shapes if s.get('shape_type') == 'rectangle' and _is_sorghum_label(s)), None)
if rect is None:
rect = next((s for s in shapes if s.get('shape_type') == 'rectangle'), None)
if rect:
(x1, y1), (x2, y2) = rect['points']
bbox_lookup[plant_id] = (
int(max(0, x1)),
int(max(0, y1)),
int(min(1e9, x2)),
int(min(1e9, y2))
)
else:
bbox_lookup[plant_id] = None
except Exception as e:
logger.error(f"Failed to load bounding box {json_file}: {e}")
logger.info(f"Loaded {len(bbox_lookup)} bounding boxes")
return bbox_lookup
def load_hand_labels(self, labels_dir: str) -> Dict[str, np.ndarray]:
"""
Load hand-labeled masks from JSON files.
Args:
labels_dir: Directory containing label JSON files
Returns:
Dictionary mapping plant names to mask arrays
"""
labels_path = Path(labels_dir)
if not labels_path.exists():
logger.warning(f"Labels directory not found: {labels_dir}")
return {}
masks = {}
for json_file in labels_path.glob("*.json"):
plant_id = json_file.stem
try:
with open(json_file, 'r') as f:
data = json.load(f)
# Create mask from shapes (assuming we have image dimensions)
# This would need to be adapted based on your label format
mask = self._create_mask_from_shapes(data)
if mask is not None:
masks[plant_id] = mask
except Exception as e:
logger.error(f"Failed to load label {json_file}: {e}")
logger.info(f"Loaded {len(masks)} hand labels")
return masks
def _create_mask_from_shapes(self, data: Dict) -> Optional[np.ndarray]:
"""Create a mask array from shape data."""
# This is a placeholder - implement based on your label format
# For now, return None
return None
def validate_data(self, plants: Dict[str, Dict[str, Any]]) -> bool:
"""
Validate loaded plant data.
Args:
plants: Dictionary of plant data
Returns:
True if data is valid, False otherwise
"""
if not plants:
logger.error("No plant data loaded")
return False
for key, data in plants.items():
if "raw_image" not in data:
logger.error(f"Missing raw_image in {key}")
return False
image, filename = data["raw_image"]
if not isinstance(image, Image.Image):
logger.error(f"Invalid image type in {key}")
return False
logger.info("Data validation passed")
return True