| """Image transforms for Bean Vision dataset.""" | |
| import albumentations as A | |
| from bean_vision.config import BeanVisionConfig | |
| def get_transforms(config: BeanVisionConfig, is_train: bool = True) -> A.Compose: | |
| """Get Albumentations transforms pipeline from configuration.""" | |
| if is_train: | |
| return A.Compose([ | |
| A.Resize(config.image.resize_height, config.image.resize_width), | |
| A.RandomRotate90(p=config.augmentation.random_rotate90_prob), | |
| A.Rotate(limit=config.augmentation.rotate_limit, p=config.augmentation.rotate_prob), | |
| A.HorizontalFlip(p=config.augmentation.horizontal_flip_prob), | |
| A.VerticalFlip(p=config.augmentation.vertical_flip_prob), | |
| A.RandomBrightnessContrast( | |
| brightness_limit=config.augmentation.brightness_limit, | |
| contrast_limit=config.augmentation.contrast_limit, | |
| p=config.augmentation.brightness_contrast_prob | |
| ), | |
| ]) | |
| else: | |
| return A.Compose([ | |
| A.Resize(config.image.resize_height, config.image.resize_width), | |
| ]) | |
| def collate_fn(batch): | |
| """Custom collate function for DataLoader.""" | |
| return tuple(zip(*batch)) |