budijuarto's picture
Upload src/egg_damage/augmentations.py
7e5d57c verified
from __future__ import annotations
import random
from typing import Any
from PIL import Image, ImageEnhance, ImageOps
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_train_transform(config: dict[str, Any]):
from torchvision import transforms
size = int(config["preprocessing"]["image_size"])
aug = config.get("augmentation", {})
ops: list[Any] = [transforms.Resize((size, size))]
if aug.get("enabled", True):
if aug.get("horizontal_flip", True):
ops.append(transforms.RandomHorizontalFlip(p=0.5))
ops.append(
transforms.RandomAffine(
degrees=float(aug.get("rotation_degrees", 10)),
translate=(float(aug.get("translate", 0.03)), float(aug.get("translate", 0.03))),
scale=(float(aug.get("scale_min", 0.95)), float(aug.get("scale_max", 1.05))),
fill=255,
)
)
jitter = aug.get("color_jitter", {})
if jitter.get("enabled", True):
ops.append(
transforms.ColorJitter(
brightness=float(jitter.get("brightness", 0.12)),
contrast=float(jitter.get("contrast", 0.12)),
saturation=float(jitter.get("saturation", 0.08)),
hue=float(jitter.get("hue", 0.02)),
)
)
ops.extend([transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
return transforms.Compose(ops)
def build_eval_transform(config: dict[str, Any]):
from torchvision import transforms
size = int(config["preprocessing"]["image_size"])
return transforms.Compose(
[
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
]
)
def augment_pil_for_classical(image: Image.Image, augment_id: int, seed: int = 42) -> Image.Image:
rng = random.Random(seed + augment_id * 9973)
img = image.copy()
variant = augment_id % 6
if variant in {1, 4}:
img = ImageOps.mirror(img)
if variant in {2, 3, 5}:
img = img.rotate(rng.uniform(-9.0, 9.0), resample=Image.Resampling.BILINEAR, fillcolor=255)
if variant in {3, 4}:
img = ImageEnhance.Brightness(img).enhance(rng.uniform(0.9, 1.1))
img = ImageEnhance.Contrast(img).enhance(rng.uniform(0.9, 1.12))
if variant == 5:
img = ImageEnhance.Sharpness(img).enhance(rng.uniform(0.9, 1.2))
return img