Kunitomi's picture
Upload folder using huggingface_hub
196c526 verified
#!/usr/bin/env python
"""Dataset module for Bean Vision project."""
import gzip
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple
import albumentations as A
import cv2
import numpy as np
import torch
from pycocotools.coco import COCO
from torch.utils.data import Dataset
from bean_vision.config import BeanVisionConfig
from bean_vision.utils.logging import get_logger
from bean_vision.utils.misc import ValidationError
from bean_vision.utils.paths import safe_load_image, validate_coco_json
from typing import Dict, Any, Tuple
# Type aliases
DatasetItem = Tuple[torch.Tensor, Dict[str, Any]]
ImageArray = np.ndarray
TensorDict = Dict[str, torch.Tensor]
class BeanDataset(Dataset):
"""COCO-format dataset with polygon-to-mask conversion and transforms."""
def __init__(self,
coco_json: str,
data_dir: str,
config: BeanVisionConfig,
transforms: Optional[A.Compose] = None,
is_train: bool = True) -> None:
self.logger = get_logger(self.__class__.__name__)
self.config = config
self.transforms = transforms
self.is_train = is_train
# Validate inputs
coco_path = validate_coco_json(coco_json)
self.data_dir = Path(data_dir)
if not self.data_dir.exists():
raise FileNotFoundError(f"Data directory not found: {self.data_dir}")
# Load COCO data
try:
self.coco = self._load_coco_data(coco_path)
self.image_ids = list(self.coco.imgs.keys())
self.logger.info(f"Loaded {len(self.image_ids)} images from {coco_path}")
except Exception as e:
raise ValidationError(f"Failed to load COCO data: {e}")
def _load_coco_data(self, coco_path: Path) -> COCO:
"""Load COCO data, handling compressed files."""
if coco_path.suffix == '.gz':
self.logger.info("Loading compressed COCO file")
with gzip.open(coco_path, 'rt') as f_in:
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f_out:
f_out.write(f_in.read())
temp_json = f_out.name
try:
coco = COCO(temp_json)
return coco
finally:
os.unlink(temp_json)
else:
return COCO(str(coco_path))
def __len__(self) -> int:
return len(self.image_ids)
def __getitem__(self, idx: int) -> DatasetItem:
"""Get dataset item with error handling."""
try:
return self._get_item_safe(idx)
except Exception as e:
self.logger.error(f"Error loading item {idx}: {e}")
return self._get_dummy_item()
def _get_item_safe(self, idx: int) -> DatasetItem:
"""Safely get dataset item."""
img_id = self.image_ids[idx]
img_info = self.coco.imgs[img_id]
# Load image with fallback
image, original_size = self._load_image(img_info)
# Get annotations
ann_ids = self.coco.getAnnIds(imgIds=img_id)
anns = self.coco.loadAnns(ann_ids)
# Process masks and boxes
masks, boxes, labels = self._process_annotations(anns, img_info)
# Apply transforms
if self.transforms:
image, masks, boxes, labels = self._apply_transforms(
image, masks, boxes, labels, img_info
)
# Convert to tensors
image_tensor = self._image_to_tensor(image)
target = self._create_target_dict(masks, boxes, labels, img_id, anns)
return image_tensor, target
def _load_image(self, img_info: dict) -> Tuple[ImageArray, Tuple[int, int]]:
"""Load image with proper error handling."""
file_name = img_info['file_name']
if file_name.startswith('data/'):
file_name = file_name[5:]
img_path = self.data_dir / file_name
return safe_load_image(img_path)
def _process_annotations(self, anns: List[dict], img_info: dict) -> Tuple[List[np.ndarray], List[List[float]], List[int]]:
"""Process COCO annotations into masks, boxes, and labels."""
masks = []
boxes = []
labels = []
for ann in anns:
mask = self._polygon_to_mask(ann['segmentation'], img_info['height'], img_info['width'])
if mask.sum() == 0:
continue
masks.append(mask)
x, y, w, h = ann['bbox']
boxes.append([x, y, x + w, y + h])
labels.append(1) # Bean class
return masks, boxes, labels
def _polygon_to_mask(self, segmentation: List[List[float]], height: int, width: int) -> np.ndarray:
"""Convert polygon segmentation to binary mask."""
mask = np.zeros((height, width), dtype=np.uint8)
if isinstance(segmentation, list):
for poly in segmentation:
if len(poly) >= 6:
poly_array = np.array(poly).reshape(-1, 2)
cv2.fillPoly(mask, [poly_array.astype(np.int32)], 1)
return mask
def _apply_transforms(self, image: ImageArray, masks: List[np.ndarray],
boxes: List[List[float]], labels: List[int],
img_info: dict) -> Tuple[ImageArray, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Apply augmentation transforms to image and masks."""
if len(masks) == 0:
transformed = self.transforms(image=image)
image = transformed['image']
return (image,
torch.zeros((0, image.shape[0], image.shape[1]), dtype=torch.uint8),
torch.zeros((0, 4), dtype=torch.float32),
torch.zeros((0,), dtype=torch.int64))
# Apply transforms to image and masks together
transform_with_masks = A.Compose(
self.transforms.transforms,
additional_targets={f'mask{i}': 'mask' for i in range(len(masks))}
)
transform_input = {'image': image}
for i, mask in enumerate(masks):
transform_input[f'mask{i}'] = mask.astype(np.uint8)
transformed = transform_with_masks(**transform_input)
image = transformed['image']
# Extract transformed masks
transformed_masks = [transformed[f'mask{i}'] for i in range(len(masks))]
# Update bounding boxes based on transformed masks
updated_boxes, updated_labels, updated_masks = self._update_boxes_from_masks(
transformed_masks, labels
)
# Convert to tensors
masks_tensor = torch.as_tensor(np.stack(updated_masks), dtype=torch.uint8) if updated_masks else torch.zeros((0, image.shape[0], image.shape[1]), dtype=torch.uint8)
boxes_tensor = torch.as_tensor(updated_boxes, dtype=torch.float32) if updated_boxes else torch.zeros((0, 4), dtype=torch.float32)
labels_tensor = torch.as_tensor(updated_labels, dtype=torch.int64) if updated_labels else torch.zeros((0,), dtype=torch.int64)
return image, masks_tensor, boxes_tensor, labels_tensor
def _update_boxes_from_masks(self, masks: List[np.ndarray],
labels: List[int]) -> Tuple[List[List[float]], List[int], List[np.ndarray]]:
"""Update bounding boxes based on mask contours after transformation."""
updated_boxes = []
updated_labels = []
updated_masks = []
for i, mask in enumerate(masks):
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) > 0:
largest_contour = max(contours, key=cv2.contourArea)
if cv2.contourArea(largest_contour) > self.config.inference.min_contour_area:
x, y, w, h = cv2.boundingRect(largest_contour)
updated_boxes.append([x, y, x + w, y + h])
updated_labels.append(labels[i] if i < len(labels) else 1)
updated_masks.append(mask)
return updated_boxes, updated_labels, updated_masks
def _image_to_tensor(self, image: ImageArray) -> torch.Tensor:
"""Convert image to normalized tensor."""
image_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0
mean = torch.tensor(self.config.image.imagenet_mean).view(3, 1, 1)
std = torch.tensor(self.config.image.imagenet_std).view(3, 1, 1)
return (image_tensor - mean) / std
def _create_target_dict(self, masks: torch.Tensor, boxes: torch.Tensor,
labels: torch.Tensor, img_id: int, anns: List[dict]) -> TensorDict:
"""Create target dictionary for model training."""
return {
'boxes': boxes,
'labels': labels,
'masks': masks,
'image_id': torch.tensor([img_id]),
'area': torch.tensor([ann['area'] for ann in anns]) if anns else torch.tensor([]),
'iscrowd': torch.tensor([ann['iscrowd'] for ann in anns]) if anns else torch.tensor([])
}
def _get_dummy_item(self) -> DatasetItem:
"""Create dummy item for error cases."""
dummy_image = torch.zeros((3, self.config.image.resize_height, self.config.image.resize_width))
dummy_target = {
'boxes': torch.zeros((0, 4), dtype=torch.float32),
'labels': torch.zeros((0,), dtype=torch.int64),
'masks': torch.zeros((0, self.config.image.resize_height, self.config.image.resize_width), dtype=torch.uint8),
'image_id': torch.tensor([0]),
'area': torch.tensor([]),
'iscrowd': torch.tensor([])
}
return dummy_image, dummy_target