#!/usr/bin/env python """Dataset module for Bean Vision project.""" import gzip import os import tempfile from pathlib import Path from typing import List, Optional, Tuple import albumentations as A import cv2 import numpy as np import torch from pycocotools.coco import COCO from torch.utils.data import Dataset from bean_vision.config import BeanVisionConfig from bean_vision.utils.logging import get_logger from bean_vision.utils.misc import ValidationError from bean_vision.utils.paths import safe_load_image, validate_coco_json from typing import Dict, Any, Tuple # Type aliases DatasetItem = Tuple[torch.Tensor, Dict[str, Any]] ImageArray = np.ndarray TensorDict = Dict[str, torch.Tensor] class BeanDataset(Dataset): """COCO-format dataset with polygon-to-mask conversion and transforms.""" def __init__(self, coco_json: str, data_dir: str, config: BeanVisionConfig, transforms: Optional[A.Compose] = None, is_train: bool = True) -> None: self.logger = get_logger(self.__class__.__name__) self.config = config self.transforms = transforms self.is_train = is_train # Validate inputs coco_path = validate_coco_json(coco_json) self.data_dir = Path(data_dir) if not self.data_dir.exists(): raise FileNotFoundError(f"Data directory not found: {self.data_dir}") # Load COCO data try: self.coco = self._load_coco_data(coco_path) self.image_ids = list(self.coco.imgs.keys()) self.logger.info(f"Loaded {len(self.image_ids)} images from {coco_path}") except Exception as e: raise ValidationError(f"Failed to load COCO data: {e}") def _load_coco_data(self, coco_path: Path) -> COCO: """Load COCO data, handling compressed files.""" if coco_path.suffix == '.gz': self.logger.info("Loading compressed COCO file") with gzip.open(coco_path, 'rt') as f_in: with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f_out: f_out.write(f_in.read()) temp_json = f_out.name try: coco = COCO(temp_json) return coco finally: os.unlink(temp_json) else: return COCO(str(coco_path)) def __len__(self) -> int: return len(self.image_ids) def __getitem__(self, idx: int) -> DatasetItem: """Get dataset item with error handling.""" try: return self._get_item_safe(idx) except Exception as e: self.logger.error(f"Error loading item {idx}: {e}") return self._get_dummy_item() def _get_item_safe(self, idx: int) -> DatasetItem: """Safely get dataset item.""" img_id = self.image_ids[idx] img_info = self.coco.imgs[img_id] # Load image with fallback image, original_size = self._load_image(img_info) # Get annotations ann_ids = self.coco.getAnnIds(imgIds=img_id) anns = self.coco.loadAnns(ann_ids) # Process masks and boxes masks, boxes, labels = self._process_annotations(anns, img_info) # Apply transforms if self.transforms: image, masks, boxes, labels = self._apply_transforms( image, masks, boxes, labels, img_info ) # Convert to tensors image_tensor = self._image_to_tensor(image) target = self._create_target_dict(masks, boxes, labels, img_id, anns) return image_tensor, target def _load_image(self, img_info: dict) -> Tuple[ImageArray, Tuple[int, int]]: """Load image with proper error handling.""" file_name = img_info['file_name'] if file_name.startswith('data/'): file_name = file_name[5:] img_path = self.data_dir / file_name return safe_load_image(img_path) def _process_annotations(self, anns: List[dict], img_info: dict) -> Tuple[List[np.ndarray], List[List[float]], List[int]]: """Process COCO annotations into masks, boxes, and labels.""" masks = [] boxes = [] labels = [] for ann in anns: mask = self._polygon_to_mask(ann['segmentation'], img_info['height'], img_info['width']) if mask.sum() == 0: continue masks.append(mask) x, y, w, h = ann['bbox'] boxes.append([x, y, x + w, y + h]) labels.append(1) # Bean class return masks, boxes, labels def _polygon_to_mask(self, segmentation: List[List[float]], height: int, width: int) -> np.ndarray: """Convert polygon segmentation to binary mask.""" mask = np.zeros((height, width), dtype=np.uint8) if isinstance(segmentation, list): for poly in segmentation: if len(poly) >= 6: poly_array = np.array(poly).reshape(-1, 2) cv2.fillPoly(mask, [poly_array.astype(np.int32)], 1) return mask def _apply_transforms(self, image: ImageArray, masks: List[np.ndarray], boxes: List[List[float]], labels: List[int], img_info: dict) -> Tuple[ImageArray, torch.Tensor, torch.Tensor, torch.Tensor]: """Apply augmentation transforms to image and masks.""" if len(masks) == 0: transformed = self.transforms(image=image) image = transformed['image'] return (image, torch.zeros((0, image.shape[0], image.shape[1]), dtype=torch.uint8), torch.zeros((0, 4), dtype=torch.float32), torch.zeros((0,), dtype=torch.int64)) # Apply transforms to image and masks together transform_with_masks = A.Compose( self.transforms.transforms, additional_targets={f'mask{i}': 'mask' for i in range(len(masks))} ) transform_input = {'image': image} for i, mask in enumerate(masks): transform_input[f'mask{i}'] = mask.astype(np.uint8) transformed = transform_with_masks(**transform_input) image = transformed['image'] # Extract transformed masks transformed_masks = [transformed[f'mask{i}'] for i in range(len(masks))] # Update bounding boxes based on transformed masks updated_boxes, updated_labels, updated_masks = self._update_boxes_from_masks( transformed_masks, labels ) # Convert to tensors masks_tensor = torch.as_tensor(np.stack(updated_masks), dtype=torch.uint8) if updated_masks else torch.zeros((0, image.shape[0], image.shape[1]), dtype=torch.uint8) boxes_tensor = torch.as_tensor(updated_boxes, dtype=torch.float32) if updated_boxes else torch.zeros((0, 4), dtype=torch.float32) labels_tensor = torch.as_tensor(updated_labels, dtype=torch.int64) if updated_labels else torch.zeros((0,), dtype=torch.int64) return image, masks_tensor, boxes_tensor, labels_tensor def _update_boxes_from_masks(self, masks: List[np.ndarray], labels: List[int]) -> Tuple[List[List[float]], List[int], List[np.ndarray]]: """Update bounding boxes based on mask contours after transformation.""" updated_boxes = [] updated_labels = [] updated_masks = [] for i, mask in enumerate(masks): contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(contours) > 0: largest_contour = max(contours, key=cv2.contourArea) if cv2.contourArea(largest_contour) > self.config.inference.min_contour_area: x, y, w, h = cv2.boundingRect(largest_contour) updated_boxes.append([x, y, x + w, y + h]) updated_labels.append(labels[i] if i < len(labels) else 1) updated_masks.append(mask) return updated_boxes, updated_labels, updated_masks def _image_to_tensor(self, image: ImageArray) -> torch.Tensor: """Convert image to normalized tensor.""" image_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0 mean = torch.tensor(self.config.image.imagenet_mean).view(3, 1, 1) std = torch.tensor(self.config.image.imagenet_std).view(3, 1, 1) return (image_tensor - mean) / std def _create_target_dict(self, masks: torch.Tensor, boxes: torch.Tensor, labels: torch.Tensor, img_id: int, anns: List[dict]) -> TensorDict: """Create target dictionary for model training.""" return { 'boxes': boxes, 'labels': labels, 'masks': masks, 'image_id': torch.tensor([img_id]), 'area': torch.tensor([ann['area'] for ann in anns]) if anns else torch.tensor([]), 'iscrowd': torch.tensor([ann['iscrowd'] for ann in anns]) if anns else torch.tensor([]) } def _get_dummy_item(self) -> DatasetItem: """Create dummy item for error cases.""" dummy_image = torch.zeros((3, self.config.image.resize_height, self.config.image.resize_width)) dummy_target = { 'boxes': torch.zeros((0, 4), dtype=torch.float32), 'labels': torch.zeros((0,), dtype=torch.int64), 'masks': torch.zeros((0, self.config.image.resize_height, self.config.image.resize_width), dtype=torch.uint8), 'image_id': torch.tensor([0]), 'area': torch.tensor([]), 'iscrowd': torch.tensor([]) } return dummy_image, dummy_target