| from __future__ import annotations |
|
|
| import random |
| from typing import Any |
|
|
| from PIL import Image, ImageEnhance, ImageOps |
|
|
|
|
| IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
|
| def build_train_transform(config: dict[str, Any]): |
| from torchvision import transforms |
|
|
| size = int(config["preprocessing"]["image_size"]) |
| aug = config.get("augmentation", {}) |
| ops: list[Any] = [transforms.Resize((size, size))] |
| if aug.get("enabled", True): |
| if aug.get("horizontal_flip", True): |
| ops.append(transforms.RandomHorizontalFlip(p=0.5)) |
| ops.append( |
| transforms.RandomAffine( |
| degrees=float(aug.get("rotation_degrees", 10)), |
| translate=(float(aug.get("translate", 0.03)), float(aug.get("translate", 0.03))), |
| scale=(float(aug.get("scale_min", 0.95)), float(aug.get("scale_max", 1.05))), |
| fill=255, |
| ) |
| ) |
| jitter = aug.get("color_jitter", {}) |
| if jitter.get("enabled", True): |
| ops.append( |
| transforms.ColorJitter( |
| brightness=float(jitter.get("brightness", 0.12)), |
| contrast=float(jitter.get("contrast", 0.12)), |
| saturation=float(jitter.get("saturation", 0.08)), |
| hue=float(jitter.get("hue", 0.02)), |
| ) |
| ) |
| ops.extend([transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) |
| return transforms.Compose(ops) |
|
|
|
|
| def build_eval_transform(config: dict[str, Any]): |
| from torchvision import transforms |
|
|
| size = int(config["preprocessing"]["image_size"]) |
| return transforms.Compose( |
| [ |
| transforms.Resize((size, size)), |
| transforms.ToTensor(), |
| transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), |
| ] |
| ) |
|
|
|
|
| def augment_pil_for_classical(image: Image.Image, augment_id: int, seed: int = 42) -> Image.Image: |
| rng = random.Random(seed + augment_id * 9973) |
| img = image.copy() |
| variant = augment_id % 6 |
| if variant in {1, 4}: |
| img = ImageOps.mirror(img) |
| if variant in {2, 3, 5}: |
| img = img.rotate(rng.uniform(-9.0, 9.0), resample=Image.Resampling.BILINEAR, fillcolor=255) |
| if variant in {3, 4}: |
| img = ImageEnhance.Brightness(img).enhance(rng.uniform(0.9, 1.1)) |
| img = ImageEnhance.Contrast(img).enhance(rng.uniform(0.9, 1.12)) |
| if variant == 5: |
| img = ImageEnhance.Sharpness(img).enhance(rng.uniform(0.9, 1.2)) |
| return img |
|
|
|
|