# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Dict from torchvision import transforms def get_image_augmentation( color_jitter: Optional[Dict[str, float]] = None, gray_scale: bool = True, gau_blur: bool = False ) -> Optional[transforms.Compose]: """Create a composition of image augmentations. Args: color_jitter: Dictionary containing color jitter parameters: - brightness: float (default: 0.5) - contrast: float (default: 0.5) - saturation: float (default: 0.5) - hue: float (default: 0.1) - p: probability of applying (default: 0.9) If None, uses default values gray_scale: Whether to apply random grayscale (default: True) gau_blur: Whether to apply gaussian blur (default: False) Returns: A Compose object of transforms or None if no transforms are added """ transform_list = [] default_jitter = { "brightness": 0.5, "contrast": 0.5, "saturation": 0.5, "hue": 0.1, "p": 0.9 } # Handle color jitter if color_jitter is not None: # Merge with defaults for missing keys effective_jitter = {**default_jitter, **color_jitter} else: effective_jitter = default_jitter transform_list.append( transforms.RandomApply( [ transforms.ColorJitter( brightness=effective_jitter["brightness"], contrast=effective_jitter["contrast"], saturation=effective_jitter["saturation"], hue=effective_jitter["hue"], ) ], p=effective_jitter["p"], ) ) if gray_scale: transform_list.append(transforms.RandomGrayscale(p=0.05)) if gau_blur: transform_list.append( transforms.RandomApply( [transforms.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05 ) ) return transforms.Compose(transform_list) if transform_list else None