File size: 5,038 Bytes
e648bfa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | from dataclasses import dataclass
from typing import Dict, Any, List, Callable
from torchvision import transforms
def build_transform(is_train: bool, config: Any):
"""
Factory function to build image transformation pipeline.
Args:
is_train (bool): Whether to build transforms for training or evaluation
config: Configuration dictionary
Returns:
transforms.Compose: Composed transformation pipeline
"""
builder = TransformBuilder(config)
if config.is_two_transform:
transform_crop, transform_resize = builder.build_two_transform(is_train)
return [transform_crop, transform_resize]
else:
return builder.build_transform(is_train)
@dataclass
class TransformConfig:
"""Configuration class for image transformations."""
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
SIMPLE_MEAN = (0.5, 0.5, 0.5)
SIMPLE_STD = (0.5, 0.5, 0.5)
class TransformBuilder:
"""Builder class for creating image transformation pipelines."""
def __init__(self, config: Any):
"""
Initialize transform builder with configuration.
Args:
config: Configuration dictionary containing transformation parameters
"""
self.config = config
self.transform_config = TransformConfig()
def _get_normalization_params(self) -> tuple:
"""Get normalization parameters based on configuration."""
if self.config.transform_type == "cnn_transform":
return (self.transform_config.IMAGENET_MEAN,
self.transform_config.IMAGENET_STD)
return (self.transform_config.SIMPLE_MEAN,
self.transform_config.SIMPLE_STD)
def _build_augmentation_transforms(self) -> list:
"""Build list of augmentation transforms."""
return [
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomRotation(degrees=15)
]
def _build_base_transforms(self, is_train) -> list:
"""Build list of base transforms."""
mean, std = self._get_normalization_params()
if is_train:
return [
transforms.Resize(self.config.image_size),
transforms.RandomCrop(size=self.config.crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
]
else:
return [
transforms.Resize(self.config.image_size),
transforms.CenterCrop(size=self.config.crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
]
def build_transform(self, is_train: bool = True) -> transforms.Compose:
"""
Build transformation pipeline based on configuration.
Args:
is_train (bool): Whether to build transforms for training or evaluation
Returns:
transforms.Compose: Composed transformation pipeline
"""
transform_list = []
if is_train and self.config.augmentation:
transform_list.extend(self._build_augmentation_transforms())
transform_list.extend(self._build_base_transforms(is_train))
return transforms.Compose(transform_list)
def build_two_transform(self, is_train: bool = True):
"""
Build two different transforms:
1. One with crop (same as regular transform)
2. One with just resize
Returns a tuple of two transform compositions.
"""
# mean, std = self._get_normalization_params()
# First transform with crop (same as regular transform)
transform_list_1 = []
if is_train and self.config.augmentation:
transform_list_1.extend(self._build_augmentation_transforms())
transform_list_1.extend(self._build_base_transforms(is_train))
# Second transform with just resize
if is_train:
transform_list_2 = [
transforms.Resize(self.config.image_size),
transforms.CenterCrop((self.config.image_size, self.config.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=self.transform_config.IMAGENET_MEAN,
std=self.transform_config.IMAGENET_STD)
]
else:
transform_list_2 = [
transforms.Resize(self.config.image_size),
transforms.CenterCrop((self.config.image_size, self.config.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=self.transform_config.IMAGENET_MEAN,
std=self.transform_config.IMAGENET_STD)
]
return transforms.Compose(transform_list_1), transforms.Compose(transform_list_2) |