File size: 18,528 Bytes
b4123b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
"""
Data loading functionality for the Sorghum Pipeline.

This module handles loading raw images, managing plant data,
and organizing data according to the pipeline requirements.
"""

import os
import glob
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from PIL import Image
import numpy as np
import logging

logger = logging.getLogger(__name__)


class DataLoader:
    """Handles loading and organizing plant image data."""
    
    # Plants to ignore completely (empty by default)
    IGNORE_PLANTS = set()
    
    # Plants where you want exactly one frame from their own folder
    EXACT_FRAME = {
        4: 7, 5: 5, 7: 5, 12: 5, 13: 5, 18: 7, 19: 2, 20: 3,
        24: 6, 25: 5, 26: 5, 30: 8, 37: 7
    }
    
    # Plants where you want to borrow a frame from a different plant folder
    BORROW_FRAME = {
        14: (13, 5), 15: (14, 5), 16: (15, 5), 33: (34, 7),
        34: (35, 7), 35: (35, 8), 36: (36, 6)
    }
    
    # Overrides provided by user: preferred frame per target plant name
    FRAME_OVERRIDE_BY_NAME = {
        'plant1': 9, 'plant2': 10, 'plant3': 9, 'plant5': 7, 'plant6': 9, 'plant8': 5,
        'plant7': 9, 'plant10': 9, 'plant11': 9, 'plant12': 9,
        'plant13': 10, 'plant14': 8, 'plant15': 11, 'plant19': 4, 'plant20': 7,
        'plant21': 9, 'plant22': 10, 'plant25': 4, 'plant26': 2, 'plant27': 10, 'plant28': 9, 'plant29': 2,
        'plant30': 9, 'plant31': 10, 'plant32': 9, 'plant33': 8,
        'plant35': 9, 'plant36': 4, 'plant38': 9, 'plant39': 9, 'plant41': 9,
        'plant42': 6, 'plant43': 10, 'plant44': 9, 'plant45': 7,
        'plant47': 10, 'plant48': 11,
    }
    
    # Substitutes provided by user: map target plant name -> source plant name
    PLANT_SUBSTITUTES_BY_NAME = {
        'plant16': 'plant15', 'plant15': 'plant14', 'plant14': 'plant13',
        'plant13': 'plant12', 'plant33': 'plant34', 'plant34': 'plant35',
        'plant24': 'plant25', 'plant25': 'plant25', 'plant35': 'plant36',
        'plant36': 'plant37', 'plant37': 'plant37', 'plant44': 'plant43',
        'plant45': 'plant44',
    }
    
    def __init__(self, input_folder: str, debug: bool = False, include_ignored: bool = False, strict_loader: bool = False, excluded_dates: Optional[List[str]] = None):
        """
        Initialize the data loader.
        
        Args:
            input_folder: Path to the input dataset folder
            debug: Enable debug logging
        """
        self.input_folder = Path(input_folder)
        self.debug = debug
        self.include_ignored = include_ignored
        self.strict_loader = strict_loader
        
        if not self.input_folder.exists():
            raise FileNotFoundError(f"Input folder does not exist: {input_folder}")
        # Normalize excluded dates as a set of folder names (with dashes)
        self.excluded_dates = set(excluded_dates or [])
    
    def load_selected_frames(self) -> Dict[str, Dict[str, Any]]:
        """
        Load selected frames according to predefined rules.
        If strict_loader is True, load only frame numbers from the plant's own folder (no borrowing/special picks).
        
        Returns:
            Dictionary with plant data organized by key format: "YYYY_MM_DD_plantX_frameY"
        """
        logger.info("Loading selected frames from dataset...")
        plants = {}

        # Detect if input folder is a direct date folder (contains plant folders)
        first_items = list(self.input_folder.iterdir())
        has_plant_folders = any(item.is_dir() and item.name.startswith('plant') for item in first_items)

        def choose_frame_and_source(pid: int) -> Tuple[int, str]:
            if self.strict_loader:
                # In strict mode, honor explicit frame overrides AND substitution of source plant
                plant_name_local = f"plant{pid}"
                frame_num = self.FRAME_OVERRIDE_BY_NAME.get(
                    plant_name_local,
                    self.EXACT_FRAME.get(pid, 8)
                )
                source_plant = self.PLANT_SUBSTITUTES_BY_NAME.get(plant_name_local, plant_name_local)
                return frame_num, source_plant
            # Original behavior
            frame_num = self._get_frame_number(pid)
            source_plant = self._get_source_plant(pid)
            return frame_num, source_plant

        if has_plant_folders:
            # Direct date folder structure
            date_name = self.input_folder.name
            date_path = self.input_folder
            for plant_name in sorted(os.listdir(date_path)):
                plant_path = date_path / plant_name
                if not plant_path.is_dir():
                    continue
                try:
                    plant_id = int(plant_name.replace("plant", ""))
                except ValueError:
                    continue
                if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
                    if self.debug:
                        logger.debug(f"Ignoring plant {plant_id}")
                    continue
                frame_num, source_plant = choose_frame_and_source(plant_id)
                frame_data = self._load_single_frame(date_path, source_plant, frame_num, plant_name)
                if frame_data:
                    key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_num}"
                    plants[key] = frame_data
                    logger.debug(f"Loaded {key}")
        else:
            # Parent folder structure with date subfolders
            for date_name in sorted(os.listdir(self.input_folder)):
                date_path = self.input_folder / date_name
                if not date_path.is_dir():
                    continue
                if date_name in self.excluded_dates:
                    logger.info(f"Skipping excluded date: {date_name}")
                    continue
                for plant_name in sorted(os.listdir(date_path)):
                    plant_path = date_path / plant_name
                    if not plant_path.is_dir():
                        continue
                    try:
                        plant_id = int(plant_name.replace("plant", ""))
                    except ValueError:
                        continue
                    if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
                        if self.debug:
                            logger.debug(f"Ignoring plant {plant_id}")
                        continue
                    frame_num, source_plant = choose_frame_and_source(plant_id)
                    frame_data = self._load_single_frame(date_path, source_plant, frame_num, plant_name)
                    if frame_data:
                        key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_num}"
                        plants[key] = frame_data
                        logger.debug(f"Loaded {key}")

        logger.info(f"Successfully loaded {len(plants)} plant frames")
        return plants
    
    def load_all_frames(self) -> Dict[str, Dict[str, Any]]:
        """
        Load all available frames for each plant.
        
        Returns:
            Dictionary with all plant frames
        """
        logger.info("Loading all frames from dataset...")
        plants = {}
        
        # Check if we're directly in a date folder (contains plant folders)
        # or in a parent folder (contains date folders)
        first_items = list(self.input_folder.iterdir())
        has_plant_folders = any(item.is_dir() and item.name.startswith('plant') for item in first_items)
        
        if has_plant_folders:
            # We're directly in a date folder
            logger.info("Detected direct date folder structure")
            date_name = self.input_folder.name
            self._load_plants_from_date_folder(self.input_folder, date_name, plants)
        else:
            # We're in a parent folder with date subfolders
            logger.info("Detected parent folder structure")
            for date_name in sorted(os.listdir(self.input_folder)):
                date_path = self.input_folder / date_name
                if not date_path.is_dir():
                    continue
                if date_name in self.excluded_dates:
                    logger.info(f"Skipping excluded date: {date_name}")
                    continue
                
                logger.info(f"Processing date: {date_name}")
                self._load_plants_from_date_folder(date_path, date_name, plants)
        
        logger.info(f"Successfully loaded {len(plants)} plant frames")
        return plants
    
    def _load_plants_from_date_folder(self, date_path: Path, date_name: str, plants: Dict[str, Dict[str, Any]]) -> None:
        """Load plants from a date folder."""
        for plant_name in sorted(os.listdir(date_path)):
            plant_path = date_path / plant_name
            if not plant_path.is_dir():
                continue
            
            # Extract plant ID
            try:
                plant_id = int(plant_name.replace("plant", ""))
            except ValueError:
                logger.warning(f"Could not extract plant ID from {plant_name}")
                continue
            
            # Skip ignored plants
            if (plant_id in self.IGNORE_PLANTS) and (not self.include_ignored):
                logger.info(f"Skipping ignored plant {plant_id}")
                continue
            
            logger.info(f"Processing plant {plant_id}")
            
            # Load all frames for this plant
            pattern = str(plant_path / f"{plant_name}_frame*.tif")
            frame_files = sorted(glob.glob(pattern))
            logger.info(f"Found {len(frame_files)} frame files for {plant_name}")
            
            for frame_path in frame_files:
                frame_data = self._load_frame_from_path(frame_path, plant_name)
                if frame_data:
                    frame_id = Path(frame_path).stem.split("_frame")[-1]
                    key = f"{date_name.replace('-', '_')}_{plant_name}_frame{frame_id}"
                    plants[key] = frame_data
                    logger.debug(f"Loaded frame: {key}")
                else:
                    logger.warning(f"Failed to load frame: {frame_path}")
    
    def load_single_plant(self, date: str, plant: str, frame: int) -> Optional[Dict[str, Any]]:
        """
        Load a specific plant frame.
        
        Args:
            date: Date string (e.g., "2025-02-05")
            plant: Plant name (e.g., "plant1")
            frame: Frame number
            
        Returns:
            Plant data dictionary or None if not found
        """
        date_path = self.input_folder / date
        if not date_path.exists():
            logger.error(f"Date folder not found: {date}")
            return None
        
        plant_path = date_path / plant
        if not plant_path.exists():
            logger.error(f"Plant folder not found: {plant}")
            return None
        
        filename = f"{plant}_frame{frame}.tif"
        frame_path = plant_path / filename
        
        return self._load_frame_from_path(str(frame_path), plant)
    
    def _get_frame_number(self, plant_id: int) -> int:
        """Get the frame number for a plant ID."""
        plant_name = f"plant{plant_id}"
        # Highest priority: explicit overrides by plant name
        if plant_name in self.FRAME_OVERRIDE_BY_NAME:
            return int(self.FRAME_OVERRIDE_BY_NAME[plant_name])
        # Next: original exact/borrrow rules
        if plant_id in self.EXACT_FRAME:
            return self.EXACT_FRAME[plant_id]
        elif plant_id in self.BORROW_FRAME:
            return self.BORROW_FRAME[plant_id][1]
        else:
            return 8  # Default frame
    
    def _get_source_plant(self, plant_id: int) -> str:
        """Get the source plant name for a plant ID."""
        plant_name = f"plant{plant_id}"
        # Highest priority: explicit substitutes by plant name
        if plant_name in self.PLANT_SUBSTITUTES_BY_NAME:
            return self.PLANT_SUBSTITUTES_BY_NAME[plant_name]
        # Next: original borrow rules
        if plant_id in self.BORROW_FRAME:
            source_id = self.BORROW_FRAME[plant_id][0]
            return f"plant{source_id}"
        else:
            return f"plant{plant_id}"
    
    def _load_single_frame(self, date_path: Path, source_plant: str, 
                          frame_num: int, target_plant: str) -> Optional[Dict[str, Any]]:
        """Load a single frame from the specified path."""
        filename = f"{source_plant}_frame{frame_num}.tif"
        frame_path = date_path / source_plant / filename
        
        if not frame_path.exists():
            if self.debug:
                logger.warning(f"Frame not found: {frame_path}")
            return None
        
        return self._load_frame_from_path(str(frame_path), target_plant)
    
    def _load_frame_from_path(self, frame_path: str, plant_name: str) -> Optional[Dict[str, Any]]:
        """Load frame data from a file path."""
        try:
            logger.debug(f"Attempting to load: {frame_path}")
            image = Image.open(frame_path)
            filename = Path(frame_path).name
            logger.debug(f"Successfully loaded image: {filename}, size: {image.size}")
            
            return {
                "raw_image": (image, filename),
                "plant_name": plant_name,
                "file_path": frame_path
            }
        except Exception as e:
            logger.error(f"Failed to load {frame_path}: {e}")
            return None
    
    def load_bounding_boxes(self, bbox_dir: str) -> Dict[str, Tuple[int, int, int, int]]:
        """
        Load bounding box data from JSON files.
        
        Args:
            bbox_dir: Directory containing bounding box JSON files
            
        Returns:
            Dictionary mapping plant names to bounding box coordinates
        """
        bbox_path = Path(bbox_dir)
        if not bbox_path.exists():
            raise FileNotFoundError(f"Bounding box directory not found: {bbox_dir}")
        
        bbox_lookup = {}
        
        for json_file in bbox_path.glob("*.json"):
            stem = json_file.stem
            # Normalize stems like plant_33_new -> plant33
            if stem.startswith('plant_'):
                parts = stem.split('_')
                try:
                    idx = next(i for i,p in enumerate(parts) if p.isdigit())
                    plant_id = f"plant{parts[idx]}"
                except Exception:
                    plant_id = stem.replace('_', '')
            else:
                plant_id = stem
            try:
                with open(json_file, 'r') as f:
                    data = json.load(f)
                
                shapes = data.get('shapes', [])
                # Prefer rectangle labeled 'sorghum' (case-insensitive), else first rectangle
                def _is_sorghum_label(s: dict) -> bool:
                    for key in ('label', 'name', 'text'):
                        val = s.get(key)
                        if isinstance(val, str) and val.lower() == 'sorghum':
                            return True
                    return False
                rect = next((s for s in shapes if s.get('shape_type') == 'rectangle' and _is_sorghum_label(s)), None)
                if rect is None:
                    rect = next((s for s in shapes if s.get('shape_type') == 'rectangle'), None)
                
                if rect:
                    (x1, y1), (x2, y2) = rect['points']
                    bbox_lookup[plant_id] = (
                        int(max(0, x1)),
                        int(max(0, y1)),
                        int(min(1e9, x2)),
                        int(min(1e9, y2))
                    )
                else:
                    bbox_lookup[plant_id] = None
                    
            except Exception as e:
                logger.error(f"Failed to load bounding box {json_file}: {e}")
        
        logger.info(f"Loaded {len(bbox_lookup)} bounding boxes")
        return bbox_lookup
    
    def load_hand_labels(self, labels_dir: str) -> Dict[str, np.ndarray]:
        """
        Load hand-labeled masks from JSON files.
        
        Args:
            labels_dir: Directory containing label JSON files
            
        Returns:
            Dictionary mapping plant names to mask arrays
        """
        labels_path = Path(labels_dir)
        if not labels_path.exists():
            logger.warning(f"Labels directory not found: {labels_dir}")
            return {}
        
        masks = {}
        
        for json_file in labels_path.glob("*.json"):
            plant_id = json_file.stem
            try:
                with open(json_file, 'r') as f:
                    data = json.load(f)
                
                # Create mask from shapes (assuming we have image dimensions)
                # This would need to be adapted based on your label format
                mask = self._create_mask_from_shapes(data)
                if mask is not None:
                    masks[plant_id] = mask
                    
            except Exception as e:
                logger.error(f"Failed to load label {json_file}: {e}")
        
        logger.info(f"Loaded {len(masks)} hand labels")
        return masks
    
    def _create_mask_from_shapes(self, data: Dict) -> Optional[np.ndarray]:
        """Create a mask array from shape data."""
        # This is a placeholder - implement based on your label format
        # For now, return None
        return None
    
    def validate_data(self, plants: Dict[str, Dict[str, Any]]) -> bool:
        """
        Validate loaded plant data.
        
        Args:
            plants: Dictionary of plant data
            
        Returns:
            True if data is valid, False otherwise
        """
        if not plants:
            logger.error("No plant data loaded")
            return False
        
        for key, data in plants.items():
            if "raw_image" not in data:
                logger.error(f"Missing raw_image in {key}")
                return False
            
            image, filename = data["raw_image"]
            if not isinstance(image, Image.Image):
                logger.error(f"Invalid image type in {key}")
                return False
        
        logger.info("Data validation passed")
        return True