| | |
| | """Dataset module for Bean Vision project.""" |
| |
|
| | import gzip |
| | import os |
| | import tempfile |
| | from pathlib import Path |
| | from typing import List, Optional, Tuple |
| |
|
| | import albumentations as A |
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from pycocotools.coco import COCO |
| | from torch.utils.data import Dataset |
| |
|
| | from bean_vision.config import BeanVisionConfig |
| | from bean_vision.utils.logging import get_logger |
| | from bean_vision.utils.misc import ValidationError |
| | from bean_vision.utils.paths import safe_load_image, validate_coco_json |
| | from typing import Dict, Any, Tuple |
| |
|
| | |
| | DatasetItem = Tuple[torch.Tensor, Dict[str, Any]] |
| | ImageArray = np.ndarray |
| | TensorDict = Dict[str, torch.Tensor] |
| |
|
| |
|
| | class BeanDataset(Dataset): |
| | """COCO-format dataset with polygon-to-mask conversion and transforms.""" |
| | |
| | def __init__(self, |
| | coco_json: str, |
| | data_dir: str, |
| | config: BeanVisionConfig, |
| | transforms: Optional[A.Compose] = None, |
| | is_train: bool = True) -> None: |
| | self.logger = get_logger(self.__class__.__name__) |
| | self.config = config |
| | self.transforms = transforms |
| | self.is_train = is_train |
| | |
| | |
| | coco_path = validate_coco_json(coco_json) |
| | self.data_dir = Path(data_dir) |
| | |
| | if not self.data_dir.exists(): |
| | raise FileNotFoundError(f"Data directory not found: {self.data_dir}") |
| | |
| | |
| | try: |
| | self.coco = self._load_coco_data(coco_path) |
| | self.image_ids = list(self.coco.imgs.keys()) |
| | self.logger.info(f"Loaded {len(self.image_ids)} images from {coco_path}") |
| | except Exception as e: |
| | raise ValidationError(f"Failed to load COCO data: {e}") |
| | |
| | def _load_coco_data(self, coco_path: Path) -> COCO: |
| | """Load COCO data, handling compressed files.""" |
| | if coco_path.suffix == '.gz': |
| | self.logger.info("Loading compressed COCO file") |
| | with gzip.open(coco_path, 'rt') as f_in: |
| | with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f_out: |
| | f_out.write(f_in.read()) |
| | temp_json = f_out.name |
| | |
| | try: |
| | coco = COCO(temp_json) |
| | return coco |
| | finally: |
| | os.unlink(temp_json) |
| | else: |
| | return COCO(str(coco_path)) |
| | |
| | def __len__(self) -> int: |
| | return len(self.image_ids) |
| | |
| | def __getitem__(self, idx: int) -> DatasetItem: |
| | """Get dataset item with error handling.""" |
| | try: |
| | return self._get_item_safe(idx) |
| | except Exception as e: |
| | self.logger.error(f"Error loading item {idx}: {e}") |
| | return self._get_dummy_item() |
| | |
| | def _get_item_safe(self, idx: int) -> DatasetItem: |
| | """Safely get dataset item.""" |
| | img_id = self.image_ids[idx] |
| | img_info = self.coco.imgs[img_id] |
| | |
| | |
| | image, original_size = self._load_image(img_info) |
| | |
| | |
| | ann_ids = self.coco.getAnnIds(imgIds=img_id) |
| | anns = self.coco.loadAnns(ann_ids) |
| | |
| | |
| | masks, boxes, labels = self._process_annotations(anns, img_info) |
| | |
| | |
| | if self.transforms: |
| | image, masks, boxes, labels = self._apply_transforms( |
| | image, masks, boxes, labels, img_info |
| | ) |
| | |
| | |
| | image_tensor = self._image_to_tensor(image) |
| | target = self._create_target_dict(masks, boxes, labels, img_id, anns) |
| | |
| | return image_tensor, target |
| | |
| | def _load_image(self, img_info: dict) -> Tuple[ImageArray, Tuple[int, int]]: |
| | """Load image with proper error handling.""" |
| | file_name = img_info['file_name'] |
| | if file_name.startswith('data/'): |
| | file_name = file_name[5:] |
| | |
| | img_path = self.data_dir / file_name |
| | return safe_load_image(img_path) |
| | |
| | def _process_annotations(self, anns: List[dict], img_info: dict) -> Tuple[List[np.ndarray], List[List[float]], List[int]]: |
| | """Process COCO annotations into masks, boxes, and labels.""" |
| | masks = [] |
| | boxes = [] |
| | labels = [] |
| | |
| | for ann in anns: |
| | mask = self._polygon_to_mask(ann['segmentation'], img_info['height'], img_info['width']) |
| | if mask.sum() == 0: |
| | continue |
| | |
| | masks.append(mask) |
| | x, y, w, h = ann['bbox'] |
| | boxes.append([x, y, x + w, y + h]) |
| | labels.append(1) |
| | |
| | return masks, boxes, labels |
| | |
| | def _polygon_to_mask(self, segmentation: List[List[float]], height: int, width: int) -> np.ndarray: |
| | """Convert polygon segmentation to binary mask.""" |
| | mask = np.zeros((height, width), dtype=np.uint8) |
| | |
| | if isinstance(segmentation, list): |
| | for poly in segmentation: |
| | if len(poly) >= 6: |
| | poly_array = np.array(poly).reshape(-1, 2) |
| | cv2.fillPoly(mask, [poly_array.astype(np.int32)], 1) |
| | |
| | return mask |
| | |
| | def _apply_transforms(self, image: ImageArray, masks: List[np.ndarray], |
| | boxes: List[List[float]], labels: List[int], |
| | img_info: dict) -> Tuple[ImageArray, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Apply augmentation transforms to image and masks.""" |
| | if len(masks) == 0: |
| | transformed = self.transforms(image=image) |
| | image = transformed['image'] |
| | return (image, |
| | torch.zeros((0, image.shape[0], image.shape[1]), dtype=torch.uint8), |
| | torch.zeros((0, 4), dtype=torch.float32), |
| | torch.zeros((0,), dtype=torch.int64)) |
| | |
| | |
| | transform_with_masks = A.Compose( |
| | self.transforms.transforms, |
| | additional_targets={f'mask{i}': 'mask' for i in range(len(masks))} |
| | ) |
| | |
| | transform_input = {'image': image} |
| | for i, mask in enumerate(masks): |
| | transform_input[f'mask{i}'] = mask.astype(np.uint8) |
| | |
| | transformed = transform_with_masks(**transform_input) |
| | image = transformed['image'] |
| | |
| | |
| | transformed_masks = [transformed[f'mask{i}'] for i in range(len(masks))] |
| | |
| | |
| | updated_boxes, updated_labels, updated_masks = self._update_boxes_from_masks( |
| | transformed_masks, labels |
| | ) |
| | |
| | |
| | masks_tensor = torch.as_tensor(np.stack(updated_masks), dtype=torch.uint8) if updated_masks else torch.zeros((0, image.shape[0], image.shape[1]), dtype=torch.uint8) |
| | boxes_tensor = torch.as_tensor(updated_boxes, dtype=torch.float32) if updated_boxes else torch.zeros((0, 4), dtype=torch.float32) |
| | labels_tensor = torch.as_tensor(updated_labels, dtype=torch.int64) if updated_labels else torch.zeros((0,), dtype=torch.int64) |
| | |
| | return image, masks_tensor, boxes_tensor, labels_tensor |
| | |
| | def _update_boxes_from_masks(self, masks: List[np.ndarray], |
| | labels: List[int]) -> Tuple[List[List[float]], List[int], List[np.ndarray]]: |
| | """Update bounding boxes based on mask contours after transformation.""" |
| | updated_boxes = [] |
| | updated_labels = [] |
| | updated_masks = [] |
| | |
| | for i, mask in enumerate(masks): |
| | contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | |
| | if len(contours) > 0: |
| | largest_contour = max(contours, key=cv2.contourArea) |
| | if cv2.contourArea(largest_contour) > self.config.inference.min_contour_area: |
| | x, y, w, h = cv2.boundingRect(largest_contour) |
| | updated_boxes.append([x, y, x + w, y + h]) |
| | updated_labels.append(labels[i] if i < len(labels) else 1) |
| | updated_masks.append(mask) |
| | |
| | return updated_boxes, updated_labels, updated_masks |
| | |
| | def _image_to_tensor(self, image: ImageArray) -> torch.Tensor: |
| | """Convert image to normalized tensor.""" |
| | image_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0 |
| | mean = torch.tensor(self.config.image.imagenet_mean).view(3, 1, 1) |
| | std = torch.tensor(self.config.image.imagenet_std).view(3, 1, 1) |
| | return (image_tensor - mean) / std |
| | |
| | def _create_target_dict(self, masks: torch.Tensor, boxes: torch.Tensor, |
| | labels: torch.Tensor, img_id: int, anns: List[dict]) -> TensorDict: |
| | """Create target dictionary for model training.""" |
| | return { |
| | 'boxes': boxes, |
| | 'labels': labels, |
| | 'masks': masks, |
| | 'image_id': torch.tensor([img_id]), |
| | 'area': torch.tensor([ann['area'] for ann in anns]) if anns else torch.tensor([]), |
| | 'iscrowd': torch.tensor([ann['iscrowd'] for ann in anns]) if anns else torch.tensor([]) |
| | } |
| | |
| | def _get_dummy_item(self) -> DatasetItem: |
| | """Create dummy item for error cases.""" |
| | dummy_image = torch.zeros((3, self.config.image.resize_height, self.config.image.resize_width)) |
| | dummy_target = { |
| | 'boxes': torch.zeros((0, 4), dtype=torch.float32), |
| | 'labels': torch.zeros((0,), dtype=torch.int64), |
| | 'masks': torch.zeros((0, self.config.image.resize_height, self.config.image.resize_width), dtype=torch.uint8), |
| | 'image_id': torch.tensor([0]), |
| | 'area': torch.tensor([]), |
| | 'iscrowd': torch.tensor([]) |
| | } |
| | return dummy_image, dummy_target |
| |
|
| |
|
| |
|