from __future__ import annotations import random from typing import Any from PIL import Image, ImageEnhance, ImageOps IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def build_train_transform(config: dict[str, Any]): from torchvision import transforms size = int(config["preprocessing"]["image_size"]) aug = config.get("augmentation", {}) ops: list[Any] = [transforms.Resize((size, size))] if aug.get("enabled", True): if aug.get("horizontal_flip", True): ops.append(transforms.RandomHorizontalFlip(p=0.5)) ops.append( transforms.RandomAffine( degrees=float(aug.get("rotation_degrees", 10)), translate=(float(aug.get("translate", 0.03)), float(aug.get("translate", 0.03))), scale=(float(aug.get("scale_min", 0.95)), float(aug.get("scale_max", 1.05))), fill=255, ) ) jitter = aug.get("color_jitter", {}) if jitter.get("enabled", True): ops.append( transforms.ColorJitter( brightness=float(jitter.get("brightness", 0.12)), contrast=float(jitter.get("contrast", 0.12)), saturation=float(jitter.get("saturation", 0.08)), hue=float(jitter.get("hue", 0.02)), ) ) ops.extend([transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) return transforms.Compose(ops) def build_eval_transform(config: dict[str, Any]): from torchvision import transforms size = int(config["preprocessing"]["image_size"]) return transforms.Compose( [ transforms.Resize((size, size)), transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), ] ) def augment_pil_for_classical(image: Image.Image, augment_id: int, seed: int = 42) -> Image.Image: rng = random.Random(seed + augment_id * 9973) img = image.copy() variant = augment_id % 6 if variant in {1, 4}: img = ImageOps.mirror(img) if variant in {2, 3, 5}: img = img.rotate(rng.uniform(-9.0, 9.0), resample=Image.Resampling.BILINEAR, fillcolor=255) if variant in {3, 4}: img = ImageEnhance.Brightness(img).enhance(rng.uniform(0.9, 1.1)) img = ImageEnhance.Contrast(img).enhance(rng.uniform(0.9, 1.12)) if variant == 5: img = ImageEnhance.Sharpness(img).enhance(rng.uniform(0.9, 1.2)) return img