"""Image transforms for Bean Vision dataset.""" import albumentations as A from bean_vision.config import BeanVisionConfig def get_transforms(config: BeanVisionConfig, is_train: bool = True) -> A.Compose: """Get Albumentations transforms pipeline from configuration.""" if is_train: return A.Compose([ A.Resize(config.image.resize_height, config.image.resize_width), A.RandomRotate90(p=config.augmentation.random_rotate90_prob), A.Rotate(limit=config.augmentation.rotate_limit, p=config.augmentation.rotate_prob), A.HorizontalFlip(p=config.augmentation.horizontal_flip_prob), A.VerticalFlip(p=config.augmentation.vertical_flip_prob), A.RandomBrightnessContrast( brightness_limit=config.augmentation.brightness_limit, contrast_limit=config.augmentation.contrast_limit, p=config.augmentation.brightness_contrast_prob ), ]) else: return A.Compose([ A.Resize(config.image.resize_height, config.image.resize_width), ]) def collate_fn(batch): """Custom collate function for DataLoader.""" return tuple(zip(*batch))