File size: 1,209 Bytes
196c526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""Image transforms for Bean Vision dataset."""

import albumentations as A
from bean_vision.config import BeanVisionConfig


def get_transforms(config: BeanVisionConfig, is_train: bool = True) -> A.Compose:
    """Get Albumentations transforms pipeline from configuration."""
    if is_train:
        return A.Compose([
            A.Resize(config.image.resize_height, config.image.resize_width),
            A.RandomRotate90(p=config.augmentation.random_rotate90_prob),
            A.Rotate(limit=config.augmentation.rotate_limit, p=config.augmentation.rotate_prob),
            A.HorizontalFlip(p=config.augmentation.horizontal_flip_prob),
            A.VerticalFlip(p=config.augmentation.vertical_flip_prob),
            A.RandomBrightnessContrast(
                brightness_limit=config.augmentation.brightness_limit,
                contrast_limit=config.augmentation.contrast_limit,
                p=config.augmentation.brightness_contrast_prob
            ),
        ])
    else:
        return A.Compose([
            A.Resize(config.image.resize_height, config.image.resize_width),
        ])


def collate_fn(batch):
    """Custom collate function for DataLoader."""
    return tuple(zip(*batch))