File size: 5,038 Bytes
e648bfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from dataclasses import dataclass
from typing import Dict, Any, List, Callable

from torchvision import transforms

def build_transform(is_train: bool, config: Any):
    """

    Factory function to build image transformation pipeline.



    Args:

        is_train (bool): Whether to build transforms for training or evaluation

        config: Configuration dictionary



    Returns:

        transforms.Compose: Composed transformation pipeline

    """
    builder = TransformBuilder(config)
    if config.is_two_transform:
        transform_crop, transform_resize = builder.build_two_transform(is_train)
        return [transform_crop, transform_resize]
    else:
        return builder.build_transform(is_train)

@dataclass
class TransformConfig:
    """Configuration class for image transformations."""
    IMAGENET_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_STD = (0.229, 0.224, 0.225)
    SIMPLE_MEAN = (0.5, 0.5, 0.5)
    SIMPLE_STD = (0.5, 0.5, 0.5)


class TransformBuilder:
    """Builder class for creating image transformation pipelines."""

    def __init__(self, config: Any):
        """

        Initialize transform builder with configuration.



        Args:

            config: Configuration dictionary containing transformation parameters

        """
        self.config = config
        self.transform_config = TransformConfig()

    def _get_normalization_params(self) -> tuple:
        """Get normalization parameters based on configuration."""
        if self.config.transform_type == "cnn_transform":
            return (self.transform_config.IMAGENET_MEAN,
                    self.transform_config.IMAGENET_STD)
        return (self.transform_config.SIMPLE_MEAN,
                self.transform_config.SIMPLE_STD)

    def _build_augmentation_transforms(self) -> list:
        """Build list of augmentation transforms."""
        return [
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=15)
        ]

    def _build_base_transforms(self, is_train) -> list:
        """Build list of base transforms."""
        mean, std = self._get_normalization_params()
        if is_train:
            return [
                transforms.Resize(self.config.image_size),
                transforms.RandomCrop(size=self.config.crop_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std)
            ]
        else:
            return [
                transforms.Resize(self.config.image_size),
                transforms.CenterCrop(size=self.config.crop_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std)
            ]

    def build_transform(self, is_train: bool = True) -> transforms.Compose:
        """

        Build transformation pipeline based on configuration.



        Args:

            is_train (bool): Whether to build transforms for training or evaluation



        Returns:

            transforms.Compose: Composed transformation pipeline

        """
        transform_list = []

        if is_train and self.config.augmentation:
            transform_list.extend(self._build_augmentation_transforms())

        transform_list.extend(self._build_base_transforms(is_train))

        return transforms.Compose(transform_list)

    def build_two_transform(self, is_train: bool = True):
        """

        Build two different transforms:

        1. One with crop (same as regular transform)

        2. One with just resize



        Returns a tuple of two transform compositions.

        """
        # mean, std = self._get_normalization_params()

        # First transform with crop (same as regular transform)
        transform_list_1 = []

        if is_train and self.config.augmentation:
            transform_list_1.extend(self._build_augmentation_transforms())

        transform_list_1.extend(self._build_base_transforms(is_train))

        # Second transform with just resize
        if is_train:
            transform_list_2 = [
                transforms.Resize(self.config.image_size),
                transforms.CenterCrop((self.config.image_size, self.config.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=self.transform_config.IMAGENET_MEAN,
                                     std=self.transform_config.IMAGENET_STD)
            ]
        else:
            transform_list_2 = [
                transforms.Resize(self.config.image_size),
                transforms.CenterCrop((self.config.image_size, self.config.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=self.transform_config.IMAGENET_MEAN,
                                     std=self.transform_config.IMAGENET_STD)
            ]

        return transforms.Compose(transform_list_1), transforms.Compose(transform_list_2)