Kunitomi's picture
Upload folder using huggingface_hub
196c526 verified
"""Image transforms for Bean Vision dataset."""
import albumentations as A
from bean_vision.config import BeanVisionConfig
def get_transforms(config: BeanVisionConfig, is_train: bool = True) -> A.Compose:
"""Get Albumentations transforms pipeline from configuration."""
if is_train:
return A.Compose([
A.Resize(config.image.resize_height, config.image.resize_width),
A.RandomRotate90(p=config.augmentation.random_rotate90_prob),
A.Rotate(limit=config.augmentation.rotate_limit, p=config.augmentation.rotate_prob),
A.HorizontalFlip(p=config.augmentation.horizontal_flip_prob),
A.VerticalFlip(p=config.augmentation.vertical_flip_prob),
A.RandomBrightnessContrast(
brightness_limit=config.augmentation.brightness_limit,
contrast_limit=config.augmentation.contrast_limit,
p=config.augmentation.brightness_contrast_prob
),
])
else:
return A.Compose([
A.Resize(config.image.resize_height, config.image.resize_width),
])
def collate_fn(batch):
"""Custom collate function for DataLoader."""
return tuple(zip(*batch))