File size: 1,209 Bytes
196c526 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
"""Image transforms for Bean Vision dataset."""
import albumentations as A
from bean_vision.config import BeanVisionConfig
def get_transforms(config: BeanVisionConfig, is_train: bool = True) -> A.Compose:
"""Get Albumentations transforms pipeline from configuration."""
if is_train:
return A.Compose([
A.Resize(config.image.resize_height, config.image.resize_width),
A.RandomRotate90(p=config.augmentation.random_rotate90_prob),
A.Rotate(limit=config.augmentation.rotate_limit, p=config.augmentation.rotate_prob),
A.HorizontalFlip(p=config.augmentation.horizontal_flip_prob),
A.VerticalFlip(p=config.augmentation.vertical_flip_prob),
A.RandomBrightnessContrast(
brightness_limit=config.augmentation.brightness_limit,
contrast_limit=config.augmentation.contrast_limit,
p=config.augmentation.brightness_contrast_prob
),
])
else:
return A.Compose([
A.Resize(config.image.resize_height, config.image.resize_width),
])
def collate_fn(batch):
"""Custom collate function for DataLoader."""
return tuple(zip(*batch)) |